From 01d45fe964d323e7f66358c2db57d00a44bf2274 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 2 Aug 2021 14:37:25 +0100 Subject: Prune inbound federation queues if they get too long (#10390) --- changelog.d/10390.misc | 1 + synapse/federation/federation_server.py | 17 ++++ synapse/storage/databases/main/event_federation.py | 104 ++++++++++++++++++++- tests/storage/test_event_federation.py | 57 +++++++++++ 4 files changed, 177 insertions(+), 2 deletions(-) create mode 100644 changelog.d/10390.misc diff --git a/changelog.d/10390.misc b/changelog.d/10390.misc new file mode 100644 index 0000000000..911a5733ee --- /dev/null +++ b/changelog.d/10390.misc @@ -0,0 +1 @@ +Prune inbound federation inbound queues for a room if they get too large. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2892a11d7d..145b9161d9 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1024,6 +1024,23 @@ class FederationServer(FederationBase): origin, event = next + # Prune the event queue if it's getting large. + # + # We do this *after* handling the first event as the common case is + # that the queue is empty (/has the single event in), and so there's + # no need to do this check. + pruned = await self.store.prune_staged_events_in_room(room_id, room_version) + if pruned: + # If we have pruned the queue check we need to refetch the next + # event to handle. + next = await self.store.get_next_staged_event_for_room( + room_id, room_version + ) + if not next: + break + + origin, event = next + lock = await self.store.try_acquire_lock( _INBOUND_EVENT_HANDLING_LOCK_NAME, room_id ) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 547e43ab98..44018c1c31 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -16,11 +16,11 @@ import logging from queue import Empty, PriorityQueue from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple -from prometheus_client import Gauge +from prometheus_client import Counter, Gauge from synapse.api.constants import MAX_DEPTH from synapse.api.errors import StoreError -from synapse.api.room_versions import RoomVersion +from synapse.api.room_versions import EventFormatVersions, RoomVersion from synapse.events import EventBase, make_event_from_dict from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -44,6 +44,12 @@ number_pdus_in_federation_queue = Gauge( "The total number of events in the inbound federation staging", ) +pdus_pruned_from_federation_queue = Counter( + "synapse_federation_server_number_inbound_pdu_pruned", + "The number of events in the inbound federation staging that have been " + "pruned due to the queue getting too long", +) + logger = logging.getLogger(__name__) @@ -1277,6 +1283,100 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return origin, event + async def prune_staged_events_in_room( + self, + room_id: str, + room_version: RoomVersion, + ) -> bool: + """Checks if there are lots of staged events for the room, and if so + prune them down. + + Returns: + Whether any events were pruned + """ + + # First check the size of the queue. + count = await self.db_pool.simple_select_one_onecol( + table="federation_inbound_events_staging", + keyvalues={"room_id": room_id}, + retcol="COALESCE(COUNT(*), 0)", + desc="prune_staged_events_in_room_count", + ) + + if count < 100: + return False + + # If the queue is too large, then we want clear the entire queue, + # keeping only the forward extremities (i.e. the events not referenced + # by other events in the queue). We do this so that we can always + # backpaginate in all the events we have dropped. + rows = await self.db_pool.simple_select_list( + table="federation_inbound_events_staging", + keyvalues={"room_id": room_id}, + retcols=("event_id", "event_json"), + desc="prune_staged_events_in_room_fetch", + ) + + # Find the set of events referenced by those in the queue, as well as + # collecting all the event IDs in the queue. + referenced_events: Set[str] = set() + seen_events: Set[str] = set() + for row in rows: + event_id = row["event_id"] + seen_events.add(event_id) + event_d = db_to_json(row["event_json"]) + + # We don't bother parsing the dicts into full blown event objects, + # as that is needlessly expensive. + + # We haven't checked that the `prev_events` have the right format + # yet, so we check as we go. + prev_events = event_d.get("prev_events", []) + if not isinstance(prev_events, list): + logger.info("Invalid prev_events for %s", event_id) + continue + + if room_version.event_format == EventFormatVersions.V1: + for prev_event_tuple in prev_events: + if not isinstance(prev_event_tuple, list) or len(prev_events) != 2: + logger.info("Invalid prev_events for %s", event_id) + break + + prev_event_id = prev_event_tuple[0] + if not isinstance(prev_event_id, str): + logger.info("Invalid prev_events for %s", event_id) + break + + referenced_events.add(prev_event_id) + else: + for prev_event_id in prev_events: + if not isinstance(prev_event_id, str): + logger.info("Invalid prev_events for %s", event_id) + break + + referenced_events.add(prev_event_id) + + to_delete = referenced_events & seen_events + if not to_delete: + return False + + pdus_pruned_from_federation_queue.inc(len(to_delete)) + logger.info( + "Pruning %d events in room %s from federation queue", + len(to_delete), + room_id, + ) + + await self.db_pool.simple_delete_many( + table="federation_inbound_events_staging", + keyvalues={"room_id": room_id}, + iterable=to_delete, + column="event_id", + desc="prune_staged_events_in_room_delete", + ) + + return True + async def get_all_rooms_with_staged_incoming_events(self) -> List[str]: """Get the room IDs of all events currently staged.""" return await self.db_pool.simple_select_onecol( diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index a0e2259478..c3fcf7e7b4 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -15,7 +15,9 @@ import attr from parameterized import parameterized +from synapse.api.room_versions import RoomVersions from synapse.events import _EventInternalMetadata +from synapse.util import json_encoder import tests.unittest import tests.utils @@ -504,6 +506,61 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) self.assertSetEqual(difference, set()) + def test_prune_inbound_federation_queue(self): + "Test that pruning of inbound federation queues work" + + room_id = "some_room_id" + + # Insert a bunch of events that all reference the previous one. + self.get_success( + self.store.db_pool.simple_insert_many( + table="federation_inbound_events_staging", + values=[ + { + "origin": "some_origin", + "room_id": room_id, + "received_ts": 0, + "event_id": f"$fake_event_id_{i + 1}", + "event_json": json_encoder.encode( + {"prev_events": [f"$fake_event_id_{i}"]} + ), + "internal_metadata": "{}", + } + for i in range(500) + ], + desc="test_prune_inbound_federation_queue", + ) + ) + + # Calling prune once should return True, i.e. a prune happen. The second + # time it shouldn't. + pruned = self.get_success( + self.store.prune_staged_events_in_room(room_id, RoomVersions.V6) + ) + self.assertTrue(pruned) + + pruned = self.get_success( + self.store.prune_staged_events_in_room(room_id, RoomVersions.V6) + ) + self.assertFalse(pruned) + + # Assert that we only have a single event left in the queue, and that it + # is the last one. + count = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table="federation_inbound_events_staging", + keyvalues={"room_id": room_id}, + retcol="COALESCE(COUNT(*), 0)", + desc="test_prune_inbound_federation_queue", + ) + ) + self.assertEqual(count, 1) + + _, event_id = self.get_success( + self.store.get_next_staged_event_id_for_room(room_id) + ) + self.assertEqual(event_id, "$fake_event_id_500") + @attr.s class FakeEvent: -- cgit 1.5.1 From fb086edaeddf8cdb8a03b8564d1b6883ac5cac6e Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Mon, 2 Aug 2021 16:50:22 +0100 Subject: Fix codestyle CI from #10440 (#10511) Co-authored-by: Erik Johnston --- changelog.d/10511.feature | 1 + tests/storage/test_txn_limit.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10511.feature diff --git a/changelog.d/10511.feature b/changelog.d/10511.feature new file mode 100644 index 0000000000..f1833b0bd7 --- /dev/null +++ b/changelog.d/10511.feature @@ -0,0 +1 @@ +Allow setting transaction limit for database connections. diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py index 9be51f9ebd..6ff3ebb137 100644 --- a/tests/storage/test_txn_limit.py +++ b/tests/storage/test_txn_limit.py @@ -32,5 +32,5 @@ class SQLTransactionLimitTestCase(unittest.HomeserverTestCase): db_pool = self.hs.get_datastores().databases[0] # force txn limit to roll over at least once - for i in range(0, 1001): + for _ in range(0, 1001): self.get_success_or_raise(db_pool.runInteraction("test_select", do_select)) -- cgit 1.5.1 From a6ea32a79893b6ee694d036f3bc29a02a79d51e8 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 2 Aug 2021 21:06:34 +0100 Subject: Fix the `tests-done` github actions step, again (#10512) --- .github/workflows/tests.yml | 21 ++++++++++++--------- changelog.d/10512.misc | 1 + 2 files changed, 13 insertions(+), 9 deletions(-) create mode 100644 changelog.d/10512.misc diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0a62c62d02..239553ae13 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -367,13 +367,16 @@ jobs: - name: Set build result env: NEEDS_CONTEXT: ${{ toJSON(needs) }} - # the `jq` incantation dumps out a series of " " lines + # the `jq` incantation dumps out a series of " " lines. + # we set it to an intermediate variable to avoid a pipe, which makes it + # hard to set $rc. run: | - set -o pipefail - jq -r 'to_entries[] | [.key,.value.result] | join(" ")' \ - <<< $NEEDS_CONTEXT | - while read job result; do - if [ "$result" != "success" ]; then - echo "::set-failed ::Job $job returned $result" - fi - done + rc=0 + results=$(jq -r 'to_entries[] | [.key,.value.result] | join(" ")' <<< $NEEDS_CONTEXT) + while read job result ; do + if [ "$result" != "success" ]; then + echo "::set-failed ::Job $job returned $result" + rc=1 + fi + done <<< $results + exit $rc diff --git a/changelog.d/10512.misc b/changelog.d/10512.misc new file mode 100644 index 0000000000..c012e89f4b --- /dev/null +++ b/changelog.d/10512.misc @@ -0,0 +1 @@ +Update the `tests-done` Github Actions status. -- cgit 1.5.1 From 2bae2c632ff595bda770212678521e04288f00a9 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 3 Aug 2021 05:08:57 -0500 Subject: Add developer documentation to explain room DAG concepts like `outliers` and `state_groups` (#10464) --- changelog.d/10464.doc | 1 + docs/SUMMARY.md | 1 + docs/development/room-dag-concepts.md | 79 +++++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+) create mode 100644 changelog.d/10464.doc create mode 100644 docs/development/room-dag-concepts.md diff --git a/changelog.d/10464.doc b/changelog.d/10464.doc new file mode 100644 index 0000000000..764fb9f65c --- /dev/null +++ b/changelog.d/10464.doc @@ -0,0 +1 @@ +Add some developer docs to explain room DAG concepts like `outliers`, `state_groups`, `depth`, etc. diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index f1bde91420..10be12d638 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -79,6 +79,7 @@ - [Single Sign-On]() - [SAML](development/saml.md) - [CAS](development/cas.md) + - [Room DAG concepts](development/room-dag-concepts.md) - [State Resolution]() - [The Auth Chain Difference Algorithm](auth_chain_difference_algorithm.md) - [Media Repository](media_repository.md) diff --git a/docs/development/room-dag-concepts.md b/docs/development/room-dag-concepts.md new file mode 100644 index 0000000000..5eed72bec6 --- /dev/null +++ b/docs/development/room-dag-concepts.md @@ -0,0 +1,79 @@ +# Room DAG concepts + +## Edges + +The word "edge" comes from graph theory lingo. An edge is just a connection +between two events. In Synapse, we connect events by specifying their +`prev_events`. A subsequent event points back at a previous event. + +``` +A (oldest) <---- B <---- C (most recent) +``` + + +## Depth and stream ordering + +Events are normally sorted by `(topological_ordering, stream_ordering)` where +`topological_ordering` is just `depth`. In other words, we first sort by `depth` +and then tie-break based on `stream_ordering`. `depth` is incremented as new +messages are added to the DAG. Normally, `stream_ordering` is an auto +incrementing integer, but backfilled events start with `stream_ordering=-1` and decrement. + +--- + + - `/sync` returns things in the order they arrive at the server (`stream_ordering`). + - `/messages` (and `/backfill` in the federation API) return them in the order determined by the event graph `(topological_ordering, stream_ordering)`. + +The general idea is that, if you're following a room in real-time (i.e. +`/sync`), you probably want to see the messages as they arrive at your server, +rather than skipping any that arrived late; whereas if you're looking at a +historical section of timeline (i.e. `/messages`), you want to see the best +representation of the state of the room as others were seeing it at the time. + + +## Forward extremity + +Most-recent-in-time events in the DAG which are not referenced by any other events' `prev_events` yet. + +The forward extremities of a room are used as the `prev_events` when the next event is sent. + + +## Backwards extremity + +The current marker of where we have backfilled up to and will generally be the +oldest-in-time events we know of in the DAG. + +This is an event where we haven't fetched all of the `prev_events` for. + +Once we have fetched all of its `prev_events`, it's unmarked as a backwards +extremity (although we may have formed new backwards extremities from the prev +events during the backfilling process). + + +## Outliers + +We mark an event as an `outlier` when we haven't figured out the state for the +room at that point in the DAG yet. + +We won't *necessarily* have the `prev_events` of an `outlier` in the database, +but it's entirely possible that we *might*. The status of whether we have all of +the `prev_events` is marked as a [backwards extremity](#backwards-extremity). + +For example, when we fetch the event auth chain or state for a given event, we +mark all of those claimed auth events as outliers because we haven't done the +state calculation ourself. + + +## State groups + +For every non-outlier event we need to know the state at that event. Instead of +storing the full state for each event in the DB (i.e. a `event_id -> state` +mapping), which is *very* space inefficient when state doesn't change, we +instead assign each different set of state a "state group" and then have +mappings of `event_id -> state_group` and `state_group -> state`. + + +### Stage group edges + +TODO: `state_group_edges` is a further optimization... + notes from @Azrenbeth, https://pastebin.com/seUGVGeT -- cgit 1.5.1 From a7bacccd8550b45fc1fa3dcff90f36125827b4ba Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 3 Aug 2021 11:23:45 +0100 Subject: Extend the release script to tag and create the releases. (#10496) --- changelog.d/10496.misc | 1 + scripts-dev/release.py | 311 +++++++++++++++++++++++++++++++++++++++++++------ setup.py | 2 + 3 files changed, 278 insertions(+), 36 deletions(-) create mode 100644 changelog.d/10496.misc diff --git a/changelog.d/10496.misc b/changelog.d/10496.misc new file mode 100644 index 0000000000..6d0d3e5391 --- /dev/null +++ b/changelog.d/10496.misc @@ -0,0 +1 @@ +Extend release script to also tag and create GitHub releases. diff --git a/scripts-dev/release.py b/scripts-dev/release.py index cff433af2a..e864dc6ed5 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py @@ -14,29 +14,57 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""An interactive script for doing a release. See `run()` below. +"""An interactive script for doing a release. See `cli()` below. """ +import re import subprocess import sys -from typing import Optional +import urllib.request +from os import path +from tempfile import TemporaryDirectory +from typing import List, Optional, Tuple +import attr import click +import commonmark import git +import redbaron +from click.exceptions import ClickException +from github import Github from packaging import version -from redbaron import RedBaron -@click.command() -def run(): - """An interactive script to walk through the initial stages of creating a - release, including creating release branch, updating changelog and pushing to - GitHub. +@click.group() +def cli(): + """An interactive script to walk through the parts of creating a release. Requires the dev dependencies be installed, which can be done via: pip install -e .[dev] + Then to use: + + ./scripts-dev/release.py prepare + + # ... ask others to look at the changelog ... + + ./scripts-dev/release.py tag + + # ... wait for asssets to build ... + + ./scripts-dev/release.py publish + ./scripts-dev/release.py upload + + If the env var GH_TOKEN (or GITHUB_TOKEN) is set, or passed into the + `tag`/`publish` command, then a new draft release will be created/published. + """ + + +@cli.command() +def prepare(): + """Do the initial stages of creating a release, including creating release + branch, updating changelog and pushing to GitHub. """ # Make sure we're in a git repo. @@ -51,32 +79,8 @@ def run(): click.secho("Updating git repo...") repo.remote().fetch() - # Parse the AST and load the `__version__` node so that we can edit it - # later. - with open("synapse/__init__.py") as f: - red = RedBaron(f.read()) - - version_node = None - for node in red: - if node.type != "assignment": - continue - - if node.target.type != "name": - continue - - if node.target.value != "__version__": - continue - - version_node = node - break - - if not version_node: - print("Failed to find '__version__' definition in synapse/__init__.py") - sys.exit(1) - - # Parse the current version. - current_version = version.parse(version_node.value.value.strip('"')) - assert isinstance(current_version, version.Version) + # Get the current version and AST from root Synapse module. + current_version, parsed_synapse_ast, version_node = parse_version_from_module() # Figure out what sort of release we're doing and calcuate the new version. rc = click.confirm("RC", default=True) @@ -190,7 +194,7 @@ def run(): # Update the `__version__` variable and write it back to the file. version_node.value = '"' + new_version + '"' with open("synapse/__init__.py", "w") as f: - f.write(red.dumps()) + f.write(parsed_synapse_ast.dumps()) # Generate changelogs subprocess.run("python3 -m towncrier", shell=True) @@ -240,6 +244,180 @@ def run(): ) +@cli.command() +@click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"]) +def tag(gh_token: Optional[str]): + """Tags the release and generates a draft GitHub release""" + + # Make sure we're in a git repo. + try: + repo = git.Repo() + except git.InvalidGitRepositoryError: + raise click.ClickException("Not in Synapse repo.") + + if repo.is_dirty(): + raise click.ClickException("Uncommitted changes exist.") + + click.secho("Updating git repo...") + repo.remote().fetch() + + # Find out the version and tag name. + current_version, _, _ = parse_version_from_module() + tag_name = f"v{current_version}" + + # Check we haven't released this version. + if tag_name in repo.tags: + raise click.ClickException(f"Tag {tag_name} already exists!\n") + + # Get the appropriate changelogs and tag. + changes = get_changes_for_version(current_version) + + click.echo_via_pager(changes) + if click.confirm("Edit text?", default=False): + changes = click.edit(changes, require_save=False) + + repo.create_tag(tag_name, message=changes) + + if not click.confirm("Push tag to GitHub?", default=True): + print("") + print("Run when ready to push:") + print("") + print(f"\tgit push {repo.remote().name} tag {current_version}") + print("") + return + + repo.git.push(repo.remote().name, "tag", tag_name) + + # If no token was given, we bail here + if not gh_token: + click.launch(f"https://github.com/matrix-org/synapse/releases/edit/{tag_name}") + return + + # Create a new draft release + gh = Github(gh_token) + gh_repo = gh.get_repo("matrix-org/synapse") + release = gh_repo.create_git_release( + tag=tag_name, + name=tag_name, + message=changes, + draft=True, + prerelease=current_version.is_prerelease, + ) + + # Open the release and the actions where we are building the assets. + click.launch(release.url) + click.launch( + f"https://github.com/matrix-org/synapse/actions?query=branch%3A{tag_name}" + ) + + click.echo("Wait for release assets to be built") + + +@cli.command() +@click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=True) +def publish(gh_token: str): + """Publish release.""" + + # Make sure we're in a git repo. + try: + repo = git.Repo() + except git.InvalidGitRepositoryError: + raise click.ClickException("Not in Synapse repo.") + + if repo.is_dirty(): + raise click.ClickException("Uncommitted changes exist.") + + current_version, _, _ = parse_version_from_module() + tag_name = f"v{current_version}" + + if not click.confirm(f"Publish {tag_name}?", default=True): + return + + # Publish the draft release + gh = Github(gh_token) + gh_repo = gh.get_repo("matrix-org/synapse") + for release in gh_repo.get_releases(): + if release.title == tag_name: + break + else: + raise ClickException(f"Failed to find GitHub release for {tag_name}") + + assert release.title == tag_name + + if not release.draft: + click.echo("Release already published.") + return + + release = release.update_release( + name=release.title, + message=release.body, + tag_name=release.tag_name, + prerelease=release.prerelease, + draft=False, + ) + + +@cli.command() +def upload(): + """Upload release to pypi.""" + + current_version, _, _ = parse_version_from_module() + tag_name = f"v{current_version}" + + pypi_asset_names = [ + f"matrix_synapse-{current_version}-py3-none-any.whl", + f"matrix-synapse-{current_version}.tar.gz", + ] + + with TemporaryDirectory(prefix=f"synapse_upload_{tag_name}_") as tmpdir: + for name in pypi_asset_names: + filename = path.join(tmpdir, name) + url = f"https://github.com/matrix-org/synapse/releases/download/{tag_name}/{name}" + + click.echo(f"Downloading {name} into {filename}") + urllib.request.urlretrieve(url, filename=filename) + + if click.confirm("Upload to PyPI?", default=True): + subprocess.run("twine upload *", shell=True, cwd=tmpdir) + + click.echo( + f"Done! Remember to merge the tag {tag_name} into the appropriate branches" + ) + + +def parse_version_from_module() -> Tuple[ + version.Version, redbaron.RedBaron, redbaron.Node +]: + # Parse the AST and load the `__version__` node so that we can edit it + # later. + with open("synapse/__init__.py") as f: + red = redbaron.RedBaron(f.read()) + + version_node = None + for node in red: + if node.type != "assignment": + continue + + if node.target.type != "name": + continue + + if node.target.value != "__version__": + continue + + version_node = node + break + + if not version_node: + print("Failed to find '__version__' definition in synapse/__init__.py") + sys.exit(1) + + # Parse the current version. + current_version = version.parse(version_node.value.value.strip('"')) + assert isinstance(current_version, version.Version) + + return current_version, red, version_node + + def find_ref(repo: git.Repo, ref_name: str) -> Optional[git.HEAD]: """Find the branch/ref, looking first locally then in the remote.""" if ref_name in repo.refs: @@ -256,5 +434,66 @@ def update_branch(repo: git.Repo): repo.git.merge(repo.active_branch.tracking_branch().name) +def get_changes_for_version(wanted_version: version.Version) -> str: + """Get the changelogs for the given version. + + If an RC then will only get the changelog for that RC version, otherwise if + its a full release will get the changelog for the release and all its RCs. + """ + + with open("CHANGES.md") as f: + changes = f.read() + + # First we parse the changelog so that we can split it into sections based + # on the release headings. + ast = commonmark.Parser().parse(changes) + + @attr.s(auto_attribs=True) + class VersionSection: + title: str + + # These are 0-based. + start_line: int + end_line: Optional[int] = None # Is none if its the last entry + + headings: List[VersionSection] = [] + for node, _ in ast.walker(): + # We look for all text nodes that are in a level 1 heading. + if node.t != "text": + continue + + if node.parent.t != "heading" or node.parent.level != 1: + continue + + # If we have a previous heading then we update its `end_line`. + if headings: + headings[-1].end_line = node.parent.sourcepos[0][0] - 1 + + headings.append(VersionSection(node.literal, node.parent.sourcepos[0][0] - 1)) + + changes_by_line = changes.split("\n") + + version_changelog = [] # The lines we want to include in the changelog + + # Go through each section and find any that match the requested version. + regex = re.compile(r"^Synapse v?(\S+)") + for section in headings: + groups = regex.match(section.title) + if not groups: + continue + + heading_version = version.parse(groups.group(1)) + heading_base_version = version.parse(heading_version.base_version) + + # Check if heading version matches the requested version, or if its an + # RC of the requested version. + if wanted_version not in (heading_version, heading_base_version): + continue + + version_changelog.extend(changes_by_line[section.start_line : section.end_line]) + + return "\n".join(version_changelog) + + if __name__ == "__main__": - run() + cli() diff --git a/setup.py b/setup.py index 1081548e00..c478563510 100755 --- a/setup.py +++ b/setup.py @@ -108,6 +108,8 @@ CONDITIONAL_REQUIREMENTS["dev"] = CONDITIONAL_REQUIREMENTS["lint"] + [ "click==7.1.2", "redbaron==0.9.2", "GitPython==3.1.14", + "commonmark==0.9.1", + "pygithub==1.55", ] CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.812", "mypy-zope==0.2.13"] -- cgit 1.5.1 From f4ac934afe1d91bb0cc1bb6dafc77a94165aa740 Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Tue, 3 Aug 2021 11:30:39 +0100 Subject: Revert use of PeriodicallyFlushingMemoryHandler by default (#10515) --- changelog.d/10515.feature | 1 + docs/sample_log_config.yaml | 5 +---- synapse/config/logger.py | 5 +---- 3 files changed, 3 insertions(+), 8 deletions(-) create mode 100644 changelog.d/10515.feature diff --git a/changelog.d/10515.feature b/changelog.d/10515.feature new file mode 100644 index 0000000000..db277d9ecd --- /dev/null +++ b/changelog.d/10515.feature @@ -0,0 +1 @@ +Add a buffered logging handler which periodically flushes itself. diff --git a/docs/sample_log_config.yaml b/docs/sample_log_config.yaml index b088c83405..669e600081 100644 --- a/docs/sample_log_config.yaml +++ b/docs/sample_log_config.yaml @@ -28,7 +28,7 @@ handlers: # will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR # logs will still be flushed immediately. buffer: - class: synapse.logging.handlers.PeriodicallyFlushingMemoryHandler + class: logging.handlers.MemoryHandler target: file # The capacity is the number of log lines that are buffered before # being written to disk. Increasing this will lead to better @@ -36,9 +36,6 @@ handlers: # be written to disk. capacity: 10 flushLevel: 30 # Flush for WARNING logs as well - # The period of time, in seconds, between forced flushes. - # Messages will not be delayed for longer than this time. - period: 5 # A handler that writes logs to stderr. Unused by default, but can be used # instead of "buffer" and "file" in the logger handlers. diff --git a/synapse/config/logger.py b/synapse/config/logger.py index dcd3ed1dac..ad4e6e61c3 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -71,7 +71,7 @@ handlers: # will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR # logs will still be flushed immediately. buffer: - class: synapse.logging.handlers.PeriodicallyFlushingMemoryHandler + class: logging.handlers.MemoryHandler target: file # The capacity is the number of log lines that are buffered before # being written to disk. Increasing this will lead to better @@ -79,9 +79,6 @@ handlers: # be written to disk. capacity: 10 flushLevel: 30 # Flush for WARNING logs as well - # The period of time, in seconds, between forced flushes. - # Messages will not be delayed for longer than this time. - period: 5 # A handler that writes logs to stderr. Unused by default, but can be used # instead of "buffer" and "file" in the logger handlers. -- cgit 1.5.1 From c8566191fcd7edf224db468f860c1b638fb8e763 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 3 Aug 2021 11:32:10 +0100 Subject: 1.40.0rc1 --- CHANGES.md | 60 +++++++++++++++++++++++++++++++++++++++++++++++ changelog.d/10245.feature | 1 - changelog.d/10254.feature | 1 - changelog.d/10283.feature | 1 - changelog.d/10390.misc | 1 - changelog.d/10407.feature | 1 - changelog.d/10408.misc | 1 - changelog.d/10410.bugfix | 1 - changelog.d/10411.feature | 1 - changelog.d/10413.feature | 1 - changelog.d/10415.misc | 1 - changelog.d/10426.feature | 1 - changelog.d/10429.misc | 1 - changelog.d/10431.misc | 1 - changelog.d/10432.misc | 1 - changelog.d/10437.misc | 1 - changelog.d/10438.misc | 1 - changelog.d/10439.bugfix | 1 - changelog.d/10440.feature | 1 - changelog.d/10442.misc | 1 - changelog.d/10444.misc | 1 - changelog.d/10445.doc | 1 - changelog.d/10446.misc | 1 - changelog.d/10447.feature | 1 - changelog.d/10448.feature | 1 - changelog.d/10450.misc | 1 - changelog.d/10451.misc | 1 - changelog.d/10453.doc | 1 - changelog.d/10455.bugfix | 1 - changelog.d/10463.misc | 1 - changelog.d/10464.doc | 1 - changelog.d/10468.misc | 1 - changelog.d/10482.misc | 1 - changelog.d/10483.doc | 1 - changelog.d/10488.misc | 1 - changelog.d/10489.feature | 1 - changelog.d/10490.misc | 1 - changelog.d/10491.misc | 1 - changelog.d/10496.misc | 1 - changelog.d/10499.bugfix | 1 - changelog.d/10500.misc | 1 - changelog.d/10511.feature | 1 - changelog.d/10512.misc | 1 - changelog.d/10515.feature | 1 - changelog.d/9918.feature | 1 - debian/changelog | 8 +++++-- synapse/__init__.py | 2 +- 47 files changed, 67 insertions(+), 47 deletions(-) delete mode 100644 changelog.d/10245.feature delete mode 100644 changelog.d/10254.feature delete mode 100644 changelog.d/10283.feature delete mode 100644 changelog.d/10390.misc delete mode 100644 changelog.d/10407.feature delete mode 100644 changelog.d/10408.misc delete mode 100644 changelog.d/10410.bugfix delete mode 100644 changelog.d/10411.feature delete mode 100644 changelog.d/10413.feature delete mode 100644 changelog.d/10415.misc delete mode 100644 changelog.d/10426.feature delete mode 100644 changelog.d/10429.misc delete mode 100644 changelog.d/10431.misc delete mode 100644 changelog.d/10432.misc delete mode 100644 changelog.d/10437.misc delete mode 100644 changelog.d/10438.misc delete mode 100644 changelog.d/10439.bugfix delete mode 100644 changelog.d/10440.feature delete mode 100644 changelog.d/10442.misc delete mode 100644 changelog.d/10444.misc delete mode 100644 changelog.d/10445.doc delete mode 100644 changelog.d/10446.misc delete mode 100644 changelog.d/10447.feature delete mode 100644 changelog.d/10448.feature delete mode 100644 changelog.d/10450.misc delete mode 100644 changelog.d/10451.misc delete mode 100644 changelog.d/10453.doc delete mode 100644 changelog.d/10455.bugfix delete mode 100644 changelog.d/10463.misc delete mode 100644 changelog.d/10464.doc delete mode 100644 changelog.d/10468.misc delete mode 100644 changelog.d/10482.misc delete mode 100644 changelog.d/10483.doc delete mode 100644 changelog.d/10488.misc delete mode 100644 changelog.d/10489.feature delete mode 100644 changelog.d/10490.misc delete mode 100644 changelog.d/10491.misc delete mode 100644 changelog.d/10496.misc delete mode 100644 changelog.d/10499.bugfix delete mode 100644 changelog.d/10500.misc delete mode 100644 changelog.d/10511.feature delete mode 100644 changelog.d/10512.misc delete mode 100644 changelog.d/10515.feature delete mode 100644 changelog.d/9918.feature diff --git a/CHANGES.md b/CHANGES.md index 6533249281..34274cfe9c 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,63 @@ +Synapse 1.40.0rc1 (2021-08-03) +============================== + +Features +-------- + +- Add support for [MSC2033](https://github.com/matrix-org/matrix-doc/pull/2033): `device_id` on `/account/whoami`. ([\#9918](https://github.com/matrix-org/synapse/issues/9918)) +- Make historical events discoverable from backfill for servers without any scrollback history (part of MSC2716). ([\#10245](https://github.com/matrix-org/synapse/issues/10245)) +- Update support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083) to consider changes in the MSC around which servers can issue join events. ([\#10254](https://github.com/matrix-org/synapse/issues/10254), [\#10447](https://github.com/matrix-org/synapse/issues/10447), [\#10489](https://github.com/matrix-org/synapse/issues/10489)) +- Initial support for MSC3244, Room version capabilities over the /capabilities API. ([\#10283](https://github.com/matrix-org/synapse/issues/10283)) +- Add a buffered logging handler which periodically flushes itself. ([\#10407](https://github.com/matrix-org/synapse/issues/10407), [\#10515](https://github.com/matrix-org/synapse/issues/10515)) +- Add support for https connections to a proxy server. Contributed by @Bubu and @dklimpel. ([\#10411](https://github.com/matrix-org/synapse/issues/10411)) +- Support for [MSC2285 (hidden read receipts)](https://github.com/matrix-org/matrix-doc/pull/2285). Contributed by @SimonBrandner. ([\#10413](https://github.com/matrix-org/synapse/issues/10413)) +- Email notifications now state whether an invitation is to a room or a space. ([\#10426](https://github.com/matrix-org/synapse/issues/10426)) +- Allow setting transaction limit for database connections. ([\#10440](https://github.com/matrix-org/synapse/issues/10440), [\#10511](https://github.com/matrix-org/synapse/issues/10511)) +- Add `creation_ts` to list users admin API. ([\#10448](https://github.com/matrix-org/synapse/issues/10448)) + + +Bugfixes +-------- + +- Improve character set detection in URL previews by supporting underscores (in addition to hyphens). Contributed by @srividyut. ([\#10410](https://github.com/matrix-org/synapse/issues/10410)) +- Fix events with floating outlier state being rejected over federation. ([\#10439](https://github.com/matrix-org/synapse/issues/10439)) +- Fix `synapse_federation_server_oldest_inbound_pdu_in_staging` Prometheus metric to not report a max age of 51 years when the queue is empty. ([\#10455](https://github.com/matrix-org/synapse/issues/10455)) +- Fix a bug which caused an explicit assignment of power-level 0 to a user to be misinterpreted in rare circumstances. ([\#10499](https://github.com/matrix-org/synapse/issues/10499)) + + +Improved Documentation +---------------------- + +- Fix hierarchy of providers on the OpenID page. ([\#10445](https://github.com/matrix-org/synapse/issues/10445)) +- Consolidate development documentation to `docs/development/`. ([\#10453](https://github.com/matrix-org/synapse/issues/10453)) +- Add some developer docs to explain room DAG concepts like `outliers`, `state_groups`, `depth`, etc. ([\#10464](https://github.com/matrix-org/synapse/issues/10464)) +- Document how to use Complement while developing a new Synapse feature. ([\#10483](https://github.com/matrix-org/synapse/issues/10483)) + + +Internal Changes +---------------- + +- Prune inbound federation inbound queues for a room if they get too large. ([\#10390](https://github.com/matrix-org/synapse/issues/10390)) +- Add type hints to `synapse.federation.transport.client` module. ([\#10408](https://github.com/matrix-org/synapse/issues/10408)) +- Remove shebang line from module files. ([\#10415](https://github.com/matrix-org/synapse/issues/10415)) +- Drop backwards-compatibility code that was required to support Ubuntu Xenial. ([\#10429](https://github.com/matrix-org/synapse/issues/10429)) +- Use a docker image cache for the prerequisites for the debian package build. ([\#10431](https://github.com/matrix-org/synapse/issues/10431)) +- Connect historical chunks together with chunk events instead of a content field (MSC2716). ([\#10432](https://github.com/matrix-org/synapse/issues/10432)) +- Improve servlet type hints. ([\#10437](https://github.com/matrix-org/synapse/issues/10437), [\#10438](https://github.com/matrix-org/synapse/issues/10438)) +- Replace usage of `or_ignore` in `simple_insert` with `simple_upsert` usage, to stop spamming postgres logs with spurious ERROR messages. ([\#10442](https://github.com/matrix-org/synapse/issues/10442)) +- Update the `tests-done` Github Actions status. ([\#10444](https://github.com/matrix-org/synapse/issues/10444), [\#10512](https://github.com/matrix-org/synapse/issues/10512)) +- Update type annotations to work with forthcoming Twisted 21.7.0 release. ([\#10446](https://github.com/matrix-org/synapse/issues/10446), [\#10450](https://github.com/matrix-org/synapse/issues/10450)) +- Cancel redundant GHA workflows when a new commit is pushed. ([\#10451](https://github.com/matrix-org/synapse/issues/10451)) +- Disable `msc2716` Complement tests until Complement updates are merged. ([\#10463](https://github.com/matrix-org/synapse/issues/10463)) +- Mitigate media repo XSS attacks on IE11 via the non-standard X-Content-Security-Policy header. ([\#10468](https://github.com/matrix-org/synapse/issues/10468)) +- Additional type hints in the state handler. ([\#10482](https://github.com/matrix-org/synapse/issues/10482)) +- Update syntax used to run complement tests. ([\#10488](https://github.com/matrix-org/synapse/issues/10488)) +- Fix up type annotations to work with Twisted 21.7. ([\#10490](https://github.com/matrix-org/synapse/issues/10490)) +- Improve type annotations for `ObservableDeferred`. ([\#10491](https://github.com/matrix-org/synapse/issues/10491)) +- Extend release script to also tag and create GitHub releases. ([\#10496](https://github.com/matrix-org/synapse/issues/10496)) +- Fix a bug which caused production debian packages to be incorrectly marked as 'prerelease'. ([\#10500](https://github.com/matrix-org/synapse/issues/10500)) + + Synapse 1.39.0 (2021-07-29) =========================== diff --git a/changelog.d/10245.feature b/changelog.d/10245.feature deleted file mode 100644 index b3c48cc2cc..0000000000 --- a/changelog.d/10245.feature +++ /dev/null @@ -1 +0,0 @@ -Make historical events discoverable from backfill for servers without any scrollback history (part of MSC2716). diff --git a/changelog.d/10254.feature b/changelog.d/10254.feature deleted file mode 100644 index df8bb51167..0000000000 --- a/changelog.d/10254.feature +++ /dev/null @@ -1 +0,0 @@ -Update support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083) to consider changes in the MSC around which servers can issue join events. diff --git a/changelog.d/10283.feature b/changelog.d/10283.feature deleted file mode 100644 index 99d633dbfb..0000000000 --- a/changelog.d/10283.feature +++ /dev/null @@ -1 +0,0 @@ -Initial support for MSC3244, Room version capabilities over the /capabilities API. \ No newline at end of file diff --git a/changelog.d/10390.misc b/changelog.d/10390.misc deleted file mode 100644 index 911a5733ee..0000000000 --- a/changelog.d/10390.misc +++ /dev/null @@ -1 +0,0 @@ -Prune inbound federation inbound queues for a room if they get too large. diff --git a/changelog.d/10407.feature b/changelog.d/10407.feature deleted file mode 100644 index db277d9ecd..0000000000 --- a/changelog.d/10407.feature +++ /dev/null @@ -1 +0,0 @@ -Add a buffered logging handler which periodically flushes itself. diff --git a/changelog.d/10408.misc b/changelog.d/10408.misc deleted file mode 100644 index abccd210a9..0000000000 --- a/changelog.d/10408.misc +++ /dev/null @@ -1 +0,0 @@ -Add type hints to `synapse.federation.transport.client` module. diff --git a/changelog.d/10410.bugfix b/changelog.d/10410.bugfix deleted file mode 100644 index 65b418fd35..0000000000 --- a/changelog.d/10410.bugfix +++ /dev/null @@ -1 +0,0 @@ -Improve character set detection in URL previews by supporting underscores (in addition to hyphens). Contributed by @srividyut. diff --git a/changelog.d/10411.feature b/changelog.d/10411.feature deleted file mode 100644 index ef0ab84b17..0000000000 --- a/changelog.d/10411.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for https connections to a proxy server. Contributed by @Bubu and @dklimpel. \ No newline at end of file diff --git a/changelog.d/10413.feature b/changelog.d/10413.feature deleted file mode 100644 index 3964db7e0e..0000000000 --- a/changelog.d/10413.feature +++ /dev/null @@ -1 +0,0 @@ -Support for [MSC2285 (hidden read receipts)](https://github.com/matrix-org/matrix-doc/pull/2285). Contributed by @SimonBrandner. diff --git a/changelog.d/10415.misc b/changelog.d/10415.misc deleted file mode 100644 index 3b9501acbb..0000000000 --- a/changelog.d/10415.misc +++ /dev/null @@ -1 +0,0 @@ -Remove shebang line from module files. diff --git a/changelog.d/10426.feature b/changelog.d/10426.feature deleted file mode 100644 index 9cca6dc456..0000000000 --- a/changelog.d/10426.feature +++ /dev/null @@ -1 +0,0 @@ -Email notifications now state whether an invitation is to a room or a space. diff --git a/changelog.d/10429.misc b/changelog.d/10429.misc deleted file mode 100644 index ccb2217f64..0000000000 --- a/changelog.d/10429.misc +++ /dev/null @@ -1 +0,0 @@ -Drop backwards-compatibility code that was required to support Ubuntu Xenial. diff --git a/changelog.d/10431.misc b/changelog.d/10431.misc deleted file mode 100644 index 34b9b49da6..0000000000 --- a/changelog.d/10431.misc +++ /dev/null @@ -1 +0,0 @@ -Use a docker image cache for the prerequisites for the debian package build. diff --git a/changelog.d/10432.misc b/changelog.d/10432.misc deleted file mode 100644 index 3a8cdf0ae0..0000000000 --- a/changelog.d/10432.misc +++ /dev/null @@ -1 +0,0 @@ -Connect historical chunks together with chunk events instead of a content field (MSC2716). diff --git a/changelog.d/10437.misc b/changelog.d/10437.misc deleted file mode 100644 index a557578499..0000000000 --- a/changelog.d/10437.misc +++ /dev/null @@ -1 +0,0 @@ -Improve servlet type hints. diff --git a/changelog.d/10438.misc b/changelog.d/10438.misc deleted file mode 100644 index a557578499..0000000000 --- a/changelog.d/10438.misc +++ /dev/null @@ -1 +0,0 @@ -Improve servlet type hints. diff --git a/changelog.d/10439.bugfix b/changelog.d/10439.bugfix deleted file mode 100644 index 74e5a25126..0000000000 --- a/changelog.d/10439.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix events with floating outlier state being rejected over federation. diff --git a/changelog.d/10440.feature b/changelog.d/10440.feature deleted file mode 100644 index f1833b0bd7..0000000000 --- a/changelog.d/10440.feature +++ /dev/null @@ -1 +0,0 @@ -Allow setting transaction limit for database connections. diff --git a/changelog.d/10442.misc b/changelog.d/10442.misc deleted file mode 100644 index b8d412d732..0000000000 --- a/changelog.d/10442.misc +++ /dev/null @@ -1 +0,0 @@ -Replace usage of `or_ignore` in `simple_insert` with `simple_upsert` usage, to stop spamming postgres logs with spurious ERROR messages. diff --git a/changelog.d/10444.misc b/changelog.d/10444.misc deleted file mode 100644 index c012e89f4b..0000000000 --- a/changelog.d/10444.misc +++ /dev/null @@ -1 +0,0 @@ -Update the `tests-done` Github Actions status. diff --git a/changelog.d/10445.doc b/changelog.d/10445.doc deleted file mode 100644 index 4c023ded7c..0000000000 --- a/changelog.d/10445.doc +++ /dev/null @@ -1 +0,0 @@ -Fix hierarchy of providers on the OpenID page. diff --git a/changelog.d/10446.misc b/changelog.d/10446.misc deleted file mode 100644 index a5a0ca80eb..0000000000 --- a/changelog.d/10446.misc +++ /dev/null @@ -1 +0,0 @@ -Update type annotations to work with forthcoming Twisted 21.7.0 release. diff --git a/changelog.d/10447.feature b/changelog.d/10447.feature deleted file mode 100644 index df8bb51167..0000000000 --- a/changelog.d/10447.feature +++ /dev/null @@ -1 +0,0 @@ -Update support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083) to consider changes in the MSC around which servers can issue join events. diff --git a/changelog.d/10448.feature b/changelog.d/10448.feature deleted file mode 100644 index f6579e0ca8..0000000000 --- a/changelog.d/10448.feature +++ /dev/null @@ -1 +0,0 @@ -Add `creation_ts` to list users admin API. \ No newline at end of file diff --git a/changelog.d/10450.misc b/changelog.d/10450.misc deleted file mode 100644 index aa646f0841..0000000000 --- a/changelog.d/10450.misc +++ /dev/null @@ -1 +0,0 @@ - Update type annotations to work with forthcoming Twisted 21.7.0 release. diff --git a/changelog.d/10451.misc b/changelog.d/10451.misc deleted file mode 100644 index e38f4b476d..0000000000 --- a/changelog.d/10451.misc +++ /dev/null @@ -1 +0,0 @@ -Cancel redundant GHA workflows when a new commit is pushed. diff --git a/changelog.d/10453.doc b/changelog.d/10453.doc deleted file mode 100644 index 5d4db9bca2..0000000000 --- a/changelog.d/10453.doc +++ /dev/null @@ -1 +0,0 @@ -Consolidate development documentation to `docs/development/`. diff --git a/changelog.d/10455.bugfix b/changelog.d/10455.bugfix deleted file mode 100644 index 23c74a3c89..0000000000 --- a/changelog.d/10455.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix `synapse_federation_server_oldest_inbound_pdu_in_staging` Prometheus metric to not report a max age of 51 years when the queue is empty. diff --git a/changelog.d/10463.misc b/changelog.d/10463.misc deleted file mode 100644 index d7b4d2222e..0000000000 --- a/changelog.d/10463.misc +++ /dev/null @@ -1 +0,0 @@ -Disable `msc2716` Complement tests until Complement updates are merged. diff --git a/changelog.d/10464.doc b/changelog.d/10464.doc deleted file mode 100644 index 764fb9f65c..0000000000 --- a/changelog.d/10464.doc +++ /dev/null @@ -1 +0,0 @@ -Add some developer docs to explain room DAG concepts like `outliers`, `state_groups`, `depth`, etc. diff --git a/changelog.d/10468.misc b/changelog.d/10468.misc deleted file mode 100644 index b9854bb4c1..0000000000 --- a/changelog.d/10468.misc +++ /dev/null @@ -1 +0,0 @@ -Mitigate media repo XSS attacks on IE11 via the non-standard X-Content-Security-Policy header. diff --git a/changelog.d/10482.misc b/changelog.d/10482.misc deleted file mode 100644 index 4e9e2126e1..0000000000 --- a/changelog.d/10482.misc +++ /dev/null @@ -1 +0,0 @@ -Additional type hints in the state handler. diff --git a/changelog.d/10483.doc b/changelog.d/10483.doc deleted file mode 100644 index 0f699fafdd..0000000000 --- a/changelog.d/10483.doc +++ /dev/null @@ -1 +0,0 @@ -Document how to use Complement while developing a new Synapse feature. diff --git a/changelog.d/10488.misc b/changelog.d/10488.misc deleted file mode 100644 index a55502c163..0000000000 --- a/changelog.d/10488.misc +++ /dev/null @@ -1 +0,0 @@ -Update syntax used to run complement tests. diff --git a/changelog.d/10489.feature b/changelog.d/10489.feature deleted file mode 100644 index df8bb51167..0000000000 --- a/changelog.d/10489.feature +++ /dev/null @@ -1 +0,0 @@ -Update support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083) to consider changes in the MSC around which servers can issue join events. diff --git a/changelog.d/10490.misc b/changelog.d/10490.misc deleted file mode 100644 index 630c31adae..0000000000 --- a/changelog.d/10490.misc +++ /dev/null @@ -1 +0,0 @@ -Fix up type annotations to work with Twisted 21.7. diff --git a/changelog.d/10491.misc b/changelog.d/10491.misc deleted file mode 100644 index 3867cf2682..0000000000 --- a/changelog.d/10491.misc +++ /dev/null @@ -1 +0,0 @@ -Improve type annotations for `ObservableDeferred`. diff --git a/changelog.d/10496.misc b/changelog.d/10496.misc deleted file mode 100644 index 6d0d3e5391..0000000000 --- a/changelog.d/10496.misc +++ /dev/null @@ -1 +0,0 @@ -Extend release script to also tag and create GitHub releases. diff --git a/changelog.d/10499.bugfix b/changelog.d/10499.bugfix deleted file mode 100644 index 6487af6c96..0000000000 --- a/changelog.d/10499.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug which caused an explicit assignment of power-level 0 to a user to be misinterpreted in rare circumstances. diff --git a/changelog.d/10500.misc b/changelog.d/10500.misc deleted file mode 100644 index dbaff57364..0000000000 --- a/changelog.d/10500.misc +++ /dev/null @@ -1 +0,0 @@ -Fix a bug which caused production debian packages to be incorrectly marked as 'prerelease'. diff --git a/changelog.d/10511.feature b/changelog.d/10511.feature deleted file mode 100644 index f1833b0bd7..0000000000 --- a/changelog.d/10511.feature +++ /dev/null @@ -1 +0,0 @@ -Allow setting transaction limit for database connections. diff --git a/changelog.d/10512.misc b/changelog.d/10512.misc deleted file mode 100644 index c012e89f4b..0000000000 --- a/changelog.d/10512.misc +++ /dev/null @@ -1 +0,0 @@ -Update the `tests-done` Github Actions status. diff --git a/changelog.d/10515.feature b/changelog.d/10515.feature deleted file mode 100644 index db277d9ecd..0000000000 --- a/changelog.d/10515.feature +++ /dev/null @@ -1 +0,0 @@ -Add a buffered logging handler which periodically flushes itself. diff --git a/changelog.d/9918.feature b/changelog.d/9918.feature deleted file mode 100644 index 98f0a50893..0000000000 --- a/changelog.d/9918.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for [MSC2033](https://github.com/matrix-org/matrix-doc/pull/2033): `device_id` on `/account/whoami`. \ No newline at end of file diff --git a/debian/changelog b/debian/changelog index 341c1ac992..f0557c35ef 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,8 +1,12 @@ -matrix-synapse-py3 (1.39.0ubuntu1) UNRELEASED; urgency=medium +matrix-synapse-py3 (1.40.0~rc1) stable; urgency=medium + [ Richard van der Hoff ] * Drop backwards-compatibility code that was required to support Ubuntu Xenial. - -- Richard van der Hoff Tue, 20 Jul 2021 00:10:03 +0100 + [ Synapse Packaging team ] + * New synapse release 1.40.0~rc1. + + -- Synapse Packaging team Tue, 03 Aug 2021 11:31:49 +0100 matrix-synapse-py3 (1.39.0) stable; urgency=medium diff --git a/synapse/__init__.py b/synapse/__init__.py index 5da6c924fc..d6c1765508 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.39.0" +__version__ = "1.40.0rc1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.5.1 From c80ec5d15386a8d3db03d0064b4e87d52d38dff2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 3 Aug 2021 11:48:48 +0100 Subject: Fixup changelog --- CHANGES.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 34274cfe9c..8b78fe92f6 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,22 +5,22 @@ Features -------- - Add support for [MSC2033](https://github.com/matrix-org/matrix-doc/pull/2033): `device_id` on `/account/whoami`. ([\#9918](https://github.com/matrix-org/synapse/issues/9918)) -- Make historical events discoverable from backfill for servers without any scrollback history (part of MSC2716). ([\#10245](https://github.com/matrix-org/synapse/issues/10245)) +- Make historical events discoverable from backfill for servers without any scrollback history (part of [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716)). ([\#10245](https://github.com/matrix-org/synapse/issues/10245)) - Update support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083) to consider changes in the MSC around which servers can issue join events. ([\#10254](https://github.com/matrix-org/synapse/issues/10254), [\#10447](https://github.com/matrix-org/synapse/issues/10447), [\#10489](https://github.com/matrix-org/synapse/issues/10489)) -- Initial support for MSC3244, Room version capabilities over the /capabilities API. ([\#10283](https://github.com/matrix-org/synapse/issues/10283)) +- Initial support for [MSC3244](https://github.com/matrix-org/matrix-doc/pull/3244), Room version capabilities over the /capabilities API. ([\#10283](https://github.com/matrix-org/synapse/issues/10283)) - Add a buffered logging handler which periodically flushes itself. ([\#10407](https://github.com/matrix-org/synapse/issues/10407), [\#10515](https://github.com/matrix-org/synapse/issues/10515)) - Add support for https connections to a proxy server. Contributed by @Bubu and @dklimpel. ([\#10411](https://github.com/matrix-org/synapse/issues/10411)) - Support for [MSC2285 (hidden read receipts)](https://github.com/matrix-org/matrix-doc/pull/2285). Contributed by @SimonBrandner. ([\#10413](https://github.com/matrix-org/synapse/issues/10413)) - Email notifications now state whether an invitation is to a room or a space. ([\#10426](https://github.com/matrix-org/synapse/issues/10426)) - Allow setting transaction limit for database connections. ([\#10440](https://github.com/matrix-org/synapse/issues/10440), [\#10511](https://github.com/matrix-org/synapse/issues/10511)) -- Add `creation_ts` to list users admin API. ([\#10448](https://github.com/matrix-org/synapse/issues/10448)) +- Add `creation_ts` to "list users" admin API. ([\#10448](https://github.com/matrix-org/synapse/issues/10448)) Bugfixes -------- - Improve character set detection in URL previews by supporting underscores (in addition to hyphens). Contributed by @srividyut. ([\#10410](https://github.com/matrix-org/synapse/issues/10410)) -- Fix events with floating outlier state being rejected over federation. ([\#10439](https://github.com/matrix-org/synapse/issues/10439)) +- Fix events being incorrectly rejected over federation if they reference auth events that the server needed to fetch. ([\#10439](https://github.com/matrix-org/synapse/issues/10439)) - Fix `synapse_federation_server_oldest_inbound_pdu_in_staging` Prometheus metric to not report a max age of 51 years when the queue is empty. ([\#10455](https://github.com/matrix-org/synapse/issues/10455)) - Fix a bug which caused an explicit assignment of power-level 0 to a user to be misinterpreted in rare circumstances. ([\#10499](https://github.com/matrix-org/synapse/issues/10499)) @@ -37,12 +37,12 @@ Improved Documentation Internal Changes ---------------- -- Prune inbound federation inbound queues for a room if they get too large. ([\#10390](https://github.com/matrix-org/synapse/issues/10390)) +- Prune inbound federation queues for a room if they get too large. ([\#10390](https://github.com/matrix-org/synapse/issues/10390)) - Add type hints to `synapse.federation.transport.client` module. ([\#10408](https://github.com/matrix-org/synapse/issues/10408)) - Remove shebang line from module files. ([\#10415](https://github.com/matrix-org/synapse/issues/10415)) - Drop backwards-compatibility code that was required to support Ubuntu Xenial. ([\#10429](https://github.com/matrix-org/synapse/issues/10429)) - Use a docker image cache for the prerequisites for the debian package build. ([\#10431](https://github.com/matrix-org/synapse/issues/10431)) -- Connect historical chunks together with chunk events instead of a content field (MSC2716). ([\#10432](https://github.com/matrix-org/synapse/issues/10432)) +- Connect historical chunks together with chunk events instead of a content field ([MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716)). ([\#10432](https://github.com/matrix-org/synapse/issues/10432)) - Improve servlet type hints. ([\#10437](https://github.com/matrix-org/synapse/issues/10437), [\#10438](https://github.com/matrix-org/synapse/issues/10438)) - Replace usage of `or_ignore` in `simple_insert` with `simple_upsert` usage, to stop spamming postgres logs with spurious ERROR messages. ([\#10442](https://github.com/matrix-org/synapse/issues/10442)) - Update the `tests-done` Github Actions status. ([\#10444](https://github.com/matrix-org/synapse/issues/10444), [\#10512](https://github.com/matrix-org/synapse/issues/10512)) -- cgit 1.5.1 From da6cd82106636d4c8b5143d7c2839f11fb40fbd2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 3 Aug 2021 12:11:26 +0100 Subject: Fixup changelog --- CHANGES.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 8b78fe92f6..f2d6945886 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,7 +5,7 @@ Features -------- - Add support for [MSC2033](https://github.com/matrix-org/matrix-doc/pull/2033): `device_id` on `/account/whoami`. ([\#9918](https://github.com/matrix-org/synapse/issues/9918)) -- Make historical events discoverable from backfill for servers without any scrollback history (part of [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716)). ([\#10245](https://github.com/matrix-org/synapse/issues/10245)) +- Add support for [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716). ([\#10245](https://github.com/matrix-org/synapse/issues/10245), [\#10432](https://github.com/matrix-org/synapse/issues/10432), [\#10463](https://github.com/matrix-org/synapse/issues/10463)) - Update support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083) to consider changes in the MSC around which servers can issue join events. ([\#10254](https://github.com/matrix-org/synapse/issues/10254), [\#10447](https://github.com/matrix-org/synapse/issues/10447), [\#10489](https://github.com/matrix-org/synapse/issues/10489)) - Initial support for [MSC3244](https://github.com/matrix-org/matrix-doc/pull/3244), Room version capabilities over the /capabilities API. ([\#10283](https://github.com/matrix-org/synapse/issues/10283)) - Add a buffered logging handler which periodically flushes itself. ([\#10407](https://github.com/matrix-org/synapse/issues/10407), [\#10515](https://github.com/matrix-org/synapse/issues/10515)) @@ -42,13 +42,11 @@ Internal Changes - Remove shebang line from module files. ([\#10415](https://github.com/matrix-org/synapse/issues/10415)) - Drop backwards-compatibility code that was required to support Ubuntu Xenial. ([\#10429](https://github.com/matrix-org/synapse/issues/10429)) - Use a docker image cache for the prerequisites for the debian package build. ([\#10431](https://github.com/matrix-org/synapse/issues/10431)) -- Connect historical chunks together with chunk events instead of a content field ([MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716)). ([\#10432](https://github.com/matrix-org/synapse/issues/10432)) - Improve servlet type hints. ([\#10437](https://github.com/matrix-org/synapse/issues/10437), [\#10438](https://github.com/matrix-org/synapse/issues/10438)) - Replace usage of `or_ignore` in `simple_insert` with `simple_upsert` usage, to stop spamming postgres logs with spurious ERROR messages. ([\#10442](https://github.com/matrix-org/synapse/issues/10442)) - Update the `tests-done` Github Actions status. ([\#10444](https://github.com/matrix-org/synapse/issues/10444), [\#10512](https://github.com/matrix-org/synapse/issues/10512)) - Update type annotations to work with forthcoming Twisted 21.7.0 release. ([\#10446](https://github.com/matrix-org/synapse/issues/10446), [\#10450](https://github.com/matrix-org/synapse/issues/10450)) - Cancel redundant GHA workflows when a new commit is pushed. ([\#10451](https://github.com/matrix-org/synapse/issues/10451)) -- Disable `msc2716` Complement tests until Complement updates are merged. ([\#10463](https://github.com/matrix-org/synapse/issues/10463)) - Mitigate media repo XSS attacks on IE11 via the non-standard X-Content-Security-Policy header. ([\#10468](https://github.com/matrix-org/synapse/issues/10468)) - Additional type hints in the state handler. ([\#10482](https://github.com/matrix-org/synapse/issues/10482)) - Update syntax used to run complement tests. ([\#10488](https://github.com/matrix-org/synapse/issues/10488)) -- cgit 1.5.1 From 42225aa421efa9dd87fc63286f24f4697e4d2572 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 3 Aug 2021 12:12:50 +0100 Subject: Fixup changelog --- CHANGES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index f2d6945886..7ce28c4c18 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,7 +5,7 @@ Features -------- - Add support for [MSC2033](https://github.com/matrix-org/matrix-doc/pull/2033): `device_id` on `/account/whoami`. ([\#9918](https://github.com/matrix-org/synapse/issues/9918)) -- Add support for [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716). ([\#10245](https://github.com/matrix-org/synapse/issues/10245), [\#10432](https://github.com/matrix-org/synapse/issues/10432), [\#10463](https://github.com/matrix-org/synapse/issues/10463)) +- Update support for [MSC2716 - Incrementally importing history into existing rooms](https://github.com/matrix-org/matrix-doc/pull/2716). ([\#10245](https://github.com/matrix-org/synapse/issues/10245), [\#10432](https://github.com/matrix-org/synapse/issues/10432), [\#10463](https://github.com/matrix-org/synapse/issues/10463)) - Update support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083) to consider changes in the MSC around which servers can issue join events. ([\#10254](https://github.com/matrix-org/synapse/issues/10254), [\#10447](https://github.com/matrix-org/synapse/issues/10447), [\#10489](https://github.com/matrix-org/synapse/issues/10489)) - Initial support for [MSC3244](https://github.com/matrix-org/matrix-doc/pull/3244), Room version capabilities over the /capabilities API. ([\#10283](https://github.com/matrix-org/synapse/issues/10283)) - Add a buffered logging handler which periodically flushes itself. ([\#10407](https://github.com/matrix-org/synapse/issues/10407), [\#10515](https://github.com/matrix-org/synapse/issues/10515)) -- cgit 1.5.1 From 6878e1065308caf0f79e380b4de1433ab1487a34 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 3 Aug 2021 13:29:17 +0100 Subject: Fix release script URL (#10516) --- changelog.d/10516.misc | 1 + scripts-dev/release.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10516.misc diff --git a/changelog.d/10516.misc b/changelog.d/10516.misc new file mode 100644 index 0000000000..4d8c5e4805 --- /dev/null +++ b/changelog.d/10516.misc @@ -0,0 +1 @@ +Fix release script to open correct URL for the release. diff --git a/scripts-dev/release.py b/scripts-dev/release.py index e864dc6ed5..a339260c43 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py @@ -305,7 +305,7 @@ def tag(gh_token: Optional[str]): ) # Open the release and the actions where we are building the assets. - click.launch(release.url) + click.launch(release.html_url) click.launch( f"https://github.com/matrix-org/synapse/actions?query=branch%3A{tag_name}" ) -- cgit 1.5.1 From 903db99ed552d06f0a9e0379e55e655c5761355b Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Tue, 3 Aug 2021 14:28:30 +0100 Subject: Fix PeriodicallyFlushingMemoryHandler inhibiting application shutdown (#10517) --- changelog.d/10517.bugfix | 1 + synapse/logging/handlers.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelog.d/10517.bugfix diff --git a/changelog.d/10517.bugfix b/changelog.d/10517.bugfix new file mode 100644 index 0000000000..5b044bb34d --- /dev/null +++ b/changelog.d/10517.bugfix @@ -0,0 +1 @@ +Fix the `PeriodicallyFlushingMemoryHandler` inhibiting application shutdown because of its background thread. diff --git a/synapse/logging/handlers.py b/synapse/logging/handlers.py index a6c212f300..af5fc407a8 100644 --- a/synapse/logging/handlers.py +++ b/synapse/logging/handlers.py @@ -45,6 +45,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler): self._flushing_thread: Thread = Thread( name="PeriodicallyFlushingMemoryHandler flushing thread", target=self._flush_periodically, + daemon=True, ) self._flushing_thread.start() -- cgit 1.5.1 From dc46f12725001dde99c536a9189045709cf7e06c Mon Sep 17 00:00:00 2001 From: Dagfinn Ilmari Mannsåker Date: Tue, 3 Aug 2021 14:35:49 +0100 Subject: Include room ID in ignored EDU log messages (#10507) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dagfinn Ilmari Mannsåker --- changelog.d/10507.misc | 1 + synapse/handlers/receipts.py | 3 ++- synapse/handlers/typing.py | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) create mode 100644 changelog.d/10507.misc diff --git a/changelog.d/10507.misc b/changelog.d/10507.misc new file mode 100644 index 0000000000..5dfd116e60 --- /dev/null +++ b/changelog.d/10507.misc @@ -0,0 +1 @@ +Include room ID in ignored EDU log messages. Contributed by @ilmari. diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index b9085bbccb..5fd4525700 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -70,7 +70,8 @@ class ReceiptsHandler(BaseHandler): ) if not is_in_room: logger.info( - "Ignoring receipt from %s as we're not in the room", + "Ignoring receipt for room %r from server %s as we're not in the room", + room_id, origin, ) continue diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 0cb651a400..a97c448595 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -335,7 +335,8 @@ class TypingWriterHandler(FollowerTypingHandler): ) if not is_in_room: logger.info( - "Ignoring typing update from %s as we're not in the room", + "Ignoring typing update for room %r from server %s as we're not in the room", + room_id, origin, ) return -- cgit 1.5.1 From 4b10880da363efed5d066191190237f1c64fddfd Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 3 Aug 2021 14:45:04 +0100 Subject: Make sync response cache time configurable. (#10513) --- changelog.d/10513.feature | 1 + docs/sample_config.yaml | 9 +++++++++ synapse/config/cache.py | 13 +++++++++++++ synapse/handlers/sync.py | 14 +++++++++++--- 4 files changed, 34 insertions(+), 3 deletions(-) create mode 100644 changelog.d/10513.feature diff --git a/changelog.d/10513.feature b/changelog.d/10513.feature new file mode 100644 index 0000000000..153b2df7b2 --- /dev/null +++ b/changelog.d/10513.feature @@ -0,0 +1 @@ +Add a configuration setting for the time a `/sync` response is cached for. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 1a217f35db..a2efc14100 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -711,6 +711,15 @@ caches: # #expiry_time: 30m + # Controls how long the results of a /sync request are cached for after + # a successful response is returned. A higher duration can help clients with + # intermittent connections, at the cost of higher memory usage. + # + # By default, this is zero, which means that sync responses are not cached + # at all. + # + #sync_response_cache_duration: 2m + ## Database ## diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 8d5f38b5d9..d119427ad8 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -151,6 +151,15 @@ class CacheConfig(Config): # entries are never evicted based on time. # #expiry_time: 30m + + # Controls how long the results of a /sync request are cached for after + # a successful response is returned. A higher duration can help clients with + # intermittent connections, at the cost of higher memory usage. + # + # By default, this is zero, which means that sync responses are not cached + # at all. + # + #sync_response_cache_duration: 2m """ def read_config(self, config, **kwargs): @@ -212,6 +221,10 @@ class CacheConfig(Config): else: self.expiry_time_msec = None + self.sync_response_cache_duration = self.parse_duration( + cache_config.get("sync_response_cache_duration", 0) + ) + # Resize all caches (if necessary) with the new factors we've loaded self.resize_all_caches() diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index f30bfcc93c..590642f510 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -269,14 +269,22 @@ class SyncHandler: self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() - self.response_cache: ResponseCache[SyncRequestKey] = ResponseCache( - hs.get_clock(), "sync" - ) self.state = hs.get_state_handler() self.auth = hs.get_auth() self.storage = hs.get_storage() self.state_store = self.storage.state + # TODO: flush cache entries on subsequent sync request. + # Once we get the next /sync request (ie, one with the same access token + # that sets 'since' to 'next_batch'), we know that device won't need a + # cached result any more, and we could flush the entry from the cache to save + # memory. + self.response_cache: ResponseCache[SyncRequestKey] = ResponseCache( + hs.get_clock(), + "sync", + timeout_ms=hs.config.caches.sync_response_cache_duration, + ) + # ExpiringCache((User, Device)) -> LruCache(user_id => event_id) self.lazy_loaded_members_cache: ExpiringCache[ Tuple[str, Optional[str]], LruCache[str, str] -- cgit 1.5.1 From 951648f26a75948a3b8de8989c98c51698043d71 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 3 Aug 2021 14:45:21 +0100 Subject: Fix debian package triggers (#10481) Replace the outdated list of dpkg triggers with an autogenerated one. --- debian/build_virtualenv | 15 +++++++++++++++ debian/changelog | 2 ++ debian/matrix-synapse-py3.triggers | 9 --------- 3 files changed, 17 insertions(+), 9 deletions(-) delete mode 100644 debian/matrix-synapse-py3.triggers diff --git a/debian/build_virtualenv b/debian/build_virtualenv index 68c8659953..801ecb9086 100755 --- a/debian/build_virtualenv +++ b/debian/build_virtualenv @@ -100,3 +100,18 @@ esac # add a dependency on the right version of python to substvars. PYPKG=`basename $SNAKE` echo "synapse:pydepends=$PYPKG" >> debian/matrix-synapse-py3.substvars + + +# add a couple of triggers. This is needed so that dh-virtualenv can rebuild +# the venv when the system python changes (see +# https://dh-virtualenv.readthedocs.io/en/latest/tutorial.html#step-2-set-up-packaging-for-your-project) +# +# we do it here rather than the more conventional way of just adding it to +# debian/matrix-synapse-py3.triggers, because we need to add a trigger on the +# right version of python. +cat >>"debian/.debhelper/generated/matrix-synapse-py3/triggers" < Date: Tue, 3 Aug 2021 11:13:34 -0700 Subject: Add warnings to ip_range_blacklist usage with proxies (#10129) Per issue #9812 using `url_preview_ip_range_blacklist` with a proxy via `HTTPS_PROXY` or `HTTP_PROXY` environment variables has some inconsistent bahavior than mentioned. This PR changes the following: - Changes the Sample Config file to include a note mentioning that `url_preview_ip_range_blacklist` and `ip_range_blacklist` is ignored when using a proxy - Changes some logic in synapse/config/repository.py to send a warning when both `*ip_range_blacklist` configs and a proxy environment variable are set and but no longer throws an error. Signed-off-by: Kento Okamoto --- changelog.d/10129.bugfix | 1 + docs/sample_config.yaml | 4 ++++ synapse/config/repository.py | 24 +++++++++++++++++++----- synapse/config/server.py | 2 ++ 4 files changed, 26 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10129.bugfix diff --git a/changelog.d/10129.bugfix b/changelog.d/10129.bugfix new file mode 100644 index 0000000000..292676ec8d --- /dev/null +++ b/changelog.d/10129.bugfix @@ -0,0 +1 @@ +Add some clarification to the sample config file. Contributed by @Kentokamoto. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index a2efc14100..16843dd8c9 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -210,6 +210,8 @@ presence: # # This option replaces federation_ip_range_blacklist in Synapse v1.25.0. # +# Note: The value is ignored when an HTTP proxy is in use +# #ip_range_blacklist: # - '127.0.0.0/8' # - '10.0.0.0/8' @@ -972,6 +974,8 @@ media_store_path: "DATADIR/media_store" # This must be specified if url_preview_enabled is set. It is recommended that # you uncomment the following list as a starting point. # +# Note: The value is ignored when an HTTP proxy is in use +# #url_preview_ip_range_blacklist: # - '127.0.0.0/8' # - '10.0.0.0/8' diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 0dfb3a227a..7481f3bf5f 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from collections import namedtuple from typing import Dict, List +from urllib.request import getproxies_environment # type: ignore from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set from synapse.python_dependencies import DependencyException, check_requirements @@ -22,6 +24,8 @@ from synapse.util.module_loader import load_module from ._base import Config, ConfigError +logger = logging.getLogger(__name__) + DEFAULT_THUMBNAIL_SIZES = [ {"width": 32, "height": 32, "method": "crop"}, {"width": 96, "height": 96, "method": "crop"}, @@ -36,6 +40,9 @@ THUMBNAIL_SIZE_YAML = """\ # method: %(method)s """ +HTTP_PROXY_SET_WARNING = """\ +The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured.""" + ThumbnailRequirement = namedtuple( "ThumbnailRequirement", ["width", "height", "method", "media_type"] ) @@ -180,12 +187,17 @@ class ContentRepositoryConfig(Config): e.message # noqa: B306, DependencyException.message is a property ) + proxy_env = getproxies_environment() if "url_preview_ip_range_blacklist" not in config: - raise ConfigError( - "For security, you must specify an explicit target IP address " - "blacklist in url_preview_ip_range_blacklist for url previewing " - "to work" - ) + if "http" not in proxy_env or "https" not in proxy_env: + raise ConfigError( + "For security, you must specify an explicit target IP address " + "blacklist in url_preview_ip_range_blacklist for url previewing " + "to work" + ) + else: + if "http" in proxy_env or "https" in proxy_env: + logger.warning("".join(HTTP_PROXY_SET_WARNING)) # we always blacklist '0.0.0.0' and '::', which are supposed to be # unroutable addresses. @@ -292,6 +304,8 @@ class ContentRepositoryConfig(Config): # This must be specified if url_preview_enabled is set. It is recommended that # you uncomment the following list as a starting point. # + # Note: The value is ignored when an HTTP proxy is in use + # #url_preview_ip_range_blacklist: %(ip_range_blacklist)s diff --git a/synapse/config/server.py b/synapse/config/server.py index b9e0c0b300..187b4301a0 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -960,6 +960,8 @@ class ServerConfig(Config): # # This option replaces federation_ip_range_blacklist in Synapse v1.25.0. # + # Note: The value is ignored when an HTTP proxy is in use + # #ip_range_blacklist: %(ip_range_blacklist)s -- cgit 1.5.1 From c2000ab35b76288a625f598d2382d4e3f29f65f6 Mon Sep 17 00:00:00 2001 From: Jason Robinson Date: Wed, 4 Aug 2021 13:40:25 +0300 Subject: Add `get_userinfo_by_id` method to `ModuleApi` (#9581) Makes it easier to fetch user details in for example spam checker modules, without needing to use api._store or figure out database interactions. Signed-off-by: Jason Robinson --- changelog.d/9581.feature | 1 + synapse/module_api/__init__.py | 12 ++++++++++- synapse/storage/databases/main/registration.py | 30 +++++++++++++++++++++++++- synapse/types.py | 29 +++++++++++++++++++++++++ tests/module_api/test_api.py | 10 +++++++++ 5 files changed, 80 insertions(+), 2 deletions(-) create mode 100644 changelog.d/9581.feature diff --git a/changelog.d/9581.feature b/changelog.d/9581.feature new file mode 100644 index 0000000000..fa1949cd4b --- /dev/null +++ b/changelog.d/9581.feature @@ -0,0 +1 @@ +Add `get_userinfo_by_id` method to ModuleApi. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 473812b8e2..1cc13fc97b 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -45,7 +45,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.state import StateFilter -from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.types import JsonDict, Requester, UserID, UserInfo, create_requester from synapse.util import Clock from synapse.util.caches.descriptors import cached @@ -174,6 +174,16 @@ class ModuleApi: """The application name configured in the homeserver's configuration.""" return self._hs.config.email.email_app_name + async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: + """Get user info by user_id + + Args: + user_id: Fully qualified user id. + Returns: + UserInfo object if a user was found, otherwise None + """ + return await self._store.get_userinfo_by_id(user_id) + async def get_user_by_req( self, req: SynapseRequest, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 6ad1a0cf7f..14670c2881 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -29,7 +29,7 @@ from synapse.storage.databases.main.stats import StatsStore from synapse.storage.types import Connection, Cursor from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import UserID +from synapse.types import UserID, UserInfo from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -146,6 +146,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): @cached() async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + """Deprecated: use get_userinfo_by_id instead""" return await self.db_pool.simple_select_one( table="users", keyvalues={"name": user_id}, @@ -166,6 +167,33 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="get_user_by_id", ) + async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: + """Get a UserInfo object for a user by user ID. + + Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed, + this method should be cached. + + Args: + user_id: The user to fetch user info for. + Returns: + `UserInfo` object if user found, otherwise `None`. + """ + user_data = await self.get_user_by_id(user_id) + if not user_data: + return None + return UserInfo( + appservice_id=user_data["appservice_id"], + consent_server_notice_sent=user_data["consent_server_notice_sent"], + consent_version=user_data["consent_version"], + creation_ts=user_data["creation_ts"], + is_admin=bool(user_data["admin"]), + is_deactivated=bool(user_data["deactivated"]), + is_guest=bool(user_data["is_guest"]), + is_shadow_banned=bool(user_data["shadow_banned"]), + user_id=UserID.from_string(user_data["name"]), + user_type=user_data["user_type"], + ) + async def is_trial_user(self, user_id: str) -> bool: """Checks if user is in the "trial" period, i.e. within the first N days of registration defined by `mau_trial_days` config diff --git a/synapse/types.py b/synapse/types.py index 429bb013d2..80fa903c4b 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -751,3 +751,32 @@ def get_verify_key_from_cross_signing_key(key_info): # and return that one key for key_id, key_data in keys.items(): return (key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))) + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class UserInfo: + """Holds information about a user. Result of get_userinfo_by_id. + + Attributes: + user_id: ID of the user. + appservice_id: Application service ID that created this user. + consent_server_notice_sent: Version of policy documents the user has been sent. + consent_version: Version of policy documents the user has consented to. + creation_ts: Creation timestamp of the user. + is_admin: True if the user is an admin. + is_deactivated: True if the user has been deactivated. + is_guest: True if the user is a guest user. + is_shadow_banned: True if the user has been shadow-banned. + user_type: User type (None for normal user, 'support' and 'bot' other options). + """ + + user_id: UserID + appservice_id: Optional[int] + consent_server_notice_sent: Optional[str] + consent_version: Optional[str] + user_type: Optional[str] + creation_ts: int + is_admin: bool + is_deactivated: bool + is_guest: bool + is_shadow_banned: bool diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 81d9e2f484..0b817cc701 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -79,6 +79,16 @@ class ModuleApiTestCase(HomeserverTestCase): displayname = self.get_success(self.store.get_profile_displayname("bob")) self.assertEqual(displayname, "Bobberino") + def test_get_userinfo_by_id(self): + user_id = self.register_user("alice", "1234") + found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) + self.assertEqual(found_user.user_id.to_string(), user_id) + self.assertIdentical(found_user.is_admin, False) + + def test_get_userinfo_by_id__no_user_found(self): + found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test")) + self.assertIsNone(found_user) + def test_sending_events_into_room(self): """Tests that a module can send events into a room""" # Mock out create_and_send_nonmember_event to check whether events are being sent -- cgit 1.5.1 From 11540be55ed15da920fa6f3ea805315517c02c76 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Aug 2021 13:09:04 +0100 Subject: Fix `could not serialize access` errors for `claim_e2e_one_time_keys` (#10504) --- changelog.d/10504.misc | 1 + synapse/storage/databases/main/end_to_end_keys.py | 188 +++++++++++++++------- 2 files changed, 127 insertions(+), 62 deletions(-) create mode 100644 changelog.d/10504.misc diff --git a/changelog.d/10504.misc b/changelog.d/10504.misc new file mode 100644 index 0000000000..1479a5022d --- /dev/null +++ b/changelog.d/10504.misc @@ -0,0 +1 @@ +Reduce errors in PostgreSQL logs due to concurrent serialization errors. diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 1edc96042b..1f0a39eac4 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -755,81 +755,145 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): """ @trace - def _claim_e2e_one_time_keys(txn): - sql = ( - "SELECT key_id, key_json FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " LIMIT 1" + def _claim_e2e_one_time_key_simple( + txn, user_id: str, device_id: str, algorithm: str + ) -> Optional[Tuple[str, str]]: + """Claim OTK for device for DBs that don't support RETURNING. + + Returns: + A tuple of key name (algorithm + key ID) and key JSON, if an + OTK was found. + """ + + sql = """ + SELECT key_id, key_json FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + LIMIT 1 + """ + + txn.execute(sql, (user_id, device_id, algorithm)) + otk_row = txn.fetchone() + if otk_row is None: + return None + + key_id, key_json = otk_row + + self.db_pool.simple_delete_one_txn( + txn, + table="e2e_one_time_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + "key_id": key_id, + }, ) - fallback_sql = ( - "SELECT key_id, key_json, used FROM e2e_fallback_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " LIMIT 1" + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - result = {} - delete = [] - used_fallbacks = [] - for user_id, device_id, algorithm in query_list: - user_result = result.setdefault(user_id, {}) - device_result = user_result.setdefault(device_id, {}) - txn.execute(sql, (user_id, device_id, algorithm)) - otk_row = txn.fetchone() - if otk_row is not None: - key_id, key_json = otk_row - device_result[algorithm + ":" + key_id] = key_json - delete.append((user_id, device_id, algorithm, key_id)) - else: - # no one-time key available, so see if there's a fallback - # key - txn.execute(fallback_sql, (user_id, device_id, algorithm)) - fallback_row = txn.fetchone() - if fallback_row is not None: - key_id, key_json, used = fallback_row - device_result[algorithm + ":" + key_id] = key_json - if not used: - used_fallbacks.append( - (user_id, device_id, algorithm, key_id) - ) - - # drop any one-time keys that were claimed - sql = ( - "DELETE FROM e2e_one_time_keys_json" - " WHERE user_id = ? AND device_id = ? AND algorithm = ?" - " AND key_id = ?" + + return f"{algorithm}:{key_id}", key_json + + @trace + def _claim_e2e_one_time_key_returning( + txn, user_id: str, device_id: str, algorithm: str + ) -> Optional[Tuple[str, str]]: + """Claim OTK for device for DBs that support RETURNING. + + Returns: + A tuple of key name (algorithm + key ID) and key JSON, if an + OTK was found. + """ + + # We can use RETURNING to do the fetch and DELETE in once step. + sql = """ + DELETE FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + AND key_id IN ( + SELECT key_id FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + LIMIT 1 + ) + RETURNING key_id, key_json + """ + + txn.execute( + sql, (user_id, device_id, algorithm, user_id, device_id, algorithm) ) - for user_id, device_id, algorithm, key_id in delete: - log_kv( - { - "message": "Executing claim e2e_one_time_keys transaction on database." - } - ) - txn.execute(sql, (user_id, device_id, algorithm, key_id)) - log_kv({"message": "finished executing and invalidating cache"}) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) + otk_row = txn.fetchone() + if otk_row is None: + return None + + key_id, key_json = otk_row + return f"{algorithm}:{key_id}", key_json + + results = {} + for user_id, device_id, algorithm in query_list: + if self.database_engine.supports_returning: + # If we support RETURNING clause we can use a single query that + # allows us to use autocommit mode. + _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning + db_autocommit = True + else: + _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple + db_autocommit = False + + row = await self.db_pool.runInteraction( + "claim_e2e_one_time_keys", + _claim_e2e_one_time_key, + user_id, + device_id, + algorithm, + db_autocommit=db_autocommit, + ) + if row: + device_results = results.setdefault(user_id, {}).setdefault( + device_id, {} ) - # mark fallback keys as used - for user_id, device_id, algorithm, key_id in used_fallbacks: - self.db_pool.simple_update_txn( - txn, - "e2e_fallback_keys_json", - { + device_results[row[0]] = row[1] + continue + + # No one-time key available, so see if there's a fallback + # key + row = await self.db_pool.simple_select_one( + table="e2e_fallback_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + }, + retcols=("key_id", "key_json", "used"), + desc="_get_fallback_key", + allow_none=True, + ) + if row is None: + continue + + key_id = row["key_id"] + key_json = row["key_json"] + used = row["used"] + + # Mark fallback key as used if not already. + if not used: + await self.db_pool.simple_update_one( + table="e2e_fallback_keys_json", + keyvalues={ "user_id": user_id, "device_id": device_id, "algorithm": algorithm, "key_id": key_id, }, - {"used": True}, + updatevalues={"used": True}, + desc="_get_fallback_key_set_used", ) - self._invalidate_cache_and_stream( - txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) + await self.invalidate_cache_and_stream( + "get_e2e_unused_fallback_key_types", (user_id, device_id) ) - return result + device_results = results.setdefault(user_id, {}).setdefault(device_id, {}) + device_results[f"{algorithm}:{key_id}"] = key_json - return await self.db_pool.runInteraction( - "claim_e2e_one_time_keys", _claim_e2e_one_time_keys - ) + return results class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): -- cgit 1.5.1 From c37dad67ab04980ac934554399f52a27e54292ab Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Aug 2021 13:54:51 +0100 Subject: Improve event caching code (#10119) Ensure we only load an event from the DB once when the same event is requested multiple times at once. --- changelog.d/10119.misc | 1 + synapse/storage/databases/main/events_worker.py | 144 +++++++++++++++------ synapse/storage/databases/main/roommember.py | 6 +- tests/storage/databases/main/test_events_worker.py | 50 +++++++ 4 files changed, 158 insertions(+), 43 deletions(-) create mode 100644 changelog.d/10119.misc diff --git a/changelog.d/10119.misc b/changelog.d/10119.misc new file mode 100644 index 0000000000..f70dc6496f --- /dev/null +++ b/changelog.d/10119.misc @@ -0,0 +1 @@ +Improve event caching mechanism to avoid having multiple copies of an event in memory at a time. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 3c86adab56..375463e4e9 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -14,7 +14,6 @@ import logging import threading -from collections import namedtuple from typing import ( Collection, Container, @@ -27,6 +26,7 @@ from typing import ( overload, ) +import attr from constantly import NamedConstant, Names from typing_extensions import Literal @@ -42,7 +42,11 @@ from synapse.api.room_versions import ( from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.events.utils import prune_event -from synapse.logging.context import PreserveLoggingContext, current_context +from synapse.logging.context import ( + PreserveLoggingContext, + current_context, + make_deferred_yieldable, +) from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -56,6 +60,8 @@ from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id +from synapse.util import unwrapFirstError +from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter @@ -74,7 +80,10 @@ EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events -_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) +@attr.s(slots=True, auto_attribs=True) +class _EventCacheEntry: + event: EventBase + redacted_event: Optional[EventBase] class EventRedactBehaviour(Names): @@ -161,6 +170,13 @@ class EventsWorkerStore(SQLBaseStore): max_size=hs.config.caches.event_cache_size, ) + # Map from event ID to a deferred that will result in a map from event + # ID to cache entry. Note that the returned dict may not have the + # requested event in it if the event isn't in the DB. + self._current_event_fetches: Dict[ + str, ObservableDeferred[Dict[str, _EventCacheEntry]] + ] = {} + self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0 @@ -476,7 +492,9 @@ class EventsWorkerStore(SQLBaseStore): return events - async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False): + async def _get_events_from_cache_or_db( + self, event_ids: Iterable[str], allow_rejected: bool = False + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the cache or the database. If events are pulled from the database, they will be cached for future lookups. @@ -485,53 +503,107 @@ class EventsWorkerStore(SQLBaseStore): Args: - event_ids (Iterable[str]): The event_ids of the events to fetch + event_ids: The event_ids of the events to fetch - allow_rejected (bool): Whether to include rejected events. If False, + allow_rejected: Whether to include rejected events. If False, rejected events are omitted from the response. Returns: - Dict[str, _EventCacheEntry]: - map from event id to result + map from event id to result """ event_entry_map = self._get_events_from_cache( - event_ids, allow_rejected=allow_rejected + event_ids, ) - missing_events_ids = [e for e in event_ids if e not in event_entry_map] + missing_events_ids = {e for e in event_ids if e not in event_entry_map} + + # We now look up if we're already fetching some of the events in the DB, + # if so we wait for those lookups to finish instead of pulling the same + # events out of the DB multiple times. + already_fetching: Dict[str, defer.Deferred] = {} + + for event_id in missing_events_ids: + deferred = self._current_event_fetches.get(event_id) + if deferred is not None: + # We're already pulling the event out of the DB. Add the deferred + # to the collection of deferreds to wait on. + already_fetching[event_id] = deferred.observe() + + missing_events_ids.difference_update(already_fetching) if missing_events_ids: log_ctx = current_context() log_ctx.record_event_fetch(len(missing_events_ids)) + # Add entries to `self._current_event_fetches` for each event we're + # going to pull from the DB. We use a single deferred that resolves + # to all the events we pulled from the DB (this will result in this + # function returning more events than requested, but that can happen + # already due to `_get_events_from_db`). + fetching_deferred: ObservableDeferred[ + Dict[str, _EventCacheEntry] + ] = ObservableDeferred(defer.Deferred()) + for event_id in missing_events_ids: + self._current_event_fetches[event_id] = fetching_deferred + # Note that _get_events_from_db is also responsible for turning db rows # into FrozenEvents (via _get_event_from_row), which involves seeing if # the events have been redacted, and if so pulling the redaction event out # of the database to check it. # - missing_events = await self._get_events_from_db( - missing_events_ids, allow_rejected=allow_rejected - ) + try: + missing_events = await self._get_events_from_db( + missing_events_ids, + ) - event_entry_map.update(missing_events) + event_entry_map.update(missing_events) + except Exception as e: + with PreserveLoggingContext(): + fetching_deferred.errback(e) + raise e + finally: + # Ensure that we mark these events as no longer being fetched. + for event_id in missing_events_ids: + self._current_event_fetches.pop(event_id, None) + + with PreserveLoggingContext(): + fetching_deferred.callback(missing_events) + + if already_fetching: + # Wait for the other event requests to finish and add their results + # to ours. + results = await make_deferred_yieldable( + defer.gatherResults( + already_fetching.values(), + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) + + for result in results: + event_entry_map.update(result) + + if not allow_rejected: + event_entry_map = { + event_id: entry + for event_id, entry in event_entry_map.items() + if not entry.event.rejected_reason + } return event_entry_map def _invalidate_get_event_cache(self, event_id): self._get_event_cache.invalidate((event_id,)) - def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): - """Fetch events from the caches + def _get_events_from_cache( + self, events: Iterable[str], update_metrics: bool = True + ) -> Dict[str, _EventCacheEntry]: + """Fetch events from the caches. - Args: - events (Iterable[str]): list of event_ids to fetch - allow_rejected (bool): Whether to return events that were rejected - update_metrics (bool): Whether to update the cache hit ratio metrics + May return rejected events. - Returns: - dict of event_id -> _EventCacheEntry for each event_id in cache. If - allow_rejected is `False` then there will still be an entry but it - will be `None` + Args: + events: list of event_ids to fetch + update_metrics: Whether to update the cache hit ratio metrics """ event_map = {} @@ -542,10 +614,7 @@ class EventsWorkerStore(SQLBaseStore): if not ret: continue - if allow_rejected or not ret.event.rejected_reason: - event_map[event_id] = ret - else: - event_map[event_id] = None + event_map[event_id] = ret return event_map @@ -672,23 +741,23 @@ class EventsWorkerStore(SQLBaseStore): with PreserveLoggingContext(): self.hs.get_reactor().callFromThread(fire, event_list, e) - async def _get_events_from_db(self, event_ids, allow_rejected=False): + async def _get_events_from_db( + self, event_ids: Iterable[str] + ) -> Dict[str, _EventCacheEntry]: """Fetch a bunch of events from the database. + May return rejected events. + Returned events will be added to the cache for future lookups. Unknown events are omitted from the response. Args: - event_ids (Iterable[str]): The event_ids of the events to fetch - - allow_rejected (bool): Whether to include rejected events. If False, - rejected events are omitted from the response. + event_ids: The event_ids of the events to fetch Returns: - Dict[str, _EventCacheEntry]: - map from event id to result. May return extra events which - weren't asked for. + map from event id to result. May return extra events which + weren't asked for. """ fetched_events = {} events_to_fetch = event_ids @@ -717,9 +786,6 @@ class EventsWorkerStore(SQLBaseStore): rejected_reason = row["rejected_reason"] - if not allow_rejected and rejected_reason: - continue - # If the event or metadata cannot be parsed, log the error and act # as if the event is unknown. try: diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 68f1b40ea6..e8157ba3d4 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -629,14 +629,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): # We don't update the event cache hit ratio as it completely throws off # the hit ratio counts. After all, we don't populate the cache if we # miss it here - event_map = self._get_events_from_cache( - member_event_ids, allow_rejected=False, update_metrics=False - ) + event_map = self._get_events_from_cache(member_event_ids, update_metrics=False) missing_member_event_ids = [] for event_id in member_event_ids: ev_entry = event_map.get(event_id) - if ev_entry: + if ev_entry and not ev_entry.event.rejected_reason: if ev_entry.event.membership == Membership.JOIN: users_in_room[ev_entry.event.state_key] = ProfileInfo( display_name=ev_entry.event.content.get("displayname", None), diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 932970fd9a..d05d367685 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -14,7 +14,10 @@ import json from synapse.logging.context import LoggingContext +from synapse.rest import admin +from synapse.rest.client.v1 import login, room from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.util.async_helpers import yieldable_gather_results from tests import unittest @@ -94,3 +97,50 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): res = self.get_success(self.store.have_seen_events("room1", ["event10"])) self.assertEquals(res, {"event10"}) self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) + + +class EventCacheTestCase(unittest.HomeserverTestCase): + """Test that the various layers of event cache works.""" + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store: EventsWorkerStore = hs.get_datastore() + + self.user = self.register_user("user", "pass") + self.token = self.login(self.user, "pass") + + self.room = self.helper.create_room_as(self.user, tok=self.token) + + res = self.helper.send(self.room, tok=self.token) + self.event_id = res["event_id"] + + # Reset the event cache so the tests start with it empty + self.store._get_event_cache.clear() + + def test_simple(self): + """Test that we cache events that we pull from the DB.""" + + with LoggingContext("test") as ctx: + self.get_success(self.store.get_event(self.event_id)) + + # We should have fetched the event from the DB + self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) + + def test_dedupe(self): + """Test that if we request the same event multiple times we only pull it + out once. + """ + + with LoggingContext("test") as ctx: + d = yieldable_gather_results( + self.store.get_event, [self.event_id, self.event_id] + ) + self.get_success(d) + + # We should have fetched the event from the DB + self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) -- cgit 1.5.1 From e8a3e8140291be0548ad80d0e942a9aaae6c2434 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 4 Aug 2021 16:13:24 +0200 Subject: Don't fail on empty bodies when sending out read receipts (#10531) Fixes a bug introduced in rc1 that would cause Synapse to 400 on read receipts requests with empty bodies. Broken in #10413 --- changelog.d/10531.bugfix | 1 + synapse/rest/client/v2_alpha/receipts.py | 2 +- tests/rest/client/v2_alpha/test_sync.py | 12 ++++++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10531.bugfix diff --git a/changelog.d/10531.bugfix b/changelog.d/10531.bugfix new file mode 100644 index 0000000000..aaa921ee91 --- /dev/null +++ b/changelog.d/10531.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.40.0rc1 that would cause Synapse to respond with an error when clients would update their read receipts. diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 4b98979b47..d9ab836cd8 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -43,7 +43,7 @@ class ReceiptRestServlet(RestServlet): if receipt_type != "m.read": raise SynapseError(400, "Receipt type must be 'm.read'") - body = parse_json_object_from_request(request) + body = parse_json_object_from_request(request, allow_empty_body=True) hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) if not isinstance(hidden, bool): diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index f6ae9ae181..15748ed4fd 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -418,6 +418,18 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that the first user can't see the other user's hidden read receipt self.assertEqual(self._get_read_receipt(), None) + def test_read_receipt_with_empty_body(self): + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a read receipt for this message with an empty body + channel = self.make_request( + "POST", + "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + def _get_read_receipt(self): """Syncs and returns the read receipt.""" -- cgit 1.5.1 From 02c2f631aed5cc2ef4bcaea25b443b625616f816 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 4 Aug 2021 17:09:27 +0100 Subject: 1.40.0rc2 --- CHANGES.md | 16 ++++++++++++++++ changelog.d/10516.misc | 1 - changelog.d/10517.bugfix | 1 - changelog.d/10531.bugfix | 1 - debian/changelog | 6 ++++++ synapse/__init__.py | 2 +- 6 files changed, 23 insertions(+), 4 deletions(-) delete mode 100644 changelog.d/10516.misc delete mode 100644 changelog.d/10517.bugfix delete mode 100644 changelog.d/10531.bugfix diff --git a/CHANGES.md b/CHANGES.md index 7ce28c4c18..75031986d3 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,19 @@ +Synapse 1.40.0rc2 (2021-08-04) +============================== + +Bugfixes +-------- + +- Fix the `PeriodicallyFlushingMemoryHandler` inhibiting application shutdown because of its background thread. ([\#10517](https://github.com/matrix-org/synapse/issues/10517)) +- Fix a bug introduced in Synapse v1.40.0rc1 that would cause Synapse to respond with an error when clients would update their read receipts. ([\#10531](https://github.com/matrix-org/synapse/issues/10531)) + + +Internal Changes +---------------- + +- Fix release script to open correct URL for the release. ([\#10516](https://github.com/matrix-org/synapse/issues/10516)) + + Synapse 1.40.0rc1 (2021-08-03) ============================== diff --git a/changelog.d/10516.misc b/changelog.d/10516.misc deleted file mode 100644 index 4d8c5e4805..0000000000 --- a/changelog.d/10516.misc +++ /dev/null @@ -1 +0,0 @@ -Fix release script to open correct URL for the release. diff --git a/changelog.d/10517.bugfix b/changelog.d/10517.bugfix deleted file mode 100644 index 5b044bb34d..0000000000 --- a/changelog.d/10517.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix the `PeriodicallyFlushingMemoryHandler` inhibiting application shutdown because of its background thread. diff --git a/changelog.d/10531.bugfix b/changelog.d/10531.bugfix deleted file mode 100644 index aaa921ee91..0000000000 --- a/changelog.d/10531.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in Synapse v1.40.0rc1 that would cause Synapse to respond with an error when clients would update their read receipts. diff --git a/debian/changelog b/debian/changelog index f0557c35ef..c523101f9a 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.40.0~rc2) stable; urgency=medium + + * New synapse release 1.40.0~rc2. + + -- Synapse Packaging team Wed, 04 Aug 2021 17:08:55 +0100 + matrix-synapse-py3 (1.40.0~rc1) stable; urgency=medium [ Richard van der Hoff ] diff --git a/synapse/__init__.py b/synapse/__init__.py index d6c1765508..da52463531 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.40.0rc1" +__version__ = "1.40.0rc2" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.5.1 From 167335bd3da0fcfa0b2ba5ca3dc9d2f7c953c1eb Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 4 Aug 2021 17:11:23 +0100 Subject: Fixup changelog --- CHANGES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 75031986d3..052ab49599 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,7 +5,7 @@ Bugfixes -------- - Fix the `PeriodicallyFlushingMemoryHandler` inhibiting application shutdown because of its background thread. ([\#10517](https://github.com/matrix-org/synapse/issues/10517)) -- Fix a bug introduced in Synapse v1.40.0rc1 that would cause Synapse to respond with an error when clients would update their read receipts. ([\#10531](https://github.com/matrix-org/synapse/issues/10531)) +- Fix a bug introduced in Synapse v1.40.0rc1 that could cause Synapse to respond with an error when clients would update their a receipts. ([\#10531](https://github.com/matrix-org/synapse/issues/10531)) Internal Changes -- cgit 1.5.1 From cc1cb0ab54654b1a1d938ae464a3471dd1b588a5 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 4 Aug 2021 17:14:55 +0100 Subject: Fixup changelog --- CHANGES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 052ab49599..7d3bdebbde 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -5,7 +5,7 @@ Bugfixes -------- - Fix the `PeriodicallyFlushingMemoryHandler` inhibiting application shutdown because of its background thread. ([\#10517](https://github.com/matrix-org/synapse/issues/10517)) -- Fix a bug introduced in Synapse v1.40.0rc1 that could cause Synapse to respond with an error when clients would update their a receipts. ([\#10531](https://github.com/matrix-org/synapse/issues/10531)) +- Fix a bug introduced in Synapse v1.40.0rc1 that could cause Synapse to respond with an error when clients would update read receipts. ([\#10531](https://github.com/matrix-org/synapse/issues/10531)) Internal Changes -- cgit 1.5.1 From 05111f8f26252cc936fce685846322289039128d Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 4 Aug 2021 17:16:08 +0100 Subject: Fixup changelog --- CHANGES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 7d3bdebbde..62ea684e58 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -11,7 +11,7 @@ Bugfixes Internal Changes ---------------- -- Fix release script to open correct URL for the release. ([\#10516](https://github.com/matrix-org/synapse/issues/10516)) +- Fix release script to open the correct URL for the release. ([\#10516](https://github.com/matrix-org/synapse/issues/10516)) Synapse 1.40.0rc1 (2021-08-03) -- cgit 1.5.1 From 684d19a11c3b93c9dd5fb90f43d38aa7e8c6005f Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 4 Aug 2021 12:07:57 -0500 Subject: Add support for MSC2716 marker events (#10498) * Make historical messages available to federated servers Part of MSC2716: https://github.com/matrix-org/matrix-doc/pull/2716 Follow-up to https://github.com/matrix-org/synapse/pull/9247 * Debug message not available on federation * Add base starting insertion point when no chunk ID is provided * Fix messages from multiple senders in historical chunk Follow-up to https://github.com/matrix-org/synapse/pull/9247 Part of MSC2716: https://github.com/matrix-org/matrix-doc/pull/2716 --- Previously, Synapse would throw a 403, `Cannot force another user to join.`, because we were trying to use `?user_id` from a single virtual user which did not match with messages from other users in the chunk. * Remove debug lines * Messing with selecting insertion event extremeties * Move db schema change to new version * Add more better comments * Make a fake requester with just what we need See https://github.com/matrix-org/synapse/pull/10276#discussion_r660999080 * Store insertion events in table * Make base insertion event float off on its own See https://github.com/matrix-org/synapse/pull/10250#issuecomment-875711889 Conflicts: synapse/rest/client/v1/room.py * Validate that the app service can actually control the given user See https://github.com/matrix-org/synapse/pull/10276#issuecomment-876316455 Conflicts: synapse/rest/client/v1/room.py * Add some better comments on what we're trying to check for * Continue debugging * Share validation logic * Add inserted historical messages to /backfill response * Remove debug sql queries * Some marker event implemntation trials * Clean up PR * Rename insertion_event_id to just event_id * Add some better sql comments * More accurate description * Add changelog * Make it clear what MSC the change is part of * Add more detail on which insertion event came through * Address review and improve sql queries * Only use event_id as unique constraint * Fix test case where insertion event is already in the normal DAG * Remove debug changes * Add support for MSC2716 marker events * Process markers when we receive it over federation * WIP: make hs2 backfill historical messages after marker event * hs2 to better ask for insertion event extremity But running into the `sqlite3.IntegrityError: NOT NULL constraint failed: event_to_state_groups.state_group` error * Add insertion_event_extremities table * Switch to chunk events so we can auth via power_levels Previously, we were using `content.chunk_id` to connect one chunk to another. But these events can be from any `sender` and we can't tell who should be able to send historical events. We know we only want the application service to do it but these events have the sender of a real historical message, not the application service user ID as the sender. Other federated homeservers also have no indicator which senders are an application service on the originating homeserver. So we want to auth all of the MSC2716 events via power_levels and have them be sent by the application service with proper PL levels in the room. * Switch to chunk events for federation * Add unstable room version to support new historical PL * Messy: Fix undefined state_group for federated historical events ``` 2021-07-13 02:27:57,810 - synapse.handlers.federation - 1248 - ERROR - GET-4 - Failed to backfill from hs1 because NOT NULL constraint failed: event_to_state_groups.state_group Traceback (most recent call last): File "/usr/local/lib/python3.8/site-packages/synapse/handlers/federation.py", line 1216, in try_backfill await self.backfill( File "/usr/local/lib/python3.8/site-packages/synapse/handlers/federation.py", line 1035, in backfill await self._auth_and_persist_event(dest, event, context, backfilled=True) File "/usr/local/lib/python3.8/site-packages/synapse/handlers/federation.py", line 2222, in _auth_and_persist_event await self._run_push_actions_and_persist_event(event, context, backfilled) File "/usr/local/lib/python3.8/site-packages/synapse/handlers/federation.py", line 2244, in _run_push_actions_and_persist_event await self.persist_events_and_notify( File "/usr/local/lib/python3.8/site-packages/synapse/handlers/federation.py", line 3290, in persist_events_and_notify events, max_stream_token = await self.storage.persistence.persist_events( File "/usr/local/lib/python3.8/site-packages/synapse/logging/opentracing.py", line 774, in _trace_inner return await func(*args, **kwargs) File "/usr/local/lib/python3.8/site-packages/synapse/storage/persist_events.py", line 320, in persist_events ret_vals = await yieldable_gather_results(enqueue, partitioned.items()) File "/usr/local/lib/python3.8/site-packages/synapse/storage/persist_events.py", line 237, in handle_queue_loop ret = await self._per_item_callback( File "/usr/local/lib/python3.8/site-packages/synapse/storage/persist_events.py", line 577, in _persist_event_batch await self.persist_events_store._persist_events_and_state_updates( File "/usr/local/lib/python3.8/site-packages/synapse/storage/databases/main/events.py", line 176, in _persist_events_and_state_updates await self.db_pool.runInteraction( File "/usr/local/lib/python3.8/site-packages/synapse/storage/database.py", line 681, in runInteraction result = await self.runWithConnection( File "/usr/local/lib/python3.8/site-packages/synapse/storage/database.py", line 770, in runWithConnection return await make_deferred_yieldable( File "/usr/local/lib/python3.8/site-packages/twisted/python/threadpool.py", line 238, in inContext result = inContext.theWork() # type: ignore[attr-defined] File "/usr/local/lib/python3.8/site-packages/twisted/python/threadpool.py", line 254, in inContext.theWork = lambda: context.call( # type: ignore[attr-defined] File "/usr/local/lib/python3.8/site-packages/twisted/python/context.py", line 118, in callWithContext return self.currentContext().callWithContext(ctx, func, *args, **kw) File "/usr/local/lib/python3.8/site-packages/twisted/python/context.py", line 83, in callWithContext return func(*args, **kw) File "/usr/local/lib/python3.8/site-packages/twisted/enterprise/adbapi.py", line 293, in _runWithConnection compat.reraise(excValue, excTraceback) File "/usr/local/lib/python3.8/site-packages/twisted/python/deprecate.py", line 298, in deprecatedFunction return function(*args, **kwargs) File "/usr/local/lib/python3.8/site-packages/twisted/python/compat.py", line 403, in reraise raise exception.with_traceback(traceback) File "/usr/local/lib/python3.8/site-packages/twisted/enterprise/adbapi.py", line 284, in _runWithConnection result = func(conn, *args, **kw) File "/usr/local/lib/python3.8/site-packages/synapse/storage/database.py", line 765, in inner_func return func(db_conn, *args, **kwargs) File "/usr/local/lib/python3.8/site-packages/synapse/storage/database.py", line 549, in new_transaction r = func(cursor, *args, **kwargs) File "/usr/local/lib/python3.8/site-packages/synapse/logging/utils.py", line 69, in wrapped return f(*args, **kwargs) File "/usr/local/lib/python3.8/site-packages/synapse/storage/databases/main/events.py", line 385, in _persist_events_txn self._store_event_state_mappings_txn(txn, events_and_contexts) File "/usr/local/lib/python3.8/site-packages/synapse/storage/databases/main/events.py", line 2065, in _store_event_state_mappings_txn self.db_pool.simple_insert_many_txn( File "/usr/local/lib/python3.8/site-packages/synapse/storage/database.py", line 923, in simple_insert_many_txn txn.execute_batch(sql, vals) File "/usr/local/lib/python3.8/site-packages/synapse/storage/database.py", line 280, in execute_batch self.executemany(sql, args) File "/usr/local/lib/python3.8/site-packages/synapse/storage/database.py", line 300, in executemany self._do_execute(self.txn.executemany, sql, *args) File "/usr/local/lib/python3.8/site-packages/synapse/storage/database.py", line 330, in _do_execute return func(sql, *args) sqlite3.IntegrityError: NOT NULL constraint failed: event_to_state_groups.state_group ``` * Revert "Messy: Fix undefined state_group for federated historical events" This reverts commit 187ab28611546321e02770944c86f30ee2bc742a. * Fix federated events being rejected for no state_groups Add fix from https://github.com/matrix-org/synapse/pull/10439 until it merges. * Adapting to experimental room version * Some log cleanup * Add better comments around extremity fetching code and why * Rename to be more accurate to what the function returns * Add changelog * Ignore rejected events * Use simplified upsert * Add Erik's explanation of extra event checks See https://github.com/matrix-org/synapse/pull/10498#discussion_r680880332 * Clarify that the depth is not directly correlated to the backwards extremity that we return See https://github.com/matrix-org/synapse/pull/10498#discussion_r681725404 * lock only matters for sqlite See https://github.com/matrix-org/synapse/pull/10498#discussion_r681728061 * Move new SQL changes to its own delta file * Clean up upsert docstring * Bump database schema version (62) --- changelog.d/10498.feature | 1 + scripts-dev/complement.sh | 2 +- synapse/handlers/federation.py | 119 +++++++++++++++++++-- synapse/storage/database.py | 14 +-- synapse/storage/databases/main/event_federation.py | 114 +++++++++++++++++--- synapse/storage/databases/main/events.py | 24 ++++- synapse/storage/schema/__init__.py | 2 +- .../delta/62/01insertion_event_extremities.sql | 24 +++++ 8 files changed, 265 insertions(+), 35 deletions(-) create mode 100644 changelog.d/10498.feature create mode 100644 synapse/storage/schema/main/delta/62/01insertion_event_extremities.sql diff --git a/changelog.d/10498.feature b/changelog.d/10498.feature new file mode 100644 index 0000000000..5df896572d --- /dev/null +++ b/changelog.d/10498.feature @@ -0,0 +1 @@ +Add support for "marker" events which makes historical events discoverable for servers that already have all of the scrollback history (part of MSC2716). diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index cba015d942..5d0ef8dd3a 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -65,4 +65,4 @@ if [[ -n "$1" ]]; then fi # Run the tests! -go test -v -tags synapse_blacklist,msc2946,msc3083,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... +go test -v -tags synapse_blacklist,msc2946,msc3083,msc2403,msc2716 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 8197b60b76..8b602e3813 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -42,6 +42,7 @@ from twisted.internet import defer from synapse import event_auth from synapse.api.constants import ( + EventContentFields, EventTypes, Membership, RejectedReason, @@ -262,7 +263,12 @@ class FederationHandler(BaseHandler): state = None - # Get missing pdus if necessary. + # Check that the event passes auth based on the state at the event. This is + # done for events that are to be added to the timeline (non-outliers). + # + # Get missing pdus if necessary: + # - Fetching any missing prev events to fill in gaps in the graph + # - Fetching state if we have a hole in the graph if not pdu.internal_metadata.is_outlier(): # We only backfill backwards to the min depth. min_depth = await self.get_min_depth_for_context(pdu.room_id) @@ -432,6 +438,13 @@ class FederationHandler(BaseHandler): affected=event_id, ) + # A second round of checks for all events. Check that the event passes auth + # based on `auth_events`, this allows us to assert that the event would + # have been allowed at some point. If an event passes this check its OK + # for it to be used as part of a returned `/state` request, as either + # a) we received the event as part of the original join and so trust it, or + # b) we'll do a state resolution with existing state before it becomes + # part of the "current state", which adds more protection. await self._process_received_pdu(origin, pdu, state=state) async def _get_missing_events_for_pdu( @@ -889,6 +902,79 @@ class FederationHandler(BaseHandler): "resync_device_due_to_pdu", self._resync_device, event.sender ) + await self._handle_marker_event(origin, event) + + async def _handle_marker_event(self, origin: str, marker_event: EventBase): + """Handles backfilling the insertion event when we receive a marker + event that points to one. + + Args: + origin: Origin of the event. Will be called to get the insertion event + marker_event: The event to process + """ + + if marker_event.type != EventTypes.MSC2716_MARKER: + # Not a marker event + return + + if marker_event.rejected_reason is not None: + # Rejected event + return + + # Skip processing a marker event if the room version doesn't + # support it. + room_version = await self.store.get_room_version(marker_event.room_id) + if not room_version.msc2716_historical: + return + + logger.debug("_handle_marker_event: received %s", marker_event) + + insertion_event_id = marker_event.content.get( + EventContentFields.MSC2716_MARKER_INSERTION + ) + + if insertion_event_id is None: + # Nothing to retrieve then (invalid marker) + return + + logger.debug( + "_handle_marker_event: backfilling insertion event %s", insertion_event_id + ) + + await self._get_events_and_persist( + origin, + marker_event.room_id, + [insertion_event_id], + ) + + insertion_event = await self.store.get_event( + insertion_event_id, allow_none=True + ) + if insertion_event is None: + logger.warning( + "_handle_marker_event: server %s didn't return insertion event %s for marker %s", + origin, + insertion_event_id, + marker_event.event_id, + ) + return + + logger.debug( + "_handle_marker_event: succesfully backfilled insertion event %s from marker event %s", + insertion_event, + marker_event, + ) + + await self.store.insert_insertion_extremity( + insertion_event_id, marker_event.room_id + ) + + logger.debug( + "_handle_marker_event: insertion extremity added for %s from marker event %s", + insertion_event, + marker_event, + ) + async def _resync_device(self, sender: str) -> None: """We have detected that the device list for the given user may be out of sync, so we try and resync them. @@ -1057,9 +1143,19 @@ class FederationHandler(BaseHandler): async def _maybe_backfill_inner( self, room_id: str, current_depth: int, limit: int ) -> bool: - extremities = await self.store.get_oldest_events_with_depth_in_room(room_id) + oldest_events_with_depth = ( + await self.store.get_oldest_event_ids_with_depth_in_room(room_id) + ) + insertion_events_to_be_backfilled = ( + await self.store.get_insertion_event_backwards_extremities_in_room(room_id) + ) + logger.debug( + "_maybe_backfill_inner: extremities oldest_events_with_depth=%s insertion_events_to_be_backfilled=%s", + oldest_events_with_depth, + insertion_events_to_be_backfilled, + ) - if not extremities: + if not oldest_events_with_depth and not insertion_events_to_be_backfilled: logger.debug("Not backfilling as no extremeties found.") return False @@ -1089,10 +1185,12 @@ class FederationHandler(BaseHandler): # state *before* the event, ignoring the special casing certain event # types have. - forward_events = await self.store.get_successor_events(list(extremities)) + forward_event_ids = await self.store.get_successor_events( + list(oldest_events_with_depth) + ) extremities_events = await self.store.get_events( - forward_events, + forward_event_ids, redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, ) @@ -1106,10 +1204,19 @@ class FederationHandler(BaseHandler): redact=False, check_history_visibility_only=True, ) + logger.debug( + "_maybe_backfill_inner: filtered_extremities %s", filtered_extremities + ) - if not filtered_extremities: + if not filtered_extremities and not insertion_events_to_be_backfilled: return False + extremities = { + **oldest_events_with_depth, + # TODO: insertion_events_to_be_backfilled is currently skipping the filtered_extremities checks + **insertion_events_to_be_backfilled, + } + # Check if we reached a point where we should start backfilling. sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1])) max_depth = sorted_extremeties_tuple[0][1] diff --git a/synapse/storage/database.py b/synapse/storage/database.py index c8015a3848..95d2caff62 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -941,13 +941,13 @@ class DatabasePool: `lock` should generally be set to True (the default), but can be set to False if either of the following are true: - - * there is a UNIQUE INDEX on the key columns. In this case a conflict - will cause an IntegrityError in which case this function will retry - the update. - - * we somehow know that we are the only thread which will be updating - this table. + 1. there is a UNIQUE INDEX on the key columns. In this case a conflict + will cause an IntegrityError in which case this function will retry + the update. + 2. we somehow know that we are the only thread which will be updating + this table. + As an additional note, this parameter only matters for old SQLite versions + because we will use native upserts otherwise. Args: table: The table to upsert into diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 44018c1c31..bddf5ef192 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -671,27 +671,97 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n} - async def get_oldest_events_with_depth_in_room(self, room_id): + async def get_oldest_event_ids_with_depth_in_room(self, room_id) -> Dict[str, int]: + """Gets the oldest events(backwards extremities) in the room along with the + aproximate depth. + + We use this function so that we can compare and see if someones current + depth at their current scrollback is within pagination range of the + event extremeties. If the current depth is close to the depth of given + oldest event, we can trigger a backfill. + + Args: + room_id: Room where we want to find the oldest events + + Returns: + Map from event_id to depth + """ + + def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id): + # Assemble a dictionary with event_id -> depth for the oldest events + # we know of in the room. Backwards extremeties are the oldest + # events we know of in the room but we only know of them because + # some other event referenced them by prev_event and aren't peristed + # in our database yet (meaning we don't know their depth + # specifically). So we need to look for the aproximate depth from + # the events connected to the current backwards extremeties. + sql = """ + SELECT b.event_id, MAX(e.depth) FROM events as e + /** + * Get the edge connections from the event_edges table + * so we can see whether this event's prev_events points + * to a backward extremity in the next join. + */ + INNER JOIN event_edges as g + ON g.event_id = e.event_id + /** + * We find the "oldest" events in the room by looking for + * events connected to backwards extremeties (oldest events + * in the room that we know of so far). + */ + INNER JOIN event_backward_extremities as b + ON g.prev_event_id = b.event_id + WHERE b.room_id = ? AND g.is_state is ? + GROUP BY b.event_id + """ + + txn.execute(sql, (room_id, False)) + + return dict(txn) + return await self.db_pool.runInteraction( - "get_oldest_events_with_depth_in_room", - self.get_oldest_events_with_depth_in_room_txn, + "get_oldest_event_ids_with_depth_in_room", + get_oldest_event_ids_with_depth_in_room_txn, room_id, ) - def get_oldest_events_with_depth_in_room_txn(self, txn, room_id): - sql = ( - "SELECT b.event_id, MAX(e.depth) FROM events as e" - " INNER JOIN event_edges as g" - " ON g.event_id = e.event_id" - " INNER JOIN event_backward_extremities as b" - " ON g.prev_event_id = b.event_id" - " WHERE b.room_id = ? AND g.is_state is ?" - " GROUP BY b.event_id" - ) + async def get_insertion_event_backwards_extremities_in_room( + self, room_id + ) -> Dict[str, int]: + """Get the insertion events we know about that we haven't backfilled yet. - txn.execute(sql, (room_id, False)) + We use this function so that we can compare and see if someones current + depth at their current scrollback is within pagination range of the + insertion event. If the current depth is close to the depth of given + insertion event, we can trigger a backfill. - return dict(txn) + Args: + room_id: Room where we want to find the oldest events + + Returns: + Map from event_id to depth + """ + + def get_insertion_event_backwards_extremities_in_room_txn(txn, room_id): + sql = """ + SELECT b.event_id, MAX(e.depth) FROM insertion_events as i + /* We only want insertion events that are also marked as backwards extremities */ + INNER JOIN insertion_event_extremities as b USING (event_id) + /* Get the depth of the insertion event from the events table */ + INNER JOIN events AS e USING (event_id) + WHERE b.room_id = ? + GROUP BY b.event_id + """ + + txn.execute(sql, (room_id,)) + + return dict(txn) + + return await self.db_pool.runInteraction( + "get_insertion_event_backwards_extremities_in_room", + get_insertion_event_backwards_extremities_in_room_txn, + room_id, + ) async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]: """Returns the event ID and depth for the event that has the max depth from a set of event IDs @@ -1041,7 +1111,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas if row[1] not in event_results: queue.put((-row[0], row[1])) - # Navigate up the DAG by prev_event txn.execute(query, (event_id, False, limit - len(event_results))) prev_event_id_results = txn.fetchall() logger.debug( @@ -1136,6 +1205,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas _delete_old_forward_extrem_cache_txn, ) + async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None: + await self.db_pool.simple_upsert( + table="insertion_event_extremities", + keyvalues={"event_id": event_id}, + values={ + "event_id": event_id, + "room_id": room_id, + }, + insertion_values={}, + desc="insert_insertion_extremity", + lock=False, + ) + async def insert_received_event_to_staging( self, origin: str, event: EventBase ) -> None: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 86baf397fb..40b53274fb 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1845,6 +1845,18 @@ class PersistEventsStore: }, ) + # When we receive an event with a `chunk_id` referencing the + # `next_chunk_id` of the insertion event, we can remove it from the + # `insertion_event_extremities` table. + sql = """ + DELETE FROM insertion_event_extremities WHERE event_id IN ( + SELECT event_id FROM insertion_events + WHERE next_chunk_id = ? + ) + """ + + txn.execute(sql, (chunk_id,)) + def _handle_redaction(self, txn, redacted_event_id): """Handles receiving a redaction and checking whether we need to remove any redacted relations from the database. @@ -2101,15 +2113,17 @@ class PersistEventsStore: Forward extremities are handled when we first start persisting the events. """ + # From the events passed in, add all of the prev events as backwards extremities. + # Ignore any events that are already backwards extrems or outliers. query = ( "INSERT INTO event_backward_extremities (event_id, room_id)" " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" + " SELECT 1 FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" " )" " AND NOT EXISTS (" - " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " - " AND outlier = ?" + " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " + " AND outlier = ?" " )" ) @@ -2123,6 +2137,8 @@ class PersistEventsStore: ], ) + # Delete all these events that we've already fetched and now know that their + # prev events are the new backwards extremeties. query = ( "DELETE FROM event_backward_extremities" " WHERE event_id = ? AND room_id = ?" diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 36340a652a..fd4dd67d91 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 61 +SCHEMA_VERSION = 62 """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the diff --git a/synapse/storage/schema/main/delta/62/01insertion_event_extremities.sql b/synapse/storage/schema/main/delta/62/01insertion_event_extremities.sql new file mode 100644 index 0000000000..b731ef284a --- /dev/null +++ b/synapse/storage/schema/main/delta/62/01insertion_event_extremities.sql @@ -0,0 +1,24 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Add a table that keeps track of which "insertion" events need to be backfilled +CREATE TABLE IF NOT EXISTS insertion_event_extremities( + event_id TEXT NOT NULL, + room_id TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS insertion_event_extremities_event_id ON insertion_event_extremities(event_id); +CREATE INDEX IF NOT EXISTS insertion_event_extremities_room_id ON insertion_event_extremities(room_id); -- cgit 1.5.1 From 9db24cc50d252b1685a4ac69a736b49ed225dcb6 Mon Sep 17 00:00:00 2001 From: Michael Telatynski <7t3chguy@gmail.com> Date: Wed, 4 Aug 2021 18:39:57 +0100 Subject: Send unstable-prefixed room_type in store-invite IS API requests (#10435) The room type is per MSC3288 to allow the identity-server to change invitation wording based on whether the invitation is to a room or a space. The prefixed key will be replaced once MSC3288 is accepted into the spec. --- changelog.d/10435.feature | 1 + synapse/handlers/identity.py | 6 ++++++ synapse/handlers/room_member.py | 13 ++++++++++++- 3 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10435.feature diff --git a/changelog.d/10435.feature b/changelog.d/10435.feature new file mode 100644 index 0000000000..f93ef5b415 --- /dev/null +++ b/changelog.d/10435.feature @@ -0,0 +1 @@ +Experimental support for [MSC3288](https://github.com/matrix-org/matrix-doc/pull/3288), sending `room_type` to the identity server for 3pid invites over the `/store-invite` API. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 0961dec5ab..8ffeabacf9 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -824,6 +824,7 @@ class IdentityHandler(BaseHandler): room_avatar_url: str, room_join_rules: str, room_name: str, + room_type: Optional[str], inviter_display_name: str, inviter_avatar_url: str, id_access_token: Optional[str] = None, @@ -843,6 +844,7 @@ class IdentityHandler(BaseHandler): notifications. room_join_rules: The join rules of the email (e.g. "public"). room_name: The m.room.name of the room. + room_type: The type of the room from its m.room.create event (e.g "m.space"). inviter_display_name: The current display name of the inviter. inviter_avatar_url: The URL of the inviter's avatar. @@ -869,6 +871,10 @@ class IdentityHandler(BaseHandler): "sender_display_name": inviter_display_name, "sender_avatar_url": inviter_avatar_url, } + + if room_type is not None: + invite_config["org.matrix.msc3288.room_type"] = room_type + # If a custom web client location is available, include it in the request. if self._web_client_location: invite_config["org.matrix.web_client_location"] = self._web_client_location diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 65ad3efa6a..ba13196218 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -19,7 +19,12 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple from synapse import types -from synapse.api.constants import AccountDataTypes, EventTypes, Membership +from synapse.api.constants import ( + AccountDataTypes, + EventContentFields, + EventTypes, + Membership, +) from synapse.api.errors import ( AuthError, Codes, @@ -1237,6 +1242,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if room_name_event: room_name = room_name_event.content.get("name", "") + room_type = None + room_create_event = room_state.get((EventTypes.Create, "")) + if room_create_event: + room_type = room_create_event.content.get(EventContentFields.ROOM_TYPE) + room_join_rules = "" join_rules_event = room_state.get((EventTypes.JoinRules, "")) if join_rules_event: @@ -1263,6 +1273,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): room_avatar_url=room_avatar_url, room_join_rules=room_join_rules, room_name=room_name, + room_type=room_type, inviter_display_name=inviter_display_name, inviter_avatar_url=inviter_avatar_url, id_access_token=id_access_token, -- cgit 1.5.1 From e33f14e8d51e33cb86d7791495b73ae4c1e784f9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 5 Aug 2021 11:22:27 +0100 Subject: Don't fail CI when lint-newfile job was skipped (#10529) --- .github/workflows/tests.yml | 7 ++++++- changelog.d/10529.misc | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10529.misc diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 239553ae13..75c2976a25 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,7 +8,7 @@ on: concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true - + jobs: lint: runs-on: ubuntu-latest @@ -374,6 +374,11 @@ jobs: rc=0 results=$(jq -r 'to_entries[] | [.key,.value.result] | join(" ")' <<< $NEEDS_CONTEXT) while read job result ; do + # The newsfile lint may be skipped on non PR builds + if [ $result == "skipped" ] && [ $job == "lint-newsfile" ]; then + continue + fi + if [ "$result" != "success" ]; then echo "::set-failed ::Job $job returned $result" rc=1 diff --git a/changelog.d/10529.misc b/changelog.d/10529.misc new file mode 100644 index 0000000000..4caf22523c --- /dev/null +++ b/changelog.d/10529.misc @@ -0,0 +1 @@ +Fix CI to not break when run against branches rather than pull requests. -- cgit 1.5.1 From 834cdc3606c9193f7b5a5e93936193b359222690 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 5 Aug 2021 13:20:05 +0200 Subject: Add documentation for configuring a forward proxy. (#10443) --- changelog.d/10443.doc | 1 + docs/SUMMARY.md | 1 + docs/setup/forward_proxy.md | 74 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+) create mode 100644 changelog.d/10443.doc create mode 100644 docs/setup/forward_proxy.md diff --git a/changelog.d/10443.doc b/changelog.d/10443.doc new file mode 100644 index 0000000000..3588e5487f --- /dev/null +++ b/changelog.d/10443.doc @@ -0,0 +1 @@ +Add documentation for configuration a forward proxy. diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 10be12d638..3d320a1c43 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -7,6 +7,7 @@ - [Installation](setup/installation.md) - [Using Postgres](postgres.md) - [Configuring a Reverse Proxy](reverse_proxy.md) + - [Configuring a Forward/Outbound Proxy](setup/forward_proxy.md) - [Configuring a Turn Server](turn-howto.md) - [Delegation](delegate.md) diff --git a/docs/setup/forward_proxy.md b/docs/setup/forward_proxy.md new file mode 100644 index 0000000000..a0720ab342 --- /dev/null +++ b/docs/setup/forward_proxy.md @@ -0,0 +1,74 @@ +# Using a forward proxy with Synapse + +You can use Synapse with a forward or outbound proxy. An example of when +this is necessary is in corporate environments behind a DMZ (demilitarized zone). +Synapse supports routing outbound HTTP(S) requests via a proxy. Only HTTP(S) +proxy is supported, not SOCKS proxy or anything else. + +## Configure + +The `http_proxy`, `https_proxy`, `no_proxy` environment variables are used to +specify proxy settings. The environment variable is not case sensitive. +- `http_proxy`: Proxy server to use for HTTP requests. +- `https_proxy`: Proxy server to use for HTTPS requests. +- `no_proxy`: Comma-separated list of hosts, IP addresses, or IP ranges in CIDR + format which should not use the proxy. Synapse will directly connect to these hosts. + +The `http_proxy` and `https_proxy` environment variables have the form: `[scheme://][:@][:]` +- Supported schemes are `http://` and `https://`. The default scheme is `http://` + for compatibility reasons; it is recommended to set a scheme. If scheme is set + to `https://` the connection uses TLS between Synapse and the proxy. + + **NOTE**: Synapse validates the certificates. If the certificate is not + valid, then the connection is dropped. +- Default port if not given is `1080`. +- Username and password are optional and will be used to authenticate against + the proxy. + +**Examples** +- HTTP_PROXY=http://USERNAME:PASSWORD@10.0.1.1:8080/ +- HTTPS_PROXY=http://USERNAME:PASSWORD@proxy.example.com:8080/ +- NO_PROXY=master.hostname.example.com,10.1.0.0/16,172.30.0.0/16 + +**NOTE**: +Synapse does not apply the IP blacklist to connections through the proxy (since +the DNS resolution is done by the proxy). It is expected that the proxy or firewall +will apply blacklisting of IP addresses. + +## Connection types + +The proxy will be **used** for: + +- push +- url previews +- phone-home stats +- recaptcha validation +- CAS auth validation +- OpenID Connect +- Federation (checking public key revocation) + +It will **not be used** for: + +- Application Services +- Identity servers +- Outbound federation +- In worker configurations + - connections between workers + - connections from workers to Redis +- Fetching public keys of other servers +- Downloading remote media + +## Troubleshooting + +If a proxy server is used with TLS (HTTPS) and no connections are established, +it is most likely due to the proxy's certificates. To test this, the validation +in Synapse can be deactivated. + +**NOTE**: This has an impact on security and is for testing purposes only! + +To deactivate the certificate validation, the following setting must be made in +[homserver.yaml](../usage/configuration/homeserver_sample_config.md). + +```yaml +use_insecure_ssl_client_just_for_testing_do_not_use: true +``` -- cgit 1.5.1 From a8a27b2b8bac2995c3edd20518680366eb543ac9 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Thu, 5 Aug 2021 13:22:14 +0100 Subject: Only return an appservice protocol if it has a service providing it. (#10532) If there are no services providing a protocol, omit it completely instead of returning an empty dictionary. This fixes a long-standing spec compliance bug. --- changelog.d/10532.bugfix | 1 + synapse/handlers/appservice.py | 7 +-- tests/handlers/test_appservice.py | 122 +++++++++++++++++++++++++++++++++++++- 3 files changed, 125 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10532.bugfix diff --git a/changelog.d/10532.bugfix b/changelog.d/10532.bugfix new file mode 100644 index 0000000000..d95e3d9b59 --- /dev/null +++ b/changelog.d/10532.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where protocols which are not implemented by any appservices were incorrectly returned via `GET /_matrix/client/r0/thirdparty/protocols`. diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 21a17cd2e8..4ab4046650 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -392,9 +392,6 @@ class ApplicationServicesHandler: protocols[p].append(info) def _merge_instances(infos: List[JsonDict]) -> JsonDict: - if not infos: - return {} - # Merge the 'instances' lists of multiple results, but just take # the other fields from the first as they ought to be identical # copy the result so as not to corrupt the cached one @@ -406,7 +403,9 @@ class ApplicationServicesHandler: return combined - return {p: _merge_instances(protocols[p]) for p in protocols.keys()} + return { + p: _merge_instances(protocols[p]) for p in protocols.keys() if protocols[p] + } async def _get_services_for_event( self, event: EventBase diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 024c5e963c..43998020b2 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -133,11 +133,131 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.assertEquals(result.room_id, room_id) self.assertEquals(result.servers, servers) - def _mkservice(self, is_interested): + def test_get_3pe_protocols_no_appservices(self): + self.mock_store.get_app_services.return_value = [] + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) + ) + self.mock_as_api.get_3pe_protocol.assert_not_called() + self.assertEquals(response, {}) + + def test_get_3pe_protocols_no_protocols(self): + service = self._mkservice(False, []) + self.mock_store.get_app_services.return_value = [service] + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_not_called() + self.assertEquals(response, {}) + + def test_get_3pe_protocols_protocol_no_response(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals(response, {}) + + def test_get_3pe_protocols_select_one_protocol(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals( + response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} + ) + + def test_get_3pe_protocols_one_protocol(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals( + response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} + ) + + def test_get_3pe_protocols_multiple_protocol(self): + service_one = self._mkservice(False, ["my-protocol"]) + service_two = self._mkservice(False, ["other-protocol"]) + self.mock_store.get_app_services.return_value = [service_one, service_two] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called() + self.assertEquals( + response, + { + "my-protocol": {"x-protocol-data": 42, "instances": []}, + "other-protocol": {"x-protocol-data": 42, "instances": []}, + }, + ) + + def test_get_3pe_protocols_multiple_info(self): + service_one = self._mkservice(False, ["my-protocol"]) + service_two = self._mkservice(False, ["my-protocol"]) + + async def get_3pe_protocol(service, unusedProtocol): + if service == service_one: + return { + "x-protocol-data": 42, + "instances": [{"desc": "Alice's service"}], + } + if service == service_two: + return { + "x-protocol-data": 36, + "x-not-used": 45, + "instances": [{"desc": "Bob's service"}], + } + raise Exception("Unexpected service") + + self.mock_store.get_app_services.return_value = [service_one, service_two] + self.mock_as_api.get_3pe_protocol = get_3pe_protocol + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + # It's expected that the second service's data doesn't appear in the response + self.assertEquals( + response, + { + "my-protocol": { + "x-protocol-data": 42, + "instances": [ + { + "desc": "Alice's service", + }, + {"desc": "Bob's service"}, + ], + }, + }, + ) + + def _mkservice(self, is_interested, protocols=None): service = Mock() service.is_interested.return_value = make_awaitable(is_interested) service.token = "mock_service_token" service.url = "mock_service_url" + service.protocols = protocols return service def _mkservice_alias(self, is_interested_in_alias): -- cgit 1.5.1 From 3b354faad0e6b1f41ed5dd0269a1785d3f505465 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 5 Aug 2021 08:39:17 -0400 Subject: Refactoring before implementing the updated spaces summary. (#10527) This should have no user-visible changes, but refactors some pieces of the SpaceSummaryHandler before adding support for the updated MSC2946. --- changelog.d/10527.misc | 1 + synapse/federation/federation_client.py | 23 ++-- synapse/handlers/space_summary.py | 125 ++++++++++++--------- tests/handlers/test_space_summary.py | 185 ++++++++++++++++++-------------- 4 files changed, 198 insertions(+), 136 deletions(-) create mode 100644 changelog.d/10527.misc diff --git a/changelog.d/10527.misc b/changelog.d/10527.misc new file mode 100644 index 0000000000..3cf22f9daf --- /dev/null +++ b/changelog.d/10527.misc @@ -0,0 +1 @@ +Prepare for the new spaces summary endpoint (updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946)). diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b7a10da15a..007d1a27dc 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1290,7 +1290,7 @@ class FederationClient(FederationBase): ) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class FederationSpaceSummaryEventResult: """Represents a single event in the result of a successful get_space_summary call. @@ -1299,12 +1299,13 @@ class FederationSpaceSummaryEventResult: object attributes. """ - event_type = attr.ib(type=str) - state_key = attr.ib(type=str) - via = attr.ib(type=Sequence[str]) + event_type: str + room_id: str + state_key: str + via: Sequence[str] # the raw data, including the above keys - data = attr.ib(type=JsonDict) + data: JsonDict @classmethod def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryEventResult": @@ -1321,6 +1322,10 @@ class FederationSpaceSummaryEventResult: if not isinstance(event_type, str): raise ValueError("Invalid event: 'event_type' must be a str") + room_id = d.get("room_id") + if not isinstance(room_id, str): + raise ValueError("Invalid event: 'room_id' must be a str") + state_key = d.get("state_key") if not isinstance(state_key, str): raise ValueError("Invalid event: 'state_key' must be a str") @@ -1335,15 +1340,15 @@ class FederationSpaceSummaryEventResult: if any(not isinstance(v, str) for v in via): raise ValueError("Invalid event: 'via' must be a list of strings") - return cls(event_type, state_key, via, d) + return cls(event_type, room_id, state_key, via, d) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class FederationSpaceSummaryResult: """Represents the data returned by a successful get_space_summary call.""" - rooms = attr.ib(type=Sequence[JsonDict]) - events = attr.ib(type=Sequence[FederationSpaceSummaryEventResult]) + rooms: Sequence[JsonDict] + events: Sequence[FederationSpaceSummaryEventResult] @classmethod def from_json_dict(cls, d: JsonDict) -> "FederationSpaceSummaryResult": diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index 5f7d4602bd..3eb232c83e 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -16,7 +16,17 @@ import itertools import logging import re from collections import deque -from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Set, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, +) import attr @@ -116,20 +126,22 @@ class SpaceSummaryHandler: max_children = max_rooms_per_space if processed_rooms else None if is_in_room: - room, events = await self._summarize_local_room( + room_entry = await self._summarize_local_room( requester, None, room_id, suggested_only, max_children ) + events: Collection[JsonDict] = [] + if room_entry: + rooms_result.append(room_entry.room) + events = room_entry.children + logger.debug( "Query of local room %s returned events %s", room_id, ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], ) - - if room: - rooms_result.append(room) else: - fed_rooms, fed_events = await self._summarize_remote_room( + fed_rooms = await self._summarize_remote_room( queue_entry, suggested_only, max_children, @@ -141,12 +153,10 @@ class SpaceSummaryHandler: # user is not permitted see. # # Filter the returned results to only what is accessible to the user. - room_ids = set() events = [] - for room in fed_rooms: - fed_room_id = room.get("room_id") - if not fed_room_id or not isinstance(fed_room_id, str): - continue + for room_entry in fed_rooms: + room = room_entry.room + fed_room_id = room_entry.room_id # The room should only be included in the summary if: # a. the user is in the room; @@ -189,21 +199,17 @@ class SpaceSummaryHandler: # The user can see the room, include it! if include_room: rooms_result.append(room) - room_ids.add(fed_room_id) + events.extend(room_entry.children) # All rooms returned don't need visiting again (even if the user # didn't have access to them). processed_rooms.add(fed_room_id) - for event in fed_events: - if event.get("room_id") in room_ids: - events.append(event) - logger.debug( "Query of %s returned rooms %s, events %s", room_id, - [room.get("room_id") for room in fed_rooms], - ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in fed_events], + [room_entry.room.get("room_id") for room_entry in fed_rooms], + ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], ) # the room we queried may or may not have been returned, but don't process @@ -283,20 +289,20 @@ class SpaceSummaryHandler: # already done this room continue - logger.debug("Processing room %s", room_id) - - room, events = await self._summarize_local_room( + room_entry = await self._summarize_local_room( None, origin, room_id, suggested_only, max_rooms_per_space ) processed_rooms.add(room_id) - if room: - rooms_result.append(room) - events_result.extend(events) + if room_entry: + rooms_result.append(room_entry.room) + events_result.extend(room_entry.children) - # add any children to the queue - room_queue.extend(edge_event["state_key"] for edge_event in events) + # add any children to the queue + room_queue.extend( + edge_event["state_key"] for edge_event in room_entry.children + ) return {"rooms": rooms_result, "events": events_result} @@ -307,7 +313,7 @@ class SpaceSummaryHandler: room_id: str, suggested_only: bool, max_children: Optional[int], - ) -> Tuple[Optional[JsonDict], Sequence[JsonDict]]: + ) -> Optional["_RoomEntry"]: """ Generate a room entry and a list of event entries for a given room. @@ -326,21 +332,16 @@ class SpaceSummaryHandler: to a server-set limit. Returns: - A tuple of: - The room information, if the room should be returned to the - user. None, otherwise. - - An iterable of the sorted children events. This may be limited - to a maximum size or may include all children. + A room entry if the room should be returned. None, otherwise. """ if not await self._is_room_accessible(room_id, requester, origin): - return None, () + return None room_entry = await self._build_room_entry(room_id) # If the room is not a space, return just the room information. if room_entry.get("room_type") != RoomTypes.SPACE: - return room_entry, () + return _RoomEntry(room_id, room_entry) # Otherwise, look for child rooms/spaces. child_events = await self._get_child_events(room_id) @@ -363,7 +364,7 @@ class SpaceSummaryHandler: ) ) - return room_entry, events_result + return _RoomEntry(room_id, room_entry, events_result) async def _summarize_remote_room( self, @@ -371,7 +372,7 @@ class SpaceSummaryHandler: suggested_only: bool, max_children: Optional[int], exclude_rooms: Iterable[str], - ) -> Tuple[Sequence[JsonDict], Sequence[JsonDict]]: + ) -> Iterable["_RoomEntry"]: """ Request room entries and a list of event entries for a given room by querying a remote server. @@ -386,11 +387,7 @@ class SpaceSummaryHandler: Rooms IDs which do not need to be summarized. Returns: - A tuple of: - An iterable of rooms. - - An iterable of the sorted children events. This may be limited - to a maximum size or may include all children. + An iterable of room entries. """ room_id = room.room_id logger.info("Requesting summary for %s via %s", room_id, room.via) @@ -414,11 +411,30 @@ class SpaceSummaryHandler: e, exc_info=logger.isEnabledFor(logging.DEBUG), ) - return (), () + return () + + # Group the events by their room. + children_by_room: Dict[str, List[JsonDict]] = {} + for ev in res.events: + if ev.event_type == EventTypes.SpaceChild: + children_by_room.setdefault(ev.room_id, []).append(ev.data) + + # Generate the final results. + results = [] + for fed_room in res.rooms: + fed_room_id = fed_room.get("room_id") + if not fed_room_id or not isinstance(fed_room_id, str): + continue - return res.rooms, tuple( - ev.data for ev in res.events if ev.event_type == EventTypes.SpaceChild - ) + results.append( + _RoomEntry( + fed_room_id, + fed_room, + children_by_room.get(fed_room_id, []), + ) + ) + + return results async def _is_room_accessible( self, room_id: str, requester: Optional[str], origin: Optional[str] @@ -606,10 +622,21 @@ class SpaceSummaryHandler: return sorted(filter(_has_valid_via, events), key=_child_events_comparison_key) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class _RoomQueueEntry: - room_id = attr.ib(type=str) - via = attr.ib(type=Sequence[str]) + room_id: str + via: Sequence[str] + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class _RoomEntry: + room_id: str + # The room summary for this room. + room: JsonDict + # An iterable of the sorted, stripped children events for children of this room. + # + # This may not include all children. + children: Collection[JsonDict] = () def _has_valid_via(e: EventBase) -> bool: diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index 3f73ad7f94..f982a8c8b4 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -26,7 +26,7 @@ from synapse.api.constants import ( from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict -from synapse.handlers.space_summary import _child_events_comparison_key +from synapse.handlers.space_summary import _child_events_comparison_key, _RoomEntry from synapse.rest import admin from synapse.rest.client.v1 import login, room from synapse.server import HomeServer @@ -351,26 +351,30 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # events before child events). # Note that these entries are brief, but should contain enough info. - rooms = [ - { - "room_id": subspace, - "world_readable": True, - "room_type": RoomTypes.SPACE, - }, - { - "room_id": subroom, - "world_readable": True, - }, - ] - event_content = {"via": [fed_hostname]} - events = [ - { - "room_id": subspace, - "state_key": subroom, - "content": event_content, - }, + return [ + _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + "room_type": RoomTypes.SPACE, + }, + [ + { + "room_id": subspace, + "state_key": subroom, + "content": {"via": [fed_hostname]}, + } + ], + ), + _RoomEntry( + subroom, + { + "room_id": subroom, + "world_readable": True, + }, + ), ] - return rooms, events # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) @@ -436,70 +440,95 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ): # Note that these entries are brief, but should contain enough info. rooms = [ - { - "room_id": public_room, - "world_readable": False, - "join_rules": JoinRules.PUBLIC, - }, - { - "room_id": knock_room, - "world_readable": False, - "join_rules": JoinRules.KNOCK, - }, - { - "room_id": not_invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": restricted_room, - "world_readable": False, - "join_rules": JoinRules.MSC3083_RESTRICTED, - "allowed_spaces": [], - }, - { - "room_id": restricted_accessible_room, - "world_readable": False, - "join_rules": JoinRules.MSC3083_RESTRICTED, - "allowed_spaces": [self.room], - }, - { - "room_id": world_readable_room, - "world_readable": True, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": joined_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ] - - # Place each room in the sub-space. - event_content = {"via": [fed_hostname]} - events = [ - { - "room_id": subspace, - "state_key": room["room_id"], - "content": event_content, - } - for room in rooms + _RoomEntry( + public_room, + { + "room_id": public_room, + "world_readable": False, + "join_rules": JoinRules.PUBLIC, + }, + ), + _RoomEntry( + knock_room, + { + "room_id": knock_room, + "world_readable": False, + "join_rules": JoinRules.KNOCK, + }, + ), + _RoomEntry( + not_invited_room, + { + "room_id": not_invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + _RoomEntry( + invited_room, + { + "room_id": invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + _RoomEntry( + restricted_room, + { + "room_id": restricted_room, + "world_readable": False, + "join_rules": JoinRules.MSC3083_RESTRICTED, + "allowed_spaces": [], + }, + ), + _RoomEntry( + restricted_accessible_room, + { + "room_id": restricted_accessible_room, + "world_readable": False, + "join_rules": JoinRules.MSC3083_RESTRICTED, + "allowed_spaces": [self.room], + }, + ), + _RoomEntry( + world_readable_room, + { + "room_id": world_readable_room, + "world_readable": True, + "join_rules": JoinRules.INVITE, + }, + ), + _RoomEntry( + joined_room, + { + "room_id": joined_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), ] # Also include the subspace. rooms.insert( 0, - { - "room_id": subspace, - "world_readable": True, - }, + _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + }, + # Place each room in the sub-space. + [ + { + "room_id": subspace, + "state_key": room.room_id, + "content": {"via": [fed_hostname]}, + } + for room in rooms + ], + ), ) - return rooms, events + return rooms # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) -- cgit 1.5.1 From 457853100240fc5e015e10a62ecffdd799128b0c Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 5 Aug 2021 20:00:44 +0200 Subject: fix broken links in `upgrade.md` (#10543) Signed-off-by: Dirk Klimpel dirk@klimpel.org --- changelog.d/10543.doc | 1 + docs/upgrade.md | 51 +++++++++++++++++++++++---------------------------- 2 files changed, 24 insertions(+), 28 deletions(-) create mode 100644 changelog.d/10543.doc diff --git a/changelog.d/10543.doc b/changelog.d/10543.doc new file mode 100644 index 0000000000..6c06722eb4 --- /dev/null +++ b/changelog.d/10543.doc @@ -0,0 +1 @@ +Fix broken links in `upgrade.md`. Contributed by @dklimpel. diff --git a/docs/upgrade.md b/docs/upgrade.md index c8f4a2c171..ce9167e6de 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -142,9 +142,9 @@ SQLite databases are unaffected by this change. The current spam checker interface is deprecated in favour of a new generic modules system. Authors of spam checker modules can refer to [this -documentation](https://matrix-org.github.io/synapse/develop/modules.html#porting-an-existing-module-that-uses-the-old-interface) +documentation](modules.md#porting-an-existing-module-that-uses-the-old-interface) to update their modules. Synapse administrators can refer to [this -documentation](https://matrix-org.github.io/synapse/develop/modules.html#using-modules) +documentation](modules.md#using-modules) to update their configuration once the modules they are using have been updated. We plan to remove support for the current spam checker interface in August 2021. @@ -217,8 +217,7 @@ Instructions for doing so are provided ## Dropping support for old Python, Postgres and SQLite versions -In line with our [deprecation -policy](https://github.com/matrix-org/synapse/blob/release-v1.32.0/docs/deprecation_policy.md), +In line with our [deprecation policy](deprecation_policy.md), we've dropped support for Python 3.5 and PostgreSQL 9.5, as they are no longer supported upstream. @@ -231,8 +230,7 @@ The deprecated v1 "list accounts" admin API (`GET /_synapse/admin/v1/users/`) has been removed in this version. -The [v2 list accounts -API](https://github.com/matrix-org/synapse/blob/master/docs/admin_api/user_admin_api.rst#list-accounts) +The [v2 list accounts API](admin_api/user_admin_api.md#list-accounts) has been available since Synapse 1.7.0 (2019-12-13), and is accessible under `GET /_synapse/admin/v2/users`. @@ -267,7 +265,7 @@ by the client. Synapse also requires the [Host]{.title-ref} header to be preserved. -See the [reverse proxy documentation](../reverse_proxy.md), where the +See the [reverse proxy documentation](reverse_proxy.md), where the example configurations have been updated to show how to set these headers. @@ -286,7 +284,7 @@ identity providers: `[synapse public baseurl]/_synapse/client/oidc/callback` to the list of permitted "redirect URIs" at the identity provider. - See the [OpenID docs](../openid.md) for more information on setting + See the [OpenID docs](openid.md) for more information on setting up OpenID Connect. - If your server is configured for single sign-on via a SAML2 identity @@ -486,8 +484,7 @@ lock down external access to the Admin API endpoints. This release deprecates use of the `structured: true` logging configuration for structured logging. If your logging configuration contains `structured: true` then it should be modified based on the -[structured logging -documentation](../structured_logging.md). +[structured logging documentation](structured_logging.md). The `structured` and `drains` logging options are now deprecated and should be replaced by standard logging configuration of `handlers` and @@ -517,14 +514,13 @@ acts the same as the `http_client` argument previously passed to ## Forwarding `/_synapse/client` through your reverse proxy -The [reverse proxy -documentation](https://github.com/matrix-org/synapse/blob/develop/docs/reverse_proxy.md) +The [reverse proxy documentation](reverse_proxy.md) has been updated to include reverse proxy directives for `/_synapse/client/*` endpoints. As the user password reset flow now uses endpoints under this prefix, **you must update your reverse proxy configurations for user password reset to work**. -Additionally, note that the [Synapse worker documentation](https://github.com/matrix-org/synapse/blob/develop/docs/workers.md) has been updated to +Additionally, note that the [Synapse worker documentation](workers.md) has been updated to : state that the `/_synapse/client/password_reset/email/submit_token` endpoint can be handled @@ -588,7 +584,7 @@ updated. When setting up worker processes, we now recommend the use of a Redis server for replication. **The old direct TCP connection method is deprecated and will be removed in a future release.** See -[workers](../workers.md) for more details. +[workers](workers.md) for more details. # Upgrading to v1.14.0 @@ -720,8 +716,7 @@ participating in many rooms. omitting the `CONCURRENTLY` keyword. Note however that this operation may in itself cause Synapse to stop running for some time. Synapse admins are reminded that [SQLite is not recommended for use - outside a test - environment](https://github.com/matrix-org/synapse/blob/master/README.rst#using-postgresql). + outside a test environment](postgres.md). 3. Once the index has been created, the `SELECT` query in step 1 above should complete quickly. It is therefore safe to upgrade to Synapse @@ -739,7 +734,7 @@ participating in many rooms. Synapse will now log a warning on start up if used with a PostgreSQL database that has a non-recommended locale set. -See [Postgres](../postgres.md) for details. +See [Postgres](postgres.md) for details. # Upgrading to v1.8.0 @@ -856,8 +851,8 @@ section headed `email`, and be sure to have at least the You may also need to set `smtp_user`, `smtp_pass`, and `require_transport_security`. -See the [sample configuration file](docs/sample_config.yaml) for more -details on these settings. +See the [sample configuration file](usage/configuration/homeserver_sample_config.md) +for more details on these settings. #### Delegate email to an identity server @@ -959,7 +954,7 @@ back to v1.3.1, subject to the following: Some counter metrics have been renamed, with the old names deprecated. See [the metrics -documentation](../metrics-howto.md#renaming-of-metrics--deprecation-of-old-names-in-12) +documentation](metrics-howto.md#renaming-of-metrics--deprecation-of-old-names-in-12) for details. # Upgrading to v1.1.0 @@ -995,7 +990,7 @@ more details on upgrading your database. Synapse v1.0 is the first release to enforce validation of TLS certificates for the federation API. It is therefore essential that your certificates are correctly configured. See the -[FAQ](../MSC1711_certificates_FAQ.md) for more information. +[FAQ](MSC1711_certificates_FAQ.md) for more information. Note, v1.0 installations will also no longer be able to federate with servers that have not correctly configured their certificates. @@ -1010,8 +1005,8 @@ ways:- - Configure a whitelist of server domains to trust via `federation_certificate_verification_whitelist`. -See the [sample configuration file](docs/sample_config.yaml) for more -details on these settings. +See the [sample configuration file](usage/configuration/homeserver_sample_config.md) +for more details on these settings. ## Email @@ -1036,8 +1031,8 @@ If you are absolutely certain that you wish to continue using an identity server for password resets, set `trust_identity_server_for_password_resets` to `true`. -See the [sample configuration file](docs/sample_config.yaml) for more -details on these settings. +See the [sample configuration file](usage/configuration/homeserver_sample_config.md) +for more details on these settings. ## New email templates @@ -1057,11 +1052,11 @@ sent to them. Please be aware that, before Synapse v1.0 is released around March 2019, you will need to replace any self-signed certificates with those -verified by a root CA. Information on how to do so can be found at [the -ACME docs](../ACME.md). +verified by a root CA. Information on how to do so can be found at the +ACME docs. For more information on configuring TLS certificates see the -[FAQ](../MSC1711_certificates_FAQ.md). +[FAQ](MSC1711_certificates_FAQ.md). # Upgrading to v0.34.0 -- cgit 1.5.1 From f5a368bb48df85dd488afdead01a39f77f50de99 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 5 Aug 2021 20:35:53 -0500 Subject: Mark all MSC2716 events as historical (#10537) * Mark all MSC2716 events as historical --- changelog.d/10537.misc | 1 + synapse/rest/client/v1/room.py | 15 ++++++++++----- 2 files changed, 11 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10537.misc diff --git a/changelog.d/10537.misc b/changelog.d/10537.misc new file mode 100644 index 0000000000..c9e045300c --- /dev/null +++ b/changelog.d/10537.misc @@ -0,0 +1 @@ +Mark all events stemming from the MSC2716 `/batch_send` endpoint as historical. diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 502a917588..982f134148 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -458,6 +458,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): "state_key": state_event["state_key"], } + # Mark all events as historical + event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True + # Make the state events float off on their own fake_prev_event_id = "$" + random_string(43) @@ -562,7 +565,10 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): "type": EventTypes.MSC2716_CHUNK, "sender": requester.user.to_string(), "room_id": room_id, - "content": {EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to}, + "content": { + EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to, + EventContentFields.MSC2716_HISTORICAL: True, + }, # Since the chunk event is put at the end of the chunk, # where the newest-in-time event is, copy the origin_server_ts from # the last event we're inserting @@ -589,10 +595,6 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): for ev in events_to_create: assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) - # Mark all events as historical - # This has important semantics within the Synapse internals to backfill properly - ev["content"][EventContentFields.MSC2716_HISTORICAL] = True - event_dict = { "type": ev["type"], "origin_server_ts": ev["origin_server_ts"], @@ -602,6 +604,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): "prev_events": prev_event_ids.copy(), } + # Mark all events as historical + event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True + event, context = await self.event_creation_handler.create_event( await self._create_requester_for_user_id_from_app_service( ev["sender"], requester.app_service -- cgit 1.5.1 From 74d7336686e7de1d0923d67af61b510ec801fa84 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 6 Aug 2021 11:13:34 +0100 Subject: Add a setting to disable TLS for sending email (#10546) This is mostly useful in case the server offers TLS, but doesn't present a valid certificate. --- changelog.d/10546.feature | 1 + docs/sample_config.yaml | 8 +++ synapse/config/emailconfig.py | 14 +++++ synapse/handlers/send_email.py | 94 +++++++++++++++++++++++------ synapse/server.py | 6 -- tests/push/test_email.py | 20 +++--- tests/rest/client/v2_alpha/test_account.py | 33 ++++++---- tests/rest/client/v2_alpha/test_register.py | 12 ++-- 8 files changed, 138 insertions(+), 50 deletions(-) create mode 100644 changelog.d/10546.feature diff --git a/changelog.d/10546.feature b/changelog.d/10546.feature new file mode 100644 index 0000000000..7709d010b3 --- /dev/null +++ b/changelog.d/10546.feature @@ -0,0 +1 @@ +Add a setting to disable TLS when sending email. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 16843dd8c9..aeebcaf45f 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2242,6 +2242,14 @@ email: # #require_transport_security: true + # Uncomment the following to disable TLS for SMTP. + # + # By default, if the server supports TLS, it will be used, and the server + # must present a certificate that is valid for 'smtp_host'. If this option + # is set to false, TLS will not be used. + # + #enable_tls: false + # notif_from defines the "From" address to use when sending emails. # It must be set if email sending is enabled. # diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 8d8f166e9b..42526502f0 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -80,6 +80,12 @@ class EmailConfig(Config): self.require_transport_security = email_config.get( "require_transport_security", False ) + self.enable_smtp_tls = email_config.get("enable_tls", True) + if self.require_transport_security and not self.enable_smtp_tls: + raise ConfigError( + "email.require_transport_security requires email.enable_tls to be true" + ) + if "app_name" in email_config: self.email_app_name = email_config["app_name"] else: @@ -368,6 +374,14 @@ class EmailConfig(Config): # #require_transport_security: true + # Uncomment the following to disable TLS for SMTP. + # + # By default, if the server supports TLS, it will be used, and the server + # must present a certificate that is valid for 'smtp_host'. If this option + # is set to false, TLS will not be used. + # + #enable_tls: false + # notif_from defines the "From" address to use when sending emails. # It must be set if email sending is enabled. # diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py index e9f6aef06f..dda9659c11 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py @@ -16,7 +16,12 @@ import email.utils import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -from typing import TYPE_CHECKING +from io import BytesIO +from typing import TYPE_CHECKING, Optional + +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IReactorTCP +from twisted.mail.smtp import ESMTPSenderFactory from synapse.logging.context import make_deferred_yieldable @@ -26,19 +31,75 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +async def _sendmail( + reactor: IReactorTCP, + smtphost: str, + smtpport: int, + from_addr: str, + to_addr: str, + msg_bytes: bytes, + username: Optional[bytes] = None, + password: Optional[bytes] = None, + require_auth: bool = False, + require_tls: bool = False, + tls_hostname: Optional[str] = None, +) -> None: + """A simple wrapper around ESMTPSenderFactory, to allow substitution in tests + + Params: + reactor: reactor to use to make the outbound connection + smtphost: hostname to connect to + smtpport: port to connect to + from_addr: "From" address for email + to_addr: "To" address for email + msg_bytes: Message content + username: username to authenticate with, if auth is enabled + password: password to give when authenticating + require_auth: if auth is not offered, fail the request + require_tls: if TLS is not offered, fail the reqest + tls_hostname: TLS hostname to check for. None to disable TLS. + """ + msg = BytesIO(msg_bytes) + + d: "Deferred[object]" = Deferred() + + factory = ESMTPSenderFactory( + username, + password, + from_addr, + to_addr, + msg, + d, + heloFallback=True, + requireAuthentication=require_auth, + requireTransportSecurity=require_tls, + hostname=tls_hostname, + ) + + # the IReactorTCP interface claims host has to be a bytes, which seems to be wrong + reactor.connectTCP(smtphost, smtpport, factory, timeout=30, bindAddress=None) # type: ignore[arg-type] + + await make_deferred_yieldable(d) + + class SendEmailHandler: def __init__(self, hs: "HomeServer"): self.hs = hs - self._sendmail = hs.get_sendmail() self._reactor = hs.get_reactor() self._from = hs.config.email.email_notif_from self._smtp_host = hs.config.email.email_smtp_host self._smtp_port = hs.config.email.email_smtp_port - self._smtp_user = hs.config.email.email_smtp_user - self._smtp_pass = hs.config.email.email_smtp_pass + + user = hs.config.email.email_smtp_user + self._smtp_user = user.encode("utf-8") if user is not None else None + passwd = hs.config.email.email_smtp_pass + self._smtp_pass = passwd.encode("utf-8") if passwd is not None else None self._require_transport_security = hs.config.email.require_transport_security + self._enable_tls = hs.config.email.enable_smtp_tls + + self._sendmail = _sendmail async def send_email( self, @@ -82,17 +143,16 @@ class SendEmailHandler: logger.info("Sending email to %s" % email_address) - await make_deferred_yieldable( - self._sendmail( - self._smtp_host, - raw_from, - raw_to, - multipart_msg.as_string().encode("utf8"), - reactor=self._reactor, - port=self._smtp_port, - requireAuthentication=self._smtp_user is not None, - username=self._smtp_user, - password=self._smtp_pass, - requireTransportSecurity=self._require_transport_security, - ) + await self._sendmail( + self._reactor, + self._smtp_host, + self._smtp_port, + raw_from, + raw_to, + multipart_msg.as_string().encode("utf8"), + username=self._smtp_user, + password=self._smtp_pass, + require_auth=self._smtp_user is not None, + require_tls=self._require_transport_security, + tls_hostname=self._smtp_host if self._enable_tls else None, ) diff --git a/synapse/server.py b/synapse/server.py index 095dba9ad0..6c867f0f47 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -34,8 +34,6 @@ from typing import ( ) import twisted.internet.tcp -from twisted.internet import defer -from twisted.mail.smtp import sendmail from twisted.web.iweb import IPolicyForHTTPS from twisted.web.resource import IResource @@ -442,10 +440,6 @@ class HomeServer(metaclass=abc.ABCMeta): def get_room_shutdown_handler(self) -> RoomShutdownHandler: return RoomShutdownHandler(self) - @cache_in_self - def get_sendmail(self) -> Callable[..., defer.Deferred]: - return sendmail - @cache_in_self def get_state_handler(self) -> StateHandler: return StateHandler(self) diff --git a/tests/push/test_email.py b/tests/push/test_email.py index e04bc5c9a6..a487706758 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -45,14 +45,6 @@ class EmailPusherTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): - # List[Tuple[Deferred, args, kwargs]] - self.email_attempts = [] - - def sendmail(*args, **kwargs): - d = Deferred() - self.email_attempts.append((d, args, kwargs)) - return d - config = self.default_config() config["email"] = { "enable_notifs": True, @@ -75,7 +67,17 @@ class EmailPusherTests(HomeserverTestCase): config["public_baseurl"] = "aaa" config["start_pushers"] = True - hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + hs = self.setup_test_homeserver(config=config) + + # List[Tuple[Deferred, args, kwargs]] + self.email_attempts = [] + + def sendmail(*args, **kwargs): + d = Deferred() + self.email_attempts.append((d, args, kwargs)) + return d + + hs.get_send_email_handler()._sendmail = sendmail return hs diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 317a2287e3..e7e617e9df 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -47,12 +47,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): config = self.default_config() # Email config. - self.email_attempts = [] - - async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): - self.email_attempts.append(msg) - return - config["email"] = { "enable_notifs": False, "template_dir": os.path.abspath( @@ -67,7 +61,16 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): } config["public_baseurl"] = "https://example.com" - hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + hs = self.setup_test_homeserver(config=config) + + async def sendmail( + reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs + ): + self.email_attempts.append(msg) + + self.email_attempts = [] + hs.get_send_email_handler()._sendmail = sendmail + return hs def prepare(self, reactor, clock, hs): @@ -511,11 +514,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): config = self.default_config() # Email config. - self.email_attempts = [] - - async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): - self.email_attempts.append(msg) - config["email"] = { "enable_notifs": False, "template_dir": os.path.abspath( @@ -530,7 +528,16 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): } config["public_baseurl"] = "https://example.com" - self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + self.hs = self.setup_test_homeserver(config=config) + + async def sendmail( + reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs + ): + self.email_attempts.append(msg) + + self.email_attempts = [] + self.hs.get_send_email_handler()._sendmail = sendmail + return self.hs def prepare(self, reactor, clock, hs): diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 1cad5f00eb..a52e5e608a 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -509,10 +509,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): } # Email config. - self.email_attempts = [] - - async def sendmail(*args, **kwargs): - self.email_attempts.append((args, kwargs)) config["email"] = { "enable_notifs": True, @@ -532,7 +528,13 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): } config["public_baseurl"] = "aaa" - self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + self.hs = self.setup_test_homeserver(config=config) + + async def sendmail(*args, **kwargs): + self.email_attempts.append((args, kwargs)) + + self.email_attempts = [] + self.hs.get_send_email_handler()._sendmail = sendmail self.store = self.hs.get_datastore() -- cgit 1.5.1 From f4ade972ada6d61ca9370d26784ac9f3ed8e5282 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 6 Aug 2021 07:40:29 -0400 Subject: Update the API response for spaces summary over federation. (#10530) This adds 'allowed_room_ids' (in addition to 'allowed_spaces', for backwards compatibility) to the federation response of the spaces summary. A future PR will remove the 'allowed_spaces' flag. --- changelog.d/10530.misc | 1 + synapse/handlers/space_summary.py | 57 ++++++++++++++++++++++++++------------- 2 files changed, 39 insertions(+), 19 deletions(-) create mode 100644 changelog.d/10530.misc diff --git a/changelog.d/10530.misc b/changelog.d/10530.misc new file mode 100644 index 0000000000..3cf22f9daf --- /dev/null +++ b/changelog.d/10530.misc @@ -0,0 +1 @@ +Prepare for the new spaces summary endpoint (updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946)). diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index 3eb232c83e..2517f278b6 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -179,7 +179,9 @@ class SpaceSummaryHandler: # Check if the user is a member of any of the allowed spaces # from the response. - allowed_rooms = room.get("allowed_spaces") + allowed_rooms = room.get("allowed_room_ids") or room.get( + "allowed_spaces" + ) if ( not include_room and allowed_rooms @@ -198,6 +200,11 @@ class SpaceSummaryHandler: # The user can see the room, include it! if include_room: + # Before returning to the client, remove the allowed_room_ids + # and allowed_spaces keys. + room.pop("allowed_room_ids", None) + room.pop("allowed_spaces", None) + rooms_result.append(room) events.extend(room_entry.children) @@ -236,11 +243,6 @@ class SpaceSummaryHandler: ) processed_events.add(ev_key) - # Before returning to the client, remove the allowed_spaces key for any - # rooms. - for room in rooms_result: - room.pop("allowed_spaces", None) - return {"rooms": rooms_result, "events": events_result} async def federation_space_summary( @@ -337,7 +339,7 @@ class SpaceSummaryHandler: if not await self._is_room_accessible(room_id, requester, origin): return None - room_entry = await self._build_room_entry(room_id) + room_entry = await self._build_room_entry(room_id, for_federation=bool(origin)) # If the room is not a space, return just the room information. if room_entry.get("room_type") != RoomTypes.SPACE: @@ -548,8 +550,18 @@ class SpaceSummaryHandler: ) return False - async def _build_room_entry(self, room_id: str) -> JsonDict: - """Generate en entry suitable for the 'rooms' list in the summary response""" + async def _build_room_entry(self, room_id: str, for_federation: bool) -> JsonDict: + """ + Generate en entry suitable for the 'rooms' list in the summary response. + + Args: + room_id: The room ID to summarize. + for_federation: True if this is a summary requested over federation + (which includes additional fields). + + Returns: + The JSON dictionary for the room. + """ stats = await self._store.get_room_with_stats(room_id) # currently this should be impossible because we call @@ -562,15 +574,6 @@ class SpaceSummaryHandler: current_state_ids[(EventTypes.Create, "")] ) - room_version = await self._store.get_room_version(room_id) - allowed_rooms = None - if await self._event_auth_handler.has_restricted_join_rules( - current_state_ids, room_version - ): - allowed_rooms = await self._event_auth_handler.get_rooms_that_allow_join( - current_state_ids - ) - entry = { "room_id": stats["room_id"], "name": stats["name"], @@ -585,9 +588,25 @@ class SpaceSummaryHandler: "guest_can_join": stats["guest_access"] == "can_join", "creation_ts": create_event.origin_server_ts, "room_type": create_event.content.get(EventContentFields.ROOM_TYPE), - "allowed_spaces": allowed_rooms, } + # Federation requests need to provide additional information so the + # requested server is able to filter the response appropriately. + if for_federation: + room_version = await self._store.get_room_version(room_id) + if await self._event_auth_handler.has_restricted_join_rules( + current_state_ids, room_version + ): + allowed_rooms = ( + await self._event_auth_handler.get_rooms_that_allow_join( + current_state_ids + ) + ) + if allowed_rooms: + entry["allowed_room_ids"] = allowed_rooms + # TODO Remove this key once the API is stable. + entry["allowed_spaces"] = allowed_rooms + # Filter out Nones – rather omit the field altogether room_entry = {k: v for k, v in entry.items() if v is not None} -- cgit 1.5.1 From 1bebc0b78cbedffb6b69fd76327f0eb7663c3c96 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 6 Aug 2021 13:54:23 +0100 Subject: Clean up federation event auth code (#10539) * drop old-room hack pretty sure we don't need this any more. * Remove incorrect comment about modifying `context` It doesn't look like the supplied context is ever modified. * Stop `_auth_and_persist_event` modifying its parameters This is only called in three places. Two of them don't pass `auth_events`, and the third doesn't use the dict after passing it in, so this should be non-functional. * Stop `_check_event_auth` modifying its parameters `_check_event_auth` is only called in three places. `on_send_membership_event` doesn't pass an `auth_events`, and `prep` and `_auth_and_persist_event` do not use the map after passing it in. * Stop `_update_auth_events_and_context_for_auth` modifying its parameters Return the updated auth event dict, rather than modifying the parameter. This is only called from `_check_event_auth`. * Improve documentation on `_auth_and_persist_event` Rename `auth_events` parameter to better reflect what it contains. * Improve documentation on `_NewEventInfo` * Improve documentation on `_check_event_auth` rename `auth_events` parameter to better describe what it contains * changelog --- changelog.d/10539.misc | 1 + synapse/handlers/federation.py | 118 +++++++++++++++++++++++------------------ tests/test_federation.py | 6 +-- 3 files changed, 69 insertions(+), 56 deletions(-) create mode 100644 changelog.d/10539.misc diff --git a/changelog.d/10539.misc b/changelog.d/10539.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10539.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 8b602e3813..9a5e726533 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -109,21 +109,33 @@ soft_failed_event_counter = Counter( ) -@attr.s(slots=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class _NewEventInfo: """Holds information about a received event, ready for passing to _auth_and_persist_events Attributes: event: the received event - state: the state at that event + state: the state at that event, according to /state_ids from a remote + homeserver. Only populated for backfilled events which are going to be a + new backwards extremity. + + claimed_auth_event_map: a map of (type, state_key) => event for the event's + claimed auth_events. + + This can include events which have not yet been persisted, in the case that + we are backfilling a batch of events. + + Note: May be incomplete: if we were unable to find all of the claimed auth + events. Also, treat the contents with caution: the events might also have + been rejected, might not yet have been authorized themselves, or they might + be in the wrong room. - auth_events: the auth_event map for that event """ - event = attr.ib(type=EventBase) - state = attr.ib(type=Optional[Sequence[EventBase]], default=None) - auth_events = attr.ib(type=Optional[MutableStateMap[EventBase]], default=None) + event: EventBase + state: Optional[Sequence[EventBase]] + claimed_auth_event_map: StateMap[EventBase] class FederationHandler(BaseHandler): @@ -1086,7 +1098,7 @@ class FederationHandler(BaseHandler): _NewEventInfo( event=ev, state=events_to_state[e_id], - auth_events={ + claimed_auth_event_map={ ( auth_events[a_id].type, auth_events[a_id].state_key, @@ -2315,7 +2327,7 @@ class FederationHandler(BaseHandler): event: EventBase, context: EventContext, state: Optional[Iterable[EventBase]] = None, - auth_events: Optional[MutableStateMap[EventBase]] = None, + claimed_auth_event_map: Optional[StateMap[EventBase]] = None, backfilled: bool = False, ) -> None: """ @@ -2327,17 +2339,18 @@ class FederationHandler(BaseHandler): context: The event context. - NB that this function potentially modifies it. state: The state events used to check the event for soft-fail. If this is not provided the current state events will be used. - auth_events: - Map from (event_type, state_key) to event - Normally, our calculated auth_events based on the state of the room - at the event's position in the DAG, though occasionally (eg if the - event is an outlier), may be the auth events claimed by the remote - server. + claimed_auth_event_map: + A map of (type, state_key) => event for the event's claimed auth_events. + Possibly incomplete, and possibly including events that are not yet + persisted, or authed, or in the right room. + + Only populated where we may not already have persisted these events - + for example, when populating outliers. + backfilled: True if the event was backfilled. """ context = await self._check_event_auth( @@ -2345,7 +2358,7 @@ class FederationHandler(BaseHandler): event, context, state=state, - auth_events=auth_events, + claimed_auth_event_map=claimed_auth_event_map, backfilled=backfilled, ) @@ -2409,7 +2422,7 @@ class FederationHandler(BaseHandler): event, res, state=ev_info.state, - auth_events=ev_info.auth_events, + claimed_auth_event_map=ev_info.claimed_auth_event_map, backfilled=backfilled, ) return res @@ -2675,7 +2688,7 @@ class FederationHandler(BaseHandler): event: EventBase, context: EventContext, state: Optional[Iterable[EventBase]] = None, - auth_events: Optional[MutableStateMap[EventBase]] = None, + claimed_auth_event_map: Optional[StateMap[EventBase]] = None, backfilled: bool = False, ) -> EventContext: """ @@ -2687,21 +2700,19 @@ class FederationHandler(BaseHandler): context: The event context. - NB that this function potentially modifies it. state: The state events used to check the event for soft-fail. If this is not provided the current state events will be used. - auth_events: - Map from (event_type, state_key) to event - Normally, our calculated auth_events based on the state of the room - at the event's position in the DAG, though occasionally (eg if the - event is an outlier), may be the auth events claimed by the remote - server. + claimed_auth_event_map: + A map of (type, state_key) => event for the event's claimed auth_events. + Possibly incomplete, and possibly including events that are not yet + persisted, or authed, or in the right room. - Also NB that this function adds entries to it. + Only populated where we may not already have persisted these events - + for example, when populating outliers, or the state for a backwards + extremity. - If this is not provided, it is calculated from the previous state IDs. backfilled: True if the event was backfilled. Returns: @@ -2710,7 +2721,12 @@ class FederationHandler(BaseHandler): room_version = await self.store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - if not auth_events: + if claimed_auth_event_map: + # if we have a copy of the auth events from the event, use that as the + # basis for auth. + auth_events = claimed_auth_event_map + else: + # otherwise, we calculate what the auth events *should* be, and use that prev_state_ids = await context.get_prev_state_ids() auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=True @@ -2718,18 +2734,11 @@ class FederationHandler(BaseHandler): auth_events_x = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} - # This is a hack to fix some old rooms where the initial join event - # didn't reference the create event in its auth events. - if event.type == EventTypes.Member and not event.auth_event_ids(): - if len(event.prev_event_ids()) == 1 and event.depth < 5: - c = await self.store.get_event( - event.prev_event_ids()[0], allow_none=True - ) - if c and c.type == EventTypes.Create: - auth_events[(c.type, c.state_key)] = c - try: - context = await self._update_auth_events_and_context_for_auth( + ( + context, + auth_events_for_auth, + ) = await self._update_auth_events_and_context_for_auth( origin, event, context, auth_events ) except Exception: @@ -2742,9 +2751,10 @@ class FederationHandler(BaseHandler): "Ignoring failure and continuing processing of event.", event.event_id, ) + auth_events_for_auth = auth_events try: - event_auth.check(room_version_obj, event, auth_events=auth_events) + event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth) except AuthError as e: logger.warning("Failed auth resolution for %r because %s", event, e) context.rejected = RejectedReason.AUTH_ERROR @@ -2769,8 +2779,8 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, context: EventContext, - auth_events: MutableStateMap[EventBase], - ) -> EventContext: + input_auth_events: StateMap[EventBase], + ) -> Tuple[EventContext, StateMap[EventBase]]: """Helper for _check_event_auth. See there for docs. Checks whether a given event has the expected auth events. If it @@ -2787,7 +2797,7 @@ class FederationHandler(BaseHandler): event: context: - auth_events: + input_auth_events: Map from (event_type, state_key) to event Normally, our calculated auth_events based on the state of the room @@ -2795,11 +2805,12 @@ class FederationHandler(BaseHandler): event is an outlier), may be the auth events claimed by the remote server. - Also NB that this function adds entries to it. - Returns: - updated context + updated context, updated auth event map """ + # take a copy of input_auth_events before we modify it. + auth_events: MutableStateMap[EventBase] = dict(input_auth_events) + event_auth_events = set(event.auth_event_ids()) # missing_auth is the set of the event's auth_events which we don't yet have @@ -2828,7 +2839,7 @@ class FederationHandler(BaseHandler): # The other side isn't around or doesn't implement the # endpoint, so lets just bail out. logger.info("Failed to get event auth from remote: %s", e1) - return context + return context, auth_events seen_remotes = await self.store.have_seen_events( event.room_id, [e.event_id for e in remote_auth_chain] @@ -2859,7 +2870,10 @@ class FederationHandler(BaseHandler): await self.state_handler.compute_event_context(e) ) await self._auth_and_persist_event( - origin, e, missing_auth_event_context, auth_events=auth + origin, + e, + missing_auth_event_context, + claimed_auth_event_map=auth, ) if e.event_id in event_auth_events: @@ -2877,14 +2891,14 @@ class FederationHandler(BaseHandler): # obviously be empty # (b) alternatively, why don't we do it earlier? logger.info("Skipping auth_event fetch for outlier") - return context + return context, auth_events different_auth = event_auth_events.difference( e.event_id for e in auth_events.values() ) if not different_auth: - return context + return context, auth_events logger.info( "auth_events refers to events which are not in our calculated auth " @@ -2910,7 +2924,7 @@ class FederationHandler(BaseHandler): # XXX: should we reject the event in this case? It feels like we should, # but then shouldn't we also do so if we've failed to fetch any of the # auth events? - return context + return context, auth_events # now we state-resolve between our own idea of the auth events, and the remote's # idea of them. @@ -2940,7 +2954,7 @@ class FederationHandler(BaseHandler): event, context, auth_events ) - return context + return context, auth_events async def _update_context_for_auth_events( self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] diff --git a/tests/test_federation.py b/tests/test_federation.py index 0ed8326f55..3785799f46 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -75,10 +75,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) self.handler = self.homeserver.get_federation_handler() - self.handler._check_event_auth = ( - lambda origin, event, context, state, auth_events, backfilled: succeed( - context - ) + self.handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed( + context ) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( -- cgit 1.5.1 From 60f0534b6e910a497800da2454638bcf4aae006e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 6 Aug 2021 14:05:41 +0100 Subject: Fix exceptions in logs when failing to get remote room list (#10541) --- changelog.d/10541.bugfix | 1 + synapse/federation/federation_client.py | 3 +- synapse/handlers/room_list.py | 46 ++++++++++------- synapse/rest/client/v1/room.py | 30 +++++------ tests/rest/client/v1/test_rooms.py | 92 ++++++++++++++++++++++++++++++++- 5 files changed, 134 insertions(+), 38 deletions(-) create mode 100644 changelog.d/10541.bugfix diff --git a/changelog.d/10541.bugfix b/changelog.d/10541.bugfix new file mode 100644 index 0000000000..bb946e0920 --- /dev/null +++ b/changelog.d/10541.bugfix @@ -0,0 +1 @@ +Fix exceptions in logs when failing to get remote room list. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 007d1a27dc..2eefac04fd 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1108,7 +1108,8 @@ class FederationClient(FederationBase): The response from the remote server. Raises: - HttpResponseException: There was an exception returned from the remote server + HttpResponseException / RequestSendFailed: There was an exception + returned from the remote server SynapseException: M_FORBIDDEN when the remote server has disallowed publicRoom requests over federation diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index fae2c098e3..6d433fad41 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -356,6 +356,12 @@ class RoomListHandler(BaseHandler): include_all_networks: bool = False, third_party_instance_id: Optional[str] = None, ) -> JsonDict: + """Get the public room list from remote server + + Raises: + SynapseError + """ + if not self.enable_room_list_search: return {"chunk": [], "total_room_count_estimate": 0} @@ -395,13 +401,16 @@ class RoomListHandler(BaseHandler): limit = None since_token = None - res = await self._get_remote_list_cached( - server_name, - limit=limit, - since_token=since_token, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) + try: + res = await self._get_remote_list_cached( + server_name, + limit=limit, + since_token=since_token, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) + except (RequestSendFailed, HttpResponseException): + raise SynapseError(502, "Failed to fetch room list") if search_filter: res = { @@ -423,20 +432,21 @@ class RoomListHandler(BaseHandler): include_all_networks: bool = False, third_party_instance_id: Optional[str] = None, ) -> JsonDict: + """Wrapper around FederationClient.get_public_rooms that caches the + result. + """ + repl_layer = self.hs.get_federation_client() if search_filter: # We can't cache when asking for search - try: - return await repl_layer.get_public_rooms( - server_name, - limit=limit, - since_token=since_token, - search_filter=search_filter, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) - except (RequestSendFailed, HttpResponseException): - raise SynapseError(502, "Failed to fetch room list") + return await repl_layer.get_public_rooms( + server_name, + limit=limit, + since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) key = ( server_name, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 982f134148..f887970b76 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -23,7 +23,6 @@ from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, - HttpResponseException, InvalidClientCredentialsError, ShadowBanError, SynapseError, @@ -783,12 +782,9 @@ class PublicRoomListRestServlet(TransactionRestServlet): Codes.INVALID_PARAM, ) - try: - data = await handler.get_remote_public_room_list( - server, limit=limit, since_token=since_token - ) - except HttpResponseException as e: - raise e.to_synapse_error() + data = await handler.get_remote_public_room_list( + server, limit=limit, since_token=since_token + ) else: data = await handler.get_local_public_room_list( limit=limit, since_token=since_token @@ -836,17 +832,15 @@ class PublicRoomListRestServlet(TransactionRestServlet): Codes.INVALID_PARAM, ) - try: - data = await handler.get_remote_public_room_list( - server, - limit=limit, - since_token=since_token, - search_filter=search_filter, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) - except HttpResponseException as e: - raise e.to_synapse_error() + data = await handler.get_remote_public_room_list( + server, + limit=limit, + since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) + else: data = await handler.get_local_public_room_list( limit=limit, diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 3df070c936..1a9528ec20 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -19,11 +19,14 @@ import json from typing import Iterable -from unittest.mock import Mock +from unittest.mock import Mock, call from urllib import parse as urlparse +from twisted.internet import defer + import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.errors import HttpResponseException from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client.v1 import directory, login, profile, room @@ -1124,6 +1127,93 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) +class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): + """Test that we correctly fallback to local filtering if a remote server + doesn't support search. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver(federation_client=Mock()) + + def prepare(self, reactor, clock, hs): + self.register_user("user", "pass") + self.token = self.login("user", "pass") + + self.federation_client = hs.get_federation_client() + + def test_simple(self): + "Simple test for searching rooms over federation" + self.federation_client.get_public_rooms.side_effect = ( + lambda *a, **k: defer.succeed({}) + ) + + search_filter = {"generic_search_term": "foobar"} + + channel = self.make_request( + "POST", + b"/_matrix/client/r0/publicRooms?server=testserv", + content={"filter": search_filter}, + access_token=self.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.federation_client.get_public_rooms.assert_called_once_with( + "testserv", + limit=100, + since_token=None, + search_filter=search_filter, + include_all_networks=False, + third_party_instance_id=None, + ) + + def test_fallback(self): + "Test that searching public rooms over federation falls back if it gets a 404" + + # The `get_public_rooms` should be called again if the first call fails + # with a 404, when using search filters. + self.federation_client.get_public_rooms.side_effect = ( + HttpResponseException(404, "Not Found", b""), + defer.succeed({}), + ) + + search_filter = {"generic_search_term": "foobar"} + + channel = self.make_request( + "POST", + b"/_matrix/client/r0/publicRooms?server=testserv", + content={"filter": search_filter}, + access_token=self.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.federation_client.get_public_rooms.assert_has_calls( + [ + call( + "testserv", + limit=100, + since_token=None, + search_filter=search_filter, + include_all_networks=False, + third_party_instance_id=None, + ), + call( + "testserv", + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ), + ] + ) + + class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): servlets = [ -- cgit 1.5.1 From 1de26b346796ec8d6b51b4395017f8107f640c47 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 6 Aug 2021 09:39:59 -0400 Subject: Convert Transaction and Edu object to attrs (#10542) Instead of wrapping the JSON into an object, this creates concrete instances for Transaction and Edu. This allows for improved type hints and simplified code. --- changelog.d/10542.misc | 1 + synapse/federation/federation_server.py | 50 ++++++----- synapse/federation/persistence.py | 4 +- synapse/federation/sender/transaction_manager.py | 9 +- synapse/federation/transport/client.py | 2 +- synapse/federation/transport/server.py | 11 +-- synapse/federation/units.py | 90 ++++++++------------ synapse/util/jsonobject.py | 102 ----------------------- 8 files changed, 75 insertions(+), 194 deletions(-) create mode 100644 changelog.d/10542.misc delete mode 100644 synapse/util/jsonobject.py diff --git a/changelog.d/10542.misc b/changelog.d/10542.misc new file mode 100644 index 0000000000..44b70b4730 --- /dev/null +++ b/changelog.d/10542.misc @@ -0,0 +1 @@ +Convert `Transaction` and `Edu` objects to attrs. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 145b9161d9..0385aadefa 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -195,13 +195,17 @@ class FederationServer(FederationBase): origin, room_id, versions, limit ) - res = self._transaction_from_pdus(pdus).get_dict() + res = self._transaction_dict_from_pdus(pdus) return 200, res async def on_incoming_transaction( - self, origin: str, transaction_data: JsonDict - ) -> Tuple[int, Dict[str, Any]]: + self, + origin: str, + transaction_id: str, + destination: str, + transaction_data: JsonDict, + ) -> Tuple[int, JsonDict]: # If we receive a transaction we should make sure that kick off handling # any old events in the staging area. if not self._started_handling_of_staged_events: @@ -212,8 +216,14 @@ class FederationServer(FederationBase): # accurate as possible. request_time = self._clock.time_msec() - transaction = Transaction(**transaction_data) - transaction_id = transaction.transaction_id # type: ignore + transaction = Transaction( + transaction_id=transaction_id, + destination=destination, + origin=origin, + origin_server_ts=transaction_data.get("origin_server_ts"), # type: ignore + pdus=transaction_data.get("pdus"), # type: ignore + edus=transaction_data.get("edus"), + ) if not transaction_id: raise Exception("Transaction missing transaction_id") @@ -221,9 +231,7 @@ class FederationServer(FederationBase): logger.debug("[%s] Got transaction", transaction_id) # Reject malformed transactions early: reject if too many PDUs/EDUs - if len(transaction.pdus) > 50 or ( # type: ignore - hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore - ): + if len(transaction.pdus) > 50 or len(transaction.edus) > 100: logger.info("Transaction PDU or EDU count too large. Returning 400") return 400, {} @@ -263,7 +271,7 @@ class FederationServer(FederationBase): # CRITICAL SECTION: the first thing we must do (before awaiting) is # add an entry to _active_transactions. assert origin not in self._active_transactions - self._active_transactions[origin] = transaction.transaction_id # type: ignore + self._active_transactions[origin] = transaction.transaction_id try: result = await self._handle_incoming_transaction( @@ -291,11 +299,11 @@ class FederationServer(FederationBase): if response: logger.debug( "[%s] We've already responded to this request", - transaction.transaction_id, # type: ignore + transaction.transaction_id, ) return response - logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore + logger.debug("[%s] Transaction is new", transaction.transaction_id) # We process PDUs and EDUs in parallel. This is important as we don't # want to block things like to device messages from reaching clients @@ -334,7 +342,7 @@ class FederationServer(FederationBase): report back to the sending server. """ - received_pdus_counter.inc(len(transaction.pdus)) # type: ignore + received_pdus_counter.inc(len(transaction.pdus)) origin_host, _ = parse_server_name(origin) @@ -342,7 +350,7 @@ class FederationServer(FederationBase): newest_pdu_ts = 0 - for p in transaction.pdus: # type: ignore + for p in transaction.pdus: # FIXME (richardv): I don't think this works: # https://github.com/matrix-org/synapse/issues/8429 if "unsigned" in p: @@ -436,10 +444,10 @@ class FederationServer(FederationBase): return pdu_results - async def _handle_edus_in_txn(self, origin: str, transaction: Transaction): + async def _handle_edus_in_txn(self, origin: str, transaction: Transaction) -> None: """Process the EDUs in a received transaction.""" - async def _process_edu(edu_dict): + async def _process_edu(edu_dict: JsonDict) -> None: received_edus_counter.inc() edu = Edu( @@ -452,7 +460,7 @@ class FederationServer(FederationBase): await concurrently_execute( _process_edu, - getattr(transaction, "edus", []), + transaction.edus, TRANSACTION_CONCURRENCY_LIMIT, ) @@ -538,7 +546,7 @@ class FederationServer(FederationBase): pdu = await self.handler.get_persisted_pdu(origin, event_id) if pdu: - return 200, self._transaction_from_pdus([pdu]).get_dict() + return 200, self._transaction_dict_from_pdus([pdu]) else: return 404, "" @@ -879,18 +887,20 @@ class FederationServer(FederationBase): ts_now_ms = self._clock.time_msec() return await self.store.get_user_id_for_open_id_token(token, ts_now_ms) - def _transaction_from_pdus(self, pdu_list: List[EventBase]) -> Transaction: + def _transaction_dict_from_pdus(self, pdu_list: List[EventBase]) -> JsonDict: """Returns a new Transaction containing the given PDUs suitable for transmission. """ time_now = self._clock.time_msec() pdus = [p.get_pdu_json(time_now) for p in pdu_list] return Transaction( + # Just need a dummy transaction ID and destination since it won't be used. + transaction_id="", origin=self.server_name, pdus=pdus, origin_server_ts=int(time_now), - destination=None, - ) + destination="", + ).get_dict() async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None: """Process a PDU received in a federation /send/ transaction. diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 2f9c9bc2cd..4fead6ca29 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -45,7 +45,7 @@ class TransactionActions: `None` if we have not previously responded to this transaction or a 2-tuple of `(int, dict)` representing the response code and response body. """ - transaction_id = transaction.transaction_id # type: ignore + transaction_id = transaction.transaction_id if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") @@ -56,7 +56,7 @@ class TransactionActions: self, origin: str, transaction: Transaction, code: int, response: JsonDict ) -> None: """Persist how we responded to a transaction.""" - transaction_id = transaction.transaction_id # type: ignore + transaction_id = transaction.transaction_id if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 72a635830b..dc555cca0b 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -27,6 +27,7 @@ from synapse.logging.opentracing import ( tags, whitelisted_homeserver, ) +from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.metrics import measure_func @@ -104,13 +105,13 @@ class TransactionManager: len(edus), ) - transaction = Transaction.create_new( + transaction = Transaction( origin_server_ts=int(self.clock.time_msec()), transaction_id=txn_id, origin=self._server_name, destination=destination, - pdus=pdus, - edus=edus, + pdus=[p.get_pdu_json() for p in pdus], + edus=[edu.get_dict() for edu in edus], ) self._next_txn_id += 1 @@ -131,7 +132,7 @@ class TransactionManager: # FIXME (richardv): I also believe it no longer works. We (now?) store # "age_ts" in "unsigned" rather than at the top level. See # https://github.com/matrix-org/synapse/issues/8429. - def json_data_cb(): + def json_data_cb() -> JsonDict: data = transaction.get_dict() now = int(self.clock.time_msec()) if "pdus" in data: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 6a8d3ad4fe..90a7c16b62 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -143,7 +143,7 @@ class TransportLayerClient: """Sends the given Transaction to its destination Args: - transaction (Transaction) + transaction Returns: Succeeds when we get a 2xx HTTP response. The result diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 5e059d6e09..640f46fff6 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -450,21 +450,12 @@ class FederationSendServlet(BaseFederationServerServlet): len(transaction_data.get("edus", [])), ) - # We should ideally be getting this from the security layer. - # origin = body["origin"] - - # Add some extra data to the transaction dict that isn't included - # in the request body. - transaction_data.update( - transaction_id=transaction_id, destination=self.server_name - ) - except Exception as e: logger.exception(e) return 400, {"error": "Invalid transaction"} code, response = await self.handler.on_incoming_transaction( - origin, transaction_data + origin, transaction_id, self.server_name, transaction_data ) return code, response diff --git a/synapse/federation/units.py b/synapse/federation/units.py index c83a261918..b9b12fbea5 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -17,18 +17,17 @@ server protocol. """ import logging -from typing import Optional +from typing import List, Optional import attr from synapse.types import JsonDict -from synapse.util.jsonobject import JsonEncodedObject logger = logging.getLogger(__name__) -@attr.s(slots=True) -class Edu(JsonEncodedObject): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class Edu: """An Edu represents a piece of data sent from one homeserver to another. In comparison to Pdus, Edus are not persisted for a long time on disk, are @@ -36,10 +35,10 @@ class Edu(JsonEncodedObject): internal ID or previous references graph. """ - edu_type = attr.ib(type=str) - content = attr.ib(type=dict) - origin = attr.ib(type=str) - destination = attr.ib(type=str) + edu_type: str + content: dict + origin: str + destination: str def get_dict(self) -> JsonDict: return { @@ -55,14 +54,21 @@ class Edu(JsonEncodedObject): "destination": self.destination, } - def get_context(self): + def get_context(self) -> str: return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}") - def strip_context(self): + def strip_context(self) -> None: getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}" -class Transaction(JsonEncodedObject): +def _none_to_list(edus: Optional[List[JsonDict]]) -> List[JsonDict]: + if edus is None: + return [] + return edus + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class Transaction: """A transaction is a list of Pdus and Edus to be sent to a remote home server with some extra metadata. @@ -78,47 +84,21 @@ class Transaction(JsonEncodedObject): """ - valid_keys = [ - "transaction_id", - "origin", - "destination", - "origin_server_ts", - "previous_ids", - "pdus", - "edus", - ] - - internal_keys = ["transaction_id", "destination"] - - required_keys = [ - "transaction_id", - "origin", - "destination", - "origin_server_ts", - "pdus", - ] - - def __init__(self, transaction_id=None, pdus: Optional[list] = None, **kwargs): - """If we include a list of pdus then we decode then as PDU's - automatically. - """ - - # If there's no EDUs then remove the arg - if "edus" in kwargs and not kwargs["edus"]: - del kwargs["edus"] - - super().__init__(transaction_id=transaction_id, pdus=pdus or [], **kwargs) - - @staticmethod - def create_new(pdus, **kwargs): - """Used to create a new transaction. Will auto fill out - transaction_id and origin_server_ts keys. - """ - if "origin_server_ts" not in kwargs: - raise KeyError("Require 'origin_server_ts' to construct a Transaction") - if "transaction_id" not in kwargs: - raise KeyError("Require 'transaction_id' to construct a Transaction") - - kwargs["pdus"] = [p.get_pdu_json() for p in pdus] - - return Transaction(**kwargs) + # Required keys. + transaction_id: str + origin: str + destination: str + origin_server_ts: int + pdus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list) + edus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list) + + def get_dict(self) -> JsonDict: + """A JSON-ready dictionary of valid keys which aren't internal.""" + result = { + "origin": self.origin, + "origin_server_ts": self.origin_server_ts, + "pdus": self.pdus, + } + if self.edus: + result["edus"] = self.edus + return result diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py deleted file mode 100644 index abc12f0837..0000000000 --- a/synapse/util/jsonobject.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class JsonEncodedObject: - """A common base class for defining protocol units that are represented - as JSON. - - Attributes: - unrecognized_keys (dict): A dict containing all the key/value pairs we - don't recognize. - """ - - valid_keys = [] # keys we will store - """A list of strings that represent keys we know about - and can handle. If we have values for these keys they will be - included in the `dictionary` instance variable. - """ - - internal_keys = [] # keys to ignore while building dict - """A list of strings that should *not* be encoded into JSON. - """ - - required_keys = [] - """A list of strings that we require to exist. If they are not given upon - construction it raises an exception. - """ - - def __init__(self, **kwargs): - """Takes the dict of `kwargs` and loads all keys that are *valid* - (i.e., are included in the `valid_keys` list) into the dictionary` - instance variable. - - Any keys that aren't recognized are added to the `unrecognized_keys` - attribute. - - Args: - **kwargs: Attributes associated with this protocol unit. - """ - for required_key in self.required_keys: - if required_key not in kwargs: - raise RuntimeError("Key %s is required" % required_key) - - self.unrecognized_keys = {} # Keys we were given not listed as valid - for k, v in kwargs.items(): - if k in self.valid_keys or k in self.internal_keys: - self.__dict__[k] = v - else: - self.unrecognized_keys[k] = v - - def get_dict(self): - """Converts this protocol unit into a :py:class:`dict`, ready to be - encoded as JSON. - - The keys it encodes are: `valid_keys` - `internal_keys` - - Returns - dict - """ - d = { - k: _encode(v) - for (k, v) in self.__dict__.items() - if k in self.valid_keys and k not in self.internal_keys - } - d.update(self.unrecognized_keys) - return d - - def get_internal_dict(self): - d = { - k: _encode(v, internal=True) - for (k, v) in self.__dict__.items() - if k in self.valid_keys - } - d.update(self.unrecognized_keys) - return d - - def __str__(self): - return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__)) - - -def _encode(obj, internal=False): - if type(obj) is list: - return [_encode(o, internal=internal) for o in obj] - - if isinstance(obj, JsonEncodedObject): - if internal: - return obj.get_internal_dict() - else: - return obj.get_dict() - - return obj -- cgit 1.5.1 From 0c246dd4a09e21e677934d0d83efa573c9127a6f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 9 Aug 2021 04:46:39 -0400 Subject: Support MSC3289: Room version 8 (#10449) This adds support for MSC3289: room version 8. This is room version 7 + MSC3083. --- changelog.d/10449.bugfix | 1 + changelog.d/10449.feature | 1 + synapse/api/constants.py | 2 +- synapse/api/room_versions.py | 28 ++++++++++++++-------------- synapse/event_auth.py | 5 +---- synapse/handlers/event_auth.py | 2 +- tests/events/test_utils.py | 2 +- tests/handlers/test_space_summary.py | 12 ++++++------ tests/test_event_auth.py | 18 +++++++++--------- 9 files changed, 35 insertions(+), 36 deletions(-) create mode 100644 changelog.d/10449.bugfix create mode 100644 changelog.d/10449.feature diff --git a/changelog.d/10449.bugfix b/changelog.d/10449.bugfix new file mode 100644 index 0000000000..c5e23ba019 --- /dev/null +++ b/changelog.d/10449.bugfix @@ -0,0 +1 @@ +Mark the experimental room version from [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) as unstable. diff --git a/changelog.d/10449.feature b/changelog.d/10449.feature new file mode 100644 index 0000000000..a45a17cb28 --- /dev/null +++ b/changelog.d/10449.feature @@ -0,0 +1 @@ +Support [MSC3289: room version 8](https://github.com/matrix-org/matrix-doc/pull/3289). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index a986fdb47a..e0e24fddac 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -62,7 +62,7 @@ class JoinRules: INVITE = "invite" PRIVATE = "private" # As defined for MSC3083. - MSC3083_RESTRICTED = "restricted" + RESTRICTED = "restricted" class RestrictedJoinRuleTypes: diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index bc678efe49..f32a40ba4a 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -177,9 +177,9 @@ class RoomVersions: msc2403_knocking=False, msc2716_historical=False, ) - MSC3083 = RoomVersion( - "org.matrix.msc3083.v2", - RoomDisposition.UNSTABLE, + V7 = RoomVersion( + "7", + RoomDisposition.STABLE, EventFormatVersions.V3, StateResolutionVersions.V2, enforce_key_validity=True, @@ -187,13 +187,13 @@ class RoomVersions: strict_canonicaljson=True, limit_notifications_power_levels=True, msc2176_redaction_rules=False, - msc3083_join_rules=True, - msc2403_knocking=False, + msc3083_join_rules=False, + msc2403_knocking=True, msc2716_historical=False, ) - V7 = RoomVersion( - "7", - RoomDisposition.STABLE, + MSC2716 = RoomVersion( + "org.matrix.msc2716", + RoomDisposition.UNSTABLE, EventFormatVersions.V3, StateResolutionVersions.V2, enforce_key_validity=True, @@ -203,10 +203,10 @@ class RoomVersions: msc2176_redaction_rules=False, msc3083_join_rules=False, msc2403_knocking=True, - msc2716_historical=False, + msc2716_historical=True, ) - MSC2716 = RoomVersion( - "org.matrix.msc2716", + V8 = RoomVersion( + "8", RoomDisposition.STABLE, EventFormatVersions.V3, StateResolutionVersions.V2, @@ -215,9 +215,9 @@ class RoomVersions: strict_canonicaljson=True, limit_notifications_power_levels=True, msc2176_redaction_rules=False, - msc3083_join_rules=False, + msc3083_join_rules=True, msc2403_knocking=True, - msc2716_historical=True, + msc2716_historical=False, ) @@ -231,9 +231,9 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = { RoomVersions.V5, RoomVersions.V6, RoomVersions.MSC2176, - RoomVersions.MSC3083, RoomVersions.V7, RoomVersions.MSC2716, + RoomVersions.V8, ) } diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 4c92e9a2d4..c3a0c10499 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -370,10 +370,7 @@ def _is_membership_change_allowed( raise AuthError(403, "You are banned from this room") elif join_rule == JoinRules.PUBLIC: pass - elif ( - room_version.msc3083_join_rules - and join_rule == JoinRules.MSC3083_RESTRICTED - ): + elif room_version.msc3083_join_rules and join_rule == JoinRules.RESTRICTED: # This is the same as public, but the event must contain a reference # to the server who authorised the join. If the event does not contain # the proper content it is rejected. diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 53fac1f8a3..e2410e482f 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -240,7 +240,7 @@ class EventAuthHandler: # If the join rule is not restricted, this doesn't apply. join_rules_event = await self._store.get_event(join_rules_event_id) - return join_rules_event.content.get("join_rule") == JoinRules.MSC3083_RESTRICTED + return join_rules_event.content.get("join_rule") == JoinRules.RESTRICTED async def get_rooms_that_allow_join( self, state_ids: StateMap[str] diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index e2a5fc018c..7a826c086e 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -341,7 +341,7 @@ class PruneEventTestCase(unittest.TestCase): "signatures": {}, "unsigned": {}, }, - room_version=RoomVersions.MSC3083, + room_version=RoomVersions.V8, ) diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index 3f73ad7f94..01975c13d4 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -231,13 +231,13 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): invited_room = self._create_room_with_join_rule(JoinRules.INVITE) self.helper.invite(invited_room, targ=user2, tok=self.token) restricted_room = self._create_room_with_join_rule( - JoinRules.MSC3083_RESTRICTED, - room_version=RoomVersions.MSC3083.identifier, + JoinRules.RESTRICTED, + room_version=RoomVersions.V8.identifier, allow=[], ) restricted_accessible_room = self._create_room_with_join_rule( - JoinRules.MSC3083_RESTRICTED, - room_version=RoomVersions.MSC3083.identifier, + JoinRules.RESTRICTED, + room_version=RoomVersions.V8.identifier, allow=[ { "type": RestrictedJoinRuleTypes.ROOM_MEMBERSHIP, @@ -459,13 +459,13 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): { "room_id": restricted_room, "world_readable": False, - "join_rules": JoinRules.MSC3083_RESTRICTED, + "join_rules": JoinRules.RESTRICTED, "allowed_spaces": [], }, { "room_id": restricted_accessible_room, "world_readable": False, - "join_rules": JoinRules.MSC3083_RESTRICTED, + "join_rules": JoinRules.RESTRICTED, "allowed_spaces": [self.room], }, { diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index e5550aec4d..6ebd01bcbe 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -384,7 +384,7 @@ class EventAuthTestCase(unittest.TestCase): }, ) event_auth.check( - RoomVersions.MSC3083, + RoomVersions.V8, authorised_join_event, auth_events, do_sig_check=False, @@ -400,7 +400,7 @@ class EventAuthTestCase(unittest.TestCase): "@inviter:foo.test" ) event_auth.check( - RoomVersions.MSC3083, + RoomVersions.V8, _join_event( pleb, additional_content={ @@ -414,7 +414,7 @@ class EventAuthTestCase(unittest.TestCase): # A join which is missing an authorised server is rejected. with self.assertRaises(AuthError): event_auth.check( - RoomVersions.MSC3083, + RoomVersions.V8, _join_event(pleb), auth_events, do_sig_check=False, @@ -427,7 +427,7 @@ class EventAuthTestCase(unittest.TestCase): ) with self.assertRaises(AuthError): event_auth.check( - RoomVersions.MSC3083, + RoomVersions.V8, _join_event( pleb, additional_content={ @@ -442,7 +442,7 @@ class EventAuthTestCase(unittest.TestCase): # *would* be valid, but is sent be a different user.) with self.assertRaises(AuthError): event_auth.check( - RoomVersions.MSC3083, + RoomVersions.V8, _member_event( pleb, "join", @@ -459,7 +459,7 @@ class EventAuthTestCase(unittest.TestCase): auth_events[("m.room.member", pleb)] = _member_event(pleb, "ban") with self.assertRaises(AuthError): event_auth.check( - RoomVersions.MSC3083, + RoomVersions.V8, authorised_join_event, auth_events, do_sig_check=False, @@ -468,7 +468,7 @@ class EventAuthTestCase(unittest.TestCase): # A user who left can re-join. auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") event_auth.check( - RoomVersions.MSC3083, + RoomVersions.V8, authorised_join_event, auth_events, do_sig_check=False, @@ -478,7 +478,7 @@ class EventAuthTestCase(unittest.TestCase): # be authorised since the user is already joined.) auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") event_auth.check( - RoomVersions.MSC3083, + RoomVersions.V8, _join_event(pleb), auth_events, do_sig_check=False, @@ -490,7 +490,7 @@ class EventAuthTestCase(unittest.TestCase): pleb, "invite", sender=creator ) event_auth.check( - RoomVersions.MSC3083, + RoomVersions.V8, _join_event(pleb), auth_events, do_sig_check=False, -- cgit 1.5.1 From ad35b7739e72fe198fa78fa4279f58cacfc9fa37 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 9 Aug 2021 13:41:29 +0100 Subject: 1.40.0rc3 --- CHANGES.md | 21 +++++++++++++++++++++ changelog.d/10449.bugfix | 1 - changelog.d/10449.feature | 1 - changelog.d/10543.doc | 1 - debian/changelog | 6 ++++++ synapse/__init__.py | 2 +- 6 files changed, 28 insertions(+), 4 deletions(-) delete mode 100644 changelog.d/10449.bugfix delete mode 100644 changelog.d/10449.feature delete mode 100644 changelog.d/10543.doc diff --git a/CHANGES.md b/CHANGES.md index 62ea684e58..b04abbeb4d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,24 @@ +Synapse 1.40.0rc3 (2021-08-09) +============================== + +Features +-------- + +- Support [MSC3289: room version 8](https://github.com/matrix-org/matrix-doc/pull/3289). ([\#10449](https://github.com/matrix-org/synapse/issues/10449)) + + +Bugfixes +-------- + +- Mark the experimental room version from [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) as unstable. ([\#10449](https://github.com/matrix-org/synapse/issues/10449)) + + +Improved Documentation +---------------------- + +- Fix broken links in `upgrade.md`. Contributed by @dklimpel. ([\#10543](https://github.com/matrix-org/synapse/issues/10543)) + + Synapse 1.40.0rc2 (2021-08-04) ============================== diff --git a/changelog.d/10449.bugfix b/changelog.d/10449.bugfix deleted file mode 100644 index c5e23ba019..0000000000 --- a/changelog.d/10449.bugfix +++ /dev/null @@ -1 +0,0 @@ -Mark the experimental room version from [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) as unstable. diff --git a/changelog.d/10449.feature b/changelog.d/10449.feature deleted file mode 100644 index a45a17cb28..0000000000 --- a/changelog.d/10449.feature +++ /dev/null @@ -1 +0,0 @@ -Support [MSC3289: room version 8](https://github.com/matrix-org/matrix-doc/pull/3289). diff --git a/changelog.d/10543.doc b/changelog.d/10543.doc deleted file mode 100644 index 6c06722eb4..0000000000 --- a/changelog.d/10543.doc +++ /dev/null @@ -1 +0,0 @@ -Fix broken links in `upgrade.md`. Contributed by @dklimpel. diff --git a/debian/changelog b/debian/changelog index c523101f9a..7b44341bc6 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.40.0~rc3) stable; urgency=medium + + * New synapse release 1.40.0~rc3. + + -- Synapse Packaging team Mon, 09 Aug 2021 13:41:08 +0100 + matrix-synapse-py3 (1.40.0~rc2) stable; urgency=medium * New synapse release 1.40.0~rc2. diff --git a/synapse/__init__.py b/synapse/__init__.py index da52463531..5cca899f7d 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.40.0rc2" +__version__ = "1.40.0rc3" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.5.1 From 189c055eb6d8a0db7aa520ecec23819d15bfaa26 Mon Sep 17 00:00:00 2001 From: Drew Short Date: Mon, 9 Aug 2021 10:12:53 -0500 Subject: Moved homeserver documentation above reverse proxy examples (#10551) Signed-off-by: Drew Short --- changelog.d/10551.doc | 1 + docs/reverse_proxy.md | 23 +++++++++++++---------- 2 files changed, 14 insertions(+), 10 deletions(-) create mode 100644 changelog.d/10551.doc diff --git a/changelog.d/10551.doc b/changelog.d/10551.doc new file mode 100644 index 0000000000..4a2b0785bf --- /dev/null +++ b/changelog.d/10551.doc @@ -0,0 +1 @@ +Updated the reverse proxy documentation to highlight the homserver configuration that is needed to make Synapse aware that is is intentionally reverse proxied. diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md index 76bb45aff2..5f8d20129e 100644 --- a/docs/reverse_proxy.md +++ b/docs/reverse_proxy.md @@ -33,6 +33,19 @@ Let's assume that we expect clients to connect to our server at `https://example.com:8448`. The following sections detail the configuration of the reverse proxy and the homeserver. + +## Homeserver Configuration + +The HTTP configuration will need to be updated for Synapse to correctly record +client IP addresses and generate redirect URLs while behind a reverse proxy. + +In `homeserver.yaml` set `x_forwarded: true` in the port 8008 section and +consider setting `bind_addresses: ['127.0.0.1']` so that the server only +listens to traffic on localhost. (Do not change `bind_addresses` to `127.0.0.1` +when using a containerized Synapse, as that will prevent it from responding +to proxied traffic.) + + ## Reverse-proxy configuration examples **NOTE**: You only need one of these. @@ -239,16 +252,6 @@ relay "matrix_federation" { } ``` -## Homeserver Configuration - -You will also want to set `bind_addresses: ['127.0.0.1']` and -`x_forwarded: true` for port 8008 in `homeserver.yaml` to ensure that -client IP addresses are recorded correctly. - -Having done so, you can then use `https://matrix.example.com` (instead -of `https://matrix.example.com:8448`) as the "Custom server" when -connecting to Synapse from a client. - ## Health check endpoint -- cgit 1.5.1 From 6b61debf5cf571ae9e230b102c758865eee2a788 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Mon, 9 Aug 2021 18:21:04 +0200 Subject: Do not remove `status_msg` when user going offline (#10550) Signed-off-by: Dirk Klimpel dirk@klimpel.org --- changelog.d/10550.bugfix | 1 + synapse/handlers/presence.py | 11 +-- tests/handlers/test_presence.py | 163 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 166 insertions(+), 9 deletions(-) create mode 100644 changelog.d/10550.bugfix diff --git a/changelog.d/10550.bugfix b/changelog.d/10550.bugfix new file mode 100644 index 0000000000..2e1b7c8bbb --- /dev/null +++ b/changelog.d/10550.bugfix @@ -0,0 +1 @@ +Fix longstanding bug which caused the user "status" to be reset when the user went offline. Contributed by @dklimpel. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 016c5df2ca..7ca14e1d84 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1184,8 +1184,7 @@ class PresenceHandler(BasePresenceHandler): new_fields = {"state": presence} if not ignore_status_msg: - msg = status_msg if presence != PresenceState.OFFLINE else None - new_fields["status_msg"] = msg + new_fields["status_msg"] = status_msg if presence == PresenceState.ONLINE or ( presence == PresenceState.BUSY and self._busy_presence_enabled @@ -1478,7 +1477,7 @@ def format_user_presence_state( content["user_id"] = state.user_id if state.last_active_ts: content["last_active_ago"] = now - state.last_active_ts - if state.status_msg and state.state != PresenceState.OFFLINE: + if state.status_msg: content["status_msg"] = state.status_msg if state.state == PresenceState.ONLINE: content["currently_active"] = state.currently_active @@ -1840,9 +1839,7 @@ def handle_timeout( # don't set them as offline. sync_or_active = max(state.last_user_sync_ts, state.last_active_ts) if now - sync_or_active > SYNC_ONLINE_TIMEOUT: - state = state.copy_and_replace( - state=PresenceState.OFFLINE, status_msg=None - ) + state = state.copy_and_replace(state=PresenceState.OFFLINE) changed = True else: # We expect to be poked occasionally by the other side. @@ -1850,7 +1847,7 @@ def handle_timeout( # no one gets stuck online forever. if now - state.last_federation_update_ts > FEDERATION_TIMEOUT: # The other side seems to have disappeared. - state = state.copy_and_replace(state=PresenceState.OFFLINE, status_msg=None) + state = state.copy_and_replace(state=PresenceState.OFFLINE) changed = True return state if changed else None diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 18e92e90d7..29845a80da 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Optional from unittest.mock import Mock, call from signedjson.key import generate_signing_key @@ -339,8 +339,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): class PresenceTimeoutTestCase(unittest.TestCase): + """Tests different timers and that the timer does not change `status_msg` of user.""" + def test_idle_timer(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -348,12 +351,14 @@ class PresenceTimeoutTestCase(unittest.TestCase): state=PresenceState.ONLINE, last_active_ts=now - IDLE_TIMER - 1, last_user_sync_ts=now, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.UNAVAILABLE) + self.assertEquals(new_state.status_msg, status_msg) def test_busy_no_idle(self): """ @@ -361,6 +366,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): presence state into unavailable. """ user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -368,15 +374,18 @@ class PresenceTimeoutTestCase(unittest.TestCase): state=PresenceState.BUSY, last_active_ts=now - IDLE_TIMER - 1, last_user_sync_ts=now, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.BUSY) + self.assertEquals(new_state.status_msg, status_msg) def test_sync_timeout(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -384,15 +393,18 @@ class PresenceTimeoutTestCase(unittest.TestCase): state=PresenceState.ONLINE, last_active_ts=0, last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.OFFLINE) + self.assertEquals(new_state.status_msg, status_msg) def test_sync_online(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -400,6 +412,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): state=PresenceState.ONLINE, last_active_ts=now - SYNC_ONLINE_TIMEOUT - 1, last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, + status_msg=status_msg, ) new_state = handle_timeout( @@ -408,9 +421,11 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.ONLINE) + self.assertEquals(new_state.status_msg, status_msg) def test_federation_ping(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -419,12 +434,13 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_active_ts=now, last_user_sync_ts=now, last_federation_update_ts=now - FEDERATION_PING_INTERVAL - 1, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(new_state, new_state) + self.assertEquals(state, new_state) def test_no_timeout(self): user_id = "@foo:bar" @@ -444,6 +460,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): def test_federation_timeout(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -452,6 +469,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_active_ts=now, last_user_sync_ts=now, last_federation_update_ts=now - FEDERATION_TIMEOUT - 1, + status_msg=status_msg, ) new_state = handle_timeout( @@ -460,9 +478,11 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.OFFLINE) + self.assertEquals(new_state.status_msg, status_msg) def test_last_active(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -471,6 +491,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_active_ts=now - LAST_ACTIVE_GRANULARITY - 1, last_user_sync_ts=now, last_federation_update_ts=now, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) @@ -516,6 +537,144 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase): ) self.assertEqual(state.state, PresenceState.OFFLINE) + def test_user_goes_offline_by_timeout_status_msg_remain(self): + """Test that if a user doesn't update the records for a while + users presence goes `OFFLINE` because of timeout and `status_msg` remains. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Check that if we wait a while without telling the handler the user has + # stopped syncing that their presence state doesn't get timed out. + self.reactor.advance(SYNC_ONLINE_TIMEOUT / 2) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.ONLINE) + self.assertEqual(state.status_msg, status_msg) + + # Check that if the timeout fires, then the syncing user gets timed out + self.reactor.advance(SYNC_ONLINE_TIMEOUT) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + # status_msg should remain even after going offline + self.assertEqual(state.state, PresenceState.OFFLINE) + self.assertEqual(state.status_msg, status_msg) + + def test_user_goes_offline_manually_with_no_status_msg(self): + """Test that if a user change presence manually to `OFFLINE` + and no status is set, that `status_msg` is `None`. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Mark user as offline + self.get_success( + self.presence_handler.set_state( + UserID.from_string(user_id), {"presence": PresenceState.OFFLINE} + ) + ) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.OFFLINE) + self.assertEqual(state.status_msg, None) + + def test_user_goes_offline_manually_with_status_msg(self): + """Test that if a user change presence manually to `OFFLINE` + and a status is set, that `status_msg` appears. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Mark user as offline + self._set_presencestate_with_status_msg( + user_id, PresenceState.OFFLINE, "And now here." + ) + + def test_user_reset_online_with_no_status(self): + """Test that if a user set again the presence manually + and no status is set, that `status_msg` is `None`. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Mark user as online again + self.get_success( + self.presence_handler.set_state( + UserID.from_string(user_id), {"presence": PresenceState.ONLINE} + ) + ) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + # status_msg should remain even after going offline + self.assertEqual(state.state, PresenceState.ONLINE) + self.assertEqual(state.status_msg, None) + + def test_set_presence_with_status_msg_none(self): + """Test that if a user set again the presence manually + and status is `None`, that `status_msg` is `None`. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Mark user as online and `status_msg = None` + self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None) + + def _set_presencestate_with_status_msg( + self, user_id: str, state: PresenceState, status_msg: Optional[str] + ): + """Set a PresenceState and status_msg and check the result. + + Args: + user_id: User for that the status is to be set. + PresenceState: The new PresenceState. + status_msg: Status message that is to be set. + """ + self.get_success( + self.presence_handler.set_state( + UserID.from_string(user_id), + {"presence": state, "status_msg": status_msg}, + ) + ) + + new_state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(new_state.state, state) + self.assertEqual(new_state.status_msg, status_msg) + class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): -- cgit 1.5.1 From 7afb615839a2df05d39f87718016d278ebdadf5c Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 9 Aug 2021 20:23:31 -0500 Subject: When redacting, keep event fields around that maintain the historical event structure intact (MSC2716) (#10538) * Keep event fields that maintain the historical event structure intact Fix https://github.com/matrix-org/synapse/issues/10521 * Add changelog * Bump room version * Better changelog text * Fix up room version after develop merge --- changelog.d/10538.feature | 1 + synapse/api/room_versions.py | 37 ++++++++++++++++++++++++++++++++----- synapse/events/utils.py | 8 +++++++- 3 files changed, 40 insertions(+), 6 deletions(-) create mode 100644 changelog.d/10538.feature diff --git a/changelog.d/10538.feature b/changelog.d/10538.feature new file mode 100644 index 0000000000..120c8e8ca0 --- /dev/null +++ b/changelog.d/10538.feature @@ -0,0 +1 @@ +Add support for new redaction rules for historical events specified in [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716). diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index f32a40ba4a..11280c4462 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -76,6 +76,8 @@ class RoomVersion: # MSC2716: Adds m.room.power_levels -> content.historical field to control # whether "insertion", "chunk", "marker" events can be sent msc2716_historical = attr.ib(type=bool) + # MSC2716: Adds support for redacting "insertion", "chunk", and "marker" events + msc2716_redactions = attr.ib(type=bool) class RoomVersions: @@ -92,6 +94,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) V2 = RoomVersion( "2", @@ -106,6 +109,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) V3 = RoomVersion( "3", @@ -120,6 +124,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) V4 = RoomVersion( "4", @@ -134,6 +139,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) V5 = RoomVersion( "5", @@ -148,6 +154,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) V6 = RoomVersion( "6", @@ -162,6 +169,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) MSC2176 = RoomVersion( "org.matrix.msc2176", @@ -176,6 +184,7 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=False, msc2716_historical=False, + msc2716_redactions=False, ) V7 = RoomVersion( "7", @@ -190,6 +199,22 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=True, msc2716_historical=False, + msc2716_redactions=False, + ) + V8 = RoomVersion( + "8", + RoomDisposition.STABLE, + EventFormatVersions.V3, + StateResolutionVersions.V2, + enforce_key_validity=True, + special_case_aliases_auth=False, + strict_canonicaljson=True, + limit_notifications_power_levels=True, + msc2176_redaction_rules=False, + msc3083_join_rules=True, + msc2403_knocking=True, + msc2716_historical=False, + msc2716_redactions=False, ) MSC2716 = RoomVersion( "org.matrix.msc2716", @@ -204,10 +229,11 @@ class RoomVersions: msc3083_join_rules=False, msc2403_knocking=True, msc2716_historical=True, + msc2716_redactions=False, ) - V8 = RoomVersion( - "8", - RoomDisposition.STABLE, + MSC2716v2 = RoomVersion( + "org.matrix.msc2716v2", + RoomDisposition.UNSTABLE, EventFormatVersions.V3, StateResolutionVersions.V2, enforce_key_validity=True, @@ -215,9 +241,10 @@ class RoomVersions: strict_canonicaljson=True, limit_notifications_power_levels=True, msc2176_redaction_rules=False, - msc3083_join_rules=True, + msc3083_join_rules=False, msc2403_knocking=True, - msc2716_historical=False, + msc2716_historical=True, + msc2716_redactions=True, ) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index a0c07f62f4..b6da2f60af 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -17,7 +17,7 @@ from typing import Any, Mapping, Union from frozendict import frozendict -from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion from synapse.util.async_helpers import yieldable_gather_results @@ -135,6 +135,12 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: add_fields("history_visibility") elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules: add_fields("redacts") + elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_INSERTION: + add_fields(EventContentFields.MSC2716_NEXT_CHUNK_ID) + elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_CHUNK: + add_fields(EventContentFields.MSC2716_CHUNK_ID) + elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_MARKER: + add_fields(EventContentFields.MSC2716_MARKER_INSERTION) allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys} -- cgit 1.5.1 From 9f7c038272318bab09535e85e6bb4345ed2f1368 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 10 Aug 2021 13:50:58 +0100 Subject: 1.40.0 --- CHANGES.md | 6 ++++++ debian/changelog | 6 ++++++ synapse/__init__.py | 2 +- 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index b04abbeb4d..0e5e052951 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,9 @@ +Synapse 1.40.0 (2021-08-10) +=========================== + +No significant changes. + + Synapse 1.40.0rc3 (2021-08-09) ============================== diff --git a/debian/changelog b/debian/changelog index 7b44341bc6..d3da448b0f 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.40.0) stable; urgency=medium + + * New synapse release 1.40.0. + + -- Synapse Packaging team Tue, 10 Aug 2021 13:50:48 +0100 + matrix-synapse-py3 (1.40.0~rc3) stable; urgency=medium * New synapse release 1.40.0~rc3. diff --git a/synapse/__init__.py b/synapse/__init__.py index 5cca899f7d..919293cd80 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.40.0rc3" +__version__ = "1.40.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.5.1 From 52bfa2d59a5c372dec3204b20efbe19953b53122 Mon Sep 17 00:00:00 2001 From: Hillery Shay Date: Tue, 10 Aug 2021 06:35:54 -0700 Subject: Update contributing.md to warn against rebasing an open PR. (#10563) Signed-off-by: H.Shay --- CONTRIBUTING.md | 1 + changelog.d/10563.misc | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelog.d/10563.misc diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e7eef23419..4486a4b2cd 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -252,6 +252,7 @@ To prepare a Pull Request, please: 4. on GitHub, [create the Pull Request](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/creating-a-pull-request); 5. add a [changelog entry](#changelog) and push it to your Pull Request; 6. for most contributors, that's all - however, if you are a member of the organization `matrix-org`, on GitHub, please request a review from `matrix.org / Synapse Core`. +7. if you need to update your PR, please avoid rebasing and just add new commits to your branch. ## Changelog diff --git a/changelog.d/10563.misc b/changelog.d/10563.misc new file mode 100644 index 0000000000..8e4e90c8f4 --- /dev/null +++ b/changelog.d/10563.misc @@ -0,0 +1 @@ +Update contributing.md to warn against rebasing an open PR. -- cgit 1.5.1 From 691593bf719edb4c8b0d7a6bee95fcb41d0c56ae Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 10 Aug 2021 10:56:54 -0400 Subject: Fix an edge-case with invited rooms over federation in the spaces summary. (#10560) If a room which the requesting user was invited to was queried over federation it will now properly appear in the spaces summary (instead of being stripped out by the requesting server). --- changelog.d/10560.feature | 1 + synapse/handlers/space_summary.py | 93 ++++++++++++++++-------------- tests/handlers/test_space_summary.py | 106 ++++++++++++++++++++++++++++------- 3 files changed, 138 insertions(+), 62 deletions(-) create mode 100644 changelog.d/10560.feature diff --git a/changelog.d/10560.feature b/changelog.d/10560.feature new file mode 100644 index 0000000000..ffc4e4289c --- /dev/null +++ b/changelog.d/10560.feature @@ -0,0 +1 @@ +Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index 2517f278b6..d04afe6c31 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -158,48 +158,10 @@ class SpaceSummaryHandler: room = room_entry.room fed_room_id = room_entry.room_id - # The room should only be included in the summary if: - # a. the user is in the room; - # b. the room is world readable; or - # c. the user could join the room, e.g. the join rules - # are set to public or the user is in a space that - # has been granted access to the room. - # - # Note that we know the user is not in the root room (which is - # why the remote call was made in the first place), but the user - # could be in one of the children rooms and we just didn't know - # about the link. - - # The API doesn't return the room version so assume that a - # join rule of knock is valid. - include_room = ( - room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK) - or room.get("world_readable") is True - ) - - # Check if the user is a member of any of the allowed spaces - # from the response. - allowed_rooms = room.get("allowed_room_ids") or room.get( - "allowed_spaces" - ) - if ( - not include_room - and allowed_rooms - and isinstance(allowed_rooms, list) - ): - include_room = await self._event_auth_handler.is_user_in_rooms( - allowed_rooms, requester - ) - - # Finally, if this isn't the requested room, check ourselves - # if we can access the room. - if not include_room and fed_room_id != queue_entry.room_id: - include_room = await self._is_room_accessible( - fed_room_id, requester, None - ) - # The user can see the room, include it! - if include_room: + if await self._is_remote_room_accessible( + requester, fed_room_id, room + ): # Before returning to the client, remove the allowed_room_ids # and allowed_spaces keys. room.pop("allowed_room_ids", None) @@ -336,7 +298,7 @@ class SpaceSummaryHandler: Returns: A room entry if the room should be returned. None, otherwise. """ - if not await self._is_room_accessible(room_id, requester, origin): + if not await self._is_local_room_accessible(room_id, requester, origin): return None room_entry = await self._build_room_entry(room_id, for_federation=bool(origin)) @@ -438,7 +400,7 @@ class SpaceSummaryHandler: return results - async def _is_room_accessible( + async def _is_local_room_accessible( self, room_id: str, requester: Optional[str], origin: Optional[str] ) -> bool: """ @@ -550,6 +512,51 @@ class SpaceSummaryHandler: ) return False + async def _is_remote_room_accessible( + self, requester: str, room_id: str, room: JsonDict + ) -> bool: + """ + Calculate whether the room received over federation should be shown in the spaces summary. + + It should be included if: + + * The requester is joined or can join the room (per MSC3173). + * The history visibility is set to world readable. + + Note that the local server is not in the requested room (which is why the + remote call was made in the first place), but the user could have access + due to an invite, etc. + + Args: + requester: The user requesting the summary. + room_id: The room ID returned over federation. + room: The summary of the child room returned over federation. + + Returns: + True if the room should be included in the spaces summary. + """ + # The API doesn't return the room version so assume that a + # join rule of knock is valid. + if ( + room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK) + or room.get("world_readable") is True + ): + return True + + # Check if the user is a member of any of the allowed spaces + # from the response. + allowed_rooms = room.get("allowed_room_ids") or room.get("allowed_spaces") + if allowed_rooms and isinstance(allowed_rooms, list): + if await self._event_auth_handler.is_user_in_rooms( + allowed_rooms, requester + ): + return True + + # Finally, check locally if we can access the room. The user might + # already be in the room (if it was a child room), or there might be a + # pending invite, etc. + return await self._is_local_room_accessible(room_id, requester, None) + async def _build_room_entry(self, room_id: str, for_federation: bool) -> JsonDict: """ Generate en entry suitable for the 'rooms' list in the summary response. diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index 6cc1a02e12..f470c81ea2 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -30,7 +30,7 @@ from synapse.handlers.space_summary import _child_events_comparison_key, _RoomEn from synapse.rest import admin from synapse.rest.client.v1 import login, room from synapse.server import HomeServer -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from tests import unittest @@ -149,6 +149,36 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): events, ) + def _poke_fed_invite(self, room_id: str, from_user: str) -> None: + """ + Creates a invite (as if received over federation) for the room from the + given hostname. + + Args: + room_id: The room ID to issue an invite for. + fed_hostname: The user to invite from. + """ + # Poke an invite over federation into the database. + fed_handler = self.hs.get_federation_handler() + fed_hostname = UserID.from_string(from_user).domain + event = make_event_from_dict( + { + "room_id": room_id, + "event_id": "!abcd:" + fed_hostname, + "type": EventTypes.Member, + "sender": from_user, + "state_key": self.user, + "content": {"membership": Membership.INVITE}, + "prev_events": [], + "auth_events": [], + "depth": 1, + "origin_server_ts": 1234, + } + ) + self.get_success( + fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6) + ) + def test_simple_space(self): """Test a simple space with a single room.""" result = self.get_success(self.handler.get_space_summary(self.user, self.space)) @@ -416,24 +446,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): joined_room = self.helper.create_room_as(self.user, tok=self.token) # Poke an invite over federation into the database. - fed_handler = self.hs.get_federation_handler() - event = make_event_from_dict( - { - "room_id": invited_room, - "event_id": "!abcd:" + fed_hostname, - "type": EventTypes.Member, - "sender": "@remote:" + fed_hostname, - "state_key": self.user, - "content": {"membership": Membership.INVITE}, - "prev_events": [], - "auth_events": [], - "depth": 1, - "origin_server_ts": 1234, - } - ) - self.get_success( - fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6) - ) + self._poke_fed_invite(invited_room, "@remote:" + fed_hostname) async def summarize_remote_room( _self, room, suggested_only, max_children, exclude_rooms @@ -570,3 +583,58 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): (subspace, joined_room), ], ) + + def test_fed_invited(self): + """ + A room which the user was invited to should be included in the response. + + This differs from test_fed_filtering in that the room itself is being + queried over federation, instead of it being included as a sub-room of + a space in the response. + """ + fed_hostname = self.hs.hostname + "2" + fed_room = "#subroom:" + fed_hostname + + # Poke an invite over federation into the database. + self._poke_fed_invite(fed_room, "@remote:" + fed_hostname) + + async def summarize_remote_room( + _self, room, suggested_only, max_children, exclude_rooms + ): + return [ + _RoomEntry( + fed_room, + { + "room_id": fed_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + ] + + # Add a room to the space which is on another server. + self._add_child(self.space, fed_room, self.token) + + with mock.patch( + "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room", + new=summarize_remote_room, + ): + result = self.get_success( + self.handler.get_space_summary(self.user, self.space) + ) + + self._assert_rooms( + result, + [ + self.space, + self.room, + fed_room, + ], + ) + self._assert_events( + result, + [ + (self.space, self.room), + (self.space, fed_room), + ], + ) -- cgit 1.5.1 From 8da9e3cb69db1bed68889b6a5fbcecf1bf20a235 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 10 Aug 2021 10:59:13 +0100 Subject: Move test_old_deps.sh to new ci dir --- .buildkite/scripts/test_old_deps.sh | 16 ---------------- .github/workflows/tests.yml | 2 +- MANIFEST.in | 1 + ci/scripts/test_old_deps.sh | 16 ++++++++++++++++ 4 files changed, 18 insertions(+), 17 deletions(-) delete mode 100755 .buildkite/scripts/test_old_deps.sh create mode 100755 ci/scripts/test_old_deps.sh diff --git a/.buildkite/scripts/test_old_deps.sh b/.buildkite/scripts/test_old_deps.sh deleted file mode 100755 index 9270d55f04..0000000000 --- a/.buildkite/scripts/test_old_deps.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env bash - -# this script is run by buildkite in a plain `bionic` container; it installs the -# minimal requirements for tox and hands over to the py3-old tox environment. - -set -ex - -apt-get update -apt-get install -y python3 python3-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox - -export LANG="C.UTF-8" - -# Prevent virtualenv from auto-updating pip to an incompatible version -export VIRTUALENV_NO_DOWNLOAD=1 - -exec tox -e py3-old,combine diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 75c2976a25..8612d1fb3a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -144,7 +144,7 @@ jobs: uses: docker://ubuntu:bionic # For old python and sqlite with: workdir: /github/workspace - entrypoint: .buildkite/scripts/test_old_deps.sh + entrypoint: ci/scripts/test_old_deps.sh env: TRIAL_FLAGS: "--jobs=2" - name: Dump logs diff --git a/MANIFEST.in b/MANIFEST.in index 0522319c40..174e1b1f47 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -47,6 +47,7 @@ recursive-include changelog.d * prune .buildkite prune .circleci prune .github +prune ci prune contrib prune debian prune demo/etc diff --git a/ci/scripts/test_old_deps.sh b/ci/scripts/test_old_deps.sh new file mode 100755 index 0000000000..8b473936f8 --- /dev/null +++ b/ci/scripts/test_old_deps.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash + +# this script is run by GitHub Actions in a plain `bionic` container; it installs the +# minimal requirements for tox and hands over to the py3-old tox environment. + +set -ex + +apt-get update +apt-get install -y python3 python3-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox + +export LANG="C.UTF-8" + +# Prevent virtualenv from auto-updating pip to an incompatible version +export VIRTUALENV_NO_DOWNLOAD=1 + +exec tox -e py3-old,combine -- cgit 1.5.1 From 03fb99a5c8dbe67cf300986e76ea0e8183641211 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 10 Aug 2021 12:15:10 +0100 Subject: check-newsfragment: pass pr number explicitly use PULL_REQUEST_NUMBER instead of BUILDKITE_PULL_REQUEST remove the other user of BUILDKITE_PULL_REQUEST, namely merge_base_branch.sh --- .buildkite/.env | 1 - .buildkite/merge_base_branch.sh | 35 ----------------------------------- .github/workflows/tests.yml | 6 ++---- scripts-dev/check-newsfragment | 2 +- 4 files changed, 3 insertions(+), 41 deletions(-) delete mode 100755 .buildkite/merge_base_branch.sh diff --git a/.buildkite/.env b/.buildkite/.env index 85b102d07f..a2969b96a1 100644 --- a/.buildkite/.env +++ b/.buildkite/.env @@ -7,7 +7,6 @@ BUILDKITE_JOB_ID BUILDKITE_BUILD_URL BUILDKITE_PROJECT_SLUG BUILDKITE_COMMIT -BUILDKITE_PULL_REQUEST BUILDKITE_TAG CODECOV_TOKEN TRIAL_FLAGS diff --git a/.buildkite/merge_base_branch.sh b/.buildkite/merge_base_branch.sh deleted file mode 100755 index 361440fd1a..0000000000 --- a/.buildkite/merge_base_branch.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/usr/bin/env bash - -set -e - -if [[ "$BUILDKITE_BRANCH" =~ ^(develop|master|dinsic|shhs|release-.*)$ ]]; then - echo "Not merging forward, as this is a release branch" - exit 0 -fi - -if [[ -z $BUILDKITE_PULL_REQUEST_BASE_BRANCH ]]; then - echo "Not a pull request, or hasn't had a PR opened yet..." - - # It probably hasn't had a PR opened yet. Since all PRs land on develop, we - # can probably assume it's based on it and will be merged into it. - GITBASE="develop" -else - # Get the reference, using the GitHub API - GITBASE=$BUILDKITE_PULL_REQUEST_BASE_BRANCH -fi - -echo "--- merge_base_branch $GITBASE" - -# Show what we are before -git --no-pager show -s - -# Set up username so it can do a merge -git config --global user.email bot@matrix.org -git config --global user.name "A robot" - -# Fetch and merge. If it doesn't work, it will raise due to set -e. -git fetch -u origin $GITBASE -git merge --no-edit --no-commit origin/$GITBASE - -# Show what we are after. -git --no-pager show -s diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8612d1fb3a..5349e83133 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -47,11 +47,9 @@ jobs: fetch-depth: 0 - uses: actions/setup-python@v2 - run: pip install tox - - name: Patch Buildkite-specific test script - run: | - sed -i -e 's/\$BUILDKITE_PULL_REQUEST/${{ github.event.number }}/' \ - scripts-dev/check-newsfragment - run: scripts-dev/check-newsfragment + env: + PULL_REQUEST_NUMBER: ${{ github.event.number }} lint-sdist: runs-on: ubuntu-latest diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment index af6d32e332..393a548d58 100755 --- a/scripts-dev/check-newsfragment +++ b/scripts-dev/check-newsfragment @@ -11,7 +11,7 @@ set -e git remote set-branches --add origin develop git fetch -q origin develop -pr="$BUILDKITE_PULL_REQUEST" +pr="$PULL_REQUEST_NUMBER" # if there are changes in the debian directory, check that the debian changelog # has been updated -- cgit 1.5.1 From 3d67b8c82b0128660376f81d226c111ad3e272a7 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 10 Aug 2021 12:43:50 +0100 Subject: Move sytest worker-blacklist to ci directory --- .buildkite/worker-blacklist | 10 ---------- .github/workflows/tests.yml | 2 +- ci/worker-blacklist | 10 ++++++++++ 3 files changed, 11 insertions(+), 11 deletions(-) delete mode 100644 .buildkite/worker-blacklist create mode 100644 ci/worker-blacklist diff --git a/.buildkite/worker-blacklist b/.buildkite/worker-blacklist deleted file mode 100644 index 5975cb98cf..0000000000 --- a/.buildkite/worker-blacklist +++ /dev/null @@ -1,10 +0,0 @@ -# This file serves as a blacklist for SyTest tests that we expect will fail in -# Synapse when run under worker mode. For more details, see sytest-blacklist. - -Can re-join room if re-invited - -# new failures as of https://github.com/matrix-org/sytest/pull/732 -Device list doesn't change if remote server is down - -# https://buildkite.com/matrix-dot-org/synapse/builds/6134#6f67bf47-e234-474d-80e8-c6e1868b15c5 -Server correctly handles incoming m.device_list_update diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5349e83133..cd88184488 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -230,7 +230,7 @@ jobs: steps: - uses: actions/checkout@v2 - name: Prepare test blacklist - run: cat sytest-blacklist .buildkite/worker-blacklist > synapse-blacklist-with-workers + run: cat sytest-blacklist ci/worker-blacklist > synapse-blacklist-with-workers - name: Run SyTest run: /bootstrap.sh synapse working-directory: /src diff --git a/ci/worker-blacklist b/ci/worker-blacklist new file mode 100644 index 0000000000..5975cb98cf --- /dev/null +++ b/ci/worker-blacklist @@ -0,0 +1,10 @@ +# This file serves as a blacklist for SyTest tests that we expect will fail in +# Synapse when run under worker mode. For more details, see sytest-blacklist. + +Can re-join room if re-invited + +# new failures as of https://github.com/matrix-org/sytest/pull/732 +Device list doesn't change if remote server is down + +# https://buildkite.com/matrix-dot-org/synapse/builds/6134#6f67bf47-e234-474d-80e8-c6e1868b15c5 +Server correctly handles incoming m.device_list_update -- cgit 1.5.1 From c5988a8eb7279f1de2d09258a41ff21158eb62c5 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 10 Aug 2021 12:45:13 +0100 Subject: Remove unused BUILDKITE_BRANCH env var --- .buildkite/.env | 1 - .github/workflows/tests.yml | 1 - 2 files changed, 2 deletions(-) diff --git a/.buildkite/.env b/.buildkite/.env index a2969b96a1..fc3606ead2 100644 --- a/.buildkite/.env +++ b/.buildkite/.env @@ -1,7 +1,6 @@ CI BUILDKITE BUILDKITE_BUILD_NUMBER -BUILDKITE_BRANCH BUILDKITE_BUILD_NUMBER BUILDKITE_JOB_ID BUILDKITE_BUILD_URL diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index cd88184488..a04f6abbed 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -195,7 +195,6 @@ jobs: volumes: - ${{ github.workspace }}:/src env: - BUILDKITE_BRANCH: ${{ github.head_ref }} POSTGRES: ${{ matrix.postgres && 1}} MULTI_POSTGRES: ${{ (matrix.postgres == 'multi-postgres') && 1}} WORKERS: ${{ matrix.workers && 1 }} -- cgit 1.5.1 From 58e5da5aa06ee4dc1ad5b2774e7bcd4eb9911a70 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 10 Aug 2021 13:11:43 +0100 Subject: Remove buildkite from portdb CI tests --- .buildkite/postgres-config.yaml | 19 ---------- .buildkite/scripts/postgres_exec.py | 31 ---------------- .buildkite/scripts/test_synapse_port_db.sh | 57 ------------------------------ .buildkite/sqlite-config.yaml | 16 --------- .coveragerc | 4 +-- .github/workflows/tests.yml | 8 +---- ci/postgres-config.yaml | 19 ++++++++++ ci/scripts/postgres_exec.py | 31 ++++++++++++++++ ci/scripts/test_synapse_port_db.sh | 57 ++++++++++++++++++++++++++++++ ci/sqlite-config.yaml | 16 +++++++++ 10 files changed, 126 insertions(+), 132 deletions(-) delete mode 100644 .buildkite/postgres-config.yaml delete mode 100755 .buildkite/scripts/postgres_exec.py delete mode 100755 .buildkite/scripts/test_synapse_port_db.sh delete mode 100644 .buildkite/sqlite-config.yaml create mode 100644 ci/postgres-config.yaml create mode 100755 ci/scripts/postgres_exec.py create mode 100755 ci/scripts/test_synapse_port_db.sh create mode 100644 ci/sqlite-config.yaml diff --git a/.buildkite/postgres-config.yaml b/.buildkite/postgres-config.yaml deleted file mode 100644 index 67e17fa9d1..0000000000 --- a/.buildkite/postgres-config.yaml +++ /dev/null @@ -1,19 +0,0 @@ -# Configuration file used for testing the 'synapse_port_db' script. -# Tells the script to connect to the postgresql database that will be available in the -# CI's Docker setup at the point where this file is considered. -server_name: "localhost:8800" - -signing_key_path: ".buildkite/test.signing.key" - -report_stats: false - -database: - name: "psycopg2" - args: - user: postgres - host: postgres - password: postgres - database: synapse - -# Suppress the key server warning. -trusted_key_servers: [] diff --git a/.buildkite/scripts/postgres_exec.py b/.buildkite/scripts/postgres_exec.py deleted file mode 100755 index 086b391724..0000000000 --- a/.buildkite/scripts/postgres_exec.py +++ /dev/null @@ -1,31 +0,0 @@ -#!/usr/bin/env python -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys - -import psycopg2 - -# a very simple replacment for `psql`, to make up for the lack of the postgres client -# libraries in the synapse docker image. - -# We use "postgres" as a database because it's bound to exist and the "synapse" one -# doesn't exist yet. -db_conn = psycopg2.connect( - user="postgres", host="postgres", password="postgres", dbname="postgres" -) -db_conn.autocommit = True -cur = db_conn.cursor() -for c in sys.argv[1:]: - cur.execute(c) diff --git a/.buildkite/scripts/test_synapse_port_db.sh b/.buildkite/scripts/test_synapse_port_db.sh deleted file mode 100755 index 82d7d56d4e..0000000000 --- a/.buildkite/scripts/test_synapse_port_db.sh +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/env bash -# -# Test script for 'synapse_port_db'. -# - sets up synapse and deps -# - runs the port script on a prepopulated test sqlite db -# - also runs it against an new sqlite db - - -set -xe -cd `dirname $0`/../.. - -echo "--- Install dependencies" - -# Install dependencies for this test. -pip install psycopg2 coverage coverage-enable-subprocess - -# Install Synapse itself. This won't update any libraries. -pip install -e . - -echo "--- Generate the signing key" - -# Generate the server's signing key. -python -m synapse.app.homeserver --generate-keys -c .buildkite/sqlite-config.yaml - -echo "--- Prepare test database" - -# Make sure the SQLite3 database is using the latest schema and has no pending background update. -scripts-dev/update_database --database-config .buildkite/sqlite-config.yaml - -# Create the PostgreSQL database. -./.buildkite/scripts/postgres_exec.py "CREATE DATABASE synapse" - -echo "+++ Run synapse_port_db against test database" -coverage run scripts/synapse_port_db --sqlite-database .buildkite/test_db.db --postgres-config .buildkite/postgres-config.yaml - -# We should be able to run twice against the same database. -echo "+++ Run synapse_port_db a second time" -coverage run scripts/synapse_port_db --sqlite-database .buildkite/test_db.db --postgres-config .buildkite/postgres-config.yaml - -##### - -# Now do the same again, on an empty database. - -echo "--- Prepare empty SQLite database" - -# we do this by deleting the sqlite db, and then doing the same again. -rm .buildkite/test_db.db - -scripts-dev/update_database --database-config .buildkite/sqlite-config.yaml - -# re-create the PostgreSQL database. -./.buildkite/scripts/postgres_exec.py \ - "DROP DATABASE synapse" \ - "CREATE DATABASE synapse" - -echo "+++ Run synapse_port_db against empty database" -coverage run scripts/synapse_port_db --sqlite-database .buildkite/test_db.db --postgres-config .buildkite/postgres-config.yaml diff --git a/.buildkite/sqlite-config.yaml b/.buildkite/sqlite-config.yaml deleted file mode 100644 index d16459cfd9..0000000000 --- a/.buildkite/sqlite-config.yaml +++ /dev/null @@ -1,16 +0,0 @@ -# Configuration file used for testing the 'synapse_port_db' script. -# Tells the 'update_database' script to connect to the test SQLite database to upgrade its -# schema and run background updates on it. -server_name: "localhost:8800" - -signing_key_path: ".buildkite/test.signing.key" - -report_stats: false - -database: - name: "sqlite3" - args: - database: ".buildkite/test_db.db" - -# Suppress the key server warning. -trusted_key_servers: [] diff --git a/.coveragerc b/.coveragerc index 11f2ec8387..bbf9046b06 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,8 +1,8 @@ [run] branch = True parallel = True -include=$TOP/synapse/* -data_file = $TOP/.coverage +include=$GITHUB_WORKSPACE/synapse/* +data_file = $GITHUB_WORKSPACE/.coverage [report] precision = 2 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a04f6abbed..572bc81b0f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -278,13 +278,7 @@ jobs: - uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - - name: Patch Buildkite-specific test scripts - run: | - sed -i -e 's/host="postgres"/host="localhost"/' .buildkite/scripts/postgres_exec.py - sed -i -e 's/host: postgres/host: localhost/' .buildkite/postgres-config.yaml - sed -i -e 's|/src/||' .buildkite/{sqlite,postgres}-config.yaml - sed -i -e 's/\$TOP/\$GITHUB_WORKSPACE/' .coveragerc - - run: .buildkite/scripts/test_synapse_port_db.sh + - run: ci/scripts/test_synapse_port_db.sh complement: if: ${{ !failure() && !cancelled() }} diff --git a/ci/postgres-config.yaml b/ci/postgres-config.yaml new file mode 100644 index 0000000000..511fef495d --- /dev/null +++ b/ci/postgres-config.yaml @@ -0,0 +1,19 @@ +# Configuration file used for testing the 'synapse_port_db' script. +# Tells the script to connect to the postgresql database that will be available in the +# CI's Docker setup at the point where this file is considered. +server_name: "localhost:8800" + +signing_key_path: "ci/test.signing.key" + +report_stats: false + +database: + name: "psycopg2" + args: + user: postgres + host: localhost + password: postgres + database: synapse + +# Suppress the key server warning. +trusted_key_servers: [] diff --git a/ci/scripts/postgres_exec.py b/ci/scripts/postgres_exec.py new file mode 100755 index 0000000000..0f39a336d5 --- /dev/null +++ b/ci/scripts/postgres_exec.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import psycopg2 + +# a very simple replacment for `psql`, to make up for the lack of the postgres client +# libraries in the synapse docker image. + +# We use "postgres" as a database because it's bound to exist and the "synapse" one +# doesn't exist yet. +db_conn = psycopg2.connect( + user="postgres", host="localhost", password="postgres", dbname="postgres" +) +db_conn.autocommit = True +cur = db_conn.cursor() +for c in sys.argv[1:]: + cur.execute(c) diff --git a/ci/scripts/test_synapse_port_db.sh b/ci/scripts/test_synapse_port_db.sh new file mode 100755 index 0000000000..9ee0ad42fc --- /dev/null +++ b/ci/scripts/test_synapse_port_db.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env bash +# +# Test script for 'synapse_port_db'. +# - sets up synapse and deps +# - runs the port script on a prepopulated test sqlite db +# - also runs it against an new sqlite db + + +set -xe +cd `dirname $0`/../.. + +echo "--- Install dependencies" + +# Install dependencies for this test. +pip install psycopg2 coverage coverage-enable-subprocess + +# Install Synapse itself. This won't update any libraries. +pip install -e . + +echo "--- Generate the signing key" + +# Generate the server's signing key. +python -m synapse.app.homeserver --generate-keys -c ci/sqlite-config.yaml + +echo "--- Prepare test database" + +# Make sure the SQLite3 database is using the latest schema and has no pending background update. +scripts-dev/update_database --database-config ci/sqlite-config.yaml + +# Create the PostgreSQL database. +./ci/scripts/postgres_exec.py "CREATE DATABASE synapse" + +echo "+++ Run synapse_port_db against test database" +coverage run scripts/synapse_port_db --sqlite-database ci/test_db.db --postgres-config ci/postgres-config.yaml + +# We should be able to run twice against the same database. +echo "+++ Run synapse_port_db a second time" +coverage run scripts/synapse_port_db --sqlite-database ci/test_db.db --postgres-config ci/postgres-config.yaml + +##### + +# Now do the same again, on an empty database. + +echo "--- Prepare empty SQLite database" + +# we do this by deleting the sqlite db, and then doing the same again. +rm ci/test_db.db + +scripts-dev/update_database --database-config ci/sqlite-config.yaml + +# re-create the PostgreSQL database. +./ci/scripts/postgres_exec.py \ + "DROP DATABASE synapse" \ + "CREATE DATABASE synapse" + +echo "+++ Run synapse_port_db against empty database" +coverage run scripts/synapse_port_db --sqlite-database ci/test_db.db --postgres-config ci/postgres-config.yaml diff --git a/ci/sqlite-config.yaml b/ci/sqlite-config.yaml new file mode 100644 index 0000000000..fd5a1c1451 --- /dev/null +++ b/ci/sqlite-config.yaml @@ -0,0 +1,16 @@ +# Configuration file used for testing the 'synapse_port_db' script. +# Tells the 'update_database' script to connect to the test SQLite database to upgrade its +# schema and run background updates on it. +server_name: "localhost:8800" + +signing_key_path: "ci/test.signing.key" + +report_stats: false + +database: + name: "sqlite3" + args: + database: "ci/test_db.db" + +# Suppress the key server warning. +trusted_key_servers: [] -- cgit 1.5.1 From c0ebdfc77e8f3e75ea162f12b2188acdde5bf4ef Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 10 Aug 2021 13:28:50 +0100 Subject: Kill off the .buildkite dir completely --- .buildkite/.env | 11 ----------- .buildkite/test_db.db | Bin 19296256 -> 0 bytes MANIFEST.in | 1 - ci/test_db.db | Bin 0 -> 19296256 bytes scripts-dev/lint.sh | 2 +- tox.ini | 2 +- 6 files changed, 2 insertions(+), 14 deletions(-) delete mode 100644 .buildkite/.env delete mode 100644 .buildkite/test_db.db create mode 100644 ci/test_db.db diff --git a/.buildkite/.env b/.buildkite/.env deleted file mode 100644 index fc3606ead2..0000000000 --- a/.buildkite/.env +++ /dev/null @@ -1,11 +0,0 @@ -CI -BUILDKITE -BUILDKITE_BUILD_NUMBER -BUILDKITE_BUILD_NUMBER -BUILDKITE_JOB_ID -BUILDKITE_BUILD_URL -BUILDKITE_PROJECT_SLUG -BUILDKITE_COMMIT -BUILDKITE_TAG -CODECOV_TOKEN -TRIAL_FLAGS diff --git a/.buildkite/test_db.db b/.buildkite/test_db.db deleted file mode 100644 index a0d9f16a75..0000000000 Binary files a/.buildkite/test_db.db and /dev/null differ diff --git a/MANIFEST.in b/MANIFEST.in index 174e1b1f47..37d61f40de 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -44,7 +44,6 @@ include book.toml include pyproject.toml recursive-include changelog.d * -prune .buildkite prune .circleci prune .github prune ci diff --git a/ci/test_db.db b/ci/test_db.db new file mode 100644 index 0000000000..a0d9f16a75 Binary files /dev/null and b/ci/test_db.db differ diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index 869eb2372d..2c77643cda 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -94,7 +94,7 @@ else "scripts-dev/build_debian_packages" "scripts-dev/sign_json" "scripts-dev/update_database" - "contrib" "synctl" "setup.py" "synmark" "stubs" ".buildkite" + "contrib" "synctl" "setup.py" "synmark" "stubs" "ci" ) fi fi diff --git a/tox.ini b/tox.ini index da77d124fc..b695126019 100644 --- a/tox.ini +++ b/tox.ini @@ -49,7 +49,7 @@ lint_targets = contrib synctl synmark - .buildkite + ci docker # default settings for all tox environments -- cgit 1.5.1 From fe1d0c86180ea025dfb444597c7ad72b036bbb10 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 10 Aug 2021 13:08:17 -0400 Subject: Add local support for the new spaces summary endpoint (MSC2946) (#10549) This adds support for the /hierarchy endpoint, which is an update to MSC2946. Currently this only supports rooms known locally to the homeserver. --- changelog.d/10527.misc | 2 +- changelog.d/10530.misc | 2 +- changelog.d/10549.feature | 1 + synapse/handlers/space_summary.py | 201 +++++++++++++++++- synapse/rest/client/v1/room.py | 41 ++++ tests/handlers/test_space_summary.py | 386 +++++++++++++++++++++++++---------- 6 files changed, 521 insertions(+), 112 deletions(-) create mode 100644 changelog.d/10549.feature diff --git a/changelog.d/10527.misc b/changelog.d/10527.misc index 3cf22f9daf..ffc4e4289c 100644 --- a/changelog.d/10527.misc +++ b/changelog.d/10527.misc @@ -1 +1 @@ -Prepare for the new spaces summary endpoint (updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946)). +Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/changelog.d/10530.misc b/changelog.d/10530.misc index 3cf22f9daf..ffc4e4289c 100644 --- a/changelog.d/10530.misc +++ b/changelog.d/10530.misc @@ -1 +1 @@ -Prepare for the new spaces summary endpoint (updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946)). +Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/changelog.d/10549.feature b/changelog.d/10549.feature new file mode 100644 index 0000000000..ffc4e4289c --- /dev/null +++ b/changelog.d/10549.feature @@ -0,0 +1 @@ +Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index d04afe6c31..fd76c34695 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -18,7 +18,7 @@ import re from collections import deque from typing import ( TYPE_CHECKING, - Collection, + Deque, Dict, Iterable, List, @@ -38,9 +38,12 @@ from synapse.api.constants import ( Membership, RoomTypes, ) +from synapse.api.errors import Codes, SynapseError from synapse.events import EventBase from synapse.events.utils import format_event_for_client_v2 from synapse.types import JsonDict +from synapse.util.caches.response_cache import ResponseCache +from synapse.util.stringutils import random_string if TYPE_CHECKING: from synapse.server import HomeServer @@ -57,6 +60,29 @@ MAX_ROOMS_PER_SPACE = 50 MAX_SERVERS_PER_SPACE = 3 +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _PaginationKey: + """The key used to find unique pagination session.""" + + # The first three entries match the request parameters (and cannot change + # during a pagination session). + room_id: str + suggested_only: bool + max_depth: Optional[int] + # The randomly generated token. + token: str + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _PaginationSession: + """The information that is stored for pagination.""" + + # The queue of rooms which are still to process. + room_queue: Deque["_RoomQueueEntry"] + # A set of rooms which have been processed. + processed_rooms: Set[str] + + class SpaceSummaryHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() @@ -67,6 +93,21 @@ class SpaceSummaryHandler: self._server_name = hs.hostname self._federation_client = hs.get_federation_client() + # A map of query information to the current pagination state. + # + # TODO Allow for multiple workers to share this data. + # TODO Expire pagination tokens. + self._pagination_sessions: Dict[_PaginationKey, _PaginationSession] = {} + + # If a user tries to fetch the same page multiple times in quick succession, + # only process the first attempt and return its result to subsequent requests. + self._pagination_response_cache: ResponseCache[ + Tuple[str, bool, Optional[int], Optional[int], Optional[str]] + ] = ResponseCache( + hs.get_clock(), + "get_room_hierarchy", + ) + async def get_space_summary( self, requester: str, @@ -130,7 +171,7 @@ class SpaceSummaryHandler: requester, None, room_id, suggested_only, max_children ) - events: Collection[JsonDict] = [] + events: Sequence[JsonDict] = [] if room_entry: rooms_result.append(room_entry.room) events = room_entry.children @@ -207,6 +248,154 @@ class SpaceSummaryHandler: return {"rooms": rooms_result, "events": events_result} + async def get_room_hierarchy( + self, + requester: str, + requested_room_id: str, + suggested_only: bool = False, + max_depth: Optional[int] = None, + limit: Optional[int] = None, + from_token: Optional[str] = None, + ) -> JsonDict: + """ + Implementation of the room hierarchy C-S API. + + Args: + requester: The user ID of the user making this request. + requested_room_id: The room ID to start the hierarchy at (the "root" room). + suggested_only: Whether we should only return children with the "suggested" + flag set. + max_depth: The maximum depth in the tree to explore, must be a + non-negative integer. + + 0 would correspond to just the root room, 1 would include just + the root room's children, etc. + limit: An optional limit on the number of rooms to return per + page. Must be a positive integer. + from_token: An optional pagination token. + + Returns: + The JSON hierarchy dictionary. + """ + # If a user tries to fetch the same page multiple times in quick succession, + # only process the first attempt and return its result to subsequent requests. + # + # This is due to the pagination process mutating internal state, attempting + # to process multiple requests for the same page will result in errors. + return await self._pagination_response_cache.wrap( + (requested_room_id, suggested_only, max_depth, limit, from_token), + self._get_room_hierarchy, + requester, + requested_room_id, + suggested_only, + max_depth, + limit, + from_token, + ) + + async def _get_room_hierarchy( + self, + requester: str, + requested_room_id: str, + suggested_only: bool = False, + max_depth: Optional[int] = None, + limit: Optional[int] = None, + from_token: Optional[str] = None, + ) -> JsonDict: + """See docstring for SpaceSummaryHandler.get_room_hierarchy.""" + + # first of all, check that the user is in the room in question (or it's + # world-readable) + await self._auth.check_user_in_room_or_world_readable( + requested_room_id, requester + ) + + # If this is continuing a previous session, pull the persisted data. + if from_token: + pagination_key = _PaginationKey( + requested_room_id, suggested_only, max_depth, from_token + ) + if pagination_key not in self._pagination_sessions: + raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM) + + # Load the previous state. + pagination_session = self._pagination_sessions[pagination_key] + room_queue = pagination_session.room_queue + processed_rooms = pagination_session.processed_rooms + else: + # the queue of rooms to process + room_queue = deque((_RoomQueueEntry(requested_room_id, ()),)) + + # Rooms we have already processed. + processed_rooms = set() + + rooms_result: List[JsonDict] = [] + + # Cap the limit to a server-side maximum. + if limit is None: + limit = MAX_ROOMS + else: + limit = min(limit, MAX_ROOMS) + + # Iterate through the queue until we reach the limit or run out of + # rooms to include. + while room_queue and len(rooms_result) < limit: + queue_entry = room_queue.popleft() + room_id = queue_entry.room_id + current_depth = queue_entry.depth + if room_id in processed_rooms: + # already done this room + continue + + logger.debug("Processing room %s", room_id) + + is_in_room = await self._store.is_host_joined(room_id, self._server_name) + if is_in_room: + room_entry = await self._summarize_local_room( + requester, + None, + room_id, + suggested_only, + # TODO Handle max children. + max_children=None, + ) + + if room_entry: + rooms_result.append(room_entry.as_json()) + + # Add the child to the queue. We have already validated + # that the vias are a list of server names. + # + # If the current depth is the maximum depth, do not queue + # more entries. + if max_depth is None or current_depth < max_depth: + room_queue.extendleft( + _RoomQueueEntry( + ev["state_key"], ev["content"]["via"], current_depth + 1 + ) + for ev in reversed(room_entry.children) + ) + + processed_rooms.add(room_id) + else: + # TODO Federation. + pass + + result: JsonDict = {"rooms": rooms_result} + + # If there's additional data, generate a pagination token (and persist state). + if room_queue: + next_token = random_string(24) + result["next_token"] = next_token + pagination_key = _PaginationKey( + requested_room_id, suggested_only, max_depth, next_token + ) + self._pagination_sessions[pagination_key] = _PaginationSession( + room_queue, processed_rooms + ) + + return result + async def federation_space_summary( self, origin: str, @@ -652,6 +841,7 @@ class SpaceSummaryHandler: class _RoomQueueEntry: room_id: str via: Sequence[str] + depth: int = 0 @attr.s(frozen=True, slots=True, auto_attribs=True) @@ -662,7 +852,12 @@ class _RoomEntry: # An iterable of the sorted, stripped children events for children of this room. # # This may not include all children. - children: Collection[JsonDict] = () + children: Sequence[JsonDict] = () + + def as_json(self) -> JsonDict: + result = dict(self.room) + result["children_state"] = self.children + return result def _has_valid_via(e: EventBase) -> bool: diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index f887970b76..b28b72bfbd 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -1445,6 +1445,46 @@ class RoomSpaceSummaryRestServlet(RestServlet): ) +class RoomHierarchyRestServlet(RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc2946" + "/rooms/(?P[^/]*)/hierarchy$" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._auth = hs.get_auth() + self._space_summary_handler = hs.get_space_summary_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request, allow_guest=True) + + max_depth = parse_integer(request, "max_depth") + if max_depth is not None and max_depth < 0: + raise SynapseError( + 400, "'max_depth' must be a non-negative integer", Codes.BAD_JSON + ) + + limit = parse_integer(request, "limit") + if limit is not None and limit <= 0: + raise SynapseError( + 400, "'limit' must be a positive integer", Codes.BAD_JSON + ) + + return 200, await self._space_summary_handler.get_room_hierarchy( + requester.user.to_string(), + room_id, + suggested_only=parse_boolean(request, "suggested_only", default=False), + max_depth=max_depth, + limit=limit, + from_token=parse_string(request, "from"), + ) + + def register_servlets(hs: "HomeServer", http_server, is_worker=False): msc2716_enabled = hs.config.experimental.msc2716_enabled @@ -1463,6 +1503,7 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False): RoomTypingRestServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server) RoomSpaceSummaryRestServlet(hs).register(http_server) + RoomHierarchyRestServlet(hs).register(http_server) RoomEventServlet(hs).register(http_server) JoinedRoomsRestServlet(hs).register(http_server) RoomAliasListServlet(hs).register(http_server) diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index f470c81ea2..255dd17f86 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -23,7 +23,7 @@ from synapse.api.constants import ( RestrictedJoinRuleTypes, RoomTypes, ) -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.handlers.space_summary import _child_events_comparison_key, _RoomEntry @@ -123,32 +123,83 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self.room = self.helper.create_room_as(self.user, tok=self.token) self._add_child(self.space, self.room, self.token) - def _add_child(self, space_id: str, room_id: str, token: str) -> None: + def _add_child( + self, space_id: str, room_id: str, token: str, order: Optional[str] = None + ) -> None: """Add a child room to a space.""" + content = {"via": [self.hs.hostname]} + if order is not None: + content["order"] = order self.helper.send_state( space_id, event_type=EventTypes.SpaceChild, - body={"via": [self.hs.hostname]}, + body=content, tok=token, state_key=room_id, ) - def _assert_rooms(self, result: JsonDict, rooms: Iterable[str]) -> None: - """Assert that the expected room IDs are in the response.""" - self.assertCountEqual([room.get("room_id") for room in result["rooms"]], rooms) - - def _assert_events( - self, result: JsonDict, events: Iterable[Tuple[str, str]] + def _assert_rooms( + self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] ) -> None: - """Assert that the expected parent / child room IDs are in the response.""" + """ + Assert that the expected room IDs and events are in the response. + + Args: + result: The result from the API call. + rooms_and_children: An iterable of tuples where each tuple is: + The expected room ID. + The expected IDs of any children rooms. + """ + room_ids = [] + children_ids = [] + for room_id, children in rooms_and_children: + room_ids.append(room_id) + if children: + children_ids.extend([(room_id, child_id) for child_id in children]) + self.assertCountEqual( + [room.get("room_id") for room in result["rooms"]], room_ids + ) self.assertCountEqual( [ (event.get("room_id"), event.get("state_key")) for event in result["events"] ], - events, + children_ids, ) + def _assert_hierarchy( + self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] + ) -> None: + """ + Assert that the expected room IDs are in the response. + + Args: + result: The result from the API call. + rooms_and_children: An iterable of tuples where each tuple is: + The expected room ID. + The expected IDs of any children rooms. + """ + result_room_ids = [] + result_children_ids = [] + for result_room in result["rooms"]: + result_room_ids.append(result_room["room_id"]) + result_children_ids.append( + [ + (cs["room_id"], cs["state_key"]) + for cs in result_room.get("children_state") + ] + ) + + room_ids = [] + children_ids = [] + for room_id, children in rooms_and_children: + room_ids.append(room_id) + children_ids.append([(room_id, child_id) for child_id in children]) + + # Note that order matters. + self.assertEqual(result_room_ids, room_ids) + self.assertEqual(result_children_ids, children_ids) + def _poke_fed_invite(self, room_id: str, from_user: str) -> None: """ Creates a invite (as if received over federation) for the room from the @@ -184,8 +235,13 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_space_summary(self.user, self.space)) # The result should have the space and the room in it, along with a link # from space -> room. - self._assert_rooms(result, [self.space, self.room]) - self._assert_events(result, [(self.space, self.room)]) + expected = [(self.space, [self.room]), (self.room, ())] + self._assert_rooms(result, expected) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) def test_visibility(self): """A user not in a space cannot inspect it.""" @@ -194,6 +250,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # The user cannot see the space. self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) + self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) # If the space is made world-readable it should return a result. self.helper.send_state( @@ -203,8 +260,11 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): tok=self.token, ) result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, [self.space, self.room]) - self._assert_events(result, [(self.space, self.room)]) + expected = [(self.space, [self.room]), (self.room, ())] + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) # Make it not world-readable again and confirm it results in an error. self.helper.send_state( @@ -214,12 +274,15 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): tok=self.token, ) self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) + self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) # Join the space and results should be returned. self.helper.join(self.space, user2, tok=token2) result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, [self.space, self.room]) - self._assert_events(result, [(self.space, self.room)]) + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) def _create_room_with_join_rule( self, join_rule: str, room_version: Optional[str] = None, **extra_content @@ -290,34 +353,33 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Join the space. self.helper.join(self.space, user2, tok=token2) result = self.get_success(self.handler.get_space_summary(user2, self.space)) - - self._assert_rooms( - result, - [ + expected = [ + ( self.space, - self.room, - public_room, - knock_room, - invited_room, - restricted_accessible_room, - world_readable_room, - joined_room, - ], - ) - self._assert_events( - result, - [ - (self.space, self.room), - (self.space, public_room), - (self.space, knock_room), - (self.space, not_invited_room), - (self.space, invited_room), - (self.space, restricted_room), - (self.space, restricted_accessible_room), - (self.space, world_readable_room), - (self.space, joined_room), - ], - ) + [ + self.room, + public_room, + knock_room, + not_invited_room, + invited_room, + restricted_room, + restricted_accessible_room, + world_readable_room, + joined_room, + ], + ), + (self.room, ()), + (public_room, ()), + (knock_room, ()), + (invited_room, ()), + (restricted_accessible_room, ()), + (world_readable_room, ()), + (joined_room, ()), + ] + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) def test_complex_space(self): """ @@ -349,19 +411,145 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_space_summary(self.user, self.space)) # The result should include each room a single time and each link. - self._assert_rooms(result, [self.space, self.room, subspace, subroom]) - self._assert_events( - result, - [ - (self.space, self.room), - (self.space, room2), - (self.space, subspace), - (subspace, subroom), - (subspace, self.room), - (subspace, room2), - ], + expected = [ + (self.space, [self.room, room2, subspace]), + (self.room, ()), + (subspace, [subroom, self.room, room2]), + (subroom, ()), + ] + self._assert_rooms(result, expected) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + def test_pagination(self): + """Test simple pagination works.""" + room_ids = [] + for i in range(1, 10): + room = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, room, self.token, order=str(i)) + room_ids.append(room) + # The room created initially doesn't have an order, so comes last. + room_ids.append(self.room) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, limit=7) + ) + # The result should have the space and all of the links, plus some of the + # rooms and a pagination token. + expected = [(self.space, room_ids)] + [ + (room_id, ()) for room_id in room_ids[:6] + ] + self._assert_hierarchy(result, expected) + self.assertIn("next_token", result) + + # Check the next page. + result = self.get_success( + self.handler.get_room_hierarchy( + self.user, self.space, limit=5, from_token=result["next_token"] + ) + ) + # The result should have the space and the room in it, along with a link + # from space -> room. + expected = [(room_id, ()) for room_id in room_ids[6:]] + self._assert_hierarchy(result, expected) + self.assertNotIn("next_token", result) + + def test_invalid_pagination_token(self): + """""" + room_ids = [] + for i in range(1, 10): + room = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, room, self.token, order=str(i)) + room_ids.append(room) + # The room created initially doesn't have an order, so comes last. + room_ids.append(self.room) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, limit=7) + ) + self.assertIn("next_token", result) + + # Changing the room ID, suggested-only, or max-depth causes an error. + self.get_failure( + self.handler.get_room_hierarchy( + self.user, self.room, from_token=result["next_token"] + ), + SynapseError, + ) + self.get_failure( + self.handler.get_room_hierarchy( + self.user, + self.space, + suggested_only=True, + from_token=result["next_token"], + ), + SynapseError, + ) + self.get_failure( + self.handler.get_room_hierarchy( + self.user, self.space, max_depth=0, from_token=result["next_token"] + ), + SynapseError, ) + # An invalid token is ignored. + self.get_failure( + self.handler.get_room_hierarchy(self.user, self.space, from_token="foo"), + SynapseError, + ) + + def test_max_depth(self): + """Create a deep tree to test the max depth against.""" + spaces = [self.space] + rooms = [self.room] + for _ in range(5): + spaces.append( + self.helper.create_room_as( + self.user, + tok=self.token, + extra_content={ + "creation_content": { + EventContentFields.ROOM_TYPE: RoomTypes.SPACE + } + }, + ) + ) + self._add_child(spaces[-2], spaces[-1], self.token) + rooms.append(self.helper.create_room_as(self.user, tok=self.token)) + self._add_child(spaces[-1], rooms[-1], self.token) + + # Test just the space itself. + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, max_depth=0) + ) + expected = [(spaces[0], [rooms[0], spaces[1]])] + self._assert_hierarchy(result, expected) + + # A single additional layer. + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, max_depth=1) + ) + expected += [ + (rooms[0], ()), + (spaces[1], [rooms[1], spaces[2]]), + ] + self._assert_hierarchy(result, expected) + + # A few layers. + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, max_depth=3) + ) + expected += [ + (rooms[1], ()), + (spaces[2], [rooms[2], spaces[3]]), + (rooms[2], ()), + (spaces[3], [rooms[3], spaces[4]]), + ] + self._assert_hierarchy(result, expected) + def test_fed_complex(self): """ Return data over federation and ensure that it is handled properly. @@ -417,15 +605,13 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self.handler.get_space_summary(self.user, self.space) ) - self._assert_rooms(result, [self.space, self.room, subspace, subroom]) - self._assert_events( - result, - [ - (self.space, self.room), - (self.space, subspace), - (subspace, subroom), - ], - ) + expected = [ + (self.space, [self.room, subspace]), + (self.room, ()), + (subspace, [subroom]), + (subroom, ()), + ] + self._assert_rooms(result, expected) def test_fed_filtering(self): """ @@ -554,35 +740,30 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self.handler.get_space_summary(self.user, self.space) ) - self._assert_rooms( - result, - [ - self.space, - self.room, + expected = [ + (self.space, [self.room, subspace]), + (self.room, ()), + ( subspace, - public_room, - knock_room, - invited_room, - restricted_accessible_room, - world_readable_room, - joined_room, - ], - ) - self._assert_events( - result, - [ - (self.space, self.room), - (self.space, subspace), - (subspace, public_room), - (subspace, knock_room), - (subspace, not_invited_room), - (subspace, invited_room), - (subspace, restricted_room), - (subspace, restricted_accessible_room), - (subspace, world_readable_room), - (subspace, joined_room), - ], - ) + [ + public_room, + knock_room, + not_invited_room, + invited_room, + restricted_room, + restricted_accessible_room, + world_readable_room, + joined_room, + ], + ), + (public_room, ()), + (knock_room, ()), + (invited_room, ()), + (restricted_accessible_room, ()), + (world_readable_room, ()), + (joined_room, ()), + ] + self._assert_rooms(result, expected) def test_fed_invited(self): """ @@ -623,18 +804,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self.handler.get_space_summary(self.user, self.space) ) - self._assert_rooms( - result, - [ - self.space, - self.room, - fed_room, - ], - ) - self._assert_events( - result, - [ - (self.space, self.room), - (self.space, fed_room), - ], - ) + expected = [ + (self.space, [self.room, fed_room]), + (self.room, ()), + (fed_room, ()), + ] + self._assert_rooms(result, expected) -- cgit 1.5.1 From b924a5c2e493423e1411322c933ceaad35fc4803 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 10 Aug 2021 18:37:40 +0100 Subject: Add changelog entry and signoff Signed-off-by: David Robertson --- changelog.d/10573.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/10573.misc diff --git a/changelog.d/10573.misc b/changelog.d/10573.misc new file mode 100644 index 0000000000..fc9b1a2f70 --- /dev/null +++ b/changelog.d/10573.misc @@ -0,0 +1 @@ +Remove references to BuildKite in favour of GitHub Actions. \ No newline at end of file -- cgit 1.5.1 From 8c654b73095a594b36101aa81cf91a8e1bebc29f Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 10 Aug 2021 18:10:40 -0500 Subject: Only return state events that the AS passed in via `state_events_at_start` (MSC2716) (#10552) * Only return state events that the AS passed in via state_events_at_start As discovered by @Half-Shot in https://github.com/matrix-org/matrix-doc/pull/2716#discussion_r684158448 Part of MSC2716 * Add changelog * Fix changelog extension --- changelog.d/10552.misc | 1 + synapse/rest/client/v1/room.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10552.misc diff --git a/changelog.d/10552.misc b/changelog.d/10552.misc new file mode 100644 index 0000000000..fc5f6aea5f --- /dev/null +++ b/changelog.d/10552.misc @@ -0,0 +1 @@ +Update `/batch_send` endpoint to only return `state_events` created by the `state_events_from_before` passed in. diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index b28b72bfbd..f1bc43be2d 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -437,6 +437,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): prev_state_ids = list(prev_state_map.values()) auth_event_ids = prev_state_ids + state_events_at_start = [] for state_event in body["state_events_at_start"]: assert_params_in_dict( state_event, ["type", "origin_server_ts", "content", "sender"] @@ -502,6 +503,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): ) event_id = event.event_id + state_events_at_start.append(event_id) auth_event_ids.append(event_id) events_to_create = body["events"] @@ -651,7 +653,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): event_ids.append(base_insertion_event.event_id) return 200, { - "state_events": auth_event_ids, + "state_events": state_events_at_start, "events": event_ids, "next_chunk_id": insertion_event["content"][ EventContentFields.MSC2716_NEXT_CHUNK_ID -- cgit 1.5.1 From 339c3918e1301d53b998c98282137b12d9d16c45 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 11 Aug 2021 16:34:59 +0200 Subject: support federation queries through http connect proxy (#10475) Signed-off-by: Marcus Hoffmann Signed-off-by: Dirk Klimpel dirk@klimpel.org --- changelog.d/10475.feature | 1 + docs/setup/forward_proxy.md | 6 +- docs/upgrade.md | 27 ++ synapse/http/connectproxyclient.py | 68 ++-- synapse/http/federation/matrix_federation_agent.py | 100 ++++- synapse/http/matrixfederationclient.py | 12 +- synapse/http/proxyagent.py | 51 +-- .../federation/test_matrix_federation_agent.py | 406 +++++++++++++++++---- tests/http/test_proxyagent.py | 75 ++-- 9 files changed, 555 insertions(+), 191 deletions(-) create mode 100644 changelog.d/10475.feature diff --git a/changelog.d/10475.feature b/changelog.d/10475.feature new file mode 100644 index 0000000000..52eab11b03 --- /dev/null +++ b/changelog.d/10475.feature @@ -0,0 +1 @@ +Add support for sending federation requests through a proxy. Contributed by @Bubu and @dklimpel. \ No newline at end of file diff --git a/docs/setup/forward_proxy.md b/docs/setup/forward_proxy.md index a0720ab342..494c14893b 100644 --- a/docs/setup/forward_proxy.md +++ b/docs/setup/forward_proxy.md @@ -45,18 +45,18 @@ The proxy will be **used** for: - recaptcha validation - CAS auth validation - OpenID Connect +- Outbound federation - Federation (checking public key revocation) +- Fetching public keys of other servers +- Downloading remote media It will **not be used** for: - Application Services - Identity servers -- Outbound federation - In worker configurations - connections between workers - connections from workers to Redis -- Fetching public keys of other servers -- Downloading remote media ## Troubleshooting diff --git a/docs/upgrade.md b/docs/upgrade.md index ce9167e6de..8831c9d6cf 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -86,6 +86,33 @@ process, for example: ``` +# Upgrading to v1.xx.0 + +## Add support for routing outbound HTTP requests via a proxy for federation + +Since Synapse 1.6.0 (2019-11-26) you can set a proxy for outbound HTTP requests via +http_proxy/https_proxy environment variables. This proxy was set for: +- push +- url previews +- phone-home stats +- recaptcha validation +- CAS auth validation +- OpenID Connect +- Federation (checking public key revocation) + +In this version we have added support for outbound requests for: +- Outbound federation +- Downloading remote media +- Fetching public keys of other servers + +These requests use the same proxy configuration. If you have a proxy configuration we +recommend to verify the configuration. It may be necessary to adjust the `no_proxy` +environment variable. + +See [using a forward proxy with Synapse documentation](setup/forward_proxy.md) for +details. + + # Upgrading to v1.39.0 ## Deprecation of the current third-party rules module interface diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py index 17e1c5abb1..c577142268 100644 --- a/synapse/http/connectproxyclient.py +++ b/synapse/http/connectproxyclient.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import logging +from typing import Optional +import attr from zope.interface import implementer from twisted.internet import defer, protocol @@ -21,7 +24,6 @@ from twisted.internet.error import ConnectError from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint from twisted.internet.protocol import ClientFactory, Protocol, connectionDone from twisted.web import http -from twisted.web.http_headers import Headers logger = logging.getLogger(__name__) @@ -30,6 +32,22 @@ class ProxyConnectError(ConnectError): pass +@attr.s +class ProxyCredentials: + username_password = attr.ib(type=bytes) + + def as_proxy_authorization_value(self) -> bytes: + """ + Return the value for a Proxy-Authorization header (i.e. 'Basic abdef=='). + + Returns: + A transformation of the authentication string the encoded value for + a Proxy-Authorization header. + """ + # Encode as base64 and prepend the authorization type + return b"Basic " + base64.encodebytes(self.username_password) + + @implementer(IStreamClientEndpoint) class HTTPConnectProxyEndpoint: """An Endpoint implementation which will send a CONNECT request to an http proxy @@ -46,7 +64,7 @@ class HTTPConnectProxyEndpoint: proxy_endpoint: the endpoint to use to connect to the proxy host: hostname that we want to CONNECT to port: port that we want to connect to - headers: Extra HTTP headers to include in the CONNECT request + proxy_creds: credentials to authenticate at proxy """ def __init__( @@ -55,20 +73,20 @@ class HTTPConnectProxyEndpoint: proxy_endpoint: IStreamClientEndpoint, host: bytes, port: int, - headers: Headers, + proxy_creds: Optional[ProxyCredentials], ): self._reactor = reactor self._proxy_endpoint = proxy_endpoint self._host = host self._port = port - self._headers = headers + self._proxy_creds = proxy_creds def __repr__(self): return "" % (self._proxy_endpoint,) def connect(self, protocolFactory: ClientFactory): f = HTTPProxiedClientFactory( - self._host, self._port, protocolFactory, self._headers + self._host, self._port, protocolFactory, self._proxy_creds ) d = self._proxy_endpoint.connect(f) # once the tcp socket connects successfully, we need to wait for the @@ -87,7 +105,7 @@ class HTTPProxiedClientFactory(protocol.ClientFactory): dst_host: hostname that we want to CONNECT to dst_port: port that we want to connect to wrapped_factory: The original Factory - headers: Extra HTTP headers to include in the CONNECT request + proxy_creds: credentials to authenticate at proxy """ def __init__( @@ -95,12 +113,12 @@ class HTTPProxiedClientFactory(protocol.ClientFactory): dst_host: bytes, dst_port: int, wrapped_factory: ClientFactory, - headers: Headers, + proxy_creds: Optional[ProxyCredentials], ): self.dst_host = dst_host self.dst_port = dst_port self.wrapped_factory = wrapped_factory - self.headers = headers + self.proxy_creds = proxy_creds self.on_connection = defer.Deferred() def startedConnecting(self, connector): @@ -114,7 +132,7 @@ class HTTPProxiedClientFactory(protocol.ClientFactory): self.dst_port, wrapped_protocol, self.on_connection, - self.headers, + self.proxy_creds, ) def clientConnectionFailed(self, connector, reason): @@ -145,7 +163,7 @@ class HTTPConnectProtocol(protocol.Protocol): connected_deferred: a Deferred which will be callbacked with wrapped_protocol when the CONNECT completes - headers: Extra HTTP headers to include in the CONNECT request + proxy_creds: credentials to authenticate at proxy """ def __init__( @@ -154,16 +172,16 @@ class HTTPConnectProtocol(protocol.Protocol): port: int, wrapped_protocol: Protocol, connected_deferred: defer.Deferred, - headers: Headers, + proxy_creds: Optional[ProxyCredentials], ): self.host = host self.port = port self.wrapped_protocol = wrapped_protocol self.connected_deferred = connected_deferred - self.headers = headers + self.proxy_creds = proxy_creds self.http_setup_client = HTTPConnectSetupClient( - self.host, self.port, self.headers + self.host, self.port, self.proxy_creds ) self.http_setup_client.on_connected.addCallback(self.proxyConnected) @@ -205,30 +223,38 @@ class HTTPConnectSetupClient(http.HTTPClient): Args: host: The hostname to send in the CONNECT message port: The port to send in the CONNECT message - headers: Extra headers to send with the CONNECT message + proxy_creds: credentials to authenticate at proxy """ - def __init__(self, host: bytes, port: int, headers: Headers): + def __init__( + self, + host: bytes, + port: int, + proxy_creds: Optional[ProxyCredentials], + ): self.host = host self.port = port - self.headers = headers + self.proxy_creds = proxy_creds self.on_connected = defer.Deferred() def connectionMade(self): logger.debug("Connected to proxy, sending CONNECT") self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port)) - # Send any additional specified headers - for name, values in self.headers.getAllRawHeaders(): - for value in values: - self.sendHeader(name, value) + # Determine whether we need to set Proxy-Authorization headers + if self.proxy_creds: + # Set a Proxy-Authorization header + self.sendHeader( + b"Proxy-Authorization", + self.proxy_creds.as_proxy_authorization_value(), + ) self.endHeaders() def handleStatus(self, version: bytes, status: bytes, message: bytes): logger.debug("Got Status: %s %s %s", status, message, version) if status != b"200": - raise ProxyConnectError("Unexpected status on CONNECT: %s" % status) + raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}") def handleEndHeaders(self): logger.debug("End Headers") diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index c16b7f10e6..1238bfd287 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -14,6 +14,10 @@ import logging import urllib.parse from typing import Any, Generator, List, Optional +from urllib.request import ( # type: ignore[attr-defined] + getproxies_environment, + proxy_bypass_environment, +) from netaddr import AddrFormatError, IPAddress, IPSet from zope.interface import implementer @@ -30,9 +34,12 @@ from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer, IResponse from synapse.crypto.context_factory import FederationPolicyForHTTPS -from synapse.http.client import BlacklistingAgentWrapper +from synapse.http import proxyagent +from synapse.http.client import BlacklistingAgentWrapper, BlacklistingReactorWrapper +from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.http.federation.well_known_resolver import WellKnownResolver +from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.types import ISynapseReactor from synapse.util import Clock @@ -57,6 +64,14 @@ class MatrixFederationAgent: user_agent: The user agent header to use for federation requests. + ip_whitelist: Allowed IP addresses. + + ip_blacklist: Disallowed IP addresses. + + proxy_reactor: twisted reactor to use for connections to the proxy server + reactor might have some blacklisting applied (i.e. for DNS queries), + but we need unblocked access to the proxy. + _srv_resolver: SrvResolver implementation to use for looking up SRV records. None to use a default implementation. @@ -71,11 +86,18 @@ class MatrixFederationAgent: reactor: ISynapseReactor, tls_client_options_factory: Optional[FederationPolicyForHTTPS], user_agent: bytes, + ip_whitelist: IPSet, ip_blacklist: IPSet, _srv_resolver: Optional[SrvResolver] = None, _well_known_resolver: Optional[WellKnownResolver] = None, ): - self._reactor = reactor + # proxy_reactor is not blacklisted + proxy_reactor = reactor + + # We need to use a DNS resolver which filters out blacklisted IP + # addresses, to prevent DNS rebinding. + reactor = BlacklistingReactorWrapper(reactor, ip_whitelist, ip_blacklist) + self._clock = Clock(reactor) self._pool = HTTPConnectionPool(reactor) self._pool.retryAutomatically = False @@ -83,24 +105,27 @@ class MatrixFederationAgent: self._pool.cachedConnectionTimeout = 2 * 60 self._agent = Agent.usingEndpointFactory( - self._reactor, + reactor, MatrixHostnameEndpointFactory( - reactor, tls_client_options_factory, _srv_resolver + reactor, + proxy_reactor, + tls_client_options_factory, + _srv_resolver, ), pool=self._pool, ) self.user_agent = user_agent if _well_known_resolver is None: - # Note that the name resolver has already been wrapped in a - # IPBlacklistingResolver by MatrixFederationHttpClient. _well_known_resolver = WellKnownResolver( - self._reactor, + reactor, agent=BlacklistingAgentWrapper( - Agent( - self._reactor, + ProxyAgent( + reactor, + proxy_reactor, pool=self._pool, contextFactory=tls_client_options_factory, + use_proxy=True, ), ip_blacklist=ip_blacklist, ), @@ -200,10 +225,12 @@ class MatrixHostnameEndpointFactory: def __init__( self, reactor: IReactorCore, + proxy_reactor: IReactorCore, tls_client_options_factory: Optional[FederationPolicyForHTTPS], srv_resolver: Optional[SrvResolver], ): self._reactor = reactor + self._proxy_reactor = proxy_reactor self._tls_client_options_factory = tls_client_options_factory if srv_resolver is None: @@ -211,9 +238,10 @@ class MatrixHostnameEndpointFactory: self._srv_resolver = srv_resolver - def endpointForURI(self, parsed_uri): + def endpointForURI(self, parsed_uri: URI): return MatrixHostnameEndpoint( self._reactor, + self._proxy_reactor, self._tls_client_options_factory, self._srv_resolver, parsed_uri, @@ -227,23 +255,45 @@ class MatrixHostnameEndpoint: Args: reactor: twisted reactor to use for underlying requests + proxy_reactor: twisted reactor to use for connections to the proxy server. + 'reactor' might have some blacklisting applied (i.e. for DNS queries), + but we need unblocked access to the proxy. tls_client_options_factory: factory to use for fetching client tls options, or none to disable TLS. srv_resolver: The SRV resolver to use parsed_uri: The parsed URI that we're wanting to connect to. + + Raises: + ValueError if the environment variables contain an invalid proxy specification. + RuntimeError if no tls_options_factory is given for a https connection """ def __init__( self, reactor: IReactorCore, + proxy_reactor: IReactorCore, tls_client_options_factory: Optional[FederationPolicyForHTTPS], srv_resolver: SrvResolver, parsed_uri: URI, ): self._reactor = reactor - self._parsed_uri = parsed_uri + # http_proxy is not needed because federation is always over TLS + proxies = getproxies_environment() + https_proxy = proxies["https"].encode() if "https" in proxies else None + self.no_proxy = proxies["no"] if "no" in proxies else None + + # endpoint and credentials to use to connect to the outbound https proxy, if any. + ( + self._https_proxy_endpoint, + self._https_proxy_creds, + ) = proxyagent.http_proxy_endpoint( + https_proxy, + proxy_reactor, + tls_client_options_factory, + ) + # set up the TLS connection params # # XXX disabling TLS is really only supported here for the benefit of the @@ -273,9 +323,33 @@ class MatrixHostnameEndpoint: host = server.host port = server.port + should_skip_proxy = False + if self.no_proxy is not None: + should_skip_proxy = proxy_bypass_environment( + host.decode(), + proxies={"no": self.no_proxy}, + ) + + endpoint: IStreamClientEndpoint try: - logger.debug("Connecting to %s:%i", host.decode("ascii"), port) - endpoint = HostnameEndpoint(self._reactor, host, port) + if self._https_proxy_endpoint and not should_skip_proxy: + logger.debug( + "Connecting to %s:%i via %s", + host.decode("ascii"), + port, + self._https_proxy_endpoint, + ) + endpoint = HTTPConnectProxyEndpoint( + self._reactor, + self._https_proxy_endpoint, + host, + port, + proxy_creds=self._https_proxy_creds, + ) + else: + logger.debug("Connecting to %s:%i", host.decode("ascii"), port) + # not using a proxy + endpoint = HostnameEndpoint(self._reactor, host, port) if self._tls_options: endpoint = wrapClientTLS(self._tls_options, endpoint) result = await make_deferred_yieldable( diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 2efa15bf04..2e9898997c 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -59,7 +59,6 @@ from synapse.api.errors import ( from synapse.http import QuieterFileBodyProducer from synapse.http.client import ( BlacklistingAgentWrapper, - BlacklistingReactorWrapper, BodyExceededMaxSize, ByteWriteable, encode_query_args, @@ -69,7 +68,7 @@ from synapse.http.federation.matrix_federation_agent import MatrixFederationAgen from synapse.logging import opentracing from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags -from synapse.types import ISynapseReactor, JsonDict +from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure @@ -325,13 +324,7 @@ class MatrixFederationHttpClient: self.signing_key = hs.signing_key self.server_name = hs.hostname - # We need to use a DNS resolver which filters out blacklisted IP - # addresses, to prevent DNS rebinding. - self.reactor: ISynapseReactor = BlacklistingReactorWrapper( - hs.get_reactor(), - hs.config.federation_ip_range_whitelist, - hs.config.federation_ip_range_blacklist, - ) + self.reactor = hs.get_reactor() user_agent = hs.version_string if hs.config.user_agent_suffix: @@ -342,6 +335,7 @@ class MatrixFederationHttpClient: self.reactor, tls_client_options_factory, user_agent, + hs.config.federation_ip_range_whitelist, hs.config.federation_ip_range_blacklist, ) diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 19e987f118..a3f31452d0 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import base64 import logging import re from typing import Any, Dict, Optional, Tuple @@ -21,7 +20,6 @@ from urllib.request import ( # type: ignore[attr-defined] proxy_bypass_environment, ) -import attr from zope.interface import implementer from twisted.internet import defer @@ -38,7 +36,7 @@ from twisted.web.error import SchemeNotSupported from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS -from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint +from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials from synapse.types import ISynapseReactor logger = logging.getLogger(__name__) @@ -46,22 +44,6 @@ logger = logging.getLogger(__name__) _VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z") -@attr.s -class ProxyCredentials: - username_password = attr.ib(type=bytes) - - def as_proxy_authorization_value(self) -> bytes: - """ - Return the value for a Proxy-Authorization header (i.e. 'Basic abdef=='). - - Returns: - A transformation of the authentication string the encoded value for - a Proxy-Authorization header. - """ - # Encode as base64 and prepend the authorization type - return b"Basic " + base64.encodebytes(self.username_password) - - @implementer(IAgent) class ProxyAgent(_AgentBase): """An Agent implementation which will use an HTTP proxy if one was requested @@ -95,6 +77,7 @@ class ProxyAgent(_AgentBase): Raises: ValueError if use_proxy is set and the environment variables contain an invalid proxy specification. + RuntimeError if no tls_options_factory is given for a https connection """ def __init__( @@ -131,11 +114,11 @@ class ProxyAgent(_AgentBase): https_proxy = proxies["https"].encode() if "https" in proxies else None no_proxy = proxies["no"] if "no" in proxies else None - self.http_proxy_endpoint, self.http_proxy_creds = _http_proxy_endpoint( + self.http_proxy_endpoint, self.http_proxy_creds = http_proxy_endpoint( http_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs ) - self.https_proxy_endpoint, self.https_proxy_creds = _http_proxy_endpoint( + self.https_proxy_endpoint, self.https_proxy_creds = http_proxy_endpoint( https_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs ) @@ -224,22 +207,12 @@ class ProxyAgent(_AgentBase): and self.https_proxy_endpoint and not should_skip_proxy ): - connect_headers = Headers() - - # Determine whether we need to set Proxy-Authorization headers - if self.https_proxy_creds: - # Set a Proxy-Authorization header - connect_headers.addRawHeader( - b"Proxy-Authorization", - self.https_proxy_creds.as_proxy_authorization_value(), - ) - endpoint = HTTPConnectProxyEndpoint( self.proxy_reactor, self.https_proxy_endpoint, parsed_uri.host, parsed_uri.port, - headers=connect_headers, + self.https_proxy_creds, ) else: # not using a proxy @@ -268,10 +241,10 @@ class ProxyAgent(_AgentBase): ) -def _http_proxy_endpoint( +def http_proxy_endpoint( proxy: Optional[bytes], reactor: IReactorCore, - tls_options_factory: IPolicyForHTTPS, + tls_options_factory: Optional[IPolicyForHTTPS], **kwargs, ) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]: """Parses an http proxy setting and returns an endpoint for the proxy @@ -294,6 +267,7 @@ def _http_proxy_endpoint( Raise: ValueError if proxy has no hostname or unsupported scheme. + RuntimeError if no tls_options_factory is given for a https connection """ if proxy is None: return None, None @@ -305,8 +279,13 @@ def _http_proxy_endpoint( proxy_endpoint = HostnameEndpoint(reactor, host, port, **kwargs) if scheme == b"https": - tls_options = tls_options_factory.creatorForNetloc(host, port) - proxy_endpoint = wrapClientTLS(tls_options, proxy_endpoint) + if tls_options_factory: + tls_options = tls_options_factory.creatorForNetloc(host, port) + proxy_endpoint = wrapClientTLS(tls_options, proxy_endpoint) + else: + raise RuntimeError( + f"No TLS options for a https connection via proxy {proxy!s}" + ) return proxy_endpoint, credentials diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index a37bce08c3..992d8f94fd 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import base64 import logging -from typing import Optional -from unittest.mock import Mock +import os +from typing import Iterable, Optional +from unittest.mock import Mock, patch import treq from netaddr import IPSet @@ -22,11 +24,12 @@ from zope.interface import implementer from twisted.internet import defer from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOptions +from twisted.internet.interfaces import IProtocolFactory from twisted.internet.protocol import Factory -from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web._newclient import ResponseNeverReceived from twisted.web.client import Agent -from twisted.web.http import HTTPChannel +from twisted.web.http import HTTPChannel, Request from twisted.web.http_headers import Headers from twisted.web.iweb import IPolicyForHTTPS @@ -49,24 +52,6 @@ from tests.utils import default_config logger = logging.getLogger(__name__) -test_server_connection_factory = None - - -def get_connection_factory(): - # this needs to happen once, but not until we are ready to run the first test - global test_server_connection_factory - if test_server_connection_factory is None: - test_server_connection_factory = TestServerTLSConnectionFactory( - sanlist=[ - b"DNS:testserv", - b"DNS:target-server", - b"DNS:xn--bcher-kva.com", - b"IP:1.2.3.4", - b"IP:::1", - ] - ) - return test_server_connection_factory - # Once Async Mocks or lambdas are supported this can go away. def generate_resolve_service(result): @@ -100,24 +85,38 @@ class MatrixFederationAgentTests(unittest.TestCase): had_well_known_cache=self.had_well_known_cache, ) - self.agent = MatrixFederationAgent( - reactor=self.reactor, - tls_client_options_factory=self.tls_factory, - user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. - ip_blacklist=IPSet(), - _srv_resolver=self.mock_resolver, - _well_known_resolver=self.well_known_resolver, - ) - - def _make_connection(self, client_factory, expected_sni): + def _make_connection( + self, + client_factory: IProtocolFactory, + ssl: bool = True, + expected_sni: bytes = None, + tls_sanlist: Optional[Iterable[bytes]] = None, + ) -> HTTPChannel: """Builds a test server, and completes the outgoing client connection + Args: + client_factory: the the factory that the + application is trying to use to make the outbound connection. We will + invoke it to build the client Protocol + + ssl: If true, we will expect an ssl connection and wrap + server_factory with a TLSMemoryBIOFactory + False is set only for when proxy expect http connection. + Otherwise federation requests use always https. + + expected_sni: the expected SNI value + + tls_sanlist: list of SAN entries for the TLS cert presented by the server. Returns: - HTTPChannel: the test server + the server Protocol returned by server_factory """ # build the test server - server_tls_protocol = _build_test_server(get_connection_factory()) + server_factory = _get_test_protocol_factory() + if ssl: + server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist) + + server_protocol = server_factory.buildProtocol(None) # now, tell the client protocol factory to build the client protocol (it will be a # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an @@ -128,35 +127,39 @@ class MatrixFederationAgentTests(unittest.TestCase): # stubbing that out here. client_protocol = client_factory.buildProtocol(None) client_protocol.makeConnection( - FakeTransport(server_tls_protocol, self.reactor, client_protocol) + FakeTransport(server_protocol, self.reactor, client_protocol) ) - # tell the server tls protocol to send its stuff back to the client, too - server_tls_protocol.makeConnection( - FakeTransport(client_protocol, self.reactor, server_tls_protocol) + # tell the server protocol to send its stuff back to the client, too + server_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, server_protocol) ) - # grab a hold of the TLS connection, in case it gets torn down - server_tls_connection = server_tls_protocol._tlsConnection - - # fish the test server back out of the server-side TLS protocol. - http_protocol = server_tls_protocol.wrappedProtocol + if ssl: + # fish the test server back out of the server-side TLS protocol. + http_protocol = server_protocol.wrappedProtocol + # grab a hold of the TLS connection, in case it gets torn down + tls_connection = server_protocol._tlsConnection + else: + http_protocol = server_protocol + tls_connection = None - # give the reactor a pump to get the TLS juices flowing. - self.reactor.pump((0.1,)) + # give the reactor a pump to get the TLS juices flowing (if needed) + self.reactor.advance(0) # check the SNI - server_name = server_tls_connection.get_servername() - self.assertEqual( - server_name, - expected_sni, - "Expected SNI %s but got %s" % (expected_sni, server_name), - ) + if expected_sni is not None: + server_name = tls_connection.get_servername() + self.assertEqual( + server_name, + expected_sni, + f"Expected SNI {expected_sni!s} but got {server_name!s}", + ) return http_protocol @defer.inlineCallbacks - def _make_get_request(self, uri): + def _make_get_request(self, uri: bytes): """ Sends a simple GET request via the agent, and checks its logcontext management """ @@ -180,20 +183,20 @@ class MatrixFederationAgentTests(unittest.TestCase): def _handle_well_known_connection( self, - client_factory, - expected_sni, - content, + client_factory: IProtocolFactory, + expected_sni: bytes, + content: bytes, response_headers: Optional[dict] = None, - ): + ) -> HTTPChannel: """Handle an outgoing HTTPs connection: wire it up to a server, check that the request is for a .well-known, and send the response. Args: - client_factory (IProtocolFactory): outgoing connection - expected_sni (bytes): SNI that we expect the outgoing connection to send - content (bytes): content to send back as the .well-known + client_factory: outgoing connection + expected_sni: SNI that we expect the outgoing connection to send + content: content to send back as the .well-known Returns: - HTTPChannel: server impl + server impl """ # make the connection for .well-known well_known_server = self._make_connection( @@ -209,7 +212,10 @@ class MatrixFederationAgentTests(unittest.TestCase): return well_known_server def _send_well_known_response( - self, request, content, headers: Optional[dict] = None + self, + request: Request, + content: bytes, + headers: Optional[dict] = None, ): """Check that an incoming request looks like a valid .well-known request, and send back the response. @@ -225,10 +231,37 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) - def test_get(self): + def _make_agent(self) -> MatrixFederationAgent: """ - happy-path test of a GET request with an explicit port + If a proxy server is set, the MatrixFederationAgent must be created again + because it is created too early during setUp """ + return MatrixFederationAgent( + reactor=self.reactor, + tls_client_options_factory=self.tls_factory, + user_agent="test-agent", # Note that this is unused since _well_known_resolver is provided. + ip_whitelist=IPSet(), + ip_blacklist=IPSet(), + _srv_resolver=self.mock_resolver, + _well_known_resolver=self.well_known_resolver, + ) + + def test_get(self): + """happy-path test of a GET request with an explicit port""" + self._do_get() + + @patch.dict( + os.environ, + {"https_proxy": "proxy.com", "no_proxy": "testserv"}, + ) + def test_get_bypass_proxy(self): + """test of a GET request with an explicit port and bypass proxy""" + self._do_get() + + def _do_get(self): + """test of a GET request with an explicit port""" + self.agent = self._make_agent() + self.reactor.lookups["testserv"] = "1.2.3.4" test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar") @@ -282,10 +315,188 @@ class MatrixFederationAgentTests(unittest.TestCase): json = self.successResultOf(treq.json_content(response)) self.assertEqual(json, {"a": 1}) + @patch.dict( + os.environ, {"https_proxy": "http://proxy.com", "no_proxy": "unused.com"} + ) + def test_get_via_http_proxy(self): + """test for federation request through a http proxy""" + self._do_get_via_proxy(expect_proxy_ssl=False, expected_auth_credentials=None) + + @patch.dict( + os.environ, + {"https_proxy": "http://user:pass@proxy.com", "no_proxy": "unused.com"}, + ) + def test_get_via_http_proxy_with_auth(self): + """test for federation request through a http proxy with authentication""" + self._do_get_via_proxy( + expect_proxy_ssl=False, expected_auth_credentials=b"user:pass" + ) + + @patch.dict( + os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"} + ) + def test_get_via_https_proxy(self): + """test for federation request through a https proxy""" + self._do_get_via_proxy(expect_proxy_ssl=True, expected_auth_credentials=None) + + @patch.dict( + os.environ, + {"https_proxy": "https://user:pass@proxy.com", "no_proxy": "unused.com"}, + ) + def test_get_via_https_proxy_with_auth(self): + """test for federation request through a https proxy with authentication""" + self._do_get_via_proxy( + expect_proxy_ssl=True, expected_auth_credentials=b"user:pass" + ) + + def _do_get_via_proxy( + self, + expect_proxy_ssl: bool = False, + expected_auth_credentials: Optional[bytes] = None, + ): + """Send a https federation request via an agent and check that it is correctly + received at the proxy and client. The proxy can use either http or https. + Args: + expect_proxy_ssl: True if we expect the request to connect to the proxy via https. + expected_auth_credentials: credentials we expect to be presented to authenticate at the proxy + """ + self.agent = self._make_agent() + + self.reactor.lookups["testserv"] = "1.2.3.4" + self.reactor.lookups["proxy.com"] = "9.9.9.9" + test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + # make sure we are connecting to the proxy + self.assertEqual(host, "9.9.9.9") + self.assertEqual(port, 1080) + + # make a test server to act as the proxy, and wire up the client + proxy_server = self._make_connection( + client_factory, + ssl=expect_proxy_ssl, + tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None, + expected_sni=b"proxy.com" if expect_proxy_ssl else None, + ) + + assert isinstance(proxy_server, HTTPChannel) + + # now there should be a pending CONNECT request + self.assertEqual(len(proxy_server.requests), 1) + + request = proxy_server.requests[0] + self.assertEqual(request.method, b"CONNECT") + self.assertEqual(request.path, b"testserv:8448") + + # Check whether auth credentials have been supplied to the proxy + proxy_auth_header_values = request.requestHeaders.getRawHeaders( + b"Proxy-Authorization" + ) + + if expected_auth_credentials is not None: + # Compute the correct header value for Proxy-Authorization + encoded_credentials = base64.b64encode(expected_auth_credentials) + expected_header_value = b"Basic " + encoded_credentials + + # Validate the header's value + self.assertIn(expected_header_value, proxy_auth_header_values) + else: + # Check that the Proxy-Authorization header has not been supplied to the proxy + self.assertIsNone(proxy_auth_header_values) + + # tell the proxy server not to close the connection + proxy_server.persistent = True + + request.finish() + + # now we make another test server to act as the upstream HTTP server. + server_ssl_protocol = _wrap_server_factory_for_tls( + _get_test_protocol_factory() + ).buildProtocol(None) + + # Tell the HTTP server to send outgoing traffic back via the proxy's transport. + proxy_server_transport = proxy_server.transport + server_ssl_protocol.makeConnection(proxy_server_transport) + + # ... and replace the protocol on the proxy's transport with the + # TLSMemoryBIOProtocol for the test server, so that incoming traffic + # to the proxy gets sent over to the HTTP(s) server. + + # See also comment at `_do_https_request_via_proxy` + # in ../test_proxyagent.py for more details + if expect_proxy_ssl: + assert isinstance(proxy_server_transport, TLSMemoryBIOProtocol) + proxy_server_transport.wrappedProtocol = server_ssl_protocol + else: + assert isinstance(proxy_server_transport, FakeTransport) + client_protocol = proxy_server_transport.other + c2s_transport = client_protocol.transport + c2s_transport.other = server_ssl_protocol + + self.reactor.advance(0) + + server_name = server_ssl_protocol._tlsConnection.get_servername() + expected_sni = b"testserv" + self.assertEqual( + server_name, + expected_sni, + f"Expected SNI {expected_sni!s} but got {server_name!s}", + ) + + # now there should be a pending request + http_server = server_ssl_protocol.wrappedProtocol + self.assertEqual(len(http_server.requests), 1) + + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual( + request.requestHeaders.getRawHeaders(b"host"), [b"testserv:8448"] + ) + self.assertEqual( + request.requestHeaders.getRawHeaders(b"user-agent"), [b"test-agent"] + ) + # Check that the destination server DID NOT receive proxy credentials + self.assertIsNone(request.requestHeaders.getRawHeaders(b"Proxy-Authorization")) + content = request.content.read() + self.assertEqual(content, b"") + + # Deferred is still without a result + self.assertNoResult(test_d) + + # send the headers + request.responseHeaders.setRawHeaders(b"Content-Type", [b"application/json"]) + request.write("") + + self.reactor.pump((0.1,)) + + response = self.successResultOf(test_d) + + # that should give us a Response object + self.assertEqual(response.code, 200) + + # Send the body + request.write('{ "a": 1 }'.encode("ascii")) + request.finish() + + self.reactor.pump((0.1,)) + + # check it can be read + json = self.successResultOf(treq.json_content(response)) + self.assertEqual(json, {"a": 1}) + def test_get_ip_address(self): """ Test the behaviour when the server name contains an explicit IP (with no port) """ + self.agent = self._make_agent() + # there will be a getaddrinfo on the IP self.reactor.lookups["1.2.3.4"] = "1.2.3.4" @@ -320,6 +531,7 @@ class MatrixFederationAgentTests(unittest.TestCase): Test the behaviour when the server name contains an explicit IPv6 address (with no port) """ + self.agent = self._make_agent() # there will be a getaddrinfo on the IP self.reactor.lookups["::1"] = "::1" @@ -355,6 +567,7 @@ class MatrixFederationAgentTests(unittest.TestCase): Test the behaviour when the server name contains an explicit IPv6 address (with explicit port) """ + self.agent = self._make_agent() # there will be a getaddrinfo on the IP self.reactor.lookups["::1"] = "::1" @@ -389,6 +602,8 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when the certificate on the server doesn't match the hostname """ + self.agent = self._make_agent() + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv1"] = "1.2.3.4" @@ -441,6 +656,8 @@ class MatrixFederationAgentTests(unittest.TestCase): Test the behaviour when the server name contains an explicit IP, but the server cert doesn't cover it """ + self.agent = self._make_agent() + # there will be a getaddrinfo on the IP self.reactor.lookups["1.2.3.5"] = "1.2.3.5" @@ -471,6 +688,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when the server name has no port, no SRV, and no well-known """ + self.agent = self._make_agent() self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" @@ -524,6 +742,7 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_get_well_known(self): """Test the behaviour when the .well-known delegates elsewhere""" + self.agent = self._make_agent() self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" @@ -587,6 +806,8 @@ class MatrixFederationAgentTests(unittest.TestCase): """Test the behaviour when the server name has no port and no SRV record, but the .well-known has a 300 redirect """ + self.agent = self._make_agent() + self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" @@ -675,6 +896,7 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when the server name has an *invalid* well-known (and no SRV) """ + self.agent = self._make_agent() self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) self.reactor.lookups["testserv"] = "1.2.3.4" @@ -743,6 +965,7 @@ class MatrixFederationAgentTests(unittest.TestCase): reactor=self.reactor, tls_client_options_factory=tls_factory, user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below. + ip_whitelist=IPSet(), ip_blacklist=IPSet(), _srv_resolver=self.mock_resolver, _well_known_resolver=WellKnownResolver( @@ -780,6 +1003,8 @@ class MatrixFederationAgentTests(unittest.TestCase): """ Test the behaviour when there is a single SRV record """ + self.agent = self._make_agent() + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( [Server(host=b"srvtarget", port=8443)] ) @@ -820,6 +1045,8 @@ class MatrixFederationAgentTests(unittest.TestCase): """Test the behaviour when the .well-known redirects to a place where there is a SRV. """ + self.agent = self._make_agent() + self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["srvtarget"] = "5.6.7.8" @@ -876,6 +1103,7 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_idna_servername(self): """test the behaviour when the server name has idna chars in""" + self.agent = self._make_agent() self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) @@ -937,6 +1165,7 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_idna_srv_target(self): """test the behaviour when the target of a SRV record has idna chars""" + self.agent = self._make_agent() self.mock_resolver.resolve_service.side_effect = generate_resolve_service( [Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com @@ -1140,6 +1369,8 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_srv_fallbacks(self): """Test that other SRV results are tried if the first one fails.""" + self.agent = self._make_agent() + self.mock_resolver.resolve_service.side_effect = generate_resolve_service( [ Server(host=b"target.com", port=8443), @@ -1266,34 +1497,49 @@ def _check_logcontext(context): raise AssertionError("Expected logcontext %s but was %s" % (context, current)) -def _build_test_server(connection_creator): - """Construct a test server - - This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol - +def _wrap_server_factory_for_tls( + factory: IProtocolFactory, sanlist: Iterable[bytes] = None +) -> IProtocolFactory: + """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory + The resultant factory will create a TLS server which presents a certificate + signed by our test CA, valid for the domains in `sanlist` Args: - connection_creator (IOpenSSLServerConnectionCreator): thing to build - SSL connections - sanlist (list[bytes]): list of the SAN entries for the cert returned - by the server + factory: protocol factory to wrap + sanlist: list of domains the cert should be valid for + Returns: + interfaces.IProtocolFactory + """ + if sanlist is None: + sanlist = [ + b"DNS:testserv", + b"DNS:target-server", + b"DNS:xn--bcher-kva.com", + b"IP:1.2.3.4", + b"IP:::1", + ] + + connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist) + return TLSMemoryBIOFactory( + connection_creator, isClient=False, wrappedFactory=factory + ) + +def _get_test_protocol_factory() -> IProtocolFactory: + """Get a protocol Factory which will build an HTTPChannel Returns: - TLSMemoryBIOProtocol + interfaces.IProtocolFactory """ server_factory = Factory.forProtocol(HTTPChannel) + # Request.finish expects the factory to have a 'log' method. server_factory.log = _log_request - server_tls_factory = TLSMemoryBIOFactory( - connection_creator, isClient=False, wrappedFactory=server_factory - ) - - return server_tls_factory.buildProtocol(None) + return server_factory -def _log_request(request): +def _log_request(request: str): """Implements Factory.log, which is expected by Request.finish""" - logger.info("Completed request %s", request) + logger.info(f"Completed request {request}") @implementer(IPolicyForHTTPS) diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index e5865c161d..2db77c6a73 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -29,7 +29,8 @@ from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web.http import HTTPChannel from synapse.http.client import BlacklistingReactorWrapper -from synapse.http.proxyagent import ProxyAgent, ProxyCredentials, parse_proxy +from synapse.http.connectproxyclient import ProxyCredentials +from synapse.http.proxyagent import ProxyAgent, parse_proxy from tests.http import TestServerTLSConnectionFactory, get_test_https_policy from tests.server import FakeTransport, ThreadedMemoryReactorClock @@ -392,7 +393,9 @@ class MatrixFederationAgentTests(TestCase): """ Tests that requests can be made through a proxy. """ - self._do_http_request_via_proxy(ssl=False, auth_credentials=None) + self._do_http_request_via_proxy( + expect_proxy_ssl=False, expected_auth_credentials=None + ) @patch.dict( os.environ, @@ -402,13 +405,17 @@ class MatrixFederationAgentTests(TestCase): """ Tests that authenticated requests can be made through a proxy. """ - self._do_http_request_via_proxy(ssl=False, auth_credentials=b"bob:pinkponies") + self._do_http_request_via_proxy( + expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies" + ) @patch.dict( os.environ, {"http_proxy": "https://proxy.com:8888", "no_proxy": "unused.com"} ) def test_http_request_via_https_proxy(self): - self._do_http_request_via_proxy(ssl=True, auth_credentials=None) + self._do_http_request_via_proxy( + expect_proxy_ssl=True, expected_auth_credentials=None + ) @patch.dict( os.environ, @@ -418,12 +425,16 @@ class MatrixFederationAgentTests(TestCase): }, ) def test_http_request_via_https_proxy_with_auth(self): - self._do_http_request_via_proxy(ssl=True, auth_credentials=b"bob:pinkponies") + self._do_http_request_via_proxy( + expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies" + ) @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) def test_https_request_via_proxy(self): """Tests that TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(ssl=False, auth_credentials=None) + self._do_https_request_via_proxy( + expect_proxy_ssl=False, expected_auth_credentials=None + ) @patch.dict( os.environ, @@ -431,14 +442,18 @@ class MatrixFederationAgentTests(TestCase): ) def test_https_request_via_proxy_with_auth(self): """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(ssl=False, auth_credentials=b"bob:pinkponies") + self._do_https_request_via_proxy( + expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies" + ) @patch.dict( os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"} ) def test_https_request_via_https_proxy(self): """Tests that TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(ssl=True, auth_credentials=None) + self._do_https_request_via_proxy( + expect_proxy_ssl=True, expected_auth_credentials=None + ) @patch.dict( os.environ, @@ -446,20 +461,22 @@ class MatrixFederationAgentTests(TestCase): ) def test_https_request_via_https_proxy_with_auth(self): """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(ssl=True, auth_credentials=b"bob:pinkponies") + self._do_https_request_via_proxy( + expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies" + ) def _do_http_request_via_proxy( self, - ssl: bool = False, - auth_credentials: Optional[bytes] = None, + expect_proxy_ssl: bool = False, + expected_auth_credentials: Optional[bytes] = None, ): """Send a http request via an agent and check that it is correctly received at the proxy. The proxy can use either http or https. Args: - ssl: True if we expect the request to connect via https to proxy - auth_credentials: credentials to authenticate at proxy + expect_proxy_ssl: True if we expect the request to connect via https to proxy + expected_auth_credentials: credentials to authenticate at proxy """ - if ssl: + if expect_proxy_ssl: agent = ProxyAgent( self.reactor, use_proxy=True, contextFactory=get_test_https_policy() ) @@ -480,9 +497,9 @@ class MatrixFederationAgentTests(TestCase): http_server = self._make_connection( client_factory, _get_test_protocol_factory(), - ssl=ssl, - tls_sanlist=[b"DNS:proxy.com"] if ssl else None, - expected_sni=b"proxy.com" if ssl else None, + ssl=expect_proxy_ssl, + tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None, + expected_sni=b"proxy.com" if expect_proxy_ssl else None, ) # the FakeTransport is async, so we need to pump the reactor @@ -498,9 +515,9 @@ class MatrixFederationAgentTests(TestCase): b"Proxy-Authorization" ) - if auth_credentials is not None: + if expected_auth_credentials is not None: # Compute the correct header value for Proxy-Authorization - encoded_credentials = base64.b64encode(auth_credentials) + encoded_credentials = base64.b64encode(expected_auth_credentials) expected_header_value = b"Basic " + encoded_credentials # Validate the header's value @@ -523,14 +540,14 @@ class MatrixFederationAgentTests(TestCase): def _do_https_request_via_proxy( self, - ssl: bool = False, - auth_credentials: Optional[bytes] = None, + expect_proxy_ssl: bool = False, + expected_auth_credentials: Optional[bytes] = None, ): """Send a https request via an agent and check that it is correctly received at the proxy and client. The proxy can use either http or https. Args: - ssl: True if we expect the request to connect via https to proxy - auth_credentials: credentials to authenticate at proxy + expect_proxy_ssl: True if we expect the request to connect via https to proxy + expected_auth_credentials: credentials to authenticate at proxy """ agent = ProxyAgent( self.reactor, @@ -552,9 +569,9 @@ class MatrixFederationAgentTests(TestCase): proxy_server = self._make_connection( client_factory, _get_test_protocol_factory(), - ssl=ssl, - tls_sanlist=[b"DNS:proxy.com"] if ssl else None, - expected_sni=b"proxy.com" if ssl else None, + ssl=expect_proxy_ssl, + tls_sanlist=[b"DNS:proxy.com"] if expect_proxy_ssl else None, + expected_sni=b"proxy.com" if expect_proxy_ssl else None, ) assert isinstance(proxy_server, HTTPChannel) @@ -570,9 +587,9 @@ class MatrixFederationAgentTests(TestCase): b"Proxy-Authorization" ) - if auth_credentials is not None: + if expected_auth_credentials is not None: # Compute the correct header value for Proxy-Authorization - encoded_credentials = base64.b64encode(auth_credentials) + encoded_credentials = base64.b64encode(expected_auth_credentials) expected_header_value = b"Basic " + encoded_credentials # Validate the header's value @@ -606,7 +623,7 @@ class MatrixFederationAgentTests(TestCase): # Protocol to implement the proxy, which starts out by forwarding to an # HTTPChannel (to implement the CONNECT command) and can then be switched # into a mode where it forwards its traffic to another Protocol.) - if ssl: + if expect_proxy_ssl: assert isinstance(proxy_server_transport, TLSMemoryBIOProtocol) proxy_server_transport.wrappedProtocol = server_ssl_protocol else: -- cgit 1.5.1 From fab352ac2cb6a9d69a74be6d4255a9b71e0f7945 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 11 Aug 2021 10:43:40 -0400 Subject: Fix type hints in space summary tests. (#10575) And ensure that the file is checked via mypy. --- changelog.d/10575.feature | 1 + mypy.ini | 1 + tests/handlers/test_space_summary.py | 11 +++++------ tests/rest/client/v1/utils.py | 6 +++--- 4 files changed, 10 insertions(+), 9 deletions(-) create mode 100644 changelog.d/10575.feature diff --git a/changelog.d/10575.feature b/changelog.d/10575.feature new file mode 100644 index 0000000000..ffc4e4289c --- /dev/null +++ b/changelog.d/10575.feature @@ -0,0 +1 @@ +Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/mypy.ini b/mypy.ini index 8717ae738e..5d6cd557bc 100644 --- a/mypy.ini +++ b/mypy.ini @@ -86,6 +86,7 @@ files = tests/test_event_auth.py, tests/test_utils, tests/handlers/test_password_providers.py, + tests/handlers/test_space_summary.py, tests/rest/client/v1/test_login.py, tests/rest/client/v2_alpha/test_auth.py, tests/util/test_itertools.py, diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index 255dd17f86..04da9bcc25 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterable, Optional, Tuple +from typing import Any, Iterable, List, Optional, Tuple from unittest import mock from synapse.api.constants import ( @@ -127,7 +127,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self, space_id: str, room_id: str, token: str, order: Optional[str] = None ) -> None: """Add a child room to a space.""" - content = {"via": [self.hs.hostname]} + content: JsonDict = {"via": [self.hs.hostname]} if order is not None: content["order"] = order self.helper.send_state( @@ -439,9 +439,8 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) # The result should have the space and all of the links, plus some of the # rooms and a pagination token. - expected = [(self.space, room_ids)] + [ - (room_id, ()) for room_id in room_ids[:6] - ] + expected: List[Tuple[str, Iterable[str]]] = [(self.space, room_ids)] + expected += [(room_id, ()) for room_id in room_ids[:6]] self._assert_hierarchy(result, expected) self.assertIn("next_token", result) @@ -525,7 +524,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result = self.get_success( self.handler.get_room_hierarchy(self.user, self.space, max_depth=0) ) - expected = [(spaces[0], [rooms[0], spaces[1]])] + expected: List[Tuple[str, Iterable[str]]] = [(spaces[0], [rooms[0], spaces[1]])] self._assert_hierarchy(result, expected) # A single additional layer. diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index fc2d35596e..954ad1a1fd 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -47,10 +47,10 @@ class RestHelper: def create_room_as( self, - room_creator: str = None, + room_creator: Optional[str] = None, is_public: bool = True, - room_version: str = None, - tok: str = None, + room_version: Optional[str] = None, + tok: Optional[str] = None, expect_code: int = 200, extra_content: Optional[Dict] = None, custom_headers: Optional[ -- cgit 1.5.1 From 2ae2a04616a627eabbf3ca69700462a52f344e69 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 11 Aug 2021 14:31:39 -0400 Subject: Clarify error message when joining a restricted room. (#10572) --- changelog.d/10572.misc | 1 + synapse/handlers/event_auth.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10572.misc diff --git a/changelog.d/10572.misc b/changelog.d/10572.misc new file mode 100644 index 0000000000..008d7be444 --- /dev/null +++ b/changelog.d/10572.misc @@ -0,0 +1 @@ +Clarify error message when failing to join a restricted room. diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index e2410e482f..4288ffff09 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -213,7 +213,7 @@ class EventAuthHandler: raise AuthError( 403, - "You do not belong to any of the required rooms to join this room.", + "You do not belong to any of the required rooms/spaces to join this room.", ) async def has_restricted_join_rules( -- cgit 1.5.1 From 5acd8b5a960b1c53ce0b9efa304010ec5f0f6682 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 11 Aug 2021 14:52:09 -0400 Subject: Expire old spaces summary pagination sessions. (#10574) --- changelog.d/10574.feature | 1 + synapse/handlers/space_summary.py | 24 +++++++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10574.feature diff --git a/changelog.d/10574.feature b/changelog.d/10574.feature new file mode 100644 index 0000000000..ffc4e4289c --- /dev/null +++ b/changelog.d/10574.feature @@ -0,0 +1 @@ +Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index fd76c34695..8c9852bc89 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -77,6 +77,8 @@ class _PaginationKey: class _PaginationSession: """The information that is stored for pagination.""" + # The time the pagination session was created, in milliseconds. + creation_time_ms: int # The queue of rooms which are still to process. room_queue: Deque["_RoomQueueEntry"] # A set of rooms which have been processed. @@ -84,6 +86,9 @@ class _PaginationSession: class SpaceSummaryHandler: + # The time a pagination session remains valid for. + _PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000 + def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() self._auth = hs.get_auth() @@ -108,6 +113,21 @@ class SpaceSummaryHandler: "get_room_hierarchy", ) + def _expire_pagination_sessions(self): + """Expire pagination session which are old.""" + expire_before = ( + self._clock.time_msec() - self._PAGINATION_SESSION_VALIDITY_PERIOD_MS + ) + to_expire = [] + + for key, value in self._pagination_sessions.items(): + if value.creation_time_ms < expire_before: + to_expire.append(key) + + for key in to_expire: + logger.debug("Expiring pagination session id %s", key) + del self._pagination_sessions[key] + async def get_space_summary( self, requester: str, @@ -312,6 +332,8 @@ class SpaceSummaryHandler: # If this is continuing a previous session, pull the persisted data. if from_token: + self._expire_pagination_sessions() + pagination_key = _PaginationKey( requested_room_id, suggested_only, max_depth, from_token ) @@ -391,7 +413,7 @@ class SpaceSummaryHandler: requested_room_id, suggested_only, max_depth, next_token ) self._pagination_sessions[pagination_key] = _PaginationSession( - room_queue, processed_rooms + self._clock.time_msec(), room_queue, processed_rooms ) return result -- cgit 1.5.1 From 33ef86aa2515f623fa6e8657d16c918b6a6d9da5 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 11 Aug 2021 19:59:57 +0100 Subject: Rename ci to .ci --- .ci/postgres-config.yaml | 19 ++++++++++++ .ci/scripts/postgres_exec.py | 31 ++++++++++++++++++++ .ci/scripts/test_old_deps.sh | 16 ++++++++++ .ci/scripts/test_synapse_port_db.sh | 57 ++++++++++++++++++++++++++++++++++++ .ci/sqlite-config.yaml | 16 ++++++++++ .ci/test_db.db | Bin 0 -> 19296256 bytes .ci/worker-blacklist | 10 +++++++ .github/workflows/tests.yml | 6 ++-- ci/postgres-config.yaml | 19 ------------ ci/scripts/postgres_exec.py | 31 -------------------- ci/scripts/test_old_deps.sh | 16 ---------- ci/scripts/test_synapse_port_db.sh | 57 ------------------------------------ ci/sqlite-config.yaml | 16 ---------- ci/test_db.db | Bin 19296256 -> 0 bytes ci/worker-blacklist | 10 ------- 15 files changed, 152 insertions(+), 152 deletions(-) create mode 100644 .ci/postgres-config.yaml create mode 100755 .ci/scripts/postgres_exec.py create mode 100755 .ci/scripts/test_old_deps.sh create mode 100755 .ci/scripts/test_synapse_port_db.sh create mode 100644 .ci/sqlite-config.yaml create mode 100644 .ci/test_db.db create mode 100644 .ci/worker-blacklist delete mode 100644 ci/postgres-config.yaml delete mode 100755 ci/scripts/postgres_exec.py delete mode 100755 ci/scripts/test_old_deps.sh delete mode 100755 ci/scripts/test_synapse_port_db.sh delete mode 100644 ci/sqlite-config.yaml delete mode 100644 ci/test_db.db delete mode 100644 ci/worker-blacklist diff --git a/.ci/postgres-config.yaml b/.ci/postgres-config.yaml new file mode 100644 index 0000000000..f5a4aecd51 --- /dev/null +++ b/.ci/postgres-config.yaml @@ -0,0 +1,19 @@ +# Configuration file used for testing the 'synapse_port_db' script. +# Tells the script to connect to the postgresql database that will be available in the +# CI's Docker setup at the point where this file is considered. +server_name: "localhost:8800" + +signing_key_path: ".ci/test.signing.key" + +report_stats: false + +database: + name: "psycopg2" + args: + user: postgres + host: localhost + password: postgres + database: synapse + +# Suppress the key server warning. +trusted_key_servers: [] diff --git a/.ci/scripts/postgres_exec.py b/.ci/scripts/postgres_exec.py new file mode 100755 index 0000000000..0f39a336d5 --- /dev/null +++ b/.ci/scripts/postgres_exec.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import psycopg2 + +# a very simple replacment for `psql`, to make up for the lack of the postgres client +# libraries in the synapse docker image. + +# We use "postgres" as a database because it's bound to exist and the "synapse" one +# doesn't exist yet. +db_conn = psycopg2.connect( + user="postgres", host="localhost", password="postgres", dbname="postgres" +) +db_conn.autocommit = True +cur = db_conn.cursor() +for c in sys.argv[1:]: + cur.execute(c) diff --git a/.ci/scripts/test_old_deps.sh b/.ci/scripts/test_old_deps.sh new file mode 100755 index 0000000000..8b473936f8 --- /dev/null +++ b/.ci/scripts/test_old_deps.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash + +# this script is run by GitHub Actions in a plain `bionic` container; it installs the +# minimal requirements for tox and hands over to the py3-old tox environment. + +set -ex + +apt-get update +apt-get install -y python3 python3-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox + +export LANG="C.UTF-8" + +# Prevent virtualenv from auto-updating pip to an incompatible version +export VIRTUALENV_NO_DOWNLOAD=1 + +exec tox -e py3-old,combine diff --git a/.ci/scripts/test_synapse_port_db.sh b/.ci/scripts/test_synapse_port_db.sh new file mode 100755 index 0000000000..9ee0ad42fc --- /dev/null +++ b/.ci/scripts/test_synapse_port_db.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env bash +# +# Test script for 'synapse_port_db'. +# - sets up synapse and deps +# - runs the port script on a prepopulated test sqlite db +# - also runs it against an new sqlite db + + +set -xe +cd `dirname $0`/../.. + +echo "--- Install dependencies" + +# Install dependencies for this test. +pip install psycopg2 coverage coverage-enable-subprocess + +# Install Synapse itself. This won't update any libraries. +pip install -e . + +echo "--- Generate the signing key" + +# Generate the server's signing key. +python -m synapse.app.homeserver --generate-keys -c ci/sqlite-config.yaml + +echo "--- Prepare test database" + +# Make sure the SQLite3 database is using the latest schema and has no pending background update. +scripts-dev/update_database --database-config ci/sqlite-config.yaml + +# Create the PostgreSQL database. +./ci/scripts/postgres_exec.py "CREATE DATABASE synapse" + +echo "+++ Run synapse_port_db against test database" +coverage run scripts/synapse_port_db --sqlite-database ci/test_db.db --postgres-config ci/postgres-config.yaml + +# We should be able to run twice against the same database. +echo "+++ Run synapse_port_db a second time" +coverage run scripts/synapse_port_db --sqlite-database ci/test_db.db --postgres-config ci/postgres-config.yaml + +##### + +# Now do the same again, on an empty database. + +echo "--- Prepare empty SQLite database" + +# we do this by deleting the sqlite db, and then doing the same again. +rm ci/test_db.db + +scripts-dev/update_database --database-config ci/sqlite-config.yaml + +# re-create the PostgreSQL database. +./ci/scripts/postgres_exec.py \ + "DROP DATABASE synapse" \ + "CREATE DATABASE synapse" + +echo "+++ Run synapse_port_db against empty database" +coverage run scripts/synapse_port_db --sqlite-database ci/test_db.db --postgres-config ci/postgres-config.yaml diff --git a/.ci/sqlite-config.yaml b/.ci/sqlite-config.yaml new file mode 100644 index 0000000000..3373743da3 --- /dev/null +++ b/.ci/sqlite-config.yaml @@ -0,0 +1,16 @@ +# Configuration file used for testing the 'synapse_port_db' script. +# Tells the 'update_database' script to connect to the test SQLite database to upgrade its +# schema and run background updates on it. +server_name: "localhost:8800" + +signing_key_path: ".ci/test.signing.key" + +report_stats: false + +database: + name: "sqlite3" + args: + database: ".ci/test_db.db" + +# Suppress the key server warning. +trusted_key_servers: [] diff --git a/.ci/test_db.db b/.ci/test_db.db new file mode 100644 index 0000000000..a0d9f16a75 Binary files /dev/null and b/.ci/test_db.db differ diff --git a/.ci/worker-blacklist b/.ci/worker-blacklist new file mode 100644 index 0000000000..5975cb98cf --- /dev/null +++ b/.ci/worker-blacklist @@ -0,0 +1,10 @@ +# This file serves as a blacklist for SyTest tests that we expect will fail in +# Synapse when run under worker mode. For more details, see sytest-blacklist. + +Can re-join room if re-invited + +# new failures as of https://github.com/matrix-org/sytest/pull/732 +Device list doesn't change if remote server is down + +# https://buildkite.com/matrix-dot-org/synapse/builds/6134#6f67bf47-e234-474d-80e8-c6e1868b15c5 +Server correctly handles incoming m.device_list_update diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 572bc81b0f..df2e3901cb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -142,7 +142,7 @@ jobs: uses: docker://ubuntu:bionic # For old python and sqlite with: workdir: /github/workspace - entrypoint: ci/scripts/test_old_deps.sh + entrypoint: .ci/scripts/test_old_deps.sh env: TRIAL_FLAGS: "--jobs=2" - name: Dump logs @@ -229,7 +229,7 @@ jobs: steps: - uses: actions/checkout@v2 - name: Prepare test blacklist - run: cat sytest-blacklist ci/worker-blacklist > synapse-blacklist-with-workers + run: cat sytest-blacklist .ci/worker-blacklist > synapse-blacklist-with-workers - name: Run SyTest run: /bootstrap.sh synapse working-directory: /src @@ -278,7 +278,7 @@ jobs: - uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - - run: ci/scripts/test_synapse_port_db.sh + - run: .ci/scripts/test_synapse_port_db.sh complement: if: ${{ !failure() && !cancelled() }} diff --git a/ci/postgres-config.yaml b/ci/postgres-config.yaml deleted file mode 100644 index 511fef495d..0000000000 --- a/ci/postgres-config.yaml +++ /dev/null @@ -1,19 +0,0 @@ -# Configuration file used for testing the 'synapse_port_db' script. -# Tells the script to connect to the postgresql database that will be available in the -# CI's Docker setup at the point where this file is considered. -server_name: "localhost:8800" - -signing_key_path: "ci/test.signing.key" - -report_stats: false - -database: - name: "psycopg2" - args: - user: postgres - host: localhost - password: postgres - database: synapse - -# Suppress the key server warning. -trusted_key_servers: [] diff --git a/ci/scripts/postgres_exec.py b/ci/scripts/postgres_exec.py deleted file mode 100755 index 0f39a336d5..0000000000 --- a/ci/scripts/postgres_exec.py +++ /dev/null @@ -1,31 +0,0 @@ -#!/usr/bin/env python -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys - -import psycopg2 - -# a very simple replacment for `psql`, to make up for the lack of the postgres client -# libraries in the synapse docker image. - -# We use "postgres" as a database because it's bound to exist and the "synapse" one -# doesn't exist yet. -db_conn = psycopg2.connect( - user="postgres", host="localhost", password="postgres", dbname="postgres" -) -db_conn.autocommit = True -cur = db_conn.cursor() -for c in sys.argv[1:]: - cur.execute(c) diff --git a/ci/scripts/test_old_deps.sh b/ci/scripts/test_old_deps.sh deleted file mode 100755 index 8b473936f8..0000000000 --- a/ci/scripts/test_old_deps.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env bash - -# this script is run by GitHub Actions in a plain `bionic` container; it installs the -# minimal requirements for tox and hands over to the py3-old tox environment. - -set -ex - -apt-get update -apt-get install -y python3 python3-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox - -export LANG="C.UTF-8" - -# Prevent virtualenv from auto-updating pip to an incompatible version -export VIRTUALENV_NO_DOWNLOAD=1 - -exec tox -e py3-old,combine diff --git a/ci/scripts/test_synapse_port_db.sh b/ci/scripts/test_synapse_port_db.sh deleted file mode 100755 index 9ee0ad42fc..0000000000 --- a/ci/scripts/test_synapse_port_db.sh +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/env bash -# -# Test script for 'synapse_port_db'. -# - sets up synapse and deps -# - runs the port script on a prepopulated test sqlite db -# - also runs it against an new sqlite db - - -set -xe -cd `dirname $0`/../.. - -echo "--- Install dependencies" - -# Install dependencies for this test. -pip install psycopg2 coverage coverage-enable-subprocess - -# Install Synapse itself. This won't update any libraries. -pip install -e . - -echo "--- Generate the signing key" - -# Generate the server's signing key. -python -m synapse.app.homeserver --generate-keys -c ci/sqlite-config.yaml - -echo "--- Prepare test database" - -# Make sure the SQLite3 database is using the latest schema and has no pending background update. -scripts-dev/update_database --database-config ci/sqlite-config.yaml - -# Create the PostgreSQL database. -./ci/scripts/postgres_exec.py "CREATE DATABASE synapse" - -echo "+++ Run synapse_port_db against test database" -coverage run scripts/synapse_port_db --sqlite-database ci/test_db.db --postgres-config ci/postgres-config.yaml - -# We should be able to run twice against the same database. -echo "+++ Run synapse_port_db a second time" -coverage run scripts/synapse_port_db --sqlite-database ci/test_db.db --postgres-config ci/postgres-config.yaml - -##### - -# Now do the same again, on an empty database. - -echo "--- Prepare empty SQLite database" - -# we do this by deleting the sqlite db, and then doing the same again. -rm ci/test_db.db - -scripts-dev/update_database --database-config ci/sqlite-config.yaml - -# re-create the PostgreSQL database. -./ci/scripts/postgres_exec.py \ - "DROP DATABASE synapse" \ - "CREATE DATABASE synapse" - -echo "+++ Run synapse_port_db against empty database" -coverage run scripts/synapse_port_db --sqlite-database ci/test_db.db --postgres-config ci/postgres-config.yaml diff --git a/ci/sqlite-config.yaml b/ci/sqlite-config.yaml deleted file mode 100644 index fd5a1c1451..0000000000 --- a/ci/sqlite-config.yaml +++ /dev/null @@ -1,16 +0,0 @@ -# Configuration file used for testing the 'synapse_port_db' script. -# Tells the 'update_database' script to connect to the test SQLite database to upgrade its -# schema and run background updates on it. -server_name: "localhost:8800" - -signing_key_path: "ci/test.signing.key" - -report_stats: false - -database: - name: "sqlite3" - args: - database: "ci/test_db.db" - -# Suppress the key server warning. -trusted_key_servers: [] diff --git a/ci/test_db.db b/ci/test_db.db deleted file mode 100644 index a0d9f16a75..0000000000 Binary files a/ci/test_db.db and /dev/null differ diff --git a/ci/worker-blacklist b/ci/worker-blacklist deleted file mode 100644 index 5975cb98cf..0000000000 --- a/ci/worker-blacklist +++ /dev/null @@ -1,10 +0,0 @@ -# This file serves as a blacklist for SyTest tests that we expect will fail in -# Synapse when run under worker mode. For more details, see sytest-blacklist. - -Can re-join room if re-invited - -# new failures as of https://github.com/matrix-org/sytest/pull/732 -Device list doesn't change if remote server is down - -# https://buildkite.com/matrix-dot-org/synapse/builds/6134#6f67bf47-e234-474d-80e8-c6e1868b15c5 -Server correctly handles incoming m.device_list_update -- cgit 1.5.1 From 3ebb6694f018eedb7d3c4fda829540f07b45a5b1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 11 Aug 2021 15:04:51 -0400 Subject: Allow requesting the summary of a space which is joinable. (#10580) As opposed to only allowing the summary of spaces which the user is already in or has world-readable visibility. This makes the logic consistent with whether a space/room is returned as part of a space and whether a space summary can start at a space. --- changelog.d/10580.bugfix | 1 + synapse/handlers/space_summary.py | 31 ++++++++++++++++++------------- tests/handlers/test_space_summary.py | 28 ++++++++++++++++++++++++++-- 3 files changed, 45 insertions(+), 15 deletions(-) create mode 100644 changelog.d/10580.bugfix diff --git a/changelog.d/10580.bugfix b/changelog.d/10580.bugfix new file mode 100644 index 0000000000..f8da7382b7 --- /dev/null +++ b/changelog.d/10580.bugfix @@ -0,0 +1 @@ +Allow public rooms to be previewed in the spaces summary APIs from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index 8c9852bc89..893546e661 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -38,7 +38,7 @@ from synapse.api.constants import ( Membership, RoomTypes, ) -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import AuthError, Codes, SynapseError from synapse.events import EventBase from synapse.events.utils import format_event_for_client_v2 from synapse.types import JsonDict @@ -91,7 +91,6 @@ class SpaceSummaryHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() - self._auth = hs.get_auth() self._event_auth_handler = hs.get_event_auth_handler() self._store = hs.get_datastore() self._event_serializer = hs.get_event_client_serializer() @@ -153,9 +152,13 @@ class SpaceSummaryHandler: Returns: summary dict to return """ - # first of all, check that the user is in the room in question (or it's - # world-readable) - await self._auth.check_user_in_room_or_world_readable(room_id, requester) + # First of all, check that the room is accessible. + if not await self._is_local_room_accessible(room_id, requester): + raise AuthError( + 403, + "User %s not in room %s, and room previews are disabled" + % (requester, room_id), + ) # the queue of rooms to process room_queue = deque((_RoomQueueEntry(room_id, ()),)) @@ -324,11 +327,13 @@ class SpaceSummaryHandler: ) -> JsonDict: """See docstring for SpaceSummaryHandler.get_room_hierarchy.""" - # first of all, check that the user is in the room in question (or it's - # world-readable) - await self._auth.check_user_in_room_or_world_readable( - requested_room_id, requester - ) + # First of all, check that the room is accessible. + if not await self._is_local_room_accessible(requested_room_id, requester): + raise AuthError( + 403, + "User %s not in room %s, and room previews are disabled" + % (requester, requested_room_id), + ) # If this is continuing a previous session, pull the persisted data. if from_token: @@ -612,7 +617,7 @@ class SpaceSummaryHandler: return results async def _is_local_room_accessible( - self, room_id: str, requester: Optional[str], origin: Optional[str] + self, room_id: str, requester: Optional[str], origin: Optional[str] = None ) -> bool: """ Calculate whether the room should be shown in the spaces summary. @@ -766,7 +771,7 @@ class SpaceSummaryHandler: # Finally, check locally if we can access the room. The user might # already be in the room (if it was a child room), or there might be a # pending invite, etc. - return await self._is_local_room_accessible(room_id, requester, None) + return await self._is_local_room_accessible(room_id, requester) async def _build_room_entry(self, room_id: str, for_federation: bool) -> JsonDict: """ @@ -783,7 +788,7 @@ class SpaceSummaryHandler: stats = await self._store.get_room_with_stats(room_id) # currently this should be impossible because we call - # check_user_in_room_or_world_readable on the room before we get here, so + # _is_local_room_accessible on the room before we get here, so # there should always be an entry assert stats is not None, "unable to retrieve stats for %s" % (room_id,) diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index 04da9bcc25..806b886fe4 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -248,7 +248,21 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): user2 = self.register_user("user2", "pass") token2 = self.login("user2", "pass") - # The user cannot see the space. + # The user can see the space since it is publicly joinable. + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + expected = [(self.space, [self.room]), (self.room, ())] + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) + + # If the space is made invite-only, it should no longer be viewable. + self.helper.send_state( + self.space, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.INVITE}, + tok=self.token, + ) self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) @@ -260,7 +274,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): tok=self.token, ) result = self.get_success(self.handler.get_space_summary(user2, self.space)) - expected = [(self.space, [self.room]), (self.room, ())] self._assert_rooms(result, expected) result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) @@ -277,6 +290,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) # Join the space and results should be returned. + self.helper.invite(self.space, targ=user2, tok=self.token) self.helper.join(self.space, user2, tok=token2) result = self.get_success(self.handler.get_space_summary(user2, self.space)) self._assert_rooms(result, expected) @@ -284,6 +298,16 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) self._assert_hierarchy(result, expected) + # Attempting to view an unknown room returns the same error. + self.get_failure( + self.handler.get_space_summary(user2, "#not-a-space:" + self.hs.hostname), + AuthError, + ) + self.get_failure( + self.handler.get_room_hierarchy(user2, "#not-a-space:" + self.hs.hostname), + AuthError, + ) + def _create_room_with_join_rule( self, join_rule: str, room_version: Optional[str] = None, **extra_content ) -> str: -- cgit 1.5.1 From 6fcc3e0bc81b4ed738eee702b0e1d193c052d205 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 11 Aug 2021 20:08:14 +0100 Subject: Teach MANIFEST and tox about ci->.ci --- MANIFEST.in | 2 +- tox.ini | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 37d61f40de..44d5cc7618 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -46,7 +46,7 @@ recursive-include changelog.d * prune .circleci prune .github -prune ci +prune .ci prune contrib prune debian prune demo/etc diff --git a/tox.ini b/tox.ini index b695126019..5a62ec76c2 100644 --- a/tox.ini +++ b/tox.ini @@ -49,7 +49,7 @@ lint_targets = contrib synctl synmark - ci + .ci docker # default settings for all tox environments -- cgit 1.5.1 From cb5976ebd71e25da2b2370185d640e80a2245d04 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 11 Aug 2021 20:08:48 +0100 Subject: set TOP in sytest containers --- .coveragerc | 4 ++-- .github/workflows/tests.yml | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.coveragerc b/.coveragerc index bbf9046b06..11f2ec8387 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,8 +1,8 @@ [run] branch = True parallel = True -include=$GITHUB_WORKSPACE/synapse/* -data_file = $GITHUB_WORKSPACE/.coverage +include=$TOP/synapse/* +data_file = $TOP/.coverage [report] precision = 2 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index df2e3901cb..de022020cc 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -200,6 +200,7 @@ jobs: WORKERS: ${{ matrix.workers && 1 }} REDIS: ${{ matrix.redis && 1 }} BLACKLIST: ${{ matrix.workers && 'synapse-blacklist-with-workers' }} + TOP: ${{ github.workspace }} strategy: fail-fast: false -- cgit 1.5.1 From 92a8e68ba2f617119e0506cee76eed2ff45b323a Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 11 Aug 2021 20:19:56 +0100 Subject: Missed another ci->.ci Should have been more systematic with my grepping. --- .ci/scripts/test_synapse_port_db.sh | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/.ci/scripts/test_synapse_port_db.sh b/.ci/scripts/test_synapse_port_db.sh index 9ee0ad42fc..2b4e5ec170 100755 --- a/.ci/scripts/test_synapse_port_db.sh +++ b/.ci/scripts/test_synapse_port_db.sh @@ -20,22 +20,22 @@ pip install -e . echo "--- Generate the signing key" # Generate the server's signing key. -python -m synapse.app.homeserver --generate-keys -c ci/sqlite-config.yaml +python -m synapse.app.homeserver --generate-keys -c .ci/sqlite-config.yaml echo "--- Prepare test database" # Make sure the SQLite3 database is using the latest schema and has no pending background update. -scripts-dev/update_database --database-config ci/sqlite-config.yaml +scripts-dev/update_database --database-config .ci/sqlite-config.yaml # Create the PostgreSQL database. -./ci/scripts/postgres_exec.py "CREATE DATABASE synapse" +.ci/scripts/postgres_exec.py "CREATE DATABASE synapse" echo "+++ Run synapse_port_db against test database" -coverage run scripts/synapse_port_db --sqlite-database ci/test_db.db --postgres-config ci/postgres-config.yaml +coverage run scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml # We should be able to run twice against the same database. echo "+++ Run synapse_port_db a second time" -coverage run scripts/synapse_port_db --sqlite-database ci/test_db.db --postgres-config ci/postgres-config.yaml +coverage run scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml ##### @@ -44,14 +44,14 @@ coverage run scripts/synapse_port_db --sqlite-database ci/test_db.db --postgres- echo "--- Prepare empty SQLite database" # we do this by deleting the sqlite db, and then doing the same again. -rm ci/test_db.db +rm .ci/test_db.db -scripts-dev/update_database --database-config ci/sqlite-config.yaml +scripts-dev/update_database --database-config .ci/sqlite-config.yaml # re-create the PostgreSQL database. -./ci/scripts/postgres_exec.py \ +.ci/scripts/postgres_exec.py \ "DROP DATABASE synapse" \ "CREATE DATABASE synapse" echo "+++ Run synapse_port_db against empty database" -coverage run scripts/synapse_port_db --sqlite-database ci/test_db.db --postgres-config ci/postgres-config.yaml +coverage run scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml -- cgit 1.5.1 From 915b37e5efd4e0fb9e57ce9895300017b4b3dd43 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 11 Aug 2021 21:29:59 +0200 Subject: Admin API to delete media for a specific user (#10558) --- changelog.d/10558.feature | 1 + docs/admin_api/media_admin_api.md | 9 +- docs/admin_api/user_admin_api.md | 54 ++++- synapse/rest/admin/media.py | 4 +- synapse/rest/admin/users.py | 80 +++++++- synapse/rest/media/v1/media_repository.py | 6 +- tests/rest/admin/test_user.py | 321 +++++++++++++++++++----------- 7 files changed, 347 insertions(+), 128 deletions(-) create mode 100644 changelog.d/10558.feature diff --git a/changelog.d/10558.feature b/changelog.d/10558.feature new file mode 100644 index 0000000000..1f461bc70a --- /dev/null +++ b/changelog.d/10558.feature @@ -0,0 +1 @@ +Admin API to delete several media for a specific user. Contributed by @dklimpel. diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index 61bed1e0d5..ea05bd6e44 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -12,6 +12,7 @@ - [Delete local media](#delete-local-media) * [Delete a specific local media](#delete-a-specific-local-media) * [Delete local media by date or size](#delete-local-media-by-date-or-size) + * [Delete media uploaded by a user](#delete-media-uploaded-by-a-user) - [Purge Remote Media API](#purge-remote-media-api) # Querying media @@ -47,7 +48,8 @@ The API returns a JSON body like the following: ## List all media uploaded by a user Listing all media that has been uploaded by a local user can be achieved through -the use of the [List media of a user](user_admin_api.md#list-media-of-a-user) +the use of the +[List media uploaded by a user](user_admin_api.md#list-media-uploaded-by-a-user) Admin API. # Quarantine media @@ -281,6 +283,11 @@ The following fields are returned in the JSON response body: * `deleted_media`: an array of strings - List of deleted `media_id` * `total`: integer - Total number of deleted `media_id` +## Delete media uploaded by a user + +You can find details of how to delete multiple media uploaded by a user in +[User Admin API](user_admin_api.md#delete-media-uploaded-by-a-user). + # Purge Remote Media API The purge remote media API allows server admins to purge old cached remote media. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 160899754e..33811f5bbb 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -443,8 +443,9 @@ The following fields are returned in the JSON response body: - `joined_rooms` - An array of `room_id`. - `total` - Number of rooms. +## User media -## List media of a user +### List media uploaded by a user Gets a list of all local media that a specific `user_id` has created. By default, the response is ordered by descending creation date and ascending media ID. The newest media is on top. You can change the order with parameters @@ -543,7 +544,6 @@ The following fields are returned in the JSON response body: - `media` - An array of objects, each containing information about a media. Media objects contain the following fields: - - `created_ts` - integer - Timestamp when the content was uploaded in ms. - `last_access_ts` - integer - Timestamp when the content was last accessed in ms. - `media_id` - string - The id used to refer to the media. @@ -551,13 +551,58 @@ The following fields are returned in the JSON response body: - `media_type` - string - The MIME-type of the media. - `quarantined_by` - string - The user ID that initiated the quarantine request for this media. - - `safe_from_quarantine` - bool - Status if this media is safe from quarantining. - `upload_name` - string - The name the media was uploaded with. - - `next_token`: integer - Indication for pagination. See above. - `total` - integer - Total number of media. +### Delete media uploaded by a user + +This API deletes the *local* media from the disk of your own server +that a specific `user_id` has created. This includes any local thumbnails. + +This API will not affect media that has been uploaded to external +media repositories (e.g https://github.com/turt2live/matrix-media-repo/). + +By default, the API deletes media ordered by descending creation date and ascending media ID. +The newest media is deleted first. You can change the order with parameters +`order_by` and `dir`. If no `limit` is set the API deletes `100` files per request. + +The API is: + +``` +DELETE /_synapse/admin/v1/users//media +``` + +To use it, you will need to authenticate by providing an `access_token` for a +server admin: [Admin API](../usage/administration/admin_api) + +A response body like the following is returned: + +```json +{ + "deleted_media": [ + "abcdefghijklmnopqrstuvwx" + ], + "total": 1 +} +``` + +The following fields are returned in the JSON response body: + +* `deleted_media`: an array of strings - List of deleted `media_id` +* `total`: integer - Total number of deleted `media_id` + +**Note**: There is no `next_token`. This is not useful for deleting media, because +after deleting media the remaining media have a new order. + +**Parameters** + +This API has the same parameters as +[List media uploaded by a user](#list-media-uploaded-by-a-user). +With the parameters you can for example limit the number of files to delete at once or +delete largest/smallest or newest/oldest files first. + ## Login as a user Get an access token that can be used to authenticate as that user. Useful for @@ -1012,4 +1057,3 @@ The following parameters should be set in the URL: - `user_id` - The fully qualified MXID: for example, `@user:server.com`. The user must be local. - diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 0a19a333d7..5f0555039d 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -259,7 +259,9 @@ class DeleteMediaByID(RestServlet): logging.info("Deleting local media by ID: %s", media_id) - deleted_media, total = await self.media_repository.delete_local_media(media_id) + deleted_media, total = await self.media_repository.delete_local_media_ids( + [media_id] + ) return 200, {"deleted_media": deleted_media, "total": total} diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index eef76ab18a..41f21ba118 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -172,7 +172,7 @@ class UserRestServletV2(RestServlet): target_user = UserID.from_string(user_id) if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(400, "Can only look up local users") ret = await self.admin_handler.get_user(target_user) @@ -796,7 +796,7 @@ class PushersRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(400, "Can only look up local users") if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") @@ -811,10 +811,10 @@ class PushersRestServlet(RestServlet): class UserMediaRestServlet(RestServlet): """ Gets information about all uploaded local media for a specific `user_id`. + With DELETE request you can delete all this media. Example: - http://localhost:8008/_synapse/admin/v1/users/ - @user:server/media + http://localhost:8008/_synapse/admin/v1/users/@user:server/media Args: The parameters `from` and `limit` are required for pagination. @@ -830,6 +830,7 @@ class UserMediaRestServlet(RestServlet): self.is_mine = hs.is_mine self.auth = hs.get_auth() self.store = hs.get_datastore() + self.media_repository = hs.get_media_repository() async def on_GET( self, request: SynapseRequest, user_id: str @@ -840,7 +841,7 @@ class UserMediaRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(400, "Can only look up local users") user = await self.store.get_user_by_id(user_id) if user is None: @@ -898,6 +899,73 @@ class UserMediaRestServlet(RestServlet): return 200, ret + async def on_DELETE( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + # This will always be set by the time Twisted calls us. + assert request.args is not None + + await assert_requester_is_admin(self.auth, request) + + if not self.is_mine(UserID.from_string(user_id)): + raise SynapseError(400, "Can only look up local users") + + user = await self.store.get_user_by_id(user_id) + if user is None: + raise NotFoundError("Unknown user") + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + 400, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + # If neither `order_by` nor `dir` is set, set the default order + # to newest media is on top for backward compatibility. + if b"order_by" not in request.args and b"dir" not in request.args: + order_by = MediaSortOrder.CREATED_TS.value + direction = "b" + else: + order_by = parse_string( + request, + "order_by", + default=MediaSortOrder.CREATED_TS.value, + allowed_values=( + MediaSortOrder.MEDIA_ID.value, + MediaSortOrder.UPLOAD_NAME.value, + MediaSortOrder.CREATED_TS.value, + MediaSortOrder.LAST_ACCESS_TS.value, + MediaSortOrder.MEDIA_LENGTH.value, + MediaSortOrder.MEDIA_TYPE.value, + MediaSortOrder.QUARANTINED_BY.value, + MediaSortOrder.SAFE_FROM_QUARANTINE.value, + ), + ) + direction = parse_string( + request, "dir", default="f", allowed_values=("f", "b") + ) + + media, _ = await self.store.get_local_media_by_user_paginate( + start, limit, user_id, order_by, direction + ) + + deleted_media, total = await self.media_repository.delete_local_media_ids( + ([row["media_id"] for row in media]) + ) + + return 200, {"deleted_media": deleted_media, "total": total} + class UserTokenRestServlet(RestServlet): """An admin API for logging in as a user. @@ -1017,7 +1085,7 @@ class RateLimitRestServlet(RestServlet): await assert_requester_is_admin(self.auth, request) if not self.hs.is_mine_id(user_id): - raise SynapseError(400, "Can only lookup local users") + raise SynapseError(400, "Can only look up local users") if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 4f702f890c..0f5ce41ff8 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -836,7 +836,9 @@ class MediaRepository: return {"deleted": deleted} - async def delete_local_media(self, media_id: str) -> Tuple[List[str], int]: + async def delete_local_media_ids( + self, media_ids: List[str] + ) -> Tuple[List[str], int]: """ Delete the given local or remote media ID from this server @@ -845,7 +847,7 @@ class MediaRepository: Returns: A tuple of (list of deleted media IDs, total deleted media IDs). """ - return await self._remove_local_media_from_disk([media_id]) + return await self._remove_local_media_from_disk(media_ids) async def delete_old_local_media( self, diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 42f50c0921..13fab5579b 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -15,17 +15,21 @@ import hashlib import hmac import json +import os import urllib.parse from binascii import unhexlify from typing import List, Optional from unittest.mock import Mock, patch +from parameterized import parameterized + import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions from synapse.rest.client.v1 import login, logout, profile, room from synapse.rest.client.v2_alpha import devices, sync +from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.types import JsonDict, UserID from tests import unittest @@ -72,7 +76,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual( "Shared secret registration is not enabled", channel.json_body["error"] ) @@ -104,7 +108,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = json.dumps({"nonce": nonce}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # 61 seconds @@ -112,7 +116,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_register_incorrect_nonce(self): @@ -166,7 +170,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) def test_nonce_reuse(self): @@ -191,13 +195,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) def test_missing_parts(self): @@ -219,7 +223,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = json.dumps({}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("nonce must be specified", channel.json_body["error"]) # @@ -230,28 +234,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = json.dumps({"nonce": nonce()}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": 1234}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # @@ -262,28 +266,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): body = json.dumps({"nonce": nonce(), "username": "a"}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("password must be specified", channel.json_body["error"]) # Must be a string body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Must not have null bytes body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Super long body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # @@ -301,7 +305,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid user type", channel.json_body["error"]) def test_displayname(self): @@ -322,11 +326,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob1:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob1:test/displayname") - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("bob1", channel.json_body["displayname"]) # displayname is None @@ -348,11 +352,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob2:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob2:test/displayname") - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("bob2", channel.json_body["displayname"]) # displayname is empty @@ -374,7 +378,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob3:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob3:test/displayname") @@ -399,11 +403,11 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob4:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob4:test/displayname") - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("Bob's Name", channel.json_body["displayname"]) @override_config( @@ -449,7 +453,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) channel = self.make_request("POST", self.url, body.encode("utf8")) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -638,7 +642,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid search order @@ -1085,7 +1089,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): content={"erase": False}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Get user channel = self.make_request( @@ -2180,7 +2184,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): ) self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) + self.assertEqual("Can only look up local users", channel.json_body["error"]) def test_get_pushers(self): """ @@ -2249,6 +2253,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.media_repo = hs.get_media_repository_resource() + self.filepaths = MediaFilePaths(hs.config.media_store_path) self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -2258,37 +2263,34 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.other_user ) - def test_no_auth(self): - """ - Try to list media of an user without authentication. - """ - channel = self.make_request("GET", self.url, b"{}") + @parameterized.expand(["GET", "DELETE"]) + def test_no_auth(self, method: str): + """Try to list media of an user without authentication.""" + channel = self.make_request(method, self.url, {}) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): - """ - If the user is not a server admin, an error is returned. - """ + @parameterized.expand(["GET", "DELETE"]) + def test_requester_is_no_admin(self, method: str): + """If the user is not a server admin, an error is returned.""" other_user_token = self.login("user", "pass") channel = self.make_request( - "GET", + method, self.url, access_token=other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self): - """ - Tests that a lookup for a user that does not exist returns a 404 - """ + @parameterized.expand(["GET", "DELETE"]) + def test_user_does_not_exist(self, method: str): + """Tests that a lookup for a user that does not exist returns a 404""" url = "/_synapse/admin/v1/users/@unknown_person:test/media" channel = self.make_request( - "GET", + method, url, access_token=self.admin_user_tok, ) @@ -2296,25 +2298,22 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_user_is_not_local(self): - """ - Tests that a lookup for a user that is not a local returns a 400 - """ + @parameterized.expand(["GET", "DELETE"]) + def test_user_is_not_local(self, method: str): + """Tests that a lookup for a user that is not a local returns a 400""" url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media" channel = self.make_request( - "GET", + method, url, access_token=self.admin_user_tok, ) self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) + self.assertEqual("Can only look up local users", channel.json_body["error"]) - def test_limit(self): - """ - Testing list of media with limit - """ + def test_limit_GET(self): + """Testing list of media with limit""" number_media = 20 other_user_tok = self.login("user", "pass") @@ -2326,16 +2325,31 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 5) self.assertEqual(channel.json_body["next_token"], 5) self._check_fields(channel.json_body["media"]) - def test_from(self): - """ - Testing list of media with a defined starting point (from) - """ + def test_limit_DELETE(self): + """Testing delete of media with limit""" + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media_for_user(other_user_tok, number_media) + + channel = self.make_request( + "DELETE", + self.url + "?limit=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 5) + self.assertEqual(len(channel.json_body["deleted_media"]), 5) + + def test_from_GET(self): + """Testing list of media with a defined starting point (from)""" number_media = 20 other_user_tok = self.login("user", "pass") @@ -2347,16 +2361,31 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 15) self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["media"]) - def test_limit_and_from(self): - """ - Testing list of media with a defined starting point and limit - """ + def test_from_DELETE(self): + """Testing delete of media with a defined starting point (from)""" + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media_for_user(other_user_tok, number_media) + + channel = self.make_request( + "DELETE", + self.url + "?from=5", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 15) + self.assertEqual(len(channel.json_body["deleted_media"]), 15) + + def test_limit_and_from_GET(self): + """Testing list of media with a defined starting point and limit""" number_media = 20 other_user_tok = self.login("user", "pass") @@ -2368,59 +2397,78 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(channel.json_body["next_token"], 15) self.assertEqual(len(channel.json_body["media"]), 10) self._check_fields(channel.json_body["media"]) - def test_invalid_parameter(self): - """ - If parameters are invalid, an error is returned. - """ + def test_limit_and_from_DELETE(self): + """Testing delete of media with a defined starting point and limit""" + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media_for_user(other_user_tok, number_media) + + channel = self.make_request( + "DELETE", + self.url + "?from=5&limit=10", + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], 10) + self.assertEqual(len(channel.json_body["deleted_media"]), 10) + + @parameterized.expand(["GET", "DELETE"]) + def test_invalid_parameter(self, method: str): + """If parameters are invalid, an error is returned.""" # unkown order_by channel = self.make_request( - "GET", + method, self.url + "?order_by=bar", access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid search order channel = self.make_request( - "GET", + method, self.url + "?dir=bar", access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # negative limit channel = self.make_request( - "GET", + method, self.url + "?limit=-5", access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from channel = self.make_request( - "GET", + method, self.url + "?from=-5", access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_next_token(self): """ Testing that `next_token` appears at the right place + + For deleting media `next_token` is not useful, because + after deleting media the media has a new order. """ number_media = 20 @@ -2435,7 +2483,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), number_media) self.assertNotIn("next_token", channel.json_body) @@ -2448,7 +2496,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), number_media) self.assertNotIn("next_token", channel.json_body) @@ -2461,7 +2509,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 19) self.assertEqual(channel.json_body["next_token"], 19) @@ -2475,12 +2523,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_media) self.assertEqual(len(channel.json_body["media"]), 1) self.assertNotIn("next_token", channel.json_body) - def test_user_has_no_media(self): + def test_user_has_no_media_GET(self): """ Tests that a normal lookup for media is successfully if user has no media created @@ -2496,11 +2544,24 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, len(channel.json_body["media"])) - def test_get_media(self): + def test_user_has_no_media_DELETE(self): """ - Tests that a normal lookup for media is successfully + Tests that a delete is successful if user has no media """ + channel = self.make_request( + "DELETE", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["deleted_media"])) + + def test_get_media(self): + """Tests that a normal lookup for media is successful""" + number_media = 5 other_user_tok = self.login("user", "pass") self._create_media_for_user(other_user_tok, number_media) @@ -2517,6 +2578,35 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["media"]) + def test_delete_media(self): + """Tests that a normal delete of media is successful""" + + number_media = 5 + other_user_tok = self.login("user", "pass") + media_ids = self._create_media_for_user(other_user_tok, number_media) + + # Test if the file exists + local_paths = [] + for media_id in media_ids: + local_path = self.filepaths.local_media_filepath(media_id) + self.assertTrue(os.path.exists(local_path)) + local_paths.append(local_path) + + channel = self.make_request( + "DELETE", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(number_media, channel.json_body["total"]) + self.assertEqual(number_media, len(channel.json_body["deleted_media"])) + self.assertCountEqual(channel.json_body["deleted_media"], media_ids) + + # Test if the file is deleted + for local_path in local_paths: + self.assertFalse(os.path.exists(local_path)) + def test_order_by(self): """ Testing order list with parameter `order_by` @@ -2622,13 +2712,16 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): [media2] + sorted([media1, media3]), "safe_from_quarantine", "b" ) - def _create_media_for_user(self, user_token: str, number_media: int): + def _create_media_for_user(self, user_token: str, number_media: int) -> List[str]: """ Create a number of media for a specific user Args: user_token: Access token of the user number_media: Number of media to be created for the user + Returns: + List of created media ID """ + media_ids = [] for _ in range(number_media): # file size is 67 Byte image_data = unhexlify( @@ -2637,7 +2730,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): b"0a2db40000000049454e44ae426082" ) - self._create_media_and_access(user_token, image_data) + media_ids.append(self._create_media_and_access(user_token, image_data)) + + return media_ids def _create_media_and_access( self, @@ -2680,7 +2775,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): 200, channel.code, msg=( - "Expected to receive a 200 on accessing media: %s" % server_and_media_id + f"Expected to receive a 200 on accessing media: {server_and_media_id}" ), ) @@ -2718,12 +2813,12 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): url = self.url + "?" if order_by is not None: - url += "order_by=%s&" % (order_by,) + url += f"order_by={order_by}&" if dir is not None and dir in ("b", "f"): - url += "dir=%s" % (dir,) + url += f"dir={dir}" channel = self.make_request( "GET", - url.encode("ascii"), + url, access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -2762,7 +2857,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "POST", self.url, b"{}", access_token=self.admin_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) return channel.json_body["access_token"] def test_no_auth(self): @@ -2803,7 +2898,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # We should only see the one device (from the login in `prepare`) self.assertEqual(len(channel.json_body["devices"]), 1) @@ -2815,11 +2910,11 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout with the puppet token channel = self.make_request("POST", "logout", b"{}", access_token=puppet_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) @@ -2829,7 +2924,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) def test_user_logout_all(self): """Tests that the target user calling `/logout/all` does *not* expire @@ -2840,17 +2935,17 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout all with the real user token channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should still work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # .. but the real user's tokens shouldn't channel = self.make_request( @@ -2867,13 +2962,13 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # Test that we can successfully make a request channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # Logout all with the admin user token channel = self.make_request( "POST", "logout/all", b"{}", access_token=self.admin_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) @@ -2883,7 +2978,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) @unittest.override_config( { @@ -3243,7 +3338,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): ) self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) + self.assertEqual("Can only look up local users", channel.json_body["error"]) channel = self.make_request( "POST", @@ -3279,7 +3374,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"messages_per_second": "string"}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # messages_per_second is negative @@ -3290,7 +3385,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"messages_per_second": -1}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # burst_count is a string @@ -3301,7 +3396,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"burst_count": "string"}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # burst_count is negative @@ -3312,7 +3407,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): content={"burst_count": -1}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) def test_return_zero_when_null(self): @@ -3337,7 +3432,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(0, channel.json_body["messages_per_second"]) self.assertEqual(0, channel.json_body["burst_count"]) @@ -3351,7 +3446,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3362,7 +3457,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"messages_per_second": 10, "burst_count": 11}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(10, channel.json_body["messages_per_second"]) self.assertEqual(11, channel.json_body["burst_count"]) @@ -3373,7 +3468,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, content={"messages_per_second": 20, "burst_count": 21}, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(21, channel.json_body["burst_count"]) @@ -3383,7 +3478,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(20, channel.json_body["messages_per_second"]) self.assertEqual(21, channel.json_body["burst_count"]) @@ -3393,7 +3488,7 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) @@ -3403,6 +3498,6 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.url, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertNotIn("messages_per_second", channel.json_body) self.assertNotIn("burst_count", channel.json_body) -- cgit 1.5.1 From 98a3355d9a58538cfbc1c88020e6b6d9bccea516 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 11 Aug 2021 15:44:45 -0400 Subject: Update the pagination parameter name based on MSC2946 review. (#10579) --- changelog.d/10579.feature | 1 + synapse/handlers/space_summary.py | 6 +++--- tests/handlers/test_space_summary.py | 14 +++++++------- 3 files changed, 11 insertions(+), 10 deletions(-) create mode 100644 changelog.d/10579.feature diff --git a/changelog.d/10579.feature b/changelog.d/10579.feature new file mode 100644 index 0000000000..ffc4e4289c --- /dev/null +++ b/changelog.d/10579.feature @@ -0,0 +1 @@ +Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index 893546e661..d0060f9046 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -412,10 +412,10 @@ class SpaceSummaryHandler: # If there's additional data, generate a pagination token (and persist state). if room_queue: - next_token = random_string(24) - result["next_token"] = next_token + next_batch = random_string(24) + result["next_batch"] = next_batch pagination_key = _PaginationKey( - requested_room_id, suggested_only, max_depth, next_token + requested_room_id, suggested_only, max_depth, next_batch ) self._pagination_sessions[pagination_key] = _PaginationSession( self._clock.time_msec(), room_queue, processed_rooms diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index 806b886fe4..83c2bdd8f9 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -466,19 +466,19 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): expected: List[Tuple[str, Iterable[str]]] = [(self.space, room_ids)] expected += [(room_id, ()) for room_id in room_ids[:6]] self._assert_hierarchy(result, expected) - self.assertIn("next_token", result) + self.assertIn("next_batch", result) # Check the next page. result = self.get_success( self.handler.get_room_hierarchy( - self.user, self.space, limit=5, from_token=result["next_token"] + self.user, self.space, limit=5, from_token=result["next_batch"] ) ) # The result should have the space and the room in it, along with a link # from space -> room. expected = [(room_id, ()) for room_id in room_ids[6:]] self._assert_hierarchy(result, expected) - self.assertNotIn("next_token", result) + self.assertNotIn("next_batch", result) def test_invalid_pagination_token(self): """""" @@ -493,12 +493,12 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result = self.get_success( self.handler.get_room_hierarchy(self.user, self.space, limit=7) ) - self.assertIn("next_token", result) + self.assertIn("next_batch", result) # Changing the room ID, suggested-only, or max-depth causes an error. self.get_failure( self.handler.get_room_hierarchy( - self.user, self.room, from_token=result["next_token"] + self.user, self.room, from_token=result["next_batch"] ), SynapseError, ) @@ -507,13 +507,13 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self.user, self.space, suggested_only=True, - from_token=result["next_token"], + from_token=result["next_batch"], ), SynapseError, ) self.get_failure( self.handler.get_room_hierarchy( - self.user, self.space, max_depth=0, from_token=result["next_token"] + self.user, self.space, max_depth=0, from_token=result["next_batch"] ), SynapseError, ) -- cgit 1.5.1 From 314a739160effac7501201cadfc8aaa9c4f34713 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 12 Aug 2021 10:40:32 +0100 Subject: Also rename in lint.sh --- scripts-dev/lint.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index 2c77643cda..809eff166a 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -94,7 +94,7 @@ else "scripts-dev/build_debian_packages" "scripts-dev/sign_json" "scripts-dev/update_database" - "contrib" "synctl" "setup.py" "synmark" "stubs" "ci" + "contrib" "synctl" "setup.py" "synmark" "stubs" ".ci" ) fi fi -- cgit 1.5.1 From 74fcd5aab9de111d5c306d3ed28a3f3ef63f3e3e Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 12 Aug 2021 10:41:01 +0100 Subject: portdb also uses coverage, so provide $TOP there --- .github/workflows/tests.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index de022020cc..6874d253ca 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -250,6 +250,8 @@ jobs: if: ${{ !failure() && !cancelled() }} # Allow previous steps to be skipped, but not fail needs: linting-done runs-on: ubuntu-latest + env: + TOP: ${{ github.workspace }} strategy: matrix: include: -- cgit 1.5.1 From 878528913d2927bba5ba8795c405f8a7475934cd Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 12 Aug 2021 11:48:36 +0100 Subject: Remove buildkite-era comment --- .github/workflows/tests.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6874d253ca..8736699ad8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -38,9 +38,6 @@ jobs: if: ${{ github.base_ref == 'develop' || contains(github.base_ref, 'release-') }} runs-on: ubuntu-latest steps: - # Note: This and the script can be simplified once we drop Buildkite. See: - # https://github.com/actions/checkout/issues/266#issuecomment-638346893 - # https://github.com/actions/checkout/issues/416 - uses: actions/checkout@v2 with: ref: ${{ github.event.pull_request.head.sha }} -- cgit 1.5.1 From d2ad397d3cbd7e675abbb1f48072f9972c60823d Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 12 Aug 2021 16:50:18 +0100 Subject: Stop building a debian package for Groovy Gorilla (#10588) --- changelog.d/10588.removal | 1 + scripts-dev/build_debian_packages | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 changelog.d/10588.removal diff --git a/changelog.d/10588.removal b/changelog.d/10588.removal new file mode 100644 index 0000000000..90c4b5cee2 --- /dev/null +++ b/changelog.d/10588.removal @@ -0,0 +1 @@ +No longer build `.dev` packages for Ubuntu 20.10 LTS Groovy Gorilla, which has now EOLed. \ No newline at end of file diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages index 0ed1c679fd..6153cb225f 100755 --- a/scripts-dev/build_debian_packages +++ b/scripts-dev/build_debian_packages @@ -25,7 +25,6 @@ DISTS = ( "debian:sid", "ubuntu:bionic", # 18.04 LTS (our EOL forced by Py36 on 2021-12-23) "ubuntu:focal", # 20.04 LTS (our EOL forced by Py38 on 2024-10-14) - "ubuntu:groovy", # 20.10 (EOL 2021-07-07) "ubuntu:hirsute", # 21.04 (EOL 2022-01-05) ) -- cgit 1.5.1 From c12b5577f22ee587b60ad7b65e88322ce1d86b7b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 13 Aug 2021 07:49:06 -0400 Subject: Fix a harmless exception when the staged events queue is empty. (#10592) --- changelog.d/10592.bugfix | 1 + synapse/federation/federation_server.py | 15 ++++++++++----- 2 files changed, 11 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10592.bugfix diff --git a/changelog.d/10592.bugfix b/changelog.d/10592.bugfix new file mode 100644 index 0000000000..efcdab1136 --- /dev/null +++ b/changelog.d/10592.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.37.1 where an error could occur in the asyncronous processing of PDUs when the queue was empty. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 0385aadefa..78d5aac6af 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -972,13 +972,18 @@ class FederationServer(FederationBase): # the room, so instead of pulling the event out of the DB and parsing # the event we just pull out the next event ID and check if that matches. if latest_event is not None and latest_origin is not None: - ( - next_origin, - next_event_id, - ) = await self.store.get_next_staged_event_id_for_room(room_id) - if next_origin != latest_origin or next_event_id != latest_event.event_id: + result = await self.store.get_next_staged_event_id_for_room(room_id) + if result is None: latest_origin = None latest_event = None + else: + next_origin, next_event_id = result + if ( + next_origin != latest_origin + or next_event_id != latest_event.event_id + ): + latest_origin = None + latest_event = None if latest_origin is None or latest_event is None: next = await self.store.get_next_staged_event_for_room( -- cgit 1.5.1 From c8d54be44c1da451f01504664d568dd2f2b37316 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 13 Aug 2021 14:37:24 -0500 Subject: Move /batch_send to /v2_alpha directory (MSC2716) (#10576) * Move /batch_send to /v2_alpha directory As pointed out by @erikjohnston, https://github.com/matrix-org/synapse/pull/10552#discussion_r685836624 --- changelog.d/10576.misc | 1 + synapse/rest/__init__.py | 2 + synapse/rest/client/v1/room.py | 410 +------------------------------- synapse/rest/client/v2_alpha/room.py | 441 +++++++++++++++++++++++++++++++++++ 4 files changed, 445 insertions(+), 409 deletions(-) create mode 100644 changelog.d/10576.misc create mode 100644 synapse/rest/client/v2_alpha/room.py diff --git a/changelog.d/10576.misc b/changelog.d/10576.misc new file mode 100644 index 0000000000..f9f9c9a6fd --- /dev/null +++ b/changelog.d/10576.misc @@ -0,0 +1 @@ +Move `/batch_send` endpoint defined by [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) to the `/v2_alpha` directory. diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index d29f2fea5e..9cffe59ce5 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -47,6 +47,7 @@ from synapse.rest.client.v2_alpha import ( register, relations, report_event, + room as roomv2, room_keys, room_upgrade_rest_servlet, sendtodevice, @@ -117,6 +118,7 @@ class ClientRestResource(JsonResource): user_directory.register_servlets(hs, client_resource) groups.register_servlets(hs, client_resource) room_upgrade_rest_servlet.register_servlets(hs, client_resource) + roomv2.register_servlets(hs, client_resource) capabilities.register_servlets(hs, client_resource) account_validity.register_servlets(hs, client_resource) relations.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index f1bc43be2d..2c3be23bc8 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -19,7 +19,7 @@ import re from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from urllib import parse as urlparse -from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -28,7 +28,6 @@ from synapse.api.errors import ( SynapseError, ) from synapse.api.filtering import Filter -from synapse.appservice import ApplicationService from synapse.events.utils import format_event_for_client_v2 from synapse.http.servlet import ( RestServlet, @@ -47,13 +46,11 @@ from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import ( JsonDict, - Requester, RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID, - create_requester, ) from synapse.util import json_decoder from synapse.util.stringutils import parse_and_validate_server_name, random_string @@ -268,407 +265,6 @@ class RoomSendEventRestServlet(TransactionRestServlet): ) -class RoomBatchSendEventRestServlet(TransactionRestServlet): - """ - API endpoint which can insert a chunk of events historically back in time - next to the given `prev_event`. - - `chunk_id` comes from `next_chunk_id `in the response of the batch send - endpoint and is derived from the "insertion" events added to each chunk. - It's not required for the first batch send. - - `state_events_at_start` is used to define the historical state events - needed to auth the events like join events. These events will float - outside of the normal DAG as outlier's and won't be visible in the chat - history which also allows us to insert multiple chunks without having a bunch - of `@mxid joined the room` noise between each chunk. - - `events` is chronological chunk/list of events you want to insert. - There is a reverse-chronological constraint on chunks so once you insert - some messages, you can only insert older ones after that. - tldr; Insert chunks from your most recent history -> oldest history. - - POST /_matrix/client/unstable/org.matrix.msc2716/rooms//batch_send?prev_event=&chunk_id= - { - "events": [ ... ], - "state_events_at_start": [ ... ] - } - """ - - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc2716" - "/rooms/(?P[^/]*)/batch_send$" - ), - ) - - def __init__(self, hs): - super().__init__(hs) - self.hs = hs - self.store = hs.get_datastore() - self.state_store = hs.get_storage().state - self.event_creation_handler = hs.get_event_creation_handler() - self.room_member_handler = hs.get_room_member_handler() - self.auth = hs.get_auth() - - async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int: - ( - most_recent_prev_event_id, - most_recent_prev_event_depth, - ) = await self.store.get_max_depth_of(prev_event_ids) - - # We want to insert the historical event after the `prev_event` but before the successor event - # - # We inherit depth from the successor event instead of the `prev_event` - # because events returned from `/messages` are first sorted by `topological_ordering` - # which is just the `depth` and then tie-break with `stream_ordering`. - # - # We mark these inserted historical events as "backfilled" which gives them a - # negative `stream_ordering`. If we use the same depth as the `prev_event`, - # then our historical event will tie-break and be sorted before the `prev_event` - # when it should come after. - # - # We want to use the successor event depth so they appear after `prev_event` because - # it has a larger `depth` but before the successor event because the `stream_ordering` - # is negative before the successor event. - successor_event_ids = await self.store.get_successor_events( - [most_recent_prev_event_id] - ) - - # If we can't find any successor events, then it's a forward extremity of - # historical messages and we can just inherit from the previous historical - # event which we can already assume has the correct depth where we want - # to insert into. - if not successor_event_ids: - depth = most_recent_prev_event_depth - else: - ( - _, - oldest_successor_depth, - ) = await self.store.get_min_depth_of(successor_event_ids) - - depth = oldest_successor_depth - - return depth - - def _create_insertion_event_dict( - self, sender: str, room_id: str, origin_server_ts: int - ): - """Creates an event dict for an "insertion" event with the proper fields - and a random chunk ID. - - Args: - sender: The event author MXID - room_id: The room ID that the event belongs to - origin_server_ts: Timestamp when the event was sent - - Returns: - Tuple of event ID and stream ordering position - """ - - next_chunk_id = random_string(8) - insertion_event = { - "type": EventTypes.MSC2716_INSERTION, - "sender": sender, - "room_id": room_id, - "content": { - EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id, - EventContentFields.MSC2716_HISTORICAL: True, - }, - "origin_server_ts": origin_server_ts, - } - - return insertion_event - - async def _create_requester_for_user_id_from_app_service( - self, user_id: str, app_service: ApplicationService - ) -> Requester: - """Creates a new requester for the given user_id - and validates that the app service is allowed to control - the given user. - - Args: - user_id: The author MXID that the app service is controlling - app_service: The app service that controls the user - - Returns: - Requester object - """ - - await self.auth.validate_appservice_can_control_user_id(app_service, user_id) - - return create_requester(user_id, app_service=app_service) - - async def on_POST(self, request, room_id): - requester = await self.auth.get_user_by_req(request, allow_guest=False) - - if not requester.app_service: - raise AuthError( - 403, - "Only application services can use the /batchsend endpoint", - ) - - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["state_events_at_start", "events"]) - - prev_events_from_query = parse_strings_from_args(request.args, "prev_event") - chunk_id_from_query = parse_string(request, "chunk_id") - - if prev_events_from_query is None: - raise SynapseError( - 400, - "prev_event query parameter is required when inserting historical messages back in time", - errcode=Codes.MISSING_PARAM, - ) - - # For the event we are inserting next to (`prev_events_from_query`), - # find the most recent auth events (derived from state events) that - # allowed that message to be sent. We will use that as a base - # to auth our historical messages against. - ( - most_recent_prev_event_id, - _, - ) = await self.store.get_max_depth_of(prev_events_from_query) - # mapping from (type, state_key) -> state_event_id - prev_state_map = await self.state_store.get_state_ids_for_event( - most_recent_prev_event_id - ) - # List of state event ID's - prev_state_ids = list(prev_state_map.values()) - auth_event_ids = prev_state_ids - - state_events_at_start = [] - for state_event in body["state_events_at_start"]: - assert_params_in_dict( - state_event, ["type", "origin_server_ts", "content", "sender"] - ) - - logger.debug( - "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s", - state_event, - auth_event_ids, - ) - - event_dict = { - "type": state_event["type"], - "origin_server_ts": state_event["origin_server_ts"], - "content": state_event["content"], - "room_id": room_id, - "sender": state_event["sender"], - "state_key": state_event["state_key"], - } - - # Mark all events as historical - event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - - # Make the state events float off on their own - fake_prev_event_id = "$" + random_string(43) - - # TODO: This is pretty much the same as some other code to handle inserting state in this file - if event_dict["type"] == EventTypes.Member: - membership = event_dict["content"].get("membership", None) - event_id, _ = await self.room_member_handler.update_membership( - await self._create_requester_for_user_id_from_app_service( - state_event["sender"], requester.app_service - ), - target=UserID.from_string(event_dict["state_key"]), - room_id=room_id, - action=membership, - content=event_dict["content"], - outlier=True, - prev_event_ids=[fake_prev_event_id], - # Make sure to use a copy of this list because we modify it - # later in the loop here. Otherwise it will be the same - # reference and also update in the event when we append later. - auth_event_ids=auth_event_ids.copy(), - ) - else: - # TODO: Add some complement tests that adds state that is not member joins - # and will use this code path. Maybe we only want to support join state events - # and can get rid of this `else`? - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - await self._create_requester_for_user_id_from_app_service( - state_event["sender"], requester.app_service - ), - event_dict, - outlier=True, - prev_event_ids=[fake_prev_event_id], - # Make sure to use a copy of this list because we modify it - # later in the loop here. Otherwise it will be the same - # reference and also update in the event when we append later. - auth_event_ids=auth_event_ids.copy(), - ) - event_id = event.event_id - - state_events_at_start.append(event_id) - auth_event_ids.append(event_id) - - events_to_create = body["events"] - - inherited_depth = await self._inherit_depth_from_prev_ids( - prev_events_from_query - ) - - # Figure out which chunk to connect to. If they passed in - # chunk_id_from_query let's use it. The chunk ID passed in comes - # from the chunk_id in the "insertion" event from the previous chunk. - last_event_in_chunk = events_to_create[-1] - chunk_id_to_connect_to = chunk_id_from_query - base_insertion_event = None - if chunk_id_from_query: - # All but the first base insertion event should point at a fake - # event, which causes the HS to ask for the state at the start of - # the chunk later. - prev_event_ids = [fake_prev_event_id] - # TODO: Verify the chunk_id_from_query corresponds to an insertion event - pass - # Otherwise, create an insertion event to act as a starting point. - # - # We don't always have an insertion event to start hanging more history - # off of (ideally there would be one in the main DAG, but that's not the - # case if we're wanting to add history to e.g. existing rooms without - # an insertion event), in which case we just create a new insertion event - # that can then get pointed to by a "marker" event later. - else: - prev_event_ids = prev_events_from_query - - base_insertion_event_dict = self._create_insertion_event_dict( - sender=requester.user.to_string(), - room_id=room_id, - origin_server_ts=last_event_in_chunk["origin_server_ts"], - ) - base_insertion_event_dict["prev_events"] = prev_event_ids.copy() - - ( - base_insertion_event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - await self._create_requester_for_user_id_from_app_service( - base_insertion_event_dict["sender"], - requester.app_service, - ), - base_insertion_event_dict, - prev_event_ids=base_insertion_event_dict.get("prev_events"), - auth_event_ids=auth_event_ids, - historical=True, - depth=inherited_depth, - ) - - chunk_id_to_connect_to = base_insertion_event["content"][ - EventContentFields.MSC2716_NEXT_CHUNK_ID - ] - - # Connect this current chunk to the insertion event from the previous chunk - chunk_event = { - "type": EventTypes.MSC2716_CHUNK, - "sender": requester.user.to_string(), - "room_id": room_id, - "content": { - EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to, - EventContentFields.MSC2716_HISTORICAL: True, - }, - # Since the chunk event is put at the end of the chunk, - # where the newest-in-time event is, copy the origin_server_ts from - # the last event we're inserting - "origin_server_ts": last_event_in_chunk["origin_server_ts"], - } - # Add the chunk event to the end of the chunk (newest-in-time) - events_to_create.append(chunk_event) - - # Add an "insertion" event to the start of each chunk (next to the oldest-in-time - # event in the chunk) so the next chunk can be connected to this one. - insertion_event = self._create_insertion_event_dict( - sender=requester.user.to_string(), - room_id=room_id, - # Since the insertion event is put at the start of the chunk, - # where the oldest-in-time event is, copy the origin_server_ts from - # the first event we're inserting - origin_server_ts=events_to_create[0]["origin_server_ts"], - ) - # Prepend the insertion event to the start of the chunk (oldest-in-time) - events_to_create = [insertion_event] + events_to_create - - event_ids = [] - events_to_persist = [] - for ev in events_to_create: - assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) - - event_dict = { - "type": ev["type"], - "origin_server_ts": ev["origin_server_ts"], - "content": ev["content"], - "room_id": room_id, - "sender": ev["sender"], # requester.user.to_string(), - "prev_events": prev_event_ids.copy(), - } - - # Mark all events as historical - event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - - event, context = await self.event_creation_handler.create_event( - await self._create_requester_for_user_id_from_app_service( - ev["sender"], requester.app_service - ), - event_dict, - prev_event_ids=event_dict.get("prev_events"), - auth_event_ids=auth_event_ids, - historical=True, - depth=inherited_depth, - ) - logger.debug( - "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s", - event, - prev_event_ids, - auth_event_ids, - ) - - assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( - event.sender, - ) - - events_to_persist.append((event, context)) - event_id = event.event_id - - event_ids.append(event_id) - prev_event_ids = [event_id] - - # Persist events in reverse-chronological order so they have the - # correct stream_ordering as they are backfilled (which decrements). - # Events are sorted by (topological_ordering, stream_ordering) - # where topological_ordering is just depth. - for (event, context) in reversed(events_to_persist): - ev = await self.event_creation_handler.handle_new_client_event( - await self._create_requester_for_user_id_from_app_service( - event["sender"], requester.app_service - ), - event=event, - context=context, - ) - - # Add the base_insertion_event to the bottom of the list we return - if base_insertion_event is not None: - event_ids.append(base_insertion_event.event_id) - - return 200, { - "state_events": state_events_at_start, - "events": event_ids, - "next_chunk_id": insertion_event["content"][ - EventContentFields.MSC2716_NEXT_CHUNK_ID - ], - } - - def on_GET(self, request, room_id): - return 501, "Not implemented" - - def on_PUT(self, request, room_id): - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id - ) - - # TODO: Needs unit testing for room ID + alias joins class JoinRoomAliasServlet(TransactionRestServlet): def __init__(self, hs): @@ -1488,8 +1084,6 @@ class RoomHierarchyRestServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server, is_worker=False): - msc2716_enabled = hs.config.experimental.msc2716_enabled - RoomStateEventRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server) JoinedRoomMemberListRestServlet(hs).register(http_server) @@ -1497,8 +1091,6 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False): JoinRoomAliasServlet(hs).register(http_server) RoomMembershipRestServlet(hs).register(http_server) RoomSendEventRestServlet(hs).register(http_server) - if msc2716_enabled: - RoomBatchSendEventRestServlet(hs).register(http_server) PublicRoomListRestServlet(hs).register(http_server) RoomStateRestServlet(hs).register(http_server) RoomRedactEventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/room.py b/synapse/rest/client/v2_alpha/room.py new file mode 100644 index 0000000000..3172aba605 --- /dev/null +++ b/synapse/rest/client/v2_alpha/room.py @@ -0,0 +1,441 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.appservice import ApplicationService +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, + parse_string, + parse_strings_from_args, +) +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.types import Requester, UserID, create_requester +from synapse.util.stringutils import random_string + +logger = logging.getLogger(__name__) + + +class RoomBatchSendEventRestServlet(RestServlet): + """ + API endpoint which can insert a chunk of events historically back in time + next to the given `prev_event`. + + `chunk_id` comes from `next_chunk_id `in the response of the batch send + endpoint and is derived from the "insertion" events added to each chunk. + It's not required for the first batch send. + + `state_events_at_start` is used to define the historical state events + needed to auth the events like join events. These events will float + outside of the normal DAG as outlier's and won't be visible in the chat + history which also allows us to insert multiple chunks without having a bunch + of `@mxid joined the room` noise between each chunk. + + `events` is chronological chunk/list of events you want to insert. + There is a reverse-chronological constraint on chunks so once you insert + some messages, you can only insert older ones after that. + tldr; Insert chunks from your most recent history -> oldest history. + + POST /_matrix/client/unstable/org.matrix.msc2716/rooms//batch_send?prev_event=&chunk_id= + { + "events": [ ... ], + "state_events_at_start": [ ... ] + } + """ + + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc2716" + "/rooms/(?P[^/]*)/batch_send$" + ), + ) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.store = hs.get_datastore() + self.state_store = hs.get_storage().state + self.event_creation_handler = hs.get_event_creation_handler() + self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() + self.txns = HttpTransactionCache(hs) + + async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int: + ( + most_recent_prev_event_id, + most_recent_prev_event_depth, + ) = await self.store.get_max_depth_of(prev_event_ids) + + # We want to insert the historical event after the `prev_event` but before the successor event + # + # We inherit depth from the successor event instead of the `prev_event` + # because events returned from `/messages` are first sorted by `topological_ordering` + # which is just the `depth` and then tie-break with `stream_ordering`. + # + # We mark these inserted historical events as "backfilled" which gives them a + # negative `stream_ordering`. If we use the same depth as the `prev_event`, + # then our historical event will tie-break and be sorted before the `prev_event` + # when it should come after. + # + # We want to use the successor event depth so they appear after `prev_event` because + # it has a larger `depth` but before the successor event because the `stream_ordering` + # is negative before the successor event. + successor_event_ids = await self.store.get_successor_events( + [most_recent_prev_event_id] + ) + + # If we can't find any successor events, then it's a forward extremity of + # historical messages and we can just inherit from the previous historical + # event which we can already assume has the correct depth where we want + # to insert into. + if not successor_event_ids: + depth = most_recent_prev_event_depth + else: + ( + _, + oldest_successor_depth, + ) = await self.store.get_min_depth_of(successor_event_ids) + + depth = oldest_successor_depth + + return depth + + def _create_insertion_event_dict( + self, sender: str, room_id: str, origin_server_ts: int + ): + """Creates an event dict for an "insertion" event with the proper fields + and a random chunk ID. + + Args: + sender: The event author MXID + room_id: The room ID that the event belongs to + origin_server_ts: Timestamp when the event was sent + + Returns: + Tuple of event ID and stream ordering position + """ + + next_chunk_id = random_string(8) + insertion_event = { + "type": EventTypes.MSC2716_INSERTION, + "sender": sender, + "room_id": room_id, + "content": { + EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id, + EventContentFields.MSC2716_HISTORICAL: True, + }, + "origin_server_ts": origin_server_ts, + } + + return insertion_event + + async def _create_requester_for_user_id_from_app_service( + self, user_id: str, app_service: ApplicationService + ) -> Requester: + """Creates a new requester for the given user_id + and validates that the app service is allowed to control + the given user. + + Args: + user_id: The author MXID that the app service is controlling + app_service: The app service that controls the user + + Returns: + Requester object + """ + + await self.auth.validate_appservice_can_control_user_id(app_service, user_id) + + return create_requester(user_id, app_service=app_service) + + async def on_POST(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=False) + + if not requester.app_service: + raise AuthError( + 403, + "Only application services can use the /batchsend endpoint", + ) + + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ["state_events_at_start", "events"]) + + prev_events_from_query = parse_strings_from_args(request.args, "prev_event") + chunk_id_from_query = parse_string(request, "chunk_id") + + if prev_events_from_query is None: + raise SynapseError( + 400, + "prev_event query parameter is required when inserting historical messages back in time", + errcode=Codes.MISSING_PARAM, + ) + + # For the event we are inserting next to (`prev_events_from_query`), + # find the most recent auth events (derived from state events) that + # allowed that message to be sent. We will use that as a base + # to auth our historical messages against. + ( + most_recent_prev_event_id, + _, + ) = await self.store.get_max_depth_of(prev_events_from_query) + # mapping from (type, state_key) -> state_event_id + prev_state_map = await self.state_store.get_state_ids_for_event( + most_recent_prev_event_id + ) + # List of state event ID's + prev_state_ids = list(prev_state_map.values()) + auth_event_ids = prev_state_ids + + state_events_at_start = [] + for state_event in body["state_events_at_start"]: + assert_params_in_dict( + state_event, ["type", "origin_server_ts", "content", "sender"] + ) + + logger.debug( + "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s", + state_event, + auth_event_ids, + ) + + event_dict = { + "type": state_event["type"], + "origin_server_ts": state_event["origin_server_ts"], + "content": state_event["content"], + "room_id": room_id, + "sender": state_event["sender"], + "state_key": state_event["state_key"], + } + + # Mark all events as historical + event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True + + # Make the state events float off on their own + fake_prev_event_id = "$" + random_string(43) + + # TODO: This is pretty much the same as some other code to handle inserting state in this file + if event_dict["type"] == EventTypes.Member: + membership = event_dict["content"].get("membership", None) + event_id, _ = await self.room_member_handler.update_membership( + await self._create_requester_for_user_id_from_app_service( + state_event["sender"], requester.app_service + ), + target=UserID.from_string(event_dict["state_key"]), + room_id=room_id, + action=membership, + content=event_dict["content"], + outlier=True, + prev_event_ids=[fake_prev_event_id], + # Make sure to use a copy of this list because we modify it + # later in the loop here. Otherwise it will be the same + # reference and also update in the event when we append later. + auth_event_ids=auth_event_ids.copy(), + ) + else: + # TODO: Add some complement tests that adds state that is not member joins + # and will use this code path. Maybe we only want to support join state events + # and can get rid of this `else`? + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + await self._create_requester_for_user_id_from_app_service( + state_event["sender"], requester.app_service + ), + event_dict, + outlier=True, + prev_event_ids=[fake_prev_event_id], + # Make sure to use a copy of this list because we modify it + # later in the loop here. Otherwise it will be the same + # reference and also update in the event when we append later. + auth_event_ids=auth_event_ids.copy(), + ) + event_id = event.event_id + + state_events_at_start.append(event_id) + auth_event_ids.append(event_id) + + events_to_create = body["events"] + + inherited_depth = await self._inherit_depth_from_prev_ids( + prev_events_from_query + ) + + # Figure out which chunk to connect to. If they passed in + # chunk_id_from_query let's use it. The chunk ID passed in comes + # from the chunk_id in the "insertion" event from the previous chunk. + last_event_in_chunk = events_to_create[-1] + chunk_id_to_connect_to = chunk_id_from_query + base_insertion_event = None + if chunk_id_from_query: + # All but the first base insertion event should point at a fake + # event, which causes the HS to ask for the state at the start of + # the chunk later. + prev_event_ids = [fake_prev_event_id] + # TODO: Verify the chunk_id_from_query corresponds to an insertion event + pass + # Otherwise, create an insertion event to act as a starting point. + # + # We don't always have an insertion event to start hanging more history + # off of (ideally there would be one in the main DAG, but that's not the + # case if we're wanting to add history to e.g. existing rooms without + # an insertion event), in which case we just create a new insertion event + # that can then get pointed to by a "marker" event later. + else: + prev_event_ids = prev_events_from_query + + base_insertion_event_dict = self._create_insertion_event_dict( + sender=requester.user.to_string(), + room_id=room_id, + origin_server_ts=last_event_in_chunk["origin_server_ts"], + ) + base_insertion_event_dict["prev_events"] = prev_event_ids.copy() + + ( + base_insertion_event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + await self._create_requester_for_user_id_from_app_service( + base_insertion_event_dict["sender"], + requester.app_service, + ), + base_insertion_event_dict, + prev_event_ids=base_insertion_event_dict.get("prev_events"), + auth_event_ids=auth_event_ids, + historical=True, + depth=inherited_depth, + ) + + chunk_id_to_connect_to = base_insertion_event["content"][ + EventContentFields.MSC2716_NEXT_CHUNK_ID + ] + + # Connect this current chunk to the insertion event from the previous chunk + chunk_event = { + "type": EventTypes.MSC2716_CHUNK, + "sender": requester.user.to_string(), + "room_id": room_id, + "content": { + EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to, + EventContentFields.MSC2716_HISTORICAL: True, + }, + # Since the chunk event is put at the end of the chunk, + # where the newest-in-time event is, copy the origin_server_ts from + # the last event we're inserting + "origin_server_ts": last_event_in_chunk["origin_server_ts"], + } + # Add the chunk event to the end of the chunk (newest-in-time) + events_to_create.append(chunk_event) + + # Add an "insertion" event to the start of each chunk (next to the oldest-in-time + # event in the chunk) so the next chunk can be connected to this one. + insertion_event = self._create_insertion_event_dict( + sender=requester.user.to_string(), + room_id=room_id, + # Since the insertion event is put at the start of the chunk, + # where the oldest-in-time event is, copy the origin_server_ts from + # the first event we're inserting + origin_server_ts=events_to_create[0]["origin_server_ts"], + ) + # Prepend the insertion event to the start of the chunk (oldest-in-time) + events_to_create = [insertion_event] + events_to_create + + event_ids = [] + events_to_persist = [] + for ev in events_to_create: + assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) + + event_dict = { + "type": ev["type"], + "origin_server_ts": ev["origin_server_ts"], + "content": ev["content"], + "room_id": room_id, + "sender": ev["sender"], # requester.user.to_string(), + "prev_events": prev_event_ids.copy(), + } + + # Mark all events as historical + event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True + + event, context = await self.event_creation_handler.create_event( + await self._create_requester_for_user_id_from_app_service( + ev["sender"], requester.app_service + ), + event_dict, + prev_event_ids=event_dict.get("prev_events"), + auth_event_ids=auth_event_ids, + historical=True, + depth=inherited_depth, + ) + logger.debug( + "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s", + event, + prev_event_ids, + auth_event_ids, + ) + + assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( + event.sender, + ) + + events_to_persist.append((event, context)) + event_id = event.event_id + + event_ids.append(event_id) + prev_event_ids = [event_id] + + # Persist events in reverse-chronological order so they have the + # correct stream_ordering as they are backfilled (which decrements). + # Events are sorted by (topological_ordering, stream_ordering) + # where topological_ordering is just depth. + for (event, context) in reversed(events_to_persist): + ev = await self.event_creation_handler.handle_new_client_event( + await self._create_requester_for_user_id_from_app_service( + event["sender"], requester.app_service + ), + event=event, + context=context, + ) + + # Add the base_insertion_event to the bottom of the list we return + if base_insertion_event is not None: + event_ids.append(base_insertion_event.event_id) + + return 200, { + "state_events": state_events_at_start, + "events": event_ids, + "next_chunk_id": insertion_event["content"][ + EventContentFields.MSC2716_NEXT_CHUNK_ID + ], + } + + def on_GET(self, request, room_id): + return 501, "Not implemented" + + def on_PUT(self, request, room_id): + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id + ) + + +def register_servlets(hs, http_server): + msc2716_enabled = hs.config.experimental.msc2716_enabled + + if msc2716_enabled: + RoomBatchSendEventRestServlet(hs).register(http_server) -- cgit 1.5.1 From d1f43b731ce041023563b1b814f858465a199630 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 16 Aug 2021 12:57:09 +0200 Subject: Update the Synapse Grafana dashboard (#10570) --- changelog.d/10570.feature | 1 + contrib/grafana/synapse.json | 550 +++++++++++++++++++++++++++++++++++++------ 2 files changed, 485 insertions(+), 66 deletions(-) create mode 100644 changelog.d/10570.feature diff --git a/changelog.d/10570.feature b/changelog.d/10570.feature new file mode 100644 index 0000000000..bd432685b3 --- /dev/null +++ b/changelog.d/10570.feature @@ -0,0 +1 @@ +Update the Synapse Grafana dashboard. diff --git a/contrib/grafana/synapse.json b/contrib/grafana/synapse.json index 0c4816b7cd..ed1e8ba7f8 100644 --- a/contrib/grafana/synapse.json +++ b/contrib/grafana/synapse.json @@ -54,7 +54,7 @@ "gnetId": null, "graphTooltip": 0, "id": null, - "iteration": 1621258266004, + "iteration": 1628606819564, "links": [ { "asDropdown": false, @@ -307,7 +307,6 @@ ], "thresholds": [ { - "$$hashKey": "object:283", "colorMode": "warning", "fill": false, "line": true, @@ -316,7 +315,6 @@ "yaxis": "left" }, { - "$$hashKey": "object:284", "colorMode": "critical", "fill": false, "line": true, @@ -344,7 +342,6 @@ }, "yaxes": [ { - "$$hashKey": "object:255", "decimals": null, "format": "s", "label": "", @@ -354,7 +351,6 @@ "show": true }, { - "$$hashKey": "object:256", "format": "hertz", "label": "", "logBase": 1, @@ -429,7 +425,6 @@ ], "thresholds": [ { - "$$hashKey": "object:566", "colorMode": "critical", "fill": true, "line": true, @@ -457,7 +452,6 @@ }, "yaxes": [ { - "$$hashKey": "object:538", "decimals": null, "format": "percentunit", "label": null, @@ -467,7 +461,6 @@ "show": true }, { - "$$hashKey": "object:539", "format": "short", "label": null, "logBase": 1, @@ -573,7 +566,6 @@ }, "yaxes": [ { - "$$hashKey": "object:1560", "format": "bytes", "logBase": 1, "max": null, @@ -581,7 +573,6 @@ "show": true }, { - "$$hashKey": "object:1561", "format": "short", "logBase": 1, "max": null, @@ -641,7 +632,6 @@ "renderer": "flot", "seriesOverrides": [ { - "$$hashKey": "object:639", "alias": "/max$/", "color": "#890F02", "fill": 0, @@ -693,7 +683,6 @@ }, "yaxes": [ { - "$$hashKey": "object:650", "decimals": null, "format": "none", "label": "", @@ -703,7 +692,6 @@ "show": true }, { - "$$hashKey": "object:651", "decimals": null, "format": "short", "label": null, @@ -783,11 +771,9 @@ "renderer": "flot", "seriesOverrides": [ { - "$$hashKey": "object:1240", "alias": "/user/" }, { - "$$hashKey": "object:1241", "alias": "/system/" } ], @@ -817,7 +803,6 @@ ], "thresholds": [ { - "$$hashKey": "object:1278", "colorMode": "custom", "fillColor": "rgba(255, 255, 255, 1)", "line": true, @@ -827,7 +812,6 @@ "yaxis": "left" }, { - "$$hashKey": "object:1279", "colorMode": "custom", "fillColor": "rgba(255, 255, 255, 1)", "line": true, @@ -837,7 +821,6 @@ "yaxis": "left" }, { - "$$hashKey": "object:1498", "colorMode": "critical", "fill": true, "line": true, @@ -865,7 +848,6 @@ }, "yaxes": [ { - "$$hashKey": "object:1250", "decimals": null, "format": "percentunit", "label": "", @@ -875,7 +857,6 @@ "show": true }, { - "$$hashKey": "object:1251", "format": "short", "logBase": 1, "max": null, @@ -1427,7 +1408,6 @@ }, "yaxes": [ { - "$$hashKey": "object:572", "format": "percentunit", "label": null, "logBase": 1, @@ -1436,7 +1416,6 @@ "show": true }, { - "$$hashKey": "object:573", "format": "short", "label": null, "logBase": 1, @@ -1720,7 +1699,6 @@ }, "yaxes": [ { - "$$hashKey": "object:102", "format": "hertz", "logBase": 1, "max": null, @@ -1728,7 +1706,6 @@ "show": true }, { - "$$hashKey": "object:103", "format": "short", "logBase": 1, "max": null, @@ -3425,7 +3402,7 @@ "h": 9, "w": 12, "x": 0, - "y": 33 + "y": 6 }, "hiddenSeries": false, "id": 79, @@ -3442,9 +3419,12 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "alertThreshold": true + }, "paceLength": 10, "percentage": false, - "pluginVersion": "7.1.3", + "pluginVersion": "7.3.7", "pointradius": 5, "points": false, "renderer": "flot", @@ -3526,7 +3506,7 @@ "h": 9, "w": 12, "x": 12, - "y": 33 + "y": 6 }, "hiddenSeries": false, "id": 83, @@ -3543,9 +3523,12 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "alertThreshold": true + }, "paceLength": 10, "percentage": false, - "pluginVersion": "7.1.3", + "pluginVersion": "7.3.7", "pointradius": 5, "points": false, "renderer": "flot", @@ -3629,7 +3612,7 @@ "h": 9, "w": 12, "x": 0, - "y": 42 + "y": 15 }, "hiddenSeries": false, "id": 109, @@ -3646,9 +3629,12 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "alertThreshold": true + }, "paceLength": 10, "percentage": false, - "pluginVersion": "7.1.3", + "pluginVersion": "7.3.7", "pointradius": 5, "points": false, "renderer": "flot", @@ -3733,7 +3719,7 @@ "h": 9, "w": 12, "x": 12, - "y": 42 + "y": 15 }, "hiddenSeries": false, "id": 111, @@ -3750,9 +3736,12 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "alertThreshold": true + }, "paceLength": 10, "percentage": false, - "pluginVersion": "7.1.3", + "pluginVersion": "7.3.7", "pointradius": 5, "points": false, "renderer": "flot", @@ -3831,7 +3820,7 @@ "h": 8, "w": 12, "x": 0, - "y": 51 + "y": 24 }, "hiddenSeries": false, "id": 142, @@ -3847,8 +3836,11 @@ "lines": true, "linewidth": 1, "nullPointMode": "null", + "options": { + "alertThreshold": true + }, "percentage": false, - "pluginVersion": "7.1.3", + "pluginVersion": "7.3.7", "pointradius": 2, "points": false, "renderer": "flot", @@ -3931,7 +3923,7 @@ "h": 9, "w": 12, "x": 12, - "y": 51 + "y": 24 }, "hiddenSeries": false, "id": 140, @@ -3948,9 +3940,12 @@ "linewidth": 1, "links": [], "nullPointMode": "null", + "options": { + "alertThreshold": true + }, "paceLength": 10, "percentage": false, - "pluginVersion": "7.1.3", + "pluginVersion": "7.3.7", "pointradius": 5, "points": false, "renderer": "flot", @@ -4079,7 +4074,7 @@ "h": 9, "w": 12, "x": 0, - "y": 59 + "y": 32 }, "heatmap": {}, "hideZeroBuckets": false, @@ -4145,7 +4140,7 @@ "h": 9, "w": 12, "x": 12, - "y": 60 + "y": 33 }, "hiddenSeries": false, "id": 162, @@ -4163,9 +4158,12 @@ "linewidth": 0, "links": [], "nullPointMode": "connected", + "options": { + "alertThreshold": true + }, "paceLength": 10, "percentage": false, - "pluginVersion": "7.1.3", + "pluginVersion": "7.3.7", "pointradius": 5, "points": false, "renderer": "flot", @@ -4350,7 +4348,7 @@ "h": 9, "w": 12, "x": 0, - "y": 68 + "y": 41 }, "heatmap": {}, "hideZeroBuckets": false, @@ -4396,6 +4394,311 @@ "yBucketBound": "auto", "yBucketNumber": null, "yBucketSize": null + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "$datasource", + "editable": true, + "error": false, + "fieldConfig": { + "defaults": { + "custom": {}, + "links": [] + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "grid": {}, + "gridPos": { + "h": 9, + "w": 12, + "x": 12, + "y": 42 + }, + "hiddenSeries": false, + "id": 203, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 2, + "links": [], + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "paceLength": 10, + "percentage": false, + "pluginVersion": "7.3.7", + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "synapse_federation_server_oldest_inbound_pdu_in_staging{job=\"$job\",index=~\"$index\",instance=\"$instance\"}", + "format": "time_series", + "interval": "", + "intervalFactor": 1, + "legendFormat": "rss {{index}}", + "refId": "A", + "step": 4 + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Age of oldest event in staging area", + "tooltip": { + "msResolution": false, + "shared": true, + "sort": 0, + "value_type": "cumulative" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "ms", + "label": null, + "logBase": 1, + "max": null, + "min": 0, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "$datasource", + "editable": true, + "error": false, + "fieldConfig": { + "defaults": { + "custom": {}, + "links": [] + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "grid": {}, + "gridPos": { + "h": 9, + "w": 12, + "x": 0, + "y": 50 + }, + "hiddenSeries": false, + "id": 202, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 2, + "links": [], + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "paceLength": 10, + "percentage": false, + "pluginVersion": "7.3.7", + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "synapse_federation_server_number_inbound_pdu_in_staging{job=\"$job\",index=~\"$index\",instance=\"$instance\"}", + "format": "time_series", + "interval": "", + "intervalFactor": 1, + "legendFormat": "rss {{index}}", + "refId": "A", + "step": 4 + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Number of events in federation staging area", + "tooltip": { + "msResolution": false, + "shared": true, + "sort": 0, + "value_type": "cumulative" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "none", + "label": null, + "logBase": 1, + "max": null, + "min": 0, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "${DS_PROMETHEUS}", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 51 + }, + "hiddenSeries": false, + "id": 205, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.3.7", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "sum(rate(synapse_federation_soft_failed_events_total{instance=\"$instance\"}[$bucket_size]))", + "interval": "", + "legendFormat": "soft-failed events", + "refId": "A" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Soft-failed event rate", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "hertz", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": false + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } } ], "title": "Federation", @@ -4647,7 +4950,7 @@ "h": 7, "w": 12, "x": 0, - "y": 8 + "y": 33 }, "hiddenSeries": false, "id": 48, @@ -4749,7 +5052,7 @@ "h": 7, "w": 12, "x": 12, - "y": 8 + "y": 33 }, "hiddenSeries": false, "id": 104, @@ -4877,7 +5180,7 @@ "h": 7, "w": 12, "x": 0, - "y": 15 + "y": 40 }, "hiddenSeries": false, "id": 10, @@ -4981,7 +5284,7 @@ "h": 7, "w": 12, "x": 12, - "y": 15 + "y": 40 }, "hiddenSeries": false, "id": 11, @@ -5086,7 +5389,7 @@ "h": 7, "w": 12, "x": 0, - "y": 22 + "y": 47 }, "hiddenSeries": false, "id": 180, @@ -5168,6 +5471,126 @@ "align": false, "alignLevel": null } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "$datasource", + "fieldConfig": { + "defaults": { + "custom": {}, + "links": [] + }, + "overrides": [] + }, + "fill": 6, + "fillGradient": 0, + "gridPos": { + "h": 9, + "w": 12, + "x": 12, + "y": 47 + }, + "hiddenSeries": false, + "id": 200, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.3.7", + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "histogram_quantile(0.99, sum(rate(synapse_storage_schedule_time_bucket{index=~\"$index\",instance=\"$instance\",job=\"$job\"}[$bucket_size])) by (le))", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "99%", + "refId": "D" + }, + { + "expr": "histogram_quantile(0.9, sum(rate(synapse_storage_schedule_time_bucket{index=~\"$index\",instance=\"$instance\",job=\"$job\"}[$bucket_size])) by (le))", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "90%", + "refId": "A" + }, + { + "expr": "histogram_quantile(0.75, sum(rate(synapse_storage_schedule_time_bucket{index=~\"$index\",instance=\"$instance\",job=\"$job\"}[$bucket_size])) by (le))", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "75%", + "refId": "C" + }, + { + "expr": "histogram_quantile(0.5, sum(rate(synapse_storage_schedule_time_bucket{index=~\"$index\",instance=\"$instance\",job=\"$job\"}[$bucket_size])) by (le))", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "50%", + "refId": "B" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Time waiting for DB connection quantiles", + "tooltip": { + "shared": true, + "sort": 2, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "decimals": null, + "format": "s", + "label": "", + "logBase": 1, + "max": null, + "min": "0", + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": false + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } } ], "repeat": null, @@ -5916,7 +6339,7 @@ "h": 10, "w": 12, "x": 0, - "y": 84 + "y": 35 }, "hiddenSeries": false, "id": 1, @@ -6022,7 +6445,7 @@ "h": 10, "w": 12, "x": 12, - "y": 84 + "y": 35 }, "hiddenSeries": false, "id": 8, @@ -6126,7 +6549,7 @@ "h": 10, "w": 12, "x": 0, - "y": 94 + "y": 45 }, "hiddenSeries": false, "id": 38, @@ -6226,7 +6649,7 @@ "h": 10, "w": 12, "x": 12, - "y": 94 + "y": 45 }, "hiddenSeries": false, "id": 39, @@ -6258,8 +6681,9 @@ "steppedLine": false, "targets": [ { - "expr": "topk(10, rate(synapse_util_caches_cache:total{job=\"$job\",index=~\"$index\",instance=\"$instance\"}[$bucket_size]) - rate(synapse_util_caches_cache:hits{job=\"$job\",instance=\"$instance\"}[$bucket_size]))", + "expr": "topk(10, rate(synapse_util_caches_cache:total{job=~\"$job\",index=~\"$index\",instance=\"$instance\"}[$bucket_size]) - rate(synapse_util_caches_cache:hits{job=~\"$job\",index=~\"$index\",instance=\"$instance\"}[$bucket_size]))", "format": "time_series", + "interval": "", "intervalFactor": 2, "legendFormat": "{{name}} {{job}}-{{index}}", "refId": "A", @@ -6326,7 +6750,7 @@ "h": 9, "w": 12, "x": 0, - "y": 104 + "y": 55 }, "hiddenSeries": false, "id": 65, @@ -9051,7 +9475,7 @@ "h": 8, "w": 12, "x": 0, - "y": 119 + "y": 41 }, "hiddenSeries": false, "id": 156, @@ -9089,7 +9513,7 @@ "steppedLine": false, "targets": [ { - "expr": "synapse_admin_mau:current{instance=\"$instance\"}", + "expr": "synapse_admin_mau:current{instance=\"$instance\", job=~\"$job\"}", "format": "time_series", "interval": "", "intervalFactor": 1, @@ -9097,7 +9521,7 @@ "refId": "A" }, { - "expr": "synapse_admin_mau:max{instance=\"$instance\"}", + "expr": "synapse_admin_mau:max{instance=\"$instance\", job=~\"$job\"}", "format": "time_series", "interval": "", "intervalFactor": 1, @@ -9164,7 +9588,7 @@ "h": 8, "w": 12, "x": 12, - "y": 119 + "y": 41 }, "hiddenSeries": false, "id": 160, @@ -9484,7 +9908,7 @@ "h": 8, "w": 12, "x": 0, - "y": 73 + "y": 43 }, "hiddenSeries": false, "id": 168, @@ -9516,7 +9940,7 @@ { "expr": "rate(synapse_appservice_api_sent_events{instance=\"$instance\"}[$bucket_size])", "interval": "", - "legendFormat": "{{exported_service}}", + "legendFormat": "{{service}}", "refId": "A" } ], @@ -9579,7 +10003,7 @@ "h": 8, "w": 12, "x": 12, - "y": 73 + "y": 43 }, "hiddenSeries": false, "id": 171, @@ -9611,7 +10035,7 @@ { "expr": "rate(synapse_appservice_api_sent_transactions{instance=\"$instance\"}[$bucket_size])", "interval": "", - "legendFormat": "{{exported_service}}", + "legendFormat": "{{service}}", "refId": "A" } ], @@ -9959,7 +10383,6 @@ }, "yaxes": [ { - "$$hashKey": "object:165", "format": "hertz", "label": null, "logBase": 1, @@ -9968,7 +10391,6 @@ "show": true }, { - "$$hashKey": "object:166", "format": "short", "label": null, "logBase": 1, @@ -10071,7 +10493,6 @@ }, "yaxes": [ { - "$$hashKey": "object:390", "format": "hertz", "label": null, "logBase": 1, @@ -10080,7 +10501,6 @@ "show": true }, { - "$$hashKey": "object:391", "format": "short", "label": null, "logBase": 1, @@ -10169,7 +10589,6 @@ }, "yaxes": [ { - "$$hashKey": "object:390", "format": "hertz", "label": null, "logBase": 1, @@ -10178,7 +10597,6 @@ "show": true }, { - "$$hashKey": "object:391", "format": "short", "label": null, "logBase": 1, @@ -10470,5 +10888,5 @@ "timezone": "", "title": "Synapse", "uid": "000000012", - "version": 90 + "version": 99 } \ No newline at end of file -- cgit 1.5.1 From a3a7514570f21dcad6f7ef4c1ee3ed1e30115825 Mon Sep 17 00:00:00 2001 From: Šimon Brandner Date: Mon, 16 Aug 2021 13:22:38 +0200 Subject: Handle string read receipt data (#10606) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Handle string read receipt data Signed-off-by: Šimon Brandner * Test that we handle string read receipt data Signed-off-by: Šimon Brandner * Add changelog for #10606 Signed-off-by: Šimon Brandner * Add docs Signed-off-by: Šimon Brandner * Ignore malformed RRs Signed-off-by: Šimon Brandner * Only surround hidden = ... Signed-off-by: Šimon Brandner * Remove unnecessary argument Signed-off-by: Šimon Brandner * Update changelog.d/10606.bugfix Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/10606.bugfix | 1 + synapse/handlers/receipts.py | 9 ++++++++- tests/handlers/test_receipts.py | 23 +++++++++++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10606.bugfix diff --git a/changelog.d/10606.bugfix b/changelog.d/10606.bugfix new file mode 100644 index 0000000000..bab9fd2a61 --- /dev/null +++ b/changelog.d/10606.bugfix @@ -0,0 +1 @@ +Fix errors on /sync when read receipt data is a string. Only affects homeservers with the experimental flag for [MSC2285](https://github.com/matrix-org/matrix-doc/pull/2285) enabled. Contributed by @SimonBrandner. diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 5fd4525700..fb495229a7 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -188,7 +188,14 @@ class ReceiptEventSource: new_users = {} for rr_user_id, user_rr in m_read.items(): - hidden = user_rr.get("hidden", None) + try: + hidden = user_rr.get("hidden") + except AttributeError: + # Due to https://github.com/matrix-org/synapse/issues/10376 + # there are cases where user_rr is a string, in those cases + # we just ignore the read receipt + continue + if hidden is not True or rr_user_id == user_id: new_users[rr_user_id] = user_rr.copy() # If hidden has a value replace hidden with the correct prefixed key diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index 93a9a084b2..732a12c9bd 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -286,6 +286,29 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) + def test_handles_string_data(self): + """ + Tests that an invalid shape for read-receipts is handled. + Context: https://github.com/matrix-org/synapse/issues/10603 + """ + + self._test_filters_hidden( + [ + { + "content": { + "$14356419edgd14394fHBLK:matrix.org": { + "m.read": { + "@rikj:jki.re": "string", + } + }, + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + }, + ], + [], + ) + def _test_filters_hidden( self, events: List[JsonDict], expected_output: List[JsonDict] ): -- cgit 1.5.1 From 7de445161f2fec115ce8518cde7a3b333a611f16 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 16 Aug 2021 08:06:17 -0400 Subject: Support federation in the new spaces summary API (MSC2946). (#10569) --- changelog.d/10569.feature | 1 + synapse/federation/federation_client.py | 82 +++++++++ synapse/federation/transport/client.py | 22 +++ synapse/federation/transport/server.py | 28 +++ synapse/handlers/space_summary.py | 258 +++++++++++++++++++++++----- tests/handlers/test_space_summary.py | 292 ++++++++++++++++++-------------- 6 files changed, 518 insertions(+), 165 deletions(-) create mode 100644 changelog.d/10569.feature diff --git a/changelog.d/10569.feature b/changelog.d/10569.feature new file mode 100644 index 0000000000..ffc4e4289c --- /dev/null +++ b/changelog.d/10569.feature @@ -0,0 +1 @@ +Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 2eefac04fd..0af953a5d6 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1290,6 +1290,88 @@ class FederationClient(FederationBase): failover_on_unknown_endpoint=True, ) + async def get_room_hierarchy( + self, + destinations: Iterable[str], + room_id: str, + suggested_only: bool, + ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]: + """ + Call other servers to get a hierarchy of the given room. + + Performs simple data validates and parsing of the response. + + Args: + destinations: The remote servers. We will try them in turn, omitting any + that have been blacklisted. + room_id: ID of the space to be queried + suggested_only: If true, ask the remote server to only return children + with the "suggested" flag set + + Returns: + A tuple of: + The room as a JSON dictionary. + A list of children rooms, as JSON dictionaries. + A list of inaccessible children room IDs. + + Raises: + SynapseError if we were unable to get a valid summary from any of the + remote servers + """ + + async def send_request( + destination: str, + ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]: + res = await self.transport_layer.get_room_hierarchy( + destination=destination, + room_id=room_id, + suggested_only=suggested_only, + ) + + room = res.get("room") + if not isinstance(room, dict): + raise InvalidResponseError("'room' must be a dict") + + # Validate children_state of the room. + children_state = room.get("children_state", []) + if not isinstance(children_state, Sequence): + raise InvalidResponseError("'room.children_state' must be a list") + if any(not isinstance(e, dict) for e in children_state): + raise InvalidResponseError("Invalid event in 'children_state' list") + try: + [ + FederationSpaceSummaryEventResult.from_json_dict(e) + for e in children_state + ] + except ValueError as e: + raise InvalidResponseError(str(e)) + + # Validate the children rooms. + children = res.get("children", []) + if not isinstance(children, Sequence): + raise InvalidResponseError("'children' must be a list") + if any(not isinstance(r, dict) for r in children): + raise InvalidResponseError("Invalid room in 'children' list") + + # Validate the inaccessible children. + inaccessible_children = res.get("inaccessible_children", []) + if not isinstance(inaccessible_children, Sequence): + raise InvalidResponseError("'inaccessible_children' must be a list") + if any(not isinstance(r, str) for r in inaccessible_children): + raise InvalidResponseError( + "Invalid room ID in 'inaccessible_children' list" + ) + + return room, children, inaccessible_children + + # TODO Fallback to the old federation API and translate the results. + return await self._try_destination_list( + "fetch room hierarchy", + destinations, + send_request, + failover_on_unknown_endpoint=True, + ) + @attr.s(frozen=True, slots=True, auto_attribs=True) class FederationSpaceSummaryEventResult: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 90a7c16b62..8b247fe206 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -1177,6 +1177,28 @@ class TransportLayerClient: destination=destination, path=path, data=params ) + async def get_room_hierarchy( + self, + destination: str, + room_id: str, + suggested_only: bool, + ) -> JsonDict: + """ + Args: + destination: The remote server + room_id: The room ID to ask about. + suggested_only: if True, only suggested rooms will be returned + """ + path = _create_path( + FEDERATION_UNSTABLE_PREFIX, "/org.matrix.msc2946/hierarchy/%s", room_id + ) + + return await self.client.get_json( + destination=destination, + path=path, + args={"suggested_only": "true" if suggested_only else "false"}, + ) + def _create_path(federation_prefix: str, path: str, *args: str) -> str: """ diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 640f46fff6..79a2e1afa0 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -1936,6 +1936,33 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): ) +class FederationRoomHierarchyServlet(BaseFederationServlet): + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" + PATH = "/hierarchy/(?P[^/]*)" + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_space_summary_handler() + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Mapping[bytes, Sequence[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) + return 200, await self.handler.get_federation_hierarchy( + origin, room_id, suggested_only + ) + + class RoomComplexityServlet(BaseFederationServlet): """ Indicates to other servers how complex (and therefore likely @@ -1999,6 +2026,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationVersionServlet, RoomComplexityServlet, FederationSpaceSummaryServlet, + FederationRoomHierarchyServlet, FederationV1SendKnockServlet, FederationMakeKnockServlet, ) diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index d0060f9046..c74e90abbc 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -16,17 +16,7 @@ import itertools import logging import re from collections import deque -from typing import ( - TYPE_CHECKING, - Deque, - Dict, - Iterable, - List, - Optional, - Sequence, - Set, - Tuple, -) +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Set, Tuple import attr @@ -80,7 +70,7 @@ class _PaginationSession: # The time the pagination session was created, in milliseconds. creation_time_ms: int # The queue of rooms which are still to process. - room_queue: Deque["_RoomQueueEntry"] + room_queue: List["_RoomQueueEntry"] # A set of rooms which have been processed. processed_rooms: Set[str] @@ -197,7 +187,7 @@ class SpaceSummaryHandler: events: Sequence[JsonDict] = [] if room_entry: rooms_result.append(room_entry.room) - events = room_entry.children + events = room_entry.children_state_events logger.debug( "Query of local room %s returned events %s", @@ -232,7 +222,7 @@ class SpaceSummaryHandler: room.pop("allowed_spaces", None) rooms_result.append(room) - events.extend(room_entry.children) + events.extend(room_entry.children_state_events) # All rooms returned don't need visiting again (even if the user # didn't have access to them). @@ -350,8 +340,8 @@ class SpaceSummaryHandler: room_queue = pagination_session.room_queue processed_rooms = pagination_session.processed_rooms else: - # the queue of rooms to process - room_queue = deque((_RoomQueueEntry(requested_room_id, ()),)) + # The queue of rooms to process, the next room is last on the stack. + room_queue = [_RoomQueueEntry(requested_room_id, ())] # Rooms we have already processed. processed_rooms = set() @@ -367,7 +357,7 @@ class SpaceSummaryHandler: # Iterate through the queue until we reach the limit or run out of # rooms to include. while room_queue and len(rooms_result) < limit: - queue_entry = room_queue.popleft() + queue_entry = room_queue.pop() room_id = queue_entry.room_id current_depth = queue_entry.depth if room_id in processed_rooms: @@ -376,6 +366,18 @@ class SpaceSummaryHandler: logger.debug("Processing room %s", room_id) + # A map of summaries for children rooms that might be returned over + # federation. The rationale for caching these and *maybe* using them + # is to prefer any information local to the homeserver before trusting + # data received over federation. + children_room_entries: Dict[str, JsonDict] = {} + # A set of room IDs which are children that did not have information + # returned over federation and are known to be inaccessible to the + # current server. We should not reach out over federation to try to + # summarise these rooms. + inaccessible_children: Set[str] = set() + + # If the room is known locally, summarise it! is_in_room = await self._store.is_host_joined(room_id, self._server_name) if is_in_room: room_entry = await self._summarize_local_room( @@ -387,26 +389,68 @@ class SpaceSummaryHandler: max_children=None, ) - if room_entry: - rooms_result.append(room_entry.as_json()) - - # Add the child to the queue. We have already validated - # that the vias are a list of server names. - # - # If the current depth is the maximum depth, do not queue - # more entries. - if max_depth is None or current_depth < max_depth: - room_queue.extendleft( - _RoomQueueEntry( - ev["state_key"], ev["content"]["via"], current_depth + 1 - ) - for ev in reversed(room_entry.children) - ) - - processed_rooms.add(room_id) + # Otherwise, attempt to use information for federation. else: - # TODO Federation. - pass + # A previous call might have included information for this room. + # It can be used if either: + # + # 1. The room is not a space. + # 2. The maximum depth has been achieved (since no children + # information is needed). + if queue_entry.remote_room and ( + queue_entry.remote_room.get("room_type") != RoomTypes.SPACE + or (max_depth is not None and current_depth >= max_depth) + ): + room_entry = _RoomEntry( + queue_entry.room_id, queue_entry.remote_room + ) + + # If the above isn't true, attempt to fetch the room + # information over federation. + else: + ( + room_entry, + children_room_entries, + inaccessible_children, + ) = await self._summarize_remote_room_hiearchy( + queue_entry, + suggested_only, + ) + + # Ensure this room is accessible to the requester (and not just + # the homeserver). + if room_entry and not await self._is_remote_room_accessible( + requester, queue_entry.room_id, room_entry.room + ): + room_entry = None + + # This room has been processed and should be ignored if it appears + # elsewhere in the hierarchy. + processed_rooms.add(room_id) + + # There may or may not be a room entry based on whether it is + # inaccessible to the requesting user. + if room_entry: + # Add the room (including the stripped m.space.child events). + rooms_result.append(room_entry.as_json()) + + # If this room is not at the max-depth, check if there are any + # children to process. + if max_depth is None or current_depth < max_depth: + # The children get added in reverse order so that the next + # room to process, according to the ordering, is the last + # item in the list. + room_queue.extend( + _RoomQueueEntry( + ev["state_key"], + ev["content"]["via"], + current_depth + 1, + children_room_entries.get(ev["state_key"]), + ) + for ev in reversed(room_entry.children_state_events) + if ev["type"] == EventTypes.SpaceChild + and ev["state_key"] not in inaccessible_children + ) result: JsonDict = {"rooms": rooms_result} @@ -477,15 +521,78 @@ class SpaceSummaryHandler: if room_entry: rooms_result.append(room_entry.room) - events_result.extend(room_entry.children) + events_result.extend(room_entry.children_state_events) # add any children to the queue room_queue.extend( - edge_event["state_key"] for edge_event in room_entry.children + edge_event["state_key"] + for edge_event in room_entry.children_state_events ) return {"rooms": rooms_result, "events": events_result} + async def get_federation_hierarchy( + self, + origin: str, + requested_room_id: str, + suggested_only: bool, + ): + """ + Implementation of the room hierarchy Federation API. + + This is similar to get_room_hierarchy, but does not recurse into the space. + It also considers whether anyone on the server may be able to access the + room, as opposed to whether a specific user can. + + Args: + origin: The server requesting the spaces summary. + requested_room_id: The room ID to start the hierarchy at (the "root" room). + suggested_only: whether we should only return children with the "suggested" + flag set. + + Returns: + The JSON hierarchy dictionary. + """ + root_room_entry = await self._summarize_local_room( + None, origin, requested_room_id, suggested_only, max_children=None + ) + if root_room_entry is None: + # Room is inaccessible to the requesting server. + raise SynapseError(404, "Unknown room: %s" % (requested_room_id,)) + + children_rooms_result: List[JsonDict] = [] + inaccessible_children: List[str] = [] + + # Iterate through each child and potentially add it, but not its children, + # to the response. + for child_room in root_room_entry.children_state_events: + room_id = child_room.get("state_key") + assert isinstance(room_id, str) + # If the room is unknown, skip it. + if not await self._store.is_host_joined(room_id, self._server_name): + continue + + room_entry = await self._summarize_local_room( + None, origin, room_id, suggested_only, max_children=0 + ) + # If the room is accessible, include it in the results. + # + # Note that only the room summary (without information on children) + # is included in the summary. + if room_entry: + children_rooms_result.append(room_entry.room) + # Otherwise, note that the requesting server shouldn't bother + # trying to summarize this room - they do not have access to it. + else: + inaccessible_children.append(room_id) + + return { + # Include the requested room (including the stripped children events). + "room": root_room_entry.as_json(), + "children": children_rooms_result, + "inaccessible_children": inaccessible_children, + } + async def _summarize_local_room( self, requester: Optional[str], @@ -519,8 +626,9 @@ class SpaceSummaryHandler: room_entry = await self._build_room_entry(room_id, for_federation=bool(origin)) - # If the room is not a space, return just the room information. - if room_entry.get("room_type") != RoomTypes.SPACE: + # If the room is not a space or the children don't matter, return just + # the room information. + if room_entry.get("room_type") != RoomTypes.SPACE or max_children == 0: return _RoomEntry(room_id, room_entry) # Otherwise, look for child rooms/spaces. @@ -616,6 +724,59 @@ class SpaceSummaryHandler: return results + async def _summarize_remote_room_hiearchy( + self, room: "_RoomQueueEntry", suggested_only: bool + ) -> Tuple[Optional["_RoomEntry"], Dict[str, JsonDict], Set[str]]: + """ + Request room entries and a list of event entries for a given room by querying a remote server. + + Args: + room: The room to summarize. + suggested_only: True if only suggested children should be returned. + Otherwise, all children are returned. + + Returns: + A tuple of: + The room entry. + Partial room data return over federation. + A set of inaccessible children room IDs. + """ + room_id = room.room_id + logger.info("Requesting summary for %s via %s", room_id, room.via) + + via = itertools.islice(room.via, MAX_SERVERS_PER_SPACE) + try: + ( + room_response, + children, + inaccessible_children, + ) = await self._federation_client.get_room_hierarchy( + via, + room_id, + suggested_only=suggested_only, + ) + except Exception as e: + logger.warning( + "Unable to get hierarchy of %s via federation: %s", + room_id, + e, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + return None, {}, set() + + # Map the children to their room ID. + children_by_room_id = { + c["room_id"]: c + for c in children + if "room_id" in c and isinstance(c["room_id"], str) + } + + return ( + _RoomEntry(room_id, room_response, room_response.pop("children_state", ())), + children_by_room_id, + set(inaccessible_children), + ) + async def _is_local_room_accessible( self, room_id: str, requester: Optional[str], origin: Optional[str] = None ) -> bool: @@ -866,9 +1027,16 @@ class SpaceSummaryHandler: @attr.s(frozen=True, slots=True, auto_attribs=True) class _RoomQueueEntry: + # The room ID of this entry. room_id: str + # The server to query if the room is not known locally. via: Sequence[str] + # The minimum number of hops necessary to get to this room (compared to the + # originally requested room). depth: int = 0 + # The room summary for this room returned via federation. This will only be + # used if the room is not known locally (and is not a space). + remote_room: Optional[JsonDict] = None @attr.s(frozen=True, slots=True, auto_attribs=True) @@ -879,11 +1047,17 @@ class _RoomEntry: # An iterable of the sorted, stripped children events for children of this room. # # This may not include all children. - children: Sequence[JsonDict] = () + children_state_events: Sequence[JsonDict] = () def as_json(self) -> JsonDict: + """ + Returns a JSON dictionary suitable for the room hierarchy endpoint. + + It returns the room summary including the stripped m.space.child events + as a sub-key. + """ result = dict(self.room) - result["children_state"] = self.children + result["children_state"] = self.children_state_events return result diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index 83c2bdd8f9..bc8e131f4a 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -481,7 +481,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self.assertNotIn("next_batch", result) def test_invalid_pagination_token(self): - """""" + """An invalid pagination token, or changing other parameters, shoudl be rejected.""" room_ids = [] for i in range(1, 10): room = self.helper.create_room_as(self.user, tok=self.token) @@ -581,33 +581,40 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): subspace = "#subspace:" + fed_hostname subroom = "#subroom:" + fed_hostname + # Generate some good data, and some bad data: + # + # * Event *back* to the root room. + # * Unrelated events / rooms + # * Multiple levels of events (in a not-useful order, e.g. grandchild + # events before child events). + + # Note that these entries are brief, but should contain enough info. + requested_room_entry = _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + "room_type": RoomTypes.SPACE, + }, + [ + { + "type": EventTypes.SpaceChild, + "room_id": subspace, + "state_key": subroom, + "content": {"via": [fed_hostname]}, + } + ], + ) + child_room = { + "room_id": subroom, + "world_readable": True, + } + async def summarize_remote_room( _self, room, suggested_only, max_children, exclude_rooms ): - # Return some good data, and some bad data: - # - # * Event *back* to the root room. - # * Unrelated events / rooms - # * Multiple levels of events (in a not-useful order, e.g. grandchild - # events before child events). - - # Note that these entries are brief, but should contain enough info. return [ - _RoomEntry( - subspace, - { - "room_id": subspace, - "world_readable": True, - "room_type": RoomTypes.SPACE, - }, - [ - { - "room_id": subspace, - "state_key": subroom, - "content": {"via": [fed_hostname]}, - } - ], - ), + requested_room_entry, _RoomEntry( subroom, { @@ -617,6 +624,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ), ] + async def summarize_remote_room_hiearchy(_self, room, suggested_only): + return requested_room_entry, {subroom: child_room}, set() + # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) @@ -636,6 +646,15 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ] self._assert_rooms(result, expected) + with mock.patch( + "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room_hiearchy", + new=summarize_remote_room_hiearchy, + ): + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + def test_fed_filtering(self): """ Rooms returned over federation should be properly filtered to only include @@ -657,100 +676,106 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Poke an invite over federation into the database. self._poke_fed_invite(invited_room, "@remote:" + fed_hostname) + # Note that these entries are brief, but should contain enough info. + children_rooms = ( + ( + public_room, + { + "room_id": public_room, + "world_readable": False, + "join_rules": JoinRules.PUBLIC, + }, + ), + ( + knock_room, + { + "room_id": knock_room, + "world_readable": False, + "join_rules": JoinRules.KNOCK, + }, + ), + ( + not_invited_room, + { + "room_id": not_invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + ( + invited_room, + { + "room_id": invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + ( + restricted_room, + { + "room_id": restricted_room, + "world_readable": False, + "join_rules": JoinRules.RESTRICTED, + "allowed_spaces": [], + }, + ), + ( + restricted_accessible_room, + { + "room_id": restricted_accessible_room, + "world_readable": False, + "join_rules": JoinRules.RESTRICTED, + "allowed_spaces": [self.room], + }, + ), + ( + world_readable_room, + { + "room_id": world_readable_room, + "world_readable": True, + "join_rules": JoinRules.INVITE, + }, + ), + ( + joined_room, + { + "room_id": joined_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + ) + + subspace_room_entry = _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + }, + # Place each room in the sub-space. + [ + { + "type": EventTypes.SpaceChild, + "room_id": subspace, + "state_key": room_id, + "content": {"via": [fed_hostname]}, + } + for room_id, _ in children_rooms + ], + ) + async def summarize_remote_room( _self, room, suggested_only, max_children, exclude_rooms ): - # Note that these entries are brief, but should contain enough info. - rooms = [ - _RoomEntry( - public_room, - { - "room_id": public_room, - "world_readable": False, - "join_rules": JoinRules.PUBLIC, - }, - ), - _RoomEntry( - knock_room, - { - "room_id": knock_room, - "world_readable": False, - "join_rules": JoinRules.KNOCK, - }, - ), - _RoomEntry( - not_invited_room, - { - "room_id": not_invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ), - _RoomEntry( - invited_room, - { - "room_id": invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ), - _RoomEntry( - restricted_room, - { - "room_id": restricted_room, - "world_readable": False, - "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [], - }, - ), - _RoomEntry( - restricted_accessible_room, - { - "room_id": restricted_accessible_room, - "world_readable": False, - "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [self.room], - }, - ), - _RoomEntry( - world_readable_room, - { - "room_id": world_readable_room, - "world_readable": True, - "join_rules": JoinRules.INVITE, - }, - ), - _RoomEntry( - joined_room, - { - "room_id": joined_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ), + return [subspace_room_entry] + [ + # A copy is made of the room data since the allowed_spaces key + # is removed. + _RoomEntry(child_room[0], dict(child_room[1])) + for child_room in children_rooms ] - # Also include the subspace. - rooms.insert( - 0, - _RoomEntry( - subspace, - { - "room_id": subspace, - "world_readable": True, - }, - # Place each room in the sub-space. - [ - { - "room_id": subspace, - "state_key": room.room_id, - "content": {"via": [fed_hostname]}, - } - for room in rooms - ], - ), - ) - return rooms + async def summarize_remote_room_hiearchy(_self, room, suggested_only): + return subspace_room_entry, dict(children_rooms), set() # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) @@ -788,6 +813,15 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ] self._assert_rooms(result, expected) + with mock.patch( + "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room_hiearchy", + new=summarize_remote_room_hiearchy, + ): + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + def test_fed_invited(self): """ A room which the user was invited to should be included in the response. @@ -802,19 +836,22 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # Poke an invite over federation into the database. self._poke_fed_invite(fed_room, "@remote:" + fed_hostname) + fed_room_entry = _RoomEntry( + fed_room, + { + "room_id": fed_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ) + async def summarize_remote_room( _self, room, suggested_only, max_children, exclude_rooms ): - return [ - _RoomEntry( - fed_room, - { - "room_id": fed_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ), - ] + return [fed_room_entry] + + async def summarize_remote_room_hiearchy(_self, room, suggested_only): + return fed_room_entry, {}, set() # Add a room to the space which is on another server. self._add_child(self.space, fed_room, self.token) @@ -833,3 +870,12 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): (fed_room, ()), ] self._assert_rooms(result, expected) + + with mock.patch( + "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room_hiearchy", + new=summarize_remote_room_hiearchy, + ): + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) -- cgit 1.5.1 From 2d9ca4ca77c2cdf98ddb738aee8d5699c7c8749f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 16 Aug 2021 13:19:02 +0100 Subject: Clean up some logging in the federation event handler (#10591) * Include outlier status in `str(event)` In places where we log event objects, knowing whether or not you're dealing with an outlier is super useful. * Remove duplicated logging in get_missing_events When we process events received from get_missing_events, we log them twice (once in `_get_missing_events_for_pdu`, and once in `on_receive_pdu`). Reduce the duplication by removing the logging in `on_receive_pdu`, and ensuring the call sites do sensible logging. * log in `on_receive_pdu` when we already have the event * Log which prev_events we are missing * changelog --- changelog.d/10591.misc | 1 + synapse/events/__init__.py | 3 +- synapse/federation/federation_server.py | 1 + synapse/handlers/federation.py | 52 +++++++++++++++------------------ 4 files changed, 28 insertions(+), 29 deletions(-) create mode 100644 changelog.d/10591.misc diff --git a/changelog.d/10591.misc b/changelog.d/10591.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10591.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 0298af4c02..a730c1719a 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -396,10 +396,11 @@ class FrozenEvent(EventBase): return self.__repr__() def __repr__(self): - return "" % ( + return "" % ( self.get("event_id", None), self.get("type", None), self.get("state_key", None), + self.internal_metadata.is_outlier(), ) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 78d5aac6af..afd8f8580a 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1003,6 +1003,7 @@ class FederationServer(FederationBase): # has started processing). while True: async with lock: + logger.info("handling received PDU: %s", event) try: await self.handler.on_receive_pdu( origin, event, sent_to_us_directly=True diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 9a5e726533..c0e13bdaac 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -220,8 +220,6 @@ class FederationHandler(BaseHandler): room_id = pdu.room_id event_id = pdu.event_id - logger.info("handling received PDU: %s", pdu) - # We reprocess pdus when we have seen them only as outliers existing = await self.store.get_event( event_id, allow_none=True, allow_rejected=True @@ -229,14 +227,19 @@ class FederationHandler(BaseHandler): # FIXME: Currently we fetch an event again when we already have it # if it has been marked as an outlier. - - already_seen = existing and ( - not existing.internal_metadata.is_outlier() - or pdu.internal_metadata.is_outlier() - ) - if already_seen: - logger.debug("Already seen pdu") - return + if existing: + if not existing.internal_metadata.is_outlier(): + logger.info( + "Ignoring received event %s which we have already seen", event_id + ) + return + if pdu.internal_metadata.is_outlier(): + logger.info( + "Ignoring received outlier %s which we already have as an outlier", + event_id, + ) + return + logger.info("De-outliering event %s", event_id) # do some initial sanity-checking of the event. In particular, make # sure it doesn't have hundreds of prev_events or auth_events, which @@ -331,7 +334,8 @@ class FederationHandler(BaseHandler): "Found all missing prev_events", ) - if prevs - seen: + missing_prevs = prevs - seen + if missing_prevs: # We've still not been able to get all of the prev_events for this event. # # In this case, we need to fall back to asking another server in the @@ -359,8 +363,8 @@ class FederationHandler(BaseHandler): if sent_to_us_directly: logger.warning( "Rejecting: failed to fetch %d prev events: %s", - len(prevs - seen), - shortstr(prevs - seen), + len(missing_prevs), + shortstr(missing_prevs), ) raise FederationError( "ERROR", @@ -373,9 +377,10 @@ class FederationHandler(BaseHandler): ) logger.info( - "Event %s is missing prev_events: calculating state for a " + "Event %s is missing prev_events %s: calculating state for a " "backwards extremity", event_id, + shortstr(missing_prevs), ) # Calculate the state after each of the previous events, and @@ -393,7 +398,7 @@ class FederationHandler(BaseHandler): # Ask the remote server for the states we don't # know about - for p in prevs - seen: + for p in missing_prevs: logger.info("Requesting state after missing prev_event %s", p) with nested_logging_context(p): @@ -556,21 +561,14 @@ class FederationHandler(BaseHandler): logger.warning("Failed to get prev_events: %s", e) return - logger.info( - "Got %d prev_events: %s", - len(missing_events), - shortstr(missing_events), - ) + logger.info("Got %d prev_events", len(missing_events)) # We want to sort these by depth so we process them and # tell clients about them in order. missing_events.sort(key=lambda x: x.depth) for ev in missing_events: - logger.info( - "Handling received prev_event %s", - ev.event_id, - ) + logger.info("Handling received prev_event %s", ev) with nested_logging_context(ev.event_id): try: await self.on_receive_pdu(origin, ev, sent_to_us_directly=False) @@ -1762,10 +1760,8 @@ class FederationHandler(BaseHandler): for p, origin in room_queue: try: logger.info( - "Processing queued PDU %s which was received " - "while we were joining %s", - p.event_id, - p.room_id, + "Processing queued PDU %s which was received while we were joining", + p, ) with nested_logging_context(p.event_id): await self.on_receive_pdu(origin, p, sent_to_us_directly=True) -- cgit 1.5.1 From 87b62f8bb23f99d76bf0ee62c8217fa45a087673 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 16 Aug 2021 10:14:31 -0400 Subject: Split `synapse.federation.transport.server` into multiple files. (#10590) --- changelog.d/10590.misc | 1 + synapse/federation/transport/server.py | 2158 -------------------- synapse/federation/transport/server/__init__.py | 332 +++ synapse/federation/transport/server/_base.py | 328 +++ synapse/federation/transport/server/federation.py | 692 +++++++ .../federation/transport/server/groups_local.py | 113 + .../federation/transport/server/groups_server.py | 753 +++++++ 7 files changed, 2219 insertions(+), 2158 deletions(-) create mode 100644 changelog.d/10590.misc delete mode 100644 synapse/federation/transport/server.py create mode 100644 synapse/federation/transport/server/__init__.py create mode 100644 synapse/federation/transport/server/_base.py create mode 100644 synapse/federation/transport/server/federation.py create mode 100644 synapse/federation/transport/server/groups_local.py create mode 100644 synapse/federation/transport/server/groups_server.py diff --git a/changelog.d/10590.misc b/changelog.d/10590.misc new file mode 100644 index 0000000000..62fec717da --- /dev/null +++ b/changelog.d/10590.misc @@ -0,0 +1 @@ +Re-organize the `synapse.federation.transport.server` module to create smaller files. diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py deleted file mode 100644 index 79a2e1afa0..0000000000 --- a/synapse/federation/transport/server.py +++ /dev/null @@ -1,2158 +0,0 @@ -# Copyright 2014-2021 The Matrix.org Foundation C.I.C. -# Copyright 2020 Sorunome -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import functools -import logging -import re -from typing import ( - Container, - Dict, - List, - Mapping, - Optional, - Sequence, - Tuple, - Type, - Union, -) - -from typing_extensions import Literal - -import synapse -from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH -from synapse.api.errors import Codes, FederationDeniedError, SynapseError -from synapse.api.room_versions import RoomVersions -from synapse.api.urls import ( - FEDERATION_UNSTABLE_PREFIX, - FEDERATION_V1_PREFIX, - FEDERATION_V2_PREFIX, -) -from synapse.handlers.groups_local import GroupsLocalHandler -from synapse.http.server import HttpServer, JsonResource -from synapse.http.servlet import ( - parse_boolean_from_args, - parse_integer_from_args, - parse_json_object_from_request, - parse_string_from_args, - parse_strings_from_args, -) -from synapse.logging import opentracing -from synapse.logging.context import run_in_background -from synapse.logging.opentracing import ( - SynapseTags, - start_active_span, - start_active_span_from_request, - tags, - whitelisted_homeserver, -) -from synapse.server import HomeServer -from synapse.types import JsonDict, ThirdPartyInstanceID, get_domain_from_id -from synapse.util.ratelimitutils import FederationRateLimiter -from synapse.util.stringutils import parse_and_validate_server_name -from synapse.util.versionstring import get_version_string - -logger = logging.getLogger(__name__) - - -class TransportLayerServer(JsonResource): - """Handles incoming federation HTTP requests""" - - def __init__(self, hs: HomeServer, servlet_groups: Optional[List[str]] = None): - """Initialize the TransportLayerServer - - Will by default register all servlets. For custom behaviour, pass in - a list of servlet_groups to register. - - Args: - hs: homeserver - servlet_groups: List of servlet groups to register. - Defaults to ``DEFAULT_SERVLET_GROUPS``. - """ - self.hs = hs - self.clock = hs.get_clock() - self.servlet_groups = servlet_groups - - super().__init__(hs, canonical_json=False) - - self.authenticator = Authenticator(hs) - self.ratelimiter = hs.get_federation_ratelimiter() - - self.register_servlets() - - def register_servlets(self) -> None: - register_servlets( - self.hs, - resource=self, - ratelimiter=self.ratelimiter, - authenticator=self.authenticator, - servlet_groups=self.servlet_groups, - ) - - -class AuthenticationError(SynapseError): - """There was a problem authenticating the request""" - - -class NoAuthenticationError(AuthenticationError): - """The request had no authentication information""" - - -class Authenticator: - def __init__(self, hs: HomeServer): - self._clock = hs.get_clock() - self.keyring = hs.get_keyring() - self.server_name = hs.hostname - self.store = hs.get_datastore() - self.federation_domain_whitelist = hs.config.federation_domain_whitelist - self.notifier = hs.get_notifier() - - self.replication_client = None - if hs.config.worker.worker_app: - self.replication_client = hs.get_tcp_replication() - - # A method just so we can pass 'self' as the authenticator to the Servlets - async def authenticate_request(self, request, content): - now = self._clock.time_msec() - json_request = { - "method": request.method.decode("ascii"), - "uri": request.uri.decode("ascii"), - "destination": self.server_name, - "signatures": {}, - } - - if content is not None: - json_request["content"] = content - - origin = None - - auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") - - if not auth_headers: - raise NoAuthenticationError( - 401, "Missing Authorization headers", Codes.UNAUTHORIZED - ) - - for auth in auth_headers: - if auth.startswith(b"X-Matrix"): - (origin, key, sig) = _parse_auth_header(auth) - json_request["origin"] = origin - json_request["signatures"].setdefault(origin, {})[key] = sig - - if ( - self.federation_domain_whitelist is not None - and origin not in self.federation_domain_whitelist - ): - raise FederationDeniedError(origin) - - if origin is None or not json_request["signatures"]: - raise NoAuthenticationError( - 401, "Missing Authorization headers", Codes.UNAUTHORIZED - ) - - await self.keyring.verify_json_for_server( - origin, - json_request, - now, - ) - - logger.debug("Request from %s", origin) - request.requester = origin - - # If we get a valid signed request from the other side, its probably - # alive - retry_timings = await self.store.get_destination_retry_timings(origin) - if retry_timings and retry_timings.retry_last_ts: - run_in_background(self._reset_retry_timings, origin) - - return origin - - async def _reset_retry_timings(self, origin): - try: - logger.info("Marking origin %r as up", origin) - await self.store.set_destination_retry_timings(origin, None, 0, 0) - - # Inform the relevant places that the remote server is back up. - self.notifier.notify_remote_server_up(origin) - if self.replication_client: - # If we're on a worker we try and inform master about this. The - # replication client doesn't hook into the notifier to avoid - # infinite loops where we send a `REMOTE_SERVER_UP` command to - # master, which then echoes it back to us which in turn pokes - # the notifier. - self.replication_client.send_remote_server_up(origin) - - except Exception: - logger.exception("Error resetting retry timings on %s", origin) - - -def _parse_auth_header(header_bytes): - """Parse an X-Matrix auth header - - Args: - header_bytes (bytes): header value - - Returns: - Tuple[str, str, str]: origin, key id, signature. - - Raises: - AuthenticationError if the header could not be parsed - """ - try: - header_str = header_bytes.decode("utf-8") - params = header_str.split(" ")[1].split(",") - param_dict = dict(kv.split("=") for kv in params) - - def strip_quotes(value): - if value.startswith('"'): - return value[1:-1] - else: - return value - - origin = strip_quotes(param_dict["origin"]) - - # ensure that the origin is a valid server name - parse_and_validate_server_name(origin) - - key = strip_quotes(param_dict["key"]) - sig = strip_quotes(param_dict["sig"]) - return origin, key, sig - except Exception as e: - logger.warning( - "Error parsing auth header '%s': %s", - header_bytes.decode("ascii", "replace"), - e, - ) - raise AuthenticationError( - 400, "Malformed Authorization header", Codes.UNAUTHORIZED - ) - - -class BaseFederationServlet: - """Abstract base class for federation servlet classes. - - The servlet object should have a PATH attribute which takes the form of a regexp to - match against the request path (excluding the /federation/v1 prefix). - - The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match - the appropriate HTTP method. These methods must be *asynchronous* and have the - signature: - - on_(self, origin, content, query, **kwargs) - - With arguments: - - origin (unicode|None): The authenticated server_name of the calling server, - unless REQUIRE_AUTH is set to False and authentication failed. - - content (unicode|None): decoded json body of the request. None if the - request was a GET. - - query (dict[bytes, list[bytes]]): Query params from the request. url-decoded - (ie, '+' and '%xx' are decoded) but note that it is *not* utf8-decoded - yet. - - **kwargs (dict[unicode, unicode]): the dict mapping keys to path - components as specified in the path match regexp. - - Returns: - Optional[Tuple[int, object]]: either (response code, response object) to - return a JSON response, or None if the request has already been handled. - - Raises: - SynapseError: to return an error code - - Exception: other exceptions will be caught, logged, and a 500 will be - returned. - """ - - PATH = "" # Overridden in subclasses, the regex to match against the path. - - REQUIRE_AUTH = True - - PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version - - RATELIMIT = True # Whether to rate limit requests or not - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - self.hs = hs - self.authenticator = authenticator - self.ratelimiter = ratelimiter - self.server_name = server_name - - def _wrap(self, func): - authenticator = self.authenticator - ratelimiter = self.ratelimiter - - @functools.wraps(func) - async def new_func(request, *args, **kwargs): - """A callback which can be passed to HttpServer.RegisterPaths - - Args: - request (twisted.web.http.Request): - *args: unused? - **kwargs (dict[unicode, unicode]): the dict mapping keys to path - components as specified in the path match regexp. - - Returns: - Tuple[int, object]|None: (response code, response object) as returned by - the callback method. None if the request has already been handled. - """ - content = None - if request.method in [b"PUT", b"POST"]: - # TODO: Handle other method types? other content types? - content = parse_json_object_from_request(request) - - try: - origin = await authenticator.authenticate_request(request, content) - except NoAuthenticationError: - origin = None - if self.REQUIRE_AUTH: - logger.warning( - "authenticate_request failed: missing authentication" - ) - raise - except Exception as e: - logger.warning("authenticate_request failed: %s", e) - raise - - request_tags = { - SynapseTags.REQUEST_ID: request.get_request_id(), - tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, - tags.HTTP_METHOD: request.get_method(), - tags.HTTP_URL: request.get_redacted_uri(), - tags.PEER_HOST_IPV6: request.getClientIP(), - "authenticated_entity": origin, - "servlet_name": request.request_metrics.name, - } - - # Only accept the span context if the origin is authenticated - # and whitelisted - if origin and whitelisted_homeserver(origin): - scope = start_active_span_from_request( - request, "incoming-federation-request", tags=request_tags - ) - else: - scope = start_active_span( - "incoming-federation-request", tags=request_tags - ) - - with scope: - opentracing.inject_response_headers(request.responseHeaders) - - if origin and self.RATELIMIT: - with ratelimiter.ratelimit(origin) as d: - await d - if request._disconnected: - logger.warning( - "client disconnected before we started processing " - "request" - ) - return -1, None - response = await func( - origin, content, request.args, *args, **kwargs - ) - else: - response = await func( - origin, content, request.args, *args, **kwargs - ) - - return response - - return new_func - - def register(self, server): - pattern = re.compile("^" + self.PREFIX + self.PATH + "$") - - for method in ("GET", "PUT", "POST"): - code = getattr(self, "on_%s" % (method), None) - if code is None: - continue - - server.register_paths( - method, - (pattern,), - self._wrap(code), - self.__class__.__name__, - ) - - -class BaseFederationServerServlet(BaseFederationServlet): - """Abstract base class for federation servlet classes which provides a federation server handler. - - See BaseFederationServlet for more information. - """ - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_federation_server() - - -class FederationSendServlet(BaseFederationServerServlet): - PATH = "/send/(?P[^/]*)/?" - - # We ratelimit manually in the handler as we queue up the requests and we - # don't want to fill up the ratelimiter with blocked requests. - RATELIMIT = False - - # This is when someone is trying to send us a bunch of data. - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - transaction_id: str, - ) -> Tuple[int, JsonDict]: - """Called on PUT /send// - - Args: - transaction_id: The transaction_id associated with this request. This - is *not* None. - - Returns: - Tuple of `(code, response)`, where - `response` is a python dict to be converted into JSON that is - used as the response body. - """ - # Parse the request - try: - transaction_data = content - - logger.debug("Decoded %s: %s", transaction_id, str(transaction_data)) - - logger.info( - "Received txn %s from %s. (PDUs: %d, EDUs: %d)", - transaction_id, - origin, - len(transaction_data.get("pdus", [])), - len(transaction_data.get("edus", [])), - ) - - except Exception as e: - logger.exception(e) - return 400, {"error": "Invalid transaction"} - - code, response = await self.handler.on_incoming_transaction( - origin, transaction_id, self.server_name, transaction_data - ) - - return code, response - - -class FederationEventServlet(BaseFederationServerServlet): - PATH = "/event/(?P[^/]*)/?" - - # This is when someone asks for a data item for a given server data_id pair. - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - event_id: str, - ) -> Tuple[int, Union[JsonDict, str]]: - return await self.handler.on_pdu_request(origin, event_id) - - -class FederationStateV1Servlet(BaseFederationServerServlet): - PATH = "/state/(?P[^/]*)/?" - - # This is when someone asks for all data for a given room. - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - return await self.handler.on_room_state_request( - origin, - room_id, - parse_string_from_args(query, "event_id", None, required=False), - ) - - -class FederationStateIdsServlet(BaseFederationServerServlet): - PATH = "/state_ids/(?P[^/]*)/?" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - return await self.handler.on_state_ids_request( - origin, - room_id, - parse_string_from_args(query, "event_id", None, required=True), - ) - - -class FederationBackfillServlet(BaseFederationServerServlet): - PATH = "/backfill/(?P[^/]*)/?" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - versions = [x.decode("ascii") for x in query[b"v"]] - limit = parse_integer_from_args(query, "limit", None) - - if not limit: - return 400, {"error": "Did not include limit param"} - - return await self.handler.on_backfill_request(origin, room_id, versions, limit) - - -class FederationQueryServlet(BaseFederationServerServlet): - PATH = "/query/(?P[^/]*)" - - # This is when we receive a server-server Query - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - query_type: str, - ) -> Tuple[int, JsonDict]: - args = {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()} - args["origin"] = origin - return await self.handler.on_query_request(query_type, args) - - -class FederationMakeJoinServlet(BaseFederationServerServlet): - PATH = "/make_join/(?P[^/]*)/(?P[^/]*)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - """ - Args: - origin: The authenticated server_name of the calling server - - content: (GETs don't have bodies) - - query: Query params from the request. - - **kwargs: the dict mapping keys to path components as specified in - the path match regexp. - - Returns: - Tuple of (response code, response object) - """ - supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8") - if supported_versions is None: - supported_versions = ["1"] - - result = await self.handler.on_make_join_request( - origin, room_id, user_id, supported_versions=supported_versions - ) - return 200, result - - -class FederationMakeLeaveServlet(BaseFederationServerServlet): - PATH = "/make_leave/(?P[^/]*)/(?P[^/]*)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - result = await self.handler.on_make_leave_request(origin, room_id, user_id) - return 200, result - - -class FederationV1SendLeaveServlet(BaseFederationServerServlet): - PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, Tuple[int, JsonDict]]: - result = await self.handler.on_send_leave_request(origin, content, room_id) - return 200, (200, result) - - -class FederationV2SendLeaveServlet(BaseFederationServerServlet): - PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" - - PREFIX = FEDERATION_V2_PREFIX - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, JsonDict]: - result = await self.handler.on_send_leave_request(origin, content, room_id) - return 200, result - - -class FederationMakeKnockServlet(BaseFederationServerServlet): - PATH = "/make_knock/(?P[^/]*)/(?P[^/]*)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - # Retrieve the room versions the remote homeserver claims to support - supported_versions = parse_strings_from_args( - query, "ver", required=True, encoding="utf-8" - ) - - result = await self.handler.on_make_knock_request( - origin, room_id, user_id, supported_versions=supported_versions - ) - return 200, result - - -class FederationV1SendKnockServlet(BaseFederationServerServlet): - PATH = "/send_knock/(?P[^/]*)/(?P[^/]*)" - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, JsonDict]: - result = await self.handler.on_send_knock_request(origin, content, room_id) - return 200, result - - -class FederationEventAuthServlet(BaseFederationServerServlet): - PATH = "/event_auth/(?P[^/]*)/(?P[^/]*)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, JsonDict]: - return await self.handler.on_event_auth(origin, room_id, event_id) - - -class FederationV1SendJoinServlet(BaseFederationServerServlet): - PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, Tuple[int, JsonDict]]: - # TODO(paul): assert that event_id parsed from path actually - # match those given in content - result = await self.handler.on_send_join_request(origin, content, room_id) - return 200, (200, result) - - -class FederationV2SendJoinServlet(BaseFederationServerServlet): - PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" - - PREFIX = FEDERATION_V2_PREFIX - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, JsonDict]: - # TODO(paul): assert that event_id parsed from path actually - # match those given in content - result = await self.handler.on_send_join_request(origin, content, room_id) - return 200, result - - -class FederationV1InviteServlet(BaseFederationServerServlet): - PATH = "/invite/(?P[^/]*)/(?P[^/]*)" - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, Tuple[int, JsonDict]]: - # We don't get a room version, so we have to assume its EITHER v1 or - # v2. This is "fine" as the only difference between V1 and V2 is the - # state resolution algorithm, and we don't use that for processing - # invites - result = await self.handler.on_invite_request( - origin, content, room_version_id=RoomVersions.V1.identifier - ) - - # V1 federation API is defined to return a content of `[200, {...}]` - # due to a historical bug. - return 200, (200, result) - - -class FederationV2InviteServlet(BaseFederationServerServlet): - PATH = "/invite/(?P[^/]*)/(?P[^/]*)" - - PREFIX = FEDERATION_V2_PREFIX - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - event_id: str, - ) -> Tuple[int, JsonDict]: - # TODO(paul): assert that room_id/event_id parsed from path actually - # match those given in content - - room_version = content["room_version"] - event = content["event"] - invite_room_state = content["invite_room_state"] - - # Synapse expects invite_room_state to be in unsigned, as it is in v1 - # API - - event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state - - result = await self.handler.on_invite_request( - origin, event, room_version_id=room_version - ) - return 200, result - - -class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet): - PATH = "/exchange_third_party_invite/(?P[^/]*)" - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - await self.handler.on_exchange_third_party_invite_request(content) - return 200, {} - - -class FederationClientKeysQueryServlet(BaseFederationServerServlet): - PATH = "/user/keys/query" - - async def on_POST( - self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] - ) -> Tuple[int, JsonDict]: - return await self.handler.on_query_client_keys(origin, content) - - -class FederationUserDevicesQueryServlet(BaseFederationServerServlet): - PATH = "/user/devices/(?P[^/]*)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - user_id: str, - ) -> Tuple[int, JsonDict]: - return await self.handler.on_query_user_devices(origin, user_id) - - -class FederationClientKeysClaimServlet(BaseFederationServerServlet): - PATH = "/user/keys/claim" - - async def on_POST( - self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] - ) -> Tuple[int, JsonDict]: - response = await self.handler.on_claim_client_keys(origin, content) - return 200, response - - -class FederationGetMissingEventsServlet(BaseFederationServerServlet): - # TODO(paul): Why does this path alone end with "/?" optional? - PATH = "/get_missing_events/(?P[^/]*)/?" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - limit = int(content.get("limit", 10)) - earliest_events = content.get("earliest_events", []) - latest_events = content.get("latest_events", []) - - result = await self.handler.on_get_missing_events( - origin, - room_id=room_id, - earliest_events=earliest_events, - latest_events=latest_events, - limit=limit, - ) - - return 200, result - - -class On3pidBindServlet(BaseFederationServerServlet): - PATH = "/3pid/onbind" - - REQUIRE_AUTH = False - - async def on_POST( - self, origin: Optional[str], content: JsonDict, query: Dict[bytes, List[bytes]] - ) -> Tuple[int, JsonDict]: - if "invites" in content: - last_exception = None - for invite in content["invites"]: - try: - if "signed" not in invite or "token" not in invite["signed"]: - message = ( - "Rejecting received notification of third-" - "party invite without signed: %s" % (invite,) - ) - logger.info(message) - raise SynapseError(400, message) - await self.handler.exchange_third_party_invite( - invite["sender"], - invite["mxid"], - invite["room_id"], - invite["signed"], - ) - except Exception as e: - last_exception = e - if last_exception: - raise last_exception - return 200, {} - - -class OpenIdUserInfo(BaseFederationServerServlet): - """ - Exchange a bearer token for information about a user. - - The response format should be compatible with: - http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse - - GET /openid/userinfo?access_token=ABDEFGH HTTP/1.1 - - HTTP/1.1 200 OK - Content-Type: application/json - - { - "sub": "@userpart:example.org", - } - """ - - PATH = "/openid/userinfo" - - REQUIRE_AUTH = False - - async def on_GET( - self, - origin: Optional[str], - content: Literal[None], - query: Dict[bytes, List[bytes]], - ) -> Tuple[int, JsonDict]: - token = parse_string_from_args(query, "access_token") - if token is None: - return ( - 401, - {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"}, - ) - - user_id = await self.handler.on_openid_userinfo(token) - - if user_id is None: - return ( - 401, - { - "errcode": "M_UNKNOWN_TOKEN", - "error": "Access Token unknown or expired", - }, - ) - - return 200, {"sub": user_id} - - -class PublicRoomList(BaseFederationServlet): - """ - Fetch the public room list for this server. - - This API returns information in the same format as /publicRooms on the - client API, but will only ever include local public rooms and hence is - intended for consumption by other homeservers. - - GET /publicRooms HTTP/1.1 - - HTTP/1.1 200 OK - Content-Type: application/json - - { - "chunk": [ - { - "aliases": [ - "#test:localhost" - ], - "guest_can_join": false, - "name": "test room", - "num_joined_members": 3, - "room_id": "!whkydVegtvatLfXmPN:localhost", - "world_readable": false - } - ], - "end": "END", - "start": "START" - } - """ - - PATH = "/publicRooms" - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - allow_access: bool, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_room_list_handler() - self.allow_access = allow_access - - async def on_GET( - self, origin: str, content: Literal[None], query: Dict[bytes, List[bytes]] - ) -> Tuple[int, JsonDict]: - if not self.allow_access: - raise FederationDeniedError(origin) - - limit = parse_integer_from_args(query, "limit", 0) - since_token = parse_string_from_args(query, "since", None) - include_all_networks = parse_boolean_from_args( - query, "include_all_networks", default=False - ) - third_party_instance_id = parse_string_from_args( - query, "third_party_instance_id", None - ) - - if include_all_networks: - network_tuple = None - elif third_party_instance_id: - network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) - else: - network_tuple = ThirdPartyInstanceID(None, None) - - if limit == 0: - # zero is a special value which corresponds to no limit. - limit = None - - data = await self.handler.get_local_public_room_list( - limit, since_token, network_tuple=network_tuple, from_federation=True - ) - return 200, data - - async def on_POST( - self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] - ) -> Tuple[int, JsonDict]: - # This implements MSC2197 (Search Filtering over Federation) - if not self.allow_access: - raise FederationDeniedError(origin) - - limit: Optional[int] = int(content.get("limit", 100)) - since_token = content.get("since", None) - search_filter = content.get("filter", None) - - include_all_networks = content.get("include_all_networks", False) - third_party_instance_id = content.get("third_party_instance_id", None) - - if include_all_networks: - network_tuple = None - if third_party_instance_id is not None: - raise SynapseError( - 400, "Can't use include_all_networks with an explicit network" - ) - elif third_party_instance_id is None: - network_tuple = ThirdPartyInstanceID(None, None) - else: - network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) - - if search_filter is None: - logger.warning("Nonefilter") - - if limit == 0: - # zero is a special value which corresponds to no limit. - limit = None - - data = await self.handler.get_local_public_room_list( - limit=limit, - since_token=since_token, - search_filter=search_filter, - network_tuple=network_tuple, - from_federation=True, - ) - - return 200, data - - -class FederationVersionServlet(BaseFederationServlet): - PATH = "/version" - - REQUIRE_AUTH = False - - async def on_GET( - self, - origin: Optional[str], - content: Literal[None], - query: Dict[bytes, List[bytes]], - ) -> Tuple[int, JsonDict]: - return ( - 200, - {"server": {"name": "Synapse", "version": get_version_string(synapse)}}, - ) - - -class BaseGroupsServerServlet(BaseFederationServlet): - """Abstract base class for federation servlet classes which provides a groups server handler. - - See BaseFederationServlet for more information. - """ - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_groups_server_handler() - - -class FederationGroupsProfileServlet(BaseGroupsServerServlet): - """Get/set the basic profile of a group on behalf of a user""" - - PATH = "/groups/(?P[^/]*)/profile" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_group_profile(group_id, requester_user_id) - - return 200, new_content - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.update_group_profile( - group_id, requester_user_id, content - ) - - return 200, new_content - - -class FederationGroupsSummaryServlet(BaseGroupsServerServlet): - PATH = "/groups/(?P[^/]*)/summary" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_group_summary(group_id, requester_user_id) - - return 200, new_content - - -class FederationGroupsRoomsServlet(BaseGroupsServerServlet): - """Get the rooms in a group on behalf of a user""" - - PATH = "/groups/(?P[^/]*)/rooms" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_rooms_in_group(group_id, requester_user_id) - - return 200, new_content - - -class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet): - """Add/remove room from group""" - - PATH = "/groups/(?P[^/]*)/room/(?P[^/]*)" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - room_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.add_room_to_group( - group_id, requester_user_id, room_id, content - ) - - return 200, new_content - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - room_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.remove_room_from_group( - group_id, requester_user_id, room_id - ) - - return 200, new_content - - -class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet): - """Update room config in group""" - - PATH = ( - "/groups/(?P[^/]*)/room/(?P[^/]*)" - "/config/(?P[^/]*)" - ) - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - room_id: str, - config_key: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - result = await self.handler.update_room_in_group( - group_id, requester_user_id, room_id, config_key, content - ) - - return 200, result - - -class FederationGroupsUsersServlet(BaseGroupsServerServlet): - """Get the users in a group on behalf of a user""" - - PATH = "/groups/(?P[^/]*)/users" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_users_in_group(group_id, requester_user_id) - - return 200, new_content - - -class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet): - """Get the users that have been invited to a group""" - - PATH = "/groups/(?P[^/]*)/invited_users" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_invited_users_in_group( - group_id, requester_user_id - ) - - return 200, new_content - - -class FederationGroupsInviteServlet(BaseGroupsServerServlet): - """Ask a group server to invite someone to the group""" - - PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/invite" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.invite_to_group( - group_id, user_id, requester_user_id, content - ) - - return 200, new_content - - -class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet): - """Accept an invitation from the group server""" - - PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/accept_invite" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - if get_domain_from_id(user_id) != origin: - raise SynapseError(403, "user_id doesn't match origin") - - new_content = await self.handler.accept_invite(group_id, user_id, content) - - return 200, new_content - - -class FederationGroupsJoinServlet(BaseGroupsServerServlet): - """Attempt to join a group""" - - PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/join" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - if get_domain_from_id(user_id) != origin: - raise SynapseError(403, "user_id doesn't match origin") - - new_content = await self.handler.join_group(group_id, user_id, content) - - return 200, new_content - - -class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet): - """Leave or kick a user from the group""" - - PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/remove" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.remove_user_from_group( - group_id, user_id, requester_user_id, content - ) - - return 200, new_content - - -class BaseGroupsLocalServlet(BaseFederationServlet): - """Abstract base class for federation servlet classes which provides a groups local handler. - - See BaseFederationServlet for more information. - """ - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_groups_local_handler() - - -class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet): - """A group server has invited a local user""" - - PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/invite" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - if get_domain_from_id(group_id) != origin: - raise SynapseError(403, "group_id doesn't match origin") - - assert isinstance( - self.handler, GroupsLocalHandler - ), "Workers cannot handle group invites." - - new_content = await self.handler.on_invite(group_id, user_id, content) - - return 200, new_content - - -class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet): - """A group server has removed a local user""" - - PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/remove" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, None]: - if get_domain_from_id(group_id) != origin: - raise SynapseError(403, "user_id doesn't match origin") - - assert isinstance( - self.handler, GroupsLocalHandler - ), "Workers cannot handle group removals." - - await self.handler.user_removed_from_group(group_id, user_id, content) - - return 200, None - - -class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): - """A group or user's server renews their attestation""" - - PATH = "/groups/(?P[^/]*)/renew_attestation/(?P[^/]*)" - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_groups_attestation_renewer() - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - # We don't need to check auth here as we check the attestation signatures - - new_content = await self.handler.on_renew_attestation( - group_id, user_id, content - ) - - return 200, new_content - - -class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet): - """Add/remove a room from the group summary, with optional category. - - Matches both: - - /groups/:group/summary/rooms/:room_id - - /groups/:group/summary/categories/:category/rooms/:room_id - """ - - PATH = ( - "/groups/(?P[^/]*)/summary" - "(/categories/(?P[^/]+))?" - "/rooms/(?P[^/]*)" - ) - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - room_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if category_id == "": - raise SynapseError( - 400, "category_id cannot be empty string", Codes.INVALID_PARAM - ) - - if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: - raise SynapseError( - 400, - "category_id may not be longer than %s characters" - % (MAX_GROUP_CATEGORYID_LENGTH,), - Codes.INVALID_PARAM, - ) - - resp = await self.handler.update_group_summary_room( - group_id, - requester_user_id, - room_id=room_id, - category_id=category_id, - content=content, - ) - - return 200, resp - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - room_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if category_id == "": - raise SynapseError(400, "category_id cannot be empty string") - - resp = await self.handler.delete_group_summary_room( - group_id, requester_user_id, room_id=room_id, category_id=category_id - ) - - return 200, resp - - -class FederationGroupsCategoriesServlet(BaseGroupsServerServlet): - """Get all categories for a group""" - - PATH = "/groups/(?P[^/]*)/categories/?" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - resp = await self.handler.get_group_categories(group_id, requester_user_id) - - return 200, resp - - -class FederationGroupsCategoryServlet(BaseGroupsServerServlet): - """Add/remove/get a category in a group""" - - PATH = "/groups/(?P[^/]*)/categories/(?P[^/]+)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - resp = await self.handler.get_group_category( - group_id, requester_user_id, category_id - ) - - return 200, resp - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if category_id == "": - raise SynapseError(400, "category_id cannot be empty string") - - if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: - raise SynapseError( - 400, - "category_id may not be longer than %s characters" - % (MAX_GROUP_CATEGORYID_LENGTH,), - Codes.INVALID_PARAM, - ) - - resp = await self.handler.upsert_group_category( - group_id, requester_user_id, category_id, content - ) - - return 200, resp - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if category_id == "": - raise SynapseError(400, "category_id cannot be empty string") - - resp = await self.handler.delete_group_category( - group_id, requester_user_id, category_id - ) - - return 200, resp - - -class FederationGroupsRolesServlet(BaseGroupsServerServlet): - """Get roles in a group""" - - PATH = "/groups/(?P[^/]*)/roles/?" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - resp = await self.handler.get_group_roles(group_id, requester_user_id) - - return 200, resp - - -class FederationGroupsRoleServlet(BaseGroupsServerServlet): - """Add/remove/get a role in a group""" - - PATH = "/groups/(?P[^/]*)/roles/(?P[^/]+)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - resp = await self.handler.get_group_role(group_id, requester_user_id, role_id) - - return 200, resp - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if role_id == "": - raise SynapseError( - 400, "role_id cannot be empty string", Codes.INVALID_PARAM - ) - - if len(role_id) > MAX_GROUP_ROLEID_LENGTH: - raise SynapseError( - 400, - "role_id may not be longer than %s characters" - % (MAX_GROUP_ROLEID_LENGTH,), - Codes.INVALID_PARAM, - ) - - resp = await self.handler.update_group_role( - group_id, requester_user_id, role_id, content - ) - - return 200, resp - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if role_id == "": - raise SynapseError(400, "role_id cannot be empty string") - - resp = await self.handler.delete_group_role( - group_id, requester_user_id, role_id - ) - - return 200, resp - - -class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet): - """Add/remove a user from the group summary, with optional role. - - Matches both: - - /groups/:group/summary/users/:user_id - - /groups/:group/summary/roles/:role/users/:user_id - """ - - PATH = ( - "/groups/(?P[^/]*)/summary" - "(/roles/(?P[^/]+))?" - "/users/(?P[^/]*)" - ) - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if role_id == "": - raise SynapseError(400, "role_id cannot be empty string") - - if len(role_id) > MAX_GROUP_ROLEID_LENGTH: - raise SynapseError( - 400, - "role_id may not be longer than %s characters" - % (MAX_GROUP_ROLEID_LENGTH,), - Codes.INVALID_PARAM, - ) - - resp = await self.handler.update_group_summary_user( - group_id, - requester_user_id, - user_id=user_id, - role_id=role_id, - content=content, - ) - - return 200, resp - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if role_id == "": - raise SynapseError(400, "role_id cannot be empty string") - - resp = await self.handler.delete_group_summary_user( - group_id, requester_user_id, user_id=user_id, role_id=role_id - ) - - return 200, resp - - -class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet): - """Get roles in a group""" - - PATH = "/get_groups_publicised" - - async def on_POST( - self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] - ) -> Tuple[int, JsonDict]: - resp = await self.handler.bulk_get_publicised_groups( - content["user_ids"], proxy=False - ) - - return 200, resp - - -class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet): - """Sets whether a group is joinable without an invite or knock""" - - PATH = "/groups/(?P[^/]*)/settings/m.join_policy" - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.set_group_join_policy( - group_id, requester_user_id, content - ) - - return 200, new_content - - -class FederationSpaceSummaryServlet(BaseFederationServlet): - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" - PATH = "/spaces/(?P[^/]*)" - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_space_summary_handler() - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Mapping[bytes, Sequence[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) - max_rooms_per_space = parse_integer_from_args(query, "max_rooms_per_space") - - exclude_rooms = parse_strings_from_args(query, "exclude_rooms", default=[]) - - return 200, await self.handler.federation_space_summary( - origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms - ) - - # TODO When switching to the stable endpoint, remove the POST handler. - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Mapping[bytes, Sequence[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - suggested_only = content.get("suggested_only", False) - if not isinstance(suggested_only, bool): - raise SynapseError( - 400, "'suggested_only' must be a boolean", Codes.BAD_JSON - ) - - exclude_rooms = content.get("exclude_rooms", []) - if not isinstance(exclude_rooms, list) or any( - not isinstance(x, str) for x in exclude_rooms - ): - raise SynapseError(400, "bad value for 'exclude_rooms'", Codes.BAD_JSON) - - max_rooms_per_space = content.get("max_rooms_per_space") - if max_rooms_per_space is not None and not isinstance(max_rooms_per_space, int): - raise SynapseError( - 400, "bad value for 'max_rooms_per_space'", Codes.BAD_JSON - ) - - return 200, await self.handler.federation_space_summary( - origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms - ) - - -class FederationRoomHierarchyServlet(BaseFederationServlet): - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" - PATH = "/hierarchy/(?P[^/]*)" - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_space_summary_handler() - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Mapping[bytes, Sequence[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) - return 200, await self.handler.get_federation_hierarchy( - origin, room_id, suggested_only - ) - - -class RoomComplexityServlet(BaseFederationServlet): - """ - Indicates to other servers how complex (and therefore likely - resource-intensive) a public room this server knows about is. - """ - - PATH = "/rooms/(?P[^/]*)/complexity" - PREFIX = FEDERATION_UNSTABLE_PREFIX - - def __init__( - self, - hs: HomeServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self._store = self.hs.get_datastore() - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - room_id: str, - ) -> Tuple[int, JsonDict]: - is_public = await self._store.is_room_world_readable_or_publicly_joinable( - room_id - ) - - if not is_public: - raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM) - - complexity = await self._store.get_room_complexity(room_id) - return 200, complexity - - -FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( - FederationSendServlet, - FederationEventServlet, - FederationStateV1Servlet, - FederationStateIdsServlet, - FederationBackfillServlet, - FederationQueryServlet, - FederationMakeJoinServlet, - FederationMakeLeaveServlet, - FederationEventServlet, - FederationV1SendJoinServlet, - FederationV2SendJoinServlet, - FederationV1SendLeaveServlet, - FederationV2SendLeaveServlet, - FederationV1InviteServlet, - FederationV2InviteServlet, - FederationGetMissingEventsServlet, - FederationEventAuthServlet, - FederationClientKeysQueryServlet, - FederationUserDevicesQueryServlet, - FederationClientKeysClaimServlet, - FederationThirdPartyInviteExchangeServlet, - On3pidBindServlet, - FederationVersionServlet, - RoomComplexityServlet, - FederationSpaceSummaryServlet, - FederationRoomHierarchyServlet, - FederationV1SendKnockServlet, - FederationMakeKnockServlet, -) - -OPENID_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (OpenIdUserInfo,) - -ROOM_LIST_CLASSES: Tuple[Type[PublicRoomList], ...] = (PublicRoomList,) - -GROUP_SERVER_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( - FederationGroupsProfileServlet, - FederationGroupsSummaryServlet, - FederationGroupsRoomsServlet, - FederationGroupsUsersServlet, - FederationGroupsInvitedUsersServlet, - FederationGroupsInviteServlet, - FederationGroupsAcceptInviteServlet, - FederationGroupsJoinServlet, - FederationGroupsRemoveUserServlet, - FederationGroupsSummaryRoomsServlet, - FederationGroupsCategoriesServlet, - FederationGroupsCategoryServlet, - FederationGroupsRolesServlet, - FederationGroupsRoleServlet, - FederationGroupsSummaryUsersServlet, - FederationGroupsAddRoomsServlet, - FederationGroupsAddRoomsConfigServlet, - FederationGroupsSettingJoinPolicyServlet, -) - - -GROUP_LOCAL_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( - FederationGroupsLocalInviteServlet, - FederationGroupsRemoveLocalUserServlet, - FederationGroupsBulkPublicisedServlet, -) - - -GROUP_ATTESTATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( - FederationGroupsRenewAttestaionServlet, -) - - -DEFAULT_SERVLET_GROUPS = ( - "federation", - "room_list", - "group_server", - "group_local", - "group_attestation", - "openid", -) - - -def register_servlets( - hs: HomeServer, - resource: HttpServer, - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - servlet_groups: Optional[Container[str]] = None, -): - """Initialize and register servlet classes. - - Will by default register all servlets. For custom behaviour, pass in - a list of servlet_groups to register. - - Args: - hs: homeserver - resource: resource class to register to - authenticator: authenticator to use - ratelimiter: ratelimiter to use - servlet_groups: List of servlet groups to register. - Defaults to ``DEFAULT_SERVLET_GROUPS``. - """ - if not servlet_groups: - servlet_groups = DEFAULT_SERVLET_GROUPS - - if "federation" in servlet_groups: - for servletclass in FEDERATION_SERVLET_CLASSES: - servletclass( - hs=hs, - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) - - if "openid" in servlet_groups: - for servletclass in OPENID_SERVLET_CLASSES: - servletclass( - hs=hs, - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) - - if "room_list" in servlet_groups: - for servletclass in ROOM_LIST_CLASSES: - servletclass( - hs=hs, - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - allow_access=hs.config.allow_public_rooms_over_federation, - ).register(resource) - - if "group_server" in servlet_groups: - for servletclass in GROUP_SERVER_SERVLET_CLASSES: - servletclass( - hs=hs, - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) - - if "group_local" in servlet_groups: - for servletclass in GROUP_LOCAL_SERVLET_CLASSES: - servletclass( - hs=hs, - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) - - if "group_attestation" in servlet_groups: - for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES: - servletclass( - hs=hs, - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py new file mode 100644 index 0000000000..95176ba6f9 --- /dev/null +++ b/synapse/federation/transport/server/__init__.py @@ -0,0 +1,332 @@ +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. +# Copyright 2020 Sorunome +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Dict, Iterable, List, Optional, Tuple, Type + +from typing_extensions import Literal + +from synapse.api.errors import FederationDeniedError, SynapseError +from synapse.federation.transport.server._base import ( + Authenticator, + BaseFederationServlet, +) +from synapse.federation.transport.server.federation import FEDERATION_SERVLET_CLASSES +from synapse.federation.transport.server.groups_local import GROUP_LOCAL_SERVLET_CLASSES +from synapse.federation.transport.server.groups_server import ( + GROUP_SERVER_SERVLET_CLASSES, +) +from synapse.http.server import HttpServer, JsonResource +from synapse.http.servlet import ( + parse_boolean_from_args, + parse_integer_from_args, + parse_string_from_args, +) +from synapse.server import HomeServer +from synapse.types import JsonDict, ThirdPartyInstanceID +from synapse.util.ratelimitutils import FederationRateLimiter + +logger = logging.getLogger(__name__) + + +class TransportLayerServer(JsonResource): + """Handles incoming federation HTTP requests""" + + def __init__(self, hs: HomeServer, servlet_groups: Optional[List[str]] = None): + """Initialize the TransportLayerServer + + Will by default register all servlets. For custom behaviour, pass in + a list of servlet_groups to register. + + Args: + hs: homeserver + servlet_groups: List of servlet groups to register. + Defaults to ``DEFAULT_SERVLET_GROUPS``. + """ + self.hs = hs + self.clock = hs.get_clock() + self.servlet_groups = servlet_groups + + super().__init__(hs, canonical_json=False) + + self.authenticator = Authenticator(hs) + self.ratelimiter = hs.get_federation_ratelimiter() + + self.register_servlets() + + def register_servlets(self) -> None: + register_servlets( + self.hs, + resource=self, + ratelimiter=self.ratelimiter, + authenticator=self.authenticator, + servlet_groups=self.servlet_groups, + ) + + +class PublicRoomList(BaseFederationServlet): + """ + Fetch the public room list for this server. + + This API returns information in the same format as /publicRooms on the + client API, but will only ever include local public rooms and hence is + intended for consumption by other homeservers. + + GET /publicRooms HTTP/1.1 + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "chunk": [ + { + "aliases": [ + "#test:localhost" + ], + "guest_can_join": false, + "name": "test room", + "num_joined_members": 3, + "room_id": "!whkydVegtvatLfXmPN:localhost", + "world_readable": false + } + ], + "end": "END", + "start": "START" + } + """ + + PATH = "/publicRooms" + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_room_list_handler() + self.allow_access = hs.config.allow_public_rooms_over_federation + + async def on_GET( + self, origin: str, content: Literal[None], query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + if not self.allow_access: + raise FederationDeniedError(origin) + + limit = parse_integer_from_args(query, "limit", 0) + since_token = parse_string_from_args(query, "since", None) + include_all_networks = parse_boolean_from_args( + query, "include_all_networks", default=False + ) + third_party_instance_id = parse_string_from_args( + query, "third_party_instance_id", None + ) + + if include_all_networks: + network_tuple = None + elif third_party_instance_id: + network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) + else: + network_tuple = ThirdPartyInstanceID(None, None) + + if limit == 0: + # zero is a special value which corresponds to no limit. + limit = None + + data = await self.handler.get_local_public_room_list( + limit, since_token, network_tuple=network_tuple, from_federation=True + ) + return 200, data + + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + # This implements MSC2197 (Search Filtering over Federation) + if not self.allow_access: + raise FederationDeniedError(origin) + + limit: Optional[int] = int(content.get("limit", 100)) + since_token = content.get("since", None) + search_filter = content.get("filter", None) + + include_all_networks = content.get("include_all_networks", False) + third_party_instance_id = content.get("third_party_instance_id", None) + + if include_all_networks: + network_tuple = None + if third_party_instance_id is not None: + raise SynapseError( + 400, "Can't use include_all_networks with an explicit network" + ) + elif third_party_instance_id is None: + network_tuple = ThirdPartyInstanceID(None, None) + else: + network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) + + if search_filter is None: + logger.warning("Nonefilter") + + if limit == 0: + # zero is a special value which corresponds to no limit. + limit = None + + data = await self.handler.get_local_public_room_list( + limit=limit, + since_token=since_token, + search_filter=search_filter, + network_tuple=network_tuple, + from_federation=True, + ) + + return 200, data + + +class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): + """A group or user's server renews their attestation""" + + PATH = "/groups/(?P[^/]*)/renew_attestation/(?P[^/]*)" + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_attestation_renewer() + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + # We don't need to check auth here as we check the attestation signatures + + new_content = await self.handler.on_renew_attestation( + group_id, user_id, content + ) + + return 200, new_content + + +class OpenIdUserInfo(BaseFederationServlet): + """ + Exchange a bearer token for information about a user. + + The response format should be compatible with: + http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse + + GET /openid/userinfo?access_token=ABDEFGH HTTP/1.1 + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "sub": "@userpart:example.org", + } + """ + + PATH = "/openid/userinfo" + + REQUIRE_AUTH = False + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_federation_server() + + async def on_GET( + self, + origin: Optional[str], + content: Literal[None], + query: Dict[bytes, List[bytes]], + ) -> Tuple[int, JsonDict]: + token = parse_string_from_args(query, "access_token") + if token is None: + return ( + 401, + {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"}, + ) + + user_id = await self.handler.on_openid_userinfo(token) + + if user_id is None: + return ( + 401, + { + "errcode": "M_UNKNOWN_TOKEN", + "error": "Access Token unknown or expired", + }, + ) + + return 200, {"sub": user_id} + + +DEFAULT_SERVLET_GROUPS: Dict[str, Iterable[Type[BaseFederationServlet]]] = { + "federation": FEDERATION_SERVLET_CLASSES, + "room_list": (PublicRoomList,), + "group_server": GROUP_SERVER_SERVLET_CLASSES, + "group_local": GROUP_LOCAL_SERVLET_CLASSES, + "group_attestation": (FederationGroupsRenewAttestaionServlet,), + "openid": (OpenIdUserInfo,), +} + + +def register_servlets( + hs: HomeServer, + resource: HttpServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + servlet_groups: Optional[Iterable[str]] = None, +): + """Initialize and register servlet classes. + + Will by default register all servlets. For custom behaviour, pass in + a list of servlet_groups to register. + + Args: + hs: homeserver + resource: resource class to register to + authenticator: authenticator to use + ratelimiter: ratelimiter to use + servlet_groups: List of servlet groups to register. + Defaults to ``DEFAULT_SERVLET_GROUPS``. + """ + if not servlet_groups: + servlet_groups = DEFAULT_SERVLET_GROUPS.keys() + + for servlet_group in servlet_groups: + # Skip unknown servlet groups. + if servlet_group not in DEFAULT_SERVLET_GROUPS: + raise RuntimeError( + f"Attempting to register unknown federation servlet: '{servlet_group}'" + ) + + for servletclass in DEFAULT_SERVLET_GROUPS[servlet_group]: + servletclass( + hs=hs, + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py new file mode 100644 index 0000000000..624c859f1e --- /dev/null +++ b/synapse/federation/transport/server/_base.py @@ -0,0 +1,328 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import logging +import re + +from synapse.api.errors import Codes, FederationDeniedError, SynapseError +from synapse.api.urls import FEDERATION_V1_PREFIX +from synapse.http.servlet import parse_json_object_from_request +from synapse.logging import opentracing +from synapse.logging.context import run_in_background +from synapse.logging.opentracing import ( + SynapseTags, + start_active_span, + start_active_span_from_request, + tags, + whitelisted_homeserver, +) +from synapse.server import HomeServer +from synapse.util.ratelimitutils import FederationRateLimiter +from synapse.util.stringutils import parse_and_validate_server_name + +logger = logging.getLogger(__name__) + + +class AuthenticationError(SynapseError): + """There was a problem authenticating the request""" + + +class NoAuthenticationError(AuthenticationError): + """The request had no authentication information""" + + +class Authenticator: + def __init__(self, hs: HomeServer): + self._clock = hs.get_clock() + self.keyring = hs.get_keyring() + self.server_name = hs.hostname + self.store = hs.get_datastore() + self.federation_domain_whitelist = hs.config.federation_domain_whitelist + self.notifier = hs.get_notifier() + + self.replication_client = None + if hs.config.worker.worker_app: + self.replication_client = hs.get_tcp_replication() + + # A method just so we can pass 'self' as the authenticator to the Servlets + async def authenticate_request(self, request, content): + now = self._clock.time_msec() + json_request = { + "method": request.method.decode("ascii"), + "uri": request.uri.decode("ascii"), + "destination": self.server_name, + "signatures": {}, + } + + if content is not None: + json_request["content"] = content + + origin = None + + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + + if not auth_headers: + raise NoAuthenticationError( + 401, "Missing Authorization headers", Codes.UNAUTHORIZED + ) + + for auth in auth_headers: + if auth.startswith(b"X-Matrix"): + (origin, key, sig) = _parse_auth_header(auth) + json_request["origin"] = origin + json_request["signatures"].setdefault(origin, {})[key] = sig + + if ( + self.federation_domain_whitelist is not None + and origin not in self.federation_domain_whitelist + ): + raise FederationDeniedError(origin) + + if origin is None or not json_request["signatures"]: + raise NoAuthenticationError( + 401, "Missing Authorization headers", Codes.UNAUTHORIZED + ) + + await self.keyring.verify_json_for_server( + origin, + json_request, + now, + ) + + logger.debug("Request from %s", origin) + request.requester = origin + + # If we get a valid signed request from the other side, its probably + # alive + retry_timings = await self.store.get_destination_retry_timings(origin) + if retry_timings and retry_timings.retry_last_ts: + run_in_background(self._reset_retry_timings, origin) + + return origin + + async def _reset_retry_timings(self, origin): + try: + logger.info("Marking origin %r as up", origin) + await self.store.set_destination_retry_timings(origin, None, 0, 0) + + # Inform the relevant places that the remote server is back up. + self.notifier.notify_remote_server_up(origin) + if self.replication_client: + # If we're on a worker we try and inform master about this. The + # replication client doesn't hook into the notifier to avoid + # infinite loops where we send a `REMOTE_SERVER_UP` command to + # master, which then echoes it back to us which in turn pokes + # the notifier. + self.replication_client.send_remote_server_up(origin) + + except Exception: + logger.exception("Error resetting retry timings on %s", origin) + + +def _parse_auth_header(header_bytes): + """Parse an X-Matrix auth header + + Args: + header_bytes (bytes): header value + + Returns: + Tuple[str, str, str]: origin, key id, signature. + + Raises: + AuthenticationError if the header could not be parsed + """ + try: + header_str = header_bytes.decode("utf-8") + params = header_str.split(" ")[1].split(",") + param_dict = dict(kv.split("=") for kv in params) + + def strip_quotes(value): + if value.startswith('"'): + return value[1:-1] + else: + return value + + origin = strip_quotes(param_dict["origin"]) + + # ensure that the origin is a valid server name + parse_and_validate_server_name(origin) + + key = strip_quotes(param_dict["key"]) + sig = strip_quotes(param_dict["sig"]) + return origin, key, sig + except Exception as e: + logger.warning( + "Error parsing auth header '%s': %s", + header_bytes.decode("ascii", "replace"), + e, + ) + raise AuthenticationError( + 400, "Malformed Authorization header", Codes.UNAUTHORIZED + ) + + +class BaseFederationServlet: + """Abstract base class for federation servlet classes. + + The servlet object should have a PATH attribute which takes the form of a regexp to + match against the request path (excluding the /federation/v1 prefix). + + The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match + the appropriate HTTP method. These methods must be *asynchronous* and have the + signature: + + on_(self, origin, content, query, **kwargs) + + With arguments: + + origin (unicode|None): The authenticated server_name of the calling server, + unless REQUIRE_AUTH is set to False and authentication failed. + + content (unicode|None): decoded json body of the request. None if the + request was a GET. + + query (dict[bytes, list[bytes]]): Query params from the request. url-decoded + (ie, '+' and '%xx' are decoded) but note that it is *not* utf8-decoded + yet. + + **kwargs (dict[unicode, unicode]): the dict mapping keys to path + components as specified in the path match regexp. + + Returns: + Optional[Tuple[int, object]]: either (response code, response object) to + return a JSON response, or None if the request has already been handled. + + Raises: + SynapseError: to return an error code + + Exception: other exceptions will be caught, logged, and a 500 will be + returned. + """ + + PATH = "" # Overridden in subclasses, the regex to match against the path. + + REQUIRE_AUTH = True + + PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version + + RATELIMIT = True # Whether to rate limit requests or not + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + self.hs = hs + self.authenticator = authenticator + self.ratelimiter = ratelimiter + self.server_name = server_name + + def _wrap(self, func): + authenticator = self.authenticator + ratelimiter = self.ratelimiter + + @functools.wraps(func) + async def new_func(request, *args, **kwargs): + """A callback which can be passed to HttpServer.RegisterPaths + + Args: + request (twisted.web.http.Request): + *args: unused? + **kwargs (dict[unicode, unicode]): the dict mapping keys to path + components as specified in the path match regexp. + + Returns: + Tuple[int, object]|None: (response code, response object) as returned by + the callback method. None if the request has already been handled. + """ + content = None + if request.method in [b"PUT", b"POST"]: + # TODO: Handle other method types? other content types? + content = parse_json_object_from_request(request) + + try: + origin = await authenticator.authenticate_request(request, content) + except NoAuthenticationError: + origin = None + if self.REQUIRE_AUTH: + logger.warning( + "authenticate_request failed: missing authentication" + ) + raise + except Exception as e: + logger.warning("authenticate_request failed: %s", e) + raise + + request_tags = { + SynapseTags.REQUEST_ID: request.get_request_id(), + tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, + tags.HTTP_METHOD: request.get_method(), + tags.HTTP_URL: request.get_redacted_uri(), + tags.PEER_HOST_IPV6: request.getClientIP(), + "authenticated_entity": origin, + "servlet_name": request.request_metrics.name, + } + + # Only accept the span context if the origin is authenticated + # and whitelisted + if origin and whitelisted_homeserver(origin): + scope = start_active_span_from_request( + request, "incoming-federation-request", tags=request_tags + ) + else: + scope = start_active_span( + "incoming-federation-request", tags=request_tags + ) + + with scope: + opentracing.inject_response_headers(request.responseHeaders) + + if origin and self.RATELIMIT: + with ratelimiter.ratelimit(origin) as d: + await d + if request._disconnected: + logger.warning( + "client disconnected before we started processing " + "request" + ) + return -1, None + response = await func( + origin, content, request.args, *args, **kwargs + ) + else: + response = await func( + origin, content, request.args, *args, **kwargs + ) + + return response + + return new_func + + def register(self, server): + pattern = re.compile("^" + self.PREFIX + self.PATH + "$") + + for method in ("GET", "PUT", "POST"): + code = getattr(self, "on_%s" % (method), None) + if code is None: + continue + + server.register_paths( + method, + (pattern,), + self._wrap(code), + self.__class__.__name__, + ) diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py new file mode 100644 index 0000000000..2806337846 --- /dev/null +++ b/synapse/federation/transport/server/federation.py @@ -0,0 +1,692 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union + +from typing_extensions import Literal + +import synapse +from synapse.api.errors import Codes, SynapseError +from synapse.api.room_versions import RoomVersions +from synapse.api.urls import FEDERATION_UNSTABLE_PREFIX, FEDERATION_V2_PREFIX +from synapse.federation.transport.server._base import ( + Authenticator, + BaseFederationServlet, +) +from synapse.http.servlet import ( + parse_boolean_from_args, + parse_integer_from_args, + parse_string_from_args, + parse_strings_from_args, +) +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util.ratelimitutils import FederationRateLimiter +from synapse.util.versionstring import get_version_string + +logger = logging.getLogger(__name__) + + +class BaseFederationServerServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a federation server handler. + + See BaseFederationServlet for more information. + """ + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_federation_server() + + +class FederationSendServlet(BaseFederationServerServlet): + PATH = "/send/(?P[^/]*)/?" + + # We ratelimit manually in the handler as we queue up the requests and we + # don't want to fill up the ratelimiter with blocked requests. + RATELIMIT = False + + # This is when someone is trying to send us a bunch of data. + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + transaction_id: str, + ) -> Tuple[int, JsonDict]: + """Called on PUT /send// + + Args: + transaction_id: The transaction_id associated with this request. This + is *not* None. + + Returns: + Tuple of `(code, response)`, where + `response` is a python dict to be converted into JSON that is + used as the response body. + """ + # Parse the request + try: + transaction_data = content + + logger.debug("Decoded %s: %s", transaction_id, str(transaction_data)) + + logger.info( + "Received txn %s from %s. (PDUs: %d, EDUs: %d)", + transaction_id, + origin, + len(transaction_data.get("pdus", [])), + len(transaction_data.get("edus", [])), + ) + + except Exception as e: + logger.exception(e) + return 400, {"error": "Invalid transaction"} + + code, response = await self.handler.on_incoming_transaction( + origin, transaction_id, self.server_name, transaction_data + ) + + return code, response + + +class FederationEventServlet(BaseFederationServerServlet): + PATH = "/event/(?P[^/]*)/?" + + # This is when someone asks for a data item for a given server data_id pair. + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + event_id: str, + ) -> Tuple[int, Union[JsonDict, str]]: + return await self.handler.on_pdu_request(origin, event_id) + + +class FederationStateV1Servlet(BaseFederationServerServlet): + PATH = "/state/(?P[^/]*)/?" + + # This is when someone asks for all data for a given room. + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + return await self.handler.on_room_state_request( + origin, + room_id, + parse_string_from_args(query, "event_id", None, required=False), + ) + + +class FederationStateIdsServlet(BaseFederationServerServlet): + PATH = "/state_ids/(?P[^/]*)/?" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + return await self.handler.on_state_ids_request( + origin, + room_id, + parse_string_from_args(query, "event_id", None, required=True), + ) + + +class FederationBackfillServlet(BaseFederationServerServlet): + PATH = "/backfill/(?P[^/]*)/?" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + versions = [x.decode("ascii") for x in query[b"v"]] + limit = parse_integer_from_args(query, "limit", None) + + if not limit: + return 400, {"error": "Did not include limit param"} + + return await self.handler.on_backfill_request(origin, room_id, versions, limit) + + +class FederationQueryServlet(BaseFederationServerServlet): + PATH = "/query/(?P[^/]*)" + + # This is when we receive a server-server Query + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + query_type: str, + ) -> Tuple[int, JsonDict]: + args = {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()} + args["origin"] = origin + return await self.handler.on_query_request(query_type, args) + + +class FederationMakeJoinServlet(BaseFederationServerServlet): + PATH = "/make_join/(?P[^/]*)/(?P[^/]*)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + """ + Args: + origin: The authenticated server_name of the calling server + + content: (GETs don't have bodies) + + query: Query params from the request. + + **kwargs: the dict mapping keys to path components as specified in + the path match regexp. + + Returns: + Tuple of (response code, response object) + """ + supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8") + if supported_versions is None: + supported_versions = ["1"] + + result = await self.handler.on_make_join_request( + origin, room_id, user_id, supported_versions=supported_versions + ) + return 200, result + + +class FederationMakeLeaveServlet(BaseFederationServerServlet): + PATH = "/make_leave/(?P[^/]*)/(?P[^/]*)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + result = await self.handler.on_make_leave_request(origin, room_id, user_id) + return 200, result + + +class FederationV1SendLeaveServlet(BaseFederationServerServlet): + PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: + result = await self.handler.on_send_leave_request(origin, content, room_id) + return 200, (200, result) + + +class FederationV2SendLeaveServlet(BaseFederationServerServlet): + PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + result = await self.handler.on_send_leave_request(origin, content, room_id) + return 200, result + + +class FederationMakeKnockServlet(BaseFederationServerServlet): + PATH = "/make_knock/(?P[^/]*)/(?P[^/]*)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + # Retrieve the room versions the remote homeserver claims to support + supported_versions = parse_strings_from_args( + query, "ver", required=True, encoding="utf-8" + ) + + result = await self.handler.on_make_knock_request( + origin, room_id, user_id, supported_versions=supported_versions + ) + return 200, result + + +class FederationV1SendKnockServlet(BaseFederationServerServlet): + PATH = "/send_knock/(?P[^/]*)/(?P[^/]*)" + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + result = await self.handler.on_send_knock_request(origin, content, room_id) + return 200, result + + +class FederationEventAuthServlet(BaseFederationServerServlet): + PATH = "/event_auth/(?P[^/]*)/(?P[^/]*)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + return await self.handler.on_event_auth(origin, room_id, event_id) + + +class FederationV1SendJoinServlet(BaseFederationServerServlet): + PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: + # TODO(paul): assert that event_id parsed from path actually + # match those given in content + result = await self.handler.on_send_join_request(origin, content, room_id) + return 200, (200, result) + + +class FederationV2SendJoinServlet(BaseFederationServerServlet): + PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + # TODO(paul): assert that event_id parsed from path actually + # match those given in content + result = await self.handler.on_send_join_request(origin, content, room_id) + return 200, result + + +class FederationV1InviteServlet(BaseFederationServerServlet): + PATH = "/invite/(?P[^/]*)/(?P[^/]*)" + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: + # We don't get a room version, so we have to assume its EITHER v1 or + # v2. This is "fine" as the only difference between V1 and V2 is the + # state resolution algorithm, and we don't use that for processing + # invites + result = await self.handler.on_invite_request( + origin, content, room_version_id=RoomVersions.V1.identifier + ) + + # V1 federation API is defined to return a content of `[200, {...}]` + # due to a historical bug. + return 200, (200, result) + + +class FederationV2InviteServlet(BaseFederationServerServlet): + PATH = "/invite/(?P[^/]*)/(?P[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + # TODO(paul): assert that room_id/event_id parsed from path actually + # match those given in content + + room_version = content["room_version"] + event = content["event"] + invite_room_state = content["invite_room_state"] + + # Synapse expects invite_room_state to be in unsigned, as it is in v1 + # API + + event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state + + result = await self.handler.on_invite_request( + origin, event, room_version_id=room_version + ) + return 200, result + + +class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet): + PATH = "/exchange_third_party_invite/(?P[^/]*)" + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + await self.handler.on_exchange_third_party_invite_request(content) + return 200, {} + + +class FederationClientKeysQueryServlet(BaseFederationServerServlet): + PATH = "/user/keys/query" + + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + return await self.handler.on_query_client_keys(origin, content) + + +class FederationUserDevicesQueryServlet(BaseFederationServerServlet): + PATH = "/user/devices/(?P[^/]*)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + user_id: str, + ) -> Tuple[int, JsonDict]: + return await self.handler.on_query_user_devices(origin, user_id) + + +class FederationClientKeysClaimServlet(BaseFederationServerServlet): + PATH = "/user/keys/claim" + + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + response = await self.handler.on_claim_client_keys(origin, content) + return 200, response + + +class FederationGetMissingEventsServlet(BaseFederationServerServlet): + # TODO(paul): Why does this path alone end with "/?" optional? + PATH = "/get_missing_events/(?P[^/]*)/?" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + limit = int(content.get("limit", 10)) + earliest_events = content.get("earliest_events", []) + latest_events = content.get("latest_events", []) + + result = await self.handler.on_get_missing_events( + origin, + room_id=room_id, + earliest_events=earliest_events, + latest_events=latest_events, + limit=limit, + ) + + return 200, result + + +class On3pidBindServlet(BaseFederationServerServlet): + PATH = "/3pid/onbind" + + REQUIRE_AUTH = False + + async def on_POST( + self, origin: Optional[str], content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + if "invites" in content: + last_exception = None + for invite in content["invites"]: + try: + if "signed" not in invite or "token" not in invite["signed"]: + message = ( + "Rejecting received notification of third-" + "party invite without signed: %s" % (invite,) + ) + logger.info(message) + raise SynapseError(400, message) + await self.handler.exchange_third_party_invite( + invite["sender"], + invite["mxid"], + invite["room_id"], + invite["signed"], + ) + except Exception as e: + last_exception = e + if last_exception: + raise last_exception + return 200, {} + + +class FederationVersionServlet(BaseFederationServlet): + PATH = "/version" + + REQUIRE_AUTH = False + + async def on_GET( + self, + origin: Optional[str], + content: Literal[None], + query: Dict[bytes, List[bytes]], + ) -> Tuple[int, JsonDict]: + return ( + 200, + {"server": {"name": "Synapse", "version": get_version_string(synapse)}}, + ) + + +class FederationSpaceSummaryServlet(BaseFederationServlet): + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" + PATH = "/spaces/(?P[^/]*)" + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_space_summary_handler() + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Mapping[bytes, Sequence[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) + max_rooms_per_space = parse_integer_from_args(query, "max_rooms_per_space") + + exclude_rooms = parse_strings_from_args(query, "exclude_rooms", default=[]) + + return 200, await self.handler.federation_space_summary( + origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms + ) + + # TODO When switching to the stable endpoint, remove the POST handler. + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Mapping[bytes, Sequence[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + suggested_only = content.get("suggested_only", False) + if not isinstance(suggested_only, bool): + raise SynapseError( + 400, "'suggested_only' must be a boolean", Codes.BAD_JSON + ) + + exclude_rooms = content.get("exclude_rooms", []) + if not isinstance(exclude_rooms, list) or any( + not isinstance(x, str) for x in exclude_rooms + ): + raise SynapseError(400, "bad value for 'exclude_rooms'", Codes.BAD_JSON) + + max_rooms_per_space = content.get("max_rooms_per_space") + if max_rooms_per_space is not None and not isinstance(max_rooms_per_space, int): + raise SynapseError( + 400, "bad value for 'max_rooms_per_space'", Codes.BAD_JSON + ) + + return 200, await self.handler.federation_space_summary( + origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms + ) + + +class FederationRoomHierarchyServlet(BaseFederationServlet): + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" + PATH = "/hierarchy/(?P[^/]*)" + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_space_summary_handler() + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Mapping[bytes, Sequence[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) + return 200, await self.handler.get_federation_hierarchy( + origin, room_id, suggested_only + ) + + +class RoomComplexityServlet(BaseFederationServlet): + """ + Indicates to other servers how complex (and therefore likely + resource-intensive) a public room this server knows about is. + """ + + PATH = "/rooms/(?P[^/]*)/complexity" + PREFIX = FEDERATION_UNSTABLE_PREFIX + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self._store = self.hs.get_datastore() + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: + is_public = await self._store.is_room_world_readable_or_publicly_joinable( + room_id + ) + + if not is_public: + raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM) + + complexity = await self._store.get_room_complexity(room_id) + return 200, complexity + + +FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( + FederationSendServlet, + FederationEventServlet, + FederationStateV1Servlet, + FederationStateIdsServlet, + FederationBackfillServlet, + FederationQueryServlet, + FederationMakeJoinServlet, + FederationMakeLeaveServlet, + FederationEventServlet, + FederationV1SendJoinServlet, + FederationV2SendJoinServlet, + FederationV1SendLeaveServlet, + FederationV2SendLeaveServlet, + FederationV1InviteServlet, + FederationV2InviteServlet, + FederationGetMissingEventsServlet, + FederationEventAuthServlet, + FederationClientKeysQueryServlet, + FederationUserDevicesQueryServlet, + FederationClientKeysClaimServlet, + FederationThirdPartyInviteExchangeServlet, + On3pidBindServlet, + FederationVersionServlet, + RoomComplexityServlet, + FederationSpaceSummaryServlet, + FederationRoomHierarchyServlet, + FederationV1SendKnockServlet, + FederationMakeKnockServlet, +) diff --git a/synapse/federation/transport/server/groups_local.py b/synapse/federation/transport/server/groups_local.py new file mode 100644 index 0000000000..a12cd18d58 --- /dev/null +++ b/synapse/federation/transport/server/groups_local.py @@ -0,0 +1,113 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Tuple, Type + +from synapse.api.errors import SynapseError +from synapse.federation.transport.server._base import ( + Authenticator, + BaseFederationServlet, +) +from synapse.handlers.groups_local import GroupsLocalHandler +from synapse.server import HomeServer +from synapse.types import JsonDict, get_domain_from_id +from synapse.util.ratelimitutils import FederationRateLimiter + + +class BaseGroupsLocalServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a groups local handler. + + See BaseFederationServlet for more information. + """ + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_local_handler() + + +class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet): + """A group server has invited a local user""" + + PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/invite" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + if get_domain_from_id(group_id) != origin: + raise SynapseError(403, "group_id doesn't match origin") + + assert isinstance( + self.handler, GroupsLocalHandler + ), "Workers cannot handle group invites." + + new_content = await self.handler.on_invite(group_id, user_id, content) + + return 200, new_content + + +class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet): + """A group server has removed a local user""" + + PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/remove" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, None]: + if get_domain_from_id(group_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + assert isinstance( + self.handler, GroupsLocalHandler + ), "Workers cannot handle group removals." + + await self.handler.user_removed_from_group(group_id, user_id, content) + + return 200, None + + +class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet): + """Get roles in a group""" + + PATH = "/get_groups_publicised" + + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + resp = await self.handler.bulk_get_publicised_groups( + content["user_ids"], proxy=False + ) + + return 200, resp + + +GROUP_LOCAL_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( + FederationGroupsLocalInviteServlet, + FederationGroupsRemoveLocalUserServlet, + FederationGroupsBulkPublicisedServlet, +) diff --git a/synapse/federation/transport/server/groups_server.py b/synapse/federation/transport/server/groups_server.py new file mode 100644 index 0000000000..b30e92a5eb --- /dev/null +++ b/synapse/federation/transport/server/groups_server.py @@ -0,0 +1,753 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Tuple, Type + +from typing_extensions import Literal + +from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH +from synapse.api.errors import Codes, SynapseError +from synapse.federation.transport.server._base import ( + Authenticator, + BaseFederationServlet, +) +from synapse.http.servlet import parse_string_from_args +from synapse.server import HomeServer +from synapse.types import JsonDict, get_domain_from_id +from synapse.util.ratelimitutils import FederationRateLimiter + + +class BaseGroupsServerServlet(BaseFederationServlet): + """Abstract base class for federation servlet classes which provides a groups server handler. + + See BaseFederationServlet for more information. + """ + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.handler = hs.get_groups_server_handler() + + +class FederationGroupsProfileServlet(BaseGroupsServerServlet): + """Get/set the basic profile of a group on behalf of a user""" + + PATH = "/groups/(?P[^/]*)/profile" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.get_group_profile(group_id, requester_user_id) + + return 200, new_content + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.update_group_profile( + group_id, requester_user_id, content + ) + + return 200, new_content + + +class FederationGroupsSummaryServlet(BaseGroupsServerServlet): + PATH = "/groups/(?P[^/]*)/summary" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.get_group_summary(group_id, requester_user_id) + + return 200, new_content + + +class FederationGroupsRoomsServlet(BaseGroupsServerServlet): + """Get the rooms in a group on behalf of a user""" + + PATH = "/groups/(?P[^/]*)/rooms" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.get_rooms_in_group(group_id, requester_user_id) + + return 200, new_content + + +class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet): + """Add/remove room from group""" + + PATH = "/groups/(?P[^/]*)/room/(?P[^/]*)" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.add_room_to_group( + group_id, requester_user_id, room_id, content + ) + + return 200, new_content + + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.remove_room_from_group( + group_id, requester_user_id, room_id + ) + + return 200, new_content + + +class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet): + """Update room config in group""" + + PATH = ( + "/groups/(?P[^/]*)/room/(?P[^/]*)" + "/config/(?P[^/]*)" + ) + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + config_key: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + result = await self.handler.update_room_in_group( + group_id, requester_user_id, room_id, config_key, content + ) + + return 200, result + + +class FederationGroupsUsersServlet(BaseGroupsServerServlet): + """Get the users in a group on behalf of a user""" + + PATH = "/groups/(?P[^/]*)/users" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.get_users_in_group(group_id, requester_user_id) + + return 200, new_content + + +class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet): + """Get the users that have been invited to a group""" + + PATH = "/groups/(?P[^/]*)/invited_users" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.get_invited_users_in_group( + group_id, requester_user_id + ) + + return 200, new_content + + +class FederationGroupsInviteServlet(BaseGroupsServerServlet): + """Ask a group server to invite someone to the group""" + + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/invite" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.invite_to_group( + group_id, user_id, requester_user_id, content + ) + + return 200, new_content + + +class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet): + """Accept an invitation from the group server""" + + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/accept_invite" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + if get_domain_from_id(user_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + new_content = await self.handler.accept_invite(group_id, user_id, content) + + return 200, new_content + + +class FederationGroupsJoinServlet(BaseGroupsServerServlet): + """Attempt to join a group""" + + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/join" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + if get_domain_from_id(user_id) != origin: + raise SynapseError(403, "user_id doesn't match origin") + + new_content = await self.handler.join_group(group_id, user_id, content) + + return 200, new_content + + +class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet): + """Leave or kick a user from the group""" + + PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/remove" + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.remove_user_from_group( + group_id, user_id, requester_user_id, content + ) + + return 200, new_content + + +class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet): + """Add/remove a room from the group summary, with optional category. + + Matches both: + - /groups/:group/summary/rooms/:room_id + - /groups/:group/summary/categories/:category/rooms/:room_id + """ + + PATH = ( + "/groups/(?P[^/]*)/summary" + "(/categories/(?P[^/]+))?" + "/rooms/(?P[^/]*)" + ) + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError( + 400, "category_id cannot be empty string", Codes.INVALID_PARAM + ) + + if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: + raise SynapseError( + 400, + "category_id may not be longer than %s characters" + % (MAX_GROUP_CATEGORYID_LENGTH,), + Codes.INVALID_PARAM, + ) + + resp = await self.handler.update_group_summary_room( + group_id, + requester_user_id, + room_id=room_id, + category_id=category_id, + content=content, + ) + + return 200, resp + + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = await self.handler.delete_group_summary_room( + group_id, requester_user_id, room_id=room_id, category_id=category_id + ) + + return 200, resp + + +class FederationGroupsCategoriesServlet(BaseGroupsServerServlet): + """Get all categories for a group""" + + PATH = "/groups/(?P[^/]*)/categories/?" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = await self.handler.get_group_categories(group_id, requester_user_id) + + return 200, resp + + +class FederationGroupsCategoryServlet(BaseGroupsServerServlet): + """Add/remove/get a category in a group""" + + PATH = "/groups/(?P[^/]*)/categories/(?P[^/]+)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = await self.handler.get_group_category( + group_id, requester_user_id, category_id + ) + + return 200, resp + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: + raise SynapseError( + 400, + "category_id may not be longer than %s characters" + % (MAX_GROUP_CATEGORYID_LENGTH,), + Codes.INVALID_PARAM, + ) + + resp = await self.handler.upsert_group_category( + group_id, requester_user_id, category_id, content + ) + + return 200, resp + + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty string") + + resp = await self.handler.delete_group_category( + group_id, requester_user_id, category_id + ) + + return 200, resp + + +class FederationGroupsRolesServlet(BaseGroupsServerServlet): + """Get roles in a group""" + + PATH = "/groups/(?P[^/]*)/roles/?" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = await self.handler.get_group_roles(group_id, requester_user_id) + + return 200, resp + + +class FederationGroupsRoleServlet(BaseGroupsServerServlet): + """Add/remove/get a role in a group""" + + PATH = "/groups/(?P[^/]*)/roles/(?P[^/]+)" + + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + resp = await self.handler.get_group_role(group_id, requester_user_id, role_id) + + return 200, resp + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError( + 400, "role_id cannot be empty string", Codes.INVALID_PARAM + ) + + if len(role_id) > MAX_GROUP_ROLEID_LENGTH: + raise SynapseError( + 400, + "role_id may not be longer than %s characters" + % (MAX_GROUP_ROLEID_LENGTH,), + Codes.INVALID_PARAM, + ) + + resp = await self.handler.update_group_role( + group_id, requester_user_id, role_id, content + ) + + return 200, resp + + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = await self.handler.delete_group_role( + group_id, requester_user_id, role_id + ) + + return 200, resp + + +class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet): + """Add/remove a user from the group summary, with optional role. + + Matches both: + - /groups/:group/summary/users/:user_id + - /groups/:group/summary/roles/:role/users/:user_id + """ + + PATH = ( + "/groups/(?P[^/]*)/summary" + "(/roles/(?P[^/]+))?" + "/users/(?P[^/]*)" + ) + + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + if len(role_id) > MAX_GROUP_ROLEID_LENGTH: + raise SynapseError( + 400, + "role_id may not be longer than %s characters" + % (MAX_GROUP_ROLEID_LENGTH,), + Codes.INVALID_PARAM, + ) + + resp = await self.handler.update_group_summary_user( + group_id, + requester_user_id, + user_id=user_id, + role_id=role_id, + content=content, + ) + + return 200, resp + + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty string") + + resp = await self.handler.delete_group_summary_user( + group_id, requester_user_id, user_id=user_id, role_id=role_id + ) + + return 200, resp + + +class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet): + """Sets whether a group is joinable without an invite or knock""" + + PATH = "/groups/(?P[^/]*)/settings/m.join_policy" + + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: + requester_user_id = parse_string_from_args( + query, "requester_user_id", required=True + ) + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = await self.handler.set_group_join_policy( + group_id, requester_user_id, content + ) + + return 200, new_content + + +GROUP_SERVER_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( + FederationGroupsProfileServlet, + FederationGroupsSummaryServlet, + FederationGroupsRoomsServlet, + FederationGroupsUsersServlet, + FederationGroupsInvitedUsersServlet, + FederationGroupsInviteServlet, + FederationGroupsAcceptInviteServlet, + FederationGroupsJoinServlet, + FederationGroupsRemoveUserServlet, + FederationGroupsSummaryRoomsServlet, + FederationGroupsCategoriesServlet, + FederationGroupsCategoryServlet, + FederationGroupsRolesServlet, + FederationGroupsRoleServlet, + FederationGroupsSummaryUsersServlet, + FederationGroupsAddRoomsServlet, + FederationGroupsAddRoomsConfigServlet, + FederationGroupsSettingJoinPolicyServlet, +) -- cgit 1.5.1 From 0ace38b7b310fc1b4f88ac93d01ec900f33f7a07 Mon Sep 17 00:00:00 2001 From: Michael Telatynski <7t3chguy@gmail.com> Date: Mon, 16 Aug 2021 15:49:12 +0100 Subject: Experimental support for MSC3266 Room Summary API. (#10394) --- changelog.d/10394.feature | 1 + mypy.ini | 2 +- synapse/config/experimental.py | 3 + synapse/federation/transport/server/federation.py | 4 +- synapse/handlers/room_summary.py | 1171 +++++++++++++++++++++ synapse/handlers/space_summary.py | 1116 -------------------- synapse/http/servlet.py | 58 +- synapse/rest/admin/rooms.py | 45 +- synapse/rest/client/v1/room.py | 90 +- synapse/server.py | 6 +- tests/handlers/test_room_summary.py | 959 +++++++++++++++++ tests/handlers/test_space_summary.py | 881 ---------------- 12 files changed, 2255 insertions(+), 2081 deletions(-) create mode 100644 changelog.d/10394.feature create mode 100644 synapse/handlers/room_summary.py delete mode 100644 synapse/handlers/space_summary.py create mode 100644 tests/handlers/test_room_summary.py delete mode 100644 tests/handlers/test_space_summary.py diff --git a/changelog.d/10394.feature b/changelog.d/10394.feature new file mode 100644 index 0000000000..c8bbc5a740 --- /dev/null +++ b/changelog.d/10394.feature @@ -0,0 +1 @@ +Initial local support for [MSC3266](https://github.com/matrix-org/synapse/pull/10394), Room Summary over the unstable `/rooms/{roomIdOrAlias}/summary` API. diff --git a/mypy.ini b/mypy.ini index 5d6cd557bc..e1b9405daa 100644 --- a/mypy.ini +++ b/mypy.ini @@ -86,7 +86,7 @@ files = tests/test_event_auth.py, tests/test_utils, tests/handlers/test_password_providers.py, - tests/handlers/test_space_summary.py, + tests/handlers/test_room_summary.py, tests/rest/client/v1/test_login.py, tests/rest/client/v2_alpha/test_auth.py, tests/util/test_itertools.py, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 4c60ee8c28..b918fb15b0 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -38,3 +38,6 @@ class ExperimentalConfig(Config): # MSC3244 (room version capabilities) self.msc3244_enabled: bool = experimental.get("msc3244_enabled", False) + + # MSC3266 (room summary api) + self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 2806337846..7d81cc642c 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -547,7 +547,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): server_name: str, ): super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_space_summary_handler() + self.handler = hs.get_room_summary_handler() async def on_GET( self, @@ -608,7 +608,7 @@ class FederationRoomHierarchyServlet(BaseFederationServlet): server_name: str, ): super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_space_summary_handler() + self.handler = hs.get_room_summary_handler() async def on_GET( self, diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py new file mode 100644 index 0000000000..ac6cfc0da9 --- /dev/null +++ b/synapse/handlers/room_summary.py @@ -0,0 +1,1171 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import logging +import re +from collections import deque +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Set, Tuple + +import attr + +from synapse.api.constants import ( + EventContentFields, + EventTypes, + HistoryVisibility, + JoinRules, + Membership, + RoomTypes, +) +from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.events import EventBase +from synapse.events.utils import format_event_for_client_v2 +from synapse.types import JsonDict +from synapse.util.caches.response_cache import ResponseCache +from synapse.util.stringutils import random_string + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + +# number of rooms to return. We'll stop once we hit this limit. +MAX_ROOMS = 50 + +# max number of events to return per room. +MAX_ROOMS_PER_SPACE = 50 + +# max number of federation servers to hit per room +MAX_SERVERS_PER_SPACE = 3 + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _PaginationKey: + """The key used to find unique pagination session.""" + + # The first three entries match the request parameters (and cannot change + # during a pagination session). + room_id: str + suggested_only: bool + max_depth: Optional[int] + # The randomly generated token. + token: str + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _PaginationSession: + """The information that is stored for pagination.""" + + # The time the pagination session was created, in milliseconds. + creation_time_ms: int + # The queue of rooms which are still to process. + room_queue: List["_RoomQueueEntry"] + # A set of rooms which have been processed. + processed_rooms: Set[str] + + +class RoomSummaryHandler: + # The time a pagination session remains valid for. + _PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000 + + def __init__(self, hs: "HomeServer"): + self._clock = hs.get_clock() + self._event_auth_handler = hs.get_event_auth_handler() + self._store = hs.get_datastore() + self._event_serializer = hs.get_event_client_serializer() + self._server_name = hs.hostname + self._federation_client = hs.get_federation_client() + + # A map of query information to the current pagination state. + # + # TODO Allow for multiple workers to share this data. + # TODO Expire pagination tokens. + self._pagination_sessions: Dict[_PaginationKey, _PaginationSession] = {} + + # If a user tries to fetch the same page multiple times in quick succession, + # only process the first attempt and return its result to subsequent requests. + self._pagination_response_cache: ResponseCache[ + Tuple[str, bool, Optional[int], Optional[int], Optional[str]] + ] = ResponseCache( + hs.get_clock(), + "get_room_hierarchy", + ) + + def _expire_pagination_sessions(self): + """Expire pagination session which are old.""" + expire_before = ( + self._clock.time_msec() - self._PAGINATION_SESSION_VALIDITY_PERIOD_MS + ) + to_expire = [] + + for key, value in self._pagination_sessions.items(): + if value.creation_time_ms < expire_before: + to_expire.append(key) + + for key in to_expire: + logger.debug("Expiring pagination session id %s", key) + del self._pagination_sessions[key] + + async def get_space_summary( + self, + requester: str, + room_id: str, + suggested_only: bool = False, + max_rooms_per_space: Optional[int] = None, + ) -> JsonDict: + """ + Implementation of the space summary C-S API + + Args: + requester: user id of the user making this request + + room_id: room id to start the summary at + + suggested_only: whether we should only return children with the "suggested" + flag set. + + max_rooms_per_space: an optional limit on the number of child rooms we will + return. This does not apply to the root room (ie, room_id), and + is overridden by MAX_ROOMS_PER_SPACE. + + Returns: + summary dict to return + """ + # First of all, check that the room is accessible. + if not await self._is_local_room_accessible(room_id, requester): + raise AuthError( + 403, + "User %s not in room %s, and room previews are disabled" + % (requester, room_id), + ) + + # the queue of rooms to process + room_queue = deque((_RoomQueueEntry(room_id, ()),)) + + # rooms we have already processed + processed_rooms: Set[str] = set() + + # events we have already processed. We don't necessarily have their event ids, + # so instead we key on (room id, state key) + processed_events: Set[Tuple[str, str]] = set() + + rooms_result: List[JsonDict] = [] + events_result: List[JsonDict] = [] + + while room_queue and len(rooms_result) < MAX_ROOMS: + queue_entry = room_queue.popleft() + room_id = queue_entry.room_id + if room_id in processed_rooms: + # already done this room + continue + + logger.debug("Processing room %s", room_id) + + is_in_room = await self._store.is_host_joined(room_id, self._server_name) + + # The client-specified max_rooms_per_space limit doesn't apply to the + # room_id specified in the request, so we ignore it if this is the + # first room we are processing. + max_children = max_rooms_per_space if processed_rooms else None + + if is_in_room: + room_entry = await self._summarize_local_room( + requester, None, room_id, suggested_only, max_children + ) + + events: Sequence[JsonDict] = [] + if room_entry: + rooms_result.append(room_entry.room) + events = room_entry.children_state_events + + logger.debug( + "Query of local room %s returned events %s", + room_id, + ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], + ) + else: + fed_rooms = await self._summarize_remote_room( + queue_entry, + suggested_only, + max_children, + exclude_rooms=processed_rooms, + ) + + # The results over federation might include rooms that the we, + # as the requesting server, are allowed to see, but the requesting + # user is not permitted see. + # + # Filter the returned results to only what is accessible to the user. + events = [] + for room_entry in fed_rooms: + room = room_entry.room + fed_room_id = room_entry.room_id + + # The user can see the room, include it! + if await self._is_remote_room_accessible( + requester, fed_room_id, room + ): + # Before returning to the client, remove the allowed_room_ids + # and allowed_spaces keys. + room.pop("allowed_room_ids", None) + room.pop("allowed_spaces", None) + + rooms_result.append(room) + events.extend(room_entry.children_state_events) + + # All rooms returned don't need visiting again (even if the user + # didn't have access to them). + processed_rooms.add(fed_room_id) + + logger.debug( + "Query of %s returned rooms %s, events %s", + room_id, + [room_entry.room.get("room_id") for room_entry in fed_rooms], + ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], + ) + + # the room we queried may or may not have been returned, but don't process + # it again, anyway. + processed_rooms.add(room_id) + + # XXX: is it ok that we blindly iterate through any events returned by + # a remote server, whether or not they actually link to any rooms in our + # tree? + for ev in events: + # remote servers might return events we have already processed + # (eg, Dendrite returns inward pointers as well as outward ones), so + # we need to filter them out, to avoid returning duplicate links to the + # client. + ev_key = (ev["room_id"], ev["state_key"]) + if ev_key in processed_events: + continue + events_result.append(ev) + + # add the child to the queue. we have already validated + # that the vias are a list of server names. + room_queue.append( + _RoomQueueEntry(ev["state_key"], ev["content"]["via"]) + ) + processed_events.add(ev_key) + + return {"rooms": rooms_result, "events": events_result} + + async def get_room_hierarchy( + self, + requester: str, + requested_room_id: str, + suggested_only: bool = False, + max_depth: Optional[int] = None, + limit: Optional[int] = None, + from_token: Optional[str] = None, + ) -> JsonDict: + """ + Implementation of the room hierarchy C-S API. + + Args: + requester: The user ID of the user making this request. + requested_room_id: The room ID to start the hierarchy at (the "root" room). + suggested_only: Whether we should only return children with the "suggested" + flag set. + max_depth: The maximum depth in the tree to explore, must be a + non-negative integer. + + 0 would correspond to just the root room, 1 would include just + the root room's children, etc. + limit: An optional limit on the number of rooms to return per + page. Must be a positive integer. + from_token: An optional pagination token. + + Returns: + The JSON hierarchy dictionary. + """ + # If a user tries to fetch the same page multiple times in quick succession, + # only process the first attempt and return its result to subsequent requests. + # + # This is due to the pagination process mutating internal state, attempting + # to process multiple requests for the same page will result in errors. + return await self._pagination_response_cache.wrap( + (requested_room_id, suggested_only, max_depth, limit, from_token), + self._get_room_hierarchy, + requester, + requested_room_id, + suggested_only, + max_depth, + limit, + from_token, + ) + + async def _get_room_hierarchy( + self, + requester: str, + requested_room_id: str, + suggested_only: bool = False, + max_depth: Optional[int] = None, + limit: Optional[int] = None, + from_token: Optional[str] = None, + ) -> JsonDict: + """See docstring for SpaceSummaryHandler.get_room_hierarchy.""" + + # First of all, check that the room is accessible. + if not await self._is_local_room_accessible(requested_room_id, requester): + raise AuthError( + 403, + "User %s not in room %s, and room previews are disabled" + % (requester, requested_room_id), + ) + + # If this is continuing a previous session, pull the persisted data. + if from_token: + self._expire_pagination_sessions() + + pagination_key = _PaginationKey( + requested_room_id, suggested_only, max_depth, from_token + ) + if pagination_key not in self._pagination_sessions: + raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM) + + # Load the previous state. + pagination_session = self._pagination_sessions[pagination_key] + room_queue = pagination_session.room_queue + processed_rooms = pagination_session.processed_rooms + else: + # The queue of rooms to process, the next room is last on the stack. + room_queue = [_RoomQueueEntry(requested_room_id, ())] + + # Rooms we have already processed. + processed_rooms = set() + + rooms_result: List[JsonDict] = [] + + # Cap the limit to a server-side maximum. + if limit is None: + limit = MAX_ROOMS + else: + limit = min(limit, MAX_ROOMS) + + # Iterate through the queue until we reach the limit or run out of + # rooms to include. + while room_queue and len(rooms_result) < limit: + queue_entry = room_queue.pop() + room_id = queue_entry.room_id + current_depth = queue_entry.depth + if room_id in processed_rooms: + # already done this room + continue + + logger.debug("Processing room %s", room_id) + + # A map of summaries for children rooms that might be returned over + # federation. The rationale for caching these and *maybe* using them + # is to prefer any information local to the homeserver before trusting + # data received over federation. + children_room_entries: Dict[str, JsonDict] = {} + # A set of room IDs which are children that did not have information + # returned over federation and are known to be inaccessible to the + # current server. We should not reach out over federation to try to + # summarise these rooms. + inaccessible_children: Set[str] = set() + + # If the room is known locally, summarise it! + is_in_room = await self._store.is_host_joined(room_id, self._server_name) + if is_in_room: + room_entry = await self._summarize_local_room( + requester, + None, + room_id, + suggested_only, + # TODO Handle max children. + max_children=None, + ) + + # Otherwise, attempt to use information for federation. + else: + # A previous call might have included information for this room. + # It can be used if either: + # + # 1. The room is not a space. + # 2. The maximum depth has been achieved (since no children + # information is needed). + if queue_entry.remote_room and ( + queue_entry.remote_room.get("room_type") != RoomTypes.SPACE + or (max_depth is not None and current_depth >= max_depth) + ): + room_entry = _RoomEntry( + queue_entry.room_id, queue_entry.remote_room + ) + + # If the above isn't true, attempt to fetch the room + # information over federation. + else: + ( + room_entry, + children_room_entries, + inaccessible_children, + ) = await self._summarize_remote_room_hierarchy( + queue_entry, + suggested_only, + ) + + # Ensure this room is accessible to the requester (and not just + # the homeserver). + if room_entry and not await self._is_remote_room_accessible( + requester, queue_entry.room_id, room_entry.room + ): + room_entry = None + + # This room has been processed and should be ignored if it appears + # elsewhere in the hierarchy. + processed_rooms.add(room_id) + + # There may or may not be a room entry based on whether it is + # inaccessible to the requesting user. + if room_entry: + # Add the room (including the stripped m.space.child events). + rooms_result.append(room_entry.as_json()) + + # If this room is not at the max-depth, check if there are any + # children to process. + if max_depth is None or current_depth < max_depth: + # The children get added in reverse order so that the next + # room to process, according to the ordering, is the last + # item in the list. + room_queue.extend( + _RoomQueueEntry( + ev["state_key"], + ev["content"]["via"], + current_depth + 1, + children_room_entries.get(ev["state_key"]), + ) + for ev in reversed(room_entry.children_state_events) + if ev["type"] == EventTypes.SpaceChild + and ev["state_key"] not in inaccessible_children + ) + + result: JsonDict = {"rooms": rooms_result} + + # If there's additional data, generate a pagination token (and persist state). + if room_queue: + next_batch = random_string(24) + result["next_batch"] = next_batch + pagination_key = _PaginationKey( + requested_room_id, suggested_only, max_depth, next_batch + ) + self._pagination_sessions[pagination_key] = _PaginationSession( + self._clock.time_msec(), room_queue, processed_rooms + ) + + return result + + async def federation_space_summary( + self, + origin: str, + room_id: str, + suggested_only: bool, + max_rooms_per_space: Optional[int], + exclude_rooms: Iterable[str], + ) -> JsonDict: + """ + Implementation of the space summary Federation API + + Args: + origin: The server requesting the spaces summary. + + room_id: room id to start the summary at + + suggested_only: whether we should only return children with the "suggested" + flag set. + + max_rooms_per_space: an optional limit on the number of child rooms we will + return. Unlike the C-S API, this applies to the root room (room_id). + It is clipped to MAX_ROOMS_PER_SPACE. + + exclude_rooms: a list of rooms to skip over (presumably because the + calling server has already seen them). + + Returns: + summary dict to return + """ + # the queue of rooms to process + room_queue = deque((room_id,)) + + # the set of rooms that we should not walk further. Initialise it with the + # excluded-rooms list; we will add other rooms as we process them so that + # we do not loop. + processed_rooms: Set[str] = set(exclude_rooms) + + rooms_result: List[JsonDict] = [] + events_result: List[JsonDict] = [] + + while room_queue and len(rooms_result) < MAX_ROOMS: + room_id = room_queue.popleft() + if room_id in processed_rooms: + # already done this room + continue + + room_entry = await self._summarize_local_room( + None, origin, room_id, suggested_only, max_rooms_per_space + ) + + processed_rooms.add(room_id) + + if room_entry: + rooms_result.append(room_entry.room) + events_result.extend(room_entry.children_state_events) + + # add any children to the queue + room_queue.extend( + edge_event["state_key"] + for edge_event in room_entry.children_state_events + ) + + return {"rooms": rooms_result, "events": events_result} + + async def get_federation_hierarchy( + self, + origin: str, + requested_room_id: str, + suggested_only: bool, + ): + """ + Implementation of the room hierarchy Federation API. + + This is similar to get_room_hierarchy, but does not recurse into the space. + It also considers whether anyone on the server may be able to access the + room, as opposed to whether a specific user can. + + Args: + origin: The server requesting the spaces summary. + requested_room_id: The room ID to start the hierarchy at (the "root" room). + suggested_only: whether we should only return children with the "suggested" + flag set. + + Returns: + The JSON hierarchy dictionary. + """ + root_room_entry = await self._summarize_local_room( + None, origin, requested_room_id, suggested_only, max_children=None + ) + if root_room_entry is None: + # Room is inaccessible to the requesting server. + raise SynapseError(404, "Unknown room: %s" % (requested_room_id,)) + + children_rooms_result: List[JsonDict] = [] + inaccessible_children: List[str] = [] + + # Iterate through each child and potentially add it, but not its children, + # to the response. + for child_room in root_room_entry.children_state_events: + room_id = child_room.get("state_key") + assert isinstance(room_id, str) + # If the room is unknown, skip it. + if not await self._store.is_host_joined(room_id, self._server_name): + continue + + room_entry = await self._summarize_local_room( + None, origin, room_id, suggested_only, max_children=0 + ) + # If the room is accessible, include it in the results. + # + # Note that only the room summary (without information on children) + # is included in the summary. + if room_entry: + children_rooms_result.append(room_entry.room) + # Otherwise, note that the requesting server shouldn't bother + # trying to summarize this room - they do not have access to it. + else: + inaccessible_children.append(room_id) + + return { + # Include the requested room (including the stripped children events). + "room": root_room_entry.as_json(), + "children": children_rooms_result, + "inaccessible_children": inaccessible_children, + } + + async def _summarize_local_room( + self, + requester: Optional[str], + origin: Optional[str], + room_id: str, + suggested_only: bool, + max_children: Optional[int], + ) -> Optional["_RoomEntry"]: + """ + Generate a room entry and a list of event entries for a given room. + + Args: + requester: + The user requesting the summary, if it is a local request. None + if this is a federation request. + origin: + The server requesting the summary, if it is a federation request. + None if this is a local request. + room_id: The room ID to summarize. + suggested_only: True if only suggested children should be returned. + Otherwise, all children are returned. + max_children: + The maximum number of children rooms to include. This is capped + to a server-set limit. + + Returns: + A room entry if the room should be returned. None, otherwise. + """ + if not await self._is_local_room_accessible(room_id, requester, origin): + return None + + room_entry = await self._build_room_entry(room_id, for_federation=bool(origin)) + + # If the room is not a space or the children don't matter, return just + # the room information. + if room_entry.get("room_type") != RoomTypes.SPACE or max_children == 0: + return _RoomEntry(room_id, room_entry) + + # Otherwise, look for child rooms/spaces. + child_events = await self._get_child_events(room_id) + + if suggested_only: + # we only care about suggested children + child_events = filter(_is_suggested_child_event, child_events) + + if max_children is None or max_children > MAX_ROOMS_PER_SPACE: + max_children = MAX_ROOMS_PER_SPACE + + now = self._clock.time_msec() + events_result: List[JsonDict] = [] + for edge_event in itertools.islice(child_events, max_children): + events_result.append( + await self._event_serializer.serialize_event( + edge_event, + time_now=now, + event_format=format_event_for_client_v2, + ) + ) + + return _RoomEntry(room_id, room_entry, events_result) + + async def _summarize_remote_room( + self, + room: "_RoomQueueEntry", + suggested_only: bool, + max_children: Optional[int], + exclude_rooms: Iterable[str], + ) -> Iterable["_RoomEntry"]: + """ + Request room entries and a list of event entries for a given room by querying a remote server. + + Args: + room: The room to summarize. + suggested_only: True if only suggested children should be returned. + Otherwise, all children are returned. + max_children: + The maximum number of children rooms to include. This is capped + to a server-set limit. + exclude_rooms: + Rooms IDs which do not need to be summarized. + + Returns: + An iterable of room entries. + """ + room_id = room.room_id + logger.info("Requesting summary for %s via %s", room_id, room.via) + + # we need to make the exclusion list json-serialisable + exclude_rooms = list(exclude_rooms) + + via = itertools.islice(room.via, MAX_SERVERS_PER_SPACE) + try: + res = await self._federation_client.get_space_summary( + via, + room_id, + suggested_only=suggested_only, + max_rooms_per_space=max_children, + exclude_rooms=exclude_rooms, + ) + except Exception as e: + logger.warning( + "Unable to get summary of %s via federation: %s", + room_id, + e, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + return () + + # Group the events by their room. + children_by_room: Dict[str, List[JsonDict]] = {} + for ev in res.events: + if ev.event_type == EventTypes.SpaceChild: + children_by_room.setdefault(ev.room_id, []).append(ev.data) + + # Generate the final results. + results = [] + for fed_room in res.rooms: + fed_room_id = fed_room.get("room_id") + if not fed_room_id or not isinstance(fed_room_id, str): + continue + + results.append( + _RoomEntry( + fed_room_id, + fed_room, + children_by_room.get(fed_room_id, []), + ) + ) + + return results + + async def _summarize_remote_room_hierarchy( + self, room: "_RoomQueueEntry", suggested_only: bool + ) -> Tuple[Optional["_RoomEntry"], Dict[str, JsonDict], Set[str]]: + """ + Request room entries and a list of event entries for a given room by querying a remote server. + + Args: + room: The room to summarize. + suggested_only: True if only suggested children should be returned. + Otherwise, all children are returned. + + Returns: + A tuple of: + The room entry. + Partial room data return over federation. + A set of inaccessible children room IDs. + """ + room_id = room.room_id + logger.info("Requesting summary for %s via %s", room_id, room.via) + + via = itertools.islice(room.via, MAX_SERVERS_PER_SPACE) + try: + ( + room_response, + children, + inaccessible_children, + ) = await self._federation_client.get_room_hierarchy( + via, + room_id, + suggested_only=suggested_only, + ) + except Exception as e: + logger.warning( + "Unable to get hierarchy of %s via federation: %s", + room_id, + e, + exc_info=logger.isEnabledFor(logging.DEBUG), + ) + return None, {}, set() + + # Map the children to their room ID. + children_by_room_id = { + c["room_id"]: c + for c in children + if "room_id" in c and isinstance(c["room_id"], str) + } + + return ( + _RoomEntry(room_id, room_response, room_response.pop("children_state", ())), + children_by_room_id, + set(inaccessible_children), + ) + + async def _is_local_room_accessible( + self, room_id: str, requester: Optional[str], origin: Optional[str] = None + ) -> bool: + """ + Calculate whether the room should be shown to the requester. + + It should return true if: + + * The requester is joined or can join the room (per MSC3173). + * The origin server has any user that is joined or can join the room. + * The history visibility is set to world readable. + + Args: + room_id: The room ID to check accessibility of. + requester: + The user making the request, if it is a local request. + None if this is a federation request. + origin: + The server making the request, if it is a federation request. + None if this is a local request. + + Returns: + True if the room is accessible to the requesting user or server. + """ + state_ids = await self._store.get_current_state_ids(room_id) + + # If there's no state for the room, it isn't known. + if not state_ids: + # The user might have a pending invite for the room. + if requester and await self._store.get_invite_for_local_user_in_room( + requester, room_id + ): + return True + + logger.info("room %s is unknown, omitting from summary", room_id) + return False + + room_version = await self._store.get_room_version(room_id) + + # Include the room if it has join rules of public or knock. + join_rules_event_id = state_ids.get((EventTypes.JoinRules, "")) + if join_rules_event_id: + join_rules_event = await self._store.get_event(join_rules_event_id) + join_rule = join_rules_event.content.get("join_rule") + if join_rule == JoinRules.PUBLIC or ( + room_version.msc2403_knocking and join_rule == JoinRules.KNOCK + ): + return True + + # Include the room if it is peekable. + hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, "")) + if hist_vis_event_id: + hist_vis_ev = await self._store.get_event(hist_vis_event_id) + hist_vis = hist_vis_ev.content.get("history_visibility") + if hist_vis == HistoryVisibility.WORLD_READABLE: + return True + + # Otherwise we need to check information specific to the user or server. + + # If we have an authenticated requesting user, check if they are a member + # of the room (or can join the room). + if requester: + member_event_id = state_ids.get((EventTypes.Member, requester), None) + + # If they're in the room they can see info on it. + if member_event_id: + member_event = await self._store.get_event(member_event_id) + if member_event.membership in (Membership.JOIN, Membership.INVITE): + return True + + # Otherwise, check if they should be allowed access via membership in a space. + if await self._event_auth_handler.has_restricted_join_rules( + state_ids, room_version + ): + allowed_rooms = ( + await self._event_auth_handler.get_rooms_that_allow_join(state_ids) + ) + if await self._event_auth_handler.is_user_in_rooms( + allowed_rooms, requester + ): + return True + + # If this is a request over federation, check if the host is in the room or + # has a user who could join the room. + elif origin: + if await self._event_auth_handler.check_host_in_room( + room_id, origin + ) or await self._store.is_host_invited(room_id, origin): + return True + + # Alternately, if the host has a user in any of the spaces specified + # for access, then the host can see this room (and should do filtering + # if the requester cannot see it). + if await self._event_auth_handler.has_restricted_join_rules( + state_ids, room_version + ): + allowed_rooms = ( + await self._event_auth_handler.get_rooms_that_allow_join(state_ids) + ) + for space_id in allowed_rooms: + if await self._event_auth_handler.check_host_in_room( + space_id, origin + ): + return True + + logger.info( + "room %s is unpeekable and requester %s is not a member / not allowed to join, omitting from summary", + room_id, + requester or origin, + ) + return False + + async def _is_remote_room_accessible( + self, requester: str, room_id: str, room: JsonDict + ) -> bool: + """ + Calculate whether the room received over federation should be shown to the requester. + + It should return true if: + + * The requester is joined or can join the room (per MSC3173). + * The history visibility is set to world readable. + + Note that the local server is not in the requested room (which is why the + remote call was made in the first place), but the user could have access + due to an invite, etc. + + Args: + requester: The user requesting the summary. + room_id: The room ID returned over federation. + room: The summary of the room returned over federation. + + Returns: + True if the room is accessible to the requesting user. + """ + # The API doesn't return the room version so assume that a + # join rule of knock is valid. + if ( + room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK) + or room.get("world_readable") is True + ): + return True + + # Check if the user is a member of any of the allowed spaces + # from the response. + allowed_rooms = room.get("allowed_room_ids") or room.get("allowed_spaces") + if allowed_rooms and isinstance(allowed_rooms, list): + if await self._event_auth_handler.is_user_in_rooms( + allowed_rooms, requester + ): + return True + + # Finally, check locally if we can access the room. The user might + # already be in the room (if it was a child room), or there might be a + # pending invite, etc. + return await self._is_local_room_accessible(room_id, requester) + + async def _build_room_entry(self, room_id: str, for_federation: bool) -> JsonDict: + """ + Generate en entry summarising a single room. + + Args: + room_id: The room ID to summarize. + for_federation: True if this is a summary requested over federation + (which includes additional fields). + + Returns: + The JSON dictionary for the room. + """ + stats = await self._store.get_room_with_stats(room_id) + + # currently this should be impossible because we call + # _is_local_room_accessible on the room before we get here, so + # there should always be an entry + assert stats is not None, "unable to retrieve stats for %s" % (room_id,) + + current_state_ids = await self._store.get_current_state_ids(room_id) + create_event = await self._store.get_event( + current_state_ids[(EventTypes.Create, "")] + ) + + entry = { + "room_id": stats["room_id"], + "name": stats["name"], + "topic": stats["topic"], + "canonical_alias": stats["canonical_alias"], + "num_joined_members": stats["joined_members"], + "avatar_url": stats["avatar"], + "join_rules": stats["join_rules"], + "world_readable": ( + stats["history_visibility"] == HistoryVisibility.WORLD_READABLE + ), + "guest_can_join": stats["guest_access"] == "can_join", + "creation_ts": create_event.origin_server_ts, + "room_type": create_event.content.get(EventContentFields.ROOM_TYPE), + } + + # Federation requests need to provide additional information so the + # requested server is able to filter the response appropriately. + if for_federation: + room_version = await self._store.get_room_version(room_id) + if await self._event_auth_handler.has_restricted_join_rules( + current_state_ids, room_version + ): + allowed_rooms = ( + await self._event_auth_handler.get_rooms_that_allow_join( + current_state_ids + ) + ) + if allowed_rooms: + entry["allowed_room_ids"] = allowed_rooms + # TODO Remove this key once the API is stable. + entry["allowed_spaces"] = allowed_rooms + + # Filter out Nones – rather omit the field altogether + room_entry = {k: v for k, v in entry.items() if v is not None} + + return room_entry + + async def _get_child_events(self, room_id: str) -> Iterable[EventBase]: + """ + Get the child events for a given room. + + The returned results are sorted for stability. + + Args: + room_id: The room id to get the children of. + + Returns: + An iterable of sorted child events. + """ + + # look for child rooms/spaces. + current_state_ids = await self._store.get_current_state_ids(room_id) + + events = await self._store.get_events_as_list( + [ + event_id + for key, event_id in current_state_ids.items() + if key[0] == EventTypes.SpaceChild + ] + ) + + # filter out any events without a "via" (which implies it has been redacted), + # and order to ensure we return stable results. + return sorted(filter(_has_valid_via, events), key=_child_events_comparison_key) + + async def get_room_summary( + self, + requester: Optional[str], + room_id: str, + remote_room_hosts: Optional[List[str]] = None, + ) -> JsonDict: + """ + Implementation of the room summary C-S API from MSC3266 + + Args: + requester: user id of the user making this request, will be None + for unauthenticated requests + + room_id: room id to summarise. + + remote_room_hosts: a list of homeservers to try fetching data through + if we don't know it ourselves + + Returns: + summary dict to return + """ + is_in_room = await self._store.is_host_joined(room_id, self._server_name) + + if is_in_room: + room_entry = await self._summarize_local_room( + requester, + None, + room_id, + # Suggested-only doesn't matter since no children are requested. + suggested_only=False, + max_children=0, + ) + + if not room_entry: + raise NotFoundError("Room not found or is not accessible") + + room_summary = room_entry.room + + # If there was a requester, add their membership. + if requester: + ( + membership, + _, + ) = await self._store.get_local_current_membership_for_user_in_room( + requester, room_id + ) + + room_summary["membership"] = membership or "leave" + else: + # TODO federation API, descoped from initial unstable implementation + # as MSC needs more maturing on that side. + raise SynapseError(400, "Federation is not currently supported.") + + return room_summary + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class _RoomQueueEntry: + # The room ID of this entry. + room_id: str + # The server to query if the room is not known locally. + via: Sequence[str] + # The minimum number of hops necessary to get to this room (compared to the + # originally requested room). + depth: int = 0 + # The room summary for this room returned via federation. This will only be + # used if the room is not known locally (and is not a space). + remote_room: Optional[JsonDict] = None + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class _RoomEntry: + room_id: str + # The room summary for this room. + room: JsonDict + # An iterable of the sorted, stripped children events for children of this room. + # + # This may not include all children. + children_state_events: Sequence[JsonDict] = () + + def as_json(self) -> JsonDict: + """ + Returns a JSON dictionary suitable for the room hierarchy endpoint. + + It returns the room summary including the stripped m.space.child events + as a sub-key. + """ + result = dict(self.room) + result["children_state"] = self.children_state_events + return result + + +def _has_valid_via(e: EventBase) -> bool: + via = e.content.get("via") + if not via or not isinstance(via, Sequence): + return False + for v in via: + if not isinstance(v, str): + logger.debug("Ignoring edge event %s with invalid via entry", e.event_id) + return False + return True + + +def _is_suggested_child_event(edge_event: EventBase) -> bool: + suggested = edge_event.content.get("suggested") + if isinstance(suggested, bool) and suggested: + return True + logger.debug("Ignorning not-suggested child %s", edge_event.state_key) + return False + + +# Order may only contain characters in the range of \x20 (space) to \x7E (~) inclusive. +_INVALID_ORDER_CHARS_RE = re.compile(r"[^\x20-\x7E]") + + +def _child_events_comparison_key(child: EventBase) -> Tuple[bool, Optional[str], str]: + """ + Generate a value for comparing two child events for ordering. + + The rules for ordering are supposed to be: + + 1. The 'order' key, if it is valid. + 2. The 'origin_server_ts' of the 'm.room.create' event. + 3. The 'room_id'. + + But we skip step 2 since we may not have any state from the room. + + Args: + child: The event for generating a comparison key. + + Returns: + The comparison key as a tuple of: + False if the ordering is valid. + The ordering field. + The room ID. + """ + order = child.content.get("order") + # If order is not a string or doesn't meet the requirements, ignore it. + if not isinstance(order, str): + order = None + elif len(order) > 50 or _INVALID_ORDER_CHARS_RE.search(order): + order = None + + # Items without an order come last. + return (order is None, order, child.room_id) diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py deleted file mode 100644 index c74e90abbc..0000000000 --- a/synapse/handlers/space_summary.py +++ /dev/null @@ -1,1116 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import itertools -import logging -import re -from collections import deque -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Set, Tuple - -import attr - -from synapse.api.constants import ( - EventContentFields, - EventTypes, - HistoryVisibility, - JoinRules, - Membership, - RoomTypes, -) -from synapse.api.errors import AuthError, Codes, SynapseError -from synapse.events import EventBase -from synapse.events.utils import format_event_for_client_v2 -from synapse.types import JsonDict -from synapse.util.caches.response_cache import ResponseCache -from synapse.util.stringutils import random_string - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - -# number of rooms to return. We'll stop once we hit this limit. -MAX_ROOMS = 50 - -# max number of events to return per room. -MAX_ROOMS_PER_SPACE = 50 - -# max number of federation servers to hit per room -MAX_SERVERS_PER_SPACE = 3 - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _PaginationKey: - """The key used to find unique pagination session.""" - - # The first three entries match the request parameters (and cannot change - # during a pagination session). - room_id: str - suggested_only: bool - max_depth: Optional[int] - # The randomly generated token. - token: str - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _PaginationSession: - """The information that is stored for pagination.""" - - # The time the pagination session was created, in milliseconds. - creation_time_ms: int - # The queue of rooms which are still to process. - room_queue: List["_RoomQueueEntry"] - # A set of rooms which have been processed. - processed_rooms: Set[str] - - -class SpaceSummaryHandler: - # The time a pagination session remains valid for. - _PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000 - - def __init__(self, hs: "HomeServer"): - self._clock = hs.get_clock() - self._event_auth_handler = hs.get_event_auth_handler() - self._store = hs.get_datastore() - self._event_serializer = hs.get_event_client_serializer() - self._server_name = hs.hostname - self._federation_client = hs.get_federation_client() - - # A map of query information to the current pagination state. - # - # TODO Allow for multiple workers to share this data. - # TODO Expire pagination tokens. - self._pagination_sessions: Dict[_PaginationKey, _PaginationSession] = {} - - # If a user tries to fetch the same page multiple times in quick succession, - # only process the first attempt and return its result to subsequent requests. - self._pagination_response_cache: ResponseCache[ - Tuple[str, bool, Optional[int], Optional[int], Optional[str]] - ] = ResponseCache( - hs.get_clock(), - "get_room_hierarchy", - ) - - def _expire_pagination_sessions(self): - """Expire pagination session which are old.""" - expire_before = ( - self._clock.time_msec() - self._PAGINATION_SESSION_VALIDITY_PERIOD_MS - ) - to_expire = [] - - for key, value in self._pagination_sessions.items(): - if value.creation_time_ms < expire_before: - to_expire.append(key) - - for key in to_expire: - logger.debug("Expiring pagination session id %s", key) - del self._pagination_sessions[key] - - async def get_space_summary( - self, - requester: str, - room_id: str, - suggested_only: bool = False, - max_rooms_per_space: Optional[int] = None, - ) -> JsonDict: - """ - Implementation of the space summary C-S API - - Args: - requester: user id of the user making this request - - room_id: room id to start the summary at - - suggested_only: whether we should only return children with the "suggested" - flag set. - - max_rooms_per_space: an optional limit on the number of child rooms we will - return. This does not apply to the root room (ie, room_id), and - is overridden by MAX_ROOMS_PER_SPACE. - - Returns: - summary dict to return - """ - # First of all, check that the room is accessible. - if not await self._is_local_room_accessible(room_id, requester): - raise AuthError( - 403, - "User %s not in room %s, and room previews are disabled" - % (requester, room_id), - ) - - # the queue of rooms to process - room_queue = deque((_RoomQueueEntry(room_id, ()),)) - - # rooms we have already processed - processed_rooms: Set[str] = set() - - # events we have already processed. We don't necessarily have their event ids, - # so instead we key on (room id, state key) - processed_events: Set[Tuple[str, str]] = set() - - rooms_result: List[JsonDict] = [] - events_result: List[JsonDict] = [] - - while room_queue and len(rooms_result) < MAX_ROOMS: - queue_entry = room_queue.popleft() - room_id = queue_entry.room_id - if room_id in processed_rooms: - # already done this room - continue - - logger.debug("Processing room %s", room_id) - - is_in_room = await self._store.is_host_joined(room_id, self._server_name) - - # The client-specified max_rooms_per_space limit doesn't apply to the - # room_id specified in the request, so we ignore it if this is the - # first room we are processing. - max_children = max_rooms_per_space if processed_rooms else None - - if is_in_room: - room_entry = await self._summarize_local_room( - requester, None, room_id, suggested_only, max_children - ) - - events: Sequence[JsonDict] = [] - if room_entry: - rooms_result.append(room_entry.room) - events = room_entry.children_state_events - - logger.debug( - "Query of local room %s returned events %s", - room_id, - ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], - ) - else: - fed_rooms = await self._summarize_remote_room( - queue_entry, - suggested_only, - max_children, - exclude_rooms=processed_rooms, - ) - - # The results over federation might include rooms that the we, - # as the requesting server, are allowed to see, but the requesting - # user is not permitted see. - # - # Filter the returned results to only what is accessible to the user. - events = [] - for room_entry in fed_rooms: - room = room_entry.room - fed_room_id = room_entry.room_id - - # The user can see the room, include it! - if await self._is_remote_room_accessible( - requester, fed_room_id, room - ): - # Before returning to the client, remove the allowed_room_ids - # and allowed_spaces keys. - room.pop("allowed_room_ids", None) - room.pop("allowed_spaces", None) - - rooms_result.append(room) - events.extend(room_entry.children_state_events) - - # All rooms returned don't need visiting again (even if the user - # didn't have access to them). - processed_rooms.add(fed_room_id) - - logger.debug( - "Query of %s returned rooms %s, events %s", - room_id, - [room_entry.room.get("room_id") for room_entry in fed_rooms], - ["%s->%s" % (ev["room_id"], ev["state_key"]) for ev in events], - ) - - # the room we queried may or may not have been returned, but don't process - # it again, anyway. - processed_rooms.add(room_id) - - # XXX: is it ok that we blindly iterate through any events returned by - # a remote server, whether or not they actually link to any rooms in our - # tree? - for ev in events: - # remote servers might return events we have already processed - # (eg, Dendrite returns inward pointers as well as outward ones), so - # we need to filter them out, to avoid returning duplicate links to the - # client. - ev_key = (ev["room_id"], ev["state_key"]) - if ev_key in processed_events: - continue - events_result.append(ev) - - # add the child to the queue. we have already validated - # that the vias are a list of server names. - room_queue.append( - _RoomQueueEntry(ev["state_key"], ev["content"]["via"]) - ) - processed_events.add(ev_key) - - return {"rooms": rooms_result, "events": events_result} - - async def get_room_hierarchy( - self, - requester: str, - requested_room_id: str, - suggested_only: bool = False, - max_depth: Optional[int] = None, - limit: Optional[int] = None, - from_token: Optional[str] = None, - ) -> JsonDict: - """ - Implementation of the room hierarchy C-S API. - - Args: - requester: The user ID of the user making this request. - requested_room_id: The room ID to start the hierarchy at (the "root" room). - suggested_only: Whether we should only return children with the "suggested" - flag set. - max_depth: The maximum depth in the tree to explore, must be a - non-negative integer. - - 0 would correspond to just the root room, 1 would include just - the root room's children, etc. - limit: An optional limit on the number of rooms to return per - page. Must be a positive integer. - from_token: An optional pagination token. - - Returns: - The JSON hierarchy dictionary. - """ - # If a user tries to fetch the same page multiple times in quick succession, - # only process the first attempt and return its result to subsequent requests. - # - # This is due to the pagination process mutating internal state, attempting - # to process multiple requests for the same page will result in errors. - return await self._pagination_response_cache.wrap( - (requested_room_id, suggested_only, max_depth, limit, from_token), - self._get_room_hierarchy, - requester, - requested_room_id, - suggested_only, - max_depth, - limit, - from_token, - ) - - async def _get_room_hierarchy( - self, - requester: str, - requested_room_id: str, - suggested_only: bool = False, - max_depth: Optional[int] = None, - limit: Optional[int] = None, - from_token: Optional[str] = None, - ) -> JsonDict: - """See docstring for SpaceSummaryHandler.get_room_hierarchy.""" - - # First of all, check that the room is accessible. - if not await self._is_local_room_accessible(requested_room_id, requester): - raise AuthError( - 403, - "User %s not in room %s, and room previews are disabled" - % (requester, requested_room_id), - ) - - # If this is continuing a previous session, pull the persisted data. - if from_token: - self._expire_pagination_sessions() - - pagination_key = _PaginationKey( - requested_room_id, suggested_only, max_depth, from_token - ) - if pagination_key not in self._pagination_sessions: - raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM) - - # Load the previous state. - pagination_session = self._pagination_sessions[pagination_key] - room_queue = pagination_session.room_queue - processed_rooms = pagination_session.processed_rooms - else: - # The queue of rooms to process, the next room is last on the stack. - room_queue = [_RoomQueueEntry(requested_room_id, ())] - - # Rooms we have already processed. - processed_rooms = set() - - rooms_result: List[JsonDict] = [] - - # Cap the limit to a server-side maximum. - if limit is None: - limit = MAX_ROOMS - else: - limit = min(limit, MAX_ROOMS) - - # Iterate through the queue until we reach the limit or run out of - # rooms to include. - while room_queue and len(rooms_result) < limit: - queue_entry = room_queue.pop() - room_id = queue_entry.room_id - current_depth = queue_entry.depth - if room_id in processed_rooms: - # already done this room - continue - - logger.debug("Processing room %s", room_id) - - # A map of summaries for children rooms that might be returned over - # federation. The rationale for caching these and *maybe* using them - # is to prefer any information local to the homeserver before trusting - # data received over federation. - children_room_entries: Dict[str, JsonDict] = {} - # A set of room IDs which are children that did not have information - # returned over federation and are known to be inaccessible to the - # current server. We should not reach out over federation to try to - # summarise these rooms. - inaccessible_children: Set[str] = set() - - # If the room is known locally, summarise it! - is_in_room = await self._store.is_host_joined(room_id, self._server_name) - if is_in_room: - room_entry = await self._summarize_local_room( - requester, - None, - room_id, - suggested_only, - # TODO Handle max children. - max_children=None, - ) - - # Otherwise, attempt to use information for federation. - else: - # A previous call might have included information for this room. - # It can be used if either: - # - # 1. The room is not a space. - # 2. The maximum depth has been achieved (since no children - # information is needed). - if queue_entry.remote_room and ( - queue_entry.remote_room.get("room_type") != RoomTypes.SPACE - or (max_depth is not None and current_depth >= max_depth) - ): - room_entry = _RoomEntry( - queue_entry.room_id, queue_entry.remote_room - ) - - # If the above isn't true, attempt to fetch the room - # information over federation. - else: - ( - room_entry, - children_room_entries, - inaccessible_children, - ) = await self._summarize_remote_room_hiearchy( - queue_entry, - suggested_only, - ) - - # Ensure this room is accessible to the requester (and not just - # the homeserver). - if room_entry and not await self._is_remote_room_accessible( - requester, queue_entry.room_id, room_entry.room - ): - room_entry = None - - # This room has been processed and should be ignored if it appears - # elsewhere in the hierarchy. - processed_rooms.add(room_id) - - # There may or may not be a room entry based on whether it is - # inaccessible to the requesting user. - if room_entry: - # Add the room (including the stripped m.space.child events). - rooms_result.append(room_entry.as_json()) - - # If this room is not at the max-depth, check if there are any - # children to process. - if max_depth is None or current_depth < max_depth: - # The children get added in reverse order so that the next - # room to process, according to the ordering, is the last - # item in the list. - room_queue.extend( - _RoomQueueEntry( - ev["state_key"], - ev["content"]["via"], - current_depth + 1, - children_room_entries.get(ev["state_key"]), - ) - for ev in reversed(room_entry.children_state_events) - if ev["type"] == EventTypes.SpaceChild - and ev["state_key"] not in inaccessible_children - ) - - result: JsonDict = {"rooms": rooms_result} - - # If there's additional data, generate a pagination token (and persist state). - if room_queue: - next_batch = random_string(24) - result["next_batch"] = next_batch - pagination_key = _PaginationKey( - requested_room_id, suggested_only, max_depth, next_batch - ) - self._pagination_sessions[pagination_key] = _PaginationSession( - self._clock.time_msec(), room_queue, processed_rooms - ) - - return result - - async def federation_space_summary( - self, - origin: str, - room_id: str, - suggested_only: bool, - max_rooms_per_space: Optional[int], - exclude_rooms: Iterable[str], - ) -> JsonDict: - """ - Implementation of the space summary Federation API - - Args: - origin: The server requesting the spaces summary. - - room_id: room id to start the summary at - - suggested_only: whether we should only return children with the "suggested" - flag set. - - max_rooms_per_space: an optional limit on the number of child rooms we will - return. Unlike the C-S API, this applies to the root room (room_id). - It is clipped to MAX_ROOMS_PER_SPACE. - - exclude_rooms: a list of rooms to skip over (presumably because the - calling server has already seen them). - - Returns: - summary dict to return - """ - # the queue of rooms to process - room_queue = deque((room_id,)) - - # the set of rooms that we should not walk further. Initialise it with the - # excluded-rooms list; we will add other rooms as we process them so that - # we do not loop. - processed_rooms: Set[str] = set(exclude_rooms) - - rooms_result: List[JsonDict] = [] - events_result: List[JsonDict] = [] - - while room_queue and len(rooms_result) < MAX_ROOMS: - room_id = room_queue.popleft() - if room_id in processed_rooms: - # already done this room - continue - - room_entry = await self._summarize_local_room( - None, origin, room_id, suggested_only, max_rooms_per_space - ) - - processed_rooms.add(room_id) - - if room_entry: - rooms_result.append(room_entry.room) - events_result.extend(room_entry.children_state_events) - - # add any children to the queue - room_queue.extend( - edge_event["state_key"] - for edge_event in room_entry.children_state_events - ) - - return {"rooms": rooms_result, "events": events_result} - - async def get_federation_hierarchy( - self, - origin: str, - requested_room_id: str, - suggested_only: bool, - ): - """ - Implementation of the room hierarchy Federation API. - - This is similar to get_room_hierarchy, but does not recurse into the space. - It also considers whether anyone on the server may be able to access the - room, as opposed to whether a specific user can. - - Args: - origin: The server requesting the spaces summary. - requested_room_id: The room ID to start the hierarchy at (the "root" room). - suggested_only: whether we should only return children with the "suggested" - flag set. - - Returns: - The JSON hierarchy dictionary. - """ - root_room_entry = await self._summarize_local_room( - None, origin, requested_room_id, suggested_only, max_children=None - ) - if root_room_entry is None: - # Room is inaccessible to the requesting server. - raise SynapseError(404, "Unknown room: %s" % (requested_room_id,)) - - children_rooms_result: List[JsonDict] = [] - inaccessible_children: List[str] = [] - - # Iterate through each child and potentially add it, but not its children, - # to the response. - for child_room in root_room_entry.children_state_events: - room_id = child_room.get("state_key") - assert isinstance(room_id, str) - # If the room is unknown, skip it. - if not await self._store.is_host_joined(room_id, self._server_name): - continue - - room_entry = await self._summarize_local_room( - None, origin, room_id, suggested_only, max_children=0 - ) - # If the room is accessible, include it in the results. - # - # Note that only the room summary (without information on children) - # is included in the summary. - if room_entry: - children_rooms_result.append(room_entry.room) - # Otherwise, note that the requesting server shouldn't bother - # trying to summarize this room - they do not have access to it. - else: - inaccessible_children.append(room_id) - - return { - # Include the requested room (including the stripped children events). - "room": root_room_entry.as_json(), - "children": children_rooms_result, - "inaccessible_children": inaccessible_children, - } - - async def _summarize_local_room( - self, - requester: Optional[str], - origin: Optional[str], - room_id: str, - suggested_only: bool, - max_children: Optional[int], - ) -> Optional["_RoomEntry"]: - """ - Generate a room entry and a list of event entries for a given room. - - Args: - requester: - The user requesting the summary, if it is a local request. None - if this is a federation request. - origin: - The server requesting the summary, if it is a federation request. - None if this is a local request. - room_id: The room ID to summarize. - suggested_only: True if only suggested children should be returned. - Otherwise, all children are returned. - max_children: - The maximum number of children rooms to include. This is capped - to a server-set limit. - - Returns: - A room entry if the room should be returned. None, otherwise. - """ - if not await self._is_local_room_accessible(room_id, requester, origin): - return None - - room_entry = await self._build_room_entry(room_id, for_federation=bool(origin)) - - # If the room is not a space or the children don't matter, return just - # the room information. - if room_entry.get("room_type") != RoomTypes.SPACE or max_children == 0: - return _RoomEntry(room_id, room_entry) - - # Otherwise, look for child rooms/spaces. - child_events = await self._get_child_events(room_id) - - if suggested_only: - # we only care about suggested children - child_events = filter(_is_suggested_child_event, child_events) - - if max_children is None or max_children > MAX_ROOMS_PER_SPACE: - max_children = MAX_ROOMS_PER_SPACE - - now = self._clock.time_msec() - events_result: List[JsonDict] = [] - for edge_event in itertools.islice(child_events, max_children): - events_result.append( - await self._event_serializer.serialize_event( - edge_event, - time_now=now, - event_format=format_event_for_client_v2, - ) - ) - - return _RoomEntry(room_id, room_entry, events_result) - - async def _summarize_remote_room( - self, - room: "_RoomQueueEntry", - suggested_only: bool, - max_children: Optional[int], - exclude_rooms: Iterable[str], - ) -> Iterable["_RoomEntry"]: - """ - Request room entries and a list of event entries for a given room by querying a remote server. - - Args: - room: The room to summarize. - suggested_only: True if only suggested children should be returned. - Otherwise, all children are returned. - max_children: - The maximum number of children rooms to include. This is capped - to a server-set limit. - exclude_rooms: - Rooms IDs which do not need to be summarized. - - Returns: - An iterable of room entries. - """ - room_id = room.room_id - logger.info("Requesting summary for %s via %s", room_id, room.via) - - # we need to make the exclusion list json-serialisable - exclude_rooms = list(exclude_rooms) - - via = itertools.islice(room.via, MAX_SERVERS_PER_SPACE) - try: - res = await self._federation_client.get_space_summary( - via, - room_id, - suggested_only=suggested_only, - max_rooms_per_space=max_children, - exclude_rooms=exclude_rooms, - ) - except Exception as e: - logger.warning( - "Unable to get summary of %s via federation: %s", - room_id, - e, - exc_info=logger.isEnabledFor(logging.DEBUG), - ) - return () - - # Group the events by their room. - children_by_room: Dict[str, List[JsonDict]] = {} - for ev in res.events: - if ev.event_type == EventTypes.SpaceChild: - children_by_room.setdefault(ev.room_id, []).append(ev.data) - - # Generate the final results. - results = [] - for fed_room in res.rooms: - fed_room_id = fed_room.get("room_id") - if not fed_room_id or not isinstance(fed_room_id, str): - continue - - results.append( - _RoomEntry( - fed_room_id, - fed_room, - children_by_room.get(fed_room_id, []), - ) - ) - - return results - - async def _summarize_remote_room_hiearchy( - self, room: "_RoomQueueEntry", suggested_only: bool - ) -> Tuple[Optional["_RoomEntry"], Dict[str, JsonDict], Set[str]]: - """ - Request room entries and a list of event entries for a given room by querying a remote server. - - Args: - room: The room to summarize. - suggested_only: True if only suggested children should be returned. - Otherwise, all children are returned. - - Returns: - A tuple of: - The room entry. - Partial room data return over federation. - A set of inaccessible children room IDs. - """ - room_id = room.room_id - logger.info("Requesting summary for %s via %s", room_id, room.via) - - via = itertools.islice(room.via, MAX_SERVERS_PER_SPACE) - try: - ( - room_response, - children, - inaccessible_children, - ) = await self._federation_client.get_room_hierarchy( - via, - room_id, - suggested_only=suggested_only, - ) - except Exception as e: - logger.warning( - "Unable to get hierarchy of %s via federation: %s", - room_id, - e, - exc_info=logger.isEnabledFor(logging.DEBUG), - ) - return None, {}, set() - - # Map the children to their room ID. - children_by_room_id = { - c["room_id"]: c - for c in children - if "room_id" in c and isinstance(c["room_id"], str) - } - - return ( - _RoomEntry(room_id, room_response, room_response.pop("children_state", ())), - children_by_room_id, - set(inaccessible_children), - ) - - async def _is_local_room_accessible( - self, room_id: str, requester: Optional[str], origin: Optional[str] = None - ) -> bool: - """ - Calculate whether the room should be shown in the spaces summary. - - It should be included if: - - * The requester is joined or can join the room (per MSC3173). - * The origin server has any user that is joined or can join the room. - * The history visibility is set to world readable. - - Args: - room_id: The room ID to summarize. - requester: - The user requesting the summary, if it is a local request. None - if this is a federation request. - origin: - The server requesting the summary, if it is a federation request. - None if this is a local request. - - Returns: - True if the room should be included in the spaces summary. - """ - state_ids = await self._store.get_current_state_ids(room_id) - - # If there's no state for the room, it isn't known. - if not state_ids: - # The user might have a pending invite for the room. - if requester and await self._store.get_invite_for_local_user_in_room( - requester, room_id - ): - return True - - logger.info("room %s is unknown, omitting from summary", room_id) - return False - - room_version = await self._store.get_room_version(room_id) - - # Include the room if it has join rules of public or knock. - join_rules_event_id = state_ids.get((EventTypes.JoinRules, "")) - if join_rules_event_id: - join_rules_event = await self._store.get_event(join_rules_event_id) - join_rule = join_rules_event.content.get("join_rule") - if join_rule == JoinRules.PUBLIC or ( - room_version.msc2403_knocking and join_rule == JoinRules.KNOCK - ): - return True - - # Include the room if it is peekable. - hist_vis_event_id = state_ids.get((EventTypes.RoomHistoryVisibility, "")) - if hist_vis_event_id: - hist_vis_ev = await self._store.get_event(hist_vis_event_id) - hist_vis = hist_vis_ev.content.get("history_visibility") - if hist_vis == HistoryVisibility.WORLD_READABLE: - return True - - # Otherwise we need to check information specific to the user or server. - - # If we have an authenticated requesting user, check if they are a member - # of the room (or can join the room). - if requester: - member_event_id = state_ids.get((EventTypes.Member, requester), None) - - # If they're in the room they can see info on it. - if member_event_id: - member_event = await self._store.get_event(member_event_id) - if member_event.membership in (Membership.JOIN, Membership.INVITE): - return True - - # Otherwise, check if they should be allowed access via membership in a space. - if await self._event_auth_handler.has_restricted_join_rules( - state_ids, room_version - ): - allowed_rooms = ( - await self._event_auth_handler.get_rooms_that_allow_join(state_ids) - ) - if await self._event_auth_handler.is_user_in_rooms( - allowed_rooms, requester - ): - return True - - # If this is a request over federation, check if the host is in the room or - # has a user who could join the room. - elif origin: - if await self._event_auth_handler.check_host_in_room( - room_id, origin - ) or await self._store.is_host_invited(room_id, origin): - return True - - # Alternately, if the host has a user in any of the spaces specified - # for access, then the host can see this room (and should do filtering - # if the requester cannot see it). - if await self._event_auth_handler.has_restricted_join_rules( - state_ids, room_version - ): - allowed_rooms = ( - await self._event_auth_handler.get_rooms_that_allow_join(state_ids) - ) - for space_id in allowed_rooms: - if await self._event_auth_handler.check_host_in_room( - space_id, origin - ): - return True - - logger.info( - "room %s is unpeekable and requester %s is not a member / not allowed to join, omitting from summary", - room_id, - requester or origin, - ) - return False - - async def _is_remote_room_accessible( - self, requester: str, room_id: str, room: JsonDict - ) -> bool: - """ - Calculate whether the room received over federation should be shown in the spaces summary. - - It should be included if: - - * The requester is joined or can join the room (per MSC3173). - * The history visibility is set to world readable. - - Note that the local server is not in the requested room (which is why the - remote call was made in the first place), but the user could have access - due to an invite, etc. - - Args: - requester: The user requesting the summary. - room_id: The room ID returned over federation. - room: The summary of the child room returned over federation. - - Returns: - True if the room should be included in the spaces summary. - """ - # The API doesn't return the room version so assume that a - # join rule of knock is valid. - if ( - room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK) - or room.get("world_readable") is True - ): - return True - - # Check if the user is a member of any of the allowed spaces - # from the response. - allowed_rooms = room.get("allowed_room_ids") or room.get("allowed_spaces") - if allowed_rooms and isinstance(allowed_rooms, list): - if await self._event_auth_handler.is_user_in_rooms( - allowed_rooms, requester - ): - return True - - # Finally, check locally if we can access the room. The user might - # already be in the room (if it was a child room), or there might be a - # pending invite, etc. - return await self._is_local_room_accessible(room_id, requester) - - async def _build_room_entry(self, room_id: str, for_federation: bool) -> JsonDict: - """ - Generate en entry suitable for the 'rooms' list in the summary response. - - Args: - room_id: The room ID to summarize. - for_federation: True if this is a summary requested over federation - (which includes additional fields). - - Returns: - The JSON dictionary for the room. - """ - stats = await self._store.get_room_with_stats(room_id) - - # currently this should be impossible because we call - # _is_local_room_accessible on the room before we get here, so - # there should always be an entry - assert stats is not None, "unable to retrieve stats for %s" % (room_id,) - - current_state_ids = await self._store.get_current_state_ids(room_id) - create_event = await self._store.get_event( - current_state_ids[(EventTypes.Create, "")] - ) - - entry = { - "room_id": stats["room_id"], - "name": stats["name"], - "topic": stats["topic"], - "canonical_alias": stats["canonical_alias"], - "num_joined_members": stats["joined_members"], - "avatar_url": stats["avatar"], - "join_rules": stats["join_rules"], - "world_readable": ( - stats["history_visibility"] == HistoryVisibility.WORLD_READABLE - ), - "guest_can_join": stats["guest_access"] == "can_join", - "creation_ts": create_event.origin_server_ts, - "room_type": create_event.content.get(EventContentFields.ROOM_TYPE), - } - - # Federation requests need to provide additional information so the - # requested server is able to filter the response appropriately. - if for_federation: - room_version = await self._store.get_room_version(room_id) - if await self._event_auth_handler.has_restricted_join_rules( - current_state_ids, room_version - ): - allowed_rooms = ( - await self._event_auth_handler.get_rooms_that_allow_join( - current_state_ids - ) - ) - if allowed_rooms: - entry["allowed_room_ids"] = allowed_rooms - # TODO Remove this key once the API is stable. - entry["allowed_spaces"] = allowed_rooms - - # Filter out Nones – rather omit the field altogether - room_entry = {k: v for k, v in entry.items() if v is not None} - - return room_entry - - async def _get_child_events(self, room_id: str) -> Iterable[EventBase]: - """ - Get the child events for a given room. - - The returned results are sorted for stability. - - Args: - room_id: The room id to get the children of. - - Returns: - An iterable of sorted child events. - """ - - # look for child rooms/spaces. - current_state_ids = await self._store.get_current_state_ids(room_id) - - events = await self._store.get_events_as_list( - [ - event_id - for key, event_id in current_state_ids.items() - if key[0] == EventTypes.SpaceChild - ] - ) - - # filter out any events without a "via" (which implies it has been redacted), - # and order to ensure we return stable results. - return sorted(filter(_has_valid_via, events), key=_child_events_comparison_key) - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class _RoomQueueEntry: - # The room ID of this entry. - room_id: str - # The server to query if the room is not known locally. - via: Sequence[str] - # The minimum number of hops necessary to get to this room (compared to the - # originally requested room). - depth: int = 0 - # The room summary for this room returned via federation. This will only be - # used if the room is not known locally (and is not a space). - remote_room: Optional[JsonDict] = None - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class _RoomEntry: - room_id: str - # The room summary for this room. - room: JsonDict - # An iterable of the sorted, stripped children events for children of this room. - # - # This may not include all children. - children_state_events: Sequence[JsonDict] = () - - def as_json(self) -> JsonDict: - """ - Returns a JSON dictionary suitable for the room hierarchy endpoint. - - It returns the room summary including the stripped m.space.child events - as a sub-key. - """ - result = dict(self.room) - result["children_state"] = self.children_state_events - return result - - -def _has_valid_via(e: EventBase) -> bool: - via = e.content.get("via") - if not via or not isinstance(via, Sequence): - return False - for v in via: - if not isinstance(v, str): - logger.debug("Ignoring edge event %s with invalid via entry", e.event_id) - return False - return True - - -def _is_suggested_child_event(edge_event: EventBase) -> bool: - suggested = edge_event.content.get("suggested") - if isinstance(suggested, bool) and suggested: - return True - logger.debug("Ignorning not-suggested child %s", edge_event.state_key) - return False - - -# Order may only contain characters in the range of \x20 (space) to \x7E (~) inclusive. -_INVALID_ORDER_CHARS_RE = re.compile(r"[^\x20-\x7E]") - - -def _child_events_comparison_key(child: EventBase) -> Tuple[bool, Optional[str], str]: - """ - Generate a value for comparing two child events for ordering. - - The rules for ordering are supposed to be: - - 1. The 'order' key, if it is valid. - 2. The 'origin_server_ts' of the 'm.room.create' event. - 3. The 'room_id'. - - But we skip step 2 since we may not have any state from the room. - - Args: - child: The event for generating a comparison key. - - Returns: - The comparison key as a tuple of: - False if the ordering is valid. - The ordering field. - The room ID. - """ - order = child.content.get("order") - # If order is not a string or doesn't meet the requirements, ignore it. - if not isinstance(order, str): - order = None - elif len(order) > 50 or _INVALID_ORDER_CHARS_RE.search(order): - order = None - - # Items without an order come last. - return (order is None, order, child.room_id) diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 732a1e6aeb..a12fa30bfd 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -14,16 +14,28 @@ """ This module contains base REST classes for constructing REST servlets. """ import logging -from typing import Iterable, List, Mapping, Optional, Sequence, overload +from typing import ( + TYPE_CHECKING, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + overload, +) from typing_extensions import Literal from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError -from synapse.types import JsonDict +from synapse.types import JsonDict, RoomAlias, RoomID from synapse.util import json_decoder +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -663,3 +675,45 @@ class RestServlet: else: raise NotImplementedError("RestServlet must register something.") + + +class ResolveRoomIdMixin: + def __init__(self, hs: "HomeServer"): + self.room_member_handler = hs.get_room_member_handler() + + async def resolve_room_id( + self, room_identifier: str, remote_room_hosts: Optional[List[str]] = None + ) -> Tuple[str, Optional[List[str]]]: + """ + Resolve a room identifier to a room ID, if necessary. + + This also performanes checks to ensure the room ID is of the proper form. + + Args: + room_identifier: The room ID or alias. + remote_room_hosts: The potential remote room hosts to use. + + Returns: + The resolved room ID. + + Raises: + SynapseError if the room ID is of the wrong form. + """ + if RoomID.is_valid(room_identifier): + resolved_room_id = room_identifier + elif RoomAlias.is_valid(room_identifier): + room_alias = RoomAlias.from_string(room_identifier) + ( + room_id, + remote_room_hosts, + ) = await self.room_member_handler.lookup_room_alias(room_alias) + resolved_room_id = room_id.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + if not resolved_room_id: + raise SynapseError( + 400, "Unknown room ID or room alias %s" % room_identifier + ) + return resolved_room_id, remote_room_hosts diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 40ee33646c..975c28b225 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -20,6 +20,7 @@ from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.http.servlet import ( + ResolveRoomIdMixin, RestServlet, assert_params_in_dict, parse_integer, @@ -33,7 +34,7 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.storage.databases.main.room import RoomSortOrder -from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester +from synapse.types import JsonDict, UserID, create_requester from synapse.util import json_decoder if TYPE_CHECKING: @@ -45,48 +46,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class ResolveRoomIdMixin: - def __init__(self, hs: "HomeServer"): - self.room_member_handler = hs.get_room_member_handler() - - async def resolve_room_id( - self, room_identifier: str, remote_room_hosts: Optional[List[str]] = None - ) -> Tuple[str, Optional[List[str]]]: - """ - Resolve a room identifier to a room ID, if necessary. - - This also performanes checks to ensure the room ID is of the proper form. - - Args: - room_identifier: The room ID or alias. - remote_room_hosts: The potential remote room hosts to use. - - Returns: - The resolved room ID. - - Raises: - SynapseError if the room ID is of the wrong form. - """ - if RoomID.is_valid(room_identifier): - resolved_room_id = room_identifier - elif RoomAlias.is_valid(room_identifier): - room_alias = RoomAlias.from_string(room_identifier) - ( - room_id, - remote_room_hosts, - ) = await self.room_member_handler.lookup_room_alias(room_alias) - resolved_room_id = room_id.to_string() - else: - raise SynapseError( - 400, "%s was not legal room ID or room alias" % (room_identifier,) - ) - if not resolved_room_id: - raise SynapseError( - 400, "Unknown room ID or room alias %s" % room_identifier - ) - return resolved_room_id, remote_room_hosts - - class ShutdownRoomRestServlet(RestServlet): """Shuts down a room by removing all local users from the room and blocking all future invites and joins to the room. Any local aliases will be repointed diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 2c3be23bc8..d3882a84e2 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -24,12 +24,14 @@ from synapse.api.errors import ( AuthError, Codes, InvalidClientCredentialsError, + MissingClientTokenError, ShadowBanError, SynapseError, ) from synapse.api.filtering import Filter from synapse.events.utils import format_event_for_client_v2 from synapse.http.servlet import ( + ResolveRoomIdMixin, RestServlet, assert_params_in_dict, parse_boolean, @@ -44,14 +46,7 @@ from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.v2_alpha._base import client_patterns from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig -from synapse.types import ( - JsonDict, - RoomAlias, - RoomID, - StreamToken, - ThirdPartyInstanceID, - UserID, -) +from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID from synapse.util import json_decoder from synapse.util.stringutils import parse_and_validate_server_name, random_string @@ -266,10 +261,10 @@ class RoomSendEventRestServlet(TransactionRestServlet): # TODO: Needs unit testing for room ID + alias joins -class JoinRoomAliasServlet(TransactionRestServlet): +class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): def __init__(self, hs): super().__init__(hs) - self.room_member_handler = hs.get_room_member_handler() + super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up self.auth = hs.get_auth() def register(self, http_server): @@ -292,24 +287,13 @@ class JoinRoomAliasServlet(TransactionRestServlet): # cheekily send invalid bodies. content = {} - if RoomID.is_valid(room_identifier): - room_id = room_identifier - - # twisted.web.server.Request.args is incorrectly defined as Optional[Any] - args: Dict[bytes, List[bytes]] = request.args # type: ignore - - remote_room_hosts = parse_strings_from_args( - args, "server_name", required=False - ) - elif RoomAlias.is_valid(room_identifier): - handler = self.room_member_handler - room_alias = RoomAlias.from_string(room_identifier) - room_id_obj, remote_room_hosts = await handler.lookup_room_alias(room_alias) - room_id = room_id_obj.to_string() - else: - raise SynapseError( - 400, "%s was not legal room ID or room alias" % (room_identifier,) - ) + # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + args: Dict[bytes, List[bytes]] = request.args # type: ignore + remote_room_hosts = parse_strings_from_args(args, "server_name", required=False) + room_id, remote_room_hosts = await self.resolve_room_id( + room_identifier, + remote_room_hosts, + ) await self.room_member_handler.update_membership( requester=requester, @@ -1002,14 +986,14 @@ class RoomSpaceSummaryRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self._auth = hs.get_auth() - self._space_summary_handler = hs.get_space_summary_handler() + self._room_summary_handler = hs.get_room_summary_handler() async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: requester = await self._auth.get_user_by_req(request, allow_guest=True) - return 200, await self._space_summary_handler.get_space_summary( + return 200, await self._room_summary_handler.get_space_summary( requester.user.to_string(), room_id, suggested_only=parse_boolean(request, "suggested_only", default=False), @@ -1035,7 +1019,7 @@ class RoomSpaceSummaryRestServlet(RestServlet): 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON ) - return 200, await self._space_summary_handler.get_space_summary( + return 200, await self._room_summary_handler.get_space_summary( requester.user.to_string(), room_id, suggested_only=suggested_only, @@ -1054,7 +1038,7 @@ class RoomHierarchyRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self._auth = hs.get_auth() - self._space_summary_handler = hs.get_space_summary_handler() + self._room_summary_handler = hs.get_room_summary_handler() async def on_GET( self, request: SynapseRequest, room_id: str @@ -1073,7 +1057,7 @@ class RoomHierarchyRestServlet(RestServlet): 400, "'limit' must be a positive integer", Codes.BAD_JSON ) - return 200, await self._space_summary_handler.get_room_hierarchy( + return 200, await self._room_summary_handler.get_room_hierarchy( requester.user.to_string(), room_id, suggested_only=parse_boolean(request, "suggested_only", default=False), @@ -1083,6 +1067,44 @@ class RoomHierarchyRestServlet(RestServlet): ) +class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/im.nheko.summary" + "/rooms/(?P[^/]*)/summary$" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + self._auth = hs.get_auth() + self._room_summary_handler = hs.get_room_summary_handler() + + async def on_GET( + self, request: SynapseRequest, room_identifier: str + ) -> Tuple[int, JsonDict]: + try: + requester = await self._auth.get_user_by_req(request, allow_guest=True) + requester_user_id: Optional[str] = requester.user.to_string() + except MissingClientTokenError: + # auth is optional + requester_user_id = None + + # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + args: Dict[bytes, List[bytes]] = request.args # type: ignore + remote_room_hosts = parse_strings_from_args(args, "via", required=False) + room_id, remote_room_hosts = await self.resolve_room_id( + room_identifier, + remote_room_hosts, + ) + + return 200, await self._room_summary_handler.get_room_summary( + requester_user_id, + room_id, + remote_room_hosts, + ) + + def register_servlets(hs: "HomeServer", http_server, is_worker=False): RoomStateEventRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server) @@ -1098,6 +1120,8 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False): RoomEventContextServlet(hs).register(http_server) RoomSpaceSummaryRestServlet(hs).register(http_server) RoomHierarchyRestServlet(hs).register(http_server) + if hs.config.experimental.msc3266_enabled: + RoomSummaryRestServlet(hs).register(http_server) RoomEventServlet(hs).register(http_server) JoinedRoomsRestServlet(hs).register(http_server) RoomAliasListServlet(hs).register(http_server) diff --git a/synapse/server.py b/synapse/server.py index 6c867f0f47..de6517663e 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -99,10 +99,10 @@ from synapse.handlers.room import ( from synapse.handlers.room_list import RoomListHandler from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler +from synapse.handlers.room_summary import RoomSummaryHandler from synapse.handlers.search import SearchHandler from synapse.handlers.send_email import SendEmailHandler from synapse.handlers.set_password import SetPasswordHandler -from synapse.handlers.space_summary import SpaceSummaryHandler from synapse.handlers.sso import SsoHandler from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler @@ -772,8 +772,8 @@ class HomeServer(metaclass=abc.ABCMeta): return AccountDataHandler(self) @cache_in_self - def get_space_summary_handler(self) -> SpaceSummaryHandler: - return SpaceSummaryHandler(self) + def get_room_summary_handler(self) -> RoomSummaryHandler: + return RoomSummaryHandler(self) @cache_in_self def get_event_auth_handler(self) -> EventAuthHandler: diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py new file mode 100644 index 0000000000..732d746e38 --- /dev/null +++ b/tests/handlers/test_room_summary.py @@ -0,0 +1,959 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Iterable, List, Optional, Tuple +from unittest import mock + +from synapse.api.constants import ( + EventContentFields, + EventTypes, + HistoryVisibility, + JoinRules, + Membership, + RestrictedJoinRuleTypes, + RoomTypes, +) +from synapse.api.errors import AuthError, NotFoundError, SynapseError +from synapse.api.room_versions import RoomVersions +from synapse.events import make_event_from_dict +from synapse.handlers.room_summary import _child_events_comparison_key, _RoomEntry +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.server import HomeServer +from synapse.types import JsonDict, UserID + +from tests import unittest + + +def _create_event(room_id: str, order: Optional[Any] = None): + result = mock.Mock() + result.room_id = room_id + result.content = {} + if order is not None: + result.content["order"] = order + return result + + +def _order(*events): + return sorted(events, key=_child_events_comparison_key) + + +class TestSpaceSummarySort(unittest.TestCase): + def test_no_order_last(self): + """An event with no ordering is placed behind those with an ordering.""" + ev1 = _create_event("!abc:test") + ev2 = _create_event("!xyz:test", "xyz") + + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + def test_order(self): + """The ordering should be used.""" + ev1 = _create_event("!abc:test", "xyz") + ev2 = _create_event("!xyz:test", "abc") + + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + def test_order_room_id(self): + """Room ID is a tie-breaker for ordering.""" + ev1 = _create_event("!abc:test", "abc") + ev2 = _create_event("!xyz:test", "abc") + + self.assertEqual([ev1, ev2], _order(ev1, ev2)) + + def test_invalid_ordering_type(self): + """Invalid orderings are considered the same as missing.""" + ev1 = _create_event("!abc:test", 1) + ev2 = _create_event("!xyz:test", "xyz") + + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + ev1 = _create_event("!abc:test", {}) + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + ev1 = _create_event("!abc:test", []) + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + ev1 = _create_event("!abc:test", True) + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + def test_invalid_ordering_value(self): + """Invalid orderings are considered the same as missing.""" + ev1 = _create_event("!abc:test", "foo\n") + ev2 = _create_event("!xyz:test", "xyz") + + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + ev1 = _create_event("!abc:test", "a" * 51) + self.assertEqual([ev2, ev1], _order(ev1, ev2)) + + +class SpaceSummaryTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs: HomeServer): + self.hs = hs + self.handler = self.hs.get_room_summary_handler() + + # Create a user. + self.user = self.register_user("user", "pass") + self.token = self.login("user", "pass") + + # Create a space and a child room. + self.space = self.helper.create_room_as( + self.user, + tok=self.token, + extra_content={ + "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} + }, + ) + self.room = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, self.room, self.token) + + def _add_child( + self, space_id: str, room_id: str, token: str, order: Optional[str] = None + ) -> None: + """Add a child room to a space.""" + content: JsonDict = {"via": [self.hs.hostname]} + if order is not None: + content["order"] = order + self.helper.send_state( + space_id, + event_type=EventTypes.SpaceChild, + body=content, + tok=token, + state_key=room_id, + ) + + def _assert_rooms( + self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] + ) -> None: + """ + Assert that the expected room IDs and events are in the response. + + Args: + result: The result from the API call. + rooms_and_children: An iterable of tuples where each tuple is: + The expected room ID. + The expected IDs of any children rooms. + """ + room_ids = [] + children_ids = [] + for room_id, children in rooms_and_children: + room_ids.append(room_id) + if children: + children_ids.extend([(room_id, child_id) for child_id in children]) + self.assertCountEqual( + [room.get("room_id") for room in result["rooms"]], room_ids + ) + self.assertCountEqual( + [ + (event.get("room_id"), event.get("state_key")) + for event in result["events"] + ], + children_ids, + ) + + def _assert_hierarchy( + self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] + ) -> None: + """ + Assert that the expected room IDs are in the response. + + Args: + result: The result from the API call. + rooms_and_children: An iterable of tuples where each tuple is: + The expected room ID. + The expected IDs of any children rooms. + """ + result_room_ids = [] + result_children_ids = [] + for result_room in result["rooms"]: + result_room_ids.append(result_room["room_id"]) + result_children_ids.append( + [ + (cs["room_id"], cs["state_key"]) + for cs in result_room.get("children_state") + ] + ) + + room_ids = [] + children_ids = [] + for room_id, children in rooms_and_children: + room_ids.append(room_id) + children_ids.append([(room_id, child_id) for child_id in children]) + + # Note that order matters. + self.assertEqual(result_room_ids, room_ids) + self.assertEqual(result_children_ids, children_ids) + + def _poke_fed_invite(self, room_id: str, from_user: str) -> None: + """ + Creates a invite (as if received over federation) for the room from the + given hostname. + + Args: + room_id: The room ID to issue an invite for. + fed_hostname: The user to invite from. + """ + # Poke an invite over federation into the database. + fed_handler = self.hs.get_federation_handler() + fed_hostname = UserID.from_string(from_user).domain + event = make_event_from_dict( + { + "room_id": room_id, + "event_id": "!abcd:" + fed_hostname, + "type": EventTypes.Member, + "sender": from_user, + "state_key": self.user, + "content": {"membership": Membership.INVITE}, + "prev_events": [], + "auth_events": [], + "depth": 1, + "origin_server_ts": 1234, + } + ) + self.get_success( + fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6) + ) + + def test_simple_space(self): + """Test a simple space with a single room.""" + result = self.get_success(self.handler.get_space_summary(self.user, self.space)) + # The result should have the space and the room in it, along with a link + # from space -> room. + expected = [(self.space, [self.room]), (self.room, ())] + self._assert_rooms(result, expected) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + def test_visibility(self): + """A user not in a space cannot inspect it.""" + user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") + + # The user can see the space since it is publicly joinable. + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + expected = [(self.space, [self.room]), (self.room, ())] + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) + + # If the space is made invite-only, it should no longer be viewable. + self.helper.send_state( + self.space, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.INVITE}, + tok=self.token, + ) + self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) + self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) + + # If the space is made world-readable it should return a result. + self.helper.send_state( + self.space, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.WORLD_READABLE}, + tok=self.token, + ) + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) + + # Make it not world-readable again and confirm it results in an error. + self.helper.send_state( + self.space, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.JOINED}, + tok=self.token, + ) + self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) + self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) + + # Join the space and results should be returned. + self.helper.invite(self.space, targ=user2, tok=self.token) + self.helper.join(self.space, user2, tok=token2) + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) + + # Attempting to view an unknown room returns the same error. + self.get_failure( + self.handler.get_space_summary(user2, "#not-a-space:" + self.hs.hostname), + AuthError, + ) + self.get_failure( + self.handler.get_room_hierarchy(user2, "#not-a-space:" + self.hs.hostname), + AuthError, + ) + + def _create_room_with_join_rule( + self, join_rule: str, room_version: Optional[str] = None, **extra_content + ) -> str: + """Create a room with the given join rule and add it to the space.""" + room_id = self.helper.create_room_as( + self.user, + room_version=room_version, + tok=self.token, + extra_content={ + "initial_state": [ + { + "type": EventTypes.JoinRules, + "state_key": "", + "content": { + "join_rule": join_rule, + **extra_content, + }, + } + ] + }, + ) + self._add_child(self.space, room_id, self.token) + return room_id + + def test_filtering(self): + """ + Rooms should be properly filtered to only include rooms the user has access to. + """ + user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") + + # Create a few rooms which will have different properties. + public_room = self._create_room_with_join_rule(JoinRules.PUBLIC) + knock_room = self._create_room_with_join_rule( + JoinRules.KNOCK, room_version=RoomVersions.V7.identifier + ) + not_invited_room = self._create_room_with_join_rule(JoinRules.INVITE) + invited_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.invite(invited_room, targ=user2, tok=self.token) + restricted_room = self._create_room_with_join_rule( + JoinRules.RESTRICTED, + room_version=RoomVersions.V8.identifier, + allow=[], + ) + restricted_accessible_room = self._create_room_with_join_rule( + JoinRules.RESTRICTED, + room_version=RoomVersions.V8.identifier, + allow=[ + { + "type": RestrictedJoinRuleTypes.ROOM_MEMBERSHIP, + "room_id": self.space, + "via": [self.hs.hostname], + } + ], + ) + world_readable_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.send_state( + world_readable_room, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.WORLD_READABLE}, + tok=self.token, + ) + joined_room = self._create_room_with_join_rule(JoinRules.INVITE) + self.helper.invite(joined_room, targ=user2, tok=self.token) + self.helper.join(joined_room, user2, tok=token2) + + # Join the space. + self.helper.join(self.space, user2, tok=token2) + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + expected = [ + ( + self.space, + [ + self.room, + public_room, + knock_room, + not_invited_room, + invited_room, + restricted_room, + restricted_accessible_room, + world_readable_room, + joined_room, + ], + ), + (self.room, ()), + (public_room, ()), + (knock_room, ()), + (invited_room, ()), + (restricted_accessible_room, ()), + (world_readable_room, ()), + (joined_room, ()), + ] + self._assert_rooms(result, expected) + + result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) + self._assert_hierarchy(result, expected) + + def test_complex_space(self): + """ + Create a "complex" space to see how it handles things like loops and subspaces. + """ + # Create an inaccessible room. + user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") + room2 = self.helper.create_room_as(user2, is_public=False, tok=token2) + # This is a bit odd as "user" is adding a room they don't know about, but + # it works for the tests. + self._add_child(self.space, room2, self.token) + + # Create a subspace under the space with an additional room in it. + subspace = self.helper.create_room_as( + self.user, + tok=self.token, + extra_content={ + "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} + }, + ) + subroom = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, subspace, token=self.token) + self._add_child(subspace, subroom, token=self.token) + # Also add the two rooms from the space into this subspace (causing loops). + self._add_child(subspace, self.room, token=self.token) + self._add_child(subspace, room2, self.token) + + result = self.get_success(self.handler.get_space_summary(self.user, self.space)) + + # The result should include each room a single time and each link. + expected = [ + (self.space, [self.room, room2, subspace]), + (self.room, ()), + (subspace, [subroom, self.room, room2]), + (subroom, ()), + ] + self._assert_rooms(result, expected) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + def test_pagination(self): + """Test simple pagination works.""" + room_ids = [] + for i in range(1, 10): + room = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, room, self.token, order=str(i)) + room_ids.append(room) + # The room created initially doesn't have an order, so comes last. + room_ids.append(self.room) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, limit=7) + ) + # The result should have the space and all of the links, plus some of the + # rooms and a pagination token. + expected: List[Tuple[str, Iterable[str]]] = [(self.space, room_ids)] + expected += [(room_id, ()) for room_id in room_ids[:6]] + self._assert_hierarchy(result, expected) + self.assertIn("next_batch", result) + + # Check the next page. + result = self.get_success( + self.handler.get_room_hierarchy( + self.user, self.space, limit=5, from_token=result["next_batch"] + ) + ) + # The result should have the space and the room in it, along with a link + # from space -> room. + expected = [(room_id, ()) for room_id in room_ids[6:]] + self._assert_hierarchy(result, expected) + self.assertNotIn("next_batch", result) + + def test_invalid_pagination_token(self): + """An invalid pagination token, or changing other parameters, shoudl be rejected.""" + room_ids = [] + for i in range(1, 10): + room = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, room, self.token, order=str(i)) + room_ids.append(room) + # The room created initially doesn't have an order, so comes last. + room_ids.append(self.room) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, limit=7) + ) + self.assertIn("next_batch", result) + + # Changing the room ID, suggested-only, or max-depth causes an error. + self.get_failure( + self.handler.get_room_hierarchy( + self.user, self.room, from_token=result["next_batch"] + ), + SynapseError, + ) + self.get_failure( + self.handler.get_room_hierarchy( + self.user, + self.space, + suggested_only=True, + from_token=result["next_batch"], + ), + SynapseError, + ) + self.get_failure( + self.handler.get_room_hierarchy( + self.user, self.space, max_depth=0, from_token=result["next_batch"] + ), + SynapseError, + ) + + # An invalid token is ignored. + self.get_failure( + self.handler.get_room_hierarchy(self.user, self.space, from_token="foo"), + SynapseError, + ) + + def test_max_depth(self): + """Create a deep tree to test the max depth against.""" + spaces = [self.space] + rooms = [self.room] + for _ in range(5): + spaces.append( + self.helper.create_room_as( + self.user, + tok=self.token, + extra_content={ + "creation_content": { + EventContentFields.ROOM_TYPE: RoomTypes.SPACE + } + }, + ) + ) + self._add_child(spaces[-2], spaces[-1], self.token) + rooms.append(self.helper.create_room_as(self.user, tok=self.token)) + self._add_child(spaces[-1], rooms[-1], self.token) + + # Test just the space itself. + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, max_depth=0) + ) + expected: List[Tuple[str, Iterable[str]]] = [(spaces[0], [rooms[0], spaces[1]])] + self._assert_hierarchy(result, expected) + + # A single additional layer. + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, max_depth=1) + ) + expected += [ + (rooms[0], ()), + (spaces[1], [rooms[1], spaces[2]]), + ] + self._assert_hierarchy(result, expected) + + # A few layers. + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space, max_depth=3) + ) + expected += [ + (rooms[1], ()), + (spaces[2], [rooms[2], spaces[3]]), + (rooms[2], ()), + (spaces[3], [rooms[3], spaces[4]]), + ] + self._assert_hierarchy(result, expected) + + def test_fed_complex(self): + """ + Return data over federation and ensure that it is handled properly. + """ + fed_hostname = self.hs.hostname + "2" + subspace = "#subspace:" + fed_hostname + subroom = "#subroom:" + fed_hostname + + # Generate some good data, and some bad data: + # + # * Event *back* to the root room. + # * Unrelated events / rooms + # * Multiple levels of events (in a not-useful order, e.g. grandchild + # events before child events). + + # Note that these entries are brief, but should contain enough info. + requested_room_entry = _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + "room_type": RoomTypes.SPACE, + }, + [ + { + "type": EventTypes.SpaceChild, + "room_id": subspace, + "state_key": subroom, + "content": {"via": [fed_hostname]}, + } + ], + ) + child_room = { + "room_id": subroom, + "world_readable": True, + } + + async def summarize_remote_room( + _self, room, suggested_only, max_children, exclude_rooms + ): + return [ + requested_room_entry, + _RoomEntry( + subroom, + { + "room_id": subroom, + "world_readable": True, + }, + ), + ] + + async def summarize_remote_room_hierarchy(_self, room, suggested_only): + return requested_room_entry, {subroom: child_room}, set() + + # Add a room to the space which is on another server. + self._add_child(self.space, subspace, self.token) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", + new=summarize_remote_room, + ): + result = self.get_success( + self.handler.get_space_summary(self.user, self.space) + ) + + expected = [ + (self.space, [self.room, subspace]), + (self.room, ()), + (subspace, [subroom]), + (subroom, ()), + ] + self._assert_rooms(result, expected) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", + new=summarize_remote_room_hierarchy, + ): + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + def test_fed_filtering(self): + """ + Rooms returned over federation should be properly filtered to only include + rooms the user has access to. + """ + fed_hostname = self.hs.hostname + "2" + subspace = "#subspace:" + fed_hostname + + # Create a few rooms which will have different properties. + public_room = "#public:" + fed_hostname + knock_room = "#knock:" + fed_hostname + not_invited_room = "#not_invited:" + fed_hostname + invited_room = "#invited:" + fed_hostname + restricted_room = "#restricted:" + fed_hostname + restricted_accessible_room = "#restricted_accessible:" + fed_hostname + world_readable_room = "#world_readable:" + fed_hostname + joined_room = self.helper.create_room_as(self.user, tok=self.token) + + # Poke an invite over federation into the database. + self._poke_fed_invite(invited_room, "@remote:" + fed_hostname) + + # Note that these entries are brief, but should contain enough info. + children_rooms = ( + ( + public_room, + { + "room_id": public_room, + "world_readable": False, + "join_rules": JoinRules.PUBLIC, + }, + ), + ( + knock_room, + { + "room_id": knock_room, + "world_readable": False, + "join_rules": JoinRules.KNOCK, + }, + ), + ( + not_invited_room, + { + "room_id": not_invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + ( + invited_room, + { + "room_id": invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + ( + restricted_room, + { + "room_id": restricted_room, + "world_readable": False, + "join_rules": JoinRules.RESTRICTED, + "allowed_spaces": [], + }, + ), + ( + restricted_accessible_room, + { + "room_id": restricted_accessible_room, + "world_readable": False, + "join_rules": JoinRules.RESTRICTED, + "allowed_spaces": [self.room], + }, + ), + ( + world_readable_room, + { + "room_id": world_readable_room, + "world_readable": True, + "join_rules": JoinRules.INVITE, + }, + ), + ( + joined_room, + { + "room_id": joined_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + ) + + subspace_room_entry = _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + }, + # Place each room in the sub-space. + [ + { + "type": EventTypes.SpaceChild, + "room_id": subspace, + "state_key": room_id, + "content": {"via": [fed_hostname]}, + } + for room_id, _ in children_rooms + ], + ) + + async def summarize_remote_room( + _self, room, suggested_only, max_children, exclude_rooms + ): + return [subspace_room_entry] + [ + # A copy is made of the room data since the allowed_spaces key + # is removed. + _RoomEntry(child_room[0], dict(child_room[1])) + for child_room in children_rooms + ] + + async def summarize_remote_room_hierarchy(_self, room, suggested_only): + return subspace_room_entry, dict(children_rooms), set() + + # Add a room to the space which is on another server. + self._add_child(self.space, subspace, self.token) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", + new=summarize_remote_room, + ): + result = self.get_success( + self.handler.get_space_summary(self.user, self.space) + ) + + expected = [ + (self.space, [self.room, subspace]), + (self.room, ()), + ( + subspace, + [ + public_room, + knock_room, + not_invited_room, + invited_room, + restricted_room, + restricted_accessible_room, + world_readable_room, + joined_room, + ], + ), + (public_room, ()), + (knock_room, ()), + (invited_room, ()), + (restricted_accessible_room, ()), + (world_readable_room, ()), + (joined_room, ()), + ] + self._assert_rooms(result, expected) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", + new=summarize_remote_room_hierarchy, + ): + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + def test_fed_invited(self): + """ + A room which the user was invited to should be included in the response. + + This differs from test_fed_filtering in that the room itself is being + queried over federation, instead of it being included as a sub-room of + a space in the response. + """ + fed_hostname = self.hs.hostname + "2" + fed_room = "#subroom:" + fed_hostname + + # Poke an invite over federation into the database. + self._poke_fed_invite(fed_room, "@remote:" + fed_hostname) + + fed_room_entry = _RoomEntry( + fed_room, + { + "room_id": fed_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ) + + async def summarize_remote_room( + _self, room, suggested_only, max_children, exclude_rooms + ): + return [fed_room_entry] + + async def summarize_remote_room_hierarchy(_self, room, suggested_only): + return fed_room_entry, {}, set() + + # Add a room to the space which is on another server. + self._add_child(self.space, fed_room, self.token) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room", + new=summarize_remote_room, + ): + result = self.get_success( + self.handler.get_space_summary(self.user, self.space) + ) + + expected = [ + (self.space, [self.room, fed_room]), + (self.room, ()), + (fed_room, ()), + ] + self._assert_rooms(result, expected) + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", + new=summarize_remote_room_hierarchy, + ): + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + + +class RoomSummaryTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs: HomeServer): + self.hs = hs + self.handler = self.hs.get_room_summary_handler() + + # Create a user. + self.user = self.register_user("user", "pass") + self.token = self.login("user", "pass") + + # Create a simple room. + self.room = self.helper.create_room_as(self.user, tok=self.token) + self.helper.send_state( + self.room, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.INVITE}, + tok=self.token, + ) + + def test_own_room(self): + """Test a simple room created by the requester.""" + result = self.get_success(self.handler.get_room_summary(self.user, self.room)) + self.assertEqual(result.get("room_id"), self.room) + + def test_visibility(self): + """A user not in a private room cannot get its summary.""" + user2 = self.register_user("user2", "pass") + token2 = self.login("user2", "pass") + + # The user cannot see the room. + self.get_failure(self.handler.get_room_summary(user2, self.room), NotFoundError) + + # If the room is made world-readable it should return a result. + self.helper.send_state( + self.room, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.WORLD_READABLE}, + tok=self.token, + ) + result = self.get_success(self.handler.get_room_summary(user2, self.room)) + self.assertEqual(result.get("room_id"), self.room) + + # Make it not world-readable again and confirm it results in an error. + self.helper.send_state( + self.room, + event_type=EventTypes.RoomHistoryVisibility, + body={"history_visibility": HistoryVisibility.JOINED}, + tok=self.token, + ) + self.get_failure(self.handler.get_room_summary(user2, self.room), NotFoundError) + + # If the room is made public it should return a result. + self.helper.send_state( + self.room, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.PUBLIC}, + tok=self.token, + ) + result = self.get_success(self.handler.get_room_summary(user2, self.room)) + self.assertEqual(result.get("room_id"), self.room) + + # Join the space, make it invite-only again and results should be returned. + self.helper.join(self.room, user2, tok=token2) + self.helper.send_state( + self.room, + event_type=EventTypes.JoinRules, + body={"join_rule": JoinRules.INVITE}, + tok=self.token, + ) + result = self.get_success(self.handler.get_room_summary(user2, self.room)) + self.assertEqual(result.get("room_id"), self.room) diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py deleted file mode 100644 index bc8e131f4a..0000000000 --- a/tests/handlers/test_space_summary.py +++ /dev/null @@ -1,881 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Iterable, List, Optional, Tuple -from unittest import mock - -from synapse.api.constants import ( - EventContentFields, - EventTypes, - HistoryVisibility, - JoinRules, - Membership, - RestrictedJoinRuleTypes, - RoomTypes, -) -from synapse.api.errors import AuthError, SynapseError -from synapse.api.room_versions import RoomVersions -from synapse.events import make_event_from_dict -from synapse.handlers.space_summary import _child_events_comparison_key, _RoomEntry -from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.server import HomeServer -from synapse.types import JsonDict, UserID - -from tests import unittest - - -def _create_event(room_id: str, order: Optional[Any] = None): - result = mock.Mock() - result.room_id = room_id - result.content = {} - if order is not None: - result.content["order"] = order - return result - - -def _order(*events): - return sorted(events, key=_child_events_comparison_key) - - -class TestSpaceSummarySort(unittest.TestCase): - def test_no_order_last(self): - """An event with no ordering is placed behind those with an ordering.""" - ev1 = _create_event("!abc:test") - ev2 = _create_event("!xyz:test", "xyz") - - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - def test_order(self): - """The ordering should be used.""" - ev1 = _create_event("!abc:test", "xyz") - ev2 = _create_event("!xyz:test", "abc") - - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - def test_order_room_id(self): - """Room ID is a tie-breaker for ordering.""" - ev1 = _create_event("!abc:test", "abc") - ev2 = _create_event("!xyz:test", "abc") - - self.assertEqual([ev1, ev2], _order(ev1, ev2)) - - def test_invalid_ordering_type(self): - """Invalid orderings are considered the same as missing.""" - ev1 = _create_event("!abc:test", 1) - ev2 = _create_event("!xyz:test", "xyz") - - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - ev1 = _create_event("!abc:test", {}) - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - ev1 = _create_event("!abc:test", []) - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - ev1 = _create_event("!abc:test", True) - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - def test_invalid_ordering_value(self): - """Invalid orderings are considered the same as missing.""" - ev1 = _create_event("!abc:test", "foo\n") - ev2 = _create_event("!xyz:test", "xyz") - - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - ev1 = _create_event("!abc:test", "a" * 51) - self.assertEqual([ev2, ev1], _order(ev1, ev2)) - - -class SpaceSummaryTestCase(unittest.HomeserverTestCase): - servlets = [ - admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def prepare(self, reactor, clock, hs: HomeServer): - self.hs = hs - self.handler = self.hs.get_space_summary_handler() - - # Create a user. - self.user = self.register_user("user", "pass") - self.token = self.login("user", "pass") - - # Create a space and a child room. - self.space = self.helper.create_room_as( - self.user, - tok=self.token, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - self.room = self.helper.create_room_as(self.user, tok=self.token) - self._add_child(self.space, self.room, self.token) - - def _add_child( - self, space_id: str, room_id: str, token: str, order: Optional[str] = None - ) -> None: - """Add a child room to a space.""" - content: JsonDict = {"via": [self.hs.hostname]} - if order is not None: - content["order"] = order - self.helper.send_state( - space_id, - event_type=EventTypes.SpaceChild, - body=content, - tok=token, - state_key=room_id, - ) - - def _assert_rooms( - self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] - ) -> None: - """ - Assert that the expected room IDs and events are in the response. - - Args: - result: The result from the API call. - rooms_and_children: An iterable of tuples where each tuple is: - The expected room ID. - The expected IDs of any children rooms. - """ - room_ids = [] - children_ids = [] - for room_id, children in rooms_and_children: - room_ids.append(room_id) - if children: - children_ids.extend([(room_id, child_id) for child_id in children]) - self.assertCountEqual( - [room.get("room_id") for room in result["rooms"]], room_ids - ) - self.assertCountEqual( - [ - (event.get("room_id"), event.get("state_key")) - for event in result["events"] - ], - children_ids, - ) - - def _assert_hierarchy( - self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]] - ) -> None: - """ - Assert that the expected room IDs are in the response. - - Args: - result: The result from the API call. - rooms_and_children: An iterable of tuples where each tuple is: - The expected room ID. - The expected IDs of any children rooms. - """ - result_room_ids = [] - result_children_ids = [] - for result_room in result["rooms"]: - result_room_ids.append(result_room["room_id"]) - result_children_ids.append( - [ - (cs["room_id"], cs["state_key"]) - for cs in result_room.get("children_state") - ] - ) - - room_ids = [] - children_ids = [] - for room_id, children in rooms_and_children: - room_ids.append(room_id) - children_ids.append([(room_id, child_id) for child_id in children]) - - # Note that order matters. - self.assertEqual(result_room_ids, room_ids) - self.assertEqual(result_children_ids, children_ids) - - def _poke_fed_invite(self, room_id: str, from_user: str) -> None: - """ - Creates a invite (as if received over federation) for the room from the - given hostname. - - Args: - room_id: The room ID to issue an invite for. - fed_hostname: The user to invite from. - """ - # Poke an invite over federation into the database. - fed_handler = self.hs.get_federation_handler() - fed_hostname = UserID.from_string(from_user).domain - event = make_event_from_dict( - { - "room_id": room_id, - "event_id": "!abcd:" + fed_hostname, - "type": EventTypes.Member, - "sender": from_user, - "state_key": self.user, - "content": {"membership": Membership.INVITE}, - "prev_events": [], - "auth_events": [], - "depth": 1, - "origin_server_ts": 1234, - } - ) - self.get_success( - fed_handler.on_invite_request(fed_hostname, event, RoomVersions.V6) - ) - - def test_simple_space(self): - """Test a simple space with a single room.""" - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) - # The result should have the space and the room in it, along with a link - # from space -> room. - expected = [(self.space, [self.room]), (self.room, ())] - self._assert_rooms(result, expected) - - result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) - ) - self._assert_hierarchy(result, expected) - - def test_visibility(self): - """A user not in a space cannot inspect it.""" - user2 = self.register_user("user2", "pass") - token2 = self.login("user2", "pass") - - # The user can see the space since it is publicly joinable. - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - expected = [(self.space, [self.room]), (self.room, ())] - self._assert_rooms(result, expected) - - result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) - self._assert_hierarchy(result, expected) - - # If the space is made invite-only, it should no longer be viewable. - self.helper.send_state( - self.space, - event_type=EventTypes.JoinRules, - body={"join_rule": JoinRules.INVITE}, - tok=self.token, - ) - self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) - self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) - - # If the space is made world-readable it should return a result. - self.helper.send_state( - self.space, - event_type=EventTypes.RoomHistoryVisibility, - body={"history_visibility": HistoryVisibility.WORLD_READABLE}, - tok=self.token, - ) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, expected) - - result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) - self._assert_hierarchy(result, expected) - - # Make it not world-readable again and confirm it results in an error. - self.helper.send_state( - self.space, - event_type=EventTypes.RoomHistoryVisibility, - body={"history_visibility": HistoryVisibility.JOINED}, - tok=self.token, - ) - self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) - self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError) - - # Join the space and results should be returned. - self.helper.invite(self.space, targ=user2, tok=self.token) - self.helper.join(self.space, user2, tok=token2) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - self._assert_rooms(result, expected) - - result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) - self._assert_hierarchy(result, expected) - - # Attempting to view an unknown room returns the same error. - self.get_failure( - self.handler.get_space_summary(user2, "#not-a-space:" + self.hs.hostname), - AuthError, - ) - self.get_failure( - self.handler.get_room_hierarchy(user2, "#not-a-space:" + self.hs.hostname), - AuthError, - ) - - def _create_room_with_join_rule( - self, join_rule: str, room_version: Optional[str] = None, **extra_content - ) -> str: - """Create a room with the given join rule and add it to the space.""" - room_id = self.helper.create_room_as( - self.user, - room_version=room_version, - tok=self.token, - extra_content={ - "initial_state": [ - { - "type": EventTypes.JoinRules, - "state_key": "", - "content": { - "join_rule": join_rule, - **extra_content, - }, - } - ] - }, - ) - self._add_child(self.space, room_id, self.token) - return room_id - - def test_filtering(self): - """ - Rooms should be properly filtered to only include rooms the user has access to. - """ - user2 = self.register_user("user2", "pass") - token2 = self.login("user2", "pass") - - # Create a few rooms which will have different properties. - public_room = self._create_room_with_join_rule(JoinRules.PUBLIC) - knock_room = self._create_room_with_join_rule( - JoinRules.KNOCK, room_version=RoomVersions.V7.identifier - ) - not_invited_room = self._create_room_with_join_rule(JoinRules.INVITE) - invited_room = self._create_room_with_join_rule(JoinRules.INVITE) - self.helper.invite(invited_room, targ=user2, tok=self.token) - restricted_room = self._create_room_with_join_rule( - JoinRules.RESTRICTED, - room_version=RoomVersions.V8.identifier, - allow=[], - ) - restricted_accessible_room = self._create_room_with_join_rule( - JoinRules.RESTRICTED, - room_version=RoomVersions.V8.identifier, - allow=[ - { - "type": RestrictedJoinRuleTypes.ROOM_MEMBERSHIP, - "room_id": self.space, - "via": [self.hs.hostname], - } - ], - ) - world_readable_room = self._create_room_with_join_rule(JoinRules.INVITE) - self.helper.send_state( - world_readable_room, - event_type=EventTypes.RoomHistoryVisibility, - body={"history_visibility": HistoryVisibility.WORLD_READABLE}, - tok=self.token, - ) - joined_room = self._create_room_with_join_rule(JoinRules.INVITE) - self.helper.invite(joined_room, targ=user2, tok=self.token) - self.helper.join(joined_room, user2, tok=token2) - - # Join the space. - self.helper.join(self.space, user2, tok=token2) - result = self.get_success(self.handler.get_space_summary(user2, self.space)) - expected = [ - ( - self.space, - [ - self.room, - public_room, - knock_room, - not_invited_room, - invited_room, - restricted_room, - restricted_accessible_room, - world_readable_room, - joined_room, - ], - ), - (self.room, ()), - (public_room, ()), - (knock_room, ()), - (invited_room, ()), - (restricted_accessible_room, ()), - (world_readable_room, ()), - (joined_room, ()), - ] - self._assert_rooms(result, expected) - - result = self.get_success(self.handler.get_room_hierarchy(user2, self.space)) - self._assert_hierarchy(result, expected) - - def test_complex_space(self): - """ - Create a "complex" space to see how it handles things like loops and subspaces. - """ - # Create an inaccessible room. - user2 = self.register_user("user2", "pass") - token2 = self.login("user2", "pass") - room2 = self.helper.create_room_as(user2, is_public=False, tok=token2) - # This is a bit odd as "user" is adding a room they don't know about, but - # it works for the tests. - self._add_child(self.space, room2, self.token) - - # Create a subspace under the space with an additional room in it. - subspace = self.helper.create_room_as( - self.user, - tok=self.token, - extra_content={ - "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} - }, - ) - subroom = self.helper.create_room_as(self.user, tok=self.token) - self._add_child(self.space, subspace, token=self.token) - self._add_child(subspace, subroom, token=self.token) - # Also add the two rooms from the space into this subspace (causing loops). - self._add_child(subspace, self.room, token=self.token) - self._add_child(subspace, room2, self.token) - - result = self.get_success(self.handler.get_space_summary(self.user, self.space)) - - # The result should include each room a single time and each link. - expected = [ - (self.space, [self.room, room2, subspace]), - (self.room, ()), - (subspace, [subroom, self.room, room2]), - (subroom, ()), - ] - self._assert_rooms(result, expected) - - result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) - ) - self._assert_hierarchy(result, expected) - - def test_pagination(self): - """Test simple pagination works.""" - room_ids = [] - for i in range(1, 10): - room = self.helper.create_room_as(self.user, tok=self.token) - self._add_child(self.space, room, self.token, order=str(i)) - room_ids.append(room) - # The room created initially doesn't have an order, so comes last. - room_ids.append(self.room) - - result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space, limit=7) - ) - # The result should have the space and all of the links, plus some of the - # rooms and a pagination token. - expected: List[Tuple[str, Iterable[str]]] = [(self.space, room_ids)] - expected += [(room_id, ()) for room_id in room_ids[:6]] - self._assert_hierarchy(result, expected) - self.assertIn("next_batch", result) - - # Check the next page. - result = self.get_success( - self.handler.get_room_hierarchy( - self.user, self.space, limit=5, from_token=result["next_batch"] - ) - ) - # The result should have the space and the room in it, along with a link - # from space -> room. - expected = [(room_id, ()) for room_id in room_ids[6:]] - self._assert_hierarchy(result, expected) - self.assertNotIn("next_batch", result) - - def test_invalid_pagination_token(self): - """An invalid pagination token, or changing other parameters, shoudl be rejected.""" - room_ids = [] - for i in range(1, 10): - room = self.helper.create_room_as(self.user, tok=self.token) - self._add_child(self.space, room, self.token, order=str(i)) - room_ids.append(room) - # The room created initially doesn't have an order, so comes last. - room_ids.append(self.room) - - result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space, limit=7) - ) - self.assertIn("next_batch", result) - - # Changing the room ID, suggested-only, or max-depth causes an error. - self.get_failure( - self.handler.get_room_hierarchy( - self.user, self.room, from_token=result["next_batch"] - ), - SynapseError, - ) - self.get_failure( - self.handler.get_room_hierarchy( - self.user, - self.space, - suggested_only=True, - from_token=result["next_batch"], - ), - SynapseError, - ) - self.get_failure( - self.handler.get_room_hierarchy( - self.user, self.space, max_depth=0, from_token=result["next_batch"] - ), - SynapseError, - ) - - # An invalid token is ignored. - self.get_failure( - self.handler.get_room_hierarchy(self.user, self.space, from_token="foo"), - SynapseError, - ) - - def test_max_depth(self): - """Create a deep tree to test the max depth against.""" - spaces = [self.space] - rooms = [self.room] - for _ in range(5): - spaces.append( - self.helper.create_room_as( - self.user, - tok=self.token, - extra_content={ - "creation_content": { - EventContentFields.ROOM_TYPE: RoomTypes.SPACE - } - }, - ) - ) - self._add_child(spaces[-2], spaces[-1], self.token) - rooms.append(self.helper.create_room_as(self.user, tok=self.token)) - self._add_child(spaces[-1], rooms[-1], self.token) - - # Test just the space itself. - result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space, max_depth=0) - ) - expected: List[Tuple[str, Iterable[str]]] = [(spaces[0], [rooms[0], spaces[1]])] - self._assert_hierarchy(result, expected) - - # A single additional layer. - result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space, max_depth=1) - ) - expected += [ - (rooms[0], ()), - (spaces[1], [rooms[1], spaces[2]]), - ] - self._assert_hierarchy(result, expected) - - # A few layers. - result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space, max_depth=3) - ) - expected += [ - (rooms[1], ()), - (spaces[2], [rooms[2], spaces[3]]), - (rooms[2], ()), - (spaces[3], [rooms[3], spaces[4]]), - ] - self._assert_hierarchy(result, expected) - - def test_fed_complex(self): - """ - Return data over federation and ensure that it is handled properly. - """ - fed_hostname = self.hs.hostname + "2" - subspace = "#subspace:" + fed_hostname - subroom = "#subroom:" + fed_hostname - - # Generate some good data, and some bad data: - # - # * Event *back* to the root room. - # * Unrelated events / rooms - # * Multiple levels of events (in a not-useful order, e.g. grandchild - # events before child events). - - # Note that these entries are brief, but should contain enough info. - requested_room_entry = _RoomEntry( - subspace, - { - "room_id": subspace, - "world_readable": True, - "room_type": RoomTypes.SPACE, - }, - [ - { - "type": EventTypes.SpaceChild, - "room_id": subspace, - "state_key": subroom, - "content": {"via": [fed_hostname]}, - } - ], - ) - child_room = { - "room_id": subroom, - "world_readable": True, - } - - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - return [ - requested_room_entry, - _RoomEntry( - subroom, - { - "room_id": subroom, - "world_readable": True, - }, - ), - ] - - async def summarize_remote_room_hiearchy(_self, room, suggested_only): - return requested_room_entry, {subroom: child_room}, set() - - # Add a room to the space which is on another server. - self._add_child(self.space, subspace, self.token) - - with mock.patch( - "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - - expected = [ - (self.space, [self.room, subspace]), - (self.room, ()), - (subspace, [subroom]), - (subroom, ()), - ] - self._assert_rooms(result, expected) - - with mock.patch( - "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room_hiearchy", - new=summarize_remote_room_hiearchy, - ): - result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) - ) - self._assert_hierarchy(result, expected) - - def test_fed_filtering(self): - """ - Rooms returned over federation should be properly filtered to only include - rooms the user has access to. - """ - fed_hostname = self.hs.hostname + "2" - subspace = "#subspace:" + fed_hostname - - # Create a few rooms which will have different properties. - public_room = "#public:" + fed_hostname - knock_room = "#knock:" + fed_hostname - not_invited_room = "#not_invited:" + fed_hostname - invited_room = "#invited:" + fed_hostname - restricted_room = "#restricted:" + fed_hostname - restricted_accessible_room = "#restricted_accessible:" + fed_hostname - world_readable_room = "#world_readable:" + fed_hostname - joined_room = self.helper.create_room_as(self.user, tok=self.token) - - # Poke an invite over federation into the database. - self._poke_fed_invite(invited_room, "@remote:" + fed_hostname) - - # Note that these entries are brief, but should contain enough info. - children_rooms = ( - ( - public_room, - { - "room_id": public_room, - "world_readable": False, - "join_rules": JoinRules.PUBLIC, - }, - ), - ( - knock_room, - { - "room_id": knock_room, - "world_readable": False, - "join_rules": JoinRules.KNOCK, - }, - ), - ( - not_invited_room, - { - "room_id": not_invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ), - ( - invited_room, - { - "room_id": invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ), - ( - restricted_room, - { - "room_id": restricted_room, - "world_readable": False, - "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [], - }, - ), - ( - restricted_accessible_room, - { - "room_id": restricted_accessible_room, - "world_readable": False, - "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [self.room], - }, - ), - ( - world_readable_room, - { - "room_id": world_readable_room, - "world_readable": True, - "join_rules": JoinRules.INVITE, - }, - ), - ( - joined_room, - { - "room_id": joined_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ), - ) - - subspace_room_entry = _RoomEntry( - subspace, - { - "room_id": subspace, - "world_readable": True, - }, - # Place each room in the sub-space. - [ - { - "type": EventTypes.SpaceChild, - "room_id": subspace, - "state_key": room_id, - "content": {"via": [fed_hostname]}, - } - for room_id, _ in children_rooms - ], - ) - - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - return [subspace_room_entry] + [ - # A copy is made of the room data since the allowed_spaces key - # is removed. - _RoomEntry(child_room[0], dict(child_room[1])) - for child_room in children_rooms - ] - - async def summarize_remote_room_hiearchy(_self, room, suggested_only): - return subspace_room_entry, dict(children_rooms), set() - - # Add a room to the space which is on another server. - self._add_child(self.space, subspace, self.token) - - with mock.patch( - "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - - expected = [ - (self.space, [self.room, subspace]), - (self.room, ()), - ( - subspace, - [ - public_room, - knock_room, - not_invited_room, - invited_room, - restricted_room, - restricted_accessible_room, - world_readable_room, - joined_room, - ], - ), - (public_room, ()), - (knock_room, ()), - (invited_room, ()), - (restricted_accessible_room, ()), - (world_readable_room, ()), - (joined_room, ()), - ] - self._assert_rooms(result, expected) - - with mock.patch( - "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room_hiearchy", - new=summarize_remote_room_hiearchy, - ): - result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) - ) - self._assert_hierarchy(result, expected) - - def test_fed_invited(self): - """ - A room which the user was invited to should be included in the response. - - This differs from test_fed_filtering in that the room itself is being - queried over federation, instead of it being included as a sub-room of - a space in the response. - """ - fed_hostname = self.hs.hostname + "2" - fed_room = "#subroom:" + fed_hostname - - # Poke an invite over federation into the database. - self._poke_fed_invite(fed_room, "@remote:" + fed_hostname) - - fed_room_entry = _RoomEntry( - fed_room, - { - "room_id": fed_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ) - - async def summarize_remote_room( - _self, room, suggested_only, max_children, exclude_rooms - ): - return [fed_room_entry] - - async def summarize_remote_room_hiearchy(_self, room, suggested_only): - return fed_room_entry, {}, set() - - # Add a room to the space which is on another server. - self._add_child(self.space, fed_room, self.token) - - with mock.patch( - "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room", - new=summarize_remote_room, - ): - result = self.get_success( - self.handler.get_space_summary(self.user, self.space) - ) - - expected = [ - (self.space, [self.room, fed_room]), - (self.room, ()), - (fed_room, ()), - ] - self._assert_rooms(result, expected) - - with mock.patch( - "synapse.handlers.space_summary.SpaceSummaryHandler._summarize_remote_room_hiearchy", - new=summarize_remote_room_hiearchy, - ): - result = self.get_success( - self.handler.get_room_hierarchy(self.user, self.space) - ) - self._assert_hierarchy(result, expected) -- cgit 1.5.1 From 5af83efe8d106ee6fe6568f6758d458159341531 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 16 Aug 2021 12:01:30 -0400 Subject: Validate the max_rooms_per_space parameter to ensure it is non-negative. (#10611) --- changelog.d/10611.bugfix | 1 + synapse/federation/transport/server/federation.py | 22 ++++++++++++++++---- synapse/rest/client/v1/room.py | 25 ++++++++++++++++++----- 3 files changed, 39 insertions(+), 9 deletions(-) create mode 100644 changelog.d/10611.bugfix diff --git a/changelog.d/10611.bugfix b/changelog.d/10611.bugfix new file mode 100644 index 0000000000..ecbe408b47 --- /dev/null +++ b/changelog.d/10611.bugfix @@ -0,0 +1 @@ +Additional validation for the spaces summary API to avoid errors like `ValueError: Stop argument for islice() must be None or an integer`. The missing validation has existed since v1.31.0. diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 7d81cc642c..2fdf6cc99e 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -557,7 +557,14 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): room_id: str, ) -> Tuple[int, JsonDict]: suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) + max_rooms_per_space = parse_integer_from_args(query, "max_rooms_per_space") + if max_rooms_per_space is not None and max_rooms_per_space < 0: + raise SynapseError( + 400, + "Value for 'max_rooms_per_space' must be a non-negative integer", + Codes.BAD_JSON, + ) exclude_rooms = parse_strings_from_args(query, "exclude_rooms", default=[]) @@ -586,10 +593,17 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): raise SynapseError(400, "bad value for 'exclude_rooms'", Codes.BAD_JSON) max_rooms_per_space = content.get("max_rooms_per_space") - if max_rooms_per_space is not None and not isinstance(max_rooms_per_space, int): - raise SynapseError( - 400, "bad value for 'max_rooms_per_space'", Codes.BAD_JSON - ) + if max_rooms_per_space is not None: + if not isinstance(max_rooms_per_space, int): + raise SynapseError( + 400, "bad value for 'max_rooms_per_space'", Codes.BAD_JSON + ) + if max_rooms_per_space < 0: + raise SynapseError( + 400, + "Value for 'max_rooms_per_space' must be a non-negative integer", + Codes.BAD_JSON, + ) return 200, await self.handler.federation_space_summary( origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index d3882a84e2..ba7250ad8e 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -993,11 +993,19 @@ class RoomSpaceSummaryRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self._auth.get_user_by_req(request, allow_guest=True) + max_rooms_per_space = parse_integer(request, "max_rooms_per_space") + if max_rooms_per_space is not None and max_rooms_per_space < 0: + raise SynapseError( + 400, + "Value for 'max_rooms_per_space' must be a non-negative integer", + Codes.BAD_JSON, + ) + return 200, await self._room_summary_handler.get_space_summary( requester.user.to_string(), room_id, suggested_only=parse_boolean(request, "suggested_only", default=False), - max_rooms_per_space=parse_integer(request, "max_rooms_per_space"), + max_rooms_per_space=max_rooms_per_space, ) # TODO When switching to the stable endpoint, remove the POST handler. @@ -1014,10 +1022,17 @@ class RoomSpaceSummaryRestServlet(RestServlet): ) max_rooms_per_space = content.get("max_rooms_per_space") - if max_rooms_per_space is not None and not isinstance(max_rooms_per_space, int): - raise SynapseError( - 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON - ) + if max_rooms_per_space is not None: + if not isinstance(max_rooms_per_space, int): + raise SynapseError( + 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON + ) + if max_rooms_per_space < 0: + raise SynapseError( + 400, + "Value for 'max_rooms_per_space' must be a non-negative integer", + Codes.BAD_JSON, + ) return 200, await self._room_summary_handler.get_space_summary( requester.user.to_string(), -- cgit 1.5.1 From 0db8cab72c8a39b4e8154295d473fbbc154854b4 Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Mon, 16 Aug 2021 18:09:47 +0100 Subject: Update CONTRIBUTING.md to fix index links and SyTest instructions (#10599) Signed-off-by: Olivier Wilkinson (reivilibre) --- CONTRIBUTING.md | 7 ++++--- changelog.d/10599.doc | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) create mode 100644 changelog.d/10599.doc diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4486a4b2cd..cd6c34df85 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -13,8 +13,9 @@ This document aims to get you started with contributing to this repo! - [7. Turn coffee and documentation into code and documentation!](#7-turn-coffee-and-documentation-into-code-and-documentation) - [8. Test, test, test!](#8-test-test-test) * [Run the linters.](#run-the-linters) - * [Run the unit tests.](#run-the-unit-tests) - * [Run the integration tests.](#run-the-integration-tests) + * [Run the unit tests.](#run-the-unit-tests-twisted-trial) + * [Run the integration tests (SyTest).](#run-the-integration-tests-sytest) + * [Run the integration tests (Complement).](#run-the-integration-tests-complement) - [9. Submit your patch.](#9-submit-your-patch) * [Changelog](#changelog) + [How do I know what to call the changelog file before I create the PR?](#how-do-i-know-what-to-call-the-changelog-file-before-i-create-the-pr) @@ -197,7 +198,7 @@ The following command will let you run the integration test with the most common configuration: ```sh -$ docker run --rm -it -v /path/where/you/have/cloned/the/repository\:/src:ro -v /path/to/where/you/want/logs\:/logs matrixdotorg/sytest-synapse:py37 +$ docker run --rm -it -v /path/where/you/have/cloned/the/repository\:/src:ro -v /path/to/where/you/want/logs\:/logs matrixdotorg/sytest-synapse:buster ``` This configuration should generally cover your needs. For more details about other configurations, see [documentation in the SyTest repo](https://github.com/matrix-org/sytest/blob/develop/docker/README.md). diff --git a/changelog.d/10599.doc b/changelog.d/10599.doc new file mode 100644 index 0000000000..66e72078f0 --- /dev/null +++ b/changelog.d/10599.doc @@ -0,0 +1 @@ +Update CONTRIBUTING.md to fix index links and the instructions for SyTest in docker. -- cgit 1.5.1 From 19e51b14d23f756883688fd8238da61c6ff29cc3 Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Mon, 16 Aug 2021 18:11:48 +0100 Subject: Manhole: wrap coroutines in `defer.ensureDeferred` automatically (#10602) --- changelog.d/10602.feature | 1 + docs/manhole.md | 2 +- synapse/util/manhole.py | 14 ++++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10602.feature diff --git a/changelog.d/10602.feature b/changelog.d/10602.feature new file mode 100644 index 0000000000..ab18291a20 --- /dev/null +++ b/changelog.d/10602.feature @@ -0,0 +1 @@ +The Synapse manhole no longer needs coroutines to be wrapped in `defer.ensureDeferred`. diff --git a/docs/manhole.md b/docs/manhole.md index 37d1d7823c..db92df88dc 100644 --- a/docs/manhole.md +++ b/docs/manhole.md @@ -67,7 +67,7 @@ This gives a Python REPL in which `hs` gives access to the `synapse.server.HomeServer` object - which in turn gives access to many other parts of the process. -Note that any call which returns a coroutine will need to be wrapped in `ensureDeferred`. +Note that, prior to Synapse 1.41, any call which returns a coroutine will need to be wrapped in `ensureDeferred`. As a simple example, retrieving an event from the database: diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index da24ba0470..522daa323d 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import sys import traceback @@ -20,6 +21,7 @@ from twisted.conch.insults import insults from twisted.conch.manhole import ColoredManhole, ManholeInterpreter from twisted.conch.ssh.keys import Key from twisted.cred import checkers, portal +from twisted.internet import defer PUBLIC_KEY = ( "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDHhGATaW4KhE23+7nrH4jFx3yLq9OjaEs5" @@ -141,3 +143,15 @@ class SynapseManholeInterpreter(ManholeInterpreter): self.write("".join(lines)) finally: last_tb = ei = None + + def displayhook(self, obj): + """ + We override the displayhook so that we automatically convert coroutines + into Deferreds. (Our superclass' displayhook will take care of the rest, + by displaying the Deferred if it's ready, or registering a callback + if it's not). + """ + if inspect.iscoroutine(obj): + super().displayhook(defer.ensureDeferred(obj)) + else: + super().displayhook(obj) -- cgit 1.5.1 From a933c2c7d8ef49c3c98ef443d959f955600bfb6b Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Tue, 17 Aug 2021 10:52:38 +0100 Subject: Add an admin API to check if a username is available (#10578) This adds a new API GET /_synapse/admin/v1/username_available?username=foo to check if a username is available. It is the counterpart to https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-register-available, except that it works even if registration is disabled. --- changelog.d/10578.feature | 1 + docs/admin_api/user_admin_api.md | 20 ++++++++++ synapse/rest/admin/__init__.py | 2 + synapse/rest/admin/username_available.py | 51 ++++++++++++++++++++++++ tests/rest/admin/test_username_available.py | 62 +++++++++++++++++++++++++++++ 5 files changed, 136 insertions(+) create mode 100644 changelog.d/10578.feature create mode 100644 synapse/rest/admin/username_available.py create mode 100644 tests/rest/admin/test_username_available.py diff --git a/changelog.d/10578.feature b/changelog.d/10578.feature new file mode 100644 index 0000000000..02397f0009 --- /dev/null +++ b/changelog.d/10578.feature @@ -0,0 +1 @@ +Add an admin API (`GET /_synapse/admin/username_available`) to check if a username is available (regardless of registration settings). \ No newline at end of file diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 33811f5bbb..4b5dd4685a 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -1057,3 +1057,23 @@ The following parameters should be set in the URL: - `user_id` - The fully qualified MXID: for example, `@user:server.com`. The user must be local. + +### Check username availability + +Checks to see if a username is available, and valid, for the server. See [the client-server +API](https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-register-available) +for more information. + +This endpoint will work even if registration is disabled on the server, unlike +`/_matrix/client/r0/register/available`. + +The API is: + +``` +POST /_synapse/admin/v1/username_availabile?username=$localpart +``` + +The request and response format is the same as the [/_matrix/client/r0/register/available](https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-register-available) API. + +To use it, you will need to authenticate by providing an `access_token` for a +server admin: [Admin API](../usage/administration/admin_api) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index abf749b001..8a91068092 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -51,6 +51,7 @@ from synapse.rest.admin.rooms import ( ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet +from synapse.rest.admin.username_available import UsernameAvailableRestServlet from synapse.rest.admin.users import ( AccountValidityRenewServlet, DeactivateAccountRestServlet, @@ -241,6 +242,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ForwardExtremitiesRestServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server) RateLimitRestServlet(hs).register(http_server) + UsernameAvailableRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( diff --git a/synapse/rest/admin/username_available.py b/synapse/rest/admin/username_available.py new file mode 100644 index 0000000000..2bf1472967 --- /dev/null +++ b/synapse/rest/admin/username_available.py @@ -0,0 +1,51 @@ +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from http import HTTPStatus +from typing import TYPE_CHECKING, Tuple + +from synapse.http.servlet import RestServlet, parse_string +from synapse.http.site import SynapseRequest +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class UsernameAvailableRestServlet(RestServlet): + """An admin API to check if a given username is available, regardless of whether registration is enabled. + + Example: + GET /_synapse/admin/v1/username_available?username=foo + 200 OK + { + "available": true + } + """ + + PATTERNS = admin_patterns("/username_available") + + def __init__(self, hs: "HomeServer"): + self.auth = hs.get_auth() + self.registration_handler = hs.get_registration_handler() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + + username = parse_string(request, "username", required=True) + await self.registration_handler.check_username(username) + return HTTPStatus.OK, {"available": True} diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py new file mode 100644 index 0000000000..53cbc8ddab --- /dev/null +++ b/tests/rest/admin/test_username_available.py @@ -0,0 +1,62 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import synapse.rest.admin +from synapse.api.errors import Codes, SynapseError +from synapse.rest.client.v1 import login + +from tests import unittest + + +class UsernameAvailableTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + url = "/_synapse/admin/v1/username_available" + + def prepare(self, reactor, clock, hs): + self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + async def check_username(username): + if username == "allowed": + return True + raise SynapseError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) + + handler = self.hs.get_registration_handler() + handler.check_username = check_username + + def test_username_available(self): + """ + The endpoint should return a 200 response if the username does not exist + """ + + url = "%s?username=%s" % (self.url, "allowed") + channel = self.make_request("GET", url, None, self.admin_user_tok) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertTrue(channel.json_body["available"]) + + def test_username_unavailable(self): + """ + The endpoint should return a 200 response if the username does not exist + """ + + url = "%s?username=%s" % (self.url, "disallowed") + channel = self.make_request("GET", url, None, self.admin_user_tok) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], "M_USER_IN_USE") + self.assertEqual(channel.json_body["error"], "User ID already taken.") -- cgit 1.5.1 From ae2714c1f31f2a843e19dc44501784401181162c Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 17 Aug 2021 12:23:14 +0200 Subject: Allow using several custom template directories (#10587) Allow using several directories in read_templates. --- changelog.d/10587.misc | 1 + synapse/config/_base.py | 43 ++++++++++++++----------- synapse/config/account_validity.py | 2 +- synapse/config/emailconfig.py | 8 +++-- synapse/config/sso.py | 2 +- synapse/module_api/__init__.py | 5 ++- tests/config/test_base.py | 64 ++++++++++++++++++++++++++++++++++++-- 7 files changed, 98 insertions(+), 27 deletions(-) create mode 100644 changelog.d/10587.misc diff --git a/changelog.d/10587.misc b/changelog.d/10587.misc new file mode 100644 index 0000000000..4c6167977c --- /dev/null +++ b/changelog.d/10587.misc @@ -0,0 +1 @@ +Allow multiple custom directories in `read_templates`. diff --git a/synapse/config/_base.py b/synapse/config/_base.py index d6ec618f8f..2cc242782a 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -237,13 +237,14 @@ class Config: def read_templates( self, filenames: List[str], - custom_template_directory: Optional[str] = None, + custom_template_directories: Optional[Iterable[str]] = None, ) -> List[jinja2.Template]: """Load a list of template files from disk using the given variables. This function will attempt to load the given templates from the default Synapse - template directory. If `custom_template_directory` is supplied, that directory - is tried first. + template directory. If `custom_template_directories` is supplied, any directory + in this list is tried (in the order they appear in the list) before trying + Synapse's default directory. Files read are treated as Jinja templates. The templates are not rendered yet and have autoescape enabled. @@ -251,8 +252,8 @@ class Config: Args: filenames: A list of template filenames to read. - custom_template_directory: A directory to try to look for the templates - before using the default Synapse template directory instead. + custom_template_directories: A list of directory to try to look for the + templates before using the default Synapse template directory instead. Raises: ConfigError: if the file's path is incorrect or otherwise cannot be read. @@ -260,20 +261,26 @@ class Config: Returns: A list of jinja2 templates. """ - search_directories = [self.default_template_dir] - - # The loader will first look in the custom template directory (if specified) for the - # given filename. If it doesn't find it, it will use the default template dir instead - if custom_template_directory: - # Check that the given template directory exists - if not self.path_exists(custom_template_directory): - raise ConfigError( - "Configured template directory does not exist: %s" - % (custom_template_directory,) - ) + search_directories = [] + + # The loader will first look in the custom template directories (if specified) + # for the given filename. If it doesn't find it, it will use the default + # template dir instead. + if custom_template_directories is not None: + for custom_template_directory in custom_template_directories: + # Check that the given template directory exists + if not self.path_exists(custom_template_directory): + raise ConfigError( + "Configured template directory does not exist: %s" + % (custom_template_directory,) + ) + + # Search the custom template directory as well + search_directories.append(custom_template_directory) - # Search the custom template directory as well - search_directories.insert(0, custom_template_directory) + # Append the default directory at the end of the list so Jinja can fallback on it + # if a template is missing from any custom directory. + search_directories.append(self.default_template_dir) # TODO: switch to synapse.util.templates.build_jinja_env loader = jinja2.FileSystemLoader(search_directories) diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py index 6be4eafe55..9acce5996e 100644 --- a/synapse/config/account_validity.py +++ b/synapse/config/account_validity.py @@ -88,5 +88,5 @@ class AccountValidityConfig(Config): "account_previously_renewed.html", invalid_token_template_filename, ], - account_validity_template_dir, + (td for td in (account_validity_template_dir,) if td), ) diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 42526502f0..fc74b4a8b9 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -257,7 +257,9 @@ class EmailConfig(Config): registration_template_success_html, add_threepid_template_success_html, ], - template_dir, + ( + td for td in (template_dir,) if td + ), # Filter out template_dir if not provided ) # Render templates that do not contain any placeholders @@ -297,7 +299,7 @@ class EmailConfig(Config): self.email_notif_template_text, ) = self.read_templates( [notif_template_html, notif_template_text], - template_dir, + (td for td in (template_dir,) if td), ) self.email_notif_for_new_users = email_config.get( @@ -320,7 +322,7 @@ class EmailConfig(Config): self.account_validity_template_text, ) = self.read_templates( [expiry_template_html, expiry_template_text], - template_dir, + (td for td in (template_dir,) if td), ) subjects_config = email_config.get("subjects", {}) diff --git a/synapse/config/sso.py b/synapse/config/sso.py index d0f04cf8e6..4b590e0535 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -63,7 +63,7 @@ class SSOConfig(Config): "sso_auth_success.html", "sso_auth_bad_user.html", ], - self.sso_template_dir, + (td for td in (self.sso_template_dir,) if td), ) # These templates have no placeholders, so render them here diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 1cc13fc97b..82725853bc 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -677,7 +677,10 @@ class ModuleApi: A list containing the loaded templates, with the orders matching the one of the filenames parameter. """ - return self._hs.config.read_templates(filenames, custom_template_directory) + return self._hs.config.read_templates( + filenames, + (td for td in (custom_template_directory,) if td), + ) class PublicRoomListManager: diff --git a/tests/config/test_base.py b/tests/config/test_base.py index 84ae3b88ae..baa5313fb3 100644 --- a/tests/config/test_base.py +++ b/tests/config/test_base.py @@ -30,7 +30,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # contain template files with tempfile.TemporaryDirectory() as tmp_dir: # Attempt to load an HTML template from our custom template directory - template = self.hs.config.read_templates(["sso_error.html"], tmp_dir)[0] + template = self.hs.config.read_templates(["sso_error.html"], (tmp_dir,))[0] # If no errors, we should've gotten the default template instead @@ -60,7 +60,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # Attempt to load the template from our custom template directory template = ( - self.hs.config.read_templates([template_filename], tmp_dir) + self.hs.config.read_templates([template_filename], (tmp_dir,)) )[0] # Render the template @@ -74,8 +74,66 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): "Template file did not contain our test string", ) + def test_multiple_custom_template_directories(self): + """Tests that directories are searched in the right order if multiple custom + template directories are provided. + """ + # Create two temporary directories on the filesystem. + tempdirs = [ + tempfile.TemporaryDirectory(), + tempfile.TemporaryDirectory(), + ] + + # Create one template in each directory, whose content is the index of the + # directory in the list. + template_filename = "my_template.html.j2" + for i in range(len(tempdirs)): + tempdir = tempdirs[i] + template_path = os.path.join(tempdir.name, template_filename) + + with open(template_path, "w") as fp: + fp.write(str(i)) + fp.flush() + + # Retrieve the template. + template = ( + self.hs.config.read_templates( + [template_filename], + (td.name for td in tempdirs), + ) + )[0] + + # Test that we got the template we dropped in the first directory in the list. + self.assertEqual(template.render(), "0") + + # Add another template, this one only in the second directory in the list, so we + # can test that the second directory is still searched into when no matching file + # could be found in the first one. + other_template_name = "my_other_template.html.j2" + other_template_path = os.path.join(tempdirs[1].name, other_template_name) + + with open(other_template_path, "w") as fp: + fp.write("hello world") + fp.flush() + + # Retrieve the template. + template = ( + self.hs.config.read_templates( + [other_template_name], + (td.name for td in tempdirs), + ) + )[0] + + # Test that the file has the expected content. + self.assertEqual(template.render(), "hello world") + + # Cleanup the temporary directories manually since we're not using a context + # manager. + for td in tempdirs: + td.cleanup() + def test_loading_template_from_nonexistent_custom_directory(self): with self.assertRaises(ConfigError): self.hs.config.read_templates( - ["some_filename.html"], "a_nonexistent_directory" + ["some_filename.html"], ("a_nonexistent_directory",) ) -- cgit 1.5.1 From 58f0d97275e9ffc134f9aaf59ce01c0e745ec041 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 17 Aug 2021 11:45:35 +0100 Subject: update links to schema doc (#10620) --- changelog.d/10620.misc | 1 + synapse/storage/schema/README.md | 2 +- synapse/storage/schema/__init__.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 changelog.d/10620.misc diff --git a/changelog.d/10620.misc b/changelog.d/10620.misc new file mode 100644 index 0000000000..8b29668a1f --- /dev/null +++ b/changelog.d/10620.misc @@ -0,0 +1 @@ +Fix up a couple of links to the database schema documentation. diff --git a/synapse/storage/schema/README.md b/synapse/storage/schema/README.md index 729f44ea6c..4fc2061a3d 100644 --- a/synapse/storage/schema/README.md +++ b/synapse/storage/schema/README.md @@ -1,4 +1,4 @@ # Synapse Database Schemas This directory contains the schema files used to build Synapse databases. For more -information, see /docs/development/database_schema.md. +information, see https://matrix-org.github.io/synapse/develop/development/database_schema.html. diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index fd4dd67d91..7e0687e197 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -19,8 +19,8 @@ This should be incremented whenever the codebase changes its requirements on the shape of the database schema (even if those requirements are backwards-compatible with older versions of Synapse). -See `README.md `_ for more information on how this -works. +See https://matrix-org.github.io/synapse/develop/development/database_schema.html +for more information on how this works. Changes in SCHEMA_VERSION = 61: - The `user_stats_historical` and `room_stats_historical` tables are not written and -- cgit 1.5.1 From 3bcd525b46678ff228c4275acad47c12974c9a33 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 17 Aug 2021 12:56:11 +0200 Subject: Allow to edit `external_ids` by Edit User admin API (#10598) Signed-off-by: Dirk Klimpel dirk@klimpel.org --- changelog.d/10598.feature | 1 + docs/admin_api/user_admin_api.md | 40 +++-- synapse/rest/admin/users.py | 139 +++++++++------ synapse/storage/databases/main/registration.py | 22 +++ tests/rest/admin/test_user.py | 227 +++++++++++++++++++++---- 5 files changed, 340 insertions(+), 89 deletions(-) create mode 100644 changelog.d/10598.feature diff --git a/changelog.d/10598.feature b/changelog.d/10598.feature new file mode 100644 index 0000000000..92c159118b --- /dev/null +++ b/changelog.d/10598.feature @@ -0,0 +1 @@ +Allow editing a user's `external_ids` via the "Edit User" admin API. Contributed by @dklimpel. \ No newline at end of file diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 4b5dd4685a..6a9335d6ec 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -81,6 +81,16 @@ with a body of: "address": "" } ], + "external_ids": [ + { + "auth_provider": "", + "external_id": "" + }, + { + "auth_provider": "", + "external_id": "" + } + ], "avatar_url": "", "admin": false, "deactivated": false @@ -90,26 +100,34 @@ with a body of: To use it, you will need to authenticate by providing an `access_token` for a server admin: [Admin API](../usage/administration/admin_api) +Returns HTTP status code: +- `201` - When a new user object was created. +- `200` - When a user was modified. + URL parameters: - `user_id`: fully-qualified user id: for example, `@user:server.com`. Body parameters: -- `password`, optional. If provided, the user's password is updated and all +- `password` - string, optional. If provided, the user's password is updated and all devices are logged out. - -- `displayname`, optional, defaults to the value of `user_id`. - -- `threepids`, optional, allows setting the third-party IDs (email, msisdn) +- `displayname` - string, optional, defaults to the value of `user_id`. +- `threepids` - array, optional, allows setting the third-party IDs (email, msisdn) + - `medium` - string. Kind of third-party ID, either `email` or `msisdn`. + - `address` - string. Value of third-party ID. belonging to a user. - -- `avatar_url`, optional, must be a +- `external_ids` - array, optional. Allow setting the identifier of the external identity + provider for SSO (Single sign-on). Details in + [Sample Configuration File](../usage/configuration/homeserver_sample_config.html) + section `sso` and `oidc_providers`. + - `auth_provider` - string. ID of the external identity provider. Value of `idp_id` + in homeserver configuration. + - `external_id` - string, user ID in the external identity provider. +- `avatar_url` - string, optional, must be a [MXC URI](https://matrix.org/docs/spec/client_server/r0.6.0#matrix-content-mxc-uris). - -- `admin`, optional, defaults to `false`. - -- `deactivated`, optional. If unspecified, deactivation state will be left +- `admin` - bool, optional, defaults to `false`. +- `deactivated` - bool, optional. If unspecified, deactivation state will be left unchanged on existing accounts and set to `false` for new accounts. A user cannot be erased by deactivating with this API. For details on deactivating users see [Deactivate Account](#deactivate-account). diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 41f21ba118..c885fd77ab 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -196,20 +196,57 @@ class UserRestServletV2(RestServlet): user = await self.admin_handler.get_user(target_user) user_id = target_user.to_string() + # check for required parameters for each threepid + threepids = body.get("threepids") + if threepids is not None: + for threepid in threepids: + assert_params_in_dict(threepid, ["medium", "address"]) + + # check for required parameters for each external_id + external_ids = body.get("external_ids") + if external_ids is not None: + for external_id in external_ids: + assert_params_in_dict(external_id, ["auth_provider", "external_id"]) + + user_type = body.get("user_type", None) + if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: + raise SynapseError(400, "Invalid user type") + + set_admin_to = body.get("admin", False) + if not isinstance(set_admin_to, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'admin' must be a boolean, if given", + Codes.BAD_JSON, + ) + + password = body.get("password", None) + if password is not None: + if not isinstance(password, str) or len(password) > 512: + raise SynapseError(400, "Invalid password") + + deactivate = body.get("deactivated", False) + if not isinstance(deactivate, bool): + raise SynapseError(400, "'deactivated' parameter is not of type boolean") + + # convert into List[Tuple[str, str]] + if external_ids is not None: + new_external_ids = [] + for external_id in external_ids: + new_external_ids.append( + (external_id["auth_provider"], external_id["external_id"]) + ) + if user: # modify user if "displayname" in body: await self.profile_handler.set_displayname( target_user, requester, body["displayname"], True ) - if "threepids" in body: - # check for required parameters for each threepid - for threepid in body["threepids"]: - assert_params_in_dict(threepid, ["medium", "address"]) - + if threepids is not None: # remove old threepids from user - threepids = await self.store.user_get_threepids(user_id) - for threepid in threepids: + old_threepids = await self.store.user_get_threepids(user_id) + for threepid in old_threepids: try: await self.auth_handler.delete_threepid( user_id, threepid["medium"], threepid["address"], None @@ -220,18 +257,39 @@ class UserRestServletV2(RestServlet): # add new threepids to user current_time = self.hs.get_clock().time_msec() - for threepid in body["threepids"]: + for threepid in threepids: await self.auth_handler.add_threepid( user_id, threepid["medium"], threepid["address"], current_time ) - if "avatar_url" in body and type(body["avatar_url"]) == str: + if external_ids is not None: + # get changed external_ids (added and removed) + cur_external_ids = await self.store.get_external_ids_by_user(user_id) + add_external_ids = set(new_external_ids) - set(cur_external_ids) + del_external_ids = set(cur_external_ids) - set(new_external_ids) + + # remove old external_ids + for auth_provider, external_id in del_external_ids: + await self.store.remove_user_external_id( + auth_provider, + external_id, + user_id, + ) + + # add new external_ids + for auth_provider, external_id in add_external_ids: + await self.store.record_user_external_id( + auth_provider, + external_id, + user_id, + ) + + if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( target_user, requester, body["avatar_url"], True ) if "admin" in body: - set_admin_to = bool(body["admin"]) if set_admin_to != user["admin"]: auth_user = requester.user if target_user == auth_user and not set_admin_to: @@ -239,29 +297,18 @@ class UserRestServletV2(RestServlet): await self.store.set_server_admin(target_user, set_admin_to) - if "password" in body: - if not isinstance(body["password"], str) or len(body["password"]) > 512: - raise SynapseError(400, "Invalid password") - else: - new_password = body["password"] - logout_devices = True - - new_password_hash = await self.auth_handler.hash(new_password) - - await self.set_password_handler.set_password( - target_user.to_string(), - new_password_hash, - logout_devices, - requester, - ) + if password is not None: + logout_devices = True + new_password_hash = await self.auth_handler.hash(password) + + await self.set_password_handler.set_password( + target_user.to_string(), + new_password_hash, + logout_devices, + requester, + ) if "deactivated" in body: - deactivate = body["deactivated"] - if not isinstance(deactivate, bool): - raise SynapseError( - 400, "'deactivated' parameter is not of type boolean" - ) - if deactivate and not user["deactivated"]: await self.deactivate_account_handler.deactivate_account( target_user.to_string(), False, requester, by_admin=True @@ -285,36 +332,24 @@ class UserRestServletV2(RestServlet): return 200, user else: # create user - password = body.get("password") + displayname = body.get("displayname", None) + password_hash = None if password is not None: - if not isinstance(password, str) or len(password) > 512: - raise SynapseError(400, "Invalid password") password_hash = await self.auth_handler.hash(password) - admin = body.get("admin", None) - user_type = body.get("user_type", None) - displayname = body.get("displayname", None) - - if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES: - raise SynapseError(400, "Invalid user type") - user_id = await self.registration_handler.register_user( localpart=target_user.localpart, password_hash=password_hash, - admin=bool(admin), + admin=set_admin_to, default_display_name=displayname, user_type=user_type, by_admin=True, ) - if "threepids" in body: - # check for required parameters for each threepid - for threepid in body["threepids"]: - assert_params_in_dict(threepid, ["medium", "address"]) - + if threepids is not None: current_time = self.hs.get_clock().time_msec() - for threepid in body["threepids"]: + for threepid in threepids: await self.auth_handler.add_threepid( user_id, threepid["medium"], threepid["address"], current_time ) @@ -334,6 +369,14 @@ class UserRestServletV2(RestServlet): data={}, ) + if external_ids is not None: + for auth_provider, external_id in new_external_ids: + await self.store.record_user_external_id( + auth_provider, + external_id, + user_id, + ) + if "avatar_url" in body and isinstance(body["avatar_url"], str): await self.profile_handler.set_avatar_url( target_user, requester, body["avatar_url"], True diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 14670c2881..c67bea81c6 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -599,6 +599,28 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="record_user_external_id", ) + async def remove_user_external_id( + self, auth_provider: str, external_id: str, user_id: str + ) -> None: + """Remove a mapping from an external user id to a mxid + + If the mapping is not found, this method does nothing. + + Args: + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ + await self.db_pool.simple_delete( + table="user_external_ids", + keyvalues={ + "auth_provider": auth_provider, + "external_id": external_id, + "user_id": user_id, + }, + desc="remove_user_external_id", + ) + async def get_user_by_external_id( self, auth_provider: str, external_id: str ) -> Optional[str]: diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 13fab5579b..a736ec4754 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1240,56 +1240,114 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual("M_NOT_FOUND", channel.json_body["errcode"]) - def test_get_user(self): + def test_invalid_parameter(self): """ - Test a simple get of a user. + If parameters are invalid, an error is returned. """ + + # admin not bool channel = self.make_request( - "GET", + "PUT", self.url_other_user, access_token=self.admin_user_tok, + content={"admin": "not_bool"}, ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual("User", channel.json_body["displayname"]) - self._check_fields(channel.json_body) + # deactivated not bool + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"deactivated": "not_bool"}, + ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) - def test_get_user_with_sso(self): - """ - Test get a user with SSO details. - """ - self.get_success( - self.store.record_user_external_id( - "auth_provider1", "external_id1", self.other_user - ) + # password not str + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"password": True}, ) - self.get_success( - self.store.record_user_external_id( - "auth_provider2", "external_id2", self.other_user - ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # password not length + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"password": "x" * 513}, ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + # user_type not valid channel = self.make_request( - "GET", + "PUT", self.url_other_user, access_token=self.admin_user_tok, + content={"user_type": "new type"}, ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual( - "external_id1", channel.json_body["external_ids"][0]["external_id"] + # external_ids not valid + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "external_ids": {"auth_provider": "prov", "wrong_external_id": "id"} + }, ) - self.assertEqual( - "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"] + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"external_ids": {"external_id": "id"}}, ) - self.assertEqual( - "external_id2", channel.json_body["external_ids"][1]["external_id"] + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + # threepids not valid + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"threepids": {"medium": "email", "wrong_address": "id"}}, ) - self.assertEqual( - "auth_provider2", channel.json_body["external_ids"][1]["auth_provider"] + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"threepids": {"address": "value"}}, ) + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + def test_get_user(self): + """ + Test a simple get of a user. + """ + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual("User", channel.json_body["displayname"]) self._check_fields(channel.json_body) def test_create_server_admin(self): @@ -1353,6 +1411,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): "admin": False, "displayname": "Bob's name", "threepids": [{"medium": "email", "address": "bob@bob.bob"}], + "external_ids": [ + { + "external_id": "external_id1", + "auth_provider": "auth_provider1", + }, + ], "avatar_url": "mxc://fibble/wibble", } @@ -1368,6 +1432,12 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual( + "external_id1", channel.json_body["external_ids"][0]["external_id"] + ) + self.assertEqual( + "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"] + ) self.assertFalse(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) self._check_fields(channel.json_body) @@ -1632,6 +1702,103 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + def test_set_external_id(self): + """ + Test setting external id for an other user. + """ + + # Add two external_ids + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "external_ids": [ + { + "external_id": "external_id1", + "auth_provider": "auth_provider1", + }, + { + "external_id": "external_id2", + "auth_provider": "auth_provider2", + }, + ] + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["external_ids"])) + # result does not always have the same sort order, therefore it becomes sorted + self.assertEqual( + sorted(channel.json_body["external_ids"], key=lambda k: k["auth_provider"]), + [ + {"auth_provider": "auth_provider1", "external_id": "external_id1"}, + {"auth_provider": "auth_provider2", "external_id": "external_id2"}, + ], + ) + self._check_fields(channel.json_body) + + # Set a new and remove an external_id + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "external_ids": [ + { + "external_id": "external_id2", + "auth_provider": "auth_provider2", + }, + { + "external_id": "external_id3", + "auth_provider": "auth_provider3", + }, + ] + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["external_ids"])) + self.assertEqual( + channel.json_body["external_ids"], + [ + {"auth_provider": "auth_provider2", "external_id": "external_id2"}, + {"auth_provider": "auth_provider3", "external_id": "external_id3"}, + ], + ) + self._check_fields(channel.json_body) + + # Get user + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual( + channel.json_body["external_ids"], + [ + {"auth_provider": "auth_provider2", "external_id": "external_id2"}, + {"auth_provider": "auth_provider3", "external_id": "external_id3"}, + ], + ) + self._check_fields(channel.json_body) + + # Remove external_ids + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"external_ids": []}, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(0, len(channel.json_body["external_ids"])) + def test_deactivate_user(self): """ Test deactivating another user. -- cgit 1.5.1 From b62eba770522fde7bf1204eb5771ee24d9a5e7bc Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 17 Aug 2021 12:32:25 +0100 Subject: Always list fallback key types in /sync (#10623) --- changelog.d/10623.bugfix | 1 + synapse/rest/client/v2_alpha/sync.py | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) create mode 100644 changelog.d/10623.bugfix diff --git a/changelog.d/10623.bugfix b/changelog.d/10623.bugfix new file mode 100644 index 0000000000..759fba3513 --- /dev/null +++ b/changelog.d/10623.bugfix @@ -0,0 +1 @@ +Revert behaviour introduced in v1.38.0 that strips `org.matrix.msc2732.device_unused_fallback_key_types` from `/sync` when its value is empty. This field should instead always be present according to [MSC2732](https://github.com/matrix-org/matrix-doc/blob/master/proposals/2732-olm-fallback-keys.md). \ No newline at end of file diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index e321668698..e18f4d01b3 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -259,10 +259,11 @@ class SyncRestServlet(RestServlet): # Corresponding synapse issue: https://github.com/matrix-org/synapse/issues/10456 response["device_one_time_keys_count"] = sync_result.device_one_time_keys_count - if sync_result.device_unused_fallback_key_types: - response[ - "org.matrix.msc2732.device_unused_fallback_key_types" - ] = sync_result.device_unused_fallback_key_types + # https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md + # states that this field should always be included, as long as the server supports the feature. + response[ + "org.matrix.msc2732.device_unused_fallback_key_types" + ] = sync_result.device_unused_fallback_key_types if joined: response["rooms"][Membership.JOIN] = joined -- cgit 1.5.1 From 642a42eddece60afbbd5e5a6659fa9b939238b4a Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Tue, 17 Aug 2021 12:57:58 +0100 Subject: Flatten the synapse.rest.client package (#10600) --- changelog.d/10600.misc | 1 + synapse/app/generic_worker.py | 40 +- synapse/handlers/auth.py | 6 +- synapse/rest/__init__.py | 35 +- synapse/rest/admin/users.py | 4 +- synapse/rest/client/__init__.py | 2 +- synapse/rest/client/_base.py | 100 ++ synapse/rest/client/account.py | 910 ++++++++++++++++ synapse/rest/client/account_data.py | 122 +++ synapse/rest/client/account_validity.py | 104 ++ synapse/rest/client/auth.py | 143 +++ synapse/rest/client/capabilities.py | 68 ++ synapse/rest/client/devices.py | 300 +++++ synapse/rest/client/directory.py | 185 ++++ synapse/rest/client/events.py | 94 ++ synapse/rest/client/filter.py | 94 ++ synapse/rest/client/groups.py | 957 ++++++++++++++++ synapse/rest/client/initial_sync.py | 47 + synapse/rest/client/keys.py | 344 ++++++ synapse/rest/client/knock.py | 107 ++ synapse/rest/client/login.py | 600 ++++++++++ synapse/rest/client/logout.py | 72 ++ synapse/rest/client/notifications.py | 91 ++ synapse/rest/client/openid.py | 94 ++ synapse/rest/client/password_policy.py | 57 + synapse/rest/client/presence.py | 95 ++ synapse/rest/client/profile.py | 155 +++ synapse/rest/client/push_rule.py | 354 ++++++ synapse/rest/client/pusher.py | 171 +++ synapse/rest/client/read_marker.py | 74 ++ synapse/rest/client/receipts.py | 71 ++ synapse/rest/client/register.py | 879 +++++++++++++++ synapse/rest/client/relations.py | 381 +++++++ synapse/rest/client/report_event.py | 68 ++ synapse/rest/client/room.py | 1152 ++++++++++++++++++++ synapse/rest/client/room_batch.py | 441 ++++++++ synapse/rest/client/room_keys.py | 391 +++++++ synapse/rest/client/room_upgrade_rest_servlet.py | 88 ++ synapse/rest/client/sendtodevice.py | 67 ++ synapse/rest/client/shared_rooms.py | 67 ++ synapse/rest/client/sync.py | 532 +++++++++ synapse/rest/client/tags.py | 85 ++ synapse/rest/client/thirdparty.py | 111 ++ synapse/rest/client/tokenrefresh.py | 37 + synapse/rest/client/user_directory.py | 79 ++ synapse/rest/client/v1/__init__.py | 13 - synapse/rest/client/v1/directory.py | 185 ---- synapse/rest/client/v1/events.py | 94 -- synapse/rest/client/v1/initial_sync.py | 47 - synapse/rest/client/v1/login.py | 600 ---------- synapse/rest/client/v1/logout.py | 72 -- synapse/rest/client/v1/presence.py | 95 -- synapse/rest/client/v1/profile.py | 155 --- synapse/rest/client/v1/push_rule.py | 354 ------ synapse/rest/client/v1/pusher.py | 171 --- synapse/rest/client/v1/room.py | 1152 -------------------- synapse/rest/client/v1/voip.py | 73 -- synapse/rest/client/v2_alpha/__init__.py | 13 - synapse/rest/client/v2_alpha/_base.py | 100 -- synapse/rest/client/v2_alpha/account.py | 910 ---------------- synapse/rest/client/v2_alpha/account_data.py | 122 --- synapse/rest/client/v2_alpha/account_validity.py | 104 -- synapse/rest/client/v2_alpha/auth.py | 143 --- synapse/rest/client/v2_alpha/capabilities.py | 68 -- synapse/rest/client/v2_alpha/devices.py | 300 ----- synapse/rest/client/v2_alpha/filter.py | 94 -- synapse/rest/client/v2_alpha/groups.py | 957 ---------------- synapse/rest/client/v2_alpha/keys.py | 344 ------ synapse/rest/client/v2_alpha/knock.py | 107 -- synapse/rest/client/v2_alpha/notifications.py | 91 -- synapse/rest/client/v2_alpha/openid.py | 94 -- synapse/rest/client/v2_alpha/password_policy.py | 57 - synapse/rest/client/v2_alpha/read_marker.py | 74 -- synapse/rest/client/v2_alpha/receipts.py | 71 -- synapse/rest/client/v2_alpha/register.py | 879 --------------- synapse/rest/client/v2_alpha/relations.py | 381 ------- synapse/rest/client/v2_alpha/report_event.py | 68 -- synapse/rest/client/v2_alpha/room.py | 441 -------- synapse/rest/client/v2_alpha/room_keys.py | 391 ------- .../client/v2_alpha/room_upgrade_rest_servlet.py | 88 -- synapse/rest/client/v2_alpha/sendtodevice.py | 67 -- synapse/rest/client/v2_alpha/shared_rooms.py | 67 -- synapse/rest/client/v2_alpha/sync.py | 532 --------- synapse/rest/client/v2_alpha/tags.py | 85 -- synapse/rest/client/v2_alpha/thirdparty.py | 111 -- synapse/rest/client/v2_alpha/tokenrefresh.py | 37 - synapse/rest/client/v2_alpha/user_directory.py | 79 -- synapse/rest/client/voip.py | 73 ++ tests/app/test_phone_stats_home.py | 2 +- tests/events/test_presence_router.py | 2 +- tests/events/test_snapshot.py | 2 +- tests/federation/test_complexity.py | 2 +- tests/federation/test_federation_catch_up.py | 2 +- tests/federation/test_federation_sender.py | 2 +- tests/federation/test_federation_server.py | 2 +- tests/federation/transport/test_knocking.py | 2 +- tests/handlers/test_admin.py | 4 +- tests/handlers/test_directory.py | 2 +- tests/handlers/test_federation.py | 2 +- tests/handlers/test_message.py | 2 +- tests/handlers/test_password_providers.py | 3 +- tests/handlers/test_presence.py | 2 +- tests/handlers/test_room_summary.py | 2 +- tests/handlers/test_stats.py | 2 +- tests/handlers/test_user_directory.py | 3 +- tests/module_api/test_api.py | 2 +- tests/push/test_email.py | 2 +- tests/push/test_http.py | 3 +- tests/replication/tcp/streams/test_events.py | 2 +- tests/replication/test_auth.py | 2 +- tests/replication/test_client_reader_shard.py | 2 +- tests/replication/test_federation_sender_shard.py | 2 +- tests/replication/test_multi_media_repo.py | 2 +- tests/replication/test_pusher_shard.py | 2 +- tests/replication/test_sharded_event_persister.py | 3 +- tests/rest/admin/test_admin.py | 3 +- tests/rest/admin/test_device.py | 2 +- tests/rest/admin/test_event_reports.py | 3 +- tests/rest/admin/test_media.py | 2 +- tests/rest/admin/test_room.py | 2 +- tests/rest/admin/test_statistics.py | 2 +- tests/rest/admin/test_user.py | 3 +- tests/rest/admin/test_username_available.py | 2 +- tests/rest/client/test_consent.py | 2 +- tests/rest/client/test_ephemeral_message.py | 2 +- tests/rest/client/test_identity.py | 2 +- tests/rest/client/test_power_levels.py | 3 +- tests/rest/client/test_redactions.py | 3 +- tests/rest/client/test_retention.py | 2 +- tests/rest/client/test_shadow_banned.py | 9 +- tests/rest/client/test_third_party_rules.py | 2 +- tests/rest/client/v1/test_directory.py | 2 +- tests/rest/client/v1/test_events.py | 2 +- tests/rest/client/v1/test_login.py | 5 +- tests/rest/client/v1/test_presence.py | 2 +- tests/rest/client/v1/test_profile.py | 2 +- tests/rest/client/v1/test_push_rule_attrs.py | 2 +- tests/rest/client/v1/test_rooms.py | 3 +- tests/rest/client/v1/test_typing.py | 2 +- tests/rest/client/v2_alpha/test_account.py | 3 +- tests/rest/client/v2_alpha/test_auth.py | 3 +- tests/rest/client/v2_alpha/test_capabilities.py | 3 +- tests/rest/client/v2_alpha/test_filter.py | 2 +- tests/rest/client/v2_alpha/test_password_policy.py | 3 +- tests/rest/client/v2_alpha/test_register.py | 3 +- tests/rest/client/v2_alpha/test_relations.py | 3 +- tests/rest/client/v2_alpha/test_report_event.py | 3 +- tests/rest/client/v2_alpha/test_sendtodevice.py | 3 +- tests/rest/client/v2_alpha/test_shared_rooms.py | 3 +- tests/rest/client/v2_alpha/test_sync.py | 3 +- tests/rest/client/v2_alpha/test_upgrade_room.py | 3 +- tests/rest/media/v1/test_media_storage.py | 2 +- tests/server_notices/test_consent.py | 3 +- .../test_resource_limits_server_notices.py | 3 +- tests/storage/databases/main/test_events_worker.py | 2 +- tests/storage/test_cleanup_extrems.py | 2 +- tests/storage/test_client_ips.py | 2 +- tests/storage/test_event_chain.py | 2 +- tests/storage/test_events.py | 2 +- tests/storage/test_purge.py | 2 +- tests/storage/test_roommember.py | 2 +- tests/test_mau.py | 2 +- tests/test_terms_auth.py | 2 +- 163 files changed, 9984 insertions(+), 10035 deletions(-) create mode 100644 changelog.d/10600.misc create mode 100644 synapse/rest/client/_base.py create mode 100644 synapse/rest/client/account.py create mode 100644 synapse/rest/client/account_data.py create mode 100644 synapse/rest/client/account_validity.py create mode 100644 synapse/rest/client/auth.py create mode 100644 synapse/rest/client/capabilities.py create mode 100644 synapse/rest/client/devices.py create mode 100644 synapse/rest/client/directory.py create mode 100644 synapse/rest/client/events.py create mode 100644 synapse/rest/client/filter.py create mode 100644 synapse/rest/client/groups.py create mode 100644 synapse/rest/client/initial_sync.py create mode 100644 synapse/rest/client/keys.py create mode 100644 synapse/rest/client/knock.py create mode 100644 synapse/rest/client/login.py create mode 100644 synapse/rest/client/logout.py create mode 100644 synapse/rest/client/notifications.py create mode 100644 synapse/rest/client/openid.py create mode 100644 synapse/rest/client/password_policy.py create mode 100644 synapse/rest/client/presence.py create mode 100644 synapse/rest/client/profile.py create mode 100644 synapse/rest/client/push_rule.py create mode 100644 synapse/rest/client/pusher.py create mode 100644 synapse/rest/client/read_marker.py create mode 100644 synapse/rest/client/receipts.py create mode 100644 synapse/rest/client/register.py create mode 100644 synapse/rest/client/relations.py create mode 100644 synapse/rest/client/report_event.py create mode 100644 synapse/rest/client/room.py create mode 100644 synapse/rest/client/room_batch.py create mode 100644 synapse/rest/client/room_keys.py create mode 100644 synapse/rest/client/room_upgrade_rest_servlet.py create mode 100644 synapse/rest/client/sendtodevice.py create mode 100644 synapse/rest/client/shared_rooms.py create mode 100644 synapse/rest/client/sync.py create mode 100644 synapse/rest/client/tags.py create mode 100644 synapse/rest/client/thirdparty.py create mode 100644 synapse/rest/client/tokenrefresh.py create mode 100644 synapse/rest/client/user_directory.py delete mode 100644 synapse/rest/client/v1/__init__.py delete mode 100644 synapse/rest/client/v1/directory.py delete mode 100644 synapse/rest/client/v1/events.py delete mode 100644 synapse/rest/client/v1/initial_sync.py delete mode 100644 synapse/rest/client/v1/login.py delete mode 100644 synapse/rest/client/v1/logout.py delete mode 100644 synapse/rest/client/v1/presence.py delete mode 100644 synapse/rest/client/v1/profile.py delete mode 100644 synapse/rest/client/v1/push_rule.py delete mode 100644 synapse/rest/client/v1/pusher.py delete mode 100644 synapse/rest/client/v1/room.py delete mode 100644 synapse/rest/client/v1/voip.py delete mode 100644 synapse/rest/client/v2_alpha/__init__.py delete mode 100644 synapse/rest/client/v2_alpha/_base.py delete mode 100644 synapse/rest/client/v2_alpha/account.py delete mode 100644 synapse/rest/client/v2_alpha/account_data.py delete mode 100644 synapse/rest/client/v2_alpha/account_validity.py delete mode 100644 synapse/rest/client/v2_alpha/auth.py delete mode 100644 synapse/rest/client/v2_alpha/capabilities.py delete mode 100644 synapse/rest/client/v2_alpha/devices.py delete mode 100644 synapse/rest/client/v2_alpha/filter.py delete mode 100644 synapse/rest/client/v2_alpha/groups.py delete mode 100644 synapse/rest/client/v2_alpha/keys.py delete mode 100644 synapse/rest/client/v2_alpha/knock.py delete mode 100644 synapse/rest/client/v2_alpha/notifications.py delete mode 100644 synapse/rest/client/v2_alpha/openid.py delete mode 100644 synapse/rest/client/v2_alpha/password_policy.py delete mode 100644 synapse/rest/client/v2_alpha/read_marker.py delete mode 100644 synapse/rest/client/v2_alpha/receipts.py delete mode 100644 synapse/rest/client/v2_alpha/register.py delete mode 100644 synapse/rest/client/v2_alpha/relations.py delete mode 100644 synapse/rest/client/v2_alpha/report_event.py delete mode 100644 synapse/rest/client/v2_alpha/room.py delete mode 100644 synapse/rest/client/v2_alpha/room_keys.py delete mode 100644 synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py delete mode 100644 synapse/rest/client/v2_alpha/sendtodevice.py delete mode 100644 synapse/rest/client/v2_alpha/shared_rooms.py delete mode 100644 synapse/rest/client/v2_alpha/sync.py delete mode 100644 synapse/rest/client/v2_alpha/tags.py delete mode 100644 synapse/rest/client/v2_alpha/thirdparty.py delete mode 100644 synapse/rest/client/v2_alpha/tokenrefresh.py delete mode 100644 synapse/rest/client/v2_alpha/user_directory.py create mode 100644 synapse/rest/client/voip.py diff --git a/changelog.d/10600.misc b/changelog.d/10600.misc new file mode 100644 index 0000000000..489dc20b11 --- /dev/null +++ b/changelog.d/10600.misc @@ -0,0 +1 @@ +Flatten the `synapse.rest.client` package by moving the contents of `v1` and `v2_alpha` into the parent. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 3b7131af8f..d7b425a7ab 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -66,40 +66,40 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.room import RoomStore from synapse.rest.admin import register_servlets_for_media_repo -from synapse.rest.client.v1 import events, login, presence, room -from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet -from synapse.rest.client.v1.profile import ( - ProfileAvatarURLRestServlet, - ProfileDisplaynameRestServlet, - ProfileRestServlet, -) -from synapse.rest.client.v1.push_rule import PushRuleRestServlet -from synapse.rest.client.v1.voip import VoipRestServlet -from synapse.rest.client.v2_alpha import ( +from synapse.rest.client import ( account_data, + events, groups, + login, + presence, read_marker, receipts, + room, room_keys, sync, tags, user_directory, ) -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.rest.client.v2_alpha.account import ThreepidRestServlet -from synapse.rest.client.v2_alpha.account_data import ( - AccountDataServlet, - RoomAccountDataServlet, -) -from synapse.rest.client.v2_alpha.devices import DevicesRestServlet -from synapse.rest.client.v2_alpha.keys import ( +from synapse.rest.client._base import client_patterns +from synapse.rest.client.account import ThreepidRestServlet +from synapse.rest.client.account_data import AccountDataServlet, RoomAccountDataServlet +from synapse.rest.client.devices import DevicesRestServlet +from synapse.rest.client.initial_sync import InitialSyncRestServlet +from synapse.rest.client.keys import ( KeyChangesServlet, KeyQueryServlet, OneTimeKeyServlet, ) -from synapse.rest.client.v2_alpha.register import RegisterRestServlet -from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet +from synapse.rest.client.profile import ( + ProfileAvatarURLRestServlet, + ProfileDisplaynameRestServlet, + ProfileRestServlet, +) +from synapse.rest.client.push_rule import PushRuleRestServlet +from synapse.rest.client.register import RegisterRestServlet +from synapse.rest.client.sendtodevice import SendToDeviceRestServlet from synapse.rest.client.versions import VersionsRestServlet +from synapse.rest.client.voip import VoipRestServlet from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.synapse.client import build_synapse_client_resource_tree diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 22a8552241..161b3c933c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -73,7 +73,7 @@ from synapse.util.stringutils import base62_encode from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: - from synapse.rest.client.v1.login import LoginResponse + from synapse.rest.client.login import LoginResponse from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -461,7 +461,7 @@ class AuthHandler(BaseHandler): If no auth flows have been completed successfully, raises an InteractiveAuthIncompleteError. To handle this, you can use - synapse.rest.client.v2_alpha._base.interactive_auth_handler as a + synapse.rest.client._base.interactive_auth_handler as a decorator. Args: @@ -543,7 +543,7 @@ class AuthHandler(BaseHandler): # Note that the registration endpoint explicitly removes the # "initial_device_display_name" parameter if it is provided # without a "password" parameter. See the changes to - # synapse.rest.client.v2_alpha.register.RegisterRestServlet.on_POST + # synapse.rest.client.register.RegisterRestServlet.on_POST # in commit 544722bad23fc31056b9240189c3cbbbf0ffd3f9. if not clientdict: clientdict = session.clientdict diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 9cffe59ce5..3adc576124 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -14,40 +14,36 @@ # limitations under the License. from synapse.http.server import JsonResource from synapse.rest import admin -from synapse.rest.client import versions -from synapse.rest.client.v1 import ( - directory, - events, - initial_sync, - login as v1_login, - logout, - presence, - profile, - push_rule, - pusher, - room, - voip, -) -from synapse.rest.client.v2_alpha import ( +from synapse.rest.client import ( account, account_data, account_validity, auth, capabilities, devices, + directory, + events, filter, groups, + initial_sync, keys, knock, + login as v1_login, + logout, notifications, openid, password_policy, + presence, + profile, + push_rule, + pusher, read_marker, receipts, register, relations, report_event, - room as roomv2, + room, + room_batch, room_keys, room_upgrade_rest_servlet, sendtodevice, @@ -57,6 +53,8 @@ from synapse.rest.client.v2_alpha import ( thirdparty, tokenrefresh, user_directory, + versions, + voip, ) @@ -85,7 +83,6 @@ class ClientRestResource(JsonResource): # Partially deprecated in r0 events.register_servlets(hs, client_resource) - # "v1" + "r0" room.register_servlets(hs, client_resource) v1_login.register_servlets(hs, client_resource) profile.register_servlets(hs, client_resource) @@ -95,8 +92,6 @@ class ClientRestResource(JsonResource): pusher.register_servlets(hs, client_resource) push_rule.register_servlets(hs, client_resource) logout.register_servlets(hs, client_resource) - - # "v2" sync.register_servlets(hs, client_resource) filter.register_servlets(hs, client_resource) account.register_servlets(hs, client_resource) @@ -118,7 +113,7 @@ class ClientRestResource(JsonResource): user_directory.register_servlets(hs, client_resource) groups.register_servlets(hs, client_resource) room_upgrade_rest_servlet.register_servlets(hs, client_resource) - roomv2.register_servlets(hs, client_resource) + room_batch.register_servlets(hs, client_resource) capabilities.register_servlets(hs, client_resource) account_validity.register_servlets(hs, client_resource) relations.register_servlets(hs, client_resource) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index c885fd77ab..93193b0864 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -34,7 +34,7 @@ from synapse.rest.admin._base import ( assert_requester_is_admin, assert_user_is_admin, ) -from synapse.rest.client.v2_alpha._base import client_patterns +from synapse.rest.client._base import client_patterns from synapse.storage.databases.main.media_repository import MediaSortOrder from synapse.storage.databases.main.stats import UserSortOrder from synapse.types import JsonDict, UserID @@ -504,7 +504,7 @@ class UserRegisterServlet(RestServlet): raise SynapseError(403, "HMAC incorrect") # Reuse the parts of RegisterRestServlet to reduce code duplication - from synapse.rest.client.v2_alpha.register import RegisterRestServlet + from synapse.rest.client.register import RegisterRestServlet register = RegisterRestServlet(self.hs) diff --git a/synapse/rest/client/__init__.py b/synapse/rest/client/__init__.py index 629e2df74a..f9830cc51f 100644 --- a/synapse/rest/client/__init__.py +++ b/synapse/rest/client/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2014-2016 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py new file mode 100644 index 0000000000..0443f4571c --- /dev/null +++ b/synapse/rest/client/_base.py @@ -0,0 +1,100 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains base REST classes for constructing client v1 servlets. +""" +import logging +import re +from typing import Iterable, Pattern + +from synapse.api.errors import InteractiveAuthIncompleteError +from synapse.api.urls import CLIENT_API_PREFIX +from synapse.types import JsonDict + +logger = logging.getLogger(__name__) + + +def client_patterns( + path_regex: str, + releases: Iterable[int] = (0,), + unstable: bool = True, + v1: bool = False, +) -> Iterable[Pattern]: + """Creates a regex compiled client path with the correct client path + prefix. + + Args: + path_regex: The regex string to match. This should NOT have a ^ + as this will be prefixed. + releases: An iterable of releases to include this endpoint under. + unstable: If true, include this endpoint under the "unstable" prefix. + v1: If true, include this endpoint under the "api/v1" prefix. + Returns: + An iterable of patterns. + """ + patterns = [] + + if unstable: + unstable_prefix = CLIENT_API_PREFIX + "/unstable" + patterns.append(re.compile("^" + unstable_prefix + path_regex)) + if v1: + v1_prefix = CLIENT_API_PREFIX + "/api/v1" + patterns.append(re.compile("^" + v1_prefix + path_regex)) + for release in releases: + new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,) + patterns.append(re.compile("^" + new_prefix + path_regex)) + + return patterns + + +def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) -> None: + """ + Enforces a maximum limit of a timeline query. + + Params: + filter_json: The timeline query to modify. + filter_timeline_limit: The maximum limit to allow, passing -1 will + disable enforcing a maximum limit. + """ + if filter_timeline_limit < 0: + return # no upper limits + timeline = filter_json.get("room", {}).get("timeline", {}) + if "limit" in timeline: + filter_json["room"]["timeline"]["limit"] = min( + filter_json["room"]["timeline"]["limit"], filter_timeline_limit + ) + + +def interactive_auth_handler(orig): + """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors + + Takes a on_POST method which returns an Awaitable (errcode, body) response + and adds exception handling to turn a InteractiveAuthIncompleteError into + a 401 response. + + Normal usage is: + + @interactive_auth_handler + async def on_POST(self, request): + # ... + await self.auth_handler.check_auth + """ + + async def wrapped(*args, **kwargs): + try: + return await orig(*args, **kwargs) + except InteractiveAuthIncompleteError as e: + return 401, e.result + + return wrapped diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py new file mode 100644 index 0000000000..fb5ad2906e --- /dev/null +++ b/synapse/rest/client/account.py @@ -0,0 +1,910 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import random +from http import HTTPStatus +from typing import TYPE_CHECKING +from urllib.parse import urlparse + +from synapse.api.constants import LoginType +from synapse.api.errors import ( + Codes, + InteractiveAuthIncompleteError, + SynapseError, + ThreepidValidationError, +) +from synapse.config.emailconfig import ThreepidBehaviour +from synapse.handlers.ui_auth import UIAuthSessionDataConstants +from synapse.http.server import finish_request, respond_with_html +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, + parse_string, +) +from synapse.metrics import threepid_send_requests +from synapse.push.mailer import Mailer +from synapse.util.msisdn import phone_number_to_msisdn +from synapse.util.stringutils import assert_valid_client_secret, random_string +from synapse.util.threepids import check_3pid_allowed, validate_email + +from ._base import client_patterns, interactive_auth_handler + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +logger = logging.getLogger(__name__) + + +class EmailPasswordRequestTokenRestServlet(RestServlet): + PATTERNS = client_patterns("/account/password/email/requestToken$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.hs = hs + self.datastore = hs.get_datastore() + self.config = hs.config + self.identity_handler = hs.get_identity_handler() + + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + self.mailer = Mailer( + hs=self.hs, + app_name=self.config.email_app_name, + template_html=self.config.email_password_reset_template_html, + template_text=self.config.email_password_reset_template_text, + ) + + async def on_POST(self, request): + if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.local_threepid_handling_disabled_due_to_email_config: + logger.warning( + "User password resets have been disabled due to lack of email config" + ) + raise SynapseError( + 400, "Email-based password resets have been disabled on this server" + ) + + body = parse_json_object_from_request(request) + + assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) + + # Extract params from body + client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + + # Canonicalise the email address. The addresses are all stored canonicalised + # in the database. This allows the user to reset his password without having to + # know the exact spelling (eg. upper and lower case) of address in the database. + # Stored in the database "foo@bar.com" + # User requests with "FOO@bar.com" would raise a Not Found error + try: + email = validate_email(body["email"]) + except ValueError as e: + raise SynapseError(400, str(e)) + send_attempt = body["send_attempt"] + next_link = body.get("next_link") # Optional param + + if next_link: + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) + + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) + + # The email will be sent to the stored address. + # This avoids a potential account hijack by requesting a password reset to + # an email address which is controlled by the attacker but which, after + # canonicalisation, matches the one in our database. + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + "email", email + ) + + if existing_user_id is None: + if self.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) + return 200, {"sid": random_string(16)} + + raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) + + if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + assert self.hs.config.account_threepid_delegate_email + + # Have the configured identity server handle the request + ret = await self.identity_handler.requestEmailToken( + self.hs.config.account_threepid_delegate_email, + email, + client_secret, + send_attempt, + next_link, + ) + else: + # Send password reset emails from Synapse + sid = await self.identity_handler.send_threepid_validation( + email, + client_secret, + send_attempt, + self.mailer.send_password_reset_mail, + next_link, + ) + + # Wrap the session id in a JSON object + ret = {"sid": sid} + + threepid_send_requests.labels(type="email", reason="password_reset").observe( + send_attempt + ) + + return 200, ret + + +class PasswordRestServlet(RestServlet): + PATTERNS = client_patterns("/account/password$") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + self.datastore = self.hs.get_datastore() + self.password_policy_handler = hs.get_password_policy_handler() + self._set_password_handler = hs.get_set_password_handler() + + @interactive_auth_handler + async def on_POST(self, request): + body = parse_json_object_from_request(request) + + # we do basic sanity checks here because the auth layer will store these + # in sessions. Pull out the new password provided to us. + new_password = body.pop("new_password", None) + if new_password is not None: + if not isinstance(new_password, str) or len(new_password) > 512: + raise SynapseError(400, "Invalid password") + self.password_policy_handler.validate_password(new_password) + + # there are two possibilities here. Either the user does not have an + # access token, and needs to do a password reset; or they have one and + # need to validate their identity. + # + # In the first case, we offer a couple of means of identifying + # themselves (email and msisdn, though it's unclear if msisdn actually + # works). + # + # In the second case, we require a password to confirm their identity. + + if self.auth.has_access_token(request): + requester = await self.auth.get_user_by_req(request) + try: + params, session_id = await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "modify your account password", + ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth, but + # they're not required to provide the password again. + # + # If a password is available now, hash the provided password and + # store it for later. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + await self.auth_handler.set_session_data( + e.session_id, + UIAuthSessionDataConstants.PASSWORD_HASH, + password_hash, + ) + raise + user_id = requester.user.to_string() + else: + requester = None + try: + result, params, session_id = await self.auth_handler.check_ui_auth( + [[LoginType.EMAIL_IDENTITY]], + request, + body, + "modify your account password", + ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth, but + # they're not required to provide the password again. + # + # If a password is available now, hash the provided password and + # store it for later. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + await self.auth_handler.set_session_data( + e.session_id, + UIAuthSessionDataConstants.PASSWORD_HASH, + password_hash, + ) + raise + + if LoginType.EMAIL_IDENTITY in result: + threepid = result[LoginType.EMAIL_IDENTITY] + if "medium" not in threepid or "address" not in threepid: + raise SynapseError(500, "Malformed threepid") + if threepid["medium"] == "email": + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See add_threepid in synapse/handlers/auth.py) + try: + threepid["address"] = validate_email(threepid["address"]) + except ValueError as e: + raise SynapseError(400, str(e)) + # if using email, we must know about the email they're authing with! + threepid_user_id = await self.datastore.get_user_id_by_threepid( + threepid["medium"], threepid["address"] + ) + if not threepid_user_id: + raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) + user_id = threepid_user_id + else: + logger.error("Auth succeeded but no known type! %r", result.keys()) + raise SynapseError(500, "", Codes.UNKNOWN) + + # If we have a password in this request, prefer it. Otherwise, use the + # password hash from an earlier request. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + elif session_id is not None: + password_hash = await self.auth_handler.get_session_data( + session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None + ) + else: + # UI validation was skipped, but the request did not include a new + # password. + password_hash = None + if not password_hash: + raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) + + logout_devices = params.get("logout_devices", True) + + await self._set_password_handler.set_password( + user_id, password_hash, logout_devices, requester + ) + + return 200, {} + + +class DeactivateAccountRestServlet(RestServlet): + PATTERNS = client_patterns("/account/deactivate$") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + self._deactivate_account_handler = hs.get_deactivate_account_handler() + + @interactive_auth_handler + async def on_POST(self, request): + body = parse_json_object_from_request(request) + erase = body.get("erase", False) + if not isinstance(erase, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'erase' must be a boolean, if given", + Codes.BAD_JSON, + ) + + requester = await self.auth.get_user_by_req(request) + + # allow ASes to deactivate their own users + if requester.app_service: + await self._deactivate_account_handler.deactivate_account( + requester.user.to_string(), erase, requester + ) + return 200, {} + + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "deactivate your account", + ) + result = await self._deactivate_account_handler.deactivate_account( + requester.user.to_string(), + erase, + requester, + id_server=body.get("id_server"), + ) + if result: + id_server_unbind_result = "success" + else: + id_server_unbind_result = "no-support" + + return 200, {"id_server_unbind_result": id_server_unbind_result} + + +class EmailThreepidRequestTokenRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid/email/requestToken$") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.config = hs.config + self.identity_handler = hs.get_identity_handler() + self.store = self.hs.get_datastore() + + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + self.mailer = Mailer( + hs=self.hs, + app_name=self.config.email_app_name, + template_html=self.config.email_add_threepid_template_html, + template_text=self.config.email_add_threepid_template_text, + ) + + async def on_POST(self, request): + if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.local_threepid_handling_disabled_due_to_email_config: + logger.warning( + "Adding emails have been disabled due to lack of an email config" + ) + raise SynapseError( + 400, "Adding an email to your account is disabled on this server" + ) + + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) + client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + + # Canonicalise the email address. The addresses are all stored canonicalised + # in the database. + # This ensures that the validation email is sent to the canonicalised address + # as it will later be entered into the database. + # Otherwise the email will be sent to "FOO@bar.com" and stored as + # "foo@bar.com" in database. + try: + email = validate_email(body["email"]) + except ValueError as e: + raise SynapseError(400, str(e)) + send_attempt = body["send_attempt"] + next_link = body.get("next_link") # Optional param + + if not check_3pid_allowed(self.hs, "email", email): + raise SynapseError( + 403, + "Your email domain is not authorized on this server", + Codes.THREEPID_DENIED, + ) + + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) + + if next_link: + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) + + existing_user_id = await self.store.get_user_id_by_threepid("email", email) + + if existing_user_id is not None: + if self.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) + return 200, {"sid": random_string(16)} + + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) + + if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + assert self.hs.config.account_threepid_delegate_email + + # Have the configured identity server handle the request + ret = await self.identity_handler.requestEmailToken( + self.hs.config.account_threepid_delegate_email, + email, + client_secret, + send_attempt, + next_link, + ) + else: + # Send threepid validation emails from Synapse + sid = await self.identity_handler.send_threepid_validation( + email, + client_secret, + send_attempt, + self.mailer.send_add_threepid_mail, + next_link, + ) + + # Wrap the session id in a JSON object + ret = {"sid": sid} + + threepid_send_requests.labels(type="email", reason="add_threepid").observe( + send_attempt + ) + + return 200, ret + + +class MsisdnThreepidRequestTokenRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + super().__init__() + self.store = self.hs.get_datastore() + self.identity_handler = hs.get_identity_handler() + + async def on_POST(self, request): + body = parse_json_object_from_request(request) + assert_params_in_dict( + body, ["client_secret", "country", "phone_number", "send_attempt"] + ) + client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + + country = body["country"] + phone_number = body["phone_number"] + send_attempt = body["send_attempt"] + next_link = body.get("next_link") # Optional param + + msisdn = phone_number_to_msisdn(country, phone_number) + + if not check_3pid_allowed(self.hs, "msisdn", msisdn): + raise SynapseError( + 403, + "Account phone numbers are not authorized on this server", + Codes.THREEPID_DENIED, + ) + + await self.identity_handler.ratelimit_request_token_requests( + request, "msisdn", msisdn + ) + + if next_link: + # Raise if the provided next_link value isn't valid + assert_valid_next_link(self.hs, next_link) + + existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) + + if existing_user_id is not None: + if self.hs.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) + return 200, {"sid": random_string(16)} + + raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) + + if not self.hs.config.account_threepid_delegate_msisdn: + logger.warning( + "No upstream msisdn account_threepid_delegate configured on the server to " + "handle this request" + ) + raise SynapseError( + 400, + "Adding phone numbers to user account is not supported by this homeserver", + ) + + ret = await self.identity_handler.requestMsisdnToken( + self.hs.config.account_threepid_delegate_msisdn, + country, + phone_number, + client_secret, + send_attempt, + next_link, + ) + + threepid_send_requests.labels(type="msisdn", reason="add_threepid").observe( + send_attempt + ) + + return 200, ret + + +class AddThreepidEmailSubmitTokenServlet(RestServlet): + """Handles 3PID validation token submission for adding an email to a user's account""" + + PATTERNS = client_patterns( + "/add_threepid/email/submit_token$", releases=(), unstable=True + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.config = hs.config + self.clock = hs.get_clock() + self.store = hs.get_datastore() + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + self._failure_email_template = ( + self.config.email_add_threepid_template_failure_html + ) + + async def on_GET(self, request): + if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.local_threepid_handling_disabled_due_to_email_config: + logger.warning( + "Adding emails have been disabled due to lack of an email config" + ) + raise SynapseError( + 400, "Adding an email to your account is disabled on this server" + ) + elif self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + raise SynapseError( + 400, + "This homeserver is not validating threepids. Use an identity server " + "instead.", + ) + + sid = parse_string(request, "sid", required=True) + token = parse_string(request, "token", required=True) + client_secret = parse_string(request, "client_secret", required=True) + assert_valid_client_secret(client_secret) + + # Attempt to validate a 3PID session + try: + # Mark the session as valid + next_link = await self.store.validate_threepid_session( + sid, client_secret, token, self.clock.time_msec() + ) + + # Perform a 302 redirect if next_link is set + if next_link: + request.setResponseCode(302) + request.setHeader("Location", next_link) + finish_request(request) + return None + + # Otherwise show the success template + html = self.config.email_add_threepid_template_success_html_content + status_code = 200 + except ThreepidValidationError as e: + status_code = e.code + + # Show a failure page with a reason + template_vars = {"failure_reason": e.msg} + html = self._failure_email_template.render(**template_vars) + + respond_with_html(request, status_code, html) + + +class AddThreepidMsisdnSubmitTokenServlet(RestServlet): + """Handles 3PID validation token submission for adding a phone number to a user's + account + """ + + PATTERNS = client_patterns( + "/add_threepid/msisdn/submit_token$", releases=(), unstable=True + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.config = hs.config + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.identity_handler = hs.get_identity_handler() + + async def on_POST(self, request): + if not self.config.account_threepid_delegate_msisdn: + raise SynapseError( + 400, + "This homeserver is not validating phone numbers. Use an identity server " + "instead.", + ) + + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ["client_secret", "sid", "token"]) + assert_valid_client_secret(body["client_secret"]) + + # Proxy submit_token request to msisdn threepid delegate + response = await self.identity_handler.proxy_msisdn_submit_token( + self.config.account_threepid_delegate_msisdn, + body["client_secret"], + body["sid"], + body["token"], + ) + return 200, response + + +class ThreepidRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid$") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.identity_handler = hs.get_identity_handler() + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + self.datastore = self.hs.get_datastore() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) + + threepids = await self.datastore.user_get_threepids(requester.user.to_string()) + + return 200, {"threepids": threepids} + + async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds") + if threepid_creds is None: + raise SynapseError( + 400, "Missing param three_pid_creds", Codes.MISSING_PARAM + ) + assert_params_in_dict(threepid_creds, ["client_secret", "sid"]) + + sid = threepid_creds["sid"] + client_secret = threepid_creds["client_secret"] + assert_valid_client_secret(client_secret) + + validation_session = await self.identity_handler.validate_threepid_session( + client_secret, sid + ) + if validation_session: + await self.auth_handler.add_threepid( + user_id, + validation_session["medium"], + validation_session["address"], + validation_session["validated_at"], + ) + return 200, {} + + raise SynapseError( + 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED + ) + + +class ThreepidAddRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid/add$") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.identity_handler = hs.get_identity_handler() + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + + @interactive_auth_handler + async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + assert_params_in_dict(body, ["client_secret", "sid"]) + sid = body["sid"] + client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "add a third-party identifier to your account", + ) + + validation_session = await self.identity_handler.validate_threepid_session( + client_secret, sid + ) + if validation_session: + await self.auth_handler.add_threepid( + user_id, + validation_session["medium"], + validation_session["address"], + validation_session["validated_at"], + ) + return 200, {} + + raise SynapseError( + 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED + ) + + +class ThreepidBindRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid/bind$") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.identity_handler = hs.get_identity_handler() + self.auth = hs.get_auth() + + async def on_POST(self, request): + body = parse_json_object_from_request(request) + + assert_params_in_dict(body, ["id_server", "sid", "client_secret"]) + id_server = body["id_server"] + sid = body["sid"] + id_access_token = body.get("id_access_token") # optional + client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + await self.identity_handler.bind_threepid( + client_secret, sid, user_id, id_server, id_access_token + ) + + return 200, {} + + +class ThreepidUnbindRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid/unbind$") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.identity_handler = hs.get_identity_handler() + self.auth = hs.get_auth() + self.datastore = self.hs.get_datastore() + + async def on_POST(self, request): + """Unbind the given 3pid from a specific identity server, or identity servers that are + known to have this 3pid bound + """ + requester = await self.auth.get_user_by_req(request) + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ["medium", "address"]) + + medium = body.get("medium") + address = body.get("address") + id_server = body.get("id_server") + + # Attempt to unbind the threepid from an identity server. If id_server is None, try to + # unbind from all identity servers this threepid has been added to in the past + result = await self.identity_handler.try_unbind_threepid( + requester.user.to_string(), + {"address": address, "medium": medium, "id_server": id_server}, + ) + return 200, {"id_server_unbind_result": "success" if result else "no-support"} + + +class ThreepidDeleteRestServlet(RestServlet): + PATTERNS = client_patterns("/account/3pid/delete$") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + + async def on_POST(self, request): + if not self.hs.config.enable_3pid_changes: + raise SynapseError( + 400, "3PID changes are disabled on this server", Codes.FORBIDDEN + ) + + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ["medium", "address"]) + + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + try: + ret = await self.auth_handler.delete_threepid( + user_id, body["medium"], body["address"], body.get("id_server") + ) + except Exception: + # NB. This endpoint should succeed if there is nothing to + # delete, so it should only throw if something is wrong + # that we ought to care about. + logger.exception("Failed to remove threepid") + raise SynapseError(500, "Failed to remove threepid") + + if ret: + id_server_unbind_result = "success" + else: + id_server_unbind_result = "no-support" + + return 200, {"id_server_unbind_result": id_server_unbind_result} + + +def assert_valid_next_link(hs: "HomeServer", next_link: str): + """ + Raises a SynapseError if a given next_link value is invalid + + next_link is valid if the scheme is http(s) and the next_link.domain_whitelist config + option is either empty or contains a domain that matches the one in the given next_link + + Args: + hs: The homeserver object + next_link: The next_link value given by the client + + Raises: + SynapseError: If the next_link is invalid + """ + valid = True + + # Parse the contents of the URL + next_link_parsed = urlparse(next_link) + + # Scheme must not point to the local drive + if next_link_parsed.scheme == "file": + valid = False + + # If the domain whitelist is set, the domain must be in it + if ( + valid + and hs.config.next_link_domain_whitelist is not None + and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist + ): + valid = False + + if not valid: + raise SynapseError( + 400, + "'next_link' domain not included in whitelist, or not http(s)", + errcode=Codes.INVALID_PARAM, + ) + + +class WhoamiRestServlet(RestServlet): + PATTERNS = client_patterns("/account/whoami$") + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) + + response = {"user_id": requester.user.to_string()} + + # Appservices and similar accounts do not have device IDs + # that we can report on, so exclude them for compliance. + if requester.device_id is not None: + response["device_id"] = requester.device_id + + return 200, response + + +def register_servlets(hs, http_server): + EmailPasswordRequestTokenRestServlet(hs).register(http_server) + PasswordRestServlet(hs).register(http_server) + DeactivateAccountRestServlet(hs).register(http_server) + EmailThreepidRequestTokenRestServlet(hs).register(http_server) + MsisdnThreepidRequestTokenRestServlet(hs).register(http_server) + AddThreepidEmailSubmitTokenServlet(hs).register(http_server) + AddThreepidMsisdnSubmitTokenServlet(hs).register(http_server) + ThreepidRestServlet(hs).register(http_server) + ThreepidAddRestServlet(hs).register(http_server) + ThreepidBindRestServlet(hs).register(http_server) + ThreepidUnbindRestServlet(hs).register(http_server) + ThreepidDeleteRestServlet(hs).register(http_server) + WhoamiRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py new file mode 100644 index 0000000000..7517e9304e --- /dev/null +++ b/synapse/rest/client/account_data.py @@ -0,0 +1,122 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import AuthError, NotFoundError, SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class AccountDataServlet(RestServlet): + """ + PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1 + GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1 + """ + + PATTERNS = client_patterns( + "/user/(?P[^/]*)/account_data/(?P[^/]*)" + ) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.handler = hs.get_account_data_handler() + + async def on_PUT(self, request, user_id, account_data_type): + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot add account data for other users.") + + body = parse_json_object_from_request(request) + + await self.handler.add_account_data_for_user(user_id, account_data_type, body) + + return 200, {} + + async def on_GET(self, request, user_id, account_data_type): + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot get account data for other users.") + + event = await self.store.get_global_account_data_by_type_for_user( + account_data_type, user_id + ) + + if event is None: + raise NotFoundError("Account data not found") + + return 200, event + + +class RoomAccountDataServlet(RestServlet): + """ + PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 + GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 + """ + + PATTERNS = client_patterns( + "/user/(?P[^/]*)" + "/rooms/(?P[^/]*)" + "/account_data/(?P[^/]*)" + ) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.handler = hs.get_account_data_handler() + + async def on_PUT(self, request, user_id, room_id, account_data_type): + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot add account data for other users.") + + body = parse_json_object_from_request(request) + + if account_data_type == "m.fully_read": + raise SynapseError( + 405, + "Cannot set m.fully_read through this API." + " Use /rooms/!roomId:server.name/read_markers", + ) + + await self.handler.add_account_data_to_room( + user_id, room_id, account_data_type, body + ) + + return 200, {} + + async def on_GET(self, request, user_id, room_id, account_data_type): + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot get account data for other users.") + + event = await self.store.get_account_data_for_room_and_type( + user_id, room_id, account_data_type + ) + + if event is None: + raise NotFoundError("Room account data not found") + + return 200, event + + +def register_servlets(hs, http_server): + AccountDataServlet(hs).register(http_server) + RoomAccountDataServlet(hs).register(http_server) diff --git a/synapse/rest/client/account_validity.py b/synapse/rest/client/account_validity.py new file mode 100644 index 0000000000..3ebe401861 --- /dev/null +++ b/synapse/rest/client/account_validity.py @@ -0,0 +1,104 @@ +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import SynapseError +from synapse.http.server import respond_with_html +from synapse.http.servlet import RestServlet + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class AccountValidityRenewServlet(RestServlet): + PATTERNS = client_patterns("/account_validity/renew$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + + self.hs = hs + self.account_activity_handler = hs.get_account_validity_handler() + self.auth = hs.get_auth() + self.account_renewed_template = ( + hs.config.account_validity.account_validity_account_renewed_template + ) + self.account_previously_renewed_template = ( + hs.config.account_validity.account_validity_account_previously_renewed_template + ) + self.invalid_token_template = ( + hs.config.account_validity.account_validity_invalid_token_template + ) + + async def on_GET(self, request): + if b"token" not in request.args: + raise SynapseError(400, "Missing renewal token") + renewal_token = request.args[b"token"][0] + + ( + token_valid, + token_stale, + expiration_ts, + ) = await self.account_activity_handler.renew_account( + renewal_token.decode("utf8") + ) + + if token_valid: + status_code = 200 + response = self.account_renewed_template.render(expiration_ts=expiration_ts) + elif token_stale: + status_code = 200 + response = self.account_previously_renewed_template.render( + expiration_ts=expiration_ts + ) + else: + status_code = 404 + response = self.invalid_token_template.render(expiration_ts=expiration_ts) + + respond_with_html(request, status_code, response) + + +class AccountValiditySendMailServlet(RestServlet): + PATTERNS = client_patterns("/account_validity/send_mail$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + + self.hs = hs + self.account_activity_handler = hs.get_account_validity_handler() + self.auth = hs.get_auth() + self.account_validity_renew_by_email_enabled = ( + hs.config.account_validity.account_validity_renew_by_email_enabled + ) + + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_expired=True) + user_id = requester.user.to_string() + await self.account_activity_handler.send_renewal_email_to_user(user_id) + + return 200, {} + + +def register_servlets(hs, http_server): + AccountValidityRenewServlet(hs).register(http_server) + AccountValiditySendMailServlet(hs).register(http_server) diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py new file mode 100644 index 0000000000..6ea1b50a62 --- /dev/null +++ b/synapse/rest/client/auth.py @@ -0,0 +1,143 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from synapse.api.constants import LoginType +from synapse.api.errors import SynapseError +from synapse.api.urls import CLIENT_API_PREFIX +from synapse.http.server import respond_with_html +from synapse.http.servlet import RestServlet, parse_string + +from ._base import client_patterns + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class AuthRestServlet(RestServlet): + """ + Handles Client / Server API authentication in any situations where it + cannot be handled in the normal flow (with requests to the same endpoint). + Current use is for web fallback auth. + """ + + PATTERNS = client_patterns(r"/auth/(?P[\w\.]*)/fallback/web") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + self.registration_handler = hs.get_registration_handler() + self.recaptcha_template = hs.config.recaptcha_template + self.terms_template = hs.config.terms_template + self.success_template = hs.config.fallback_success_template + + async def on_GET(self, request, stagetype): + session = parse_string(request, "session") + if not session: + raise SynapseError(400, "No session supplied") + + if stagetype == LoginType.RECAPTCHA: + html = self.recaptcha_template.render( + session=session, + myurl="%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), + sitekey=self.hs.config.recaptcha_public_key, + ) + elif stagetype == LoginType.TERMS: + html = self.terms_template.render( + session=session, + terms_url="%s_matrix/consent?v=%s" + % (self.hs.config.public_baseurl, self.hs.config.user_consent_version), + myurl="%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.TERMS), + ) + + elif stagetype == LoginType.SSO: + # Display a confirmation page which prompts the user to + # re-authenticate with their SSO provider. + html = await self.auth_handler.start_sso_ui_auth(request, session) + + else: + raise SynapseError(404, "Unknown auth stage type") + + # Render the HTML and return. + respond_with_html(request, 200, html) + return None + + async def on_POST(self, request, stagetype): + + session = parse_string(request, "session") + if not session: + raise SynapseError(400, "No session supplied") + + if stagetype == LoginType.RECAPTCHA: + response = parse_string(request, "g-recaptcha-response") + + if not response: + raise SynapseError(400, "No captcha response supplied") + + authdict = {"response": response, "session": session} + + success = await self.auth_handler.add_oob_auth( + LoginType.RECAPTCHA, authdict, request.getClientIP() + ) + + if success: + html = self.success_template.render() + else: + html = self.recaptcha_template.render( + session=session, + myurl="%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), + sitekey=self.hs.config.recaptcha_public_key, + ) + elif stagetype == LoginType.TERMS: + authdict = {"session": session} + + success = await self.auth_handler.add_oob_auth( + LoginType.TERMS, authdict, request.getClientIP() + ) + + if success: + html = self.success_template.render() + else: + html = self.terms_template.render( + session=session, + terms_url="%s_matrix/consent?v=%s" + % ( + self.hs.config.public_baseurl, + self.hs.config.user_consent_version, + ), + myurl="%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.TERMS), + ) + elif stagetype == LoginType.SSO: + # The SSO fallback workflow should not post here, + raise SynapseError(404, "Fallback SSO auth does not support POST requests.") + else: + raise SynapseError(404, "Unknown auth stage type") + + # Render the HTML and return. + respond_with_html(request, 200, html) + return None + + +def register_servlets(hs, http_server): + AuthRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py new file mode 100644 index 0000000000..88e3aac797 --- /dev/null +++ b/synapse/rest/client/capabilities.py @@ -0,0 +1,68 @@ +# Copyright 2019 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Tuple + +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES +from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict + +from ._base import client_patterns + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class CapabilitiesRestServlet(RestServlet): + """End point to expose the capabilities of the server.""" + + PATTERNS = client_patterns("/capabilities$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.hs = hs + self.config = hs.config + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.auth.get_user_by_req(request, allow_guest=True) + change_password = self.auth_handler.can_change_password() + + response = { + "capabilities": { + "m.room_versions": { + "default": self.config.default_room_version.identifier, + "available": { + v.identifier: v.disposition + for v in KNOWN_ROOM_VERSIONS.values() + }, + }, + "m.change_password": {"enabled": change_password}, + } + } + + if self.config.experimental.msc3244_enabled: + response["capabilities"]["m.room_versions"][ + "org.matrix.msc3244.room_capabilities" + ] = MSC3244_CAPABILITIES + + return 200, response + + +def register_servlets(hs: "HomeServer", http_server): + CapabilitiesRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py new file mode 100644 index 0000000000..8b9674db06 --- /dev/null +++ b/synapse/rest/client/devices.py @@ -0,0 +1,300 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api import errors +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from synapse.http.site import SynapseRequest + +from ._base import client_patterns, interactive_auth_handler + +logger = logging.getLogger(__name__) + + +class DevicesRestServlet(RestServlet): + PATTERNS = client_patterns("/devices$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + devices = await self.device_handler.get_devices_by_user( + requester.user.to_string() + ) + return 200, {"devices": devices} + + +class DeleteDevicesRestServlet(RestServlet): + """ + API for bulk deletion of devices. Accepts a JSON object with a devices + key which lists the device_ids to delete. Requires user interactive auth. + """ + + PATTERNS = client_patterns("/delete_devices") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + self.auth_handler = hs.get_auth_handler() + + @interactive_auth_handler + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) + + try: + body = parse_json_object_from_request(request) + except errors.SynapseError as e: + if e.errcode == errors.Codes.NOT_JSON: + # DELETE + # deal with older clients which didn't pass a JSON dict + # the same as those that pass an empty dict + body = {} + else: + raise e + + assert_params_in_dict(body, ["devices"]) + + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "remove device(s) from your account", + # Users might call this multiple times in a row while cleaning up + # devices, allow a single UI auth session to be re-used. + can_skip_ui_auth=True, + ) + + await self.device_handler.delete_devices( + requester.user.to_string(), body["devices"] + ) + return 200, {} + + +class DeviceRestServlet(RestServlet): + PATTERNS = client_patterns("/devices/(?P[^/]*)$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + self.auth_handler = hs.get_auth_handler() + + async def on_GET(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + device = await self.device_handler.get_device( + requester.user.to_string(), device_id + ) + return 200, device + + @interactive_auth_handler + async def on_DELETE(self, request, device_id): + requester = await self.auth.get_user_by_req(request) + + try: + body = parse_json_object_from_request(request) + + except errors.SynapseError as e: + if e.errcode == errors.Codes.NOT_JSON: + # deal with older clients which didn't pass a JSON dict + # the same as those that pass an empty dict + body = {} + else: + raise + + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "remove a device from your account", + # Users might call this multiple times in a row while cleaning up + # devices, allow a single UI auth session to be re-used. + can_skip_ui_auth=True, + ) + + await self.device_handler.delete_device(requester.user.to_string(), device_id) + return 200, {} + + async def on_PUT(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + body = parse_json_object_from_request(request) + await self.device_handler.update_device( + requester.user.to_string(), device_id, body + ) + return 200, {} + + +class DehydratedDeviceServlet(RestServlet): + """Retrieve or store a dehydrated device. + + GET /org.matrix.msc2697.v2/dehydrated_device + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "device_id": "dehydrated_device_id", + "device_data": { + "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", + "account": "dehydrated_device" + } + } + + PUT /org.matrix.msc2697/dehydrated_device + Content-Type: application/json + + { + "device_data": { + "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", + "account": "dehydrated_device" + } + } + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "device_id": "dehydrated_device_id" + } + + """ + + PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=()) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + async def on_GET(self, request: SynapseRequest): + requester = await self.auth.get_user_by_req(request) + dehydrated_device = await self.device_handler.get_dehydrated_device( + requester.user.to_string() + ) + if dehydrated_device is not None: + (device_id, device_data) = dehydrated_device + result = {"device_id": device_id, "device_data": device_data} + return (200, result) + else: + raise errors.NotFoundError("No dehydrated device available") + + async def on_PUT(self, request: SynapseRequest): + submission = parse_json_object_from_request(request) + requester = await self.auth.get_user_by_req(request) + + if "device_data" not in submission: + raise errors.SynapseError( + 400, + "device_data missing", + errcode=errors.Codes.MISSING_PARAM, + ) + elif not isinstance(submission["device_data"], dict): + raise errors.SynapseError( + 400, + "device_data must be an object", + errcode=errors.Codes.INVALID_PARAM, + ) + + device_id = await self.device_handler.store_dehydrated_device( + requester.user.to_string(), + submission["device_data"], + submission.get("initial_device_display_name", None), + ) + return 200, {"device_id": device_id} + + +class ClaimDehydratedDeviceServlet(RestServlet): + """Claim a dehydrated device. + + POST /org.matrix.msc2697.v2/dehydrated_device/claim + Content-Type: application/json + + { + "device_id": "dehydrated_device_id" + } + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "success": true, + } + + """ + + PATTERNS = client_patterns( + "/org.matrix.msc2697.v2/dehydrated_device/claim", releases=() + ) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + async def on_POST(self, request: SynapseRequest): + requester = await self.auth.get_user_by_req(request) + + submission = parse_json_object_from_request(request) + + if "device_id" not in submission: + raise errors.SynapseError( + 400, + "device_id missing", + errcode=errors.Codes.MISSING_PARAM, + ) + elif not isinstance(submission["device_id"], str): + raise errors.SynapseError( + 400, + "device_id must be a string", + errcode=errors.Codes.INVALID_PARAM, + ) + + result = await self.device_handler.rehydrate_device( + requester.user.to_string(), + self.auth.get_access_token_from_request(request), + submission["device_id"], + ) + + return (200, result) + + +def register_servlets(hs, http_server): + DeleteDevicesRestServlet(hs).register(http_server) + DevicesRestServlet(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) + DehydratedDeviceServlet(hs).register(http_server) + ClaimDehydratedDeviceServlet(hs).register(http_server) diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py new file mode 100644 index 0000000000..ffa075c8e5 --- /dev/null +++ b/synapse/rest/client/directory.py @@ -0,0 +1,185 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +from synapse.api.errors import ( + AuthError, + Codes, + InvalidClientCredentialsError, + NotFoundError, + SynapseError, +) +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.rest.client._base import client_patterns +from synapse.types import RoomAlias + +logger = logging.getLogger(__name__) + + +def register_servlets(hs, http_server): + ClientDirectoryServer(hs).register(http_server) + ClientDirectoryListServer(hs).register(http_server) + ClientAppserviceDirectoryListServer(hs).register(http_server) + + +class ClientDirectoryServer(RestServlet): + PATTERNS = client_patterns("/directory/room/(?P[^/]*)$", v1=True) + + def __init__(self, hs): + super().__init__() + self.store = hs.get_datastore() + self.directory_handler = hs.get_directory_handler() + self.auth = hs.get_auth() + + async def on_GET(self, request, room_alias): + room_alias = RoomAlias.from_string(room_alias) + + res = await self.directory_handler.get_association(room_alias) + + return 200, res + + async def on_PUT(self, request, room_alias): + room_alias = RoomAlias.from_string(room_alias) + + content = parse_json_object_from_request(request) + if "room_id" not in content: + raise SynapseError( + 400, 'Missing params: ["room_id"]', errcode=Codes.BAD_JSON + ) + + logger.debug("Got content: %s", content) + logger.debug("Got room name: %s", room_alias.to_string()) + + room_id = content["room_id"] + servers = content["servers"] if "servers" in content else None + + logger.debug("Got room_id: %s", room_id) + logger.debug("Got servers: %s", servers) + + # TODO(erikj): Check types. + + room = await self.store.get_room(room_id) + if room is None: + raise SynapseError(400, "Room does not exist") + + requester = await self.auth.get_user_by_req(request) + + await self.directory_handler.create_association( + requester, room_alias, room_id, servers + ) + + return 200, {} + + async def on_DELETE(self, request, room_alias): + try: + service = self.auth.get_appservice_by_req(request) + room_alias = RoomAlias.from_string(room_alias) + await self.directory_handler.delete_appservice_association( + service, room_alias + ) + logger.info( + "Application service at %s deleted alias %s", + service.url, + room_alias.to_string(), + ) + return 200, {} + except InvalidClientCredentialsError: + # fallback to default user behaviour if they aren't an AS + pass + + requester = await self.auth.get_user_by_req(request) + user = requester.user + + room_alias = RoomAlias.from_string(room_alias) + + await self.directory_handler.delete_association(requester, room_alias) + + logger.info( + "User %s deleted alias %s", user.to_string(), room_alias.to_string() + ) + + return 200, {} + + +class ClientDirectoryListServer(RestServlet): + PATTERNS = client_patterns("/directory/list/room/(?P[^/]*)$", v1=True) + + def __init__(self, hs): + super().__init__() + self.store = hs.get_datastore() + self.directory_handler = hs.get_directory_handler() + self.auth = hs.get_auth() + + async def on_GET(self, request, room_id): + room = await self.store.get_room(room_id) + if room is None: + raise NotFoundError("Unknown room") + + return 200, {"visibility": "public" if room["is_public"] else "private"} + + async def on_PUT(self, request, room_id): + requester = await self.auth.get_user_by_req(request) + + content = parse_json_object_from_request(request) + visibility = content.get("visibility", "public") + + await self.directory_handler.edit_published_room_list( + requester, room_id, visibility + ) + + return 200, {} + + async def on_DELETE(self, request, room_id): + requester = await self.auth.get_user_by_req(request) + + await self.directory_handler.edit_published_room_list( + requester, room_id, "private" + ) + + return 200, {} + + +class ClientAppserviceDirectoryListServer(RestServlet): + PATTERNS = client_patterns( + "/directory/list/appservice/(?P[^/]*)/(?P[^/]*)$", v1=True + ) + + def __init__(self, hs): + super().__init__() + self.store = hs.get_datastore() + self.directory_handler = hs.get_directory_handler() + self.auth = hs.get_auth() + + def on_PUT(self, request, network_id, room_id): + content = parse_json_object_from_request(request) + visibility = content.get("visibility", "public") + return self._edit(request, network_id, room_id, visibility) + + def on_DELETE(self, request, network_id, room_id): + return self._edit(request, network_id, room_id, "private") + + async def _edit(self, request, network_id, room_id, visibility): + requester = await self.auth.get_user_by_req(request) + if not requester.app_service: + raise AuthError( + 403, "Only appservices can edit the appservice published room list" + ) + + await self.directory_handler.edit_published_appservice_room_list( + requester.app_service.id, network_id, room_id, visibility + ) + + return 200, {} diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py new file mode 100644 index 0000000000..52bb579cfd --- /dev/null +++ b/synapse/rest/client/events.py @@ -0,0 +1,94 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains REST servlets to do with event streaming, /events.""" +import logging + +from synapse.api.errors import SynapseError +from synapse.http.servlet import RestServlet +from synapse.rest.client._base import client_patterns +from synapse.streams.config import PaginationConfig + +logger = logging.getLogger(__name__) + + +class EventStreamRestServlet(RestServlet): + PATTERNS = client_patterns("/events$", v1=True) + + DEFAULT_LONGPOLL_TIME_MS = 30000 + + def __init__(self, hs): + super().__init__() + self.event_stream_handler = hs.get_event_stream_handler() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + is_guest = requester.is_guest + room_id = None + if is_guest: + if b"room_id" not in request.args: + raise SynapseError(400, "Guest users must specify room_id param") + if b"room_id" in request.args: + room_id = request.args[b"room_id"][0].decode("ascii") + + pagin_config = await PaginationConfig.from_request(self.store, request) + timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS + if b"timeout" in request.args: + try: + timeout = int(request.args[b"timeout"][0]) + except ValueError: + raise SynapseError(400, "timeout must be in milliseconds.") + + as_client_event = b"raw" not in request.args + + chunk = await self.event_stream_handler.get_stream( + requester.user.to_string(), + pagin_config, + timeout=timeout, + as_client_event=as_client_event, + affect_presence=(not is_guest), + room_id=room_id, + is_guest=is_guest, + ) + + return 200, chunk + + +class EventRestServlet(RestServlet): + PATTERNS = client_patterns("/events/(?P[^/]*)$", v1=True) + + def __init__(self, hs): + super().__init__() + self.clock = hs.get_clock() + self.event_handler = hs.get_event_handler() + self.auth = hs.get_auth() + self._event_serializer = hs.get_event_client_serializer() + + async def on_GET(self, request, event_id): + requester = await self.auth.get_user_by_req(request) + event = await self.event_handler.get_event(requester.user, None, event_id) + + time_now = self.clock.time_msec() + if event: + event = await self._event_serializer.serialize_event(event, time_now) + return 200, event + else: + return 404, "Event not found." + + +def register_servlets(hs, http_server): + EventStreamRestServlet(hs).register(http_server) + EventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py new file mode 100644 index 0000000000..411667a9c8 --- /dev/null +++ b/synapse/rest/client/filter.py @@ -0,0 +1,94 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.types import UserID + +from ._base import client_patterns, set_timeline_upper_limit + +logger = logging.getLogger(__name__) + + +class GetFilterRestServlet(RestServlet): + PATTERNS = client_patterns("/user/(?P[^/]*)/filter/(?P[^/]*)") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.filtering = hs.get_filtering() + + async def on_GET(self, request, user_id, filter_id): + target_user = UserID.from_string(user_id) + requester = await self.auth.get_user_by_req(request) + + if target_user != requester.user: + raise AuthError(403, "Cannot get filters for other users") + + if not self.hs.is_mine(target_user): + raise AuthError(403, "Can only get filters for local users") + + try: + filter_id = int(filter_id) + except Exception: + raise SynapseError(400, "Invalid filter_id") + + try: + filter_collection = await self.filtering.get_user_filter( + user_localpart=target_user.localpart, filter_id=filter_id + ) + except StoreError as e: + if e.code != 404: + raise + raise NotFoundError("No such filter") + + return 200, filter_collection.get_filter_json() + + +class CreateFilterRestServlet(RestServlet): + PATTERNS = client_patterns("/user/(?P[^/]*)/filter") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.filtering = hs.get_filtering() + + async def on_POST(self, request, user_id): + + target_user = UserID.from_string(user_id) + requester = await self.auth.get_user_by_req(request) + + if target_user != requester.user: + raise AuthError(403, "Cannot create filters for other users") + + if not self.hs.is_mine(target_user): + raise AuthError(403, "Can only create filters for local users") + + content = parse_json_object_from_request(request) + set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit) + + filter_id = await self.filtering.add_user_filter( + user_localpart=target_user.localpart, user_filter=content + ) + + return 200, {"filter_id": str(filter_id)} + + +def register_servlets(hs, http_server): + GetFilterRestServlet(hs).register(http_server) + CreateFilterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py new file mode 100644 index 0000000000..6285680c00 --- /dev/null +++ b/synapse/rest/client/groups.py @@ -0,0 +1,957 @@ +# Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from functools import wraps +from typing import TYPE_CHECKING, Optional, Tuple + +from twisted.web.server import Request + +from synapse.api.constants import ( + MAX_GROUP_CATEGORYID_LENGTH, + MAX_GROUP_ROLEID_LENGTH, + MAX_GROUPID_LENGTH, +) +from synapse.api.errors import Codes, SynapseError +from synapse.handlers.groups_local import GroupsLocalHandler +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from synapse.http.site import SynapseRequest +from synapse.types import GroupID, JsonDict + +from ._base import client_patterns + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +def _validate_group_id(f): + """Wrapper to validate the form of the group ID. + + Can be applied to any on_FOO methods that accepts a group ID as a URL parameter. + """ + + @wraps(f) + def wrapper(self, request: Request, group_id: str, *args, **kwargs): + if not GroupID.is_valid(group_id): + raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) + + return f(self, request, group_id, *args, **kwargs) + + return wrapper + + +class GroupServlet(RestServlet): + """Get the group profile""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/profile$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + group_description = await self.groups_handler.get_group_profile( + group_id, requester_user_id + ) + + return 200, group_description + + @_validate_group_id + async def on_POST( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + assert_params_in_dict( + content, ("name", "avatar_url", "short_description", "long_description") + ) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot create group profiles." + await self.groups_handler.update_group_profile( + group_id, requester_user_id, content + ) + + return 200, {} + + +class GroupSummaryServlet(RestServlet): + """Get the full group summary""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/summary$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + get_group_summary = await self.groups_handler.get_group_summary( + group_id, requester_user_id + ) + + return 200, get_group_summary + + +class GroupSummaryRoomsCatServlet(RestServlet): + """Update/delete a rooms entry in the summary. + + Matches both: + - /groups/:group/summary/rooms/:room_id + - /groups/:group/summary/categories/:category/rooms/:room_id + """ + + PATTERNS = client_patterns( + "/groups/(?P[^/]*)/summary" + "(/categories/(?P[^/]+))?" + "/rooms/(?P[^/]*)$" + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_PUT( + self, + request: SynapseRequest, + group_id: str, + category_id: Optional[str], + room_id: str, + ): + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + if category_id == "": + raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM) + + if category_id and len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: + raise SynapseError( + 400, + "category_id may not be longer than %s characters" + % (MAX_GROUP_CATEGORYID_LENGTH,), + Codes.INVALID_PARAM, + ) + + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group summaries." + resp = await self.groups_handler.update_group_summary_room( + group_id, + requester_user_id, + room_id=room_id, + category_id=category_id, + content=content, + ) + + return 200, resp + + @_validate_group_id + async def on_DELETE( + self, request: SynapseRequest, group_id: str, category_id: str, room_id: str + ): + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group profiles." + resp = await self.groups_handler.delete_group_summary_room( + group_id, requester_user_id, room_id=room_id, category_id=category_id + ) + + return 200, resp + + +class GroupCategoryServlet(RestServlet): + """Get/add/update/delete a group category""" + + PATTERNS = client_patterns( + "/groups/(?P[^/]*)/categories/(?P[^/]+)$" + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_GET( + self, request: SynapseRequest, group_id: str, category_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + category = await self.groups_handler.get_group_category( + group_id, requester_user_id, category_id=category_id + ) + + return 200, category + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id: str, category_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + if not category_id: + raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM) + + if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: + raise SynapseError( + 400, + "category_id may not be longer than %s characters" + % (MAX_GROUP_CATEGORYID_LENGTH,), + Codes.INVALID_PARAM, + ) + + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group categories." + resp = await self.groups_handler.update_group_category( + group_id, requester_user_id, category_id=category_id, content=content + ) + + return 200, resp + + @_validate_group_id + async def on_DELETE( + self, request: SynapseRequest, group_id: str, category_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group categories." + resp = await self.groups_handler.delete_group_category( + group_id, requester_user_id, category_id=category_id + ) + + return 200, resp + + +class GroupCategoriesServlet(RestServlet): + """Get all group categories""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/categories/$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + category = await self.groups_handler.get_group_categories( + group_id, requester_user_id + ) + + return 200, category + + +class GroupRoleServlet(RestServlet): + """Get/add/update/delete a group role""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/(?P[^/]+)$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_GET( + self, request: SynapseRequest, group_id: str, role_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + category = await self.groups_handler.get_group_role( + group_id, requester_user_id, role_id=role_id + ) + + return 200, category + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id: str, role_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + if not role_id: + raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM) + + if len(role_id) > MAX_GROUP_ROLEID_LENGTH: + raise SynapseError( + 400, + "role_id may not be longer than %s characters" + % (MAX_GROUP_ROLEID_LENGTH,), + Codes.INVALID_PARAM, + ) + + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group roles." + resp = await self.groups_handler.update_group_role( + group_id, requester_user_id, role_id=role_id, content=content + ) + + return 200, resp + + @_validate_group_id + async def on_DELETE( + self, request: SynapseRequest, group_id: str, role_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group roles." + resp = await self.groups_handler.delete_group_role( + group_id, requester_user_id, role_id=role_id + ) + + return 200, resp + + +class GroupRolesServlet(RestServlet): + """Get all group roles""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + category = await self.groups_handler.get_group_roles( + group_id, requester_user_id + ) + + return 200, category + + +class GroupSummaryUsersRoleServlet(RestServlet): + """Update/delete a user's entry in the summary. + + Matches both: + - /groups/:group/summary/users/:room_id + - /groups/:group/summary/roles/:role/users/:user_id + """ + + PATTERNS = client_patterns( + "/groups/(?P[^/]*)/summary" + "(/roles/(?P[^/]+))?" + "/users/(?P[^/]*)$" + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_PUT( + self, + request: SynapseRequest, + group_id: str, + role_id: Optional[str], + user_id: str, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + if role_id == "": + raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM) + + if role_id and len(role_id) > MAX_GROUP_ROLEID_LENGTH: + raise SynapseError( + 400, + "role_id may not be longer than %s characters" + % (MAX_GROUP_ROLEID_LENGTH,), + Codes.INVALID_PARAM, + ) + + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group summaries." + resp = await self.groups_handler.update_group_summary_user( + group_id, + requester_user_id, + user_id=user_id, + role_id=role_id, + content=content, + ) + + return 200, resp + + @_validate_group_id + async def on_DELETE( + self, request: SynapseRequest, group_id: str, role_id: str, user_id: str + ): + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group summaries." + resp = await self.groups_handler.delete_group_summary_user( + group_id, requester_user_id, user_id=user_id, role_id=role_id + ) + + return 200, resp + + +class GroupRoomServlet(RestServlet): + """Get all rooms in a group""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/rooms$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + result = await self.groups_handler.get_rooms_in_group( + group_id, requester_user_id + ) + + return 200, result + + +class GroupUsersServlet(RestServlet): + """Get all users in a group""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/users$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + result = await self.groups_handler.get_users_in_group( + group_id, requester_user_id + ) + + return 200, result + + +class GroupInvitedUsersServlet(RestServlet): + """Get users invited to a group""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/invited_users$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_GET( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + result = await self.groups_handler.get_invited_users_in_group( + group_id, requester_user_id + ) + + return 200, result + + +class GroupSettingJoinPolicyServlet(RestServlet): + """Set group join policy""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/settings/m.join_policy$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group join policy." + result = await self.groups_handler.set_group_join_policy( + group_id, requester_user_id, content + ) + + return 200, result + + +class GroupCreateServlet(RestServlet): + """Create a group""" + + PATTERNS = client_patterns("/create_group$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + self.server_name = hs.hostname + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + # TODO: Create group on remote server + content = parse_json_object_from_request(request) + localpart = content.pop("localpart") + group_id = GroupID(localpart, self.server_name).to_string() + + if not localpart: + raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) + + if len(group_id) > MAX_GROUPID_LENGTH: + raise SynapseError( + 400, + "Group ID may not be longer than %s characters" % (MAX_GROUPID_LENGTH,), + Codes.INVALID_PARAM, + ) + + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot create groups." + result = await self.groups_handler.create_group( + group_id, requester_user_id, content + ) + + return 200, result + + +class GroupAdminRoomsServlet(RestServlet): + """Add a room to the group""" + + PATTERNS = client_patterns( + "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)$" + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id: str, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify rooms in a group." + result = await self.groups_handler.add_room_to_group( + group_id, requester_user_id, room_id, content + ) + + return 200, result + + @_validate_group_id + async def on_DELETE( + self, request: SynapseRequest, group_id: str, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group categories." + result = await self.groups_handler.remove_room_from_group( + group_id, requester_user_id, room_id + ) + + return 200, result + + +class GroupAdminRoomsConfigServlet(RestServlet): + """Update the config of a room in a group""" + + PATTERNS = client_patterns( + "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)" + "/config/(?P[^/]*)$" + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id: str, room_id: str, config_key: str + ): + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group categories." + result = await self.groups_handler.update_room_in_group( + group_id, requester_user_id, room_id, config_key, content + ) + + return 200, result + + +class GroupAdminUsersInviteServlet(RestServlet): + """Invite a user to the group""" + + PATTERNS = client_patterns( + "/groups/(?P[^/]*)/admin/users/invite/(?P[^/]*)$" + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + self.store = hs.get_datastore() + self.is_mine_id = hs.is_mine_id + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id, user_id + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + config = content.get("config", {}) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot invite users to a group." + result = await self.groups_handler.invite( + group_id, user_id, requester_user_id, config + ) + + return 200, result + + +class GroupAdminUsersKickServlet(RestServlet): + """Kick a user from the group""" + + PATTERNS = client_patterns( + "/groups/(?P[^/]*)/admin/users/remove/(?P[^/]*)$" + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id, user_id + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot kick users from a group." + result = await self.groups_handler.remove_user_from_group( + group_id, user_id, requester_user_id, content + ) + + return 200, result + + +class GroupSelfLeaveServlet(RestServlet): + """Leave a joined group""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/self/leave$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot leave a group for a users." + result = await self.groups_handler.remove_user_from_group( + group_id, requester_user_id, requester_user_id, content + ) + + return 200, result + + +class GroupSelfJoinServlet(RestServlet): + """Attempt to join a group, or knock""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/self/join$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot join a user to a group." + result = await self.groups_handler.join_group( + group_id, requester_user_id, content + ) + + return 200, result + + +class GroupSelfAcceptInviteServlet(RestServlet): + """Accept a group invite""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/self/accept_invite$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot accept an invite to a group." + result = await self.groups_handler.accept_invite( + group_id, requester_user_id, content + ) + + return 200, result + + +class GroupSelfUpdatePublicityServlet(RestServlet): + """Update whether we publicise a users membership of a group""" + + PATTERNS = client_patterns("/groups/(?P[^/]*)/self/update_publicity$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.store = hs.get_datastore() + + @_validate_group_id + async def on_PUT( + self, request: SynapseRequest, group_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + requester_user_id = requester.user.to_string() + + content = parse_json_object_from_request(request) + publicise = content["publicise"] + await self.store.update_group_publicity(group_id, requester_user_id, publicise) + + return 200, {} + + +class PublicisedGroupsForUserServlet(RestServlet): + """Get the list of groups a user is advertising""" + + PATTERNS = client_patterns("/publicised_groups/(?P[^/]*)$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.groups_handler = hs.get_groups_local_handler() + + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + await self.auth.get_user_by_req(request, allow_guest=True) + + result = await self.groups_handler.get_publicised_groups_for_user(user_id) + + return 200, result + + +class PublicisedGroupsForUsersServlet(RestServlet): + """Get the list of groups a user is advertising""" + + PATTERNS = client_patterns("/publicised_groups$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.store = hs.get_datastore() + self.groups_handler = hs.get_groups_local_handler() + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.auth.get_user_by_req(request, allow_guest=True) + + content = parse_json_object_from_request(request) + user_ids = content["user_ids"] + + result = await self.groups_handler.bulk_get_publicised_groups(user_ids) + + return 200, result + + +class GroupsForUserServlet(RestServlet): + """Get all groups the logged in user is joined to""" + + PATTERNS = client_patterns("/joined_groups$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + requester_user_id = requester.user.to_string() + + result = await self.groups_handler.get_joined_groups(requester_user_id) + + return 200, result + + +def register_servlets(hs: "HomeServer", http_server): + GroupServlet(hs).register(http_server) + GroupSummaryServlet(hs).register(http_server) + GroupInvitedUsersServlet(hs).register(http_server) + GroupUsersServlet(hs).register(http_server) + GroupRoomServlet(hs).register(http_server) + GroupSettingJoinPolicyServlet(hs).register(http_server) + GroupCreateServlet(hs).register(http_server) + GroupAdminRoomsServlet(hs).register(http_server) + GroupAdminRoomsConfigServlet(hs).register(http_server) + GroupAdminUsersInviteServlet(hs).register(http_server) + GroupAdminUsersKickServlet(hs).register(http_server) + GroupSelfLeaveServlet(hs).register(http_server) + GroupSelfJoinServlet(hs).register(http_server) + GroupSelfAcceptInviteServlet(hs).register(http_server) + GroupsForUserServlet(hs).register(http_server) + GroupCategoryServlet(hs).register(http_server) + GroupCategoriesServlet(hs).register(http_server) + GroupSummaryRoomsCatServlet(hs).register(http_server) + GroupRoleServlet(hs).register(http_server) + GroupRolesServlet(hs).register(http_server) + GroupSelfUpdatePublicityServlet(hs).register(http_server) + GroupSummaryUsersRoleServlet(hs).register(http_server) + PublicisedGroupsForUserServlet(hs).register(http_server) + PublicisedGroupsForUsersServlet(hs).register(http_server) diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py new file mode 100644 index 0000000000..12ba0e91db --- /dev/null +++ b/synapse/rest/client/initial_sync.py @@ -0,0 +1,47 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.http.servlet import RestServlet, parse_boolean +from synapse.rest.client._base import client_patterns +from synapse.streams.config import PaginationConfig + + +# TODO: Needs unit testing +class InitialSyncRestServlet(RestServlet): + PATTERNS = client_patterns("/initialSync$", v1=True) + + def __init__(self, hs): + super().__init__() + self.initial_sync_handler = hs.get_initial_sync_handler() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) + as_client_event = b"raw" not in request.args + pagination_config = await PaginationConfig.from_request(self.store, request) + include_archived = parse_boolean(request, "archived", default=False) + content = await self.initial_sync_handler.snapshot_all_rooms( + user_id=requester.user.to_string(), + pagin_config=pagination_config, + as_client_event=as_client_event, + include_archived=include_archived, + ) + + return 200, content + + +def register_servlets(hs, http_server): + InitialSyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py new file mode 100644 index 0000000000..d0d9d30d40 --- /dev/null +++ b/synapse/rest/client/keys.py @@ -0,0 +1,344 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_integer, + parse_json_object_from_request, + parse_string, +) +from synapse.logging.opentracing import log_kv, set_tag, trace +from synapse.types import StreamToken + +from ._base import client_patterns, interactive_auth_handler + +logger = logging.getLogger(__name__) + + +class KeyUploadServlet(RestServlet): + """ + POST /keys/upload HTTP/1.1 + Content-Type: application/json + + { + "device_keys": { + "user_id": "", + "device_id": "", + "valid_until_ts": , + "algorithms": [ + "m.olm.curve25519-aes-sha2", + ] + "keys": { + ":": "", + }, + "signatures:" { + "" { + ":": "" + } } }, + "one_time_keys": { + ":": "" + }, + } + """ + + PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + self.device_handler = hs.get_device_handler() + + @trace(opname="upload_keys") + async def on_POST(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + if device_id is not None: + # Providing the device_id should only be done for setting keys + # for dehydrated devices; however, we allow it for any device for + # compatibility with older clients. + if requester.device_id is not None and device_id != requester.device_id: + dehydrated_device = await self.device_handler.get_dehydrated_device( + user_id + ) + if dehydrated_device is not None and device_id != dehydrated_device[0]: + set_tag("error", True) + log_kv( + { + "message": "Client uploading keys for a different device", + "logged_in_id": requester.device_id, + "key_being_uploaded": device_id, + } + ) + logger.warning( + "Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, + device_id, + ) + else: + device_id = requester.device_id + + if device_id is None: + raise SynapseError( + 400, "To upload keys, you must pass device_id when authenticating" + ) + + result = await self.e2e_keys_handler.upload_keys_for_user( + user_id, device_id, body + ) + return 200, result + + +class KeyQueryServlet(RestServlet): + """ + POST /keys/query HTTP/1.1 + Content-Type: application/json + { + "device_keys": { + "": [""] + } } + + HTTP/1.1 200 OK + { + "device_keys": { + "": { + "": { + "user_id": "", // Duplicated to be signed + "device_id": "", // Duplicated to be signed + "valid_until_ts": , + "algorithms": [ // List of supported algorithms + "m.olm.curve25519-aes-sha2", + ], + "keys": { // Must include a ed25519 signing key + ":": "", + }, + "signatures:" { + // Must be signed with device's ed25519 key + "/": { + ":": "" + } + // Must be signed by this server. + "": { + ":": "" + } } } } } } + """ + + PATTERNS = client_patterns("/keys/query$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ + super().__init__() + self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user_id = requester.user.to_string() + device_id = requester.device_id + timeout = parse_integer(request, "timeout", 10 * 1000) + body = parse_json_object_from_request(request) + result = await self.e2e_keys_handler.query_devices( + body, timeout, user_id, device_id + ) + return 200, result + + +class KeyChangesServlet(RestServlet): + """Returns the list of changes of keys between two stream tokens (may return + spurious extra results, since we currently ignore the `to` param). + + GET /keys/changes?from=...&to=... + + 200 OK + { "changed": ["@foo:example.com"] } + """ + + PATTERNS = client_patterns("/keys/changes$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ + super().__init__() + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + self.store = hs.get_datastore() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + from_token_string = parse_string(request, "from", required=True) + set_tag("from", from_token_string) + + # We want to enforce they do pass us one, but we ignore it and return + # changes after the "to" as well as before. + set_tag("to", parse_string(request, "to")) + + from_token = await StreamToken.from_string(self.store, from_token_string) + + user_id = requester.user.to_string() + + results = await self.device_handler.get_user_ids_changed(user_id, from_token) + + return 200, results + + +class OneTimeKeyServlet(RestServlet): + """ + POST /keys/claim HTTP/1.1 + { + "one_time_keys": { + "": { + "": "" + } } } + + HTTP/1.1 200 OK + { + "one_time_keys": { + "": { + "": { + ":": "" + } } } } + + """ + + PATTERNS = client_patterns("/keys/claim$") + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + + async def on_POST(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) + timeout = parse_integer(request, "timeout", 10 * 1000) + body = parse_json_object_from_request(request) + result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout) + return 200, result + + +class SigningKeyUploadServlet(RestServlet): + """ + POST /keys/device_signing/upload HTTP/1.1 + Content-Type: application/json + + { + } + """ + + PATTERNS = client_patterns("/keys/device_signing/upload$", releases=()) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + self.auth_handler = hs.get_auth_handler() + + @interactive_auth_handler + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "add a device signing key to your account", + # Allow skipping of UI auth since this is frequently called directly + # after login and it is silly to ask users to re-auth immediately. + can_skip_ui_auth=True, + ) + + result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) + return 200, result + + +class SignaturesUploadServlet(RestServlet): + """ + POST /keys/signatures/upload HTTP/1.1 + Content-Type: application/json + + { + "@alice:example.com": { + "": { + "user_id": "", + "device_id": "", + "algorithms": [ + "m.olm.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2" + ], + "keys": { + ":": "", + }, + "signatures": { + "": { + ":": ">" + } + } + } + } + } + """ + + PATTERNS = client_patterns("/keys/signatures/upload$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + result = await self.e2e_keys_handler.upload_signatures_for_device_keys( + user_id, body + ) + return 200, result + + +def register_servlets(hs, http_server): + KeyUploadServlet(hs).register(http_server) + KeyQueryServlet(hs).register(http_server) + KeyChangesServlet(hs).register(http_server) + OneTimeKeyServlet(hs).register(http_server) + SigningKeyUploadServlet(hs).register(http_server) + SignaturesUploadServlet(hs).register(http_server) diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py new file mode 100644 index 0000000000..7d1bc40658 --- /dev/null +++ b/synapse/rest/client/knock.py @@ -0,0 +1,107 @@ +# Copyright 2020 Sorunome +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + +from twisted.web.server import Request + +from synapse.api.constants import Membership +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_json_object_from_request, + parse_strings_from_args, +) +from synapse.http.site import SynapseRequest +from synapse.logging.opentracing import set_tag +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.types import JsonDict, RoomAlias, RoomID + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class KnockRoomAliasServlet(RestServlet): + """ + POST /knock/{roomIdOrAlias} + """ + + PATTERNS = client_patterns("/knock/(?P[^/]*)") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.txns = HttpTransactionCache(hs) + self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() + + async def on_POST( + self, + request: SynapseRequest, + room_identifier: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + content = parse_json_object_from_request(request) + event_content = None + if "reason" in content: + event_content = {"reason": content["reason"]} + + if RoomID.is_valid(room_identifier): + room_id = room_identifier + + # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + args: Dict[bytes, List[bytes]] = request.args # type: ignore + + remote_room_hosts = parse_strings_from_args( + args, "server_name", required=False + ) + elif RoomAlias.is_valid(room_identifier): + handler = self.room_member_handler + room_alias = RoomAlias.from_string(room_identifier) + room_id_obj, remote_room_hosts = await handler.lookup_room_alias(room_alias) + room_id = room_id_obj.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + + await self.room_member_handler.update_membership( + requester=requester, + target=requester.user, + room_id=room_id, + action=Membership.KNOCK, + txn_id=txn_id, + third_party_signed=None, + remote_room_hosts=remote_room_hosts, + content=event_content, + ) + + return 200, {"room_id": room_id} + + def on_PUT(self, request: Request, room_identifier: str, txn_id: str): + set_tag("txn_id", txn_id) + + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_identifier, txn_id + ) + + +def register_servlets(hs, http_server): + KnockRoomAliasServlet(hs).register(http_server) diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py new file mode 100644 index 0000000000..0c8d8967b7 --- /dev/null +++ b/synapse/rest/client/login.py @@ -0,0 +1,600 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional + +from typing_extensions import TypedDict + +from synapse.api.errors import Codes, LoginError, SynapseError +from synapse.api.ratelimiting import Ratelimiter +from synapse.api.urls import CLIENT_API_PREFIX +from synapse.appservice import ApplicationService +from synapse.handlers.sso import SsoIdentityProvider +from synapse.http import get_request_uri +from synapse.http.server import HttpServer, finish_request +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_boolean, + parse_bytes_from_args, + parse_json_object_from_request, + parse_string, +) +from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns +from synapse.rest.well_known import WellKnownBuilder +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class LoginResponse(TypedDict, total=False): + user_id: str + access_token: str + home_server: str + expires_in_ms: Optional[int] + refresh_token: Optional[str] + device_id: str + well_known: Optional[Dict[str, Any]] + + +class LoginRestServlet(RestServlet): + PATTERNS = client_patterns("/login$", v1=True) + CAS_TYPE = "m.login.cas" + SSO_TYPE = "m.login.sso" + TOKEN_TYPE = "m.login.token" + JWT_TYPE = "org.matrix.login.jwt" + JWT_TYPE_DEPRECATED = "m.login.jwt" + APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" + REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token" + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.hs = hs + + # JWT configuration variables. + self.jwt_enabled = hs.config.jwt_enabled + self.jwt_secret = hs.config.jwt_secret + self.jwt_algorithm = hs.config.jwt_algorithm + self.jwt_issuer = hs.config.jwt_issuer + self.jwt_audiences = hs.config.jwt_audiences + + # SSO configuration. + self.saml2_enabled = hs.config.saml2_enabled + self.cas_enabled = hs.config.cas_enabled + self.oidc_enabled = hs.config.oidc_enabled + self._msc2858_enabled = hs.config.experimental.msc2858_enabled + self._msc2918_enabled = hs.config.access_token_lifetime is not None + + self.auth = hs.get_auth() + + self.clock = hs.get_clock() + + self.auth_handler = self.hs.get_auth_handler() + self.registration_handler = hs.get_registration_handler() + self._sso_handler = hs.get_sso_handler() + + self._well_known_builder = WellKnownBuilder(hs) + self._address_ratelimiter = Ratelimiter( + store=hs.get_datastore(), + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_address.per_second, + burst_count=self.hs.config.rc_login_address.burst_count, + ) + self._account_ratelimiter = Ratelimiter( + store=hs.get_datastore(), + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_account.per_second, + burst_count=self.hs.config.rc_login_account.burst_count, + ) + + def on_GET(self, request: SynapseRequest): + flows = [] + if self.jwt_enabled: + flows.append({"type": LoginRestServlet.JWT_TYPE}) + flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED}) + + if self.cas_enabled: + # we advertise CAS for backwards compat, though MSC1721 renamed it + # to SSO. + flows.append({"type": LoginRestServlet.CAS_TYPE}) + + if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: + sso_flow: JsonDict = { + "type": LoginRestServlet.SSO_TYPE, + "identity_providers": [ + _get_auth_flow_dict_for_idp( + idp, + ) + for idp in self._sso_handler.get_identity_providers().values() + ], + } + + if self._msc2858_enabled: + # backwards-compatibility support for clients which don't + # support the stable API yet + sso_flow["org.matrix.msc2858.identity_providers"] = [ + _get_auth_flow_dict_for_idp(idp, use_unstable_brands=True) + for idp in self._sso_handler.get_identity_providers().values() + ] + + flows.append(sso_flow) + + # While it's valid for us to advertise this login type generally, + # synapse currently only gives out these tokens as part of the + # SSO login flow. + # Generally we don't want to advertise login flows that clients + # don't know how to implement, since they (currently) will always + # fall back to the fallback API if they don't understand one of the + # login flow types returned. + flows.append({"type": LoginRestServlet.TOKEN_TYPE}) + + flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types()) + + flows.append({"type": LoginRestServlet.APPSERVICE_TYPE}) + + return 200, {"flows": flows} + + async def on_POST(self, request: SynapseRequest): + login_submission = parse_json_object_from_request(request) + + if self._msc2918_enabled: + # Check if this login should also issue a refresh token, as per + # MSC2918 + should_issue_refresh_token = parse_boolean( + request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False + ) + else: + should_issue_refresh_token = False + + try: + if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: + appservice = self.auth.get_appservice_by_req(request) + + if appservice.is_rate_limited(): + await self._address_ratelimiter.ratelimit( + None, request.getClientIP() + ) + + result = await self._do_appservice_login( + login_submission, + appservice, + should_issue_refresh_token=should_issue_refresh_token, + ) + elif self.jwt_enabled and ( + login_submission["type"] == LoginRestServlet.JWT_TYPE + or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED + ): + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) + result = await self._do_jwt_login( + login_submission, + should_issue_refresh_token=should_issue_refresh_token, + ) + elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) + result = await self._do_token_login( + login_submission, + should_issue_refresh_token=should_issue_refresh_token, + ) + else: + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) + result = await self._do_other_login( + login_submission, + should_issue_refresh_token=should_issue_refresh_token, + ) + except KeyError: + raise SynapseError(400, "Missing JSON keys.") + + well_known_data = self._well_known_builder.get_well_known() + if well_known_data: + result["well_known"] = well_known_data + return 200, result + + async def _do_appservice_login( + self, + login_submission: JsonDict, + appservice: ApplicationService, + should_issue_refresh_token: bool = False, + ): + identifier = login_submission.get("identifier") + logger.info("Got appservice login request with identifier: %r", identifier) + + if not isinstance(identifier, dict): + raise SynapseError( + 400, "Invalid identifier in login submission", Codes.INVALID_PARAM + ) + + # this login flow only supports identifiers of type "m.id.user". + if identifier.get("type") != "m.id.user": + raise SynapseError( + 400, "Unknown login identifier type", Codes.INVALID_PARAM + ) + + user = identifier.get("user") + if not isinstance(user, str): + raise SynapseError(400, "Invalid user in identifier", Codes.INVALID_PARAM) + + if user.startswith("@"): + qualified_user_id = user + else: + qualified_user_id = UserID(user, self.hs.hostname).to_string() + + if not appservice.is_interested_in_user(qualified_user_id): + raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN) + + return await self._complete_login( + qualified_user_id, + login_submission, + ratelimit=appservice.is_rate_limited(), + should_issue_refresh_token=should_issue_refresh_token, + ) + + async def _do_other_login( + self, login_submission: JsonDict, should_issue_refresh_token: bool = False + ) -> LoginResponse: + """Handle non-token/saml/jwt logins + + Args: + login_submission: + should_issue_refresh_token: True if this login should issue + a refresh token alongside the access token. + + Returns: + HTTP response + """ + # Log the request we got, but only certain fields to minimise the chance of + # logging someone's password (even if they accidentally put it in the wrong + # field) + logger.info( + "Got login request with identifier: %r, medium: %r, address: %r, user: %r", + login_submission.get("identifier"), + login_submission.get("medium"), + login_submission.get("address"), + login_submission.get("user"), + ) + canonical_user_id, callback = await self.auth_handler.validate_login( + login_submission, ratelimit=True + ) + result = await self._complete_login( + canonical_user_id, + login_submission, + callback, + should_issue_refresh_token=should_issue_refresh_token, + ) + return result + + async def _complete_login( + self, + user_id: str, + login_submission: JsonDict, + callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None, + create_non_existent_users: bool = False, + ratelimit: bool = True, + auth_provider_id: Optional[str] = None, + should_issue_refresh_token: bool = False, + ) -> LoginResponse: + """Called when we've successfully authed the user and now need to + actually login them in (e.g. create devices). This gets called on + all successful logins. + + Applies the ratelimiting for successful login attempts against an + account. + + Args: + user_id: ID of the user to register. + login_submission: Dictionary of login information. + callback: Callback function to run after login. + create_non_existent_users: Whether to create the user if they don't + exist. Defaults to False. + ratelimit: Whether to ratelimit the login request. + auth_provider_id: The SSO IdP the user used, if any (just used for the + prometheus metrics). + should_issue_refresh_token: True if this login should issue + a refresh token alongside the access token. + + Returns: + result: Dictionary of account information after successful login. + """ + + # Before we actually log them in we check if they've already logged in + # too often. This happens here rather than before as we don't + # necessarily know the user before now. + if ratelimit: + await self._account_ratelimiter.ratelimit(None, user_id.lower()) + + if create_non_existent_users: + canonical_uid = await self.auth_handler.check_user_exists(user_id) + if not canonical_uid: + canonical_uid = await self.registration_handler.register_user( + localpart=UserID.from_string(user_id).localpart + ) + user_id = canonical_uid + + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get("initial_device_display_name") + ( + device_id, + access_token, + valid_until_ms, + refresh_token, + ) = await self.registration_handler.register_device( + user_id, + device_id, + initial_display_name, + auth_provider_id=auth_provider_id, + should_issue_refresh_token=should_issue_refresh_token, + ) + + result = LoginResponse( + user_id=user_id, + access_token=access_token, + home_server=self.hs.hostname, + device_id=device_id, + ) + + if valid_until_ms is not None: + expires_in_ms = valid_until_ms - self.clock.time_msec() + result["expires_in_ms"] = expires_in_ms + + if refresh_token is not None: + result["refresh_token"] = refresh_token + + if callback is not None: + await callback(result) + + return result + + async def _do_token_login( + self, login_submission: JsonDict, should_issue_refresh_token: bool = False + ) -> LoginResponse: + """ + Handle the final stage of SSO login. + + Args: + login_submission: The JSON request body. + should_issue_refresh_token: True if this login should issue + a refresh token alongside the access token. + + Returns: + The body of the JSON response. + """ + token = login_submission["token"] + auth_handler = self.auth_handler + res = await auth_handler.validate_short_term_login_token(token) + + return await self._complete_login( + res.user_id, + login_submission, + self.auth_handler._sso_login_callback, + auth_provider_id=res.auth_provider_id, + should_issue_refresh_token=should_issue_refresh_token, + ) + + async def _do_jwt_login( + self, login_submission: JsonDict, should_issue_refresh_token: bool = False + ) -> LoginResponse: + token = login_submission.get("token", None) + if token is None: + raise LoginError( + 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN + ) + + import jwt + + try: + payload = jwt.decode( + token, + self.jwt_secret, + algorithms=[self.jwt_algorithm], + issuer=self.jwt_issuer, + audience=self.jwt_audiences, + ) + except jwt.PyJWTError as e: + # A JWT error occurred, return some info back to the client. + raise LoginError( + 403, + "JWT validation failed: %s" % (str(e),), + errcode=Codes.FORBIDDEN, + ) + + user = payload.get("sub", None) + if user is None: + raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) + + user_id = UserID(user, self.hs.hostname).to_string() + result = await self._complete_login( + user_id, + login_submission, + create_non_existent_users=True, + should_issue_refresh_token=should_issue_refresh_token, + ) + return result + + +def _get_auth_flow_dict_for_idp( + idp: SsoIdentityProvider, use_unstable_brands: bool = False +) -> JsonDict: + """Return an entry for the login flow dict + + Returns an entry suitable for inclusion in "identity_providers" in the + response to GET /_matrix/client/r0/login + + Args: + idp: the identity provider to describe + use_unstable_brands: whether we should use brand identifiers suitable + for the unstable API + """ + e: JsonDict = {"id": idp.idp_id, "name": idp.idp_name} + if idp.idp_icon: + e["icon"] = idp.idp_icon + if idp.idp_brand: + e["brand"] = idp.idp_brand + # use the stable brand identifier if the unstable identifier isn't defined. + if use_unstable_brands and idp.unstable_idp_brand: + e["brand"] = idp.unstable_idp_brand + return e + + +class RefreshTokenServlet(RestServlet): + PATTERNS = client_patterns( + "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True + ) + + def __init__(self, hs: "HomeServer"): + self._auth_handler = hs.get_auth_handler() + self._clock = hs.get_clock() + self.access_token_lifetime = hs.config.access_token_lifetime + + async def on_POST( + self, + request: SynapseRequest, + ): + refresh_submission = parse_json_object_from_request(request) + + assert_params_in_dict(refresh_submission, ["refresh_token"]) + token = refresh_submission["refresh_token"] + if not isinstance(token, str): + raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM) + + valid_until_ms = self._clock.time_msec() + self.access_token_lifetime + access_token, refresh_token = await self._auth_handler.refresh_token( + token, valid_until_ms + ) + expires_in_ms = valid_until_ms - self._clock.time_msec() + return ( + 200, + { + "access_token": access_token, + "refresh_token": refresh_token, + "expires_in_ms": expires_in_ms, + }, + ) + + +class SsoRedirectServlet(RestServlet): + PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [ + re.compile( + "^" + + CLIENT_API_PREFIX + + "/r0/login/sso/redirect/(?P[A-Za-z0-9_.~-]+)$" + ) + ] + + def __init__(self, hs: "HomeServer"): + # make sure that the relevant handlers are instantiated, so that they + # register themselves with the main SSOHandler. + if hs.config.cas_enabled: + hs.get_cas_handler() + if hs.config.saml2_enabled: + hs.get_saml_handler() + if hs.config.oidc_enabled: + hs.get_oidc_handler() + self._sso_handler = hs.get_sso_handler() + self._msc2858_enabled = hs.config.experimental.msc2858_enabled + self._public_baseurl = hs.config.public_baseurl + + def register(self, http_server: HttpServer) -> None: + super().register(http_server) + if self._msc2858_enabled: + # expose additional endpoint for MSC2858 support: backwards-compat support + # for clients which don't yet support the stable endpoints. + http_server.register_paths( + "GET", + client_patterns( + "/org.matrix.msc2858/login/sso/redirect/(?P[A-Za-z0-9_.~-]+)$", + releases=(), + unstable=True, + ), + self.on_GET, + self.__class__.__name__, + ) + + async def on_GET( + self, request: SynapseRequest, idp_id: Optional[str] = None + ) -> None: + if not self._public_baseurl: + raise SynapseError(400, "SSO requires a valid public_baseurl") + + # if this isn't the expected hostname, redirect to the right one, so that we + # get our cookies back. + requested_uri = get_request_uri(request) + baseurl_bytes = self._public_baseurl.encode("utf-8") + if not requested_uri.startswith(baseurl_bytes): + # swap out the incorrect base URL for the right one. + # + # The idea here is to redirect from + # https://foo.bar/whatever/_matrix/... + # to + # https://public.baseurl/_matrix/... + # + i = requested_uri.index(b"/_matrix") + new_uri = baseurl_bytes[:-1] + requested_uri[i:] + logger.info( + "Requested URI %s is not canonical: redirecting to %s", + requested_uri.decode("utf-8", errors="replace"), + new_uri.decode("utf-8", errors="replace"), + ) + request.redirect(new_uri) + finish_request(request) + return + + args: Dict[bytes, List[bytes]] = request.args # type: ignore + client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True) + sso_url = await self._sso_handler.handle_redirect_request( + request, + client_redirect_url, + idp_id, + ) + logger.info("Redirecting to %s", sso_url) + request.redirect(sso_url) + finish_request(request) + + +class CasTicketServlet(RestServlet): + PATTERNS = client_patterns("/login/cas/ticket", v1=True) + + def __init__(self, hs): + super().__init__() + self._cas_handler = hs.get_cas_handler() + + async def on_GET(self, request: SynapseRequest) -> None: + client_redirect_url = parse_string(request, "redirectUrl") + ticket = parse_string(request, "ticket", required=True) + + # Maybe get a session ID (if this ticket is from user interactive + # authentication). + session = parse_string(request, "session") + + # Either client_redirect_url or session must be provided. + if not client_redirect_url and not session: + message = "Missing string query parameter redirectUrl or session" + raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) + + await self._cas_handler.handle_ticket( + request, ticket, client_redirect_url, session + ) + + +def register_servlets(hs, http_server): + LoginRestServlet(hs).register(http_server) + if hs.config.access_token_lifetime is not None: + RefreshTokenServlet(hs).register(http_server) + SsoRedirectServlet(hs).register(http_server) + if hs.config.cas_enabled: + CasTicketServlet(hs).register(http_server) diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py new file mode 100644 index 0000000000..6055cac2bd --- /dev/null +++ b/synapse/rest/client/logout.py @@ -0,0 +1,72 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.http.servlet import RestServlet +from synapse.rest.client._base import client_patterns + +logger = logging.getLogger(__name__) + + +class LogoutRestServlet(RestServlet): + PATTERNS = client_patterns("/logout$", v1=True) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() + + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_expired=True) + + if requester.device_id is None: + # The access token wasn't associated with a device. + # Just delete the access token + access_token = self.auth.get_access_token_from_request(request) + await self._auth_handler.delete_access_token(access_token) + else: + await self._device_handler.delete_device( + requester.user.to_string(), requester.device_id + ) + + return 200, {} + + +class LogoutAllRestServlet(RestServlet): + PATTERNS = client_patterns("/logout/all$", v1=True) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self._auth_handler = hs.get_auth_handler() + self._device_handler = hs.get_device_handler() + + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_expired=True) + user_id = requester.user.to_string() + + # first delete all of the user's devices + await self._device_handler.delete_all_devices_for_user(user_id) + + # .. and then delete any access tokens which weren't associated with + # devices. + await self._auth_handler.delete_access_tokens_for_user(user_id) + return 200, {} + + +def register_servlets(hs, http_server): + LogoutRestServlet(hs).register(http_server) + LogoutAllRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py new file mode 100644 index 0000000000..0ede643c2d --- /dev/null +++ b/synapse/rest/client/notifications.py @@ -0,0 +1,91 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.events.utils import format_event_for_client_v2_without_room_id +from synapse.http.servlet import RestServlet, parse_integer, parse_string + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class NotificationsServlet(RestServlet): + PATTERNS = client_patterns("/notifications$") + + def __init__(self, hs): + super().__init__() + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + from_token = parse_string(request, "from", required=False) + limit = parse_integer(request, "limit", default=50) + only = parse_string(request, "only", required=False) + + limit = min(limit, 500) + + push_actions = await self.store.get_push_actions_for_user( + user_id, from_token, limit, only_highlight=(only == "highlight") + ) + + receipts_by_room = await self.store.get_receipts_for_user_with_orderings( + user_id, "m.read" + ) + + notif_event_ids = [pa["event_id"] for pa in push_actions] + notif_events = await self.store.get_events(notif_event_ids) + + returned_push_actions = [] + + next_token = None + + for pa in push_actions: + returned_pa = { + "room_id": pa["room_id"], + "profile_tag": pa["profile_tag"], + "actions": pa["actions"], + "ts": pa["received_ts"], + "event": ( + await self._event_serializer.serialize_event( + notif_events[pa["event_id"]], + self.clock.time_msec(), + event_format=format_event_for_client_v2_without_room_id, + ) + ), + } + + if pa["room_id"] not in receipts_by_room: + returned_pa["read"] = False + else: + receipt = receipts_by_room[pa["room_id"]] + + returned_pa["read"] = ( + receipt["topological_ordering"], + receipt["stream_ordering"], + ) >= (pa["topological_ordering"], pa["stream_ordering"]) + returned_push_actions.append(returned_pa) + next_token = str(pa["stream_ordering"]) + + return 200, {"notifications": returned_push_actions, "next_token": next_token} + + +def register_servlets(hs, http_server): + NotificationsServlet(hs).register(http_server) diff --git a/synapse/rest/client/openid.py b/synapse/rest/client/openid.py new file mode 100644 index 0000000000..e8d2673819 --- /dev/null +++ b/synapse/rest/client/openid.py @@ -0,0 +1,94 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +from synapse.api.errors import AuthError +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.util.stringutils import random_string + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class IdTokenServlet(RestServlet): + """ + Get a bearer token that may be passed to a third party to confirm ownership + of a matrix user id. + + The format of the response could be made compatible with the format given + in http://openid.net/specs/openid-connect-core-1_0.html#TokenResponse + + But instead of returning a signed "id_token" the response contains the + name of the issuing matrix homeserver. This means that for now the third + party will need to check the validity of the "id_token" against the + federation /openid/userinfo endpoint of the homeserver. + + Request: + + POST /user/{user_id}/openid/request_token?access_token=... HTTP/1.1 + + {} + + Response: + + HTTP/1.1 200 OK + { + "access_token": "ABDEFGH", + "token_type": "Bearer", + "matrix_server_name": "example.com", + "expires_in": 3600, + } + """ + + PATTERNS = client_patterns("/user/(?P[^/]*)/openid/request_token") + + EXPIRES_MS = 3600 * 1000 + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self.server_name = hs.config.server_name + + async def on_POST(self, request, user_id): + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot request tokens for other users.") + + # Parse the request body to make sure it's JSON, but ignore the contents + # for now. + parse_json_object_from_request(request) + + token = random_string(24) + ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS + + await self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) + + return ( + 200, + { + "access_token": token, + "token_type": "Bearer", + "matrix_server_name": self.server_name, + "expires_in": self.EXPIRES_MS // 1000, + }, + ) + + +def register_servlets(hs, http_server): + IdTokenServlet(hs).register(http_server) diff --git a/synapse/rest/client/password_policy.py b/synapse/rest/client/password_policy.py new file mode 100644 index 0000000000..a83927aee6 --- /dev/null +++ b/synapse/rest/client/password_policy.py @@ -0,0 +1,57 @@ +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.http.servlet import RestServlet + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class PasswordPolicyServlet(RestServlet): + PATTERNS = client_patterns("/password_policy$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + + self.policy = hs.config.password_policy + self.enabled = hs.config.password_policy_enabled + + def on_GET(self, request): + if not self.enabled or not self.policy: + return (200, {}) + + policy = {} + + for param in [ + "minimum_length", + "require_digit", + "require_symbol", + "require_lowercase", + "require_uppercase", + ]: + if param in self.policy: + policy["m.%s" % param] = self.policy[param] + + return (200, policy) + + +def register_servlets(hs, http_server): + PasswordPolicyServlet(hs).register(http_server) diff --git a/synapse/rest/client/presence.py b/synapse/rest/client/presence.py new file mode 100644 index 0000000000..6c27e5faf9 --- /dev/null +++ b/synapse/rest/client/presence.py @@ -0,0 +1,95 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" This module contains REST servlets to do with presence: /presence/ +""" +import logging + +from synapse.api.errors import AuthError, SynapseError +from synapse.handlers.presence import format_user_presence_state +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.rest.client._base import client_patterns +from synapse.types import UserID + +logger = logging.getLogger(__name__) + + +class PresenceStatusRestServlet(RestServlet): + PATTERNS = client_patterns("/presence/(?P[^/]*)/status", v1=True) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.presence_handler = hs.get_presence_handler() + self.clock = hs.get_clock() + self.auth = hs.get_auth() + + self._use_presence = hs.config.server.use_presence + + async def on_GET(self, request, user_id): + requester = await self.auth.get_user_by_req(request) + user = UserID.from_string(user_id) + + if not self._use_presence: + return 200, {"presence": "offline"} + + if requester.user != user: + allowed = await self.presence_handler.is_visible( + observed_user=user, observer_user=requester.user + ) + + if not allowed: + raise AuthError(403, "You are not allowed to see their presence.") + + state = await self.presence_handler.get_state(target_user=user) + state = format_user_presence_state( + state, self.clock.time_msec(), include_user_id=False + ) + + return 200, state + + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request) + user = UserID.from_string(user_id) + + if requester.user != user: + raise AuthError(403, "Can only set your own presence state") + + state = {} + + content = parse_json_object_from_request(request) + + try: + state["presence"] = content.pop("presence") + + if "status_msg" in content: + state["status_msg"] = content.pop("status_msg") + if not isinstance(state["status_msg"], str): + raise SynapseError(400, "status_msg must be a string.") + + if content: + raise KeyError() + except SynapseError as e: + raise e + except Exception: + raise SynapseError(400, "Unable to parse state") + + if self._use_presence: + await self.presence_handler.set_state(user, state) + + return 200, {} + + +def register_servlets(hs, http_server): + PresenceStatusRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py new file mode 100644 index 0000000000..5463ed2c4f --- /dev/null +++ b/synapse/rest/client/profile.py @@ -0,0 +1,155 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" This module contains REST servlets to do with profile: /profile/ """ + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.rest.client._base import client_patterns +from synapse.types import UserID + + +class ProfileDisplaynameRestServlet(RestServlet): + PATTERNS = client_patterns("/profile/(?P[^/]*)/displayname", v1=True) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.profile_handler = hs.get_profile_handler() + self.auth = hs.get_auth() + + async def on_GET(self, request, user_id): + requester_user = None + + if self.hs.config.require_auth_for_profile_requests: + requester = await self.auth.get_user_by_req(request) + requester_user = requester.user + + user = UserID.from_string(user_id) + + await self.profile_handler.check_profile_query_allowed(user, requester_user) + + displayname = await self.profile_handler.get_displayname(user) + + ret = {} + if displayname is not None: + ret["displayname"] = displayname + + return 200, ret + + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user = UserID.from_string(user_id) + is_admin = await self.auth.is_server_admin(requester.user) + + content = parse_json_object_from_request(request) + + try: + new_name = content["displayname"] + except Exception: + raise SynapseError( + code=400, + msg="Unable to parse name", + errcode=Codes.BAD_JSON, + ) + + await self.profile_handler.set_displayname(user, requester, new_name, is_admin) + + return 200, {} + + +class ProfileAvatarURLRestServlet(RestServlet): + PATTERNS = client_patterns("/profile/(?P[^/]*)/avatar_url", v1=True) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.profile_handler = hs.get_profile_handler() + self.auth = hs.get_auth() + + async def on_GET(self, request, user_id): + requester_user = None + + if self.hs.config.require_auth_for_profile_requests: + requester = await self.auth.get_user_by_req(request) + requester_user = requester.user + + user = UserID.from_string(user_id) + + await self.profile_handler.check_profile_query_allowed(user, requester_user) + + avatar_url = await self.profile_handler.get_avatar_url(user) + + ret = {} + if avatar_url is not None: + ret["avatar_url"] = avatar_url + + return 200, ret + + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request) + user = UserID.from_string(user_id) + is_admin = await self.auth.is_server_admin(requester.user) + + content = parse_json_object_from_request(request) + try: + new_avatar_url = content["avatar_url"] + except KeyError: + raise SynapseError( + 400, "Missing key 'avatar_url'", errcode=Codes.MISSING_PARAM + ) + + await self.profile_handler.set_avatar_url( + user, requester, new_avatar_url, is_admin + ) + + return 200, {} + + +class ProfileRestServlet(RestServlet): + PATTERNS = client_patterns("/profile/(?P[^/]*)", v1=True) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.profile_handler = hs.get_profile_handler() + self.auth = hs.get_auth() + + async def on_GET(self, request, user_id): + requester_user = None + + if self.hs.config.require_auth_for_profile_requests: + requester = await self.auth.get_user_by_req(request) + requester_user = requester.user + + user = UserID.from_string(user_id) + + await self.profile_handler.check_profile_query_allowed(user, requester_user) + + displayname = await self.profile_handler.get_displayname(user) + avatar_url = await self.profile_handler.get_avatar_url(user) + + ret = {} + if displayname is not None: + ret["displayname"] = displayname + if avatar_url is not None: + ret["avatar_url"] = avatar_url + + return 200, ret + + +def register_servlets(hs, http_server): + ProfileDisplaynameRestServlet(hs).register(http_server) + ProfileAvatarURLRestServlet(hs).register(http_server) + ProfileRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py new file mode 100644 index 0000000000..702b351d18 --- /dev/null +++ b/synapse/rest/client/push_rule.py @@ -0,0 +1,354 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.api.errors import ( + NotFoundError, + StoreError, + SynapseError, + UnrecognizedRequestError, +) +from synapse.http.servlet import ( + RestServlet, + parse_json_value_from_request, + parse_string, +) +from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS +from synapse.push.clientformat import format_push_rules_for_user +from synapse.push.rulekinds import PRIORITY_CLASS_MAP +from synapse.rest.client._base import client_patterns +from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException + + +class PushRuleRestServlet(RestServlet): + PATTERNS = client_patterns("/(?Ppushrules/.*)$", v1=True) + SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( + "Unrecognised request: You probably wanted a trailing slash" + ) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + self._is_worker = hs.config.worker_app is not None + + self._users_new_default_push_rules = hs.config.users_new_default_push_rules + + async def on_PUT(self, request, path): + if self._is_worker: + raise Exception("Cannot handle PUT /push_rules on worker") + + spec = _rule_spec_from_path(path.split("/")) + try: + priority_class = _priority_class_from_spec(spec) + except InvalidRuleException as e: + raise SynapseError(400, str(e)) + + requester = await self.auth.get_user_by_req(request) + + if "/" in spec["rule_id"] or "\\" in spec["rule_id"]: + raise SynapseError(400, "rule_id may not contain slashes") + + content = parse_json_value_from_request(request) + + user_id = requester.user.to_string() + + if "attr" in spec: + await self.set_rule_attr(user_id, spec, content) + self.notify_user(user_id) + return 200, {} + + if spec["rule_id"].startswith("."): + # Rule ids starting with '.' are reserved for server default rules. + raise SynapseError(400, "cannot add new rule_ids that start with '.'") + + try: + (conditions, actions) = _rule_tuple_from_request_object( + spec["template"], spec["rule_id"], content + ) + except InvalidRuleException as e: + raise SynapseError(400, str(e)) + + before = parse_string(request, "before") + if before: + before = _namespaced_rule_id(spec, before) + + after = parse_string(request, "after") + if after: + after = _namespaced_rule_id(spec, after) + + try: + await self.store.add_push_rule( + user_id=user_id, + rule_id=_namespaced_rule_id_from_spec(spec), + priority_class=priority_class, + conditions=conditions, + actions=actions, + before=before, + after=after, + ) + self.notify_user(user_id) + except InconsistentRuleException as e: + raise SynapseError(400, str(e)) + except RuleNotFoundException as e: + raise SynapseError(400, str(e)) + + return 200, {} + + async def on_DELETE(self, request, path): + if self._is_worker: + raise Exception("Cannot handle DELETE /push_rules on worker") + + spec = _rule_spec_from_path(path.split("/")) + + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + namespaced_rule_id = _namespaced_rule_id_from_spec(spec) + + try: + await self.store.delete_push_rule(user_id, namespaced_rule_id) + self.notify_user(user_id) + return 200, {} + except StoreError as e: + if e.code == 404: + raise NotFoundError() + else: + raise + + async def on_GET(self, request, path): + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + # we build up the full structure and then decide which bits of it + # to send which means doing unnecessary work sometimes but is + # is probably not going to make a whole lot of difference + rules = await self.store.get_push_rules_for_user(user_id) + + rules = format_push_rules_for_user(requester.user, rules) + + path = path.split("/")[1:] + + if path == []: + # we're a reference impl: pedantry is our job. + raise UnrecognizedRequestError( + PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR + ) + + if path[0] == "": + return 200, rules + elif path[0] == "global": + result = _filter_ruleset_with_path(rules["global"], path[1:]) + return 200, result + else: + raise UnrecognizedRequestError() + + def notify_user(self, user_id): + stream_id = self.store.get_max_push_rules_stream_id() + self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) + + async def set_rule_attr(self, user_id, spec, val): + if spec["attr"] not in ("enabled", "actions"): + # for the sake of potential future expansion, shouldn't report + # 404 in the case of an unknown request so check it corresponds to + # a known attribute first. + raise UnrecognizedRequestError() + + namespaced_rule_id = _namespaced_rule_id_from_spec(spec) + rule_id = spec["rule_id"] + is_default_rule = rule_id.startswith(".") + if is_default_rule: + if namespaced_rule_id not in BASE_RULE_IDS: + raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,)) + if spec["attr"] == "enabled": + if isinstance(val, dict) and "enabled" in val: + val = val["enabled"] + if not isinstance(val, bool): + # Legacy fallback + # This should *actually* take a dict, but many clients pass + # bools directly, so let's not break them. + raise SynapseError(400, "Value for 'enabled' must be boolean") + return await self.store.set_push_rule_enabled( + user_id, namespaced_rule_id, val, is_default_rule + ) + elif spec["attr"] == "actions": + actions = val.get("actions") + _check_actions(actions) + namespaced_rule_id = _namespaced_rule_id_from_spec(spec) + rule_id = spec["rule_id"] + is_default_rule = rule_id.startswith(".") + if is_default_rule: + if user_id in self._users_new_default_push_rules: + rule_ids = NEW_RULE_IDS + else: + rule_ids = BASE_RULE_IDS + + if namespaced_rule_id not in rule_ids: + raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) + return await self.store.set_push_rule_actions( + user_id, namespaced_rule_id, actions, is_default_rule + ) + else: + raise UnrecognizedRequestError() + + +def _rule_spec_from_path(path): + """Turn a sequence of path components into a rule spec + + Args: + path (sequence[unicode]): the URL path components. + + Returns: + dict: rule spec dict, containing scope/template/rule_id entries, + and possibly attr. + + Raises: + UnrecognizedRequestError if the path components cannot be parsed. + """ + if len(path) < 2: + raise UnrecognizedRequestError() + if path[0] != "pushrules": + raise UnrecognizedRequestError() + + scope = path[1] + path = path[2:] + if scope != "global": + raise UnrecognizedRequestError() + + if len(path) == 0: + raise UnrecognizedRequestError() + + template = path[0] + path = path[1:] + + if len(path) == 0 or len(path[0]) == 0: + raise UnrecognizedRequestError() + + rule_id = path[0] + + spec = {"scope": scope, "template": template, "rule_id": rule_id} + + path = path[1:] + + if len(path) > 0 and len(path[0]) > 0: + spec["attr"] = path[0] + + return spec + + +def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): + if rule_template in ["override", "underride"]: + if "conditions" not in req_obj: + raise InvalidRuleException("Missing 'conditions'") + conditions = req_obj["conditions"] + for c in conditions: + if "kind" not in c: + raise InvalidRuleException("Condition without 'kind'") + elif rule_template == "room": + conditions = [{"kind": "event_match", "key": "room_id", "pattern": rule_id}] + elif rule_template == "sender": + conditions = [{"kind": "event_match", "key": "user_id", "pattern": rule_id}] + elif rule_template == "content": + if "pattern" not in req_obj: + raise InvalidRuleException("Content rule missing 'pattern'") + pat = req_obj["pattern"] + + conditions = [{"kind": "event_match", "key": "content.body", "pattern": pat}] + else: + raise InvalidRuleException("Unknown rule template: %s" % (rule_template,)) + + if "actions" not in req_obj: + raise InvalidRuleException("No actions found") + actions = req_obj["actions"] + + _check_actions(actions) + + return conditions, actions + + +def _check_actions(actions): + if not isinstance(actions, list): + raise InvalidRuleException("No actions found") + + for a in actions: + if a in ["notify", "dont_notify", "coalesce"]: + pass + elif isinstance(a, dict) and "set_tweak" in a: + pass + else: + raise InvalidRuleException("Unrecognised action") + + +def _filter_ruleset_with_path(ruleset, path): + if path == []: + raise UnrecognizedRequestError( + PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR + ) + + if path[0] == "": + return ruleset + template_kind = path[0] + if template_kind not in ruleset: + raise UnrecognizedRequestError() + path = path[1:] + if path == []: + raise UnrecognizedRequestError( + PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR + ) + if path[0] == "": + return ruleset[template_kind] + rule_id = path[0] + + the_rule = None + for r in ruleset[template_kind]: + if r["rule_id"] == rule_id: + the_rule = r + if the_rule is None: + raise NotFoundError + + path = path[1:] + if len(path) == 0: + return the_rule + + attr = path[0] + if attr in the_rule: + # Make sure we return a JSON object as the attribute may be a + # JSON value. + return {attr: the_rule[attr]} + else: + raise UnrecognizedRequestError() + + +def _priority_class_from_spec(spec): + if spec["template"] not in PRIORITY_CLASS_MAP.keys(): + raise InvalidRuleException("Unknown template: %s" % (spec["template"])) + pc = PRIORITY_CLASS_MAP[spec["template"]] + + return pc + + +def _namespaced_rule_id_from_spec(spec): + return _namespaced_rule_id(spec, spec["rule_id"]) + + +def _namespaced_rule_id(spec, rule_id): + return "global/%s/%s" % (spec["template"], rule_id) + + +class InvalidRuleException(Exception): + pass + + +def register_servlets(hs, http_server): + PushRuleRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py new file mode 100644 index 0000000000..84619c5e41 --- /dev/null +++ b/synapse/rest/client/pusher.py @@ -0,0 +1,171 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import Codes, StoreError, SynapseError +from synapse.http.server import respond_with_html_bytes +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, + parse_string, +) +from synapse.push import PusherConfigException +from synapse.rest.client._base import client_patterns + +logger = logging.getLogger(__name__) + + +class PushersRestServlet(RestServlet): + PATTERNS = client_patterns("/pushers$", v1=True) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) + user = requester.user + + pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) + + filtered_pushers = [p.as_dict() for p in pushers] + + return 200, {"pushers": filtered_pushers} + + +class PushersSetRestServlet(RestServlet): + PATTERNS = client_patterns("/pushers/set$", v1=True) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.notifier = hs.get_notifier() + self.pusher_pool = self.hs.get_pusherpool() + + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) + user = requester.user + + content = parse_json_object_from_request(request) + + if ( + "pushkey" in content + and "app_id" in content + and "kind" in content + and content["kind"] is None + ): + await self.pusher_pool.remove_pusher( + content["app_id"], content["pushkey"], user_id=user.to_string() + ) + return 200, {} + + assert_params_in_dict( + content, + [ + "kind", + "app_id", + "app_display_name", + "device_display_name", + "pushkey", + "lang", + "data", + ], + ) + + logger.debug("set pushkey %s to kind %s", content["pushkey"], content["kind"]) + logger.debug("Got pushers request with body: %r", content) + + append = False + if "append" in content: + append = content["append"] + + if not append: + await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( + app_id=content["app_id"], + pushkey=content["pushkey"], + not_user_id=user.to_string(), + ) + + try: + await self.pusher_pool.add_pusher( + user_id=user.to_string(), + access_token=requester.access_token_id, + kind=content["kind"], + app_id=content["app_id"], + app_display_name=content["app_display_name"], + device_display_name=content["device_display_name"], + pushkey=content["pushkey"], + lang=content["lang"], + data=content["data"], + profile_tag=content.get("profile_tag", ""), + ) + except PusherConfigException as pce: + raise SynapseError( + 400, "Config Error: " + str(pce), errcode=Codes.MISSING_PARAM + ) + + self.notifier.on_new_replication_data() + + return 200, {} + + +class PushersRemoveRestServlet(RestServlet): + """ + To allow pusher to be delete by clicking a link (ie. GET request) + """ + + PATTERNS = client_patterns("/pushers/remove$", v1=True) + SUCCESS_HTML = b"You have been unsubscribed" + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.notifier = hs.get_notifier() + self.auth = hs.get_auth() + self.pusher_pool = self.hs.get_pusherpool() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, rights="delete_pusher") + user = requester.user + + app_id = parse_string(request, "app_id", required=True) + pushkey = parse_string(request, "pushkey", required=True) + + try: + await self.pusher_pool.remove_pusher( + app_id=app_id, pushkey=pushkey, user_id=user.to_string() + ) + except StoreError as se: + if se.code != 404: + # This is fine: they're already unsubscribed + raise + + self.notifier.on_new_replication_data() + + respond_with_html_bytes( + request, + 200, + PushersRemoveRestServlet.SUCCESS_HTML, + ) + return None + + +def register_servlets(hs, http_server): + PushersRestServlet(hs).register(http_server) + PushersSetRestServlet(hs).register(http_server) + PushersRemoveRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py new file mode 100644 index 0000000000..027f8b81fa --- /dev/null +++ b/synapse/rest/client/read_marker.py @@ -0,0 +1,74 @@ +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.constants import ReadReceiptEventFields +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class ReadMarkerRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P[^/]*)/read_markers$") + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.receipts_handler = hs.get_receipts_handler() + self.read_marker_handler = hs.get_read_marker_handler() + self.presence_handler = hs.get_presence_handler() + + async def on_POST(self, request, room_id): + requester = await self.auth.get_user_by_req(request) + + await self.presence_handler.bump_presence_active_time(requester.user) + + body = parse_json_object_from_request(request) + read_event_id = body.get("m.read", None) + hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) + + if not isinstance(hidden, bool): + raise SynapseError( + 400, + "Param %s must be a boolean, if given" + % ReadReceiptEventFields.MSC2285_HIDDEN, + Codes.BAD_JSON, + ) + + if read_event_id: + await self.receipts_handler.received_client_receipt( + room_id, + "m.read", + user_id=requester.user.to_string(), + event_id=read_event_id, + hidden=hidden, + ) + + read_marker_event_id = body.get("m.fully_read", None) + if read_marker_event_id: + await self.read_marker_handler.received_client_read_marker( + room_id, + user_id=requester.user.to_string(), + event_id=read_marker_event_id, + ) + + return 200, {} + + +def register_servlets(hs, http_server): + ReadMarkerRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py new file mode 100644 index 0000000000..d9ab836cd8 --- /dev/null +++ b/synapse/rest/client/receipts.py @@ -0,0 +1,71 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.constants import ReadReceiptEventFields +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class ReceiptRestServlet(RestServlet): + PATTERNS = client_patterns( + "/rooms/(?P[^/]*)" + "/receipt/(?P[^/]*)" + "/(?P[^/]*)$" + ) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.receipts_handler = hs.get_receipts_handler() + self.presence_handler = hs.get_presence_handler() + + async def on_POST(self, request, room_id, receipt_type, event_id): + requester = await self.auth.get_user_by_req(request) + + if receipt_type != "m.read": + raise SynapseError(400, "Receipt type must be 'm.read'") + + body = parse_json_object_from_request(request, allow_empty_body=True) + hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) + + if not isinstance(hidden, bool): + raise SynapseError( + 400, + "Param %s must be a boolean, if given" + % ReadReceiptEventFields.MSC2285_HIDDEN, + Codes.BAD_JSON, + ) + + await self.presence_handler.bump_presence_active_time(requester.user) + + await self.receipts_handler.received_client_receipt( + room_id, + receipt_type, + user_id=requester.user.to_string(), + event_id=event_id, + hidden=hidden, + ) + + return 200, {} + + +def register_servlets(hs, http_server): + ReceiptRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py new file mode 100644 index 0000000000..58b8e8f261 --- /dev/null +++ b/synapse/rest/client/register.py @@ -0,0 +1,879 @@ +# Copyright 2015 - 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import hmac +import logging +import random +from typing import List, Union + +import synapse +import synapse.api.auth +import synapse.types +from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType +from synapse.api.errors import ( + Codes, + InteractiveAuthIncompleteError, + SynapseError, + ThreepidValidationError, + UnrecognizedRequestError, +) +from synapse.config import ConfigError +from synapse.config.captcha import CaptchaConfig +from synapse.config.consent import ConsentConfig +from synapse.config.emailconfig import ThreepidBehaviour +from synapse.config.ratelimiting import FederationRateLimitConfig +from synapse.config.registration import RegistrationConfig +from synapse.config.server import is_threepid_reserved +from synapse.handlers.auth import AuthHandler +from synapse.handlers.ui_auth import UIAuthSessionDataConstants +from synapse.http.server import finish_request, respond_with_html +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_boolean, + parse_json_object_from_request, + parse_string, +) +from synapse.metrics import threepid_send_requests +from synapse.push.mailer import Mailer +from synapse.types import JsonDict +from synapse.util.msisdn import phone_number_to_msisdn +from synapse.util.ratelimitutils import FederationRateLimiter +from synapse.util.stringutils import assert_valid_client_secret, random_string +from synapse.util.threepids import ( + canonicalise_email, + check_3pid_allowed, + validate_email, +) + +from ._base import client_patterns, interactive_auth_handler + +# We ought to be using hmac.compare_digest() but on older pythons it doesn't +# exist. It's a _really minor_ security flaw to use plain string comparison +# because the timing attack is so obscured by all the other code here it's +# unlikely to make much difference +if hasattr(hmac, "compare_digest"): + compare_digest = hmac.compare_digest +else: + + def compare_digest(a, b): + return a == b + + +logger = logging.getLogger(__name__) + + +class EmailRegisterRequestTokenRestServlet(RestServlet): + PATTERNS = client_patterns("/register/email/requestToken$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.identity_handler = hs.get_identity_handler() + self.config = hs.config + + if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + self.mailer = Mailer( + hs=self.hs, + app_name=self.config.email_app_name, + template_html=self.config.email_registration_template_html, + template_text=self.config.email_registration_template_text, + ) + + async def on_POST(self, request): + if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.hs.config.local_threepid_handling_disabled_due_to_email_config: + logger.warning( + "Email registration has been disabled due to lack of email config" + ) + raise SynapseError( + 400, "Email-based registration has been disabled on this server" + ) + body = parse_json_object_from_request(request) + + assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) + + # Extract params from body + client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See on_POST in EmailThreepidRequestTokenRestServlet + # in synapse/rest/client/account.py) + try: + email = validate_email(body["email"]) + except ValueError as e: + raise SynapseError(400, str(e)) + send_attempt = body["send_attempt"] + next_link = body.get("next_link") # Optional param + + if not check_3pid_allowed(self.hs, "email", email): + raise SynapseError( + 403, + "Your email domain is not authorized to register on this server", + Codes.THREEPID_DENIED, + ) + + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) + + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + "email", email + ) + + if existing_user_id is not None: + if self.hs.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) + return 200, {"sid": random_string(16)} + + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) + + if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + assert self.hs.config.account_threepid_delegate_email + + # Have the configured identity server handle the request + ret = await self.identity_handler.requestEmailToken( + self.hs.config.account_threepid_delegate_email, + email, + client_secret, + send_attempt, + next_link, + ) + else: + # Send registration emails from Synapse + sid = await self.identity_handler.send_threepid_validation( + email, + client_secret, + send_attempt, + self.mailer.send_registration_mail, + next_link, + ) + + # Wrap the session id in a JSON object + ret = {"sid": sid} + + threepid_send_requests.labels(type="email", reason="register").observe( + send_attempt + ) + + return 200, ret + + +class MsisdnRegisterRequestTokenRestServlet(RestServlet): + PATTERNS = client_patterns("/register/msisdn/requestToken$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.identity_handler = hs.get_identity_handler() + + async def on_POST(self, request): + body = parse_json_object_from_request(request) + + assert_params_in_dict( + body, ["client_secret", "country", "phone_number", "send_attempt"] + ) + client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) + country = body["country"] + phone_number = body["phone_number"] + send_attempt = body["send_attempt"] + next_link = body.get("next_link") # Optional param + + msisdn = phone_number_to_msisdn(country, phone_number) + + if not check_3pid_allowed(self.hs, "msisdn", msisdn): + raise SynapseError( + 403, + "Phone numbers are not authorized to register on this server", + Codes.THREEPID_DENIED, + ) + + await self.identity_handler.ratelimit_request_token_requests( + request, "msisdn", msisdn + ) + + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( + "msisdn", msisdn + ) + + if existing_user_id is not None: + if self.hs.config.request_token_inhibit_3pid_errors: + # Make the client think the operation succeeded. See the rationale in the + # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.get_clock().sleep(random.randint(1, 10) / 10) + return 200, {"sid": random_string(16)} + + raise SynapseError( + 400, "Phone number is already in use", Codes.THREEPID_IN_USE + ) + + if not self.hs.config.account_threepid_delegate_msisdn: + logger.warning( + "No upstream msisdn account_threepid_delegate configured on the server to " + "handle this request" + ) + raise SynapseError( + 400, "Registration by phone number is not supported on this homeserver" + ) + + ret = await self.identity_handler.requestMsisdnToken( + self.hs.config.account_threepid_delegate_msisdn, + country, + phone_number, + client_secret, + send_attempt, + next_link, + ) + + threepid_send_requests.labels(type="msisdn", reason="register").observe( + send_attempt + ) + + return 200, ret + + +class RegistrationSubmitTokenServlet(RestServlet): + """Handles registration 3PID validation token submission""" + + PATTERNS = client_patterns( + "/registration/(?P[^/]*)/submit_token$", releases=(), unstable=True + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.config = hs.config + self.clock = hs.get_clock() + self.store = hs.get_datastore() + + if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + self._failure_email_template = ( + self.config.email_registration_template_failure_html + ) + + async def on_GET(self, request, medium): + if medium != "email": + raise SynapseError( + 400, "This medium is currently not supported for registration" + ) + if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.local_threepid_handling_disabled_due_to_email_config: + logger.warning( + "User registration via email has been disabled due to lack of email config" + ) + raise SynapseError( + 400, "Email-based registration is disabled on this server" + ) + + sid = parse_string(request, "sid", required=True) + client_secret = parse_string(request, "client_secret", required=True) + assert_valid_client_secret(client_secret) + token = parse_string(request, "token", required=True) + + # Attempt to validate a 3PID session + try: + # Mark the session as valid + next_link = await self.store.validate_threepid_session( + sid, client_secret, token, self.clock.time_msec() + ) + + # Perform a 302 redirect if next_link is set + if next_link: + if next_link.startswith("file:///"): + logger.warning( + "Not redirecting to next_link as it is a local file: address" + ) + else: + request.setResponseCode(302) + request.setHeader("Location", next_link) + finish_request(request) + return None + + # Otherwise show the success template + html = self.config.email_registration_template_success_html_content + status_code = 200 + except ThreepidValidationError as e: + status_code = e.code + + # Show a failure page with a reason + template_vars = {"failure_reason": e.msg} + html = self._failure_email_template.render(**template_vars) + + respond_with_html(request, status_code, html) + + +class UsernameAvailabilityRestServlet(RestServlet): + PATTERNS = client_patterns("/register/available") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.registration_handler = hs.get_registration_handler() + self.ratelimiter = FederationRateLimiter( + hs.get_clock(), + FederationRateLimitConfig( + # Time window of 2s + window_size=2000, + # Artificially delay requests if rate > sleep_limit/window_size + sleep_limit=1, + # Amount of artificial delay to apply + sleep_msec=1000, + # Error with 429 if more than reject_limit requests are queued + reject_limit=1, + # Allow 1 request at a time + concurrent_requests=1, + ), + ) + + async def on_GET(self, request): + if not self.hs.config.enable_registration: + raise SynapseError( + 403, "Registration has been disabled", errcode=Codes.FORBIDDEN + ) + + ip = request.getClientIP() + with self.ratelimiter.ratelimit(ip) as wait_deferred: + await wait_deferred + + username = parse_string(request, "username", required=True) + + await self.registration_handler.check_username(username) + + return 200, {"available": True} + + +class RegisterRestServlet(RestServlet): + PATTERNS = client_patterns("/register$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.auth_handler = hs.get_auth_handler() + self.registration_handler = hs.get_registration_handler() + self.identity_handler = hs.get_identity_handler() + self.room_member_handler = hs.get_room_member_handler() + self.macaroon_gen = hs.get_macaroon_generator() + self.ratelimiter = hs.get_registration_ratelimiter() + self.password_policy_handler = hs.get_password_policy_handler() + self.clock = hs.get_clock() + self._registration_enabled = self.hs.config.enable_registration + self._msc2918_enabled = hs.config.access_token_lifetime is not None + + self._registration_flows = _calculate_registration_flows( + hs.config, self.auth_handler + ) + + @interactive_auth_handler + async def on_POST(self, request): + body = parse_json_object_from_request(request) + + client_addr = request.getClientIP() + + await self.ratelimiter.ratelimit(None, client_addr, update=False) + + kind = b"user" + if b"kind" in request.args: + kind = request.args[b"kind"][0] + + if kind == b"guest": + ret = await self._do_guest_registration(body, address=client_addr) + return ret + elif kind != b"user": + raise UnrecognizedRequestError( + "Do not understand membership kind: %s" % (kind.decode("utf8"),) + ) + + if self._msc2918_enabled: + # Check if this registration should also issue a refresh token, as + # per MSC2918 + should_issue_refresh_token = parse_boolean( + request, name="org.matrix.msc2918.refresh_token", default=False + ) + else: + should_issue_refresh_token = False + + # Pull out the provided username and do basic sanity checks early since + # the auth layer will store these in sessions. + desired_username = None + if "username" in body: + if not isinstance(body["username"], str) or len(body["username"]) > 512: + raise SynapseError(400, "Invalid username") + desired_username = body["username"] + + # fork off as soon as possible for ASes which have completely + # different registration flows to normal users + + # == Application Service Registration == + if body.get("type") == APP_SERVICE_REGISTRATION_TYPE: + if not self.auth.has_access_token(request): + raise SynapseError( + 400, + "Appservice token must be provided when using a type of m.login.application_service", + ) + + # Verify the AS + self.auth.get_appservice_by_req(request) + + # Set the desired user according to the AS API (which uses the + # 'user' key not 'username'). Since this is a new addition, we'll + # fallback to 'username' if they gave one. + desired_username = body.get("user", desired_username) + + # XXX we should check that desired_username is valid. Currently + # we give appservices carte blanche for any insanity in mxids, + # because the IRC bridges rely on being able to register stupid + # IDs. + + access_token = self.auth.get_access_token_from_request(request) + + if not isinstance(desired_username, str): + raise SynapseError(400, "Desired Username is missing or not a string") + + result = await self._do_appservice_registration( + desired_username, + access_token, + body, + should_issue_refresh_token=should_issue_refresh_token, + ) + + return 200, result + elif self.auth.has_access_token(request): + raise SynapseError( + 400, + "An access token should not be provided on requests to /register (except if type is m.login.application_service)", + ) + + # == Normal User Registration == (everyone else) + if not self._registration_enabled: + raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN) + + # For regular registration, convert the provided username to lowercase + # before attempting to register it. This should mean that people who try + # to register with upper-case in their usernames don't get a nasty surprise. + # + # Note that we treat usernames case-insensitively in login, so they are + # free to carry on imagining that their username is CrAzYh4cKeR if that + # keeps them happy. + if desired_username is not None: + desired_username = desired_username.lower() + + # Check if this account is upgrading from a guest account. + guest_access_token = body.get("guest_access_token", None) + + # Pull out the provided password and do basic sanity checks early. + # + # Note that we remove the password from the body since the auth layer + # will store the body in the session and we don't want a plaintext + # password store there. + password = body.pop("password", None) + if password is not None: + if not isinstance(password, str) or len(password) > 512: + raise SynapseError(400, "Invalid password") + self.password_policy_handler.validate_password(password) + + if "initial_device_display_name" in body and password is None: + # ignore 'initial_device_display_name' if sent without + # a password to work around a client bug where it sent + # the 'initial_device_display_name' param alone, wiping out + # the original registration params + logger.warning("Ignoring initial_device_display_name without password") + del body["initial_device_display_name"] + + session_id = self.auth_handler.get_session_id(body) + registered_user_id = None + password_hash = None + if session_id: + # if we get a registered user id out of here, it means we previously + # registered a user for this session, so we could just return the + # user here. We carry on and go through the auth checks though, + # for paranoia. + registered_user_id = await self.auth_handler.get_session_data( + session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None + ) + # Extract the previously-hashed password from the session. + password_hash = await self.auth_handler.get_session_data( + session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None + ) + + # Ensure that the username is valid. + if desired_username is not None: + await self.registration_handler.check_username( + desired_username, + guest_access_token=guest_access_token, + assigned_user_id=registered_user_id, + ) + + # Check if the user-interactive authentication flows are complete, if + # not this will raise a user-interactive auth error. + try: + auth_result, params, session_id = await self.auth_handler.check_ui_auth( + self._registration_flows, + request, + body, + "register a new account", + ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth. + # + # Hash the password and store it with the session since the client + # is not required to provide the password again. + # + # If a password hash was previously stored we will not attempt to + # re-hash and store it for efficiency. This assumes the password + # does not change throughout the authentication flow, but this + # should be fine since the data is meant to be consistent. + if not password_hash and password: + password_hash = await self.auth_handler.hash(password) + await self.auth_handler.set_session_data( + e.session_id, + UIAuthSessionDataConstants.PASSWORD_HASH, + password_hash, + ) + raise + + # Check that we're not trying to register a denied 3pid. + # + # the user-facing checks will probably already have happened in + # /register/email/requestToken when we requested a 3pid, but that's not + # guaranteed. + if auth_result: + for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: + if login_type in auth_result: + medium = auth_result[login_type]["medium"] + address = auth_result[login_type]["address"] + + if not check_3pid_allowed(self.hs, medium, address): + raise SynapseError( + 403, + "Third party identifiers (email/phone numbers)" + + " are not authorized on this server", + Codes.THREEPID_DENIED, + ) + + if registered_user_id is not None: + logger.info( + "Already registered user ID %r for this session", registered_user_id + ) + # don't re-register the threepids + registered = False + else: + # If we have a password in this request, prefer it. Otherwise, there + # might be a password hash from an earlier request. + if password: + password_hash = await self.auth_handler.hash(password) + if not password_hash: + raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) + + desired_username = params.get("username", None) + guest_access_token = params.get("guest_access_token", None) + + if desired_username is not None: + desired_username = desired_username.lower() + + threepid = None + if auth_result: + threepid = auth_result.get(LoginType.EMAIL_IDENTITY) + + # Also check that we're not trying to register a 3pid that's already + # been registered. + # + # This has probably happened in /register/email/requestToken as well, + # but if a user hits this endpoint twice then clicks on each link from + # the two activation emails, they would register the same 3pid twice. + for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: + if login_type in auth_result: + medium = auth_result[login_type]["medium"] + address = auth_result[login_type]["address"] + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See on_POST in EmailThreepidRequestTokenRestServlet + # in synapse/rest/client/account.py) + if medium == "email": + try: + address = canonicalise_email(address) + except ValueError as e: + raise SynapseError(400, str(e)) + + existing_user_id = await self.store.get_user_id_by_threepid( + medium, address + ) + + if existing_user_id is not None: + raise SynapseError( + 400, + "%s is already in use" % medium, + Codes.THREEPID_IN_USE, + ) + + entries = await self.store.get_user_agents_ips_to_ui_auth_session( + session_id + ) + + registered_user_id = await self.registration_handler.register_user( + localpart=desired_username, + password_hash=password_hash, + guest_access_token=guest_access_token, + threepid=threepid, + address=client_addr, + user_agent_ips=entries, + ) + # Necessary due to auth checks prior to the threepid being + # written to the db + if threepid: + if is_threepid_reserved( + self.hs.config.mau_limits_reserved_threepids, threepid + ): + await self.store.upsert_monthly_active_user(registered_user_id) + + # Remember that the user account has been registered (and the user + # ID it was registered with, since it might not have been specified). + await self.auth_handler.set_session_data( + session_id, + UIAuthSessionDataConstants.REGISTERED_USER_ID, + registered_user_id, + ) + + registered = True + + return_dict = await self._create_registration_details( + registered_user_id, + params, + should_issue_refresh_token=should_issue_refresh_token, + ) + + if registered: + await self.registration_handler.post_registration_actions( + user_id=registered_user_id, + auth_result=auth_result, + access_token=return_dict.get("access_token"), + ) + + return 200, return_dict + + async def _do_appservice_registration( + self, username, as_token, body, should_issue_refresh_token: bool = False + ): + user_id = await self.registration_handler.appservice_register( + username, as_token + ) + return await self._create_registration_details( + user_id, + body, + is_appservice_ghost=True, + should_issue_refresh_token=should_issue_refresh_token, + ) + + async def _create_registration_details( + self, + user_id: str, + params: JsonDict, + is_appservice_ghost: bool = False, + should_issue_refresh_token: bool = False, + ): + """Complete registration of newly-registered user + + Allocates device_id if one was not given; also creates access_token. + + Args: + user_id: full canonical @user:id + params: registration parameters, from which we pull device_id, + initial_device_name and inhibit_login + is_appservice_ghost + should_issue_refresh_token: True if this registration should issue + a refresh token alongside the access token. + Returns: + dictionary for response from /register + """ + result = {"user_id": user_id, "home_server": self.hs.hostname} + if not params.get("inhibit_login", False): + device_id = params.get("device_id") + initial_display_name = params.get("initial_device_display_name") + ( + device_id, + access_token, + valid_until_ms, + refresh_token, + ) = await self.registration_handler.register_device( + user_id, + device_id, + initial_display_name, + is_guest=False, + is_appservice_ghost=is_appservice_ghost, + should_issue_refresh_token=should_issue_refresh_token, + ) + + result.update({"access_token": access_token, "device_id": device_id}) + + if valid_until_ms is not None: + expires_in_ms = valid_until_ms - self.clock.time_msec() + result["expires_in_ms"] = expires_in_ms + + if refresh_token is not None: + result["refresh_token"] = refresh_token + + return result + + async def _do_guest_registration(self, params, address=None): + if not self.hs.config.allow_guest_access: + raise SynapseError(403, "Guest access is disabled") + user_id = await self.registration_handler.register_user( + make_guest=True, address=address + ) + + # we don't allow guests to specify their own device_id, because + # we have nowhere to store it. + device_id = synapse.api.auth.GUEST_DEVICE_ID + initial_display_name = params.get("initial_device_display_name") + ( + device_id, + access_token, + valid_until_ms, + refresh_token, + ) = await self.registration_handler.register_device( + user_id, device_id, initial_display_name, is_guest=True + ) + + result = { + "user_id": user_id, + "device_id": device_id, + "access_token": access_token, + "home_server": self.hs.hostname, + } + + if valid_until_ms is not None: + expires_in_ms = valid_until_ms - self.clock.time_msec() + result["expires_in_ms"] = expires_in_ms + + if refresh_token is not None: + result["refresh_token"] = refresh_token + + return 200, result + + +def _calculate_registration_flows( + # technically `config` has to provide *all* of these interfaces, not just one + config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig], + auth_handler: AuthHandler, +) -> List[List[str]]: + """Get a suitable flows list for registration + + Args: + config: server configuration + auth_handler: authorization handler + + Returns: a list of supported flows + """ + # FIXME: need a better error than "no auth flow found" for scenarios + # where we required 3PID for registration but the user didn't give one + require_email = "email" in config.registrations_require_3pid + require_msisdn = "msisdn" in config.registrations_require_3pid + + show_msisdn = True + show_email = True + + if config.disable_msisdn_registration: + show_msisdn = False + require_msisdn = False + + enabled_auth_types = auth_handler.get_enabled_auth_types() + if LoginType.EMAIL_IDENTITY not in enabled_auth_types: + show_email = False + if require_email: + raise ConfigError( + "Configuration requires email address at registration, but email " + "validation is not configured" + ) + + if LoginType.MSISDN not in enabled_auth_types: + show_msisdn = False + if require_msisdn: + raise ConfigError( + "Configuration requires msisdn at registration, but msisdn " + "validation is not configured" + ) + + flows = [] + + # only support 3PIDless registration if no 3PIDs are required + if not require_email and not require_msisdn: + # Add a dummy step here, otherwise if a client completes + # recaptcha first we'll assume they were going for this flow + # and complete the request, when they could have been trying to + # complete one of the flows with email/msisdn auth. + flows.append([LoginType.DUMMY]) + + # only support the email-only flow if we don't require MSISDN 3PIDs + if show_email and not require_msisdn: + flows.append([LoginType.EMAIL_IDENTITY]) + + # only support the MSISDN-only flow if we don't require email 3PIDs + if show_msisdn and not require_email: + flows.append([LoginType.MSISDN]) + + if show_email and show_msisdn: + # always let users provide both MSISDN & email + flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY]) + + # Prepend m.login.terms to all flows if we're requiring consent + if config.user_consent_at_registration: + for flow in flows: + flow.insert(0, LoginType.TERMS) + + # Prepend recaptcha to all flows if we're requiring captcha + if config.enable_registration_captcha: + for flow in flows: + flow.insert(0, LoginType.RECAPTCHA) + + return flows + + +def register_servlets(hs, http_server): + EmailRegisterRequestTokenRestServlet(hs).register(http_server) + MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) + UsernameAvailabilityRestServlet(hs).register(http_server) + RegistrationSubmitTokenServlet(hs).register(http_server) + RegisterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py new file mode 100644 index 0000000000..0821cd285f --- /dev/null +++ b/synapse/rest/client/relations.py @@ -0,0 +1,381 @@ +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This class implements the proposed relation APIs from MSC 1849. + +Since the MSC has not been approved all APIs here are unstable and may change at +any time to reflect changes in the MSC. +""" + +import logging + +from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.errors import ShadowBanError, SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_integer, + parse_json_object_from_request, + parse_string, +) +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.storage.relations import ( + AggregationPaginationToken, + PaginationChunk, + RelationPaginationToken, +) +from synapse.util.stringutils import random_string + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class RelationSendServlet(RestServlet): + """Helper API for sending events that have relation data. + + Example API shape to send a 👍 reaction to a room: + + POST /rooms/!foo/send_relation/$bar/m.annotation/m.reaction?key=%F0%9F%91%8D + {} + + { + "event_id": "$foobar" + } + """ + + PATTERN = ( + "/rooms/(?P[^/]*)/send_relation" + "/(?P[^/]*)/(?P[^/]*)/(?P[^/]*)" + ) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.event_creation_handler = hs.get_event_creation_handler() + self.txns = HttpTransactionCache(hs) + + def register(self, http_server): + http_server.register_paths( + "POST", + client_patterns(self.PATTERN + "$", releases=()), + self.on_PUT_or_POST, + self.__class__.__name__, + ) + http_server.register_paths( + "PUT", + client_patterns(self.PATTERN + "/(?P[^/]*)$", releases=()), + self.on_PUT, + self.__class__.__name__, + ) + + def on_PUT(self, request, *args, **kwargs): + return self.txns.fetch_or_execute_request( + request, self.on_PUT_or_POST, request, *args, **kwargs + ) + + async def on_PUT_or_POST( + self, request, room_id, parent_id, relation_type, event_type, txn_id=None + ): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + if event_type == EventTypes.Member: + # Add relations to a membership is meaningless, so we just deny it + # at the CS API rather than trying to handle it correctly. + raise SynapseError(400, "Cannot send member events with relations") + + content = parse_json_object_from_request(request) + + aggregation_key = parse_string(request, "key", encoding="utf-8") + + content["m.relates_to"] = { + "event_id": parent_id, + "key": aggregation_key, + "rel_type": relation_type, + } + + event_dict = { + "type": event_type, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + } + + try: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict=event_dict, txn_id=txn_id + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) + + return 200, {"event_id": event_id} + + +class RelationPaginationServlet(RestServlet): + """API to paginate relations on an event by topological ordering, optionally + filtered by relation type and event type. + """ + + PATTERNS = client_patterns( + "/rooms/(?P[^/]*)/relations/(?P[^/]*)" + "(/(?P[^/]*)(/(?P[^/]*))?)?$", + releases=(), + ) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + self.event_handler = hs.get_event_handler() + + async def on_GET( + self, request, room_id, parent_id, relation_type=None, event_type=None + ): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + await self.auth.check_user_in_room_or_world_readable( + room_id, requester.user.to_string(), allow_departed_users=True + ) + + # This gets the original event and checks that a) the event exists and + # b) the user is allowed to view it. + event = await self.event_handler.get_event(requester.user, room_id, parent_id) + + limit = parse_integer(request, "limit", default=5) + from_token_str = parse_string(request, "from") + to_token_str = parse_string(request, "to") + + if event.internal_metadata.is_redacted(): + # If the event is redacted, return an empty list of relations + pagination_chunk = PaginationChunk(chunk=[]) + else: + # Return the relations + from_token = None + if from_token_str: + from_token = RelationPaginationToken.from_string(from_token_str) + + to_token = None + if to_token_str: + to_token = RelationPaginationToken.from_string(to_token_str) + + pagination_chunk = await self.store.get_relations_for_event( + event_id=parent_id, + relation_type=relation_type, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + events = await self.store.get_events_as_list( + [c["event_id"] for c in pagination_chunk.chunk] + ) + + now = self.clock.time_msec() + # We set bundle_aggregations to False when retrieving the original + # event because we want the content before relations were applied to + # it. + original_event = await self._event_serializer.serialize_event( + event, now, bundle_aggregations=False + ) + # Similarly, we don't allow relations to be applied to relations, so we + # return the original relations without any aggregations on top of them + # here. + events = await self._event_serializer.serialize_events( + events, now, bundle_aggregations=False + ) + + return_value = pagination_chunk.to_dict() + return_value["chunk"] = events + return_value["original_event"] = original_event + + return 200, return_value + + +class RelationAggregationPaginationServlet(RestServlet): + """API to paginate aggregation groups of relations, e.g. paginate the + types and counts of the reactions on the events. + + Example request and response: + + GET /rooms/{room_id}/aggregations/{parent_id} + + { + chunk: [ + { + "type": "m.reaction", + "key": "👍", + "count": 3 + } + ] + } + """ + + PATTERNS = client_patterns( + "/rooms/(?P[^/]*)/aggregations/(?P[^/]*)" + "(/(?P[^/]*)(/(?P[^/]*))?)?$", + releases=(), + ) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.event_handler = hs.get_event_handler() + + async def on_GET( + self, request, room_id, parent_id, relation_type=None, event_type=None + ): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + await self.auth.check_user_in_room_or_world_readable( + room_id, + requester.user.to_string(), + allow_departed_users=True, + ) + + # This checks that a) the event exists and b) the user is allowed to + # view it. + event = await self.event_handler.get_event(requester.user, room_id, parent_id) + + if relation_type not in (RelationTypes.ANNOTATION, None): + raise SynapseError(400, "Relation type must be 'annotation'") + + limit = parse_integer(request, "limit", default=5) + from_token_str = parse_string(request, "from") + to_token_str = parse_string(request, "to") + + if event.internal_metadata.is_redacted(): + # If the event is redacted, return an empty list of relations + pagination_chunk = PaginationChunk(chunk=[]) + else: + # Return the relations + from_token = None + if from_token_str: + from_token = AggregationPaginationToken.from_string(from_token_str) + + to_token = None + if to_token_str: + to_token = AggregationPaginationToken.from_string(to_token_str) + + pagination_chunk = await self.store.get_aggregation_groups_for_event( + event_id=parent_id, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + return 200, pagination_chunk.to_dict() + + +class RelationAggregationGroupPaginationServlet(RestServlet): + """API to paginate within an aggregation group of relations, e.g. paginate + all the 👍 reactions on an event. + + Example request and response: + + GET /rooms/{room_id}/aggregations/{parent_id}/m.annotation/m.reaction/👍 + + { + chunk: [ + { + "type": "m.reaction", + "content": { + "m.relates_to": { + "rel_type": "m.annotation", + "key": "👍" + } + } + }, + ... + ] + } + """ + + PATTERNS = client_patterns( + "/rooms/(?P[^/]*)/aggregations/(?P[^/]*)" + "/(?P[^/]*)/(?P[^/]*)/(?P[^/]*)$", + releases=(), + ) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + self.event_handler = hs.get_event_handler() + + async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + await self.auth.check_user_in_room_or_world_readable( + room_id, + requester.user.to_string(), + allow_departed_users=True, + ) + + # This checks that a) the event exists and b) the user is allowed to + # view it. + await self.event_handler.get_event(requester.user, room_id, parent_id) + + if relation_type != RelationTypes.ANNOTATION: + raise SynapseError(400, "Relation type must be 'annotation'") + + limit = parse_integer(request, "limit", default=5) + from_token_str = parse_string(request, "from") + to_token_str = parse_string(request, "to") + + from_token = None + if from_token_str: + from_token = RelationPaginationToken.from_string(from_token_str) + + to_token = None + if to_token_str: + to_token = RelationPaginationToken.from_string(to_token_str) + + result = await self.store.get_relations_for_event( + event_id=parent_id, + relation_type=relation_type, + event_type=event_type, + aggregation_key=key, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + events = await self.store.get_events_as_list( + [c["event_id"] for c in result.chunk] + ) + + now = self.clock.time_msec() + events = await self._event_serializer.serialize_events(events, now) + + return_value = result.to_dict() + return_value["chunk"] = events + + return 200, return_value + + +def register_servlets(hs, http_server): + RelationSendServlet(hs).register(http_server) + RelationPaginationServlet(hs).register(http_server) + RelationAggregationPaginationServlet(hs).register(http_server) + RelationAggregationGroupPaginationServlet(hs).register(http_server) diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py new file mode 100644 index 0000000000..07ea39a8a3 --- /dev/null +++ b/synapse/rest/client/report_event.py @@ -0,0 +1,68 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from http import HTTPStatus + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class ReportEventRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P[^/]*)/report/(?P[^/]*)$") + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.store = hs.get_datastore() + + async def on_POST(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + body = parse_json_object_from_request(request) + + if not isinstance(body.get("reason", ""), str): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'reason' must be a string", + Codes.BAD_JSON, + ) + if not isinstance(body.get("score", 0), int): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'score' must be an integer", + Codes.BAD_JSON, + ) + + await self.store.add_event_report( + room_id=room_id, + event_id=event_id, + user_id=user_id, + reason=body.get("reason"), + content=body, + received_ts=self.clock.time_msec(), + ) + + return 200, {} + + +def register_servlets(hs, http_server): + ReportEventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py new file mode 100644 index 0000000000..ed238b2141 --- /dev/null +++ b/synapse/rest/client/room.py @@ -0,0 +1,1152 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" This module contains REST servlets to do with rooms: /rooms/ """ +import logging +import re +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from urllib import parse as urlparse + +from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import ( + AuthError, + Codes, + InvalidClientCredentialsError, + MissingClientTokenError, + ShadowBanError, + SynapseError, +) +from synapse.api.filtering import Filter +from synapse.events.utils import format_event_for_client_v2 +from synapse.http.servlet import ( + ResolveRoomIdMixin, + RestServlet, + assert_params_in_dict, + parse_boolean, + parse_integer, + parse_json_object_from_request, + parse_string, + parse_strings_from_args, +) +from synapse.http.site import SynapseRequest +from synapse.logging.opentracing import set_tag +from synapse.rest.client._base import client_patterns +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.storage.state import StateFilter +from synapse.streams.config import PaginationConfig +from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID +from synapse.util import json_decoder +from synapse.util.stringutils import parse_and_validate_server_name, random_string + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class TransactionRestServlet(RestServlet): + def __init__(self, hs): + super().__init__() + self.txns = HttpTransactionCache(hs) + + +class RoomCreateRestServlet(TransactionRestServlet): + # No PATTERN; we have custom dispatch rules here + + def __init__(self, hs): + super().__init__(hs) + self._room_creation_handler = hs.get_room_creation_handler() + self.auth = hs.get_auth() + + def register(self, http_server): + PATTERNS = "/createRoom" + register_txn_path(self, PATTERNS, http_server) + + def on_PUT(self, request, txn_id): + set_tag("txn_id", txn_id) + return self.txns.fetch_or_execute_request(request, self.on_POST, request) + + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) + + info, _ = await self._room_creation_handler.create_room( + requester, self.get_room_config(request) + ) + + return 200, info + + def get_room_config(self, request): + user_supplied_config = parse_json_object_from_request(request) + return user_supplied_config + + +# TODO: Needs unit testing for generic events +class RoomStateEventRestServlet(TransactionRestServlet): + def __init__(self, hs): + super().__init__(hs) + self.event_creation_handler = hs.get_event_creation_handler() + self.room_member_handler = hs.get_room_member_handler() + self.message_handler = hs.get_message_handler() + self.auth = hs.get_auth() + + def register(self, http_server): + # /room/$roomid/state/$eventtype + no_state_key = "/rooms/(?P[^/]*)/state/(?P[^/]*)$" + + # /room/$roomid/state/$eventtype/$statekey + state_key = ( + "/rooms/(?P[^/]*)/state/" + "(?P[^/]*)/(?P[^/]*)$" + ) + + http_server.register_paths( + "GET", + client_patterns(state_key, v1=True), + self.on_GET, + self.__class__.__name__, + ) + http_server.register_paths( + "PUT", + client_patterns(state_key, v1=True), + self.on_PUT, + self.__class__.__name__, + ) + http_server.register_paths( + "GET", + client_patterns(no_state_key, v1=True), + self.on_GET_no_state_key, + self.__class__.__name__, + ) + http_server.register_paths( + "PUT", + client_patterns(no_state_key, v1=True), + self.on_PUT_no_state_key, + self.__class__.__name__, + ) + + def on_GET_no_state_key(self, request, room_id, event_type): + return self.on_GET(request, room_id, event_type, "") + + def on_PUT_no_state_key(self, request, room_id, event_type): + return self.on_PUT(request, room_id, event_type, "") + + async def on_GET(self, request, room_id, event_type, state_key): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + format = parse_string( + request, "format", default="content", allowed_values=["content", "event"] + ) + + msg_handler = self.message_handler + data = await msg_handler.get_room_data( + user_id=requester.user.to_string(), + room_id=room_id, + event_type=event_type, + state_key=state_key, + ) + + if not data: + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + + if format == "event": + event = format_event_for_client_v2(data.get_dict()) + return 200, event + elif format == "content": + return 200, data.get_dict()["content"] + + async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): + requester = await self.auth.get_user_by_req(request) + + if txn_id: + set_tag("txn_id", txn_id) + + content = parse_json_object_from_request(request) + + event_dict = { + "type": event_type, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + } + + if state_key is not None: + event_dict["state_key"] = state_key + + try: + if event_type == EventTypes.Member: + membership = content.get("membership", None) + event_id, _ = await self.room_member_handler.update_membership( + requester, + target=UserID.from_string(state_key), + room_id=room_id, + action=membership, + content=content, + ) + else: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict, txn_id=txn_id + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) + + set_tag("event_id", event_id) + ret = {"event_id": event_id} + return 200, ret + + +# TODO: Needs unit testing for generic events + feedback +class RoomSendEventRestServlet(TransactionRestServlet): + def __init__(self, hs): + super().__init__(hs) + self.event_creation_handler = hs.get_event_creation_handler() + self.auth = hs.get_auth() + + def register(self, http_server): + # /rooms/$roomid/send/$event_type[/$txn_id] + PATTERNS = "/rooms/(?P[^/]*)/send/(?P[^/]*)" + register_txn_path(self, PATTERNS, http_server, with_get=True) + + async def on_POST(self, request, room_id, event_type, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + content = parse_json_object_from_request(request) + + event_dict = { + "type": event_type, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + } + + if b"ts" in request.args and requester.app_service: + event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) + + try: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict, txn_id=txn_id + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) + + set_tag("event_id", event_id) + return 200, {"event_id": event_id} + + def on_GET(self, request, room_id, event_type, txn_id): + return 200, "Not implemented" + + def on_PUT(self, request, room_id, event_type, txn_id): + set_tag("txn_id", txn_id) + + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id, event_type, txn_id + ) + + +# TODO: Needs unit testing for room ID + alias joins +class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): + def __init__(self, hs): + super().__init__(hs) + super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up + self.auth = hs.get_auth() + + def register(self, http_server): + # /join/$room_identifier[/$txn_id] + PATTERNS = "/join/(?P[^/]*)" + register_txn_path(self, PATTERNS, http_server) + + async def on_POST( + self, + request: SynapseRequest, + room_identifier: str, + txn_id: Optional[str] = None, + ): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + try: + content = parse_json_object_from_request(request) + except Exception: + # Turns out we used to ignore the body entirely, and some clients + # cheekily send invalid bodies. + content = {} + + # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + args: Dict[bytes, List[bytes]] = request.args # type: ignore + remote_room_hosts = parse_strings_from_args(args, "server_name", required=False) + room_id, remote_room_hosts = await self.resolve_room_id( + room_identifier, + remote_room_hosts, + ) + + await self.room_member_handler.update_membership( + requester=requester, + target=requester.user, + room_id=room_id, + action="join", + txn_id=txn_id, + remote_room_hosts=remote_room_hosts, + content=content, + third_party_signed=content.get("third_party_signed", None), + ) + + return 200, {"room_id": room_id} + + def on_PUT(self, request, room_identifier, txn_id): + set_tag("txn_id", txn_id) + + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_identifier, txn_id + ) + + +# TODO: Needs unit testing +class PublicRoomListRestServlet(TransactionRestServlet): + PATTERNS = client_patterns("/publicRooms$", v1=True) + + def __init__(self, hs): + super().__init__(hs) + self.hs = hs + self.auth = hs.get_auth() + + async def on_GET(self, request): + server = parse_string(request, "server") + + try: + await self.auth.get_user_by_req(request, allow_guest=True) + except InvalidClientCredentialsError as e: + # Option to allow servers to require auth when accessing + # /publicRooms via CS API. This is especially helpful in private + # federations. + if not self.hs.config.allow_public_rooms_without_auth: + raise + + # We allow people to not be authed if they're just looking at our + # room list, but require auth when we proxy the request. + # In both cases we call the auth function, as that has the side + # effect of logging who issued this request if an access token was + # provided. + if server: + raise e + + limit: Optional[int] = parse_integer(request, "limit", 0) + since_token = parse_string(request, "since") + + if limit == 0: + # zero is a special value which corresponds to no limit. + limit = None + + handler = self.hs.get_room_list_handler() + if server and server != self.hs.config.server_name: + # Ensure the server is valid. + try: + parse_and_validate_server_name(server) + except ValueError: + raise SynapseError( + 400, + "Invalid server name: %s" % (server,), + Codes.INVALID_PARAM, + ) + + data = await handler.get_remote_public_room_list( + server, limit=limit, since_token=since_token + ) + else: + data = await handler.get_local_public_room_list( + limit=limit, since_token=since_token + ) + + return 200, data + + async def on_POST(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) + + server = parse_string(request, "server") + content = parse_json_object_from_request(request) + + limit: Optional[int] = int(content.get("limit", 100)) + since_token = content.get("since", None) + search_filter = content.get("filter", None) + + include_all_networks = content.get("include_all_networks", False) + third_party_instance_id = content.get("third_party_instance_id", None) + + if include_all_networks: + network_tuple = None + if third_party_instance_id is not None: + raise SynapseError( + 400, "Can't use include_all_networks with an explicit network" + ) + elif third_party_instance_id is None: + network_tuple = ThirdPartyInstanceID(None, None) + else: + network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) + + if limit == 0: + # zero is a special value which corresponds to no limit. + limit = None + + handler = self.hs.get_room_list_handler() + if server and server != self.hs.config.server_name: + # Ensure the server is valid. + try: + parse_and_validate_server_name(server) + except ValueError: + raise SynapseError( + 400, + "Invalid server name: %s" % (server,), + Codes.INVALID_PARAM, + ) + + data = await handler.get_remote_public_room_list( + server, + limit=limit, + since_token=since_token, + search_filter=search_filter, + include_all_networks=include_all_networks, + third_party_instance_id=third_party_instance_id, + ) + + else: + data = await handler.get_local_public_room_list( + limit=limit, + since_token=since_token, + search_filter=search_filter, + network_tuple=network_tuple, + ) + + return 200, data + + +# TODO: Needs unit testing +class RoomMemberListRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P[^/]*)/members$", v1=True) + + def __init__(self, hs): + super().__init__() + self.message_handler = hs.get_message_handler() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, room_id): + # TODO support Pagination stream API (limit/tokens) + requester = await self.auth.get_user_by_req(request, allow_guest=True) + handler = self.message_handler + + # request the state as of a given event, as identified by a stream token, + # for consistency with /messages etc. + # useful for getting the membership in retrospect as of a given /sync + # response. + at_token_string = parse_string(request, "at") + if at_token_string is None: + at_token = None + else: + at_token = await StreamToken.from_string(self.store, at_token_string) + + # let you filter down on particular memberships. + # XXX: this may not be the best shape for this API - we could pass in a filter + # instead, except filters aren't currently aware of memberships. + # See https://github.com/matrix-org/matrix-doc/issues/1337 for more details. + membership = parse_string(request, "membership") + not_membership = parse_string(request, "not_membership") + + events = await handler.get_state_events( + room_id=room_id, + user_id=requester.user.to_string(), + at_token=at_token, + state_filter=StateFilter.from_types([(EventTypes.Member, None)]), + ) + + chunk = [] + + for event in events: + if (membership and event["content"].get("membership") != membership) or ( + not_membership and event["content"].get("membership") == not_membership + ): + continue + chunk.append(event) + + return 200, {"chunk": chunk} + + +# deprecated in favour of /members?membership=join? +# except it does custom AS logic and has a simpler return format +class JoinedRoomMemberListRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P[^/]*)/joined_members$", v1=True) + + def __init__(self, hs): + super().__init__() + self.message_handler = hs.get_message_handler() + self.auth = hs.get_auth() + + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request) + + users_with_profile = await self.message_handler.get_joined_members( + requester, room_id + ) + + return 200, {"joined": users_with_profile} + + +# TODO: Needs better unit testing +class RoomMessageListRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P[^/]*)/messages$", v1=True) + + def __init__(self, hs): + super().__init__() + self.pagination_handler = hs.get_pagination_handler() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + pagination_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) + as_client_event = b"raw" not in request.args + filter_str = parse_string(request, "filter", encoding="utf-8") + if filter_str: + filter_json = urlparse.unquote(filter_str) + event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) + if ( + event_filter + and event_filter.filter_json.get("event_format", "client") + == "federation" + ): + as_client_event = False + else: + event_filter = None + + msgs = await self.pagination_handler.get_messages( + room_id=room_id, + requester=requester, + pagin_config=pagination_config, + as_client_event=as_client_event, + event_filter=event_filter, + ) + + return 200, msgs + + +# TODO: Needs unit testing +class RoomStateRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P[^/]*)/state$", v1=True) + + def __init__(self, hs): + super().__init__() + self.message_handler = hs.get_message_handler() + self.auth = hs.get_auth() + + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + # Get all the current state for this room + events = await self.message_handler.get_state_events( + room_id=room_id, + user_id=requester.user.to_string(), + is_guest=requester.is_guest, + ) + return 200, events + + +# TODO: Needs unit testing +class RoomInitialSyncRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P[^/]*)/initialSync$", v1=True) + + def __init__(self, hs): + super().__init__() + self.initial_sync_handler = hs.get_initial_sync_handler() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + pagination_config = await PaginationConfig.from_request(self.store, request) + content = await self.initial_sync_handler.room_initial_sync( + room_id=room_id, requester=requester, pagin_config=pagination_config + ) + return 200, content + + +class RoomEventServlet(RestServlet): + PATTERNS = client_patterns( + "/rooms/(?P[^/]*)/event/(?P[^/]*)$", v1=True + ) + + def __init__(self, hs): + super().__init__() + self.clock = hs.get_clock() + self.event_handler = hs.get_event_handler() + self._event_serializer = hs.get_event_client_serializer() + self.auth = hs.get_auth() + + async def on_GET(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + try: + event = await self.event_handler.get_event( + requester.user, room_id, event_id + ) + except AuthError: + # This endpoint is supposed to return a 404 when the requester does + # not have permission to access the event + # https://matrix.org/docs/spec/client_server/r0.5.0#get-matrix-client-r0-rooms-roomid-event-eventid + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + + time_now = self.clock.time_msec() + if event: + event = await self._event_serializer.serialize_event(event, time_now) + return 200, event + + return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + + +class RoomEventContextServlet(RestServlet): + PATTERNS = client_patterns( + "/rooms/(?P[^/]*)/context/(?P[^/]*)$", v1=True + ) + + def __init__(self, hs): + super().__init__() + self.clock = hs.get_clock() + self.room_context_handler = hs.get_room_context_handler() + self._event_serializer = hs.get_event_client_serializer() + self.auth = hs.get_auth() + + async def on_GET(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + limit = parse_integer(request, "limit", default=10) + + # picking the API shape for symmetry with /messages + filter_str = parse_string(request, "filter", encoding="utf-8") + if filter_str: + filter_json = urlparse.unquote(filter_str) + event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) + else: + event_filter = None + + results = await self.room_context_handler.get_event_context( + requester, room_id, event_id, limit, event_filter + ) + + if not results: + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + + time_now = self.clock.time_msec() + results["events_before"] = await self._event_serializer.serialize_events( + results["events_before"], time_now + ) + results["event"] = await self._event_serializer.serialize_event( + results["event"], time_now + ) + results["events_after"] = await self._event_serializer.serialize_events( + results["events_after"], time_now + ) + results["state"] = await self._event_serializer.serialize_events( + results["state"], + time_now, + # No need to bundle aggregations for state events + bundle_aggregations=False, + ) + + return 200, results + + +class RoomForgetRestServlet(TransactionRestServlet): + def __init__(self, hs): + super().__init__(hs) + self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() + + def register(self, http_server): + PATTERNS = "/rooms/(?P[^/]*)/forget" + register_txn_path(self, PATTERNS, http_server) + + async def on_POST(self, request, room_id, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=False) + + await self.room_member_handler.forget(user=requester.user, room_id=room_id) + + return 200, {} + + def on_PUT(self, request, room_id, txn_id): + set_tag("txn_id", txn_id) + + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id, txn_id + ) + + +# TODO: Needs unit testing +class RoomMembershipRestServlet(TransactionRestServlet): + def __init__(self, hs): + super().__init__(hs) + self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() + + def register(self, http_server): + # /rooms/$roomid/[invite|join|leave] + PATTERNS = ( + "/rooms/(?P[^/]*)/" + "(?Pjoin|invite|leave|ban|unban|kick)" + ) + register_txn_path(self, PATTERNS, http_server) + + async def on_POST(self, request, room_id, membership_action, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + if requester.is_guest and membership_action not in { + Membership.JOIN, + Membership.LEAVE, + }: + raise AuthError(403, "Guest access not allowed") + + try: + content = parse_json_object_from_request(request) + except Exception: + # Turns out we used to ignore the body entirely, and some clients + # cheekily send invalid bodies. + content = {} + + if membership_action == "invite" and self._has_3pid_invite_keys(content): + try: + await self.room_member_handler.do_3pid_invite( + room_id, + requester.user, + content["medium"], + content["address"], + content["id_server"], + requester, + txn_id, + content.get("id_access_token"), + ) + except ShadowBanError: + # Pretend the request succeeded. + pass + return 200, {} + + target = requester.user + if membership_action in ["invite", "ban", "unban", "kick"]: + assert_params_in_dict(content, ["user_id"]) + target = UserID.from_string(content["user_id"]) + + event_content = None + if "reason" in content: + event_content = {"reason": content["reason"]} + + try: + await self.room_member_handler.update_membership( + requester=requester, + target=target, + room_id=room_id, + action=membership_action, + txn_id=txn_id, + third_party_signed=content.get("third_party_signed", None), + content=event_content, + ) + except ShadowBanError: + # Pretend the request succeeded. + pass + + return_value = {} + + if membership_action == "join": + return_value["room_id"] = room_id + + return 200, return_value + + def _has_3pid_invite_keys(self, content): + for key in {"id_server", "medium", "address"}: + if key not in content: + return False + return True + + def on_PUT(self, request, room_id, membership_action, txn_id): + set_tag("txn_id", txn_id) + + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id, membership_action, txn_id + ) + + +class RoomRedactEventRestServlet(TransactionRestServlet): + def __init__(self, hs): + super().__init__(hs) + self.event_creation_handler = hs.get_event_creation_handler() + self.auth = hs.get_auth() + + def register(self, http_server): + PATTERNS = "/rooms/(?P[^/]*)/redact/(?P[^/]*)" + register_txn_path(self, PATTERNS, http_server) + + async def on_POST(self, request, room_id, event_id, txn_id=None): + requester = await self.auth.get_user_by_req(request) + content = parse_json_object_from_request(request) + + try: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "redacts": event_id, + }, + txn_id=txn_id, + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) + + set_tag("event_id", event_id) + return 200, {"event_id": event_id} + + def on_PUT(self, request, room_id, event_id, txn_id): + set_tag("txn_id", txn_id) + + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id, event_id, txn_id + ) + + +class RoomTypingRestServlet(RestServlet): + PATTERNS = client_patterns( + "/rooms/(?P[^/]*)/typing/(?P[^/]*)$", v1=True + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.hs = hs + self.presence_handler = hs.get_presence_handler() + self.auth = hs.get_auth() + + # If we're not on the typing writer instance we should scream if we get + # requests. + self._is_typing_writer = ( + hs.config.worker.writers.typing == hs.get_instance_name() + ) + + async def on_PUT(self, request, room_id, user_id): + requester = await self.auth.get_user_by_req(request) + + if not self._is_typing_writer: + raise Exception("Got /typing request on instance that is not typing writer") + + room_id = urlparse.unquote(room_id) + target_user = UserID.from_string(urlparse.unquote(user_id)) + + content = parse_json_object_from_request(request) + + await self.presence_handler.bump_presence_active_time(requester.user) + + # Limit timeout to stop people from setting silly typing timeouts. + timeout = min(content.get("timeout", 30000), 120000) + + # Defer getting the typing handler since it will raise on workers. + typing_handler = self.hs.get_typing_writer_handler() + + try: + if content["typing"]: + await typing_handler.started_typing( + target_user=target_user, + requester=requester, + room_id=room_id, + timeout=timeout, + ) + else: + await typing_handler.stopped_typing( + target_user=target_user, requester=requester, room_id=room_id + ) + except ShadowBanError: + # Pretend this worked without error. + pass + + return 200, {} + + +class RoomAliasListServlet(RestServlet): + PATTERNS = [ + re.compile( + r"^/_matrix/client/unstable/org\.matrix\.msc2432" + r"/rooms/(?P[^/]*)/aliases" + ), + ] + list(client_patterns("/rooms/(?P[^/]*)/aliases$", unstable=False)) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.directory_handler = hs.get_directory_handler() + + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request) + + alias_list = await self.directory_handler.get_aliases_for_room( + requester, room_id + ) + + return 200, {"aliases": alias_list} + + +class SearchRestServlet(RestServlet): + PATTERNS = client_patterns("/search$", v1=True) + + def __init__(self, hs): + super().__init__() + self.search_handler = hs.get_search_handler() + self.auth = hs.get_auth() + + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) + + content = parse_json_object_from_request(request) + + batch = parse_string(request, "next_batch") + results = await self.search_handler.search(requester.user, content, batch) + + return 200, results + + +class JoinedRoomsRestServlet(RestServlet): + PATTERNS = client_patterns("/joined_rooms$", v1=True) + + def __init__(self, hs): + super().__init__() + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + room_ids = await self.store.get_rooms_for_user(requester.user.to_string()) + return 200, {"joined_rooms": list(room_ids)} + + +def register_txn_path(servlet, regex_string, http_server, with_get=False): + """Registers a transaction-based path. + + This registers two paths: + PUT regex_string/$txnid + POST regex_string + + Args: + regex_string (str): The regex string to register. Must NOT have a + trailing $ as this string will be appended to. + http_server : The http_server to register paths with. + with_get: True to also register respective GET paths for the PUTs. + """ + http_server.register_paths( + "POST", + client_patterns(regex_string + "$", v1=True), + servlet.on_POST, + servlet.__class__.__name__, + ) + http_server.register_paths( + "PUT", + client_patterns(regex_string + "/(?P[^/]*)$", v1=True), + servlet.on_PUT, + servlet.__class__.__name__, + ) + if with_get: + http_server.register_paths( + "GET", + client_patterns(regex_string + "/(?P[^/]*)$", v1=True), + servlet.on_GET, + servlet.__class__.__name__, + ) + + +class RoomSpaceSummaryRestServlet(RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc2946" + "/rooms/(?P[^/]*)/spaces$" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._auth = hs.get_auth() + self._room_summary_handler = hs.get_room_summary_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request, allow_guest=True) + + max_rooms_per_space = parse_integer(request, "max_rooms_per_space") + if max_rooms_per_space is not None and max_rooms_per_space < 0: + raise SynapseError( + 400, + "Value for 'max_rooms_per_space' must be a non-negative integer", + Codes.BAD_JSON, + ) + + return 200, await self._room_summary_handler.get_space_summary( + requester.user.to_string(), + room_id, + suggested_only=parse_boolean(request, "suggested_only", default=False), + max_rooms_per_space=max_rooms_per_space, + ) + + # TODO When switching to the stable endpoint, remove the POST handler. + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request, allow_guest=True) + content = parse_json_object_from_request(request) + + suggested_only = content.get("suggested_only", False) + if not isinstance(suggested_only, bool): + raise SynapseError( + 400, "'suggested_only' must be a boolean", Codes.BAD_JSON + ) + + max_rooms_per_space = content.get("max_rooms_per_space") + if max_rooms_per_space is not None: + if not isinstance(max_rooms_per_space, int): + raise SynapseError( + 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON + ) + if max_rooms_per_space < 0: + raise SynapseError( + 400, + "Value for 'max_rooms_per_space' must be a non-negative integer", + Codes.BAD_JSON, + ) + + return 200, await self._room_summary_handler.get_space_summary( + requester.user.to_string(), + room_id, + suggested_only=suggested_only, + max_rooms_per_space=max_rooms_per_space, + ) + + +class RoomHierarchyRestServlet(RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc2946" + "/rooms/(?P[^/]*)/hierarchy$" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._auth = hs.get_auth() + self._room_summary_handler = hs.get_room_summary_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request, allow_guest=True) + + max_depth = parse_integer(request, "max_depth") + if max_depth is not None and max_depth < 0: + raise SynapseError( + 400, "'max_depth' must be a non-negative integer", Codes.BAD_JSON + ) + + limit = parse_integer(request, "limit") + if limit is not None and limit <= 0: + raise SynapseError( + 400, "'limit' must be a positive integer", Codes.BAD_JSON + ) + + return 200, await self._room_summary_handler.get_room_hierarchy( + requester.user.to_string(), + room_id, + suggested_only=parse_boolean(request, "suggested_only", default=False), + max_depth=max_depth, + limit=limit, + from_token=parse_string(request, "from"), + ) + + +class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/im.nheko.summary" + "/rooms/(?P[^/]*)/summary$" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + self._auth = hs.get_auth() + self._room_summary_handler = hs.get_room_summary_handler() + + async def on_GET( + self, request: SynapseRequest, room_identifier: str + ) -> Tuple[int, JsonDict]: + try: + requester = await self._auth.get_user_by_req(request, allow_guest=True) + requester_user_id: Optional[str] = requester.user.to_string() + except MissingClientTokenError: + # auth is optional + requester_user_id = None + + # twisted.web.server.Request.args is incorrectly defined as Optional[Any] + args: Dict[bytes, List[bytes]] = request.args # type: ignore + remote_room_hosts = parse_strings_from_args(args, "via", required=False) + room_id, remote_room_hosts = await self.resolve_room_id( + room_identifier, + remote_room_hosts, + ) + + return 200, await self._room_summary_handler.get_room_summary( + requester_user_id, + room_id, + remote_room_hosts, + ) + + +def register_servlets(hs: "HomeServer", http_server, is_worker=False): + RoomStateEventRestServlet(hs).register(http_server) + RoomMemberListRestServlet(hs).register(http_server) + JoinedRoomMemberListRestServlet(hs).register(http_server) + RoomMessageListRestServlet(hs).register(http_server) + JoinRoomAliasServlet(hs).register(http_server) + RoomMembershipRestServlet(hs).register(http_server) + RoomSendEventRestServlet(hs).register(http_server) + PublicRoomListRestServlet(hs).register(http_server) + RoomStateRestServlet(hs).register(http_server) + RoomRedactEventRestServlet(hs).register(http_server) + RoomTypingRestServlet(hs).register(http_server) + RoomEventContextServlet(hs).register(http_server) + RoomSpaceSummaryRestServlet(hs).register(http_server) + RoomHierarchyRestServlet(hs).register(http_server) + if hs.config.experimental.msc3266_enabled: + RoomSummaryRestServlet(hs).register(http_server) + RoomEventServlet(hs).register(http_server) + JoinedRoomsRestServlet(hs).register(http_server) + RoomAliasListServlet(hs).register(http_server) + SearchRestServlet(hs).register(http_server) + + # Some servlets only get registered for the main process. + if not is_worker: + RoomCreateRestServlet(hs).register(http_server) + RoomForgetRestServlet(hs).register(http_server) + + +def register_deprecated_servlets(hs, http_server): + RoomInitialSyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py new file mode 100644 index 0000000000..3172aba605 --- /dev/null +++ b/synapse/rest/client/room_batch.py @@ -0,0 +1,441 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.appservice import ApplicationService +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, + parse_string, + parse_strings_from_args, +) +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.types import Requester, UserID, create_requester +from synapse.util.stringutils import random_string + +logger = logging.getLogger(__name__) + + +class RoomBatchSendEventRestServlet(RestServlet): + """ + API endpoint which can insert a chunk of events historically back in time + next to the given `prev_event`. + + `chunk_id` comes from `next_chunk_id `in the response of the batch send + endpoint and is derived from the "insertion" events added to each chunk. + It's not required for the first batch send. + + `state_events_at_start` is used to define the historical state events + needed to auth the events like join events. These events will float + outside of the normal DAG as outlier's and won't be visible in the chat + history which also allows us to insert multiple chunks without having a bunch + of `@mxid joined the room` noise between each chunk. + + `events` is chronological chunk/list of events you want to insert. + There is a reverse-chronological constraint on chunks so once you insert + some messages, you can only insert older ones after that. + tldr; Insert chunks from your most recent history -> oldest history. + + POST /_matrix/client/unstable/org.matrix.msc2716/rooms//batch_send?prev_event=&chunk_id= + { + "events": [ ... ], + "state_events_at_start": [ ... ] + } + """ + + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc2716" + "/rooms/(?P[^/]*)/batch_send$" + ), + ) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.store = hs.get_datastore() + self.state_store = hs.get_storage().state + self.event_creation_handler = hs.get_event_creation_handler() + self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() + self.txns = HttpTransactionCache(hs) + + async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int: + ( + most_recent_prev_event_id, + most_recent_prev_event_depth, + ) = await self.store.get_max_depth_of(prev_event_ids) + + # We want to insert the historical event after the `prev_event` but before the successor event + # + # We inherit depth from the successor event instead of the `prev_event` + # because events returned from `/messages` are first sorted by `topological_ordering` + # which is just the `depth` and then tie-break with `stream_ordering`. + # + # We mark these inserted historical events as "backfilled" which gives them a + # negative `stream_ordering`. If we use the same depth as the `prev_event`, + # then our historical event will tie-break and be sorted before the `prev_event` + # when it should come after. + # + # We want to use the successor event depth so they appear after `prev_event` because + # it has a larger `depth` but before the successor event because the `stream_ordering` + # is negative before the successor event. + successor_event_ids = await self.store.get_successor_events( + [most_recent_prev_event_id] + ) + + # If we can't find any successor events, then it's a forward extremity of + # historical messages and we can just inherit from the previous historical + # event which we can already assume has the correct depth where we want + # to insert into. + if not successor_event_ids: + depth = most_recent_prev_event_depth + else: + ( + _, + oldest_successor_depth, + ) = await self.store.get_min_depth_of(successor_event_ids) + + depth = oldest_successor_depth + + return depth + + def _create_insertion_event_dict( + self, sender: str, room_id: str, origin_server_ts: int + ): + """Creates an event dict for an "insertion" event with the proper fields + and a random chunk ID. + + Args: + sender: The event author MXID + room_id: The room ID that the event belongs to + origin_server_ts: Timestamp when the event was sent + + Returns: + Tuple of event ID and stream ordering position + """ + + next_chunk_id = random_string(8) + insertion_event = { + "type": EventTypes.MSC2716_INSERTION, + "sender": sender, + "room_id": room_id, + "content": { + EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id, + EventContentFields.MSC2716_HISTORICAL: True, + }, + "origin_server_ts": origin_server_ts, + } + + return insertion_event + + async def _create_requester_for_user_id_from_app_service( + self, user_id: str, app_service: ApplicationService + ) -> Requester: + """Creates a new requester for the given user_id + and validates that the app service is allowed to control + the given user. + + Args: + user_id: The author MXID that the app service is controlling + app_service: The app service that controls the user + + Returns: + Requester object + """ + + await self.auth.validate_appservice_can_control_user_id(app_service, user_id) + + return create_requester(user_id, app_service=app_service) + + async def on_POST(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=False) + + if not requester.app_service: + raise AuthError( + 403, + "Only application services can use the /batchsend endpoint", + ) + + body = parse_json_object_from_request(request) + assert_params_in_dict(body, ["state_events_at_start", "events"]) + + prev_events_from_query = parse_strings_from_args(request.args, "prev_event") + chunk_id_from_query = parse_string(request, "chunk_id") + + if prev_events_from_query is None: + raise SynapseError( + 400, + "prev_event query parameter is required when inserting historical messages back in time", + errcode=Codes.MISSING_PARAM, + ) + + # For the event we are inserting next to (`prev_events_from_query`), + # find the most recent auth events (derived from state events) that + # allowed that message to be sent. We will use that as a base + # to auth our historical messages against. + ( + most_recent_prev_event_id, + _, + ) = await self.store.get_max_depth_of(prev_events_from_query) + # mapping from (type, state_key) -> state_event_id + prev_state_map = await self.state_store.get_state_ids_for_event( + most_recent_prev_event_id + ) + # List of state event ID's + prev_state_ids = list(prev_state_map.values()) + auth_event_ids = prev_state_ids + + state_events_at_start = [] + for state_event in body["state_events_at_start"]: + assert_params_in_dict( + state_event, ["type", "origin_server_ts", "content", "sender"] + ) + + logger.debug( + "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s", + state_event, + auth_event_ids, + ) + + event_dict = { + "type": state_event["type"], + "origin_server_ts": state_event["origin_server_ts"], + "content": state_event["content"], + "room_id": room_id, + "sender": state_event["sender"], + "state_key": state_event["state_key"], + } + + # Mark all events as historical + event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True + + # Make the state events float off on their own + fake_prev_event_id = "$" + random_string(43) + + # TODO: This is pretty much the same as some other code to handle inserting state in this file + if event_dict["type"] == EventTypes.Member: + membership = event_dict["content"].get("membership", None) + event_id, _ = await self.room_member_handler.update_membership( + await self._create_requester_for_user_id_from_app_service( + state_event["sender"], requester.app_service + ), + target=UserID.from_string(event_dict["state_key"]), + room_id=room_id, + action=membership, + content=event_dict["content"], + outlier=True, + prev_event_ids=[fake_prev_event_id], + # Make sure to use a copy of this list because we modify it + # later in the loop here. Otherwise it will be the same + # reference and also update in the event when we append later. + auth_event_ids=auth_event_ids.copy(), + ) + else: + # TODO: Add some complement tests that adds state that is not member joins + # and will use this code path. Maybe we only want to support join state events + # and can get rid of this `else`? + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + await self._create_requester_for_user_id_from_app_service( + state_event["sender"], requester.app_service + ), + event_dict, + outlier=True, + prev_event_ids=[fake_prev_event_id], + # Make sure to use a copy of this list because we modify it + # later in the loop here. Otherwise it will be the same + # reference and also update in the event when we append later. + auth_event_ids=auth_event_ids.copy(), + ) + event_id = event.event_id + + state_events_at_start.append(event_id) + auth_event_ids.append(event_id) + + events_to_create = body["events"] + + inherited_depth = await self._inherit_depth_from_prev_ids( + prev_events_from_query + ) + + # Figure out which chunk to connect to. If they passed in + # chunk_id_from_query let's use it. The chunk ID passed in comes + # from the chunk_id in the "insertion" event from the previous chunk. + last_event_in_chunk = events_to_create[-1] + chunk_id_to_connect_to = chunk_id_from_query + base_insertion_event = None + if chunk_id_from_query: + # All but the first base insertion event should point at a fake + # event, which causes the HS to ask for the state at the start of + # the chunk later. + prev_event_ids = [fake_prev_event_id] + # TODO: Verify the chunk_id_from_query corresponds to an insertion event + pass + # Otherwise, create an insertion event to act as a starting point. + # + # We don't always have an insertion event to start hanging more history + # off of (ideally there would be one in the main DAG, but that's not the + # case if we're wanting to add history to e.g. existing rooms without + # an insertion event), in which case we just create a new insertion event + # that can then get pointed to by a "marker" event later. + else: + prev_event_ids = prev_events_from_query + + base_insertion_event_dict = self._create_insertion_event_dict( + sender=requester.user.to_string(), + room_id=room_id, + origin_server_ts=last_event_in_chunk["origin_server_ts"], + ) + base_insertion_event_dict["prev_events"] = prev_event_ids.copy() + + ( + base_insertion_event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + await self._create_requester_for_user_id_from_app_service( + base_insertion_event_dict["sender"], + requester.app_service, + ), + base_insertion_event_dict, + prev_event_ids=base_insertion_event_dict.get("prev_events"), + auth_event_ids=auth_event_ids, + historical=True, + depth=inherited_depth, + ) + + chunk_id_to_connect_to = base_insertion_event["content"][ + EventContentFields.MSC2716_NEXT_CHUNK_ID + ] + + # Connect this current chunk to the insertion event from the previous chunk + chunk_event = { + "type": EventTypes.MSC2716_CHUNK, + "sender": requester.user.to_string(), + "room_id": room_id, + "content": { + EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to, + EventContentFields.MSC2716_HISTORICAL: True, + }, + # Since the chunk event is put at the end of the chunk, + # where the newest-in-time event is, copy the origin_server_ts from + # the last event we're inserting + "origin_server_ts": last_event_in_chunk["origin_server_ts"], + } + # Add the chunk event to the end of the chunk (newest-in-time) + events_to_create.append(chunk_event) + + # Add an "insertion" event to the start of each chunk (next to the oldest-in-time + # event in the chunk) so the next chunk can be connected to this one. + insertion_event = self._create_insertion_event_dict( + sender=requester.user.to_string(), + room_id=room_id, + # Since the insertion event is put at the start of the chunk, + # where the oldest-in-time event is, copy the origin_server_ts from + # the first event we're inserting + origin_server_ts=events_to_create[0]["origin_server_ts"], + ) + # Prepend the insertion event to the start of the chunk (oldest-in-time) + events_to_create = [insertion_event] + events_to_create + + event_ids = [] + events_to_persist = [] + for ev in events_to_create: + assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) + + event_dict = { + "type": ev["type"], + "origin_server_ts": ev["origin_server_ts"], + "content": ev["content"], + "room_id": room_id, + "sender": ev["sender"], # requester.user.to_string(), + "prev_events": prev_event_ids.copy(), + } + + # Mark all events as historical + event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True + + event, context = await self.event_creation_handler.create_event( + await self._create_requester_for_user_id_from_app_service( + ev["sender"], requester.app_service + ), + event_dict, + prev_event_ids=event_dict.get("prev_events"), + auth_event_ids=auth_event_ids, + historical=True, + depth=inherited_depth, + ) + logger.debug( + "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s", + event, + prev_event_ids, + auth_event_ids, + ) + + assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( + event.sender, + ) + + events_to_persist.append((event, context)) + event_id = event.event_id + + event_ids.append(event_id) + prev_event_ids = [event_id] + + # Persist events in reverse-chronological order so they have the + # correct stream_ordering as they are backfilled (which decrements). + # Events are sorted by (topological_ordering, stream_ordering) + # where topological_ordering is just depth. + for (event, context) in reversed(events_to_persist): + ev = await self.event_creation_handler.handle_new_client_event( + await self._create_requester_for_user_id_from_app_service( + event["sender"], requester.app_service + ), + event=event, + context=context, + ) + + # Add the base_insertion_event to the bottom of the list we return + if base_insertion_event is not None: + event_ids.append(base_insertion_event.event_id) + + return 200, { + "state_events": state_events_at_start, + "events": event_ids, + "next_chunk_id": insertion_event["content"][ + EventContentFields.MSC2716_NEXT_CHUNK_ID + ], + } + + def on_GET(self, request, room_id): + return 501, "Not implemented" + + def on_PUT(self, request, room_id): + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_id + ) + + +def register_servlets(hs, http_server): + msc2716_enabled = hs.config.experimental.msc2716_enabled + + if msc2716_enabled: + RoomBatchSendEventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py new file mode 100644 index 0000000000..263596be86 --- /dev/null +++ b/synapse/rest/client/room_keys.py @@ -0,0 +1,391 @@ +# Copyright 2017, 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_json_object_from_request, + parse_string, +) + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class RoomKeysServlet(RestServlet): + PATTERNS = client_patterns( + "/room_keys/keys(/(?P[^/]+))?(/(?P[^/]+))?$" + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.auth = hs.get_auth() + self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() + + async def on_PUT(self, request, room_id, session_id): + """ + Uploads one or more encrypted E2E room keys for backup purposes. + room_id: the ID of the room the keys are for (optional) + session_id: the ID for the E2E room keys for the room (optional) + version: the version of the user's backup which this data is for. + the version must already have been created via the /room_keys/version API. + + Each session has: + * first_message_index: a numeric index indicating the oldest message + encrypted by this session. + * forwarded_count: how many times the uploading client claims this key + has been shared (forwarded) + * is_verified: whether the client that uploaded the keys claims they + were sent by a device which they've verified + * session_data: base64-encrypted data describing the session. + + Returns 200 OK on success with body {} + Returns 403 Forbidden if the version in question is not the most recently + created version (i.e. if this is an old client trying to write to a stale backup) + Returns 404 Not Found if the version in question doesn't exist + + The API is designed to be otherwise agnostic to the room_key encryption + algorithm being used. Sessions are merged with existing ones in the + backup using the heuristics: + * is_verified sessions always win over unverified sessions + * older first_message_index always win over newer sessions + * lower forwarded_count always wins over higher forwarded_count + + We trust the clients not to lie and corrupt their own backups. + It also means that if your access_token is stolen, the attacker could + delete your backup. + + POST /room_keys/keys/!abc:matrix.org/c0ff33?version=1 HTTP/1.1 + Content-Type: application/json + + { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + + Or... + + POST /room_keys/keys/!abc:matrix.org?version=1 HTTP/1.1 + Content-Type: application/json + + { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + + Or... + + POST /room_keys/keys?version=1 HTTP/1.1 + Content-Type: application/json + + { + "rooms": { + "!abc:matrix.org": { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + } + } + """ + requester = await self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + version = parse_string(request, "version") + + if session_id: + body = {"sessions": {session_id: body}} + + if room_id: + body = {"rooms": {room_id: body}} + + ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) + return 200, ret + + async def on_GET(self, request, room_id, session_id): + """ + Retrieves one or more encrypted E2E room keys for backup purposes. + Symmetric with the PUT version of the API. + + room_id: the ID of the room to retrieve the keys for (optional) + session_id: the ID for the E2E room keys to retrieve the keys for (optional) + version: the version of the user's backup which this data is for. + the version must already have been created via the /change_secret API. + + Returns as follows: + + GET /room_keys/keys/!abc:matrix.org/c0ff33?version=1 HTTP/1.1 + { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + + Or... + + GET /room_keys/keys/!abc:matrix.org?version=1 HTTP/1.1 + { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + + Or... + + GET /room_keys/keys?version=1 HTTP/1.1 + { + "rooms": { + "!abc:matrix.org": { + "sessions": { + "c0ff33": { + "first_message_index": 1, + "forwarded_count": 1, + "is_verified": false, + "session_data": "SSBBTSBBIEZJU0gK" + } + } + } + } + } + """ + requester = await self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + version = parse_string(request, "version", required=True) + + room_keys = await self.e2e_room_keys_handler.get_room_keys( + user_id, version, room_id, session_id + ) + + # Convert room_keys to the right format to return. + if session_id: + # If the client requests a specific session, but that session was + # not backed up, then return an M_NOT_FOUND. + if room_keys["rooms"] == {}: + raise NotFoundError("No room_keys found") + else: + room_keys = room_keys["rooms"][room_id]["sessions"][session_id] + elif room_id: + # If the client requests all sessions from a room, but no sessions + # are found, then return an empty result rather than an error, so + # that clients don't have to handle an error condition, and an + # empty result is valid. (Similarly if the client requests all + # sessions from the backup, but in that case, room_keys is already + # in the right format, so we don't need to do anything about it.) + if room_keys["rooms"] == {}: + room_keys = {"sessions": {}} + else: + room_keys = room_keys["rooms"][room_id] + + return 200, room_keys + + async def on_DELETE(self, request, room_id, session_id): + """ + Deletes one or more encrypted E2E room keys for a user for backup purposes. + + DELETE /room_keys/keys/!abc:matrix.org/c0ff33?version=1 + HTTP/1.1 200 OK + {} + + room_id: the ID of the room whose keys to delete (optional) + session_id: the ID for the E2E session to delete (optional) + version: the version of the user's backup which this data is for. + the version must already have been created via the /change_secret API. + """ + + requester = await self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + version = parse_string(request, "version") + + ret = await self.e2e_room_keys_handler.delete_room_keys( + user_id, version, room_id, session_id + ) + return 200, ret + + +class RoomKeysNewVersionServlet(RestServlet): + PATTERNS = client_patterns("/room_keys/version$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.auth = hs.get_auth() + self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() + + async def on_POST(self, request): + """ + Create a new backup version for this user's room_keys with the given + info. The version is allocated by the server and returned to the user + in the response. This API is intended to be used whenever the user + changes the encryption key for their backups, ensuring that backups + encrypted with different keys don't collide. + + It takes out an exclusive lock on this user's room_key backups, to ensure + clients only upload to the current backup. + + The algorithm passed in the version info is a reverse-DNS namespaced + identifier to describe the format of the encrypted backupped keys. + + The auth_data is { user_id: "user_id", nonce: } + encrypted using the algorithm and current encryption key described above. + + POST /room_keys/version + Content-Type: application/json + { + "algorithm": "m.megolm_backup.v1", + "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" + } + + HTTP/1.1 200 OK + Content-Type: application/json + { + "version": 12345 + } + """ + requester = await self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + info = parse_json_object_from_request(request) + + new_version = await self.e2e_room_keys_handler.create_version(user_id, info) + return 200, {"version": new_version} + + # we deliberately don't have a PUT /version, as these things really should + # be immutable to avoid people footgunning + + +class RoomKeysVersionServlet(RestServlet): + PATTERNS = client_patterns("/room_keys/version(/(?P[^/]+))?$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.auth = hs.get_auth() + self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() + + async def on_GET(self, request, version): + """ + Retrieve the version information about a given version of the user's + room_keys backup. If the version part is missing, returns info about the + most current backup version (if any) + + It takes out an exclusive lock on this user's room_key backups, to ensure + clients only upload to the current backup. + + Returns 404 if the given version does not exist. + + GET /room_keys/version/12345 HTTP/1.1 + { + "version": "12345", + "algorithm": "m.megolm_backup.v1", + "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" + } + """ + requester = await self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + + try: + info = await self.e2e_room_keys_handler.get_version_info(user_id, version) + except SynapseError as e: + if e.code == 404: + raise SynapseError(404, "No backup found", Codes.NOT_FOUND) + return 200, info + + async def on_DELETE(self, request, version): + """ + Delete the information about a given version of the user's + room_keys backup. If the version part is missing, deletes the most + current backup version (if any). Doesn't delete the actual room data. + + DELETE /room_keys/version/12345 HTTP/1.1 + HTTP/1.1 200 OK + {} + """ + if version is None: + raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND) + + requester = await self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + + await self.e2e_room_keys_handler.delete_version(user_id, version) + return 200, {} + + async def on_PUT(self, request, version): + """ + Update the information about a given version of the user's room_keys backup. + + POST /room_keys/version/12345 HTTP/1.1 + Content-Type: application/json + { + "algorithm": "m.megolm_backup.v1", + "auth_data": { + "public_key": "abcdefg", + "signatures": { + "ed25519:something": "hijklmnop" + } + }, + "version": "12345" + } + + HTTP/1.1 200 OK + Content-Type: application/json + {} + """ + requester = await self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + info = parse_json_object_from_request(request) + + if version is None: + raise SynapseError( + 400, "No version specified to update", Codes.MISSING_PARAM + ) + + await self.e2e_room_keys_handler.update_version(user_id, version, info) + return 200, {} + + +def register_servlets(hs, http_server): + RoomKeysServlet(hs).register(http_server) + RoomKeysVersionServlet(hs).register(http_server) + RoomKeysNewVersionServlet(hs).register(http_server) diff --git a/synapse/rest/client/room_upgrade_rest_servlet.py b/synapse/rest/client/room_upgrade_rest_servlet.py new file mode 100644 index 0000000000..6d1b083acb --- /dev/null +++ b/synapse/rest/client/room_upgrade_rest_servlet.py @@ -0,0 +1,88 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import Codes, ShadowBanError, SynapseError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from synapse.util import stringutils + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class RoomUpgradeRestServlet(RestServlet): + """Handler for room upgrade requests. + + Handles requests of the form: + + POST /_matrix/client/r0/rooms/$roomid/upgrade HTTP/1.1 + Content-Type: application/json + + { + "new_version": "2", + } + + Creates a new room and shuts down the old one. Returns the ID of the new room. + + Args: + hs (synapse.server.HomeServer): + """ + + PATTERNS = client_patterns( + # /rooms/$roomid/upgrade + "/rooms/(?P[^/]*)/upgrade$" + ) + + def __init__(self, hs): + super().__init__() + self._hs = hs + self._room_creation_handler = hs.get_room_creation_handler() + self._auth = hs.get_auth() + + async def on_POST(self, request, room_id): + requester = await self._auth.get_user_by_req(request) + + content = parse_json_object_from_request(request) + assert_params_in_dict(content, ("new_version",)) + + new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"]) + if new_version is None: + raise SynapseError( + 400, + "Your homeserver does not support this room version", + Codes.UNSUPPORTED_ROOM_VERSION, + ) + + try: + new_room_id = await self._room_creation_handler.upgrade_room( + requester, room_id, new_version + ) + except ShadowBanError: + # Generate a random room ID. + new_room_id = stringutils.random_string(18) + + ret = {"replacement_room": new_room_id} + + return 200, ret + + +def register_servlets(hs, http_server): + RoomUpgradeRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py new file mode 100644 index 0000000000..d537d811d8 --- /dev/null +++ b/synapse/rest/client/sendtodevice.py @@ -0,0 +1,67 @@ +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Tuple + +from synapse.http import servlet +from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request +from synapse.logging.opentracing import set_tag, trace +from synapse.rest.client.transactions import HttpTransactionCache + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class SendToDeviceRestServlet(servlet.RestServlet): + PATTERNS = client_patterns( + "/sendToDevice/(?P[^/]*)/(?P[^/]*)$" + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.txns = HttpTransactionCache(hs) + self.device_message_handler = hs.get_device_message_handler() + + @trace(opname="sendToDevice") + def on_PUT(self, request, message_type, txn_id): + set_tag("message_type", message_type) + set_tag("txn_id", txn_id) + return self.txns.fetch_or_execute_request( + request, self._put, request, message_type, txn_id + ) + + async def _put(self, request, message_type, txn_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + + content = parse_json_object_from_request(request) + assert_params_in_dict(content, ("messages",)) + + await self.device_message_handler.send_device_message( + requester, message_type, content["messages"] + ) + + response: Tuple[int, dict] = (200, {}) + return response + + +def register_servlets(hs, http_server): + SendToDeviceRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py new file mode 100644 index 0000000000..d2e7f04b40 --- /dev/null +++ b/synapse/rest/client/shared_rooms.py @@ -0,0 +1,67 @@ +# Copyright 2020 Half-Shot +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet +from synapse.types import UserID + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class UserSharedRoomsServlet(RestServlet): + """ + GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1 + """ + + PATTERNS = client_patterns( + "/uk.half-shot.msc2666/user/shared_rooms/(?P[^/]*)", + releases=(), # This is an unstable feature + ) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.user_directory_active = hs.config.update_user_directory + + async def on_GET(self, request, user_id): + + if not self.user_directory_active: + raise SynapseError( + code=400, + msg="The user directory is disabled on this server. Cannot determine shared rooms.", + errcode=Codes.FORBIDDEN, + ) + + UserID.from_string(user_id) + + requester = await self.auth.get_user_by_req(request) + if user_id == requester.user.to_string(): + raise SynapseError( + code=400, + msg="You cannot request a list of shared rooms with yourself", + errcode=Codes.FORBIDDEN, + ) + rooms = await self.store.get_shared_rooms_for_users( + requester.user.to_string(), user_id + ) + + return 200, {"joined": list(rooms)} + + +def register_servlets(hs, http_server): + UserSharedRoomsServlet(hs).register(http_server) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py new file mode 100644 index 0000000000..e18f4d01b3 --- /dev/null +++ b/synapse/rest/client/sync.py @@ -0,0 +1,532 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +import logging +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple + +from synapse.api.constants import Membership, PresenceState +from synapse.api.errors import Codes, StoreError, SynapseError +from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection +from synapse.events.utils import ( + format_event_for_client_v2_without_room_id, + format_event_raw, +) +from synapse.handlers.presence import format_user_presence_state +from synapse.handlers.sync import KnockedSyncResult, SyncConfig +from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict, StreamToken +from synapse.util import json_decoder + +from ._base import client_patterns, set_timeline_upper_limit + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class SyncRestServlet(RestServlet): + """ + + GET parameters:: + timeout(int): How long to wait for new events in milliseconds. + since(batch_token): Batch token when asking for incremental deltas. + set_presence(str): What state the device presence should be set to. + default is "online". + filter(filter_id): A filter to apply to the events returned. + + Response JSON:: + { + "next_batch": // batch token for the next /sync + "presence": // presence data for the user. + "rooms": { + "join": { // Joined rooms being updated. + "${room_id}": { // Id of the room being updated + "event_map": // Map of EventID -> event JSON. + "timeline": { // The recent events in the room if gap is "true" + "limited": // Was the per-room event limit exceeded? + // otherwise the next events in the room. + "events": [] // list of EventIDs in the "event_map". + "prev_batch": // back token for getting previous events. + } + "state": {"events": []} // list of EventIDs updating the + // current state to be what it should + // be at the end of the batch. + "ephemeral": {"events": []} // list of event objects + } + }, + "invite": {}, // Invited rooms being updated. + "leave": {} // Archived rooms being updated. + } + } + """ + + PATTERNS = client_patterns("/sync$") + ALLOWED_PRESENCE = {"online", "offline", "unavailable"} + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.sync_handler = hs.get_sync_handler() + self.clock = hs.get_clock() + self.filtering = hs.get_filtering() + self.presence_handler = hs.get_presence_handler() + self._server_notices_sender = hs.get_server_notices_sender() + self._event_serializer = hs.get_event_client_serializer() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + # This will always be set by the time Twisted calls us. + assert request.args is not None + + if b"from" in request.args: + # /events used to use 'from', but /sync uses 'since'. + # Lets be helpful and whine if we see a 'from'. + raise SynapseError( + 400, "'from' is not a valid query parameter. Did you mean 'since'?" + ) + + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user = requester.user + device_id = requester.device_id + + timeout = parse_integer(request, "timeout", default=0) + since = parse_string(request, "since") + set_presence = parse_string( + request, + "set_presence", + default="online", + allowed_values=self.ALLOWED_PRESENCE, + ) + filter_id = parse_string(request, "filter") + full_state = parse_boolean(request, "full_state", default=False) + + logger.debug( + "/sync: user=%r, timeout=%r, since=%r, " + "set_presence=%r, filter_id=%r, device_id=%r", + user, + timeout, + since, + set_presence, + filter_id, + device_id, + ) + + request_key = (user, timeout, since, filter_id, full_state, device_id) + + if filter_id is None: + filter_collection = DEFAULT_FILTER_COLLECTION + elif filter_id.startswith("{"): + try: + filter_object = json_decoder.decode(filter_id) + set_timeline_upper_limit( + filter_object, self.hs.config.filter_timeline_limit + ) + except Exception: + raise SynapseError(400, "Invalid filter JSON") + self.filtering.check_valid_filter(filter_object) + filter_collection = FilterCollection(filter_object) + else: + try: + filter_collection = await self.filtering.get_user_filter( + user.localpart, filter_id + ) + except StoreError as err: + if err.code != 404: + raise + # fix up the description and errcode to be more useful + raise SynapseError(400, "No such filter", errcode=Codes.INVALID_PARAM) + + sync_config = SyncConfig( + user=user, + filter_collection=filter_collection, + is_guest=requester.is_guest, + request_key=request_key, + device_id=device_id, + ) + + since_token = None + if since is not None: + since_token = await StreamToken.from_string(self.store, since) + + # send any outstanding server notices to the user. + await self._server_notices_sender.on_user_syncing(user.to_string()) + + affect_presence = set_presence != PresenceState.OFFLINE + + if affect_presence: + await self.presence_handler.set_state( + user, {"presence": set_presence}, True + ) + + context = await self.presence_handler.user_syncing( + user.to_string(), affect_presence=affect_presence + ) + with context: + sync_result = await self.sync_handler.wait_for_sync_for_user( + requester, + sync_config, + since_token=since_token, + timeout=timeout, + full_state=full_state, + ) + + # the client may have disconnected by now; don't bother to serialize the + # response if so. + if request._disconnected: + logger.info("Client has disconnected; not serializing response.") + return 200, {} + + time_now = self.clock.time_msec() + response_content = await self.encode_response( + time_now, sync_result, requester.access_token_id, filter_collection + ) + + logger.debug("Event formatting complete") + return 200, response_content + + async def encode_response(self, time_now, sync_result, access_token_id, filter): + logger.debug("Formatting events in sync response") + if filter.event_format == "client": + event_formatter = format_event_for_client_v2_without_room_id + elif filter.event_format == "federation": + event_formatter = format_event_raw + else: + raise Exception("Unknown event format %s" % (filter.event_format,)) + + joined = await self.encode_joined( + sync_result.joined, + time_now, + access_token_id, + filter.event_fields, + event_formatter, + ) + + invited = await self.encode_invited( + sync_result.invited, time_now, access_token_id, event_formatter + ) + + knocked = await self.encode_knocked( + sync_result.knocked, time_now, access_token_id, event_formatter + ) + + archived = await self.encode_archived( + sync_result.archived, + time_now, + access_token_id, + filter.event_fields, + event_formatter, + ) + + logger.debug("building sync response dict") + + response: dict = defaultdict(dict) + response["next_batch"] = await sync_result.next_batch.to_string(self.store) + + if sync_result.account_data: + response["account_data"] = {"events": sync_result.account_data} + if sync_result.presence: + response["presence"] = SyncRestServlet.encode_presence( + sync_result.presence, time_now + ) + + if sync_result.to_device: + response["to_device"] = {"events": sync_result.to_device} + + if sync_result.device_lists.changed: + response["device_lists"]["changed"] = list(sync_result.device_lists.changed) + if sync_result.device_lists.left: + response["device_lists"]["left"] = list(sync_result.device_lists.left) + + # We always include this because https://github.com/vector-im/element-android/issues/3725 + # The spec isn't terribly clear on when this can be omitted and how a client would tell + # the difference between "no keys present" and "nothing changed" in terms of whole field + # absent / individual key type entry absent + # Corresponding synapse issue: https://github.com/matrix-org/synapse/issues/10456 + response["device_one_time_keys_count"] = sync_result.device_one_time_keys_count + + # https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md + # states that this field should always be included, as long as the server supports the feature. + response[ + "org.matrix.msc2732.device_unused_fallback_key_types" + ] = sync_result.device_unused_fallback_key_types + + if joined: + response["rooms"][Membership.JOIN] = joined + if invited: + response["rooms"][Membership.INVITE] = invited + if knocked: + response["rooms"][Membership.KNOCK] = knocked + if archived: + response["rooms"][Membership.LEAVE] = archived + + if sync_result.groups.join: + response["groups"][Membership.JOIN] = sync_result.groups.join + if sync_result.groups.invite: + response["groups"][Membership.INVITE] = sync_result.groups.invite + if sync_result.groups.leave: + response["groups"][Membership.LEAVE] = sync_result.groups.leave + + return response + + @staticmethod + def encode_presence(events, time_now): + return { + "events": [ + { + "type": "m.presence", + "sender": event.user_id, + "content": format_user_presence_state( + event, time_now, include_user_id=False + ), + } + for event in events + ] + } + + async def encode_joined( + self, rooms, time_now, token_id, event_fields, event_formatter + ): + """ + Encode the joined rooms in a sync result + + Args: + rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync + results for rooms this user is joined to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing + of transaction IDs + event_fields(list): List of event fields to include. If empty, + all fields will be returned. + event_formatter (func[dict]): function to convert from federation format + to client format + Returns: + dict[str, dict[str, object]]: the joined rooms list, in our + response format + """ + joined = {} + for room in rooms: + joined[room.room_id] = await self.encode_room( + room, + time_now, + token_id, + joined=True, + only_fields=event_fields, + event_formatter=event_formatter, + ) + + return joined + + async def encode_invited(self, rooms, time_now, token_id, event_formatter): + """ + Encode the invited rooms in a sync result + + Args: + rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of + sync results for rooms this user is invited to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing + of transaction IDs + event_formatter (func[dict]): function to convert from federation format + to client format + + Returns: + dict[str, dict[str, object]]: the invited rooms list, in our + response format + """ + invited = {} + for room in rooms: + invite = await self._event_serializer.serialize_event( + room.invite, + time_now, + token_id=token_id, + event_format=event_formatter, + include_stripped_room_state=True, + ) + unsigned = dict(invite.get("unsigned", {})) + invite["unsigned"] = unsigned + invited_state = list(unsigned.pop("invite_room_state", [])) + invited_state.append(invite) + invited[room.room_id] = {"invite_state": {"events": invited_state}} + + return invited + + async def encode_knocked( + self, + rooms: List[KnockedSyncResult], + time_now: int, + token_id: int, + event_formatter: Callable[[Dict], Dict], + ) -> Dict[str, Dict[str, Any]]: + """ + Encode the rooms we've knocked on in a sync result. + + Args: + rooms: list of sync results for rooms this user is knocking on + time_now: current time - used as a baseline for age calculations + token_id: ID of the user's auth token - used for namespacing of transaction IDs + event_formatter: function to convert from federation format to client format + + Returns: + The list of rooms the user has knocked on, in our response format. + """ + knocked = {} + for room in rooms: + knock = await self._event_serializer.serialize_event( + room.knock, + time_now, + token_id=token_id, + event_format=event_formatter, + include_stripped_room_state=True, + ) + + # Extract the `unsigned` key from the knock event. + # This is where we (cheekily) store the knock state events + unsigned = knock.setdefault("unsigned", {}) + + # Duplicate the dictionary in order to avoid modifying the original + unsigned = dict(unsigned) + + # Extract the stripped room state from the unsigned dict + # This is for clients to get a little bit of information about + # the room they've knocked on, without revealing any sensitive information + knocked_state = list(unsigned.pop("knock_room_state", [])) + + # Append the actual knock membership event itself as well. This provides + # the client with: + # + # * A knock state event that they can use for easier internal tracking + # * The rough timestamp of when the knock occurred contained within the event + knocked_state.append(knock) + + # Build the `knock_state` dictionary, which will contain the state of the + # room that the client has knocked on + knocked[room.room_id] = {"knock_state": {"events": knocked_state}} + + return knocked + + async def encode_archived( + self, rooms, time_now, token_id, event_fields, event_formatter + ): + """ + Encode the archived rooms in a sync result + + Args: + rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of + sync results for rooms this user is joined to + time_now(int): current time - used as a baseline for age + calculations + token_id(int): ID of the user's auth token - used for namespacing + of transaction IDs + event_fields(list): List of event fields to include. If empty, + all fields will be returned. + event_formatter (func[dict]): function to convert from federation format + to client format + Returns: + dict[str, dict[str, object]]: The invited rooms list, in our + response format + """ + joined = {} + for room in rooms: + joined[room.room_id] = await self.encode_room( + room, + time_now, + token_id, + joined=False, + only_fields=event_fields, + event_formatter=event_formatter, + ) + + return joined + + async def encode_room( + self, room, time_now, token_id, joined, only_fields, event_formatter + ): + """ + Args: + room (JoinedSyncResult|ArchivedSyncResult): sync result for a + single room + time_now (int): current time - used as a baseline for age + calculations + token_id (int): ID of the user's auth token - used for namespacing + of transaction IDs + joined (bool): True if the user is joined to this room - will mean + we handle ephemeral events + only_fields(list): Optional. The list of event fields to include. + event_formatter (func[dict]): function to convert from federation format + to client format + Returns: + dict[str, object]: the room, encoded in our response format + """ + + def serialize(events): + return self._event_serializer.serialize_events( + events, + time_now=time_now, + # We don't bundle "live" events, as otherwise clients + # will end up double counting annotations. + bundle_aggregations=False, + token_id=token_id, + event_format=event_formatter, + only_event_fields=only_fields, + ) + + state_dict = room.state + timeline_events = room.timeline.events + + state_events = state_dict.values() + + for event in itertools.chain(state_events, timeline_events): + # We've had bug reports that events were coming down under the + # wrong room. + if event.room_id != room.room_id: + logger.warning( + "Event %r is under room %r instead of %r", + event.event_id, + room.room_id, + event.room_id, + ) + + serialized_state = await serialize(state_events) + serialized_timeline = await serialize(timeline_events) + + account_data = room.account_data + + result = { + "timeline": { + "events": serialized_timeline, + "prev_batch": await room.timeline.prev_batch.to_string(self.store), + "limited": room.timeline.limited, + }, + "state": {"events": serialized_state}, + "account_data": {"events": account_data}, + } + + if joined: + ephemeral_events = room.ephemeral + result["ephemeral"] = {"events": ephemeral_events} + result["unread_notifications"] = room.unread_notifications + result["summary"] = room.summary + result["org.matrix.msc2654.unread_count"] = room.unread_count + + return result + + +def register_servlets(hs, http_server): + SyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py new file mode 100644 index 0000000000..c14f83be18 --- /dev/null +++ b/synapse/rest/client/tags.py @@ -0,0 +1,85 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import AuthError +from synapse.http.servlet import RestServlet, parse_json_object_from_request + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class TagListServlet(RestServlet): + """ + GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 + """ + + PATTERNS = client_patterns("/user/(?P[^/]*)/rooms/(?P[^/]*)/tags") + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, user_id, room_id): + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot get tags for other users.") + + tags = await self.store.get_tags_for_room(user_id, room_id) + + return 200, {"tags": tags} + + +class TagServlet(RestServlet): + """ + PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 + DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 + """ + + PATTERNS = client_patterns( + "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags/(?P[^/]*)" + ) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.handler = hs.get_account_data_handler() + + async def on_PUT(self, request, user_id, room_id, tag): + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot add tags for other users.") + + body = parse_json_object_from_request(request) + + await self.handler.add_tag_to_room(user_id, room_id, tag, body) + + return 200, {} + + async def on_DELETE(self, request, user_id, room_id, tag): + requester = await self.auth.get_user_by_req(request) + if user_id != requester.user.to_string(): + raise AuthError(403, "Cannot add tags for other users.") + + await self.handler.remove_tag_from_room(user_id, room_id, tag) + + return 200, {} + + +def register_servlets(hs, http_server): + TagListServlet(hs).register(http_server) + TagServlet(hs).register(http_server) diff --git a/synapse/rest/client/thirdparty.py b/synapse/rest/client/thirdparty.py new file mode 100644 index 0000000000..b5c67c9bb6 --- /dev/null +++ b/synapse/rest/client/thirdparty.py @@ -0,0 +1,111 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +from synapse.api.constants import ThirdPartyEntityKind +from synapse.http.servlet import RestServlet + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class ThirdPartyProtocolsServlet(RestServlet): + PATTERNS = client_patterns("/thirdparty/protocols") + + def __init__(self, hs): + super().__init__() + + self.auth = hs.get_auth() + self.appservice_handler = hs.get_application_service_handler() + + async def on_GET(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) + + protocols = await self.appservice_handler.get_3pe_protocols() + return 200, protocols + + +class ThirdPartyProtocolServlet(RestServlet): + PATTERNS = client_patterns("/thirdparty/protocol/(?P[^/]+)$") + + def __init__(self, hs): + super().__init__() + + self.auth = hs.get_auth() + self.appservice_handler = hs.get_application_service_handler() + + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) + + protocols = await self.appservice_handler.get_3pe_protocols( + only_protocol=protocol + ) + if protocol in protocols: + return 200, protocols[protocol] + else: + return 404, {"error": "Unknown protocol"} + + +class ThirdPartyUserServlet(RestServlet): + PATTERNS = client_patterns("/thirdparty/user(/(?P[^/]+))?$") + + def __init__(self, hs): + super().__init__() + + self.auth = hs.get_auth() + self.appservice_handler = hs.get_application_service_handler() + + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) + + fields = request.args + fields.pop(b"access_token", None) + + results = await self.appservice_handler.query_3pe( + ThirdPartyEntityKind.USER, protocol, fields + ) + + return 200, results + + +class ThirdPartyLocationServlet(RestServlet): + PATTERNS = client_patterns("/thirdparty/location(/(?P[^/]+))?$") + + def __init__(self, hs): + super().__init__() + + self.auth = hs.get_auth() + self.appservice_handler = hs.get_application_service_handler() + + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) + + fields = request.args + fields.pop(b"access_token", None) + + results = await self.appservice_handler.query_3pe( + ThirdPartyEntityKind.LOCATION, protocol, fields + ) + + return 200, results + + +def register_servlets(hs, http_server): + ThirdPartyProtocolsServlet(hs).register(http_server) + ThirdPartyProtocolServlet(hs).register(http_server) + ThirdPartyUserServlet(hs).register(http_server) + ThirdPartyLocationServlet(hs).register(http_server) diff --git a/synapse/rest/client/tokenrefresh.py b/synapse/rest/client/tokenrefresh.py new file mode 100644 index 0000000000..b2f858545c --- /dev/null +++ b/synapse/rest/client/tokenrefresh.py @@ -0,0 +1,37 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.api.errors import AuthError +from synapse.http.servlet import RestServlet + +from ._base import client_patterns + + +class TokenRefreshRestServlet(RestServlet): + """ + Exchanges refresh tokens for a pair of an access token and a new refresh + token. + """ + + PATTERNS = client_patterns("/tokenrefresh") + + def __init__(self, hs): + super().__init__() + + async def on_POST(self, request): + raise AuthError(403, "tokenrefresh is no longer supported.") + + +def register_servlets(hs, http_server): + TokenRefreshRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py new file mode 100644 index 0000000000..7e8912f0b9 --- /dev/null +++ b/synapse/rest/client/user_directory.py @@ -0,0 +1,79 @@ +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class UserDirectorySearchRestServlet(RestServlet): + PATTERNS = client_patterns("/user_directory/search$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + self.user_directory_handler = hs.get_user_directory_handler() + + async def on_POST(self, request): + """Searches for users in directory + + Returns: + dict of the form:: + + { + "limited": , # whether there were more results or not + "results": [ # Ordered by best match first + { + "user_id": , + "display_name": , + "avatar_url": + } + ] + } + """ + requester = await self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + + if not self.hs.config.user_directory_search_enabled: + return 200, {"limited": False, "results": []} + + body = parse_json_object_from_request(request) + + limit = body.get("limit", 10) + limit = min(limit, 50) + + try: + search_term = body["search_term"] + except Exception: + raise SynapseError(400, "`search_term` is required field") + + results = await self.user_directory_handler.search_users( + user_id, search_term, limit + ) + + return 200, results + + +def register_servlets(hs, http_server): + UserDirectorySearchRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/__init__.py b/synapse/rest/client/v1/__init__.py deleted file mode 100644 index 5e83dba2ed..0000000000 --- a/synapse/rest/client/v1/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py deleted file mode 100644 index ae92a3df8e..0000000000 --- a/synapse/rest/client/v1/directory.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging - -from synapse.api.errors import ( - AuthError, - Codes, - InvalidClientCredentialsError, - NotFoundError, - SynapseError, -) -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.types import RoomAlias - -logger = logging.getLogger(__name__) - - -def register_servlets(hs, http_server): - ClientDirectoryServer(hs).register(http_server) - ClientDirectoryListServer(hs).register(http_server) - ClientAppserviceDirectoryListServer(hs).register(http_server) - - -class ClientDirectoryServer(RestServlet): - PATTERNS = client_patterns("/directory/room/(?P[^/]*)$", v1=True) - - def __init__(self, hs): - super().__init__() - self.store = hs.get_datastore() - self.directory_handler = hs.get_directory_handler() - self.auth = hs.get_auth() - - async def on_GET(self, request, room_alias): - room_alias = RoomAlias.from_string(room_alias) - - res = await self.directory_handler.get_association(room_alias) - - return 200, res - - async def on_PUT(self, request, room_alias): - room_alias = RoomAlias.from_string(room_alias) - - content = parse_json_object_from_request(request) - if "room_id" not in content: - raise SynapseError( - 400, 'Missing params: ["room_id"]', errcode=Codes.BAD_JSON - ) - - logger.debug("Got content: %s", content) - logger.debug("Got room name: %s", room_alias.to_string()) - - room_id = content["room_id"] - servers = content["servers"] if "servers" in content else None - - logger.debug("Got room_id: %s", room_id) - logger.debug("Got servers: %s", servers) - - # TODO(erikj): Check types. - - room = await self.store.get_room(room_id) - if room is None: - raise SynapseError(400, "Room does not exist") - - requester = await self.auth.get_user_by_req(request) - - await self.directory_handler.create_association( - requester, room_alias, room_id, servers - ) - - return 200, {} - - async def on_DELETE(self, request, room_alias): - try: - service = self.auth.get_appservice_by_req(request) - room_alias = RoomAlias.from_string(room_alias) - await self.directory_handler.delete_appservice_association( - service, room_alias - ) - logger.info( - "Application service at %s deleted alias %s", - service.url, - room_alias.to_string(), - ) - return 200, {} - except InvalidClientCredentialsError: - # fallback to default user behaviour if they aren't an AS - pass - - requester = await self.auth.get_user_by_req(request) - user = requester.user - - room_alias = RoomAlias.from_string(room_alias) - - await self.directory_handler.delete_association(requester, room_alias) - - logger.info( - "User %s deleted alias %s", user.to_string(), room_alias.to_string() - ) - - return 200, {} - - -class ClientDirectoryListServer(RestServlet): - PATTERNS = client_patterns("/directory/list/room/(?P[^/]*)$", v1=True) - - def __init__(self, hs): - super().__init__() - self.store = hs.get_datastore() - self.directory_handler = hs.get_directory_handler() - self.auth = hs.get_auth() - - async def on_GET(self, request, room_id): - room = await self.store.get_room(room_id) - if room is None: - raise NotFoundError("Unknown room") - - return 200, {"visibility": "public" if room["is_public"] else "private"} - - async def on_PUT(self, request, room_id): - requester = await self.auth.get_user_by_req(request) - - content = parse_json_object_from_request(request) - visibility = content.get("visibility", "public") - - await self.directory_handler.edit_published_room_list( - requester, room_id, visibility - ) - - return 200, {} - - async def on_DELETE(self, request, room_id): - requester = await self.auth.get_user_by_req(request) - - await self.directory_handler.edit_published_room_list( - requester, room_id, "private" - ) - - return 200, {} - - -class ClientAppserviceDirectoryListServer(RestServlet): - PATTERNS = client_patterns( - "/directory/list/appservice/(?P[^/]*)/(?P[^/]*)$", v1=True - ) - - def __init__(self, hs): - super().__init__() - self.store = hs.get_datastore() - self.directory_handler = hs.get_directory_handler() - self.auth = hs.get_auth() - - def on_PUT(self, request, network_id, room_id): - content = parse_json_object_from_request(request) - visibility = content.get("visibility", "public") - return self._edit(request, network_id, room_id, visibility) - - def on_DELETE(self, request, network_id, room_id): - return self._edit(request, network_id, room_id, "private") - - async def _edit(self, request, network_id, room_id, visibility): - requester = await self.auth.get_user_by_req(request) - if not requester.app_service: - raise AuthError( - 403, "Only appservices can edit the appservice published room list" - ) - - await self.directory_handler.edit_published_appservice_room_list( - requester.app_service.id, network_id, room_id, visibility - ) - - return 200, {} diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py deleted file mode 100644 index ee7454996e..0000000000 --- a/synapse/rest/client/v1/events.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This module contains REST servlets to do with event streaming, /events.""" -import logging - -from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.streams.config import PaginationConfig - -logger = logging.getLogger(__name__) - - -class EventStreamRestServlet(RestServlet): - PATTERNS = client_patterns("/events$", v1=True) - - DEFAULT_LONGPOLL_TIME_MS = 30000 - - def __init__(self, hs): - super().__init__() - self.event_stream_handler = hs.get_event_stream_handler() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - is_guest = requester.is_guest - room_id = None - if is_guest: - if b"room_id" not in request.args: - raise SynapseError(400, "Guest users must specify room_id param") - if b"room_id" in request.args: - room_id = request.args[b"room_id"][0].decode("ascii") - - pagin_config = await PaginationConfig.from_request(self.store, request) - timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS - if b"timeout" in request.args: - try: - timeout = int(request.args[b"timeout"][0]) - except ValueError: - raise SynapseError(400, "timeout must be in milliseconds.") - - as_client_event = b"raw" not in request.args - - chunk = await self.event_stream_handler.get_stream( - requester.user.to_string(), - pagin_config, - timeout=timeout, - as_client_event=as_client_event, - affect_presence=(not is_guest), - room_id=room_id, - is_guest=is_guest, - ) - - return 200, chunk - - -class EventRestServlet(RestServlet): - PATTERNS = client_patterns("/events/(?P[^/]*)$", v1=True) - - def __init__(self, hs): - super().__init__() - self.clock = hs.get_clock() - self.event_handler = hs.get_event_handler() - self.auth = hs.get_auth() - self._event_serializer = hs.get_event_client_serializer() - - async def on_GET(self, request, event_id): - requester = await self.auth.get_user_by_req(request) - event = await self.event_handler.get_event(requester.user, None, event_id) - - time_now = self.clock.time_msec() - if event: - event = await self._event_serializer.serialize_event(event, time_now) - return 200, event - else: - return 404, "Event not found." - - -def register_servlets(hs, http_server): - EventStreamRestServlet(hs).register(http_server) - EventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py deleted file mode 100644 index bef1edc838..0000000000 --- a/synapse/rest/client/v1/initial_sync.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from synapse.http.servlet import RestServlet, parse_boolean -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.streams.config import PaginationConfig - - -# TODO: Needs unit testing -class InitialSyncRestServlet(RestServlet): - PATTERNS = client_patterns("/initialSync$", v1=True) - - def __init__(self, hs): - super().__init__() - self.initial_sync_handler = hs.get_initial_sync_handler() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request) - as_client_event = b"raw" not in request.args - pagination_config = await PaginationConfig.from_request(self.store, request) - include_archived = parse_boolean(request, "archived", default=False) - content = await self.initial_sync_handler.snapshot_all_rooms( - user_id=requester.user.to_string(), - pagin_config=pagination_config, - as_client_event=as_client_event, - include_archived=include_archived, - ) - - return 200, content - - -def register_servlets(hs, http_server): - InitialSyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py deleted file mode 100644 index 11567bf32c..0000000000 --- a/synapse/rest/client/v1/login.py +++ /dev/null @@ -1,600 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import re -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional - -from typing_extensions import TypedDict - -from synapse.api.errors import Codes, LoginError, SynapseError -from synapse.api.ratelimiting import Ratelimiter -from synapse.api.urls import CLIENT_API_PREFIX -from synapse.appservice import ApplicationService -from synapse.handlers.sso import SsoIdentityProvider -from synapse.http import get_request_uri -from synapse.http.server import HttpServer, finish_request -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_boolean, - parse_bytes_from_args, - parse_json_object_from_request, - parse_string, -) -from synapse.http.site import SynapseRequest -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.rest.well_known import WellKnownBuilder -from synapse.types import JsonDict, UserID - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class LoginResponse(TypedDict, total=False): - user_id: str - access_token: str - home_server: str - expires_in_ms: Optional[int] - refresh_token: Optional[str] - device_id: str - well_known: Optional[Dict[str, Any]] - - -class LoginRestServlet(RestServlet): - PATTERNS = client_patterns("/login$", v1=True) - CAS_TYPE = "m.login.cas" - SSO_TYPE = "m.login.sso" - TOKEN_TYPE = "m.login.token" - JWT_TYPE = "org.matrix.login.jwt" - JWT_TYPE_DEPRECATED = "m.login.jwt" - APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" - REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token" - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - - # JWT configuration variables. - self.jwt_enabled = hs.config.jwt_enabled - self.jwt_secret = hs.config.jwt_secret - self.jwt_algorithm = hs.config.jwt_algorithm - self.jwt_issuer = hs.config.jwt_issuer - self.jwt_audiences = hs.config.jwt_audiences - - # SSO configuration. - self.saml2_enabled = hs.config.saml2_enabled - self.cas_enabled = hs.config.cas_enabled - self.oidc_enabled = hs.config.oidc_enabled - self._msc2858_enabled = hs.config.experimental.msc2858_enabled - self._msc2918_enabled = hs.config.access_token_lifetime is not None - - self.auth = hs.get_auth() - - self.clock = hs.get_clock() - - self.auth_handler = self.hs.get_auth_handler() - self.registration_handler = hs.get_registration_handler() - self._sso_handler = hs.get_sso_handler() - - self._well_known_builder = WellKnownBuilder(hs) - self._address_ratelimiter = Ratelimiter( - store=hs.get_datastore(), - clock=hs.get_clock(), - rate_hz=self.hs.config.rc_login_address.per_second, - burst_count=self.hs.config.rc_login_address.burst_count, - ) - self._account_ratelimiter = Ratelimiter( - store=hs.get_datastore(), - clock=hs.get_clock(), - rate_hz=self.hs.config.rc_login_account.per_second, - burst_count=self.hs.config.rc_login_account.burst_count, - ) - - def on_GET(self, request: SynapseRequest): - flows = [] - if self.jwt_enabled: - flows.append({"type": LoginRestServlet.JWT_TYPE}) - flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED}) - - if self.cas_enabled: - # we advertise CAS for backwards compat, though MSC1721 renamed it - # to SSO. - flows.append({"type": LoginRestServlet.CAS_TYPE}) - - if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: - sso_flow: JsonDict = { - "type": LoginRestServlet.SSO_TYPE, - "identity_providers": [ - _get_auth_flow_dict_for_idp( - idp, - ) - for idp in self._sso_handler.get_identity_providers().values() - ], - } - - if self._msc2858_enabled: - # backwards-compatibility support for clients which don't - # support the stable API yet - sso_flow["org.matrix.msc2858.identity_providers"] = [ - _get_auth_flow_dict_for_idp(idp, use_unstable_brands=True) - for idp in self._sso_handler.get_identity_providers().values() - ] - - flows.append(sso_flow) - - # While it's valid for us to advertise this login type generally, - # synapse currently only gives out these tokens as part of the - # SSO login flow. - # Generally we don't want to advertise login flows that clients - # don't know how to implement, since they (currently) will always - # fall back to the fallback API if they don't understand one of the - # login flow types returned. - flows.append({"type": LoginRestServlet.TOKEN_TYPE}) - - flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types()) - - flows.append({"type": LoginRestServlet.APPSERVICE_TYPE}) - - return 200, {"flows": flows} - - async def on_POST(self, request: SynapseRequest): - login_submission = parse_json_object_from_request(request) - - if self._msc2918_enabled: - # Check if this login should also issue a refresh token, as per - # MSC2918 - should_issue_refresh_token = parse_boolean( - request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False - ) - else: - should_issue_refresh_token = False - - try: - if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: - appservice = self.auth.get_appservice_by_req(request) - - if appservice.is_rate_limited(): - await self._address_ratelimiter.ratelimit( - None, request.getClientIP() - ) - - result = await self._do_appservice_login( - login_submission, - appservice, - should_issue_refresh_token=should_issue_refresh_token, - ) - elif self.jwt_enabled and ( - login_submission["type"] == LoginRestServlet.JWT_TYPE - or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED - ): - await self._address_ratelimiter.ratelimit(None, request.getClientIP()) - result = await self._do_jwt_login( - login_submission, - should_issue_refresh_token=should_issue_refresh_token, - ) - elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: - await self._address_ratelimiter.ratelimit(None, request.getClientIP()) - result = await self._do_token_login( - login_submission, - should_issue_refresh_token=should_issue_refresh_token, - ) - else: - await self._address_ratelimiter.ratelimit(None, request.getClientIP()) - result = await self._do_other_login( - login_submission, - should_issue_refresh_token=should_issue_refresh_token, - ) - except KeyError: - raise SynapseError(400, "Missing JSON keys.") - - well_known_data = self._well_known_builder.get_well_known() - if well_known_data: - result["well_known"] = well_known_data - return 200, result - - async def _do_appservice_login( - self, - login_submission: JsonDict, - appservice: ApplicationService, - should_issue_refresh_token: bool = False, - ): - identifier = login_submission.get("identifier") - logger.info("Got appservice login request with identifier: %r", identifier) - - if not isinstance(identifier, dict): - raise SynapseError( - 400, "Invalid identifier in login submission", Codes.INVALID_PARAM - ) - - # this login flow only supports identifiers of type "m.id.user". - if identifier.get("type") != "m.id.user": - raise SynapseError( - 400, "Unknown login identifier type", Codes.INVALID_PARAM - ) - - user = identifier.get("user") - if not isinstance(user, str): - raise SynapseError(400, "Invalid user in identifier", Codes.INVALID_PARAM) - - if user.startswith("@"): - qualified_user_id = user - else: - qualified_user_id = UserID(user, self.hs.hostname).to_string() - - if not appservice.is_interested_in_user(qualified_user_id): - raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN) - - return await self._complete_login( - qualified_user_id, - login_submission, - ratelimit=appservice.is_rate_limited(), - should_issue_refresh_token=should_issue_refresh_token, - ) - - async def _do_other_login( - self, login_submission: JsonDict, should_issue_refresh_token: bool = False - ) -> LoginResponse: - """Handle non-token/saml/jwt logins - - Args: - login_submission: - should_issue_refresh_token: True if this login should issue - a refresh token alongside the access token. - - Returns: - HTTP response - """ - # Log the request we got, but only certain fields to minimise the chance of - # logging someone's password (even if they accidentally put it in the wrong - # field) - logger.info( - "Got login request with identifier: %r, medium: %r, address: %r, user: %r", - login_submission.get("identifier"), - login_submission.get("medium"), - login_submission.get("address"), - login_submission.get("user"), - ) - canonical_user_id, callback = await self.auth_handler.validate_login( - login_submission, ratelimit=True - ) - result = await self._complete_login( - canonical_user_id, - login_submission, - callback, - should_issue_refresh_token=should_issue_refresh_token, - ) - return result - - async def _complete_login( - self, - user_id: str, - login_submission: JsonDict, - callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None, - create_non_existent_users: bool = False, - ratelimit: bool = True, - auth_provider_id: Optional[str] = None, - should_issue_refresh_token: bool = False, - ) -> LoginResponse: - """Called when we've successfully authed the user and now need to - actually login them in (e.g. create devices). This gets called on - all successful logins. - - Applies the ratelimiting for successful login attempts against an - account. - - Args: - user_id: ID of the user to register. - login_submission: Dictionary of login information. - callback: Callback function to run after login. - create_non_existent_users: Whether to create the user if they don't - exist. Defaults to False. - ratelimit: Whether to ratelimit the login request. - auth_provider_id: The SSO IdP the user used, if any (just used for the - prometheus metrics). - should_issue_refresh_token: True if this login should issue - a refresh token alongside the access token. - - Returns: - result: Dictionary of account information after successful login. - """ - - # Before we actually log them in we check if they've already logged in - # too often. This happens here rather than before as we don't - # necessarily know the user before now. - if ratelimit: - await self._account_ratelimiter.ratelimit(None, user_id.lower()) - - if create_non_existent_users: - canonical_uid = await self.auth_handler.check_user_exists(user_id) - if not canonical_uid: - canonical_uid = await self.registration_handler.register_user( - localpart=UserID.from_string(user_id).localpart - ) - user_id = canonical_uid - - device_id = login_submission.get("device_id") - initial_display_name = login_submission.get("initial_device_display_name") - ( - device_id, - access_token, - valid_until_ms, - refresh_token, - ) = await self.registration_handler.register_device( - user_id, - device_id, - initial_display_name, - auth_provider_id=auth_provider_id, - should_issue_refresh_token=should_issue_refresh_token, - ) - - result = LoginResponse( - user_id=user_id, - access_token=access_token, - home_server=self.hs.hostname, - device_id=device_id, - ) - - if valid_until_ms is not None: - expires_in_ms = valid_until_ms - self.clock.time_msec() - result["expires_in_ms"] = expires_in_ms - - if refresh_token is not None: - result["refresh_token"] = refresh_token - - if callback is not None: - await callback(result) - - return result - - async def _do_token_login( - self, login_submission: JsonDict, should_issue_refresh_token: bool = False - ) -> LoginResponse: - """ - Handle the final stage of SSO login. - - Args: - login_submission: The JSON request body. - should_issue_refresh_token: True if this login should issue - a refresh token alongside the access token. - - Returns: - The body of the JSON response. - """ - token = login_submission["token"] - auth_handler = self.auth_handler - res = await auth_handler.validate_short_term_login_token(token) - - return await self._complete_login( - res.user_id, - login_submission, - self.auth_handler._sso_login_callback, - auth_provider_id=res.auth_provider_id, - should_issue_refresh_token=should_issue_refresh_token, - ) - - async def _do_jwt_login( - self, login_submission: JsonDict, should_issue_refresh_token: bool = False - ) -> LoginResponse: - token = login_submission.get("token", None) - if token is None: - raise LoginError( - 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN - ) - - import jwt - - try: - payload = jwt.decode( - token, - self.jwt_secret, - algorithms=[self.jwt_algorithm], - issuer=self.jwt_issuer, - audience=self.jwt_audiences, - ) - except jwt.PyJWTError as e: - # A JWT error occurred, return some info back to the client. - raise LoginError( - 403, - "JWT validation failed: %s" % (str(e),), - errcode=Codes.FORBIDDEN, - ) - - user = payload.get("sub", None) - if user is None: - raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) - - user_id = UserID(user, self.hs.hostname).to_string() - result = await self._complete_login( - user_id, - login_submission, - create_non_existent_users=True, - should_issue_refresh_token=should_issue_refresh_token, - ) - return result - - -def _get_auth_flow_dict_for_idp( - idp: SsoIdentityProvider, use_unstable_brands: bool = False -) -> JsonDict: - """Return an entry for the login flow dict - - Returns an entry suitable for inclusion in "identity_providers" in the - response to GET /_matrix/client/r0/login - - Args: - idp: the identity provider to describe - use_unstable_brands: whether we should use brand identifiers suitable - for the unstable API - """ - e: JsonDict = {"id": idp.idp_id, "name": idp.idp_name} - if idp.idp_icon: - e["icon"] = idp.idp_icon - if idp.idp_brand: - e["brand"] = idp.idp_brand - # use the stable brand identifier if the unstable identifier isn't defined. - if use_unstable_brands and idp.unstable_idp_brand: - e["brand"] = idp.unstable_idp_brand - return e - - -class RefreshTokenServlet(RestServlet): - PATTERNS = client_patterns( - "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True - ) - - def __init__(self, hs: "HomeServer"): - self._auth_handler = hs.get_auth_handler() - self._clock = hs.get_clock() - self.access_token_lifetime = hs.config.access_token_lifetime - - async def on_POST( - self, - request: SynapseRequest, - ): - refresh_submission = parse_json_object_from_request(request) - - assert_params_in_dict(refresh_submission, ["refresh_token"]) - token = refresh_submission["refresh_token"] - if not isinstance(token, str): - raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM) - - valid_until_ms = self._clock.time_msec() + self.access_token_lifetime - access_token, refresh_token = await self._auth_handler.refresh_token( - token, valid_until_ms - ) - expires_in_ms = valid_until_ms - self._clock.time_msec() - return ( - 200, - { - "access_token": access_token, - "refresh_token": refresh_token, - "expires_in_ms": expires_in_ms, - }, - ) - - -class SsoRedirectServlet(RestServlet): - PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [ - re.compile( - "^" - + CLIENT_API_PREFIX - + "/r0/login/sso/redirect/(?P[A-Za-z0-9_.~-]+)$" - ) - ] - - def __init__(self, hs: "HomeServer"): - # make sure that the relevant handlers are instantiated, so that they - # register themselves with the main SSOHandler. - if hs.config.cas_enabled: - hs.get_cas_handler() - if hs.config.saml2_enabled: - hs.get_saml_handler() - if hs.config.oidc_enabled: - hs.get_oidc_handler() - self._sso_handler = hs.get_sso_handler() - self._msc2858_enabled = hs.config.experimental.msc2858_enabled - self._public_baseurl = hs.config.public_baseurl - - def register(self, http_server: HttpServer) -> None: - super().register(http_server) - if self._msc2858_enabled: - # expose additional endpoint for MSC2858 support: backwards-compat support - # for clients which don't yet support the stable endpoints. - http_server.register_paths( - "GET", - client_patterns( - "/org.matrix.msc2858/login/sso/redirect/(?P[A-Za-z0-9_.~-]+)$", - releases=(), - unstable=True, - ), - self.on_GET, - self.__class__.__name__, - ) - - async def on_GET( - self, request: SynapseRequest, idp_id: Optional[str] = None - ) -> None: - if not self._public_baseurl: - raise SynapseError(400, "SSO requires a valid public_baseurl") - - # if this isn't the expected hostname, redirect to the right one, so that we - # get our cookies back. - requested_uri = get_request_uri(request) - baseurl_bytes = self._public_baseurl.encode("utf-8") - if not requested_uri.startswith(baseurl_bytes): - # swap out the incorrect base URL for the right one. - # - # The idea here is to redirect from - # https://foo.bar/whatever/_matrix/... - # to - # https://public.baseurl/_matrix/... - # - i = requested_uri.index(b"/_matrix") - new_uri = baseurl_bytes[:-1] + requested_uri[i:] - logger.info( - "Requested URI %s is not canonical: redirecting to %s", - requested_uri.decode("utf-8", errors="replace"), - new_uri.decode("utf-8", errors="replace"), - ) - request.redirect(new_uri) - finish_request(request) - return - - args: Dict[bytes, List[bytes]] = request.args # type: ignore - client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True) - sso_url = await self._sso_handler.handle_redirect_request( - request, - client_redirect_url, - idp_id, - ) - logger.info("Redirecting to %s", sso_url) - request.redirect(sso_url) - finish_request(request) - - -class CasTicketServlet(RestServlet): - PATTERNS = client_patterns("/login/cas/ticket", v1=True) - - def __init__(self, hs): - super().__init__() - self._cas_handler = hs.get_cas_handler() - - async def on_GET(self, request: SynapseRequest) -> None: - client_redirect_url = parse_string(request, "redirectUrl") - ticket = parse_string(request, "ticket", required=True) - - # Maybe get a session ID (if this ticket is from user interactive - # authentication). - session = parse_string(request, "session") - - # Either client_redirect_url or session must be provided. - if not client_redirect_url and not session: - message = "Missing string query parameter redirectUrl or session" - raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) - - await self._cas_handler.handle_ticket( - request, ticket, client_redirect_url, session - ) - - -def register_servlets(hs, http_server): - LoginRestServlet(hs).register(http_server) - if hs.config.access_token_lifetime is not None: - RefreshTokenServlet(hs).register(http_server) - SsoRedirectServlet(hs).register(http_server) - if hs.config.cas_enabled: - CasTicketServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py deleted file mode 100644 index 5aa7908d73..0000000000 --- a/synapse/rest/client/v1/logout.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.http.servlet import RestServlet -from synapse.rest.client.v2_alpha._base import client_patterns - -logger = logging.getLogger(__name__) - - -class LogoutRestServlet(RestServlet): - PATTERNS = client_patterns("/logout$", v1=True) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() - - async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request, allow_expired=True) - - if requester.device_id is None: - # The access token wasn't associated with a device. - # Just delete the access token - access_token = self.auth.get_access_token_from_request(request) - await self._auth_handler.delete_access_token(access_token) - else: - await self._device_handler.delete_device( - requester.user.to_string(), requester.device_id - ) - - return 200, {} - - -class LogoutAllRestServlet(RestServlet): - PATTERNS = client_patterns("/logout/all$", v1=True) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self._auth_handler = hs.get_auth_handler() - self._device_handler = hs.get_device_handler() - - async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request, allow_expired=True) - user_id = requester.user.to_string() - - # first delete all of the user's devices - await self._device_handler.delete_all_devices_for_user(user_id) - - # .. and then delete any access tokens which weren't associated with - # devices. - await self._auth_handler.delete_access_tokens_for_user(user_id) - return 200, {} - - -def register_servlets(hs, http_server): - LogoutRestServlet(hs).register(http_server) - LogoutAllRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py deleted file mode 100644 index 2b24fe5aa6..0000000000 --- a/synapse/rest/client/v1/presence.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" This module contains REST servlets to do with presence: /presence/ -""" -import logging - -from synapse.api.errors import AuthError, SynapseError -from synapse.handlers.presence import format_user_presence_state -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.types import UserID - -logger = logging.getLogger(__name__) - - -class PresenceStatusRestServlet(RestServlet): - PATTERNS = client_patterns("/presence/(?P[^/]*)/status", v1=True) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.presence_handler = hs.get_presence_handler() - self.clock = hs.get_clock() - self.auth = hs.get_auth() - - self._use_presence = hs.config.server.use_presence - - async def on_GET(self, request, user_id): - requester = await self.auth.get_user_by_req(request) - user = UserID.from_string(user_id) - - if not self._use_presence: - return 200, {"presence": "offline"} - - if requester.user != user: - allowed = await self.presence_handler.is_visible( - observed_user=user, observer_user=requester.user - ) - - if not allowed: - raise AuthError(403, "You are not allowed to see their presence.") - - state = await self.presence_handler.get_state(target_user=user) - state = format_user_presence_state( - state, self.clock.time_msec(), include_user_id=False - ) - - return 200, state - - async def on_PUT(self, request, user_id): - requester = await self.auth.get_user_by_req(request) - user = UserID.from_string(user_id) - - if requester.user != user: - raise AuthError(403, "Can only set your own presence state") - - state = {} - - content = parse_json_object_from_request(request) - - try: - state["presence"] = content.pop("presence") - - if "status_msg" in content: - state["status_msg"] = content.pop("status_msg") - if not isinstance(state["status_msg"], str): - raise SynapseError(400, "status_msg must be a string.") - - if content: - raise KeyError() - except SynapseError as e: - raise e - except Exception: - raise SynapseError(400, "Unable to parse state") - - if self._use_presence: - await self.presence_handler.set_state(user, state) - - return 200, {} - - -def register_servlets(hs, http_server): - PresenceStatusRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py deleted file mode 100644 index f42f4b3567..0000000000 --- a/synapse/rest/client/v1/profile.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" This module contains REST servlets to do with profile: /profile/ """ - -from synapse.api.errors import Codes, SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.types import UserID - - -class ProfileDisplaynameRestServlet(RestServlet): - PATTERNS = client_patterns("/profile/(?P[^/]*)/displayname", v1=True) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.profile_handler = hs.get_profile_handler() - self.auth = hs.get_auth() - - async def on_GET(self, request, user_id): - requester_user = None - - if self.hs.config.require_auth_for_profile_requests: - requester = await self.auth.get_user_by_req(request) - requester_user = requester.user - - user = UserID.from_string(user_id) - - await self.profile_handler.check_profile_query_allowed(user, requester_user) - - displayname = await self.profile_handler.get_displayname(user) - - ret = {} - if displayname is not None: - ret["displayname"] = displayname - - return 200, ret - - async def on_PUT(self, request, user_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - user = UserID.from_string(user_id) - is_admin = await self.auth.is_server_admin(requester.user) - - content = parse_json_object_from_request(request) - - try: - new_name = content["displayname"] - except Exception: - raise SynapseError( - code=400, - msg="Unable to parse name", - errcode=Codes.BAD_JSON, - ) - - await self.profile_handler.set_displayname(user, requester, new_name, is_admin) - - return 200, {} - - -class ProfileAvatarURLRestServlet(RestServlet): - PATTERNS = client_patterns("/profile/(?P[^/]*)/avatar_url", v1=True) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.profile_handler = hs.get_profile_handler() - self.auth = hs.get_auth() - - async def on_GET(self, request, user_id): - requester_user = None - - if self.hs.config.require_auth_for_profile_requests: - requester = await self.auth.get_user_by_req(request) - requester_user = requester.user - - user = UserID.from_string(user_id) - - await self.profile_handler.check_profile_query_allowed(user, requester_user) - - avatar_url = await self.profile_handler.get_avatar_url(user) - - ret = {} - if avatar_url is not None: - ret["avatar_url"] = avatar_url - - return 200, ret - - async def on_PUT(self, request, user_id): - requester = await self.auth.get_user_by_req(request) - user = UserID.from_string(user_id) - is_admin = await self.auth.is_server_admin(requester.user) - - content = parse_json_object_from_request(request) - try: - new_avatar_url = content["avatar_url"] - except KeyError: - raise SynapseError( - 400, "Missing key 'avatar_url'", errcode=Codes.MISSING_PARAM - ) - - await self.profile_handler.set_avatar_url( - user, requester, new_avatar_url, is_admin - ) - - return 200, {} - - -class ProfileRestServlet(RestServlet): - PATTERNS = client_patterns("/profile/(?P[^/]*)", v1=True) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.profile_handler = hs.get_profile_handler() - self.auth = hs.get_auth() - - async def on_GET(self, request, user_id): - requester_user = None - - if self.hs.config.require_auth_for_profile_requests: - requester = await self.auth.get_user_by_req(request) - requester_user = requester.user - - user = UserID.from_string(user_id) - - await self.profile_handler.check_profile_query_allowed(user, requester_user) - - displayname = await self.profile_handler.get_displayname(user) - avatar_url = await self.profile_handler.get_avatar_url(user) - - ret = {} - if displayname is not None: - ret["displayname"] = displayname - if avatar_url is not None: - ret["avatar_url"] = avatar_url - - return 200, ret - - -def register_servlets(hs, http_server): - ProfileDisplaynameRestServlet(hs).register(http_server) - ProfileAvatarURLRestServlet(hs).register(http_server) - ProfileRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py deleted file mode 100644 index be29a0b39e..0000000000 --- a/synapse/rest/client/v1/push_rule.py +++ /dev/null @@ -1,354 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.api.errors import ( - NotFoundError, - StoreError, - SynapseError, - UnrecognizedRequestError, -) -from synapse.http.servlet import ( - RestServlet, - parse_json_value_from_request, - parse_string, -) -from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS -from synapse.push.clientformat import format_push_rules_for_user -from synapse.push.rulekinds import PRIORITY_CLASS_MAP -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException - - -class PushRuleRestServlet(RestServlet): - PATTERNS = client_patterns("/(?Ppushrules/.*)$", v1=True) - SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( - "Unrecognised request: You probably wanted a trailing slash" - ) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.notifier = hs.get_notifier() - self._is_worker = hs.config.worker_app is not None - - self._users_new_default_push_rules = hs.config.users_new_default_push_rules - - async def on_PUT(self, request, path): - if self._is_worker: - raise Exception("Cannot handle PUT /push_rules on worker") - - spec = _rule_spec_from_path(path.split("/")) - try: - priority_class = _priority_class_from_spec(spec) - except InvalidRuleException as e: - raise SynapseError(400, str(e)) - - requester = await self.auth.get_user_by_req(request) - - if "/" in spec["rule_id"] or "\\" in spec["rule_id"]: - raise SynapseError(400, "rule_id may not contain slashes") - - content = parse_json_value_from_request(request) - - user_id = requester.user.to_string() - - if "attr" in spec: - await self.set_rule_attr(user_id, spec, content) - self.notify_user(user_id) - return 200, {} - - if spec["rule_id"].startswith("."): - # Rule ids starting with '.' are reserved for server default rules. - raise SynapseError(400, "cannot add new rule_ids that start with '.'") - - try: - (conditions, actions) = _rule_tuple_from_request_object( - spec["template"], spec["rule_id"], content - ) - except InvalidRuleException as e: - raise SynapseError(400, str(e)) - - before = parse_string(request, "before") - if before: - before = _namespaced_rule_id(spec, before) - - after = parse_string(request, "after") - if after: - after = _namespaced_rule_id(spec, after) - - try: - await self.store.add_push_rule( - user_id=user_id, - rule_id=_namespaced_rule_id_from_spec(spec), - priority_class=priority_class, - conditions=conditions, - actions=actions, - before=before, - after=after, - ) - self.notify_user(user_id) - except InconsistentRuleException as e: - raise SynapseError(400, str(e)) - except RuleNotFoundException as e: - raise SynapseError(400, str(e)) - - return 200, {} - - async def on_DELETE(self, request, path): - if self._is_worker: - raise Exception("Cannot handle DELETE /push_rules on worker") - - spec = _rule_spec_from_path(path.split("/")) - - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - - namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - - try: - await self.store.delete_push_rule(user_id, namespaced_rule_id) - self.notify_user(user_id) - return 200, {} - except StoreError as e: - if e.code == 404: - raise NotFoundError() - else: - raise - - async def on_GET(self, request, path): - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - - # we build up the full structure and then decide which bits of it - # to send which means doing unnecessary work sometimes but is - # is probably not going to make a whole lot of difference - rules = await self.store.get_push_rules_for_user(user_id) - - rules = format_push_rules_for_user(requester.user, rules) - - path = path.split("/")[1:] - - if path == []: - # we're a reference impl: pedantry is our job. - raise UnrecognizedRequestError( - PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR - ) - - if path[0] == "": - return 200, rules - elif path[0] == "global": - result = _filter_ruleset_with_path(rules["global"], path[1:]) - return 200, result - else: - raise UnrecognizedRequestError() - - def notify_user(self, user_id): - stream_id = self.store.get_max_push_rules_stream_id() - self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) - - async def set_rule_attr(self, user_id, spec, val): - if spec["attr"] not in ("enabled", "actions"): - # for the sake of potential future expansion, shouldn't report - # 404 in the case of an unknown request so check it corresponds to - # a known attribute first. - raise UnrecognizedRequestError() - - namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec["rule_id"] - is_default_rule = rule_id.startswith(".") - if is_default_rule: - if namespaced_rule_id not in BASE_RULE_IDS: - raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,)) - if spec["attr"] == "enabled": - if isinstance(val, dict) and "enabled" in val: - val = val["enabled"] - if not isinstance(val, bool): - # Legacy fallback - # This should *actually* take a dict, but many clients pass - # bools directly, so let's not break them. - raise SynapseError(400, "Value for 'enabled' must be boolean") - return await self.store.set_push_rule_enabled( - user_id, namespaced_rule_id, val, is_default_rule - ) - elif spec["attr"] == "actions": - actions = val.get("actions") - _check_actions(actions) - namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec["rule_id"] - is_default_rule = rule_id.startswith(".") - if is_default_rule: - if user_id in self._users_new_default_push_rules: - rule_ids = NEW_RULE_IDS - else: - rule_ids = BASE_RULE_IDS - - if namespaced_rule_id not in rule_ids: - raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) - return await self.store.set_push_rule_actions( - user_id, namespaced_rule_id, actions, is_default_rule - ) - else: - raise UnrecognizedRequestError() - - -def _rule_spec_from_path(path): - """Turn a sequence of path components into a rule spec - - Args: - path (sequence[unicode]): the URL path components. - - Returns: - dict: rule spec dict, containing scope/template/rule_id entries, - and possibly attr. - - Raises: - UnrecognizedRequestError if the path components cannot be parsed. - """ - if len(path) < 2: - raise UnrecognizedRequestError() - if path[0] != "pushrules": - raise UnrecognizedRequestError() - - scope = path[1] - path = path[2:] - if scope != "global": - raise UnrecognizedRequestError() - - if len(path) == 0: - raise UnrecognizedRequestError() - - template = path[0] - path = path[1:] - - if len(path) == 0 or len(path[0]) == 0: - raise UnrecognizedRequestError() - - rule_id = path[0] - - spec = {"scope": scope, "template": template, "rule_id": rule_id} - - path = path[1:] - - if len(path) > 0 and len(path[0]) > 0: - spec["attr"] = path[0] - - return spec - - -def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): - if rule_template in ["override", "underride"]: - if "conditions" not in req_obj: - raise InvalidRuleException("Missing 'conditions'") - conditions = req_obj["conditions"] - for c in conditions: - if "kind" not in c: - raise InvalidRuleException("Condition without 'kind'") - elif rule_template == "room": - conditions = [{"kind": "event_match", "key": "room_id", "pattern": rule_id}] - elif rule_template == "sender": - conditions = [{"kind": "event_match", "key": "user_id", "pattern": rule_id}] - elif rule_template == "content": - if "pattern" not in req_obj: - raise InvalidRuleException("Content rule missing 'pattern'") - pat = req_obj["pattern"] - - conditions = [{"kind": "event_match", "key": "content.body", "pattern": pat}] - else: - raise InvalidRuleException("Unknown rule template: %s" % (rule_template,)) - - if "actions" not in req_obj: - raise InvalidRuleException("No actions found") - actions = req_obj["actions"] - - _check_actions(actions) - - return conditions, actions - - -def _check_actions(actions): - if not isinstance(actions, list): - raise InvalidRuleException("No actions found") - - for a in actions: - if a in ["notify", "dont_notify", "coalesce"]: - pass - elif isinstance(a, dict) and "set_tweak" in a: - pass - else: - raise InvalidRuleException("Unrecognised action") - - -def _filter_ruleset_with_path(ruleset, path): - if path == []: - raise UnrecognizedRequestError( - PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR - ) - - if path[0] == "": - return ruleset - template_kind = path[0] - if template_kind not in ruleset: - raise UnrecognizedRequestError() - path = path[1:] - if path == []: - raise UnrecognizedRequestError( - PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR - ) - if path[0] == "": - return ruleset[template_kind] - rule_id = path[0] - - the_rule = None - for r in ruleset[template_kind]: - if r["rule_id"] == rule_id: - the_rule = r - if the_rule is None: - raise NotFoundError - - path = path[1:] - if len(path) == 0: - return the_rule - - attr = path[0] - if attr in the_rule: - # Make sure we return a JSON object as the attribute may be a - # JSON value. - return {attr: the_rule[attr]} - else: - raise UnrecognizedRequestError() - - -def _priority_class_from_spec(spec): - if spec["template"] not in PRIORITY_CLASS_MAP.keys(): - raise InvalidRuleException("Unknown template: %s" % (spec["template"])) - pc = PRIORITY_CLASS_MAP[spec["template"]] - - return pc - - -def _namespaced_rule_id_from_spec(spec): - return _namespaced_rule_id(spec, spec["rule_id"]) - - -def _namespaced_rule_id(spec, rule_id): - return "global/%s/%s" % (spec["template"], rule_id) - - -class InvalidRuleException(Exception): - pass - - -def register_servlets(hs, http_server): - PushRuleRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py deleted file mode 100644 index 18102eca6c..0000000000 --- a/synapse/rest/client/v1/pusher.py +++ /dev/null @@ -1,171 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.errors import Codes, StoreError, SynapseError -from synapse.http.server import respond_with_html_bytes -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_json_object_from_request, - parse_string, -) -from synapse.push import PusherConfigException -from synapse.rest.client.v2_alpha._base import client_patterns - -logger = logging.getLogger(__name__) - - -class PushersRestServlet(RestServlet): - PATTERNS = client_patterns("/pushers$", v1=True) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request) - user = requester.user - - pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) - - filtered_pushers = [p.as_dict() for p in pushers] - - return 200, {"pushers": filtered_pushers} - - -class PushersSetRestServlet(RestServlet): - PATTERNS = client_patterns("/pushers/set$", v1=True) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.notifier = hs.get_notifier() - self.pusher_pool = self.hs.get_pusherpool() - - async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request) - user = requester.user - - content = parse_json_object_from_request(request) - - if ( - "pushkey" in content - and "app_id" in content - and "kind" in content - and content["kind"] is None - ): - await self.pusher_pool.remove_pusher( - content["app_id"], content["pushkey"], user_id=user.to_string() - ) - return 200, {} - - assert_params_in_dict( - content, - [ - "kind", - "app_id", - "app_display_name", - "device_display_name", - "pushkey", - "lang", - "data", - ], - ) - - logger.debug("set pushkey %s to kind %s", content["pushkey"], content["kind"]) - logger.debug("Got pushers request with body: %r", content) - - append = False - if "append" in content: - append = content["append"] - - if not append: - await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( - app_id=content["app_id"], - pushkey=content["pushkey"], - not_user_id=user.to_string(), - ) - - try: - await self.pusher_pool.add_pusher( - user_id=user.to_string(), - access_token=requester.access_token_id, - kind=content["kind"], - app_id=content["app_id"], - app_display_name=content["app_display_name"], - device_display_name=content["device_display_name"], - pushkey=content["pushkey"], - lang=content["lang"], - data=content["data"], - profile_tag=content.get("profile_tag", ""), - ) - except PusherConfigException as pce: - raise SynapseError( - 400, "Config Error: " + str(pce), errcode=Codes.MISSING_PARAM - ) - - self.notifier.on_new_replication_data() - - return 200, {} - - -class PushersRemoveRestServlet(RestServlet): - """ - To allow pusher to be delete by clicking a link (ie. GET request) - """ - - PATTERNS = client_patterns("/pushers/remove$", v1=True) - SUCCESS_HTML = b"You have been unsubscribed" - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.notifier = hs.get_notifier() - self.auth = hs.get_auth() - self.pusher_pool = self.hs.get_pusherpool() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request, rights="delete_pusher") - user = requester.user - - app_id = parse_string(request, "app_id", required=True) - pushkey = parse_string(request, "pushkey", required=True) - - try: - await self.pusher_pool.remove_pusher( - app_id=app_id, pushkey=pushkey, user_id=user.to_string() - ) - except StoreError as se: - if se.code != 404: - # This is fine: they're already unsubscribed - raise - - self.notifier.on_new_replication_data() - - respond_with_html_bytes( - request, - 200, - PushersRemoveRestServlet.SUCCESS_HTML, - ) - return None - - -def register_servlets(hs, http_server): - PushersRestServlet(hs).register(http_server) - PushersSetRestServlet(hs).register(http_server) - PushersRemoveRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py deleted file mode 100644 index ba7250ad8e..0000000000 --- a/synapse/rest/client/v1/room.py +++ /dev/null @@ -1,1152 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" This module contains REST servlets to do with rooms: /rooms/ """ -import logging -import re -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple -from urllib import parse as urlparse - -from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import ( - AuthError, - Codes, - InvalidClientCredentialsError, - MissingClientTokenError, - ShadowBanError, - SynapseError, -) -from synapse.api.filtering import Filter -from synapse.events.utils import format_event_for_client_v2 -from synapse.http.servlet import ( - ResolveRoomIdMixin, - RestServlet, - assert_params_in_dict, - parse_boolean, - parse_integer, - parse_json_object_from_request, - parse_string, - parse_strings_from_args, -) -from synapse.http.site import SynapseRequest -from synapse.logging.opentracing import set_tag -from synapse.rest.client.transactions import HttpTransactionCache -from synapse.rest.client.v2_alpha._base import client_patterns -from synapse.storage.state import StateFilter -from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID -from synapse.util import json_decoder -from synapse.util.stringutils import parse_and_validate_server_name, random_string - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class TransactionRestServlet(RestServlet): - def __init__(self, hs): - super().__init__() - self.txns = HttpTransactionCache(hs) - - -class RoomCreateRestServlet(TransactionRestServlet): - # No PATTERN; we have custom dispatch rules here - - def __init__(self, hs): - super().__init__(hs) - self._room_creation_handler = hs.get_room_creation_handler() - self.auth = hs.get_auth() - - def register(self, http_server): - PATTERNS = "/createRoom" - register_txn_path(self, PATTERNS, http_server) - - def on_PUT(self, request, txn_id): - set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request(request, self.on_POST, request) - - async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request) - - info, _ = await self._room_creation_handler.create_room( - requester, self.get_room_config(request) - ) - - return 200, info - - def get_room_config(self, request): - user_supplied_config = parse_json_object_from_request(request) - return user_supplied_config - - -# TODO: Needs unit testing for generic events -class RoomStateEventRestServlet(TransactionRestServlet): - def __init__(self, hs): - super().__init__(hs) - self.event_creation_handler = hs.get_event_creation_handler() - self.room_member_handler = hs.get_room_member_handler() - self.message_handler = hs.get_message_handler() - self.auth = hs.get_auth() - - def register(self, http_server): - # /room/$roomid/state/$eventtype - no_state_key = "/rooms/(?P[^/]*)/state/(?P[^/]*)$" - - # /room/$roomid/state/$eventtype/$statekey - state_key = ( - "/rooms/(?P[^/]*)/state/" - "(?P[^/]*)/(?P[^/]*)$" - ) - - http_server.register_paths( - "GET", - client_patterns(state_key, v1=True), - self.on_GET, - self.__class__.__name__, - ) - http_server.register_paths( - "PUT", - client_patterns(state_key, v1=True), - self.on_PUT, - self.__class__.__name__, - ) - http_server.register_paths( - "GET", - client_patterns(no_state_key, v1=True), - self.on_GET_no_state_key, - self.__class__.__name__, - ) - http_server.register_paths( - "PUT", - client_patterns(no_state_key, v1=True), - self.on_PUT_no_state_key, - self.__class__.__name__, - ) - - def on_GET_no_state_key(self, request, room_id, event_type): - return self.on_GET(request, room_id, event_type, "") - - def on_PUT_no_state_key(self, request, room_id, event_type): - return self.on_PUT(request, room_id, event_type, "") - - async def on_GET(self, request, room_id, event_type, state_key): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - format = parse_string( - request, "format", default="content", allowed_values=["content", "event"] - ) - - msg_handler = self.message_handler - data = await msg_handler.get_room_data( - user_id=requester.user.to_string(), - room_id=room_id, - event_type=event_type, - state_key=state_key, - ) - - if not data: - raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) - - if format == "event": - event = format_event_for_client_v2(data.get_dict()) - return 200, event - elif format == "content": - return 200, data.get_dict()["content"] - - async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): - requester = await self.auth.get_user_by_req(request) - - if txn_id: - set_tag("txn_id", txn_id) - - content = parse_json_object_from_request(request) - - event_dict = { - "type": event_type, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - } - - if state_key is not None: - event_dict["state_key"] = state_key - - try: - if event_type == EventTypes.Member: - membership = content.get("membership", None) - event_id, _ = await self.room_member_handler.update_membership( - requester, - target=UserID.from_string(state_key), - room_id=room_id, - action=membership, - content=content, - ) - else: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict, txn_id=txn_id - ) - event_id = event.event_id - except ShadowBanError: - event_id = "$" + random_string(43) - - set_tag("event_id", event_id) - ret = {"event_id": event_id} - return 200, ret - - -# TODO: Needs unit testing for generic events + feedback -class RoomSendEventRestServlet(TransactionRestServlet): - def __init__(self, hs): - super().__init__(hs) - self.event_creation_handler = hs.get_event_creation_handler() - self.auth = hs.get_auth() - - def register(self, http_server): - # /rooms/$roomid/send/$event_type[/$txn_id] - PATTERNS = "/rooms/(?P[^/]*)/send/(?P[^/]*)" - register_txn_path(self, PATTERNS, http_server, with_get=True) - - async def on_POST(self, request, room_id, event_type, txn_id=None): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - content = parse_json_object_from_request(request) - - event_dict = { - "type": event_type, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - } - - if b"ts" in request.args and requester.app_service: - event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) - - try: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict, txn_id=txn_id - ) - event_id = event.event_id - except ShadowBanError: - event_id = "$" + random_string(43) - - set_tag("event_id", event_id) - return 200, {"event_id": event_id} - - def on_GET(self, request, room_id, event_type, txn_id): - return 200, "Not implemented" - - def on_PUT(self, request, room_id, event_type, txn_id): - set_tag("txn_id", txn_id) - - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id, event_type, txn_id - ) - - -# TODO: Needs unit testing for room ID + alias joins -class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): - def __init__(self, hs): - super().__init__(hs) - super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up - self.auth = hs.get_auth() - - def register(self, http_server): - # /join/$room_identifier[/$txn_id] - PATTERNS = "/join/(?P[^/]*)" - register_txn_path(self, PATTERNS, http_server) - - async def on_POST( - self, - request: SynapseRequest, - room_identifier: str, - txn_id: Optional[str] = None, - ): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - try: - content = parse_json_object_from_request(request) - except Exception: - # Turns out we used to ignore the body entirely, and some clients - # cheekily send invalid bodies. - content = {} - - # twisted.web.server.Request.args is incorrectly defined as Optional[Any] - args: Dict[bytes, List[bytes]] = request.args # type: ignore - remote_room_hosts = parse_strings_from_args(args, "server_name", required=False) - room_id, remote_room_hosts = await self.resolve_room_id( - room_identifier, - remote_room_hosts, - ) - - await self.room_member_handler.update_membership( - requester=requester, - target=requester.user, - room_id=room_id, - action="join", - txn_id=txn_id, - remote_room_hosts=remote_room_hosts, - content=content, - third_party_signed=content.get("third_party_signed", None), - ) - - return 200, {"room_id": room_id} - - def on_PUT(self, request, room_identifier, txn_id): - set_tag("txn_id", txn_id) - - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_identifier, txn_id - ) - - -# TODO: Needs unit testing -class PublicRoomListRestServlet(TransactionRestServlet): - PATTERNS = client_patterns("/publicRooms$", v1=True) - - def __init__(self, hs): - super().__init__(hs) - self.hs = hs - self.auth = hs.get_auth() - - async def on_GET(self, request): - server = parse_string(request, "server") - - try: - await self.auth.get_user_by_req(request, allow_guest=True) - except InvalidClientCredentialsError as e: - # Option to allow servers to require auth when accessing - # /publicRooms via CS API. This is especially helpful in private - # federations. - if not self.hs.config.allow_public_rooms_without_auth: - raise - - # We allow people to not be authed if they're just looking at our - # room list, but require auth when we proxy the request. - # In both cases we call the auth function, as that has the side - # effect of logging who issued this request if an access token was - # provided. - if server: - raise e - - limit: Optional[int] = parse_integer(request, "limit", 0) - since_token = parse_string(request, "since") - - if limit == 0: - # zero is a special value which corresponds to no limit. - limit = None - - handler = self.hs.get_room_list_handler() - if server and server != self.hs.config.server_name: - # Ensure the server is valid. - try: - parse_and_validate_server_name(server) - except ValueError: - raise SynapseError( - 400, - "Invalid server name: %s" % (server,), - Codes.INVALID_PARAM, - ) - - data = await handler.get_remote_public_room_list( - server, limit=limit, since_token=since_token - ) - else: - data = await handler.get_local_public_room_list( - limit=limit, since_token=since_token - ) - - return 200, data - - async def on_POST(self, request): - await self.auth.get_user_by_req(request, allow_guest=True) - - server = parse_string(request, "server") - content = parse_json_object_from_request(request) - - limit: Optional[int] = int(content.get("limit", 100)) - since_token = content.get("since", None) - search_filter = content.get("filter", None) - - include_all_networks = content.get("include_all_networks", False) - third_party_instance_id = content.get("third_party_instance_id", None) - - if include_all_networks: - network_tuple = None - if third_party_instance_id is not None: - raise SynapseError( - 400, "Can't use include_all_networks with an explicit network" - ) - elif third_party_instance_id is None: - network_tuple = ThirdPartyInstanceID(None, None) - else: - network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) - - if limit == 0: - # zero is a special value which corresponds to no limit. - limit = None - - handler = self.hs.get_room_list_handler() - if server and server != self.hs.config.server_name: - # Ensure the server is valid. - try: - parse_and_validate_server_name(server) - except ValueError: - raise SynapseError( - 400, - "Invalid server name: %s" % (server,), - Codes.INVALID_PARAM, - ) - - data = await handler.get_remote_public_room_list( - server, - limit=limit, - since_token=since_token, - search_filter=search_filter, - include_all_networks=include_all_networks, - third_party_instance_id=third_party_instance_id, - ) - - else: - data = await handler.get_local_public_room_list( - limit=limit, - since_token=since_token, - search_filter=search_filter, - network_tuple=network_tuple, - ) - - return 200, data - - -# TODO: Needs unit testing -class RoomMemberListRestServlet(RestServlet): - PATTERNS = client_patterns("/rooms/(?P[^/]*)/members$", v1=True) - - def __init__(self, hs): - super().__init__() - self.message_handler = hs.get_message_handler() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - - async def on_GET(self, request, room_id): - # TODO support Pagination stream API (limit/tokens) - requester = await self.auth.get_user_by_req(request, allow_guest=True) - handler = self.message_handler - - # request the state as of a given event, as identified by a stream token, - # for consistency with /messages etc. - # useful for getting the membership in retrospect as of a given /sync - # response. - at_token_string = parse_string(request, "at") - if at_token_string is None: - at_token = None - else: - at_token = await StreamToken.from_string(self.store, at_token_string) - - # let you filter down on particular memberships. - # XXX: this may not be the best shape for this API - we could pass in a filter - # instead, except filters aren't currently aware of memberships. - # See https://github.com/matrix-org/matrix-doc/issues/1337 for more details. - membership = parse_string(request, "membership") - not_membership = parse_string(request, "not_membership") - - events = await handler.get_state_events( - room_id=room_id, - user_id=requester.user.to_string(), - at_token=at_token, - state_filter=StateFilter.from_types([(EventTypes.Member, None)]), - ) - - chunk = [] - - for event in events: - if (membership and event["content"].get("membership") != membership) or ( - not_membership and event["content"].get("membership") == not_membership - ): - continue - chunk.append(event) - - return 200, {"chunk": chunk} - - -# deprecated in favour of /members?membership=join? -# except it does custom AS logic and has a simpler return format -class JoinedRoomMemberListRestServlet(RestServlet): - PATTERNS = client_patterns("/rooms/(?P[^/]*)/joined_members$", v1=True) - - def __init__(self, hs): - super().__init__() - self.message_handler = hs.get_message_handler() - self.auth = hs.get_auth() - - async def on_GET(self, request, room_id): - requester = await self.auth.get_user_by_req(request) - - users_with_profile = await self.message_handler.get_joined_members( - requester, room_id - ) - - return 200, {"joined": users_with_profile} - - -# TODO: Needs better unit testing -class RoomMessageListRestServlet(RestServlet): - PATTERNS = client_patterns("/rooms/(?P[^/]*)/messages$", v1=True) - - def __init__(self, hs): - super().__init__() - self.pagination_handler = hs.get_pagination_handler() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - - async def on_GET(self, request, room_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = await PaginationConfig.from_request( - self.store, request, default_limit=10 - ) - as_client_event = b"raw" not in request.args - filter_str = parse_string(request, "filter", encoding="utf-8") - if filter_str: - filter_json = urlparse.unquote(filter_str) - event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) - if ( - event_filter - and event_filter.filter_json.get("event_format", "client") - == "federation" - ): - as_client_event = False - else: - event_filter = None - - msgs = await self.pagination_handler.get_messages( - room_id=room_id, - requester=requester, - pagin_config=pagination_config, - as_client_event=as_client_event, - event_filter=event_filter, - ) - - return 200, msgs - - -# TODO: Needs unit testing -class RoomStateRestServlet(RestServlet): - PATTERNS = client_patterns("/rooms/(?P[^/]*)/state$", v1=True) - - def __init__(self, hs): - super().__init__() - self.message_handler = hs.get_message_handler() - self.auth = hs.get_auth() - - async def on_GET(self, request, room_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - # Get all the current state for this room - events = await self.message_handler.get_state_events( - room_id=room_id, - user_id=requester.user.to_string(), - is_guest=requester.is_guest, - ) - return 200, events - - -# TODO: Needs unit testing -class RoomInitialSyncRestServlet(RestServlet): - PATTERNS = client_patterns("/rooms/(?P[^/]*)/initialSync$", v1=True) - - def __init__(self, hs): - super().__init__() - self.initial_sync_handler = hs.get_initial_sync_handler() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - - async def on_GET(self, request, room_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = await PaginationConfig.from_request(self.store, request) - content = await self.initial_sync_handler.room_initial_sync( - room_id=room_id, requester=requester, pagin_config=pagination_config - ) - return 200, content - - -class RoomEventServlet(RestServlet): - PATTERNS = client_patterns( - "/rooms/(?P[^/]*)/event/(?P[^/]*)$", v1=True - ) - - def __init__(self, hs): - super().__init__() - self.clock = hs.get_clock() - self.event_handler = hs.get_event_handler() - self._event_serializer = hs.get_event_client_serializer() - self.auth = hs.get_auth() - - async def on_GET(self, request, room_id, event_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - try: - event = await self.event_handler.get_event( - requester.user, room_id, event_id - ) - except AuthError: - # This endpoint is supposed to return a 404 when the requester does - # not have permission to access the event - # https://matrix.org/docs/spec/client_server/r0.5.0#get-matrix-client-r0-rooms-roomid-event-eventid - raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) - - time_now = self.clock.time_msec() - if event: - event = await self._event_serializer.serialize_event(event, time_now) - return 200, event - - return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) - - -class RoomEventContextServlet(RestServlet): - PATTERNS = client_patterns( - "/rooms/(?P[^/]*)/context/(?P[^/]*)$", v1=True - ) - - def __init__(self, hs): - super().__init__() - self.clock = hs.get_clock() - self.room_context_handler = hs.get_room_context_handler() - self._event_serializer = hs.get_event_client_serializer() - self.auth = hs.get_auth() - - async def on_GET(self, request, room_id, event_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - limit = parse_integer(request, "limit", default=10) - - # picking the API shape for symmetry with /messages - filter_str = parse_string(request, "filter", encoding="utf-8") - if filter_str: - filter_json = urlparse.unquote(filter_str) - event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) - else: - event_filter = None - - results = await self.room_context_handler.get_event_context( - requester, room_id, event_id, limit, event_filter - ) - - if not results: - raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) - - time_now = self.clock.time_msec() - results["events_before"] = await self._event_serializer.serialize_events( - results["events_before"], time_now - ) - results["event"] = await self._event_serializer.serialize_event( - results["event"], time_now - ) - results["events_after"] = await self._event_serializer.serialize_events( - results["events_after"], time_now - ) - results["state"] = await self._event_serializer.serialize_events( - results["state"], - time_now, - # No need to bundle aggregations for state events - bundle_aggregations=False, - ) - - return 200, results - - -class RoomForgetRestServlet(TransactionRestServlet): - def __init__(self, hs): - super().__init__(hs) - self.room_member_handler = hs.get_room_member_handler() - self.auth = hs.get_auth() - - def register(self, http_server): - PATTERNS = "/rooms/(?P[^/]*)/forget" - register_txn_path(self, PATTERNS, http_server) - - async def on_POST(self, request, room_id, txn_id=None): - requester = await self.auth.get_user_by_req(request, allow_guest=False) - - await self.room_member_handler.forget(user=requester.user, room_id=room_id) - - return 200, {} - - def on_PUT(self, request, room_id, txn_id): - set_tag("txn_id", txn_id) - - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id, txn_id - ) - - -# TODO: Needs unit testing -class RoomMembershipRestServlet(TransactionRestServlet): - def __init__(self, hs): - super().__init__(hs) - self.room_member_handler = hs.get_room_member_handler() - self.auth = hs.get_auth() - - def register(self, http_server): - # /rooms/$roomid/[invite|join|leave] - PATTERNS = ( - "/rooms/(?P[^/]*)/" - "(?Pjoin|invite|leave|ban|unban|kick)" - ) - register_txn_path(self, PATTERNS, http_server) - - async def on_POST(self, request, room_id, membership_action, txn_id=None): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - if requester.is_guest and membership_action not in { - Membership.JOIN, - Membership.LEAVE, - }: - raise AuthError(403, "Guest access not allowed") - - try: - content = parse_json_object_from_request(request) - except Exception: - # Turns out we used to ignore the body entirely, and some clients - # cheekily send invalid bodies. - content = {} - - if membership_action == "invite" and self._has_3pid_invite_keys(content): - try: - await self.room_member_handler.do_3pid_invite( - room_id, - requester.user, - content["medium"], - content["address"], - content["id_server"], - requester, - txn_id, - content.get("id_access_token"), - ) - except ShadowBanError: - # Pretend the request succeeded. - pass - return 200, {} - - target = requester.user - if membership_action in ["invite", "ban", "unban", "kick"]: - assert_params_in_dict(content, ["user_id"]) - target = UserID.from_string(content["user_id"]) - - event_content = None - if "reason" in content: - event_content = {"reason": content["reason"]} - - try: - await self.room_member_handler.update_membership( - requester=requester, - target=target, - room_id=room_id, - action=membership_action, - txn_id=txn_id, - third_party_signed=content.get("third_party_signed", None), - content=event_content, - ) - except ShadowBanError: - # Pretend the request succeeded. - pass - - return_value = {} - - if membership_action == "join": - return_value["room_id"] = room_id - - return 200, return_value - - def _has_3pid_invite_keys(self, content): - for key in {"id_server", "medium", "address"}: - if key not in content: - return False - return True - - def on_PUT(self, request, room_id, membership_action, txn_id): - set_tag("txn_id", txn_id) - - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id, membership_action, txn_id - ) - - -class RoomRedactEventRestServlet(TransactionRestServlet): - def __init__(self, hs): - super().__init__(hs) - self.event_creation_handler = hs.get_event_creation_handler() - self.auth = hs.get_auth() - - def register(self, http_server): - PATTERNS = "/rooms/(?P[^/]*)/redact/(?P[^/]*)" - register_txn_path(self, PATTERNS, http_server) - - async def on_POST(self, request, room_id, event_id, txn_id=None): - requester = await self.auth.get_user_by_req(request) - content = parse_json_object_from_request(request) - - try: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.Redaction, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "redacts": event_id, - }, - txn_id=txn_id, - ) - event_id = event.event_id - except ShadowBanError: - event_id = "$" + random_string(43) - - set_tag("event_id", event_id) - return 200, {"event_id": event_id} - - def on_PUT(self, request, room_id, event_id, txn_id): - set_tag("txn_id", txn_id) - - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id, event_id, txn_id - ) - - -class RoomTypingRestServlet(RestServlet): - PATTERNS = client_patterns( - "/rooms/(?P[^/]*)/typing/(?P[^/]*)$", v1=True - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.presence_handler = hs.get_presence_handler() - self.auth = hs.get_auth() - - # If we're not on the typing writer instance we should scream if we get - # requests. - self._is_typing_writer = ( - hs.config.worker.writers.typing == hs.get_instance_name() - ) - - async def on_PUT(self, request, room_id, user_id): - requester = await self.auth.get_user_by_req(request) - - if not self._is_typing_writer: - raise Exception("Got /typing request on instance that is not typing writer") - - room_id = urlparse.unquote(room_id) - target_user = UserID.from_string(urlparse.unquote(user_id)) - - content = parse_json_object_from_request(request) - - await self.presence_handler.bump_presence_active_time(requester.user) - - # Limit timeout to stop people from setting silly typing timeouts. - timeout = min(content.get("timeout", 30000), 120000) - - # Defer getting the typing handler since it will raise on workers. - typing_handler = self.hs.get_typing_writer_handler() - - try: - if content["typing"]: - await typing_handler.started_typing( - target_user=target_user, - requester=requester, - room_id=room_id, - timeout=timeout, - ) - else: - await typing_handler.stopped_typing( - target_user=target_user, requester=requester, room_id=room_id - ) - except ShadowBanError: - # Pretend this worked without error. - pass - - return 200, {} - - -class RoomAliasListServlet(RestServlet): - PATTERNS = [ - re.compile( - r"^/_matrix/client/unstable/org\.matrix\.msc2432" - r"/rooms/(?P[^/]*)/aliases" - ), - ] + list(client_patterns("/rooms/(?P[^/]*)/aliases$", unstable=False)) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.directory_handler = hs.get_directory_handler() - - async def on_GET(self, request, room_id): - requester = await self.auth.get_user_by_req(request) - - alias_list = await self.directory_handler.get_aliases_for_room( - requester, room_id - ) - - return 200, {"aliases": alias_list} - - -class SearchRestServlet(RestServlet): - PATTERNS = client_patterns("/search$", v1=True) - - def __init__(self, hs): - super().__init__() - self.search_handler = hs.get_search_handler() - self.auth = hs.get_auth() - - async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request) - - content = parse_json_object_from_request(request) - - batch = parse_string(request, "next_batch") - results = await self.search_handler.search(requester.user, content, batch) - - return 200, results - - -class JoinedRoomsRestServlet(RestServlet): - PATTERNS = client_patterns("/joined_rooms$", v1=True) - - def __init__(self, hs): - super().__init__() - self.store = hs.get_datastore() - self.auth = hs.get_auth() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - room_ids = await self.store.get_rooms_for_user(requester.user.to_string()) - return 200, {"joined_rooms": list(room_ids)} - - -def register_txn_path(servlet, regex_string, http_server, with_get=False): - """Registers a transaction-based path. - - This registers two paths: - PUT regex_string/$txnid - POST regex_string - - Args: - regex_string (str): The regex string to register. Must NOT have a - trailing $ as this string will be appended to. - http_server : The http_server to register paths with. - with_get: True to also register respective GET paths for the PUTs. - """ - http_server.register_paths( - "POST", - client_patterns(regex_string + "$", v1=True), - servlet.on_POST, - servlet.__class__.__name__, - ) - http_server.register_paths( - "PUT", - client_patterns(regex_string + "/(?P[^/]*)$", v1=True), - servlet.on_PUT, - servlet.__class__.__name__, - ) - if with_get: - http_server.register_paths( - "GET", - client_patterns(regex_string + "/(?P[^/]*)$", v1=True), - servlet.on_GET, - servlet.__class__.__name__, - ) - - -class RoomSpaceSummaryRestServlet(RestServlet): - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc2946" - "/rooms/(?P[^/]*)/spaces$" - ), - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self._auth = hs.get_auth() - self._room_summary_handler = hs.get_room_summary_handler() - - async def on_GET( - self, request: SynapseRequest, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request, allow_guest=True) - - max_rooms_per_space = parse_integer(request, "max_rooms_per_space") - if max_rooms_per_space is not None and max_rooms_per_space < 0: - raise SynapseError( - 400, - "Value for 'max_rooms_per_space' must be a non-negative integer", - Codes.BAD_JSON, - ) - - return 200, await self._room_summary_handler.get_space_summary( - requester.user.to_string(), - room_id, - suggested_only=parse_boolean(request, "suggested_only", default=False), - max_rooms_per_space=max_rooms_per_space, - ) - - # TODO When switching to the stable endpoint, remove the POST handler. - async def on_POST( - self, request: SynapseRequest, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request, allow_guest=True) - content = parse_json_object_from_request(request) - - suggested_only = content.get("suggested_only", False) - if not isinstance(suggested_only, bool): - raise SynapseError( - 400, "'suggested_only' must be a boolean", Codes.BAD_JSON - ) - - max_rooms_per_space = content.get("max_rooms_per_space") - if max_rooms_per_space is not None: - if not isinstance(max_rooms_per_space, int): - raise SynapseError( - 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON - ) - if max_rooms_per_space < 0: - raise SynapseError( - 400, - "Value for 'max_rooms_per_space' must be a non-negative integer", - Codes.BAD_JSON, - ) - - return 200, await self._room_summary_handler.get_space_summary( - requester.user.to_string(), - room_id, - suggested_only=suggested_only, - max_rooms_per_space=max_rooms_per_space, - ) - - -class RoomHierarchyRestServlet(RestServlet): - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc2946" - "/rooms/(?P[^/]*)/hierarchy$" - ), - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self._auth = hs.get_auth() - self._room_summary_handler = hs.get_room_summary_handler() - - async def on_GET( - self, request: SynapseRequest, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self._auth.get_user_by_req(request, allow_guest=True) - - max_depth = parse_integer(request, "max_depth") - if max_depth is not None and max_depth < 0: - raise SynapseError( - 400, "'max_depth' must be a non-negative integer", Codes.BAD_JSON - ) - - limit = parse_integer(request, "limit") - if limit is not None and limit <= 0: - raise SynapseError( - 400, "'limit' must be a positive integer", Codes.BAD_JSON - ) - - return 200, await self._room_summary_handler.get_room_hierarchy( - requester.user.to_string(), - room_id, - suggested_only=parse_boolean(request, "suggested_only", default=False), - max_depth=max_depth, - limit=limit, - from_token=parse_string(request, "from"), - ) - - -class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet): - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/im.nheko.summary" - "/rooms/(?P[^/]*)/summary$" - ), - ) - - def __init__(self, hs: "HomeServer"): - super().__init__(hs) - self._auth = hs.get_auth() - self._room_summary_handler = hs.get_room_summary_handler() - - async def on_GET( - self, request: SynapseRequest, room_identifier: str - ) -> Tuple[int, JsonDict]: - try: - requester = await self._auth.get_user_by_req(request, allow_guest=True) - requester_user_id: Optional[str] = requester.user.to_string() - except MissingClientTokenError: - # auth is optional - requester_user_id = None - - # twisted.web.server.Request.args is incorrectly defined as Optional[Any] - args: Dict[bytes, List[bytes]] = request.args # type: ignore - remote_room_hosts = parse_strings_from_args(args, "via", required=False) - room_id, remote_room_hosts = await self.resolve_room_id( - room_identifier, - remote_room_hosts, - ) - - return 200, await self._room_summary_handler.get_room_summary( - requester_user_id, - room_id, - remote_room_hosts, - ) - - -def register_servlets(hs: "HomeServer", http_server, is_worker=False): - RoomStateEventRestServlet(hs).register(http_server) - RoomMemberListRestServlet(hs).register(http_server) - JoinedRoomMemberListRestServlet(hs).register(http_server) - RoomMessageListRestServlet(hs).register(http_server) - JoinRoomAliasServlet(hs).register(http_server) - RoomMembershipRestServlet(hs).register(http_server) - RoomSendEventRestServlet(hs).register(http_server) - PublicRoomListRestServlet(hs).register(http_server) - RoomStateRestServlet(hs).register(http_server) - RoomRedactEventRestServlet(hs).register(http_server) - RoomTypingRestServlet(hs).register(http_server) - RoomEventContextServlet(hs).register(http_server) - RoomSpaceSummaryRestServlet(hs).register(http_server) - RoomHierarchyRestServlet(hs).register(http_server) - if hs.config.experimental.msc3266_enabled: - RoomSummaryRestServlet(hs).register(http_server) - RoomEventServlet(hs).register(http_server) - JoinedRoomsRestServlet(hs).register(http_server) - RoomAliasListServlet(hs).register(http_server) - SearchRestServlet(hs).register(http_server) - - # Some servlets only get registered for the main process. - if not is_worker: - RoomCreateRestServlet(hs).register(http_server) - RoomForgetRestServlet(hs).register(http_server) - - -def register_deprecated_servlets(hs, http_server): - RoomInitialSyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py deleted file mode 100644 index c780ffded5..0000000000 --- a/synapse/rest/client/v1/voip.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import base64 -import hashlib -import hmac - -from synapse.http.servlet import RestServlet -from synapse.rest.client.v2_alpha._base import client_patterns - - -class VoipRestServlet(RestServlet): - PATTERNS = client_patterns("/voip/turnServer$", v1=True) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req( - request, self.hs.config.turn_allow_guests - ) - - turnUris = self.hs.config.turn_uris - turnSecret = self.hs.config.turn_shared_secret - turnUsername = self.hs.config.turn_username - turnPassword = self.hs.config.turn_password - userLifetime = self.hs.config.turn_user_lifetime - - if turnUris and turnSecret and userLifetime: - expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000 - username = "%d:%s" % (expiry, requester.user.to_string()) - - mac = hmac.new( - turnSecret.encode(), msg=username.encode(), digestmod=hashlib.sha1 - ) - # We need to use standard padded base64 encoding here - # encode_base64 because we need to add the standard padding to get the - # same result as the TURN server. - password = base64.b64encode(mac.digest()).decode("ascii") - - elif turnUris and turnUsername and turnPassword and userLifetime: - username = turnUsername - password = turnPassword - - else: - return 200, {} - - return ( - 200, - { - "username": username, - "password": password, - "ttl": userLifetime / 1000, - "uris": turnUris, - }, - ) - - -def register_servlets(hs, http_server): - VoipRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py deleted file mode 100644 index 5e83dba2ed..0000000000 --- a/synapse/rest/client/v2_alpha/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py deleted file mode 100644 index 0443f4571c..0000000000 --- a/synapse/rest/client/v2_alpha/_base.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This module contains base REST classes for constructing client v1 servlets. -""" -import logging -import re -from typing import Iterable, Pattern - -from synapse.api.errors import InteractiveAuthIncompleteError -from synapse.api.urls import CLIENT_API_PREFIX -from synapse.types import JsonDict - -logger = logging.getLogger(__name__) - - -def client_patterns( - path_regex: str, - releases: Iterable[int] = (0,), - unstable: bool = True, - v1: bool = False, -) -> Iterable[Pattern]: - """Creates a regex compiled client path with the correct client path - prefix. - - Args: - path_regex: The regex string to match. This should NOT have a ^ - as this will be prefixed. - releases: An iterable of releases to include this endpoint under. - unstable: If true, include this endpoint under the "unstable" prefix. - v1: If true, include this endpoint under the "api/v1" prefix. - Returns: - An iterable of patterns. - """ - patterns = [] - - if unstable: - unstable_prefix = CLIENT_API_PREFIX + "/unstable" - patterns.append(re.compile("^" + unstable_prefix + path_regex)) - if v1: - v1_prefix = CLIENT_API_PREFIX + "/api/v1" - patterns.append(re.compile("^" + v1_prefix + path_regex)) - for release in releases: - new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,) - patterns.append(re.compile("^" + new_prefix + path_regex)) - - return patterns - - -def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) -> None: - """ - Enforces a maximum limit of a timeline query. - - Params: - filter_json: The timeline query to modify. - filter_timeline_limit: The maximum limit to allow, passing -1 will - disable enforcing a maximum limit. - """ - if filter_timeline_limit < 0: - return # no upper limits - timeline = filter_json.get("room", {}).get("timeline", {}) - if "limit" in timeline: - filter_json["room"]["timeline"]["limit"] = min( - filter_json["room"]["timeline"]["limit"], filter_timeline_limit - ) - - -def interactive_auth_handler(orig): - """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors - - Takes a on_POST method which returns an Awaitable (errcode, body) response - and adds exception handling to turn a InteractiveAuthIncompleteError into - a 401 response. - - Normal usage is: - - @interactive_auth_handler - async def on_POST(self, request): - # ... - await self.auth_handler.check_auth - """ - - async def wrapped(*args, **kwargs): - try: - return await orig(*args, **kwargs) - except InteractiveAuthIncompleteError as e: - return 401, e.result - - return wrapped diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py deleted file mode 100644 index fb5ad2906e..0000000000 --- a/synapse/rest/client/v2_alpha/account.py +++ /dev/null @@ -1,910 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import random -from http import HTTPStatus -from typing import TYPE_CHECKING -from urllib.parse import urlparse - -from synapse.api.constants import LoginType -from synapse.api.errors import ( - Codes, - InteractiveAuthIncompleteError, - SynapseError, - ThreepidValidationError, -) -from synapse.config.emailconfig import ThreepidBehaviour -from synapse.handlers.ui_auth import UIAuthSessionDataConstants -from synapse.http.server import finish_request, respond_with_html -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_json_object_from_request, - parse_string, -) -from synapse.metrics import threepid_send_requests -from synapse.push.mailer import Mailer -from synapse.util.msisdn import phone_number_to_msisdn -from synapse.util.stringutils import assert_valid_client_secret, random_string -from synapse.util.threepids import check_3pid_allowed, validate_email - -from ._base import client_patterns, interactive_auth_handler - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -logger = logging.getLogger(__name__) - - -class EmailPasswordRequestTokenRestServlet(RestServlet): - PATTERNS = client_patterns("/account/password/email/requestToken$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.datastore = hs.get_datastore() - self.config = hs.config - self.identity_handler = hs.get_identity_handler() - - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.mailer = Mailer( - hs=self.hs, - app_name=self.config.email_app_name, - template_html=self.config.email_password_reset_template_html, - template_text=self.config.email_password_reset_template_text, - ) - - async def on_POST(self, request): - if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "User password resets have been disabled due to lack of email config" - ) - raise SynapseError( - 400, "Email-based password resets have been disabled on this server" - ) - - body = parse_json_object_from_request(request) - - assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) - - # Extract params from body - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) - - # Canonicalise the email address. The addresses are all stored canonicalised - # in the database. This allows the user to reset his password without having to - # know the exact spelling (eg. upper and lower case) of address in the database. - # Stored in the database "foo@bar.com" - # User requests with "FOO@bar.com" would raise a Not Found error - try: - email = validate_email(body["email"]) - except ValueError as e: - raise SynapseError(400, str(e)) - send_attempt = body["send_attempt"] - next_link = body.get("next_link") # Optional param - - if next_link: - # Raise if the provided next_link value isn't valid - assert_valid_next_link(self.hs, next_link) - - await self.identity_handler.ratelimit_request_token_requests( - request, "email", email - ) - - # The email will be sent to the stored address. - # This avoids a potential account hijack by requesting a password reset to - # an email address which is controlled by the attacker but which, after - # canonicalisation, matches the one in our database. - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( - "email", email - ) - - if existing_user_id is None: - if self.config.request_token_inhibit_3pid_errors: - # Make the client think the operation succeeded. See the rationale in the - # comments for request_token_inhibit_3pid_errors. - # Also wait for some random amount of time between 100ms and 1s to make it - # look like we did something. - await self.hs.get_clock().sleep(random.randint(1, 10) / 10) - return 200, {"sid": random_string(16)} - - raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) - - if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email - - # Have the configured identity server handle the request - ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, - email, - client_secret, - send_attempt, - next_link, - ) - else: - # Send password reset emails from Synapse - sid = await self.identity_handler.send_threepid_validation( - email, - client_secret, - send_attempt, - self.mailer.send_password_reset_mail, - next_link, - ) - - # Wrap the session id in a JSON object - ret = {"sid": sid} - - threepid_send_requests.labels(type="email", reason="password_reset").observe( - send_attempt - ) - - return 200, ret - - -class PasswordRestServlet(RestServlet): - PATTERNS = client_patterns("/account/password$") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() - self.datastore = self.hs.get_datastore() - self.password_policy_handler = hs.get_password_policy_handler() - self._set_password_handler = hs.get_set_password_handler() - - @interactive_auth_handler - async def on_POST(self, request): - body = parse_json_object_from_request(request) - - # we do basic sanity checks here because the auth layer will store these - # in sessions. Pull out the new password provided to us. - new_password = body.pop("new_password", None) - if new_password is not None: - if not isinstance(new_password, str) or len(new_password) > 512: - raise SynapseError(400, "Invalid password") - self.password_policy_handler.validate_password(new_password) - - # there are two possibilities here. Either the user does not have an - # access token, and needs to do a password reset; or they have one and - # need to validate their identity. - # - # In the first case, we offer a couple of means of identifying - # themselves (email and msisdn, though it's unclear if msisdn actually - # works). - # - # In the second case, we require a password to confirm their identity. - - if self.auth.has_access_token(request): - requester = await self.auth.get_user_by_req(request) - try: - params, session_id = await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - "modify your account password", - ) - except InteractiveAuthIncompleteError as e: - # The user needs to provide more steps to complete auth, but - # they're not required to provide the password again. - # - # If a password is available now, hash the provided password and - # store it for later. - if new_password: - password_hash = await self.auth_handler.hash(new_password) - await self.auth_handler.set_session_data( - e.session_id, - UIAuthSessionDataConstants.PASSWORD_HASH, - password_hash, - ) - raise - user_id = requester.user.to_string() - else: - requester = None - try: - result, params, session_id = await self.auth_handler.check_ui_auth( - [[LoginType.EMAIL_IDENTITY]], - request, - body, - "modify your account password", - ) - except InteractiveAuthIncompleteError as e: - # The user needs to provide more steps to complete auth, but - # they're not required to provide the password again. - # - # If a password is available now, hash the provided password and - # store it for later. - if new_password: - password_hash = await self.auth_handler.hash(new_password) - await self.auth_handler.set_session_data( - e.session_id, - UIAuthSessionDataConstants.PASSWORD_HASH, - password_hash, - ) - raise - - if LoginType.EMAIL_IDENTITY in result: - threepid = result[LoginType.EMAIL_IDENTITY] - if "medium" not in threepid or "address" not in threepid: - raise SynapseError(500, "Malformed threepid") - if threepid["medium"] == "email": - # For emails, canonicalise the address. - # We store all email addresses canonicalised in the DB. - # (See add_threepid in synapse/handlers/auth.py) - try: - threepid["address"] = validate_email(threepid["address"]) - except ValueError as e: - raise SynapseError(400, str(e)) - # if using email, we must know about the email they're authing with! - threepid_user_id = await self.datastore.get_user_id_by_threepid( - threepid["medium"], threepid["address"] - ) - if not threepid_user_id: - raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) - user_id = threepid_user_id - else: - logger.error("Auth succeeded but no known type! %r", result.keys()) - raise SynapseError(500, "", Codes.UNKNOWN) - - # If we have a password in this request, prefer it. Otherwise, use the - # password hash from an earlier request. - if new_password: - password_hash = await self.auth_handler.hash(new_password) - elif session_id is not None: - password_hash = await self.auth_handler.get_session_data( - session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None - ) - else: - # UI validation was skipped, but the request did not include a new - # password. - password_hash = None - if not password_hash: - raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) - - logout_devices = params.get("logout_devices", True) - - await self._set_password_handler.set_password( - user_id, password_hash, logout_devices, requester - ) - - return 200, {} - - -class DeactivateAccountRestServlet(RestServlet): - PATTERNS = client_patterns("/account/deactivate$") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() - self._deactivate_account_handler = hs.get_deactivate_account_handler() - - @interactive_auth_handler - async def on_POST(self, request): - body = parse_json_object_from_request(request) - erase = body.get("erase", False) - if not isinstance(erase, bool): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Param 'erase' must be a boolean, if given", - Codes.BAD_JSON, - ) - - requester = await self.auth.get_user_by_req(request) - - # allow ASes to deactivate their own users - if requester.app_service: - await self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase, requester - ) - return 200, {} - - await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - "deactivate your account", - ) - result = await self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), - erase, - requester, - id_server=body.get("id_server"), - ) - if result: - id_server_unbind_result = "success" - else: - id_server_unbind_result = "no-support" - - return 200, {"id_server_unbind_result": id_server_unbind_result} - - -class EmailThreepidRequestTokenRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid/email/requestToken$") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.config = hs.config - self.identity_handler = hs.get_identity_handler() - self.store = self.hs.get_datastore() - - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.mailer = Mailer( - hs=self.hs, - app_name=self.config.email_app_name, - template_html=self.config.email_add_threepid_template_html, - template_text=self.config.email_add_threepid_template_text, - ) - - async def on_POST(self, request): - if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "Adding emails have been disabled due to lack of an email config" - ) - raise SynapseError( - 400, "Adding an email to your account is disabled on this server" - ) - - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) - - # Canonicalise the email address. The addresses are all stored canonicalised - # in the database. - # This ensures that the validation email is sent to the canonicalised address - # as it will later be entered into the database. - # Otherwise the email will be sent to "FOO@bar.com" and stored as - # "foo@bar.com" in database. - try: - email = validate_email(body["email"]) - except ValueError as e: - raise SynapseError(400, str(e)) - send_attempt = body["send_attempt"] - next_link = body.get("next_link") # Optional param - - if not check_3pid_allowed(self.hs, "email", email): - raise SynapseError( - 403, - "Your email domain is not authorized on this server", - Codes.THREEPID_DENIED, - ) - - await self.identity_handler.ratelimit_request_token_requests( - request, "email", email - ) - - if next_link: - # Raise if the provided next_link value isn't valid - assert_valid_next_link(self.hs, next_link) - - existing_user_id = await self.store.get_user_id_by_threepid("email", email) - - if existing_user_id is not None: - if self.config.request_token_inhibit_3pid_errors: - # Make the client think the operation succeeded. See the rationale in the - # comments for request_token_inhibit_3pid_errors. - # Also wait for some random amount of time between 100ms and 1s to make it - # look like we did something. - await self.hs.get_clock().sleep(random.randint(1, 10) / 10) - return 200, {"sid": random_string(16)} - - raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) - - if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email - - # Have the configured identity server handle the request - ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, - email, - client_secret, - send_attempt, - next_link, - ) - else: - # Send threepid validation emails from Synapse - sid = await self.identity_handler.send_threepid_validation( - email, - client_secret, - send_attempt, - self.mailer.send_add_threepid_mail, - next_link, - ) - - # Wrap the session id in a JSON object - ret = {"sid": sid} - - threepid_send_requests.labels(type="email", reason="add_threepid").observe( - send_attempt - ) - - return 200, ret - - -class MsisdnThreepidRequestTokenRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$") - - def __init__(self, hs: "HomeServer"): - self.hs = hs - super().__init__() - self.store = self.hs.get_datastore() - self.identity_handler = hs.get_identity_handler() - - async def on_POST(self, request): - body = parse_json_object_from_request(request) - assert_params_in_dict( - body, ["client_secret", "country", "phone_number", "send_attempt"] - ) - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) - - country = body["country"] - phone_number = body["phone_number"] - send_attempt = body["send_attempt"] - next_link = body.get("next_link") # Optional param - - msisdn = phone_number_to_msisdn(country, phone_number) - - if not check_3pid_allowed(self.hs, "msisdn", msisdn): - raise SynapseError( - 403, - "Account phone numbers are not authorized on this server", - Codes.THREEPID_DENIED, - ) - - await self.identity_handler.ratelimit_request_token_requests( - request, "msisdn", msisdn - ) - - if next_link: - # Raise if the provided next_link value isn't valid - assert_valid_next_link(self.hs, next_link) - - existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) - - if existing_user_id is not None: - if self.hs.config.request_token_inhibit_3pid_errors: - # Make the client think the operation succeeded. See the rationale in the - # comments for request_token_inhibit_3pid_errors. - # Also wait for some random amount of time between 100ms and 1s to make it - # look like we did something. - await self.hs.get_clock().sleep(random.randint(1, 10) / 10) - return 200, {"sid": random_string(16)} - - raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) - - if not self.hs.config.account_threepid_delegate_msisdn: - logger.warning( - "No upstream msisdn account_threepid_delegate configured on the server to " - "handle this request" - ) - raise SynapseError( - 400, - "Adding phone numbers to user account is not supported by this homeserver", - ) - - ret = await self.identity_handler.requestMsisdnToken( - self.hs.config.account_threepid_delegate_msisdn, - country, - phone_number, - client_secret, - send_attempt, - next_link, - ) - - threepid_send_requests.labels(type="msisdn", reason="add_threepid").observe( - send_attempt - ) - - return 200, ret - - -class AddThreepidEmailSubmitTokenServlet(RestServlet): - """Handles 3PID validation token submission for adding an email to a user's account""" - - PATTERNS = client_patterns( - "/add_threepid/email/submit_token$", releases=(), unstable=True - ) - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.config = hs.config - self.clock = hs.get_clock() - self.store = hs.get_datastore() - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self._failure_email_template = ( - self.config.email_add_threepid_template_failure_html - ) - - async def on_GET(self, request): - if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "Adding emails have been disabled due to lack of an email config" - ) - raise SynapseError( - 400, "Adding an email to your account is disabled on this server" - ) - elif self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - raise SynapseError( - 400, - "This homeserver is not validating threepids. Use an identity server " - "instead.", - ) - - sid = parse_string(request, "sid", required=True) - token = parse_string(request, "token", required=True) - client_secret = parse_string(request, "client_secret", required=True) - assert_valid_client_secret(client_secret) - - # Attempt to validate a 3PID session - try: - # Mark the session as valid - next_link = await self.store.validate_threepid_session( - sid, client_secret, token, self.clock.time_msec() - ) - - # Perform a 302 redirect if next_link is set - if next_link: - request.setResponseCode(302) - request.setHeader("Location", next_link) - finish_request(request) - return None - - # Otherwise show the success template - html = self.config.email_add_threepid_template_success_html_content - status_code = 200 - except ThreepidValidationError as e: - status_code = e.code - - # Show a failure page with a reason - template_vars = {"failure_reason": e.msg} - html = self._failure_email_template.render(**template_vars) - - respond_with_html(request, status_code, html) - - -class AddThreepidMsisdnSubmitTokenServlet(RestServlet): - """Handles 3PID validation token submission for adding a phone number to a user's - account - """ - - PATTERNS = client_patterns( - "/add_threepid/msisdn/submit_token$", releases=(), unstable=True - ) - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.config = hs.config - self.clock = hs.get_clock() - self.store = hs.get_datastore() - self.identity_handler = hs.get_identity_handler() - - async def on_POST(self, request): - if not self.config.account_threepid_delegate_msisdn: - raise SynapseError( - 400, - "This homeserver is not validating phone numbers. Use an identity server " - "instead.", - ) - - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["client_secret", "sid", "token"]) - assert_valid_client_secret(body["client_secret"]) - - # Proxy submit_token request to msisdn threepid delegate - response = await self.identity_handler.proxy_msisdn_submit_token( - self.config.account_threepid_delegate_msisdn, - body["client_secret"], - body["sid"], - body["token"], - ) - return 200, response - - -class ThreepidRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid$") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.identity_handler = hs.get_identity_handler() - self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() - self.datastore = self.hs.get_datastore() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request) - - threepids = await self.datastore.user_get_threepids(requester.user.to_string()) - - return 200, {"threepids": threepids} - - async def on_POST(self, request): - if not self.hs.config.enable_3pid_changes: - raise SynapseError( - 400, "3PID changes are disabled on this server", Codes.FORBIDDEN - ) - - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - body = parse_json_object_from_request(request) - - threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds") - if threepid_creds is None: - raise SynapseError( - 400, "Missing param three_pid_creds", Codes.MISSING_PARAM - ) - assert_params_in_dict(threepid_creds, ["client_secret", "sid"]) - - sid = threepid_creds["sid"] - client_secret = threepid_creds["client_secret"] - assert_valid_client_secret(client_secret) - - validation_session = await self.identity_handler.validate_threepid_session( - client_secret, sid - ) - if validation_session: - await self.auth_handler.add_threepid( - user_id, - validation_session["medium"], - validation_session["address"], - validation_session["validated_at"], - ) - return 200, {} - - raise SynapseError( - 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED - ) - - -class ThreepidAddRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid/add$") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.identity_handler = hs.get_identity_handler() - self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() - - @interactive_auth_handler - async def on_POST(self, request): - if not self.hs.config.enable_3pid_changes: - raise SynapseError( - 400, "3PID changes are disabled on this server", Codes.FORBIDDEN - ) - - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - body = parse_json_object_from_request(request) - - assert_params_in_dict(body, ["client_secret", "sid"]) - sid = body["sid"] - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) - - await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - "add a third-party identifier to your account", - ) - - validation_session = await self.identity_handler.validate_threepid_session( - client_secret, sid - ) - if validation_session: - await self.auth_handler.add_threepid( - user_id, - validation_session["medium"], - validation_session["address"], - validation_session["validated_at"], - ) - return 200, {} - - raise SynapseError( - 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED - ) - - -class ThreepidBindRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid/bind$") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.identity_handler = hs.get_identity_handler() - self.auth = hs.get_auth() - - async def on_POST(self, request): - body = parse_json_object_from_request(request) - - assert_params_in_dict(body, ["id_server", "sid", "client_secret"]) - id_server = body["id_server"] - sid = body["sid"] - id_access_token = body.get("id_access_token") # optional - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) - - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - - await self.identity_handler.bind_threepid( - client_secret, sid, user_id, id_server, id_access_token - ) - - return 200, {} - - -class ThreepidUnbindRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid/unbind$") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.identity_handler = hs.get_identity_handler() - self.auth = hs.get_auth() - self.datastore = self.hs.get_datastore() - - async def on_POST(self, request): - """Unbind the given 3pid from a specific identity server, or identity servers that are - known to have this 3pid bound - """ - requester = await self.auth.get_user_by_req(request) - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["medium", "address"]) - - medium = body.get("medium") - address = body.get("address") - id_server = body.get("id_server") - - # Attempt to unbind the threepid from an identity server. If id_server is None, try to - # unbind from all identity servers this threepid has been added to in the past - result = await self.identity_handler.try_unbind_threepid( - requester.user.to_string(), - {"address": address, "medium": medium, "id_server": id_server}, - ) - return 200, {"id_server_unbind_result": "success" if result else "no-support"} - - -class ThreepidDeleteRestServlet(RestServlet): - PATTERNS = client_patterns("/account/3pid/delete$") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() - - async def on_POST(self, request): - if not self.hs.config.enable_3pid_changes: - raise SynapseError( - 400, "3PID changes are disabled on this server", Codes.FORBIDDEN - ) - - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["medium", "address"]) - - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - - try: - ret = await self.auth_handler.delete_threepid( - user_id, body["medium"], body["address"], body.get("id_server") - ) - except Exception: - # NB. This endpoint should succeed if there is nothing to - # delete, so it should only throw if something is wrong - # that we ought to care about. - logger.exception("Failed to remove threepid") - raise SynapseError(500, "Failed to remove threepid") - - if ret: - id_server_unbind_result = "success" - else: - id_server_unbind_result = "no-support" - - return 200, {"id_server_unbind_result": id_server_unbind_result} - - -def assert_valid_next_link(hs: "HomeServer", next_link: str): - """ - Raises a SynapseError if a given next_link value is invalid - - next_link is valid if the scheme is http(s) and the next_link.domain_whitelist config - option is either empty or contains a domain that matches the one in the given next_link - - Args: - hs: The homeserver object - next_link: The next_link value given by the client - - Raises: - SynapseError: If the next_link is invalid - """ - valid = True - - # Parse the contents of the URL - next_link_parsed = urlparse(next_link) - - # Scheme must not point to the local drive - if next_link_parsed.scheme == "file": - valid = False - - # If the domain whitelist is set, the domain must be in it - if ( - valid - and hs.config.next_link_domain_whitelist is not None - and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist - ): - valid = False - - if not valid: - raise SynapseError( - 400, - "'next_link' domain not included in whitelist, or not http(s)", - errcode=Codes.INVALID_PARAM, - ) - - -class WhoamiRestServlet(RestServlet): - PATTERNS = client_patterns("/account/whoami$") - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request) - - response = {"user_id": requester.user.to_string()} - - # Appservices and similar accounts do not have device IDs - # that we can report on, so exclude them for compliance. - if requester.device_id is not None: - response["device_id"] = requester.device_id - - return 200, response - - -def register_servlets(hs, http_server): - EmailPasswordRequestTokenRestServlet(hs).register(http_server) - PasswordRestServlet(hs).register(http_server) - DeactivateAccountRestServlet(hs).register(http_server) - EmailThreepidRequestTokenRestServlet(hs).register(http_server) - MsisdnThreepidRequestTokenRestServlet(hs).register(http_server) - AddThreepidEmailSubmitTokenServlet(hs).register(http_server) - AddThreepidMsisdnSubmitTokenServlet(hs).register(http_server) - ThreepidRestServlet(hs).register(http_server) - ThreepidAddRestServlet(hs).register(http_server) - ThreepidBindRestServlet(hs).register(http_server) - ThreepidUnbindRestServlet(hs).register(http_server) - ThreepidDeleteRestServlet(hs).register(http_server) - WhoamiRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py deleted file mode 100644 index 7517e9304e..0000000000 --- a/synapse/rest/client/v2_alpha/account_data.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.errors import AuthError, NotFoundError, SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class AccountDataServlet(RestServlet): - """ - PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1 - GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1 - """ - - PATTERNS = client_patterns( - "/user/(?P[^/]*)/account_data/(?P[^/]*)" - ) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.handler = hs.get_account_data_handler() - - async def on_PUT(self, request, user_id, account_data_type): - requester = await self.auth.get_user_by_req(request) - if user_id != requester.user.to_string(): - raise AuthError(403, "Cannot add account data for other users.") - - body = parse_json_object_from_request(request) - - await self.handler.add_account_data_for_user(user_id, account_data_type, body) - - return 200, {} - - async def on_GET(self, request, user_id, account_data_type): - requester = await self.auth.get_user_by_req(request) - if user_id != requester.user.to_string(): - raise AuthError(403, "Cannot get account data for other users.") - - event = await self.store.get_global_account_data_by_type_for_user( - account_data_type, user_id - ) - - if event is None: - raise NotFoundError("Account data not found") - - return 200, event - - -class RoomAccountDataServlet(RestServlet): - """ - PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 - GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 - """ - - PATTERNS = client_patterns( - "/user/(?P[^/]*)" - "/rooms/(?P[^/]*)" - "/account_data/(?P[^/]*)" - ) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.handler = hs.get_account_data_handler() - - async def on_PUT(self, request, user_id, room_id, account_data_type): - requester = await self.auth.get_user_by_req(request) - if user_id != requester.user.to_string(): - raise AuthError(403, "Cannot add account data for other users.") - - body = parse_json_object_from_request(request) - - if account_data_type == "m.fully_read": - raise SynapseError( - 405, - "Cannot set m.fully_read through this API." - " Use /rooms/!roomId:server.name/read_markers", - ) - - await self.handler.add_account_data_to_room( - user_id, room_id, account_data_type, body - ) - - return 200, {} - - async def on_GET(self, request, user_id, room_id, account_data_type): - requester = await self.auth.get_user_by_req(request) - if user_id != requester.user.to_string(): - raise AuthError(403, "Cannot get account data for other users.") - - event = await self.store.get_account_data_for_room_and_type( - user_id, room_id, account_data_type - ) - - if event is None: - raise NotFoundError("Room account data not found") - - return 200, event - - -def register_servlets(hs, http_server): - AccountDataServlet(hs).register(http_server) - RoomAccountDataServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py deleted file mode 100644 index 3ebe401861..0000000000 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.errors import SynapseError -from synapse.http.server import respond_with_html -from synapse.http.servlet import RestServlet - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class AccountValidityRenewServlet(RestServlet): - PATTERNS = client_patterns("/account_validity/renew$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - - self.hs = hs - self.account_activity_handler = hs.get_account_validity_handler() - self.auth = hs.get_auth() - self.account_renewed_template = ( - hs.config.account_validity.account_validity_account_renewed_template - ) - self.account_previously_renewed_template = ( - hs.config.account_validity.account_validity_account_previously_renewed_template - ) - self.invalid_token_template = ( - hs.config.account_validity.account_validity_invalid_token_template - ) - - async def on_GET(self, request): - if b"token" not in request.args: - raise SynapseError(400, "Missing renewal token") - renewal_token = request.args[b"token"][0] - - ( - token_valid, - token_stale, - expiration_ts, - ) = await self.account_activity_handler.renew_account( - renewal_token.decode("utf8") - ) - - if token_valid: - status_code = 200 - response = self.account_renewed_template.render(expiration_ts=expiration_ts) - elif token_stale: - status_code = 200 - response = self.account_previously_renewed_template.render( - expiration_ts=expiration_ts - ) - else: - status_code = 404 - response = self.invalid_token_template.render(expiration_ts=expiration_ts) - - respond_with_html(request, status_code, response) - - -class AccountValiditySendMailServlet(RestServlet): - PATTERNS = client_patterns("/account_validity/send_mail$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - - self.hs = hs - self.account_activity_handler = hs.get_account_validity_handler() - self.auth = hs.get_auth() - self.account_validity_renew_by_email_enabled = ( - hs.config.account_validity.account_validity_renew_by_email_enabled - ) - - async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request, allow_expired=True) - user_id = requester.user.to_string() - await self.account_activity_handler.send_renewal_email_to_user(user_id) - - return 200, {} - - -def register_servlets(hs, http_server): - AccountValidityRenewServlet(hs).register(http_server) - AccountValiditySendMailServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py deleted file mode 100644 index 6ea1b50a62..0000000000 --- a/synapse/rest/client/v2_alpha/auth.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import TYPE_CHECKING - -from synapse.api.constants import LoginType -from synapse.api.errors import SynapseError -from synapse.api.urls import CLIENT_API_PREFIX -from synapse.http.server import respond_with_html -from synapse.http.servlet import RestServlet, parse_string - -from ._base import client_patterns - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class AuthRestServlet(RestServlet): - """ - Handles Client / Server API authentication in any situations where it - cannot be handled in the normal flow (with requests to the same endpoint). - Current use is for web fallback auth. - """ - - PATTERNS = client_patterns(r"/auth/(?P[\w\.]*)/fallback/web") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() - self.registration_handler = hs.get_registration_handler() - self.recaptcha_template = hs.config.recaptcha_template - self.terms_template = hs.config.terms_template - self.success_template = hs.config.fallback_success_template - - async def on_GET(self, request, stagetype): - session = parse_string(request, "session") - if not session: - raise SynapseError(400, "No session supplied") - - if stagetype == LoginType.RECAPTCHA: - html = self.recaptcha_template.render( - session=session, - myurl="%s/r0/auth/%s/fallback/web" - % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), - sitekey=self.hs.config.recaptcha_public_key, - ) - elif stagetype == LoginType.TERMS: - html = self.terms_template.render( - session=session, - terms_url="%s_matrix/consent?v=%s" - % (self.hs.config.public_baseurl, self.hs.config.user_consent_version), - myurl="%s/r0/auth/%s/fallback/web" - % (CLIENT_API_PREFIX, LoginType.TERMS), - ) - - elif stagetype == LoginType.SSO: - # Display a confirmation page which prompts the user to - # re-authenticate with their SSO provider. - html = await self.auth_handler.start_sso_ui_auth(request, session) - - else: - raise SynapseError(404, "Unknown auth stage type") - - # Render the HTML and return. - respond_with_html(request, 200, html) - return None - - async def on_POST(self, request, stagetype): - - session = parse_string(request, "session") - if not session: - raise SynapseError(400, "No session supplied") - - if stagetype == LoginType.RECAPTCHA: - response = parse_string(request, "g-recaptcha-response") - - if not response: - raise SynapseError(400, "No captcha response supplied") - - authdict = {"response": response, "session": session} - - success = await self.auth_handler.add_oob_auth( - LoginType.RECAPTCHA, authdict, request.getClientIP() - ) - - if success: - html = self.success_template.render() - else: - html = self.recaptcha_template.render( - session=session, - myurl="%s/r0/auth/%s/fallback/web" - % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), - sitekey=self.hs.config.recaptcha_public_key, - ) - elif stagetype == LoginType.TERMS: - authdict = {"session": session} - - success = await self.auth_handler.add_oob_auth( - LoginType.TERMS, authdict, request.getClientIP() - ) - - if success: - html = self.success_template.render() - else: - html = self.terms_template.render( - session=session, - terms_url="%s_matrix/consent?v=%s" - % ( - self.hs.config.public_baseurl, - self.hs.config.user_consent_version, - ), - myurl="%s/r0/auth/%s/fallback/web" - % (CLIENT_API_PREFIX, LoginType.TERMS), - ) - elif stagetype == LoginType.SSO: - # The SSO fallback workflow should not post here, - raise SynapseError(404, "Fallback SSO auth does not support POST requests.") - else: - raise SynapseError(404, "Unknown auth stage type") - - # Render the HTML and return. - respond_with_html(request, 200, html) - return None - - -def register_servlets(hs, http_server): - AuthRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py deleted file mode 100644 index 88e3aac797..0000000000 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2019 New Vector -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import TYPE_CHECKING, Tuple - -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES -from synapse.http.servlet import RestServlet -from synapse.http.site import SynapseRequest -from synapse.types import JsonDict - -from ._base import client_patterns - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class CapabilitiesRestServlet(RestServlet): - """End point to expose the capabilities of the server.""" - - PATTERNS = client_patterns("/capabilities$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.config = hs.config - self.auth = hs.get_auth() - self.auth_handler = hs.get_auth_handler() - - async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - await self.auth.get_user_by_req(request, allow_guest=True) - change_password = self.auth_handler.can_change_password() - - response = { - "capabilities": { - "m.room_versions": { - "default": self.config.default_room_version.identifier, - "available": { - v.identifier: v.disposition - for v in KNOWN_ROOM_VERSIONS.values() - }, - }, - "m.change_password": {"enabled": change_password}, - } - } - - if self.config.experimental.msc3244_enabled: - response["capabilities"]["m.room_versions"][ - "org.matrix.msc3244.room_capabilities" - ] = MSC3244_CAPABILITIES - - return 200, response - - -def register_servlets(hs: "HomeServer", http_server): - CapabilitiesRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py deleted file mode 100644 index 8b9674db06..0000000000 --- a/synapse/rest/client/v2_alpha/devices.py +++ /dev/null @@ -1,300 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2020 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api import errors -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_json_object_from_request, -) -from synapse.http.site import SynapseRequest - -from ._base import client_patterns, interactive_auth_handler - -logger = logging.getLogger(__name__) - - -class DevicesRestServlet(RestServlet): - PATTERNS = client_patterns("/devices$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - devices = await self.device_handler.get_devices_by_user( - requester.user.to_string() - ) - return 200, {"devices": devices} - - -class DeleteDevicesRestServlet(RestServlet): - """ - API for bulk deletion of devices. Accepts a JSON object with a devices - key which lists the device_ids to delete. Requires user interactive auth. - """ - - PATTERNS = client_patterns("/delete_devices") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() - self.auth_handler = hs.get_auth_handler() - - @interactive_auth_handler - async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request) - - try: - body = parse_json_object_from_request(request) - except errors.SynapseError as e: - if e.errcode == errors.Codes.NOT_JSON: - # DELETE - # deal with older clients which didn't pass a JSON dict - # the same as those that pass an empty dict - body = {} - else: - raise e - - assert_params_in_dict(body, ["devices"]) - - await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - "remove device(s) from your account", - # Users might call this multiple times in a row while cleaning up - # devices, allow a single UI auth session to be re-used. - can_skip_ui_auth=True, - ) - - await self.device_handler.delete_devices( - requester.user.to_string(), body["devices"] - ) - return 200, {} - - -class DeviceRestServlet(RestServlet): - PATTERNS = client_patterns("/devices/(?P[^/]*)$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() - self.auth_handler = hs.get_auth_handler() - - async def on_GET(self, request, device_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - device = await self.device_handler.get_device( - requester.user.to_string(), device_id - ) - return 200, device - - @interactive_auth_handler - async def on_DELETE(self, request, device_id): - requester = await self.auth.get_user_by_req(request) - - try: - body = parse_json_object_from_request(request) - - except errors.SynapseError as e: - if e.errcode == errors.Codes.NOT_JSON: - # deal with older clients which didn't pass a JSON dict - # the same as those that pass an empty dict - body = {} - else: - raise - - await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - "remove a device from your account", - # Users might call this multiple times in a row while cleaning up - # devices, allow a single UI auth session to be re-used. - can_skip_ui_auth=True, - ) - - await self.device_handler.delete_device(requester.user.to_string(), device_id) - return 200, {} - - async def on_PUT(self, request, device_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - body = parse_json_object_from_request(request) - await self.device_handler.update_device( - requester.user.to_string(), device_id, body - ) - return 200, {} - - -class DehydratedDeviceServlet(RestServlet): - """Retrieve or store a dehydrated device. - - GET /org.matrix.msc2697.v2/dehydrated_device - - HTTP/1.1 200 OK - Content-Type: application/json - - { - "device_id": "dehydrated_device_id", - "device_data": { - "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", - "account": "dehydrated_device" - } - } - - PUT /org.matrix.msc2697/dehydrated_device - Content-Type: application/json - - { - "device_data": { - "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", - "account": "dehydrated_device" - } - } - - HTTP/1.1 200 OK - Content-Type: application/json - - { - "device_id": "dehydrated_device_id" - } - - """ - - PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=()) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() - - async def on_GET(self, request: SynapseRequest): - requester = await self.auth.get_user_by_req(request) - dehydrated_device = await self.device_handler.get_dehydrated_device( - requester.user.to_string() - ) - if dehydrated_device is not None: - (device_id, device_data) = dehydrated_device - result = {"device_id": device_id, "device_data": device_data} - return (200, result) - else: - raise errors.NotFoundError("No dehydrated device available") - - async def on_PUT(self, request: SynapseRequest): - submission = parse_json_object_from_request(request) - requester = await self.auth.get_user_by_req(request) - - if "device_data" not in submission: - raise errors.SynapseError( - 400, - "device_data missing", - errcode=errors.Codes.MISSING_PARAM, - ) - elif not isinstance(submission["device_data"], dict): - raise errors.SynapseError( - 400, - "device_data must be an object", - errcode=errors.Codes.INVALID_PARAM, - ) - - device_id = await self.device_handler.store_dehydrated_device( - requester.user.to_string(), - submission["device_data"], - submission.get("initial_device_display_name", None), - ) - return 200, {"device_id": device_id} - - -class ClaimDehydratedDeviceServlet(RestServlet): - """Claim a dehydrated device. - - POST /org.matrix.msc2697.v2/dehydrated_device/claim - Content-Type: application/json - - { - "device_id": "dehydrated_device_id" - } - - HTTP/1.1 200 OK - Content-Type: application/json - - { - "success": true, - } - - """ - - PATTERNS = client_patterns( - "/org.matrix.msc2697.v2/dehydrated_device/claim", releases=() - ) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() - - async def on_POST(self, request: SynapseRequest): - requester = await self.auth.get_user_by_req(request) - - submission = parse_json_object_from_request(request) - - if "device_id" not in submission: - raise errors.SynapseError( - 400, - "device_id missing", - errcode=errors.Codes.MISSING_PARAM, - ) - elif not isinstance(submission["device_id"], str): - raise errors.SynapseError( - 400, - "device_id must be a string", - errcode=errors.Codes.INVALID_PARAM, - ) - - result = await self.device_handler.rehydrate_device( - requester.user.to_string(), - self.auth.get_access_token_from_request(request), - submission["device_id"], - ) - - return (200, result) - - -def register_servlets(hs, http_server): - DeleteDevicesRestServlet(hs).register(http_server) - DevicesRestServlet(hs).register(http_server) - DeviceRestServlet(hs).register(http_server) - DehydratedDeviceServlet(hs).register(http_server) - ClaimDehydratedDeviceServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py deleted file mode 100644 index 411667a9c8..0000000000 --- a/synapse/rest/client/v2_alpha/filter.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.types import UserID - -from ._base import client_patterns, set_timeline_upper_limit - -logger = logging.getLogger(__name__) - - -class GetFilterRestServlet(RestServlet): - PATTERNS = client_patterns("/user/(?P[^/]*)/filter/(?P[^/]*)") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.filtering = hs.get_filtering() - - async def on_GET(self, request, user_id, filter_id): - target_user = UserID.from_string(user_id) - requester = await self.auth.get_user_by_req(request) - - if target_user != requester.user: - raise AuthError(403, "Cannot get filters for other users") - - if not self.hs.is_mine(target_user): - raise AuthError(403, "Can only get filters for local users") - - try: - filter_id = int(filter_id) - except Exception: - raise SynapseError(400, "Invalid filter_id") - - try: - filter_collection = await self.filtering.get_user_filter( - user_localpart=target_user.localpart, filter_id=filter_id - ) - except StoreError as e: - if e.code != 404: - raise - raise NotFoundError("No such filter") - - return 200, filter_collection.get_filter_json() - - -class CreateFilterRestServlet(RestServlet): - PATTERNS = client_patterns("/user/(?P[^/]*)/filter") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.filtering = hs.get_filtering() - - async def on_POST(self, request, user_id): - - target_user = UserID.from_string(user_id) - requester = await self.auth.get_user_by_req(request) - - if target_user != requester.user: - raise AuthError(403, "Cannot create filters for other users") - - if not self.hs.is_mine(target_user): - raise AuthError(403, "Can only create filters for local users") - - content = parse_json_object_from_request(request) - set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit) - - filter_id = await self.filtering.add_user_filter( - user_localpart=target_user.localpart, user_filter=content - ) - - return 200, {"filter_id": str(filter_id)} - - -def register_servlets(hs, http_server): - GetFilterRestServlet(hs).register(http_server) - CreateFilterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py deleted file mode 100644 index 6285680c00..0000000000 --- a/synapse/rest/client/v2_alpha/groups.py +++ /dev/null @@ -1,957 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from functools import wraps -from typing import TYPE_CHECKING, Optional, Tuple - -from twisted.web.server import Request - -from synapse.api.constants import ( - MAX_GROUP_CATEGORYID_LENGTH, - MAX_GROUP_ROLEID_LENGTH, - MAX_GROUPID_LENGTH, -) -from synapse.api.errors import Codes, SynapseError -from synapse.handlers.groups_local import GroupsLocalHandler -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_json_object_from_request, -) -from synapse.http.site import SynapseRequest -from synapse.types import GroupID, JsonDict - -from ._base import client_patterns - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -def _validate_group_id(f): - """Wrapper to validate the form of the group ID. - - Can be applied to any on_FOO methods that accepts a group ID as a URL parameter. - """ - - @wraps(f) - def wrapper(self, request: Request, group_id: str, *args, **kwargs): - if not GroupID.is_valid(group_id): - raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) - - return f(self, request, group_id, *args, **kwargs) - - return wrapper - - -class GroupServlet(RestServlet): - """Get the group profile""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/profile$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - group_description = await self.groups_handler.get_group_profile( - group_id, requester_user_id - ) - - return 200, group_description - - @_validate_group_id - async def on_POST( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert_params_in_dict( - content, ("name", "avatar_url", "short_description", "long_description") - ) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot create group profiles." - await self.groups_handler.update_group_profile( - group_id, requester_user_id, content - ) - - return 200, {} - - -class GroupSummaryServlet(RestServlet): - """Get the full group summary""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/summary$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - get_group_summary = await self.groups_handler.get_group_summary( - group_id, requester_user_id - ) - - return 200, get_group_summary - - -class GroupSummaryRoomsCatServlet(RestServlet): - """Update/delete a rooms entry in the summary. - - Matches both: - - /groups/:group/summary/rooms/:room_id - - /groups/:group/summary/categories/:category/rooms/:room_id - """ - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/summary" - "(/categories/(?P[^/]+))?" - "/rooms/(?P[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, - request: SynapseRequest, - group_id: str, - category_id: Optional[str], - room_id: str, - ): - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - if category_id == "": - raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM) - - if category_id and len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: - raise SynapseError( - 400, - "category_id may not be longer than %s characters" - % (MAX_GROUP_CATEGORYID_LENGTH,), - Codes.INVALID_PARAM, - ) - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group summaries." - resp = await self.groups_handler.update_group_summary_room( - group_id, - requester_user_id, - room_id=room_id, - category_id=category_id, - content=content, - ) - - return 200, resp - - @_validate_group_id - async def on_DELETE( - self, request: SynapseRequest, group_id: str, category_id: str, room_id: str - ): - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group profiles." - resp = await self.groups_handler.delete_group_summary_room( - group_id, requester_user_id, room_id=room_id, category_id=category_id - ) - - return 200, resp - - -class GroupCategoryServlet(RestServlet): - """Get/add/update/delete a group category""" - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/categories/(?P[^/]+)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str, category_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - category = await self.groups_handler.get_group_category( - group_id, requester_user_id, category_id=category_id - ) - - return 200, category - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str, category_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - if not category_id: - raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM) - - if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: - raise SynapseError( - 400, - "category_id may not be longer than %s characters" - % (MAX_GROUP_CATEGORYID_LENGTH,), - Codes.INVALID_PARAM, - ) - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group categories." - resp = await self.groups_handler.update_group_category( - group_id, requester_user_id, category_id=category_id, content=content - ) - - return 200, resp - - @_validate_group_id - async def on_DELETE( - self, request: SynapseRequest, group_id: str, category_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group categories." - resp = await self.groups_handler.delete_group_category( - group_id, requester_user_id, category_id=category_id - ) - - return 200, resp - - -class GroupCategoriesServlet(RestServlet): - """Get all group categories""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/categories/$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - category = await self.groups_handler.get_group_categories( - group_id, requester_user_id - ) - - return 200, category - - -class GroupRoleServlet(RestServlet): - """Get/add/update/delete a group role""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/(?P[^/]+)$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str, role_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - category = await self.groups_handler.get_group_role( - group_id, requester_user_id, role_id=role_id - ) - - return 200, category - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str, role_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - if not role_id: - raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM) - - if len(role_id) > MAX_GROUP_ROLEID_LENGTH: - raise SynapseError( - 400, - "role_id may not be longer than %s characters" - % (MAX_GROUP_ROLEID_LENGTH,), - Codes.INVALID_PARAM, - ) - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group roles." - resp = await self.groups_handler.update_group_role( - group_id, requester_user_id, role_id=role_id, content=content - ) - - return 200, resp - - @_validate_group_id - async def on_DELETE( - self, request: SynapseRequest, group_id: str, role_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group roles." - resp = await self.groups_handler.delete_group_role( - group_id, requester_user_id, role_id=role_id - ) - - return 200, resp - - -class GroupRolesServlet(RestServlet): - """Get all group roles""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/roles/$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - category = await self.groups_handler.get_group_roles( - group_id, requester_user_id - ) - - return 200, category - - -class GroupSummaryUsersRoleServlet(RestServlet): - """Update/delete a user's entry in the summary. - - Matches both: - - /groups/:group/summary/users/:room_id - - /groups/:group/summary/roles/:role/users/:user_id - """ - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/summary" - "(/roles/(?P[^/]+))?" - "/users/(?P[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, - request: SynapseRequest, - group_id: str, - role_id: Optional[str], - user_id: str, - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - if role_id == "": - raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM) - - if role_id and len(role_id) > MAX_GROUP_ROLEID_LENGTH: - raise SynapseError( - 400, - "role_id may not be longer than %s characters" - % (MAX_GROUP_ROLEID_LENGTH,), - Codes.INVALID_PARAM, - ) - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group summaries." - resp = await self.groups_handler.update_group_summary_user( - group_id, - requester_user_id, - user_id=user_id, - role_id=role_id, - content=content, - ) - - return 200, resp - - @_validate_group_id - async def on_DELETE( - self, request: SynapseRequest, group_id: str, role_id: str, user_id: str - ): - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group summaries." - resp = await self.groups_handler.delete_group_summary_user( - group_id, requester_user_id, user_id=user_id, role_id=role_id - ) - - return 200, resp - - -class GroupRoomServlet(RestServlet): - """Get all rooms in a group""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/rooms$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - result = await self.groups_handler.get_rooms_in_group( - group_id, requester_user_id - ) - - return 200, result - - -class GroupUsersServlet(RestServlet): - """Get all users in a group""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/users$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - result = await self.groups_handler.get_users_in_group( - group_id, requester_user_id - ) - - return 200, result - - -class GroupInvitedUsersServlet(RestServlet): - """Get users invited to a group""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/invited_users$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - result = await self.groups_handler.get_invited_users_in_group( - group_id, requester_user_id - ) - - return 200, result - - -class GroupSettingJoinPolicyServlet(RestServlet): - """Set group join policy""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/settings/m.join_policy$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group join policy." - result = await self.groups_handler.set_group_join_policy( - group_id, requester_user_id, content - ) - - return 200, result - - -class GroupCreateServlet(RestServlet): - """Create a group""" - - PATTERNS = client_patterns("/create_group$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - self.server_name = hs.hostname - - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - # TODO: Create group on remote server - content = parse_json_object_from_request(request) - localpart = content.pop("localpart") - group_id = GroupID(localpart, self.server_name).to_string() - - if not localpart: - raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) - - if len(group_id) > MAX_GROUPID_LENGTH: - raise SynapseError( - 400, - "Group ID may not be longer than %s characters" % (MAX_GROUPID_LENGTH,), - Codes.INVALID_PARAM, - ) - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot create groups." - result = await self.groups_handler.create_group( - group_id, requester_user_id, content - ) - - return 200, result - - -class GroupAdminRoomsServlet(RestServlet): - """Add a room to the group""" - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify rooms in a group." - result = await self.groups_handler.add_room_to_group( - group_id, requester_user_id, room_id, content - ) - - return 200, result - - @_validate_group_id - async def on_DELETE( - self, request: SynapseRequest, group_id: str, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group categories." - result = await self.groups_handler.remove_room_from_group( - group_id, requester_user_id, room_id - ) - - return 200, result - - -class GroupAdminRoomsConfigServlet(RestServlet): - """Update the config of a room in a group""" - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/admin/rooms/(?P[^/]*)" - "/config/(?P[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str, room_id: str, config_key: str - ): - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group categories." - result = await self.groups_handler.update_room_in_group( - group_id, requester_user_id, room_id, config_key, content - ) - - return 200, result - - -class GroupAdminUsersInviteServlet(RestServlet): - """Invite a user to the group""" - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/admin/users/invite/(?P[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - self.store = hs.get_datastore() - self.is_mine_id = hs.is_mine_id - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id, user_id - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - config = content.get("config", {}) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot invite users to a group." - result = await self.groups_handler.invite( - group_id, user_id, requester_user_id, config - ) - - return 200, result - - -class GroupAdminUsersKickServlet(RestServlet): - """Kick a user from the group""" - - PATTERNS = client_patterns( - "/groups/(?P[^/]*)/admin/users/remove/(?P[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id, user_id - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot kick users from a group." - result = await self.groups_handler.remove_user_from_group( - group_id, user_id, requester_user_id, content - ) - - return 200, result - - -class GroupSelfLeaveServlet(RestServlet): - """Leave a joined group""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/self/leave$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot leave a group for a users." - result = await self.groups_handler.remove_user_from_group( - group_id, requester_user_id, requester_user_id, content - ) - - return 200, result - - -class GroupSelfJoinServlet(RestServlet): - """Attempt to join a group, or knock""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/self/join$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot join a user to a group." - result = await self.groups_handler.join_group( - group_id, requester_user_id, content - ) - - return 200, result - - -class GroupSelfAcceptInviteServlet(RestServlet): - """Accept a group invite""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/self/accept_invite$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot accept an invite to a group." - result = await self.groups_handler.accept_invite( - group_id, requester_user_id, content - ) - - return 200, result - - -class GroupSelfUpdatePublicityServlet(RestServlet): - """Update whether we publicise a users membership of a group""" - - PATTERNS = client_patterns("/groups/(?P[^/]*)/self/update_publicity$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.store = hs.get_datastore() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - publicise = content["publicise"] - await self.store.update_group_publicity(group_id, requester_user_id, publicise) - - return 200, {} - - -class PublicisedGroupsForUserServlet(RestServlet): - """Get the list of groups a user is advertising""" - - PATTERNS = client_patterns("/publicised_groups/(?P[^/]*)$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.store = hs.get_datastore() - self.groups_handler = hs.get_groups_local_handler() - - async def on_GET( - self, request: SynapseRequest, user_id: str - ) -> Tuple[int, JsonDict]: - await self.auth.get_user_by_req(request, allow_guest=True) - - result = await self.groups_handler.get_publicised_groups_for_user(user_id) - - return 200, result - - -class PublicisedGroupsForUsersServlet(RestServlet): - """Get the list of groups a user is advertising""" - - PATTERNS = client_patterns("/publicised_groups$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.store = hs.get_datastore() - self.groups_handler = hs.get_groups_local_handler() - - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - await self.auth.get_user_by_req(request, allow_guest=True) - - content = parse_json_object_from_request(request) - user_ids = content["user_ids"] - - result = await self.groups_handler.bulk_get_publicised_groups(user_ids) - - return 200, result - - -class GroupsForUserServlet(RestServlet): - """Get all groups the logged in user is joined to""" - - PATTERNS = client_patterns("/joined_groups$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - result = await self.groups_handler.get_joined_groups(requester_user_id) - - return 200, result - - -def register_servlets(hs: "HomeServer", http_server): - GroupServlet(hs).register(http_server) - GroupSummaryServlet(hs).register(http_server) - GroupInvitedUsersServlet(hs).register(http_server) - GroupUsersServlet(hs).register(http_server) - GroupRoomServlet(hs).register(http_server) - GroupSettingJoinPolicyServlet(hs).register(http_server) - GroupCreateServlet(hs).register(http_server) - GroupAdminRoomsServlet(hs).register(http_server) - GroupAdminRoomsConfigServlet(hs).register(http_server) - GroupAdminUsersInviteServlet(hs).register(http_server) - GroupAdminUsersKickServlet(hs).register(http_server) - GroupSelfLeaveServlet(hs).register(http_server) - GroupSelfJoinServlet(hs).register(http_server) - GroupSelfAcceptInviteServlet(hs).register(http_server) - GroupsForUserServlet(hs).register(http_server) - GroupCategoryServlet(hs).register(http_server) - GroupCategoriesServlet(hs).register(http_server) - GroupSummaryRoomsCatServlet(hs).register(http_server) - GroupRoleServlet(hs).register(http_server) - GroupRolesServlet(hs).register(http_server) - GroupSelfUpdatePublicityServlet(hs).register(http_server) - GroupSummaryUsersRoleServlet(hs).register(http_server) - PublicisedGroupsForUserServlet(hs).register(http_server) - PublicisedGroupsForUsersServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py deleted file mode 100644 index d0d9d30d40..0000000000 --- a/synapse/rest/client/v2_alpha/keys.py +++ /dev/null @@ -1,344 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2019 New Vector Ltd -# Copyright 2020 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.errors import SynapseError -from synapse.http.servlet import ( - RestServlet, - parse_integer, - parse_json_object_from_request, - parse_string, -) -from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.types import StreamToken - -from ._base import client_patterns, interactive_auth_handler - -logger = logging.getLogger(__name__) - - -class KeyUploadServlet(RestServlet): - """ - POST /keys/upload HTTP/1.1 - Content-Type: application/json - - { - "device_keys": { - "user_id": "", - "device_id": "", - "valid_until_ts": , - "algorithms": [ - "m.olm.curve25519-aes-sha2", - ] - "keys": { - ":": "", - }, - "signatures:" { - "" { - ":": "" - } } }, - "one_time_keys": { - ":": "" - }, - } - """ - - PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.auth = hs.get_auth() - self.e2e_keys_handler = hs.get_e2e_keys_handler() - self.device_handler = hs.get_device_handler() - - @trace(opname="upload_keys") - async def on_POST(self, request, device_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - user_id = requester.user.to_string() - body = parse_json_object_from_request(request) - - if device_id is not None: - # Providing the device_id should only be done for setting keys - # for dehydrated devices; however, we allow it for any device for - # compatibility with older clients. - if requester.device_id is not None and device_id != requester.device_id: - dehydrated_device = await self.device_handler.get_dehydrated_device( - user_id - ) - if dehydrated_device is not None and device_id != dehydrated_device[0]: - set_tag("error", True) - log_kv( - { - "message": "Client uploading keys for a different device", - "logged_in_id": requester.device_id, - "key_being_uploaded": device_id, - } - ) - logger.warning( - "Client uploading keys for a different device " - "(logged in as %s, uploading for %s)", - requester.device_id, - device_id, - ) - else: - device_id = requester.device_id - - if device_id is None: - raise SynapseError( - 400, "To upload keys, you must pass device_id when authenticating" - ) - - result = await self.e2e_keys_handler.upload_keys_for_user( - user_id, device_id, body - ) - return 200, result - - -class KeyQueryServlet(RestServlet): - """ - POST /keys/query HTTP/1.1 - Content-Type: application/json - { - "device_keys": { - "": [""] - } } - - HTTP/1.1 200 OK - { - "device_keys": { - "": { - "": { - "user_id": "", // Duplicated to be signed - "device_id": "", // Duplicated to be signed - "valid_until_ts": , - "algorithms": [ // List of supported algorithms - "m.olm.curve25519-aes-sha2", - ], - "keys": { // Must include a ed25519 signing key - ":": "", - }, - "signatures:" { - // Must be signed with device's ed25519 key - "/": { - ":": "" - } - // Must be signed by this server. - "": { - ":": "" - } } } } } } - """ - - PATTERNS = client_patterns("/keys/query$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ - super().__init__() - self.auth = hs.get_auth() - self.e2e_keys_handler = hs.get_e2e_keys_handler() - - async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - user_id = requester.user.to_string() - device_id = requester.device_id - timeout = parse_integer(request, "timeout", 10 * 1000) - body = parse_json_object_from_request(request) - result = await self.e2e_keys_handler.query_devices( - body, timeout, user_id, device_id - ) - return 200, result - - -class KeyChangesServlet(RestServlet): - """Returns the list of changes of keys between two stream tokens (may return - spurious extra results, since we currently ignore the `to` param). - - GET /keys/changes?from=...&to=... - - 200 OK - { "changed": ["@foo:example.com"] } - """ - - PATTERNS = client_patterns("/keys/changes$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ - super().__init__() - self.auth = hs.get_auth() - self.device_handler = hs.get_device_handler() - self.store = hs.get_datastore() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - from_token_string = parse_string(request, "from", required=True) - set_tag("from", from_token_string) - - # We want to enforce they do pass us one, but we ignore it and return - # changes after the "to" as well as before. - set_tag("to", parse_string(request, "to")) - - from_token = await StreamToken.from_string(self.store, from_token_string) - - user_id = requester.user.to_string() - - results = await self.device_handler.get_user_ids_changed(user_id, from_token) - - return 200, results - - -class OneTimeKeyServlet(RestServlet): - """ - POST /keys/claim HTTP/1.1 - { - "one_time_keys": { - "": { - "": "" - } } } - - HTTP/1.1 200 OK - { - "one_time_keys": { - "": { - "": { - ":": "" - } } } } - - """ - - PATTERNS = client_patterns("/keys/claim$") - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.e2e_keys_handler = hs.get_e2e_keys_handler() - - async def on_POST(self, request): - await self.auth.get_user_by_req(request, allow_guest=True) - timeout = parse_integer(request, "timeout", 10 * 1000) - body = parse_json_object_from_request(request) - result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout) - return 200, result - - -class SigningKeyUploadServlet(RestServlet): - """ - POST /keys/device_signing/upload HTTP/1.1 - Content-Type: application/json - - { - } - """ - - PATTERNS = client_patterns("/keys/device_signing/upload$", releases=()) - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.e2e_keys_handler = hs.get_e2e_keys_handler() - self.auth_handler = hs.get_auth_handler() - - @interactive_auth_handler - async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - body = parse_json_object_from_request(request) - - await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - "add a device signing key to your account", - # Allow skipping of UI auth since this is frequently called directly - # after login and it is silly to ask users to re-auth immediately. - can_skip_ui_auth=True, - ) - - result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) - return 200, result - - -class SignaturesUploadServlet(RestServlet): - """ - POST /keys/signatures/upload HTTP/1.1 - Content-Type: application/json - - { - "@alice:example.com": { - "": { - "user_id": "", - "device_id": "", - "algorithms": [ - "m.olm.curve25519-aes-sha2", - "m.megolm.v1.aes-sha2" - ], - "keys": { - ":": "", - }, - "signatures": { - "": { - ":": ">" - } - } - } - } - } - """ - - PATTERNS = client_patterns("/keys/signatures/upload$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.auth = hs.get_auth() - self.e2e_keys_handler = hs.get_e2e_keys_handler() - - async def on_POST(self, request): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - user_id = requester.user.to_string() - body = parse_json_object_from_request(request) - - result = await self.e2e_keys_handler.upload_signatures_for_device_keys( - user_id, body - ) - return 200, result - - -def register_servlets(hs, http_server): - KeyUploadServlet(hs).register(http_server) - KeyQueryServlet(hs).register(http_server) - KeyChangesServlet(hs).register(http_server) - OneTimeKeyServlet(hs).register(http_server) - SigningKeyUploadServlet(hs).register(http_server) - SignaturesUploadServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/knock.py b/synapse/rest/client/v2_alpha/knock.py deleted file mode 100644 index 7d1bc40658..0000000000 --- a/synapse/rest/client/v2_alpha/knock.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2020 Sorunome -# Copyright 2020 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple - -from twisted.web.server import Request - -from synapse.api.constants import Membership -from synapse.api.errors import SynapseError -from synapse.http.servlet import ( - RestServlet, - parse_json_object_from_request, - parse_strings_from_args, -) -from synapse.http.site import SynapseRequest -from synapse.logging.opentracing import set_tag -from synapse.rest.client.transactions import HttpTransactionCache -from synapse.types import JsonDict, RoomAlias, RoomID - -if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class KnockRoomAliasServlet(RestServlet): - """ - POST /knock/{roomIdOrAlias} - """ - - PATTERNS = client_patterns("/knock/(?P[^/]*)") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.txns = HttpTransactionCache(hs) - self.room_member_handler = hs.get_room_member_handler() - self.auth = hs.get_auth() - - async def on_POST( - self, - request: SynapseRequest, - room_identifier: str, - txn_id: Optional[str] = None, - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - - content = parse_json_object_from_request(request) - event_content = None - if "reason" in content: - event_content = {"reason": content["reason"]} - - if RoomID.is_valid(room_identifier): - room_id = room_identifier - - # twisted.web.server.Request.args is incorrectly defined as Optional[Any] - args: Dict[bytes, List[bytes]] = request.args # type: ignore - - remote_room_hosts = parse_strings_from_args( - args, "server_name", required=False - ) - elif RoomAlias.is_valid(room_identifier): - handler = self.room_member_handler - room_alias = RoomAlias.from_string(room_identifier) - room_id_obj, remote_room_hosts = await handler.lookup_room_alias(room_alias) - room_id = room_id_obj.to_string() - else: - raise SynapseError( - 400, "%s was not legal room ID or room alias" % (room_identifier,) - ) - - await self.room_member_handler.update_membership( - requester=requester, - target=requester.user, - room_id=room_id, - action=Membership.KNOCK, - txn_id=txn_id, - third_party_signed=None, - remote_room_hosts=remote_room_hosts, - content=event_content, - ) - - return 200, {"room_id": room_id} - - def on_PUT(self, request: Request, room_identifier: str, txn_id: str): - set_tag("txn_id", txn_id) - - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_identifier, txn_id - ) - - -def register_servlets(hs, http_server): - KnockRoomAliasServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py deleted file mode 100644 index 0ede643c2d..0000000000 --- a/synapse/rest/client/v2_alpha/notifications.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.events.utils import format_event_for_client_v2_without_room_id -from synapse.http.servlet import RestServlet, parse_integer, parse_string - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class NotificationsServlet(RestServlet): - PATTERNS = client_patterns("/notifications$") - - def __init__(self, hs): - super().__init__() - self.store = hs.get_datastore() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self._event_serializer = hs.get_event_client_serializer() - - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - - from_token = parse_string(request, "from", required=False) - limit = parse_integer(request, "limit", default=50) - only = parse_string(request, "only", required=False) - - limit = min(limit, 500) - - push_actions = await self.store.get_push_actions_for_user( - user_id, from_token, limit, only_highlight=(only == "highlight") - ) - - receipts_by_room = await self.store.get_receipts_for_user_with_orderings( - user_id, "m.read" - ) - - notif_event_ids = [pa["event_id"] for pa in push_actions] - notif_events = await self.store.get_events(notif_event_ids) - - returned_push_actions = [] - - next_token = None - - for pa in push_actions: - returned_pa = { - "room_id": pa["room_id"], - "profile_tag": pa["profile_tag"], - "actions": pa["actions"], - "ts": pa["received_ts"], - "event": ( - await self._event_serializer.serialize_event( - notif_events[pa["event_id"]], - self.clock.time_msec(), - event_format=format_event_for_client_v2_without_room_id, - ) - ), - } - - if pa["room_id"] not in receipts_by_room: - returned_pa["read"] = False - else: - receipt = receipts_by_room[pa["room_id"]] - - returned_pa["read"] = ( - receipt["topological_ordering"], - receipt["stream_ordering"], - ) >= (pa["topological_ordering"], pa["stream_ordering"]) - returned_push_actions.append(returned_pa) - next_token = str(pa["stream_ordering"]) - - return 200, {"notifications": returned_push_actions, "next_token": next_token} - - -def register_servlets(hs, http_server): - NotificationsServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py deleted file mode 100644 index e8d2673819..0000000000 --- a/synapse/rest/client/v2_alpha/openid.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging - -from synapse.api.errors import AuthError -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.util.stringutils import random_string - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class IdTokenServlet(RestServlet): - """ - Get a bearer token that may be passed to a third party to confirm ownership - of a matrix user id. - - The format of the response could be made compatible with the format given - in http://openid.net/specs/openid-connect-core-1_0.html#TokenResponse - - But instead of returning a signed "id_token" the response contains the - name of the issuing matrix homeserver. This means that for now the third - party will need to check the validity of the "id_token" against the - federation /openid/userinfo endpoint of the homeserver. - - Request: - - POST /user/{user_id}/openid/request_token?access_token=... HTTP/1.1 - - {} - - Response: - - HTTP/1.1 200 OK - { - "access_token": "ABDEFGH", - "token_type": "Bearer", - "matrix_server_name": "example.com", - "expires_in": 3600, - } - """ - - PATTERNS = client_patterns("/user/(?P[^/]*)/openid/request_token") - - EXPIRES_MS = 3600 * 1000 - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.clock = hs.get_clock() - self.server_name = hs.config.server_name - - async def on_POST(self, request, user_id): - requester = await self.auth.get_user_by_req(request) - if user_id != requester.user.to_string(): - raise AuthError(403, "Cannot request tokens for other users.") - - # Parse the request body to make sure it's JSON, but ignore the contents - # for now. - parse_json_object_from_request(request) - - token = random_string(24) - ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS - - await self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) - - return ( - 200, - { - "access_token": token, - "token_type": "Bearer", - "matrix_server_name": self.server_name, - "expires_in": self.EXPIRES_MS // 1000, - }, - ) - - -def register_servlets(hs, http_server): - IdTokenServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py deleted file mode 100644 index a83927aee6..0000000000 --- a/synapse/rest/client/v2_alpha/password_policy.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.http.servlet import RestServlet - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class PasswordPolicyServlet(RestServlet): - PATTERNS = client_patterns("/password_policy$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - - self.policy = hs.config.password_policy - self.enabled = hs.config.password_policy_enabled - - def on_GET(self, request): - if not self.enabled or not self.policy: - return (200, {}) - - policy = {} - - for param in [ - "minimum_length", - "require_digit", - "require_symbol", - "require_lowercase", - "require_uppercase", - ]: - if param in self.policy: - policy["m.%s" % param] = self.policy[param] - - return (200, policy) - - -def register_servlets(hs, http_server): - PasswordPolicyServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py deleted file mode 100644 index 027f8b81fa..0000000000 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.constants import ReadReceiptEventFields -from synapse.api.errors import Codes, SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class ReadMarkerRestServlet(RestServlet): - PATTERNS = client_patterns("/rooms/(?P[^/]*)/read_markers$") - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.receipts_handler = hs.get_receipts_handler() - self.read_marker_handler = hs.get_read_marker_handler() - self.presence_handler = hs.get_presence_handler() - - async def on_POST(self, request, room_id): - requester = await self.auth.get_user_by_req(request) - - await self.presence_handler.bump_presence_active_time(requester.user) - - body = parse_json_object_from_request(request) - read_event_id = body.get("m.read", None) - hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) - - if not isinstance(hidden, bool): - raise SynapseError( - 400, - "Param %s must be a boolean, if given" - % ReadReceiptEventFields.MSC2285_HIDDEN, - Codes.BAD_JSON, - ) - - if read_event_id: - await self.receipts_handler.received_client_receipt( - room_id, - "m.read", - user_id=requester.user.to_string(), - event_id=read_event_id, - hidden=hidden, - ) - - read_marker_event_id = body.get("m.fully_read", None) - if read_marker_event_id: - await self.read_marker_handler.received_client_read_marker( - room_id, - user_id=requester.user.to_string(), - event_id=read_marker_event_id, - ) - - return 200, {} - - -def register_servlets(hs, http_server): - ReadMarkerRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py deleted file mode 100644 index d9ab836cd8..0000000000 --- a/synapse/rest/client/v2_alpha/receipts.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.constants import ReadReceiptEventFields -from synapse.api.errors import Codes, SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class ReceiptRestServlet(RestServlet): - PATTERNS = client_patterns( - "/rooms/(?P[^/]*)" - "/receipt/(?P[^/]*)" - "/(?P[^/]*)$" - ) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.receipts_handler = hs.get_receipts_handler() - self.presence_handler = hs.get_presence_handler() - - async def on_POST(self, request, room_id, receipt_type, event_id): - requester = await self.auth.get_user_by_req(request) - - if receipt_type != "m.read": - raise SynapseError(400, "Receipt type must be 'm.read'") - - body = parse_json_object_from_request(request, allow_empty_body=True) - hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) - - if not isinstance(hidden, bool): - raise SynapseError( - 400, - "Param %s must be a boolean, if given" - % ReadReceiptEventFields.MSC2285_HIDDEN, - Codes.BAD_JSON, - ) - - await self.presence_handler.bump_presence_active_time(requester.user) - - await self.receipts_handler.received_client_receipt( - room_id, - receipt_type, - user_id=requester.user.to_string(), - event_id=event_id, - hidden=hidden, - ) - - return 200, {} - - -def register_servlets(hs, http_server): - ReceiptRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py deleted file mode 100644 index 4d31584acd..0000000000 --- a/synapse/rest/client/v2_alpha/register.py +++ /dev/null @@ -1,879 +0,0 @@ -# Copyright 2015 - 2016 OpenMarket Ltd -# Copyright 2017 Vector Creations Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import hmac -import logging -import random -from typing import List, Union - -import synapse -import synapse.api.auth -import synapse.types -from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType -from synapse.api.errors import ( - Codes, - InteractiveAuthIncompleteError, - SynapseError, - ThreepidValidationError, - UnrecognizedRequestError, -) -from synapse.config import ConfigError -from synapse.config.captcha import CaptchaConfig -from synapse.config.consent import ConsentConfig -from synapse.config.emailconfig import ThreepidBehaviour -from synapse.config.ratelimiting import FederationRateLimitConfig -from synapse.config.registration import RegistrationConfig -from synapse.config.server import is_threepid_reserved -from synapse.handlers.auth import AuthHandler -from synapse.handlers.ui_auth import UIAuthSessionDataConstants -from synapse.http.server import finish_request, respond_with_html -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_boolean, - parse_json_object_from_request, - parse_string, -) -from synapse.metrics import threepid_send_requests -from synapse.push.mailer import Mailer -from synapse.types import JsonDict -from synapse.util.msisdn import phone_number_to_msisdn -from synapse.util.ratelimitutils import FederationRateLimiter -from synapse.util.stringutils import assert_valid_client_secret, random_string -from synapse.util.threepids import ( - canonicalise_email, - check_3pid_allowed, - validate_email, -) - -from ._base import client_patterns, interactive_auth_handler - -# We ought to be using hmac.compare_digest() but on older pythons it doesn't -# exist. It's a _really minor_ security flaw to use plain string comparison -# because the timing attack is so obscured by all the other code here it's -# unlikely to make much difference -if hasattr(hmac, "compare_digest"): - compare_digest = hmac.compare_digest -else: - - def compare_digest(a, b): - return a == b - - -logger = logging.getLogger(__name__) - - -class EmailRegisterRequestTokenRestServlet(RestServlet): - PATTERNS = client_patterns("/register/email/requestToken$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.hs = hs - self.identity_handler = hs.get_identity_handler() - self.config = hs.config - - if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self.mailer = Mailer( - hs=self.hs, - app_name=self.config.email_app_name, - template_html=self.config.email_registration_template_html, - template_text=self.config.email_registration_template_text, - ) - - async def on_POST(self, request): - if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.hs.config.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "Email registration has been disabled due to lack of email config" - ) - raise SynapseError( - 400, "Email-based registration has been disabled on this server" - ) - body = parse_json_object_from_request(request) - - assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) - - # Extract params from body - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) - - # For emails, canonicalise the address. - # We store all email addresses canonicalised in the DB. - # (See on_POST in EmailThreepidRequestTokenRestServlet - # in synapse/rest/client/v2_alpha/account.py) - try: - email = validate_email(body["email"]) - except ValueError as e: - raise SynapseError(400, str(e)) - send_attempt = body["send_attempt"] - next_link = body.get("next_link") # Optional param - - if not check_3pid_allowed(self.hs, "email", email): - raise SynapseError( - 403, - "Your email domain is not authorized to register on this server", - Codes.THREEPID_DENIED, - ) - - await self.identity_handler.ratelimit_request_token_requests( - request, "email", email - ) - - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( - "email", email - ) - - if existing_user_id is not None: - if self.hs.config.request_token_inhibit_3pid_errors: - # Make the client think the operation succeeded. See the rationale in the - # comments for request_token_inhibit_3pid_errors. - # Also wait for some random amount of time between 100ms and 1s to make it - # look like we did something. - await self.hs.get_clock().sleep(random.randint(1, 10) / 10) - return 200, {"sid": random_string(16)} - - raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) - - if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email - - # Have the configured identity server handle the request - ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, - email, - client_secret, - send_attempt, - next_link, - ) - else: - # Send registration emails from Synapse - sid = await self.identity_handler.send_threepid_validation( - email, - client_secret, - send_attempt, - self.mailer.send_registration_mail, - next_link, - ) - - # Wrap the session id in a JSON object - ret = {"sid": sid} - - threepid_send_requests.labels(type="email", reason="register").observe( - send_attempt - ) - - return 200, ret - - -class MsisdnRegisterRequestTokenRestServlet(RestServlet): - PATTERNS = client_patterns("/register/msisdn/requestToken$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.hs = hs - self.identity_handler = hs.get_identity_handler() - - async def on_POST(self, request): - body = parse_json_object_from_request(request) - - assert_params_in_dict( - body, ["client_secret", "country", "phone_number", "send_attempt"] - ) - client_secret = body["client_secret"] - assert_valid_client_secret(client_secret) - country = body["country"] - phone_number = body["phone_number"] - send_attempt = body["send_attempt"] - next_link = body.get("next_link") # Optional param - - msisdn = phone_number_to_msisdn(country, phone_number) - - if not check_3pid_allowed(self.hs, "msisdn", msisdn): - raise SynapseError( - 403, - "Phone numbers are not authorized to register on this server", - Codes.THREEPID_DENIED, - ) - - await self.identity_handler.ratelimit_request_token_requests( - request, "msisdn", msisdn - ) - - existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( - "msisdn", msisdn - ) - - if existing_user_id is not None: - if self.hs.config.request_token_inhibit_3pid_errors: - # Make the client think the operation succeeded. See the rationale in the - # comments for request_token_inhibit_3pid_errors. - # Also wait for some random amount of time between 100ms and 1s to make it - # look like we did something. - await self.hs.get_clock().sleep(random.randint(1, 10) / 10) - return 200, {"sid": random_string(16)} - - raise SynapseError( - 400, "Phone number is already in use", Codes.THREEPID_IN_USE - ) - - if not self.hs.config.account_threepid_delegate_msisdn: - logger.warning( - "No upstream msisdn account_threepid_delegate configured on the server to " - "handle this request" - ) - raise SynapseError( - 400, "Registration by phone number is not supported on this homeserver" - ) - - ret = await self.identity_handler.requestMsisdnToken( - self.hs.config.account_threepid_delegate_msisdn, - country, - phone_number, - client_secret, - send_attempt, - next_link, - ) - - threepid_send_requests.labels(type="msisdn", reason="register").observe( - send_attempt - ) - - return 200, ret - - -class RegistrationSubmitTokenServlet(RestServlet): - """Handles registration 3PID validation token submission""" - - PATTERNS = client_patterns( - "/registration/(?P[^/]*)/submit_token$", releases=(), unstable=True - ) - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.config = hs.config - self.clock = hs.get_clock() - self.store = hs.get_datastore() - - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - self._failure_email_template = ( - self.config.email_registration_template_failure_html - ) - - async def on_GET(self, request, medium): - if medium != "email": - raise SynapseError( - 400, "This medium is currently not supported for registration" - ) - if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "User registration via email has been disabled due to lack of email config" - ) - raise SynapseError( - 400, "Email-based registration is disabled on this server" - ) - - sid = parse_string(request, "sid", required=True) - client_secret = parse_string(request, "client_secret", required=True) - assert_valid_client_secret(client_secret) - token = parse_string(request, "token", required=True) - - # Attempt to validate a 3PID session - try: - # Mark the session as valid - next_link = await self.store.validate_threepid_session( - sid, client_secret, token, self.clock.time_msec() - ) - - # Perform a 302 redirect if next_link is set - if next_link: - if next_link.startswith("file:///"): - logger.warning( - "Not redirecting to next_link as it is a local file: address" - ) - else: - request.setResponseCode(302) - request.setHeader("Location", next_link) - finish_request(request) - return None - - # Otherwise show the success template - html = self.config.email_registration_template_success_html_content - status_code = 200 - except ThreepidValidationError as e: - status_code = e.code - - # Show a failure page with a reason - template_vars = {"failure_reason": e.msg} - html = self._failure_email_template.render(**template_vars) - - respond_with_html(request, status_code, html) - - -class UsernameAvailabilityRestServlet(RestServlet): - PATTERNS = client_patterns("/register/available") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.hs = hs - self.registration_handler = hs.get_registration_handler() - self.ratelimiter = FederationRateLimiter( - hs.get_clock(), - FederationRateLimitConfig( - # Time window of 2s - window_size=2000, - # Artificially delay requests if rate > sleep_limit/window_size - sleep_limit=1, - # Amount of artificial delay to apply - sleep_msec=1000, - # Error with 429 if more than reject_limit requests are queued - reject_limit=1, - # Allow 1 request at a time - concurrent_requests=1, - ), - ) - - async def on_GET(self, request): - if not self.hs.config.enable_registration: - raise SynapseError( - 403, "Registration has been disabled", errcode=Codes.FORBIDDEN - ) - - ip = request.getClientIP() - with self.ratelimiter.ratelimit(ip) as wait_deferred: - await wait_deferred - - username = parse_string(request, "username", required=True) - - await self.registration_handler.check_username(username) - - return 200, {"available": True} - - -class RegisterRestServlet(RestServlet): - PATTERNS = client_patterns("/register$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - - self.hs = hs - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.auth_handler = hs.get_auth_handler() - self.registration_handler = hs.get_registration_handler() - self.identity_handler = hs.get_identity_handler() - self.room_member_handler = hs.get_room_member_handler() - self.macaroon_gen = hs.get_macaroon_generator() - self.ratelimiter = hs.get_registration_ratelimiter() - self.password_policy_handler = hs.get_password_policy_handler() - self.clock = hs.get_clock() - self._registration_enabled = self.hs.config.enable_registration - self._msc2918_enabled = hs.config.access_token_lifetime is not None - - self._registration_flows = _calculate_registration_flows( - hs.config, self.auth_handler - ) - - @interactive_auth_handler - async def on_POST(self, request): - body = parse_json_object_from_request(request) - - client_addr = request.getClientIP() - - await self.ratelimiter.ratelimit(None, client_addr, update=False) - - kind = b"user" - if b"kind" in request.args: - kind = request.args[b"kind"][0] - - if kind == b"guest": - ret = await self._do_guest_registration(body, address=client_addr) - return ret - elif kind != b"user": - raise UnrecognizedRequestError( - "Do not understand membership kind: %s" % (kind.decode("utf8"),) - ) - - if self._msc2918_enabled: - # Check if this registration should also issue a refresh token, as - # per MSC2918 - should_issue_refresh_token = parse_boolean( - request, name="org.matrix.msc2918.refresh_token", default=False - ) - else: - should_issue_refresh_token = False - - # Pull out the provided username and do basic sanity checks early since - # the auth layer will store these in sessions. - desired_username = None - if "username" in body: - if not isinstance(body["username"], str) or len(body["username"]) > 512: - raise SynapseError(400, "Invalid username") - desired_username = body["username"] - - # fork off as soon as possible for ASes which have completely - # different registration flows to normal users - - # == Application Service Registration == - if body.get("type") == APP_SERVICE_REGISTRATION_TYPE: - if not self.auth.has_access_token(request): - raise SynapseError( - 400, - "Appservice token must be provided when using a type of m.login.application_service", - ) - - # Verify the AS - self.auth.get_appservice_by_req(request) - - # Set the desired user according to the AS API (which uses the - # 'user' key not 'username'). Since this is a new addition, we'll - # fallback to 'username' if they gave one. - desired_username = body.get("user", desired_username) - - # XXX we should check that desired_username is valid. Currently - # we give appservices carte blanche for any insanity in mxids, - # because the IRC bridges rely on being able to register stupid - # IDs. - - access_token = self.auth.get_access_token_from_request(request) - - if not isinstance(desired_username, str): - raise SynapseError(400, "Desired Username is missing or not a string") - - result = await self._do_appservice_registration( - desired_username, - access_token, - body, - should_issue_refresh_token=should_issue_refresh_token, - ) - - return 200, result - elif self.auth.has_access_token(request): - raise SynapseError( - 400, - "An access token should not be provided on requests to /register (except if type is m.login.application_service)", - ) - - # == Normal User Registration == (everyone else) - if not self._registration_enabled: - raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN) - - # For regular registration, convert the provided username to lowercase - # before attempting to register it. This should mean that people who try - # to register with upper-case in their usernames don't get a nasty surprise. - # - # Note that we treat usernames case-insensitively in login, so they are - # free to carry on imagining that their username is CrAzYh4cKeR if that - # keeps them happy. - if desired_username is not None: - desired_username = desired_username.lower() - - # Check if this account is upgrading from a guest account. - guest_access_token = body.get("guest_access_token", None) - - # Pull out the provided password and do basic sanity checks early. - # - # Note that we remove the password from the body since the auth layer - # will store the body in the session and we don't want a plaintext - # password store there. - password = body.pop("password", None) - if password is not None: - if not isinstance(password, str) or len(password) > 512: - raise SynapseError(400, "Invalid password") - self.password_policy_handler.validate_password(password) - - if "initial_device_display_name" in body and password is None: - # ignore 'initial_device_display_name' if sent without - # a password to work around a client bug where it sent - # the 'initial_device_display_name' param alone, wiping out - # the original registration params - logger.warning("Ignoring initial_device_display_name without password") - del body["initial_device_display_name"] - - session_id = self.auth_handler.get_session_id(body) - registered_user_id = None - password_hash = None - if session_id: - # if we get a registered user id out of here, it means we previously - # registered a user for this session, so we could just return the - # user here. We carry on and go through the auth checks though, - # for paranoia. - registered_user_id = await self.auth_handler.get_session_data( - session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None - ) - # Extract the previously-hashed password from the session. - password_hash = await self.auth_handler.get_session_data( - session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None - ) - - # Ensure that the username is valid. - if desired_username is not None: - await self.registration_handler.check_username( - desired_username, - guest_access_token=guest_access_token, - assigned_user_id=registered_user_id, - ) - - # Check if the user-interactive authentication flows are complete, if - # not this will raise a user-interactive auth error. - try: - auth_result, params, session_id = await self.auth_handler.check_ui_auth( - self._registration_flows, - request, - body, - "register a new account", - ) - except InteractiveAuthIncompleteError as e: - # The user needs to provide more steps to complete auth. - # - # Hash the password and store it with the session since the client - # is not required to provide the password again. - # - # If a password hash was previously stored we will not attempt to - # re-hash and store it for efficiency. This assumes the password - # does not change throughout the authentication flow, but this - # should be fine since the data is meant to be consistent. - if not password_hash and password: - password_hash = await self.auth_handler.hash(password) - await self.auth_handler.set_session_data( - e.session_id, - UIAuthSessionDataConstants.PASSWORD_HASH, - password_hash, - ) - raise - - # Check that we're not trying to register a denied 3pid. - # - # the user-facing checks will probably already have happened in - # /register/email/requestToken when we requested a 3pid, but that's not - # guaranteed. - if auth_result: - for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: - if login_type in auth_result: - medium = auth_result[login_type]["medium"] - address = auth_result[login_type]["address"] - - if not check_3pid_allowed(self.hs, medium, address): - raise SynapseError( - 403, - "Third party identifiers (email/phone numbers)" - + " are not authorized on this server", - Codes.THREEPID_DENIED, - ) - - if registered_user_id is not None: - logger.info( - "Already registered user ID %r for this session", registered_user_id - ) - # don't re-register the threepids - registered = False - else: - # If we have a password in this request, prefer it. Otherwise, there - # might be a password hash from an earlier request. - if password: - password_hash = await self.auth_handler.hash(password) - if not password_hash: - raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) - - desired_username = params.get("username", None) - guest_access_token = params.get("guest_access_token", None) - - if desired_username is not None: - desired_username = desired_username.lower() - - threepid = None - if auth_result: - threepid = auth_result.get(LoginType.EMAIL_IDENTITY) - - # Also check that we're not trying to register a 3pid that's already - # been registered. - # - # This has probably happened in /register/email/requestToken as well, - # but if a user hits this endpoint twice then clicks on each link from - # the two activation emails, they would register the same 3pid twice. - for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: - if login_type in auth_result: - medium = auth_result[login_type]["medium"] - address = auth_result[login_type]["address"] - # For emails, canonicalise the address. - # We store all email addresses canonicalised in the DB. - # (See on_POST in EmailThreepidRequestTokenRestServlet - # in synapse/rest/client/v2_alpha/account.py) - if medium == "email": - try: - address = canonicalise_email(address) - except ValueError as e: - raise SynapseError(400, str(e)) - - existing_user_id = await self.store.get_user_id_by_threepid( - medium, address - ) - - if existing_user_id is not None: - raise SynapseError( - 400, - "%s is already in use" % medium, - Codes.THREEPID_IN_USE, - ) - - entries = await self.store.get_user_agents_ips_to_ui_auth_session( - session_id - ) - - registered_user_id = await self.registration_handler.register_user( - localpart=desired_username, - password_hash=password_hash, - guest_access_token=guest_access_token, - threepid=threepid, - address=client_addr, - user_agent_ips=entries, - ) - # Necessary due to auth checks prior to the threepid being - # written to the db - if threepid: - if is_threepid_reserved( - self.hs.config.mau_limits_reserved_threepids, threepid - ): - await self.store.upsert_monthly_active_user(registered_user_id) - - # Remember that the user account has been registered (and the user - # ID it was registered with, since it might not have been specified). - await self.auth_handler.set_session_data( - session_id, - UIAuthSessionDataConstants.REGISTERED_USER_ID, - registered_user_id, - ) - - registered = True - - return_dict = await self._create_registration_details( - registered_user_id, - params, - should_issue_refresh_token=should_issue_refresh_token, - ) - - if registered: - await self.registration_handler.post_registration_actions( - user_id=registered_user_id, - auth_result=auth_result, - access_token=return_dict.get("access_token"), - ) - - return 200, return_dict - - async def _do_appservice_registration( - self, username, as_token, body, should_issue_refresh_token: bool = False - ): - user_id = await self.registration_handler.appservice_register( - username, as_token - ) - return await self._create_registration_details( - user_id, - body, - is_appservice_ghost=True, - should_issue_refresh_token=should_issue_refresh_token, - ) - - async def _create_registration_details( - self, - user_id: str, - params: JsonDict, - is_appservice_ghost: bool = False, - should_issue_refresh_token: bool = False, - ): - """Complete registration of newly-registered user - - Allocates device_id if one was not given; also creates access_token. - - Args: - user_id: full canonical @user:id - params: registration parameters, from which we pull device_id, - initial_device_name and inhibit_login - is_appservice_ghost - should_issue_refresh_token: True if this registration should issue - a refresh token alongside the access token. - Returns: - dictionary for response from /register - """ - result = {"user_id": user_id, "home_server": self.hs.hostname} - if not params.get("inhibit_login", False): - device_id = params.get("device_id") - initial_display_name = params.get("initial_device_display_name") - ( - device_id, - access_token, - valid_until_ms, - refresh_token, - ) = await self.registration_handler.register_device( - user_id, - device_id, - initial_display_name, - is_guest=False, - is_appservice_ghost=is_appservice_ghost, - should_issue_refresh_token=should_issue_refresh_token, - ) - - result.update({"access_token": access_token, "device_id": device_id}) - - if valid_until_ms is not None: - expires_in_ms = valid_until_ms - self.clock.time_msec() - result["expires_in_ms"] = expires_in_ms - - if refresh_token is not None: - result["refresh_token"] = refresh_token - - return result - - async def _do_guest_registration(self, params, address=None): - if not self.hs.config.allow_guest_access: - raise SynapseError(403, "Guest access is disabled") - user_id = await self.registration_handler.register_user( - make_guest=True, address=address - ) - - # we don't allow guests to specify their own device_id, because - # we have nowhere to store it. - device_id = synapse.api.auth.GUEST_DEVICE_ID - initial_display_name = params.get("initial_device_display_name") - ( - device_id, - access_token, - valid_until_ms, - refresh_token, - ) = await self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest=True - ) - - result = { - "user_id": user_id, - "device_id": device_id, - "access_token": access_token, - "home_server": self.hs.hostname, - } - - if valid_until_ms is not None: - expires_in_ms = valid_until_ms - self.clock.time_msec() - result["expires_in_ms"] = expires_in_ms - - if refresh_token is not None: - result["refresh_token"] = refresh_token - - return 200, result - - -def _calculate_registration_flows( - # technically `config` has to provide *all* of these interfaces, not just one - config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig], - auth_handler: AuthHandler, -) -> List[List[str]]: - """Get a suitable flows list for registration - - Args: - config: server configuration - auth_handler: authorization handler - - Returns: a list of supported flows - """ - # FIXME: need a better error than "no auth flow found" for scenarios - # where we required 3PID for registration but the user didn't give one - require_email = "email" in config.registrations_require_3pid - require_msisdn = "msisdn" in config.registrations_require_3pid - - show_msisdn = True - show_email = True - - if config.disable_msisdn_registration: - show_msisdn = False - require_msisdn = False - - enabled_auth_types = auth_handler.get_enabled_auth_types() - if LoginType.EMAIL_IDENTITY not in enabled_auth_types: - show_email = False - if require_email: - raise ConfigError( - "Configuration requires email address at registration, but email " - "validation is not configured" - ) - - if LoginType.MSISDN not in enabled_auth_types: - show_msisdn = False - if require_msisdn: - raise ConfigError( - "Configuration requires msisdn at registration, but msisdn " - "validation is not configured" - ) - - flows = [] - - # only support 3PIDless registration if no 3PIDs are required - if not require_email and not require_msisdn: - # Add a dummy step here, otherwise if a client completes - # recaptcha first we'll assume they were going for this flow - # and complete the request, when they could have been trying to - # complete one of the flows with email/msisdn auth. - flows.append([LoginType.DUMMY]) - - # only support the email-only flow if we don't require MSISDN 3PIDs - if show_email and not require_msisdn: - flows.append([LoginType.EMAIL_IDENTITY]) - - # only support the MSISDN-only flow if we don't require email 3PIDs - if show_msisdn and not require_email: - flows.append([LoginType.MSISDN]) - - if show_email and show_msisdn: - # always let users provide both MSISDN & email - flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY]) - - # Prepend m.login.terms to all flows if we're requiring consent - if config.user_consent_at_registration: - for flow in flows: - flow.insert(0, LoginType.TERMS) - - # Prepend recaptcha to all flows if we're requiring captcha - if config.enable_registration_captcha: - for flow in flows: - flow.insert(0, LoginType.RECAPTCHA) - - return flows - - -def register_servlets(hs, http_server): - EmailRegisterRequestTokenRestServlet(hs).register(http_server) - MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) - UsernameAvailabilityRestServlet(hs).register(http_server) - RegistrationSubmitTokenServlet(hs).register(http_server) - RegisterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py deleted file mode 100644 index 0821cd285f..0000000000 --- a/synapse/rest/client/v2_alpha/relations.py +++ /dev/null @@ -1,381 +0,0 @@ -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This class implements the proposed relation APIs from MSC 1849. - -Since the MSC has not been approved all APIs here are unstable and may change at -any time to reflect changes in the MSC. -""" - -import logging - -from synapse.api.constants import EventTypes, RelationTypes -from synapse.api.errors import ShadowBanError, SynapseError -from synapse.http.servlet import ( - RestServlet, - parse_integer, - parse_json_object_from_request, - parse_string, -) -from synapse.rest.client.transactions import HttpTransactionCache -from synapse.storage.relations import ( - AggregationPaginationToken, - PaginationChunk, - RelationPaginationToken, -) -from synapse.util.stringutils import random_string - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class RelationSendServlet(RestServlet): - """Helper API for sending events that have relation data. - - Example API shape to send a 👍 reaction to a room: - - POST /rooms/!foo/send_relation/$bar/m.annotation/m.reaction?key=%F0%9F%91%8D - {} - - { - "event_id": "$foobar" - } - """ - - PATTERN = ( - "/rooms/(?P[^/]*)/send_relation" - "/(?P[^/]*)/(?P[^/]*)/(?P[^/]*)" - ) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.event_creation_handler = hs.get_event_creation_handler() - self.txns = HttpTransactionCache(hs) - - def register(self, http_server): - http_server.register_paths( - "POST", - client_patterns(self.PATTERN + "$", releases=()), - self.on_PUT_or_POST, - self.__class__.__name__, - ) - http_server.register_paths( - "PUT", - client_patterns(self.PATTERN + "/(?P[^/]*)$", releases=()), - self.on_PUT, - self.__class__.__name__, - ) - - def on_PUT(self, request, *args, **kwargs): - return self.txns.fetch_or_execute_request( - request, self.on_PUT_or_POST, request, *args, **kwargs - ) - - async def on_PUT_or_POST( - self, request, room_id, parent_id, relation_type, event_type, txn_id=None - ): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - if event_type == EventTypes.Member: - # Add relations to a membership is meaningless, so we just deny it - # at the CS API rather than trying to handle it correctly. - raise SynapseError(400, "Cannot send member events with relations") - - content = parse_json_object_from_request(request) - - aggregation_key = parse_string(request, "key", encoding="utf-8") - - content["m.relates_to"] = { - "event_id": parent_id, - "key": aggregation_key, - "rel_type": relation_type, - } - - event_dict = { - "type": event_type, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - } - - try: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict=event_dict, txn_id=txn_id - ) - event_id = event.event_id - except ShadowBanError: - event_id = "$" + random_string(43) - - return 200, {"event_id": event_id} - - -class RelationPaginationServlet(RestServlet): - """API to paginate relations on an event by topological ordering, optionally - filtered by relation type and event type. - """ - - PATTERNS = client_patterns( - "/rooms/(?P[^/]*)/relations/(?P[^/]*)" - "(/(?P[^/]*)(/(?P[^/]*))?)?$", - releases=(), - ) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.clock = hs.get_clock() - self._event_serializer = hs.get_event_client_serializer() - self.event_handler = hs.get_event_handler() - - async def on_GET( - self, request, room_id, parent_id, relation_type=None, event_type=None - ): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - await self.auth.check_user_in_room_or_world_readable( - room_id, requester.user.to_string(), allow_departed_users=True - ) - - # This gets the original event and checks that a) the event exists and - # b) the user is allowed to view it. - event = await self.event_handler.get_event(requester.user, room_id, parent_id) - - limit = parse_integer(request, "limit", default=5) - from_token_str = parse_string(request, "from") - to_token_str = parse_string(request, "to") - - if event.internal_metadata.is_redacted(): - # If the event is redacted, return an empty list of relations - pagination_chunk = PaginationChunk(chunk=[]) - else: - # Return the relations - from_token = None - if from_token_str: - from_token = RelationPaginationToken.from_string(from_token_str) - - to_token = None - if to_token_str: - to_token = RelationPaginationToken.from_string(to_token_str) - - pagination_chunk = await self.store.get_relations_for_event( - event_id=parent_id, - relation_type=relation_type, - event_type=event_type, - limit=limit, - from_token=from_token, - to_token=to_token, - ) - - events = await self.store.get_events_as_list( - [c["event_id"] for c in pagination_chunk.chunk] - ) - - now = self.clock.time_msec() - # We set bundle_aggregations to False when retrieving the original - # event because we want the content before relations were applied to - # it. - original_event = await self._event_serializer.serialize_event( - event, now, bundle_aggregations=False - ) - # Similarly, we don't allow relations to be applied to relations, so we - # return the original relations without any aggregations on top of them - # here. - events = await self._event_serializer.serialize_events( - events, now, bundle_aggregations=False - ) - - return_value = pagination_chunk.to_dict() - return_value["chunk"] = events - return_value["original_event"] = original_event - - return 200, return_value - - -class RelationAggregationPaginationServlet(RestServlet): - """API to paginate aggregation groups of relations, e.g. paginate the - types and counts of the reactions on the events. - - Example request and response: - - GET /rooms/{room_id}/aggregations/{parent_id} - - { - chunk: [ - { - "type": "m.reaction", - "key": "👍", - "count": 3 - } - ] - } - """ - - PATTERNS = client_patterns( - "/rooms/(?P[^/]*)/aggregations/(?P[^/]*)" - "(/(?P[^/]*)(/(?P[^/]*))?)?$", - releases=(), - ) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.event_handler = hs.get_event_handler() - - async def on_GET( - self, request, room_id, parent_id, relation_type=None, event_type=None - ): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - await self.auth.check_user_in_room_or_world_readable( - room_id, - requester.user.to_string(), - allow_departed_users=True, - ) - - # This checks that a) the event exists and b) the user is allowed to - # view it. - event = await self.event_handler.get_event(requester.user, room_id, parent_id) - - if relation_type not in (RelationTypes.ANNOTATION, None): - raise SynapseError(400, "Relation type must be 'annotation'") - - limit = parse_integer(request, "limit", default=5) - from_token_str = parse_string(request, "from") - to_token_str = parse_string(request, "to") - - if event.internal_metadata.is_redacted(): - # If the event is redacted, return an empty list of relations - pagination_chunk = PaginationChunk(chunk=[]) - else: - # Return the relations - from_token = None - if from_token_str: - from_token = AggregationPaginationToken.from_string(from_token_str) - - to_token = None - if to_token_str: - to_token = AggregationPaginationToken.from_string(to_token_str) - - pagination_chunk = await self.store.get_aggregation_groups_for_event( - event_id=parent_id, - event_type=event_type, - limit=limit, - from_token=from_token, - to_token=to_token, - ) - - return 200, pagination_chunk.to_dict() - - -class RelationAggregationGroupPaginationServlet(RestServlet): - """API to paginate within an aggregation group of relations, e.g. paginate - all the 👍 reactions on an event. - - Example request and response: - - GET /rooms/{room_id}/aggregations/{parent_id}/m.annotation/m.reaction/👍 - - { - chunk: [ - { - "type": "m.reaction", - "content": { - "m.relates_to": { - "rel_type": "m.annotation", - "key": "👍" - } - } - }, - ... - ] - } - """ - - PATTERNS = client_patterns( - "/rooms/(?P[^/]*)/aggregations/(?P[^/]*)" - "/(?P[^/]*)/(?P[^/]*)/(?P[^/]*)$", - releases=(), - ) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.clock = hs.get_clock() - self._event_serializer = hs.get_event_client_serializer() - self.event_handler = hs.get_event_handler() - - async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - await self.auth.check_user_in_room_or_world_readable( - room_id, - requester.user.to_string(), - allow_departed_users=True, - ) - - # This checks that a) the event exists and b) the user is allowed to - # view it. - await self.event_handler.get_event(requester.user, room_id, parent_id) - - if relation_type != RelationTypes.ANNOTATION: - raise SynapseError(400, "Relation type must be 'annotation'") - - limit = parse_integer(request, "limit", default=5) - from_token_str = parse_string(request, "from") - to_token_str = parse_string(request, "to") - - from_token = None - if from_token_str: - from_token = RelationPaginationToken.from_string(from_token_str) - - to_token = None - if to_token_str: - to_token = RelationPaginationToken.from_string(to_token_str) - - result = await self.store.get_relations_for_event( - event_id=parent_id, - relation_type=relation_type, - event_type=event_type, - aggregation_key=key, - limit=limit, - from_token=from_token, - to_token=to_token, - ) - - events = await self.store.get_events_as_list( - [c["event_id"] for c in result.chunk] - ) - - now = self.clock.time_msec() - events = await self._event_serializer.serialize_events(events, now) - - return_value = result.to_dict() - return_value["chunk"] = events - - return 200, return_value - - -def register_servlets(hs, http_server): - RelationSendServlet(hs).register(http_server) - RelationPaginationServlet(hs).register(http_server) - RelationAggregationPaginationServlet(hs).register(http_server) - RelationAggregationGroupPaginationServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py deleted file mode 100644 index 07ea39a8a3..0000000000 --- a/synapse/rest/client/v2_alpha/report_event.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from http import HTTPStatus - -from synapse.api.errors import Codes, SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class ReportEventRestServlet(RestServlet): - PATTERNS = client_patterns("/rooms/(?P[^/]*)/report/(?P[^/]*)$") - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.store = hs.get_datastore() - - async def on_POST(self, request, room_id, event_id): - requester = await self.auth.get_user_by_req(request) - user_id = requester.user.to_string() - - body = parse_json_object_from_request(request) - - if not isinstance(body.get("reason", ""), str): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Param 'reason' must be a string", - Codes.BAD_JSON, - ) - if not isinstance(body.get("score", 0), int): - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "Param 'score' must be an integer", - Codes.BAD_JSON, - ) - - await self.store.add_event_report( - room_id=room_id, - event_id=event_id, - user_id=user_id, - reason=body.get("reason"), - content=body, - received_ts=self.clock.time_msec(), - ) - - return 200, {} - - -def register_servlets(hs, http_server): - ReportEventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/room.py b/synapse/rest/client/v2_alpha/room.py deleted file mode 100644 index 3172aba605..0000000000 --- a/synapse/rest/client/v2_alpha/room.py +++ /dev/null @@ -1,441 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import re - -from synapse.api.constants import EventContentFields, EventTypes -from synapse.api.errors import AuthError, Codes, SynapseError -from synapse.appservice import ApplicationService -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_json_object_from_request, - parse_string, - parse_strings_from_args, -) -from synapse.rest.client.transactions import HttpTransactionCache -from synapse.types import Requester, UserID, create_requester -from synapse.util.stringutils import random_string - -logger = logging.getLogger(__name__) - - -class RoomBatchSendEventRestServlet(RestServlet): - """ - API endpoint which can insert a chunk of events historically back in time - next to the given `prev_event`. - - `chunk_id` comes from `next_chunk_id `in the response of the batch send - endpoint and is derived from the "insertion" events added to each chunk. - It's not required for the first batch send. - - `state_events_at_start` is used to define the historical state events - needed to auth the events like join events. These events will float - outside of the normal DAG as outlier's and won't be visible in the chat - history which also allows us to insert multiple chunks without having a bunch - of `@mxid joined the room` noise between each chunk. - - `events` is chronological chunk/list of events you want to insert. - There is a reverse-chronological constraint on chunks so once you insert - some messages, you can only insert older ones after that. - tldr; Insert chunks from your most recent history -> oldest history. - - POST /_matrix/client/unstable/org.matrix.msc2716/rooms//batch_send?prev_event=&chunk_id= - { - "events": [ ... ], - "state_events_at_start": [ ... ] - } - """ - - PATTERNS = ( - re.compile( - "^/_matrix/client/unstable/org.matrix.msc2716" - "/rooms/(?P[^/]*)/batch_send$" - ), - ) - - def __init__(self, hs): - super().__init__() - self.hs = hs - self.store = hs.get_datastore() - self.state_store = hs.get_storage().state - self.event_creation_handler = hs.get_event_creation_handler() - self.room_member_handler = hs.get_room_member_handler() - self.auth = hs.get_auth() - self.txns = HttpTransactionCache(hs) - - async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int: - ( - most_recent_prev_event_id, - most_recent_prev_event_depth, - ) = await self.store.get_max_depth_of(prev_event_ids) - - # We want to insert the historical event after the `prev_event` but before the successor event - # - # We inherit depth from the successor event instead of the `prev_event` - # because events returned from `/messages` are first sorted by `topological_ordering` - # which is just the `depth` and then tie-break with `stream_ordering`. - # - # We mark these inserted historical events as "backfilled" which gives them a - # negative `stream_ordering`. If we use the same depth as the `prev_event`, - # then our historical event will tie-break and be sorted before the `prev_event` - # when it should come after. - # - # We want to use the successor event depth so they appear after `prev_event` because - # it has a larger `depth` but before the successor event because the `stream_ordering` - # is negative before the successor event. - successor_event_ids = await self.store.get_successor_events( - [most_recent_prev_event_id] - ) - - # If we can't find any successor events, then it's a forward extremity of - # historical messages and we can just inherit from the previous historical - # event which we can already assume has the correct depth where we want - # to insert into. - if not successor_event_ids: - depth = most_recent_prev_event_depth - else: - ( - _, - oldest_successor_depth, - ) = await self.store.get_min_depth_of(successor_event_ids) - - depth = oldest_successor_depth - - return depth - - def _create_insertion_event_dict( - self, sender: str, room_id: str, origin_server_ts: int - ): - """Creates an event dict for an "insertion" event with the proper fields - and a random chunk ID. - - Args: - sender: The event author MXID - room_id: The room ID that the event belongs to - origin_server_ts: Timestamp when the event was sent - - Returns: - Tuple of event ID and stream ordering position - """ - - next_chunk_id = random_string(8) - insertion_event = { - "type": EventTypes.MSC2716_INSERTION, - "sender": sender, - "room_id": room_id, - "content": { - EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id, - EventContentFields.MSC2716_HISTORICAL: True, - }, - "origin_server_ts": origin_server_ts, - } - - return insertion_event - - async def _create_requester_for_user_id_from_app_service( - self, user_id: str, app_service: ApplicationService - ) -> Requester: - """Creates a new requester for the given user_id - and validates that the app service is allowed to control - the given user. - - Args: - user_id: The author MXID that the app service is controlling - app_service: The app service that controls the user - - Returns: - Requester object - """ - - await self.auth.validate_appservice_can_control_user_id(app_service, user_id) - - return create_requester(user_id, app_service=app_service) - - async def on_POST(self, request, room_id): - requester = await self.auth.get_user_by_req(request, allow_guest=False) - - if not requester.app_service: - raise AuthError( - 403, - "Only application services can use the /batchsend endpoint", - ) - - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["state_events_at_start", "events"]) - - prev_events_from_query = parse_strings_from_args(request.args, "prev_event") - chunk_id_from_query = parse_string(request, "chunk_id") - - if prev_events_from_query is None: - raise SynapseError( - 400, - "prev_event query parameter is required when inserting historical messages back in time", - errcode=Codes.MISSING_PARAM, - ) - - # For the event we are inserting next to (`prev_events_from_query`), - # find the most recent auth events (derived from state events) that - # allowed that message to be sent. We will use that as a base - # to auth our historical messages against. - ( - most_recent_prev_event_id, - _, - ) = await self.store.get_max_depth_of(prev_events_from_query) - # mapping from (type, state_key) -> state_event_id - prev_state_map = await self.state_store.get_state_ids_for_event( - most_recent_prev_event_id - ) - # List of state event ID's - prev_state_ids = list(prev_state_map.values()) - auth_event_ids = prev_state_ids - - state_events_at_start = [] - for state_event in body["state_events_at_start"]: - assert_params_in_dict( - state_event, ["type", "origin_server_ts", "content", "sender"] - ) - - logger.debug( - "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s", - state_event, - auth_event_ids, - ) - - event_dict = { - "type": state_event["type"], - "origin_server_ts": state_event["origin_server_ts"], - "content": state_event["content"], - "room_id": room_id, - "sender": state_event["sender"], - "state_key": state_event["state_key"], - } - - # Mark all events as historical - event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - - # Make the state events float off on their own - fake_prev_event_id = "$" + random_string(43) - - # TODO: This is pretty much the same as some other code to handle inserting state in this file - if event_dict["type"] == EventTypes.Member: - membership = event_dict["content"].get("membership", None) - event_id, _ = await self.room_member_handler.update_membership( - await self._create_requester_for_user_id_from_app_service( - state_event["sender"], requester.app_service - ), - target=UserID.from_string(event_dict["state_key"]), - room_id=room_id, - action=membership, - content=event_dict["content"], - outlier=True, - prev_event_ids=[fake_prev_event_id], - # Make sure to use a copy of this list because we modify it - # later in the loop here. Otherwise it will be the same - # reference and also update in the event when we append later. - auth_event_ids=auth_event_ids.copy(), - ) - else: - # TODO: Add some complement tests that adds state that is not member joins - # and will use this code path. Maybe we only want to support join state events - # and can get rid of this `else`? - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - await self._create_requester_for_user_id_from_app_service( - state_event["sender"], requester.app_service - ), - event_dict, - outlier=True, - prev_event_ids=[fake_prev_event_id], - # Make sure to use a copy of this list because we modify it - # later in the loop here. Otherwise it will be the same - # reference and also update in the event when we append later. - auth_event_ids=auth_event_ids.copy(), - ) - event_id = event.event_id - - state_events_at_start.append(event_id) - auth_event_ids.append(event_id) - - events_to_create = body["events"] - - inherited_depth = await self._inherit_depth_from_prev_ids( - prev_events_from_query - ) - - # Figure out which chunk to connect to. If they passed in - # chunk_id_from_query let's use it. The chunk ID passed in comes - # from the chunk_id in the "insertion" event from the previous chunk. - last_event_in_chunk = events_to_create[-1] - chunk_id_to_connect_to = chunk_id_from_query - base_insertion_event = None - if chunk_id_from_query: - # All but the first base insertion event should point at a fake - # event, which causes the HS to ask for the state at the start of - # the chunk later. - prev_event_ids = [fake_prev_event_id] - # TODO: Verify the chunk_id_from_query corresponds to an insertion event - pass - # Otherwise, create an insertion event to act as a starting point. - # - # We don't always have an insertion event to start hanging more history - # off of (ideally there would be one in the main DAG, but that's not the - # case if we're wanting to add history to e.g. existing rooms without - # an insertion event), in which case we just create a new insertion event - # that can then get pointed to by a "marker" event later. - else: - prev_event_ids = prev_events_from_query - - base_insertion_event_dict = self._create_insertion_event_dict( - sender=requester.user.to_string(), - room_id=room_id, - origin_server_ts=last_event_in_chunk["origin_server_ts"], - ) - base_insertion_event_dict["prev_events"] = prev_event_ids.copy() - - ( - base_insertion_event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - await self._create_requester_for_user_id_from_app_service( - base_insertion_event_dict["sender"], - requester.app_service, - ), - base_insertion_event_dict, - prev_event_ids=base_insertion_event_dict.get("prev_events"), - auth_event_ids=auth_event_ids, - historical=True, - depth=inherited_depth, - ) - - chunk_id_to_connect_to = base_insertion_event["content"][ - EventContentFields.MSC2716_NEXT_CHUNK_ID - ] - - # Connect this current chunk to the insertion event from the previous chunk - chunk_event = { - "type": EventTypes.MSC2716_CHUNK, - "sender": requester.user.to_string(), - "room_id": room_id, - "content": { - EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to, - EventContentFields.MSC2716_HISTORICAL: True, - }, - # Since the chunk event is put at the end of the chunk, - # where the newest-in-time event is, copy the origin_server_ts from - # the last event we're inserting - "origin_server_ts": last_event_in_chunk["origin_server_ts"], - } - # Add the chunk event to the end of the chunk (newest-in-time) - events_to_create.append(chunk_event) - - # Add an "insertion" event to the start of each chunk (next to the oldest-in-time - # event in the chunk) so the next chunk can be connected to this one. - insertion_event = self._create_insertion_event_dict( - sender=requester.user.to_string(), - room_id=room_id, - # Since the insertion event is put at the start of the chunk, - # where the oldest-in-time event is, copy the origin_server_ts from - # the first event we're inserting - origin_server_ts=events_to_create[0]["origin_server_ts"], - ) - # Prepend the insertion event to the start of the chunk (oldest-in-time) - events_to_create = [insertion_event] + events_to_create - - event_ids = [] - events_to_persist = [] - for ev in events_to_create: - assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) - - event_dict = { - "type": ev["type"], - "origin_server_ts": ev["origin_server_ts"], - "content": ev["content"], - "room_id": room_id, - "sender": ev["sender"], # requester.user.to_string(), - "prev_events": prev_event_ids.copy(), - } - - # Mark all events as historical - event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - - event, context = await self.event_creation_handler.create_event( - await self._create_requester_for_user_id_from_app_service( - ev["sender"], requester.app_service - ), - event_dict, - prev_event_ids=event_dict.get("prev_events"), - auth_event_ids=auth_event_ids, - historical=True, - depth=inherited_depth, - ) - logger.debug( - "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s", - event, - prev_event_ids, - auth_event_ids, - ) - - assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( - event.sender, - ) - - events_to_persist.append((event, context)) - event_id = event.event_id - - event_ids.append(event_id) - prev_event_ids = [event_id] - - # Persist events in reverse-chronological order so they have the - # correct stream_ordering as they are backfilled (which decrements). - # Events are sorted by (topological_ordering, stream_ordering) - # where topological_ordering is just depth. - for (event, context) in reversed(events_to_persist): - ev = await self.event_creation_handler.handle_new_client_event( - await self._create_requester_for_user_id_from_app_service( - event["sender"], requester.app_service - ), - event=event, - context=context, - ) - - # Add the base_insertion_event to the bottom of the list we return - if base_insertion_event is not None: - event_ids.append(base_insertion_event.event_id) - - return 200, { - "state_events": state_events_at_start, - "events": event_ids, - "next_chunk_id": insertion_event["content"][ - EventContentFields.MSC2716_NEXT_CHUNK_ID - ], - } - - def on_GET(self, request, room_id): - return 501, "Not implemented" - - def on_PUT(self, request, room_id): - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id - ) - - -def register_servlets(hs, http_server): - msc2716_enabled = hs.config.experimental.msc2716_enabled - - if msc2716_enabled: - RoomBatchSendEventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py deleted file mode 100644 index 263596be86..0000000000 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ /dev/null @@ -1,391 +0,0 @@ -# Copyright 2017, 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.errors import Codes, NotFoundError, SynapseError -from synapse.http.servlet import ( - RestServlet, - parse_json_object_from_request, - parse_string, -) - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class RoomKeysServlet(RestServlet): - PATTERNS = client_patterns( - "/room_keys/keys(/(?P[^/]+))?(/(?P[^/]+))?$" - ) - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.auth = hs.get_auth() - self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - - async def on_PUT(self, request, room_id, session_id): - """ - Uploads one or more encrypted E2E room keys for backup purposes. - room_id: the ID of the room the keys are for (optional) - session_id: the ID for the E2E room keys for the room (optional) - version: the version of the user's backup which this data is for. - the version must already have been created via the /room_keys/version API. - - Each session has: - * first_message_index: a numeric index indicating the oldest message - encrypted by this session. - * forwarded_count: how many times the uploading client claims this key - has been shared (forwarded) - * is_verified: whether the client that uploaded the keys claims they - were sent by a device which they've verified - * session_data: base64-encrypted data describing the session. - - Returns 200 OK on success with body {} - Returns 403 Forbidden if the version in question is not the most recently - created version (i.e. if this is an old client trying to write to a stale backup) - Returns 404 Not Found if the version in question doesn't exist - - The API is designed to be otherwise agnostic to the room_key encryption - algorithm being used. Sessions are merged with existing ones in the - backup using the heuristics: - * is_verified sessions always win over unverified sessions - * older first_message_index always win over newer sessions - * lower forwarded_count always wins over higher forwarded_count - - We trust the clients not to lie and corrupt their own backups. - It also means that if your access_token is stolen, the attacker could - delete your backup. - - POST /room_keys/keys/!abc:matrix.org/c0ff33?version=1 HTTP/1.1 - Content-Type: application/json - - { - "first_message_index": 1, - "forwarded_count": 1, - "is_verified": false, - "session_data": "SSBBTSBBIEZJU0gK" - } - - Or... - - POST /room_keys/keys/!abc:matrix.org?version=1 HTTP/1.1 - Content-Type: application/json - - { - "sessions": { - "c0ff33": { - "first_message_index": 1, - "forwarded_count": 1, - "is_verified": false, - "session_data": "SSBBTSBBIEZJU0gK" - } - } - } - - Or... - - POST /room_keys/keys?version=1 HTTP/1.1 - Content-Type: application/json - - { - "rooms": { - "!abc:matrix.org": { - "sessions": { - "c0ff33": { - "first_message_index": 1, - "forwarded_count": 1, - "is_verified": false, - "session_data": "SSBBTSBBIEZJU0gK" - } - } - } - } - } - """ - requester = await self.auth.get_user_by_req(request, allow_guest=False) - user_id = requester.user.to_string() - body = parse_json_object_from_request(request) - version = parse_string(request, "version") - - if session_id: - body = {"sessions": {session_id: body}} - - if room_id: - body = {"rooms": {room_id: body}} - - ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) - return 200, ret - - async def on_GET(self, request, room_id, session_id): - """ - Retrieves one or more encrypted E2E room keys for backup purposes. - Symmetric with the PUT version of the API. - - room_id: the ID of the room to retrieve the keys for (optional) - session_id: the ID for the E2E room keys to retrieve the keys for (optional) - version: the version of the user's backup which this data is for. - the version must already have been created via the /change_secret API. - - Returns as follows: - - GET /room_keys/keys/!abc:matrix.org/c0ff33?version=1 HTTP/1.1 - { - "first_message_index": 1, - "forwarded_count": 1, - "is_verified": false, - "session_data": "SSBBTSBBIEZJU0gK" - } - - Or... - - GET /room_keys/keys/!abc:matrix.org?version=1 HTTP/1.1 - { - "sessions": { - "c0ff33": { - "first_message_index": 1, - "forwarded_count": 1, - "is_verified": false, - "session_data": "SSBBTSBBIEZJU0gK" - } - } - } - - Or... - - GET /room_keys/keys?version=1 HTTP/1.1 - { - "rooms": { - "!abc:matrix.org": { - "sessions": { - "c0ff33": { - "first_message_index": 1, - "forwarded_count": 1, - "is_verified": false, - "session_data": "SSBBTSBBIEZJU0gK" - } - } - } - } - } - """ - requester = await self.auth.get_user_by_req(request, allow_guest=False) - user_id = requester.user.to_string() - version = parse_string(request, "version", required=True) - - room_keys = await self.e2e_room_keys_handler.get_room_keys( - user_id, version, room_id, session_id - ) - - # Convert room_keys to the right format to return. - if session_id: - # If the client requests a specific session, but that session was - # not backed up, then return an M_NOT_FOUND. - if room_keys["rooms"] == {}: - raise NotFoundError("No room_keys found") - else: - room_keys = room_keys["rooms"][room_id]["sessions"][session_id] - elif room_id: - # If the client requests all sessions from a room, but no sessions - # are found, then return an empty result rather than an error, so - # that clients don't have to handle an error condition, and an - # empty result is valid. (Similarly if the client requests all - # sessions from the backup, but in that case, room_keys is already - # in the right format, so we don't need to do anything about it.) - if room_keys["rooms"] == {}: - room_keys = {"sessions": {}} - else: - room_keys = room_keys["rooms"][room_id] - - return 200, room_keys - - async def on_DELETE(self, request, room_id, session_id): - """ - Deletes one or more encrypted E2E room keys for a user for backup purposes. - - DELETE /room_keys/keys/!abc:matrix.org/c0ff33?version=1 - HTTP/1.1 200 OK - {} - - room_id: the ID of the room whose keys to delete (optional) - session_id: the ID for the E2E session to delete (optional) - version: the version of the user's backup which this data is for. - the version must already have been created via the /change_secret API. - """ - - requester = await self.auth.get_user_by_req(request, allow_guest=False) - user_id = requester.user.to_string() - version = parse_string(request, "version") - - ret = await self.e2e_room_keys_handler.delete_room_keys( - user_id, version, room_id, session_id - ) - return 200, ret - - -class RoomKeysNewVersionServlet(RestServlet): - PATTERNS = client_patterns("/room_keys/version$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.auth = hs.get_auth() - self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - - async def on_POST(self, request): - """ - Create a new backup version for this user's room_keys with the given - info. The version is allocated by the server and returned to the user - in the response. This API is intended to be used whenever the user - changes the encryption key for their backups, ensuring that backups - encrypted with different keys don't collide. - - It takes out an exclusive lock on this user's room_key backups, to ensure - clients only upload to the current backup. - - The algorithm passed in the version info is a reverse-DNS namespaced - identifier to describe the format of the encrypted backupped keys. - - The auth_data is { user_id: "user_id", nonce: } - encrypted using the algorithm and current encryption key described above. - - POST /room_keys/version - Content-Type: application/json - { - "algorithm": "m.megolm_backup.v1", - "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" - } - - HTTP/1.1 200 OK - Content-Type: application/json - { - "version": 12345 - } - """ - requester = await self.auth.get_user_by_req(request, allow_guest=False) - user_id = requester.user.to_string() - info = parse_json_object_from_request(request) - - new_version = await self.e2e_room_keys_handler.create_version(user_id, info) - return 200, {"version": new_version} - - # we deliberately don't have a PUT /version, as these things really should - # be immutable to avoid people footgunning - - -class RoomKeysVersionServlet(RestServlet): - PATTERNS = client_patterns("/room_keys/version(/(?P[^/]+))?$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.auth = hs.get_auth() - self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - - async def on_GET(self, request, version): - """ - Retrieve the version information about a given version of the user's - room_keys backup. If the version part is missing, returns info about the - most current backup version (if any) - - It takes out an exclusive lock on this user's room_key backups, to ensure - clients only upload to the current backup. - - Returns 404 if the given version does not exist. - - GET /room_keys/version/12345 HTTP/1.1 - { - "version": "12345", - "algorithm": "m.megolm_backup.v1", - "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" - } - """ - requester = await self.auth.get_user_by_req(request, allow_guest=False) - user_id = requester.user.to_string() - - try: - info = await self.e2e_room_keys_handler.get_version_info(user_id, version) - except SynapseError as e: - if e.code == 404: - raise SynapseError(404, "No backup found", Codes.NOT_FOUND) - return 200, info - - async def on_DELETE(self, request, version): - """ - Delete the information about a given version of the user's - room_keys backup. If the version part is missing, deletes the most - current backup version (if any). Doesn't delete the actual room data. - - DELETE /room_keys/version/12345 HTTP/1.1 - HTTP/1.1 200 OK - {} - """ - if version is None: - raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND) - - requester = await self.auth.get_user_by_req(request, allow_guest=False) - user_id = requester.user.to_string() - - await self.e2e_room_keys_handler.delete_version(user_id, version) - return 200, {} - - async def on_PUT(self, request, version): - """ - Update the information about a given version of the user's room_keys backup. - - POST /room_keys/version/12345 HTTP/1.1 - Content-Type: application/json - { - "algorithm": "m.megolm_backup.v1", - "auth_data": { - "public_key": "abcdefg", - "signatures": { - "ed25519:something": "hijklmnop" - } - }, - "version": "12345" - } - - HTTP/1.1 200 OK - Content-Type: application/json - {} - """ - requester = await self.auth.get_user_by_req(request, allow_guest=False) - user_id = requester.user.to_string() - info = parse_json_object_from_request(request) - - if version is None: - raise SynapseError( - 400, "No version specified to update", Codes.MISSING_PARAM - ) - - await self.e2e_room_keys_handler.update_version(user_id, version, info) - return 200, {} - - -def register_servlets(hs, http_server): - RoomKeysServlet(hs).register(http_server) - RoomKeysVersionServlet(hs).register(http_server) - RoomKeysNewVersionServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py deleted file mode 100644 index 6d1b083acb..0000000000 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.errors import Codes, ShadowBanError, SynapseError -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_json_object_from_request, -) -from synapse.util import stringutils - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class RoomUpgradeRestServlet(RestServlet): - """Handler for room upgrade requests. - - Handles requests of the form: - - POST /_matrix/client/r0/rooms/$roomid/upgrade HTTP/1.1 - Content-Type: application/json - - { - "new_version": "2", - } - - Creates a new room and shuts down the old one. Returns the ID of the new room. - - Args: - hs (synapse.server.HomeServer): - """ - - PATTERNS = client_patterns( - # /rooms/$roomid/upgrade - "/rooms/(?P[^/]*)/upgrade$" - ) - - def __init__(self, hs): - super().__init__() - self._hs = hs - self._room_creation_handler = hs.get_room_creation_handler() - self._auth = hs.get_auth() - - async def on_POST(self, request, room_id): - requester = await self._auth.get_user_by_req(request) - - content = parse_json_object_from_request(request) - assert_params_in_dict(content, ("new_version",)) - - new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"]) - if new_version is None: - raise SynapseError( - 400, - "Your homeserver does not support this room version", - Codes.UNSUPPORTED_ROOM_VERSION, - ) - - try: - new_room_id = await self._room_creation_handler.upgrade_room( - requester, room_id, new_version - ) - except ShadowBanError: - # Generate a random room ID. - new_room_id = stringutils.random_string(18) - - ret = {"replacement_room": new_room_id} - - return 200, ret - - -def register_servlets(hs, http_server): - RoomUpgradeRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py deleted file mode 100644 index d537d811d8..0000000000 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Tuple - -from synapse.http import servlet -from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request -from synapse.logging.opentracing import set_tag, trace -from synapse.rest.client.transactions import HttpTransactionCache - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class SendToDeviceRestServlet(servlet.RestServlet): - PATTERNS = client_patterns( - "/sendToDevice/(?P[^/]*)/(?P[^/]*)$" - ) - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.txns = HttpTransactionCache(hs) - self.device_message_handler = hs.get_device_message_handler() - - @trace(opname="sendToDevice") - def on_PUT(self, request, message_type, txn_id): - set_tag("message_type", message_type) - set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request( - request, self._put, request, message_type, txn_id - ) - - async def _put(self, request, message_type, txn_id): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - content = parse_json_object_from_request(request) - assert_params_in_dict(content, ("messages",)) - - await self.device_message_handler.send_device_message( - requester, message_type, content["messages"] - ) - - response: Tuple[int, dict] = (200, {}) - return response - - -def register_servlets(hs, http_server): - SendToDeviceRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/shared_rooms.py b/synapse/rest/client/v2_alpha/shared_rooms.py deleted file mode 100644 index d2e7f04b40..0000000000 --- a/synapse/rest/client/v2_alpha/shared_rooms.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2020 Half-Shot -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from synapse.api.errors import Codes, SynapseError -from synapse.http.servlet import RestServlet -from synapse.types import UserID - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class UserSharedRoomsServlet(RestServlet): - """ - GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1 - """ - - PATTERNS = client_patterns( - "/uk.half-shot.msc2666/user/shared_rooms/(?P[^/]*)", - releases=(), # This is an unstable feature - ) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.user_directory_active = hs.config.update_user_directory - - async def on_GET(self, request, user_id): - - if not self.user_directory_active: - raise SynapseError( - code=400, - msg="The user directory is disabled on this server. Cannot determine shared rooms.", - errcode=Codes.FORBIDDEN, - ) - - UserID.from_string(user_id) - - requester = await self.auth.get_user_by_req(request) - if user_id == requester.user.to_string(): - raise SynapseError( - code=400, - msg="You cannot request a list of shared rooms with yourself", - errcode=Codes.FORBIDDEN, - ) - rooms = await self.store.get_shared_rooms_for_users( - requester.user.to_string(), user_id - ) - - return 200, {"joined": list(rooms)} - - -def register_servlets(hs, http_server): - UserSharedRoomsServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py deleted file mode 100644 index e18f4d01b3..0000000000 --- a/synapse/rest/client/v2_alpha/sync.py +++ /dev/null @@ -1,532 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import itertools -import logging -from collections import defaultdict -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple - -from synapse.api.constants import Membership, PresenceState -from synapse.api.errors import Codes, StoreError, SynapseError -from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection -from synapse.events.utils import ( - format_event_for_client_v2_without_room_id, - format_event_raw, -) -from synapse.handlers.presence import format_user_presence_state -from synapse.handlers.sync import KnockedSyncResult, SyncConfig -from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string -from synapse.http.site import SynapseRequest -from synapse.types import JsonDict, StreamToken -from synapse.util import json_decoder - -from ._base import client_patterns, set_timeline_upper_limit - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class SyncRestServlet(RestServlet): - """ - - GET parameters:: - timeout(int): How long to wait for new events in milliseconds. - since(batch_token): Batch token when asking for incremental deltas. - set_presence(str): What state the device presence should be set to. - default is "online". - filter(filter_id): A filter to apply to the events returned. - - Response JSON:: - { - "next_batch": // batch token for the next /sync - "presence": // presence data for the user. - "rooms": { - "join": { // Joined rooms being updated. - "${room_id}": { // Id of the room being updated - "event_map": // Map of EventID -> event JSON. - "timeline": { // The recent events in the room if gap is "true" - "limited": // Was the per-room event limit exceeded? - // otherwise the next events in the room. - "events": [] // list of EventIDs in the "event_map". - "prev_batch": // back token for getting previous events. - } - "state": {"events": []} // list of EventIDs updating the - // current state to be what it should - // be at the end of the batch. - "ephemeral": {"events": []} // list of event objects - } - }, - "invite": {}, // Invited rooms being updated. - "leave": {} // Archived rooms being updated. - } - } - """ - - PATTERNS = client_patterns("/sync$") - ALLOWED_PRESENCE = {"online", "offline", "unavailable"} - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.sync_handler = hs.get_sync_handler() - self.clock = hs.get_clock() - self.filtering = hs.get_filtering() - self.presence_handler = hs.get_presence_handler() - self._server_notices_sender = hs.get_server_notices_sender() - self._event_serializer = hs.get_event_client_serializer() - - async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - # This will always be set by the time Twisted calls us. - assert request.args is not None - - if b"from" in request.args: - # /events used to use 'from', but /sync uses 'since'. - # Lets be helpful and whine if we see a 'from'. - raise SynapseError( - 400, "'from' is not a valid query parameter. Did you mean 'since'?" - ) - - requester = await self.auth.get_user_by_req(request, allow_guest=True) - user = requester.user - device_id = requester.device_id - - timeout = parse_integer(request, "timeout", default=0) - since = parse_string(request, "since") - set_presence = parse_string( - request, - "set_presence", - default="online", - allowed_values=self.ALLOWED_PRESENCE, - ) - filter_id = parse_string(request, "filter") - full_state = parse_boolean(request, "full_state", default=False) - - logger.debug( - "/sync: user=%r, timeout=%r, since=%r, " - "set_presence=%r, filter_id=%r, device_id=%r", - user, - timeout, - since, - set_presence, - filter_id, - device_id, - ) - - request_key = (user, timeout, since, filter_id, full_state, device_id) - - if filter_id is None: - filter_collection = DEFAULT_FILTER_COLLECTION - elif filter_id.startswith("{"): - try: - filter_object = json_decoder.decode(filter_id) - set_timeline_upper_limit( - filter_object, self.hs.config.filter_timeline_limit - ) - except Exception: - raise SynapseError(400, "Invalid filter JSON") - self.filtering.check_valid_filter(filter_object) - filter_collection = FilterCollection(filter_object) - else: - try: - filter_collection = await self.filtering.get_user_filter( - user.localpart, filter_id - ) - except StoreError as err: - if err.code != 404: - raise - # fix up the description and errcode to be more useful - raise SynapseError(400, "No such filter", errcode=Codes.INVALID_PARAM) - - sync_config = SyncConfig( - user=user, - filter_collection=filter_collection, - is_guest=requester.is_guest, - request_key=request_key, - device_id=device_id, - ) - - since_token = None - if since is not None: - since_token = await StreamToken.from_string(self.store, since) - - # send any outstanding server notices to the user. - await self._server_notices_sender.on_user_syncing(user.to_string()) - - affect_presence = set_presence != PresenceState.OFFLINE - - if affect_presence: - await self.presence_handler.set_state( - user, {"presence": set_presence}, True - ) - - context = await self.presence_handler.user_syncing( - user.to_string(), affect_presence=affect_presence - ) - with context: - sync_result = await self.sync_handler.wait_for_sync_for_user( - requester, - sync_config, - since_token=since_token, - timeout=timeout, - full_state=full_state, - ) - - # the client may have disconnected by now; don't bother to serialize the - # response if so. - if request._disconnected: - logger.info("Client has disconnected; not serializing response.") - return 200, {} - - time_now = self.clock.time_msec() - response_content = await self.encode_response( - time_now, sync_result, requester.access_token_id, filter_collection - ) - - logger.debug("Event formatting complete") - return 200, response_content - - async def encode_response(self, time_now, sync_result, access_token_id, filter): - logger.debug("Formatting events in sync response") - if filter.event_format == "client": - event_formatter = format_event_for_client_v2_without_room_id - elif filter.event_format == "federation": - event_formatter = format_event_raw - else: - raise Exception("Unknown event format %s" % (filter.event_format,)) - - joined = await self.encode_joined( - sync_result.joined, - time_now, - access_token_id, - filter.event_fields, - event_formatter, - ) - - invited = await self.encode_invited( - sync_result.invited, time_now, access_token_id, event_formatter - ) - - knocked = await self.encode_knocked( - sync_result.knocked, time_now, access_token_id, event_formatter - ) - - archived = await self.encode_archived( - sync_result.archived, - time_now, - access_token_id, - filter.event_fields, - event_formatter, - ) - - logger.debug("building sync response dict") - - response: dict = defaultdict(dict) - response["next_batch"] = await sync_result.next_batch.to_string(self.store) - - if sync_result.account_data: - response["account_data"] = {"events": sync_result.account_data} - if sync_result.presence: - response["presence"] = SyncRestServlet.encode_presence( - sync_result.presence, time_now - ) - - if sync_result.to_device: - response["to_device"] = {"events": sync_result.to_device} - - if sync_result.device_lists.changed: - response["device_lists"]["changed"] = list(sync_result.device_lists.changed) - if sync_result.device_lists.left: - response["device_lists"]["left"] = list(sync_result.device_lists.left) - - # We always include this because https://github.com/vector-im/element-android/issues/3725 - # The spec isn't terribly clear on when this can be omitted and how a client would tell - # the difference between "no keys present" and "nothing changed" in terms of whole field - # absent / individual key type entry absent - # Corresponding synapse issue: https://github.com/matrix-org/synapse/issues/10456 - response["device_one_time_keys_count"] = sync_result.device_one_time_keys_count - - # https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md - # states that this field should always be included, as long as the server supports the feature. - response[ - "org.matrix.msc2732.device_unused_fallback_key_types" - ] = sync_result.device_unused_fallback_key_types - - if joined: - response["rooms"][Membership.JOIN] = joined - if invited: - response["rooms"][Membership.INVITE] = invited - if knocked: - response["rooms"][Membership.KNOCK] = knocked - if archived: - response["rooms"][Membership.LEAVE] = archived - - if sync_result.groups.join: - response["groups"][Membership.JOIN] = sync_result.groups.join - if sync_result.groups.invite: - response["groups"][Membership.INVITE] = sync_result.groups.invite - if sync_result.groups.leave: - response["groups"][Membership.LEAVE] = sync_result.groups.leave - - return response - - @staticmethod - def encode_presence(events, time_now): - return { - "events": [ - { - "type": "m.presence", - "sender": event.user_id, - "content": format_user_presence_state( - event, time_now, include_user_id=False - ), - } - for event in events - ] - } - - async def encode_joined( - self, rooms, time_now, token_id, event_fields, event_formatter - ): - """ - Encode the joined rooms in a sync result - - Args: - rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync - results for rooms this user is joined to - time_now(int): current time - used as a baseline for age - calculations - token_id(int): ID of the user's auth token - used for namespacing - of transaction IDs - event_fields(list): List of event fields to include. If empty, - all fields will be returned. - event_formatter (func[dict]): function to convert from federation format - to client format - Returns: - dict[str, dict[str, object]]: the joined rooms list, in our - response format - """ - joined = {} - for room in rooms: - joined[room.room_id] = await self.encode_room( - room, - time_now, - token_id, - joined=True, - only_fields=event_fields, - event_formatter=event_formatter, - ) - - return joined - - async def encode_invited(self, rooms, time_now, token_id, event_formatter): - """ - Encode the invited rooms in a sync result - - Args: - rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of - sync results for rooms this user is invited to - time_now(int): current time - used as a baseline for age - calculations - token_id(int): ID of the user's auth token - used for namespacing - of transaction IDs - event_formatter (func[dict]): function to convert from federation format - to client format - - Returns: - dict[str, dict[str, object]]: the invited rooms list, in our - response format - """ - invited = {} - for room in rooms: - invite = await self._event_serializer.serialize_event( - room.invite, - time_now, - token_id=token_id, - event_format=event_formatter, - include_stripped_room_state=True, - ) - unsigned = dict(invite.get("unsigned", {})) - invite["unsigned"] = unsigned - invited_state = list(unsigned.pop("invite_room_state", [])) - invited_state.append(invite) - invited[room.room_id] = {"invite_state": {"events": invited_state}} - - return invited - - async def encode_knocked( - self, - rooms: List[KnockedSyncResult], - time_now: int, - token_id: int, - event_formatter: Callable[[Dict], Dict], - ) -> Dict[str, Dict[str, Any]]: - """ - Encode the rooms we've knocked on in a sync result. - - Args: - rooms: list of sync results for rooms this user is knocking on - time_now: current time - used as a baseline for age calculations - token_id: ID of the user's auth token - used for namespacing of transaction IDs - event_formatter: function to convert from federation format to client format - - Returns: - The list of rooms the user has knocked on, in our response format. - """ - knocked = {} - for room in rooms: - knock = await self._event_serializer.serialize_event( - room.knock, - time_now, - token_id=token_id, - event_format=event_formatter, - include_stripped_room_state=True, - ) - - # Extract the `unsigned` key from the knock event. - # This is where we (cheekily) store the knock state events - unsigned = knock.setdefault("unsigned", {}) - - # Duplicate the dictionary in order to avoid modifying the original - unsigned = dict(unsigned) - - # Extract the stripped room state from the unsigned dict - # This is for clients to get a little bit of information about - # the room they've knocked on, without revealing any sensitive information - knocked_state = list(unsigned.pop("knock_room_state", [])) - - # Append the actual knock membership event itself as well. This provides - # the client with: - # - # * A knock state event that they can use for easier internal tracking - # * The rough timestamp of when the knock occurred contained within the event - knocked_state.append(knock) - - # Build the `knock_state` dictionary, which will contain the state of the - # room that the client has knocked on - knocked[room.room_id] = {"knock_state": {"events": knocked_state}} - - return knocked - - async def encode_archived( - self, rooms, time_now, token_id, event_fields, event_formatter - ): - """ - Encode the archived rooms in a sync result - - Args: - rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of - sync results for rooms this user is joined to - time_now(int): current time - used as a baseline for age - calculations - token_id(int): ID of the user's auth token - used for namespacing - of transaction IDs - event_fields(list): List of event fields to include. If empty, - all fields will be returned. - event_formatter (func[dict]): function to convert from federation format - to client format - Returns: - dict[str, dict[str, object]]: The invited rooms list, in our - response format - """ - joined = {} - for room in rooms: - joined[room.room_id] = await self.encode_room( - room, - time_now, - token_id, - joined=False, - only_fields=event_fields, - event_formatter=event_formatter, - ) - - return joined - - async def encode_room( - self, room, time_now, token_id, joined, only_fields, event_formatter - ): - """ - Args: - room (JoinedSyncResult|ArchivedSyncResult): sync result for a - single room - time_now (int): current time - used as a baseline for age - calculations - token_id (int): ID of the user's auth token - used for namespacing - of transaction IDs - joined (bool): True if the user is joined to this room - will mean - we handle ephemeral events - only_fields(list): Optional. The list of event fields to include. - event_formatter (func[dict]): function to convert from federation format - to client format - Returns: - dict[str, object]: the room, encoded in our response format - """ - - def serialize(events): - return self._event_serializer.serialize_events( - events, - time_now=time_now, - # We don't bundle "live" events, as otherwise clients - # will end up double counting annotations. - bundle_aggregations=False, - token_id=token_id, - event_format=event_formatter, - only_event_fields=only_fields, - ) - - state_dict = room.state - timeline_events = room.timeline.events - - state_events = state_dict.values() - - for event in itertools.chain(state_events, timeline_events): - # We've had bug reports that events were coming down under the - # wrong room. - if event.room_id != room.room_id: - logger.warning( - "Event %r is under room %r instead of %r", - event.event_id, - room.room_id, - event.room_id, - ) - - serialized_state = await serialize(state_events) - serialized_timeline = await serialize(timeline_events) - - account_data = room.account_data - - result = { - "timeline": { - "events": serialized_timeline, - "prev_batch": await room.timeline.prev_batch.to_string(self.store), - "limited": room.timeline.limited, - }, - "state": {"events": serialized_state}, - "account_data": {"events": account_data}, - } - - if joined: - ephemeral_events = room.ephemeral - result["ephemeral"] = {"events": ephemeral_events} - result["unread_notifications"] = room.unread_notifications - result["summary"] = room.summary - result["org.matrix.msc2654.unread_count"] = room.unread_count - - return result - - -def register_servlets(hs, http_server): - SyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py deleted file mode 100644 index c14f83be18..0000000000 --- a/synapse/rest/client/v2_alpha/tags.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.errors import AuthError -from synapse.http.servlet import RestServlet, parse_json_object_from_request - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class TagListServlet(RestServlet): - """ - GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 - """ - - PATTERNS = client_patterns("/user/(?P[^/]*)/rooms/(?P[^/]*)/tags") - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.store = hs.get_datastore() - - async def on_GET(self, request, user_id, room_id): - requester = await self.auth.get_user_by_req(request) - if user_id != requester.user.to_string(): - raise AuthError(403, "Cannot get tags for other users.") - - tags = await self.store.get_tags_for_room(user_id, room_id) - - return 200, {"tags": tags} - - -class TagServlet(RestServlet): - """ - PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 - DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 - """ - - PATTERNS = client_patterns( - "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags/(?P[^/]*)" - ) - - def __init__(self, hs): - super().__init__() - self.auth = hs.get_auth() - self.handler = hs.get_account_data_handler() - - async def on_PUT(self, request, user_id, room_id, tag): - requester = await self.auth.get_user_by_req(request) - if user_id != requester.user.to_string(): - raise AuthError(403, "Cannot add tags for other users.") - - body = parse_json_object_from_request(request) - - await self.handler.add_tag_to_room(user_id, room_id, tag, body) - - return 200, {} - - async def on_DELETE(self, request, user_id, room_id, tag): - requester = await self.auth.get_user_by_req(request) - if user_id != requester.user.to_string(): - raise AuthError(403, "Cannot add tags for other users.") - - await self.handler.remove_tag_from_room(user_id, room_id, tag) - - return 200, {} - - -def register_servlets(hs, http_server): - TagListServlet(hs).register(http_server) - TagServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py deleted file mode 100644 index b5c67c9bb6..0000000000 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging - -from synapse.api.constants import ThirdPartyEntityKind -from synapse.http.servlet import RestServlet - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class ThirdPartyProtocolsServlet(RestServlet): - PATTERNS = client_patterns("/thirdparty/protocols") - - def __init__(self, hs): - super().__init__() - - self.auth = hs.get_auth() - self.appservice_handler = hs.get_application_service_handler() - - async def on_GET(self, request): - await self.auth.get_user_by_req(request, allow_guest=True) - - protocols = await self.appservice_handler.get_3pe_protocols() - return 200, protocols - - -class ThirdPartyProtocolServlet(RestServlet): - PATTERNS = client_patterns("/thirdparty/protocol/(?P[^/]+)$") - - def __init__(self, hs): - super().__init__() - - self.auth = hs.get_auth() - self.appservice_handler = hs.get_application_service_handler() - - async def on_GET(self, request, protocol): - await self.auth.get_user_by_req(request, allow_guest=True) - - protocols = await self.appservice_handler.get_3pe_protocols( - only_protocol=protocol - ) - if protocol in protocols: - return 200, protocols[protocol] - else: - return 404, {"error": "Unknown protocol"} - - -class ThirdPartyUserServlet(RestServlet): - PATTERNS = client_patterns("/thirdparty/user(/(?P[^/]+))?$") - - def __init__(self, hs): - super().__init__() - - self.auth = hs.get_auth() - self.appservice_handler = hs.get_application_service_handler() - - async def on_GET(self, request, protocol): - await self.auth.get_user_by_req(request, allow_guest=True) - - fields = request.args - fields.pop(b"access_token", None) - - results = await self.appservice_handler.query_3pe( - ThirdPartyEntityKind.USER, protocol, fields - ) - - return 200, results - - -class ThirdPartyLocationServlet(RestServlet): - PATTERNS = client_patterns("/thirdparty/location(/(?P[^/]+))?$") - - def __init__(self, hs): - super().__init__() - - self.auth = hs.get_auth() - self.appservice_handler = hs.get_application_service_handler() - - async def on_GET(self, request, protocol): - await self.auth.get_user_by_req(request, allow_guest=True) - - fields = request.args - fields.pop(b"access_token", None) - - results = await self.appservice_handler.query_3pe( - ThirdPartyEntityKind.LOCATION, protocol, fields - ) - - return 200, results - - -def register_servlets(hs, http_server): - ThirdPartyProtocolsServlet(hs).register(http_server) - ThirdPartyProtocolServlet(hs).register(http_server) - ThirdPartyUserServlet(hs).register(http_server) - ThirdPartyLocationServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py deleted file mode 100644 index b2f858545c..0000000000 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.api.errors import AuthError -from synapse.http.servlet import RestServlet - -from ._base import client_patterns - - -class TokenRefreshRestServlet(RestServlet): - """ - Exchanges refresh tokens for a pair of an access token and a new refresh - token. - """ - - PATTERNS = client_patterns("/tokenrefresh") - - def __init__(self, hs): - super().__init__() - - async def on_POST(self, request): - raise AuthError(403, "tokenrefresh is no longer supported.") - - -def register_servlets(hs, http_server): - TokenRefreshRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py deleted file mode 100644 index 7e8912f0b9..0000000000 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request - -from ._base import client_patterns - -logger = logging.getLogger(__name__) - - -class UserDirectorySearchRestServlet(RestServlet): - PATTERNS = client_patterns("/user_directory/search$") - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super().__init__() - self.hs = hs - self.auth = hs.get_auth() - self.user_directory_handler = hs.get_user_directory_handler() - - async def on_POST(self, request): - """Searches for users in directory - - Returns: - dict of the form:: - - { - "limited": , # whether there were more results or not - "results": [ # Ordered by best match first - { - "user_id": , - "display_name": , - "avatar_url": - } - ] - } - """ - requester = await self.auth.get_user_by_req(request, allow_guest=False) - user_id = requester.user.to_string() - - if not self.hs.config.user_directory_search_enabled: - return 200, {"limited": False, "results": []} - - body = parse_json_object_from_request(request) - - limit = body.get("limit", 10) - limit = min(limit, 50) - - try: - search_term = body["search_term"] - except Exception: - raise SynapseError(400, "`search_term` is required field") - - results = await self.user_directory_handler.search_users( - user_id, search_term, limit - ) - - return 200, results - - -def register_servlets(hs, http_server): - UserDirectorySearchRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/voip.py b/synapse/rest/client/voip.py new file mode 100644 index 0000000000..f53020520d --- /dev/null +++ b/synapse/rest/client/voip.py @@ -0,0 +1,73 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import hashlib +import hmac + +from synapse.http.servlet import RestServlet +from synapse.rest.client._base import client_patterns + + +class VoipRestServlet(RestServlet): + PATTERNS = client_patterns("/voip/turnServer$", v1=True) + + def __init__(self, hs): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + + async def on_GET(self, request): + requester = await self.auth.get_user_by_req( + request, self.hs.config.turn_allow_guests + ) + + turnUris = self.hs.config.turn_uris + turnSecret = self.hs.config.turn_shared_secret + turnUsername = self.hs.config.turn_username + turnPassword = self.hs.config.turn_password + userLifetime = self.hs.config.turn_user_lifetime + + if turnUris and turnSecret and userLifetime: + expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000 + username = "%d:%s" % (expiry, requester.user.to_string()) + + mac = hmac.new( + turnSecret.encode(), msg=username.encode(), digestmod=hashlib.sha1 + ) + # We need to use standard padded base64 encoding here + # encode_base64 because we need to add the standard padding to get the + # same result as the TURN server. + password = base64.b64encode(mac.digest()).decode("ascii") + + elif turnUris and turnUsername and turnPassword and userLifetime: + username = turnUsername + password = turnPassword + + else: + return 200, {} + + return ( + 200, + { + "username": username, + "password": password, + "ttl": userLifetime / 1000, + "uris": turnUris, + }, + ) + + +def register_servlets(hs, http_server): + VoipRestServlet(hs).register(http_server) diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py index 5527e278db..d66aeb00eb 100644 --- a/tests/app/test_phone_stats_home.py +++ b/tests/app/test_phone_stats_home.py @@ -1,6 +1,6 @@ import synapse from synapse.app.phone_stats_home import start_phone_stats_home -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests import unittest from tests.unittest import HomeserverTestCase diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index 3f41e99950..6b87f571b8 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -22,7 +22,7 @@ from synapse.federation.units import Transaction from synapse.handlers.presence import UserPresenceState from synapse.module_api import ModuleApi from synapse.rest import admin -from synapse.rest.client.v1 import login, presence, room +from synapse.rest.client import login, presence, room from synapse.types import JsonDict, StreamToken, create_requester from tests.handlers.test_sync import generate_sync_config diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index 48e98aac79..ca27388ae8 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -14,7 +14,7 @@ from synapse.events.snapshot import EventContext from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests import unittest from tests.test_utils.event_injection import create_event diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 1a809b2a6a..7b486aba4a 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -16,7 +16,7 @@ from unittest.mock import Mock from synapse.api.errors import Codes, SynapseError from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.types import UserID from tests import unittest diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 802c5ad299..f0aa8ed9db 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -6,7 +6,7 @@ from synapse.events import EventBase from synapse.federation.sender import PerDestinationQueue, TransactionManager from synapse.federation.units import Edu from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.util.retryutils import NotRetryingDestination from tests.test_utils import event_injection, make_awaitable diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index b00dd143d6..65b18fbd7a 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -21,7 +21,7 @@ from twisted.internet import defer from synapse.api.constants import RoomEncryptionAlgorithms from synapse.rest import admin -from synapse.rest.client.v1 import login +from synapse.rest.client import login from synapse.types import JsonDict, ReadReceipt from tests.test_utils import make_awaitable diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 1737891564..0b60cc4261 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -19,7 +19,7 @@ from parameterized import parameterized from synapse.events import make_event_from_dict from synapse.federation.federation_server import server_matches_acl_event from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests import unittest diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index aab44bce4a..383214ab50 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -18,7 +18,7 @@ from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.room_versions import RoomVersions from synapse.events import builder from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.types import RoomAlias diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py index 18a734daf4..59de1142b1 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py @@ -15,12 +15,10 @@ from collections import Counter from unittest.mock import Mock -import synapse.api.errors -import synapse.handlers.admin import synapse.rest.admin import synapse.storage from synapse.api.constants import EventTypes -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests import unittest diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 7a8041ab44..a0a48b564e 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -19,7 +19,7 @@ import synapse import synapse.api.errors from synapse.api.constants import EventTypes from synapse.config.room_directory import RoomDirectoryConfig -from synapse.rest.client.v1 import directory, login, room +from synapse.rest.client import directory, login, room from synapse.types import RoomAlias, create_requester from tests import unittest diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 4140fcefc2..c72a8972a3 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -22,7 +22,7 @@ from synapse.events import EventBase from synapse.federation.federation_base import event_from_pdu_json from synapse.logging.context import LoggingContext, run_in_background from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.util.stringutils import random_string from tests import unittest diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index a8a9fc5b62..8a8d369fac 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -18,7 +18,7 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.types import create_requester from synapse.util.stringutils import random_string diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 32651db096..38e6d9f536 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -20,8 +20,7 @@ from unittest.mock import Mock from twisted.internet import defer import synapse -from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import devices +from synapse.rest.client import devices, login from synapse.types import JsonDict from tests import unittest diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 29845a80da..0a52bc8b72 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -33,7 +33,7 @@ from synapse.handlers.presence import ( handle_update, ) from synapse.rest import admin -from synapse.rest.client.v1 import room +from synapse.rest.client import room from synapse.types import UserID, get_domain_from_id from tests import unittest diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index 732d746e38..ac800afa7d 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -28,7 +28,7 @@ from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.handlers.room_summary import _child_events_comparison_key, _RoomEntry from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.types import JsonDict, UserID diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index e4059acda3..1ba4c05b9b 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -13,7 +13,7 @@ # limitations under the License. from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.storage.databases.main import stats from tests import unittest diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 549876dc85..e44bf2b3b1 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -18,8 +18,7 @@ from twisted.internet import defer import synapse.rest.admin from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes from synapse.api.room_versions import RoomVersion, RoomVersions -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import user_directory +from synapse.rest.client import login, room, user_directory from synapse.storage.roommember import ProfileInfo from tests import unittest diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 0b817cc701..7dd519cd44 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -20,7 +20,7 @@ from synapse.events import EventBase from synapse.federation.units import Transaction from synapse.handlers.presence import UserPresenceState from synapse.rest import admin -from synapse.rest.client.v1 import login, presence, room +from synapse.rest.client import login, presence, room from synapse.types import create_requester from tests.events.test_presence_router import send_presence_update, sync_presence diff --git a/tests/push/test_email.py b/tests/push/test_email.py index a487706758..e0a3342088 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -21,7 +21,7 @@ from twisted.internet.defer import Deferred import synapse.rest.admin from synapse.api.errors import Codes, SynapseError -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests.unittest import HomeserverTestCase diff --git a/tests/push/test_http.py b/tests/push/test_http.py index ffd75b1491..c068d329a9 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -18,8 +18,7 @@ from twisted.internet.defer import Deferred import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable from synapse.push import PusherConfigException -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import receipts +from synapse.rest.client import login, receipts, room from tests.unittest import HomeserverTestCase, override_config diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 666008425a..f198a94887 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -24,7 +24,7 @@ from synapse.replication.tcp.streams.events import ( EventsStreamRow, ) from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests.replication._base import BaseStreamTestCase from tests.test_utils.event_injection import inject_event, inject_member_event diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py index 1346e0e160..43a16bb141 100644 --- a/tests/replication/test_auth.py +++ b/tests/replication/test_auth.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from synapse.rest.client.v2_alpha import register +from synapse.rest.client import register from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import FakeChannel, make_request diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py index b9751efdc5..995097d72c 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from synapse.rest.client.v2_alpha import register +from synapse.rest.client import register from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import make_request diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index a0c710f855..af5dfca752 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -17,7 +17,7 @@ from unittest.mock import Mock from synapse.api.constants import EventTypes, Membership from synapse.events.builder import EventBuilderFactory from synapse.rest.admin import register_servlets_for_client_rest_resource -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.types import UserID, create_requester from tests.replication._base import BaseMultiWorkerStreamTestCase diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index ffa425328f..ac419f0db3 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -22,7 +22,7 @@ from twisted.web.http import HTTPChannel from twisted.web.server import Request from synapse.rest import admin -from synapse.rest.client.v1 import login +from synapse.rest.client import login from synapse.server import HomeServer from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 1e4e3821b9..4094a75f36 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -17,7 +17,7 @@ from unittest.mock import Mock from twisted.internet import defer from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests.replication._base import BaseMultiWorkerStreamTestCase diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index f3615af97e..0a6e4795ee 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -16,8 +16,7 @@ from unittest.mock import patch from synapse.api.room_versions import RoomVersion from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client import login, room, sync from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import make_request diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index a7c6e595b9..bfa638fb4b 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -24,8 +24,7 @@ import synapse.rest.admin from synapse.http.server import JsonResource from synapse.logging.context import make_deferred_yieldable from synapse.rest.admin import VersionServlet -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import groups +from synapse.rest.client import groups, login, room from tests import unittest from tests.server import FakeSite, make_request diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index 120730b764..c4afe5c3d9 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -17,7 +17,7 @@ import urllib.parse import synapse.rest.admin from synapse.api.errors import Codes -from synapse.rest.client.v1 import login +from synapse.rest.client import login from tests import unittest diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index f15d1cf6f7..e9ef89731f 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -16,8 +16,7 @@ import json import synapse.rest.admin from synapse.api.errors import Codes -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import report_event +from synapse.rest.client import login, report_event, room from tests import unittest diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 7198fd293f..972d60570c 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -20,7 +20,7 @@ from parameterized import parameterized import synapse.rest.admin from synapse.api.errors import Codes -from synapse.rest.client.v1 import login, profile, room +from synapse.rest.client import login, profile, room from synapse.rest.media.v1.filepath import MediaFilePaths from tests import unittest diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 17ec8bfd3b..c9d4731017 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -22,7 +22,7 @@ from parameterized import parameterized_class import synapse.rest.admin from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes -from synapse.rest.client.v1 import directory, events, login, room +from synapse.rest.client import directory, events, login, room from tests import unittest diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 79cac4266b..5cd82209c4 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -18,7 +18,7 @@ from typing import Any, Dict, List, Optional import synapse.rest.admin from synapse.api.errors import Codes -from synapse.rest.client.v1 import login +from synapse.rest.client import login from tests import unittest diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index a736ec4754..ef77275238 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -27,8 +27,7 @@ import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions -from synapse.rest.client.v1 import login, logout, profile, room -from synapse.rest.client.v2_alpha import devices, sync +from synapse.rest.client import devices, login, logout, profile, room, sync from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.types import JsonDict, UserID diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py index 53cbc8ddab..4e1c49c28b 100644 --- a/tests/rest/admin/test_username_available.py +++ b/tests/rest/admin/test_username_available.py @@ -14,7 +14,7 @@ import synapse.rest.admin from synapse.api.errors import Codes, SynapseError -from synapse.rest.client.v1 import login +from synapse.rest.client import login from tests import unittest diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index 5cc62a910a..65c58ce70a 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -16,7 +16,7 @@ import os import synapse.rest.admin from synapse.api.urls import ConsentURIBuilder -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.rest.consent import consent_resource from tests import unittest diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py index eec0fc01f9..3d7aa8ec86 100644 --- a/tests/rest/client/test_ephemeral_message.py +++ b/tests/rest/client/test_ephemeral_message.py @@ -13,7 +13,7 @@ # limitations under the License. from synapse.api.constants import EventContentFields, EventTypes from synapse.rest import admin -from synapse.rest.client.v1 import room +from synapse.rest.client import room from tests import unittest diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index 478296ba0e..ca2e8ff8ef 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -15,7 +15,7 @@ import json import synapse.rest.admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests import unittest diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py index ba5ad47df5..91d0762cb0 100644 --- a/tests/rest/client/test_power_levels.py +++ b/tests/rest/client/test_power_levels.py @@ -13,8 +13,7 @@ # limitations under the License. from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client import login, room, sync from tests.unittest import HomeserverTestCase diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index dfd85221d0..433d715f69 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -13,8 +13,7 @@ # limitations under the License. from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client import login, room, sync from tests.unittest import HomeserverTestCase diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index e1a6e73e17..b58452195a 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -15,7 +15,7 @@ from unittest.mock import Mock from synapse.api.constants import EventTypes from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.visibility import filter_events_for_client from tests import unittest diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index 288ee12888..6a0d9a82be 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -16,8 +16,13 @@ from unittest.mock import Mock, patch import synapse.rest.admin from synapse.api.constants import EventTypes -from synapse.rest.client.v1 import directory, login, profile, room -from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet +from synapse.rest.client import ( + directory, + login, + profile, + room, + room_upgrade_rest_servlet, +) from synapse.types import UserID from tests import unittest diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 28dd47a28b..0ae4029640 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -19,7 +19,7 @@ from synapse.events import EventBase from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.module_api import ModuleApi from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.types import Requester, StateMap from synapse.util.frozenutils import unfreeze diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py index 8ed470490b..d2181ea907 100644 --- a/tests/rest/client/v1/test_directory.py +++ b/tests/rest/client/v1/test_directory.py @@ -15,7 +15,7 @@ import json from synapse.rest import admin -from synapse.rest.client.v1 import directory, login, room +from synapse.rest.client import directory, login, room from synapse.types import RoomAlias from synapse.util.stringutils import random_string diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 2789d51546..a90294003e 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -17,7 +17,7 @@ from unittest.mock import Mock import synapse.rest.admin -from synapse.rest.client.v1 import events, login, room +from synapse.rest.client import events, login, room from tests import unittest diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 7eba69642a..eba3552b19 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -24,9 +24,8 @@ from twisted.web.resource import Resource import synapse.rest.admin from synapse.appservice import ApplicationService -from synapse.rest.client.v1 import login, logout -from synapse.rest.client.v2_alpha import devices, register -from synapse.rest.client.v2_alpha.account import WhoamiRestServlet +from synapse.rest.client import devices, login, logout, register +from synapse.rest.client.account import WhoamiRestServlet from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.types import create_requester diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 597e4c67de..1d152352d1 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -17,7 +17,7 @@ from unittest.mock import Mock from twisted.internet import defer from synapse.handlers.presence import PresenceHandler -from synapse.rest.client.v1 import presence +from synapse.rest.client import presence from synapse.types import UserID from tests import unittest diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 165ad33fb7..2860579c2e 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -14,7 +14,7 @@ """Tests REST events for /profile paths.""" from synapse.rest import admin -from synapse.rest.client.v1 import login, profile, room +from synapse.rest.client import login, profile, room from tests import unittest diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py index d077616082..d0ce91ccd9 100644 --- a/tests/rest/client/v1/test_push_rule_attrs.py +++ b/tests/rest/client/v1/test_push_rule_attrs.py @@ -13,7 +13,7 @@ # limitations under the License. import synapse from synapse.api.errors import Codes -from synapse.rest.client.v1 import login, push_rule, room +from synapse.rest.client import login, push_rule, room from tests.unittest import HomeserverTestCase diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 1a9528ec20..0c9cbb9aff 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -29,8 +29,7 @@ from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import HttpResponseException from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin -from synapse.rest.client.v1 import directory, login, profile, room -from synapse.rest.client.v2_alpha import account +from synapse.rest.client import account, directory, login, profile, room from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.util.stringutils import random_string diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 44e22ca999..b54b004733 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -17,7 +17,7 @@ from unittest.mock import Mock -from synapse.rest.client.v1 import room +from synapse.rest.client import room from synapse.types import UserID from tests import unittest diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index e7e617e9df..b946fca8b3 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -25,8 +25,7 @@ import synapse.rest.admin from synapse.api.constants import LoginType, Membership from synapse.api.errors import Codes, HttpResponseException from synapse.appservice import ApplicationService -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import account, register +from synapse.rest.client import account, login, register, room from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from tests import unittest diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 6b90f838b6..cf5cfb910c 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -19,8 +19,7 @@ from twisted.internet.defer import succeed import synapse.rest.admin from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker -from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import account, auth, devices, register +from synapse.rest.client import account, auth, devices, login, register from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.types import JsonDict, UserID diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index f80f48a455..ad83b3d2ff 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -13,8 +13,7 @@ # limitations under the License. import synapse.rest.admin from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import capabilities +from synapse.rest.client import capabilities, login from tests import unittest from tests.unittest import override_config diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index c7e47725b7..475c6bed3d 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -15,7 +15,7 @@ from twisted.internet import defer from synapse.api.errors import Codes -from synapse.rest.client.v2_alpha import filter +from synapse.rest.client import filter from tests import unittest diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py index 6f07ff6cbb..3cf5871899 100644 --- a/tests/rest/client/v2_alpha/test_password_policy.py +++ b/tests/rest/client/v2_alpha/test_password_policy.py @@ -17,8 +17,7 @@ import json from synapse.api.constants import LoginType from synapse.api.errors import Codes from synapse.rest import admin -from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import account, password_policy, register +from synapse.rest.client import account, login, password_policy, register from tests import unittest diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index a52e5e608a..fecda037a5 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -23,8 +23,7 @@ import synapse.rest.admin from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService -from synapse.rest.client.v1 import login, logout -from synapse.rest.client.v2_alpha import account, account_validity, register, sync +from synapse.rest.client import account, account_validity, login, logout, register, sync from tests import unittest from tests.unittest import override_config diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 2e2f94742e..02b5e9a8d0 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -19,8 +19,7 @@ from typing import Optional from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import register, relations +from synapse.rest.client import login, register, relations, room from tests import unittest diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/v2_alpha/test_report_event.py index a76a6fef1e..ee6b0b9ebf 100644 --- a/tests/rest/client/v2_alpha/test_report_event.py +++ b/tests/rest/client/v2_alpha/test_report_event.py @@ -15,8 +15,7 @@ import json import synapse.rest.admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import report_event +from synapse.rest.client import login, report_event, room from tests import unittest diff --git a/tests/rest/client/v2_alpha/test_sendtodevice.py b/tests/rest/client/v2_alpha/test_sendtodevice.py index c9c99cc5d7..6db7062a8e 100644 --- a/tests/rest/client/v2_alpha/test_sendtodevice.py +++ b/tests/rest/client/v2_alpha/test_sendtodevice.py @@ -13,8 +13,7 @@ # limitations under the License. from synapse.rest import admin -from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import sendtodevice, sync +from synapse.rest.client import login, sendtodevice, sync from tests.unittest import HomeserverTestCase, override_config diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py index cedb9614a8..283eccd53f 100644 --- a/tests/rest/client/v2_alpha/test_shared_rooms.py +++ b/tests/rest/client/v2_alpha/test_shared_rooms.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import synapse.rest.admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import shared_rooms +from synapse.rest.client import login, room, shared_rooms from tests import unittest from tests.server import FakeChannel diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 15748ed4fd..95be369d4b 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -21,8 +21,7 @@ from synapse.api.constants import ( ReadReceiptEventFields, RelationTypes, ) -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import knock, read_marker, receipts, sync +from synapse.rest.client import knock, login, read_marker, receipts, room, sync from tests import unittest from tests.federation.transport.test_knocking import ( diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/v2_alpha/test_upgrade_room.py index 5f3f15fc57..72f976d8e2 100644 --- a/tests/rest/client/v2_alpha/test_upgrade_room.py +++ b/tests/rest/client/v2_alpha/test_upgrade_room.py @@ -15,8 +15,7 @@ from typing import Optional from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet +from synapse.rest.client import login, room, room_upgrade_rest_servlet from tests import unittest from tests.server import FakeChannel diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 2d6b49692e..6085444b9d 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -30,7 +30,7 @@ from twisted.internet.defer import Deferred from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.logging.context import make_deferred_yieldable from synapse.rest import admin -from synapse.rest.client.v1 import login +from synapse.rest.client import login from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.media_storage import MediaStorage diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py index ac98259b7e..58b399a043 100644 --- a/tests/server_notices/test_consent.py +++ b/tests/server_notices/test_consent.py @@ -15,8 +15,7 @@ import os import synapse.rest.admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client import login, room, sync from tests import unittest diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 3245aa91ca..8701b5f7e3 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -19,8 +19,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType from synapse.api.errors import ResourceLimitError from synapse.rest import admin -from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client import login, room, sync from synapse.server_notices.resource_limits_server_notices import ( ResourceLimitsServerNotices, ) diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index d05d367685..a649e8c618 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -15,7 +15,7 @@ import json from synapse.logging.context import LoggingContext from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.util.async_helpers import yieldable_gather_results diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 77c4fe721c..da98733ce8 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -17,7 +17,7 @@ from unittest.mock import Mock, patch import synapse.rest.admin from synapse.api.constants import EventTypes -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.storage import prepare_database from synapse.types import UserID, create_requester diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index e57fce9694..1c2df54ecc 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -17,7 +17,7 @@ from unittest.mock import Mock import synapse.rest.admin from synapse.http.site import XForwardedForRequest -from synapse.rest.client.v1 import login +from synapse.rest.client import login from tests import unittest from tests.server import make_request diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index d87f124c26..93136f0717 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes from synapse.api.room_versions import RoomVersions from synapse.events import EventBase from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.storage.databases.main.events import _LinkMap from synapse.types import create_requester diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index 617bc8091f..f462a8b1c7 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -17,7 +17,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.federation.federation_base import event_from_pdu_json from synapse.rest import admin -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from tests.unittest import HomeserverTestCase diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index e5574063f1..22a77c3ccc 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -13,7 +13,7 @@ # limitations under the License. from synapse.api.errors import NotFoundError, SynapseError -from synapse.rest.client.v1 import room +from synapse.rest.client import room from tests.unittest import HomeserverTestCase diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 9fa968f6bb..c72dc40510 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -15,7 +15,7 @@ from synapse.api.constants import Membership from synapse.rest.admin import register_servlets_for_client_rest_resource -from synapse.rest.client.v1 import login, room +from synapse.rest.client import login, room from synapse.types import UserID, create_requester from tests import unittest diff --git a/tests/test_mau.py b/tests/test_mau.py index fa6ef92b3b..66111eb367 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -17,7 +17,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.appservice import ApplicationService -from synapse.rest.client.v2_alpha import register, sync +from synapse.rest.client import register, sync from tests import unittest from tests.unittest import override_config diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 0df480db9f..67dcf567cd 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -17,7 +17,7 @@ from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactorClock -from synapse.rest.client.v2_alpha.register import register_servlets +from synapse.rest.client.register import register_servlets from synapse.util import Clock from tests import unittest -- cgit 1.5.1 From 5f7b1e1f276fdd25304ff06076e1cd77cf3a9640 Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Tue, 17 Aug 2021 13:13:11 +0100 Subject: Make `PeriodicallyFlushingMemoryHandler` the default logging handler. (#10518) --- changelog.d/10518.feature | 1 + docker/conf/log.config | 27 ++++++++++++++++++++------- docs/sample_log_config.yaml | 27 ++++++++++++++++++++------- synapse/config/logger.py | 27 ++++++++++++++++++++------- 4 files changed, 61 insertions(+), 21 deletions(-) create mode 100644 changelog.d/10518.feature diff --git a/changelog.d/10518.feature b/changelog.d/10518.feature new file mode 100644 index 0000000000..112e4d105c --- /dev/null +++ b/changelog.d/10518.feature @@ -0,0 +1 @@ +The default logging handler for new installations is now `PeriodicallyFlushingMemoryHandler`, a buffered logging handler which periodically flushes itself. diff --git a/docker/conf/log.config b/docker/conf/log.config index a994626926..7a216a36a0 100644 --- a/docker/conf/log.config +++ b/docker/conf/log.config @@ -18,18 +18,31 @@ handlers: backupCount: 6 # Does not include the current log file. encoding: utf8 - # Default to buffering writes to log file for efficiency. This means that - # there will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR - # logs will still be flushed immediately. + # Default to buffering writes to log file for efficiency. + # WARNING/ERROR logs will still be flushed immediately, but there will be a + # delay (of up to `period` seconds, or until the buffer is full with + # `capacity` messages) before INFO/DEBUG logs get written. buffer: - class: logging.handlers.MemoryHandler + class: synapse.logging.handlers.PeriodicallyFlushingMemoryHandler target: file - # The capacity is the number of log lines that are buffered before - # being written to disk. Increasing this will lead to better + + # The capacity is the maximum number of log lines that are buffered + # before being written to disk. Increasing this will lead to better # performance, at the expensive of it taking longer for log lines to # be written to disk. + # This parameter is required. capacity: 10 - flushLevel: 30 # Flush for WARNING logs as well + + # Logs with a level at or above the flush level will cause the buffer to + # be flushed immediately. + # Default value: 40 (ERROR) + # Other values: 50 (CRITICAL), 30 (WARNING), 20 (INFO), 10 (DEBUG) + flushLevel: 30 # Flush immediately for WARNING logs and higher + + # The period of time, in seconds, between forced flushes. + # Messages will not be delayed for longer than this time. + # Default value: 5 seconds + period: 5 {% endif %} console: diff --git a/docs/sample_log_config.yaml b/docs/sample_log_config.yaml index 669e600081..2485ad25ed 100644 --- a/docs/sample_log_config.yaml +++ b/docs/sample_log_config.yaml @@ -24,18 +24,31 @@ handlers: backupCount: 3 # Does not include the current log file. encoding: utf8 - # Default to buffering writes to log file for efficiency. This means that - # will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR - # logs will still be flushed immediately. + # Default to buffering writes to log file for efficiency. + # WARNING/ERROR logs will still be flushed immediately, but there will be a + # delay (of up to `period` seconds, or until the buffer is full with + # `capacity` messages) before INFO/DEBUG logs get written. buffer: - class: logging.handlers.MemoryHandler + class: synapse.logging.handlers.PeriodicallyFlushingMemoryHandler target: file - # The capacity is the number of log lines that are buffered before - # being written to disk. Increasing this will lead to better + + # The capacity is the maximum number of log lines that are buffered + # before being written to disk. Increasing this will lead to better # performance, at the expensive of it taking longer for log lines to # be written to disk. + # This parameter is required. capacity: 10 - flushLevel: 30 # Flush for WARNING logs as well + + # Logs with a level at or above the flush level will cause the buffer to + # be flushed immediately. + # Default value: 40 (ERROR) + # Other values: 50 (CRITICAL), 30 (WARNING), 20 (INFO), 10 (DEBUG) + flushLevel: 30 # Flush immediately for WARNING logs and higher + + # The period of time, in seconds, between forced flushes. + # Messages will not be delayed for longer than this time. + # Default value: 5 seconds + period: 5 # A handler that writes logs to stderr. Unused by default, but can be used # instead of "buffer" and "file" in the logger handlers. diff --git a/synapse/config/logger.py b/synapse/config/logger.py index ad4e6e61c3..4a398a7932 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -67,18 +67,31 @@ handlers: backupCount: 3 # Does not include the current log file. encoding: utf8 - # Default to buffering writes to log file for efficiency. This means that - # will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR - # logs will still be flushed immediately. + # Default to buffering writes to log file for efficiency. + # WARNING/ERROR logs will still be flushed immediately, but there will be a + # delay (of up to `period` seconds, or until the buffer is full with + # `capacity` messages) before INFO/DEBUG logs get written. buffer: - class: logging.handlers.MemoryHandler + class: synapse.logging.handlers.PeriodicallyFlushingMemoryHandler target: file - # The capacity is the number of log lines that are buffered before - # being written to disk. Increasing this will lead to better + + # The capacity is the maximum number of log lines that are buffered + # before being written to disk. Increasing this will lead to better # performance, at the expensive of it taking longer for log lines to # be written to disk. + # This parameter is required. capacity: 10 - flushLevel: 30 # Flush for WARNING logs as well + + # Logs with a level at or above the flush level will cause the buffer to + # be flushed immediately. + # Default value: 40 (ERROR) + # Other values: 50 (CRITICAL), 30 (WARNING), 20 (INFO), 10 (DEBUG) + flushLevel: 30 # Flush immediately for WARNING logs and higher + + # The period of time, in seconds, between forced flushes. + # Messages will not be delayed for longer than this time. + # Default value: 5 seconds + period: 5 # A handler that writes logs to stderr. Unused by default, but can be used # instead of "buffer" and "file" in the logger handlers. -- cgit 1.5.1 From 272b89d547a80628fd90ed35953d2d7c95410b65 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 17 Aug 2021 13:13:42 +0100 Subject: Stop setting the outlier flag for things that aren't (#10614) Marking things as outliers to inhibit pushes is a sledgehammer to crack a nut. Move the test further down the stack so that we just inhibit the thing we want. --- changelog.d/10614.misc | 1 + synapse/handlers/federation.py | 9 ++------- 2 files changed, 3 insertions(+), 7 deletions(-) create mode 100644 changelog.d/10614.misc diff --git a/changelog.d/10614.misc b/changelog.d/10614.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10614.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c0e13bdaac..a257d87fed 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -293,13 +293,7 @@ class FederationHandler(BaseHandler): prevs = set(pdu.prev_event_ids()) seen = await self.store.have_events_in_timeline(prevs) - if min_depth is not None and pdu.depth < min_depth: - # This is so that we don't notify the user about this - # message, to work around the fact that some events will - # reference really really old events we really don't want to - # send to the clients. - pdu.internal_metadata.outlier = True - elif min_depth is not None and pdu.depth > min_depth: + if min_depth is not None and pdu.depth > min_depth: missing_prevs = prevs - seen if sent_to_us_directly and missing_prevs: # If we're missing stuff, ensure we only fetch stuff one @@ -2375,6 +2369,7 @@ class FederationHandler(BaseHandler): not event.internal_metadata.is_outlier() and not backfilled and not context.rejected + and (await self.store.get_min_depth(event.room_id)) <= event.depth ): await self.action_generator.handle_push_actions_for_event( event, context -- cgit 1.5.1 From c4cf0c047329e125f0940281fd53688474d26581 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 17 Aug 2021 08:19:12 -0400 Subject: Attempt to pull from the legacy spaces summary API over federation. (#10583) If the new /hierarchy API does not exist on all destinations, fallback to querying the /spaces API and translating the results. This is a backwards compatibility hack since not all of the federated homeservers will update at the same time. --- changelog.d/10583.feature | 1 + synapse/federation/federation_client.py | 64 ++++++++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 9 deletions(-) create mode 100644 changelog.d/10583.feature diff --git a/changelog.d/10583.feature b/changelog.d/10583.feature new file mode 100644 index 0000000000..ffc4e4289c --- /dev/null +++ b/changelog.d/10583.feature @@ -0,0 +1 @@ +Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 0af953a5d6..29979414e3 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1364,13 +1364,59 @@ class FederationClient(FederationBase): return room, children, inaccessible_children - # TODO Fallback to the old federation API and translate the results. - return await self._try_destination_list( - "fetch room hierarchy", - destinations, - send_request, - failover_on_unknown_endpoint=True, - ) + try: + return await self._try_destination_list( + "fetch room hierarchy", + destinations, + send_request, + failover_on_unknown_endpoint=True, + ) + except SynapseError as e: + # Fallback to the old federation API and translate the results if + # no servers implement the new API. + # + # The algorithm below is a bit inefficient as it only attempts to + # get information for the requested room, but the legacy API may + # return additional layers. + if e.code == 502: + legacy_result = await self.get_space_summary( + destinations, + room_id, + suggested_only, + max_rooms_per_space=None, + exclude_rooms=[], + ) + + # Find the requested room in the response (and remove it). + for _i, room in enumerate(legacy_result.rooms): + if room.get("room_id") == room_id: + break + else: + # The requested room was not returned, nothing we can do. + raise + requested_room = legacy_result.rooms.pop(_i) + + # Find any children events of the requested room. + children_events = [] + children_room_ids = set() + for event in legacy_result.events: + if event.room_id == room_id: + children_events.append(event.data) + children_room_ids.add(event.state_key) + # And add them under the requested room. + requested_room["children_state"] = children_events + + # Find the children rooms. + children = [] + for room in legacy_result.rooms: + if room.get("room_id") in children_room_ids: + children.append(room) + + # It isn't clear from the response whether some of the rooms are + # not accessible. + return requested_room, children, () + + raise @attr.s(frozen=True, slots=True, auto_attribs=True) @@ -1430,7 +1476,7 @@ class FederationSpaceSummaryEventResult: class FederationSpaceSummaryResult: """Represents the data returned by a successful get_space_summary call.""" - rooms: Sequence[JsonDict] + rooms: List[JsonDict] events: Sequence[FederationSpaceSummaryEventResult] @classmethod @@ -1444,7 +1490,7 @@ class FederationSpaceSummaryResult: ValueError if d is not a valid /spaces/ response """ rooms = d.get("rooms") - if not isinstance(rooms, Sequence): + if not isinstance(rooms, List): raise ValueError("'rooms' must be a list") if any(not isinstance(r, dict) for r in rooms): raise ValueError("Invalid room in 'rooms' list") -- cgit 1.5.1 From 56397599809e131174daaeb4c6dc18fde9db6c3f Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 17 Aug 2021 14:45:24 +0200 Subject: Centralise the custom template directory (#10596) Several configuration sections are using separate settings for custom template directories, which can be confusing. This PR adds a new top-level configuration for a custom template directory which is then used for every module. The only exception is the consent templates, since the consent template directory require a specific hierarchy, so it's probably better that it stays separate from everything else. --- changelog.d/10596.removal | 1 + docs/SUMMARY.md | 1 + docs/sample_config.yaml | 225 ++-------------------- docs/templates.md | 239 ++++++++++++++++++++++++ docs/upgrade.md | 11 ++ synapse/config/account_validity.py | 7 +- synapse/config/emailconfig.py | 71 +++---- synapse/config/server.py | 25 +++ synapse/config/sso.py | 173 +---------------- synapse/module_api/__init__.py | 3 +- synapse/rest/synapse/client/new_user_consent.py | 2 + synapse/rest/synapse/client/pick_username.py | 2 + 12 files changed, 342 insertions(+), 418 deletions(-) create mode 100644 changelog.d/10596.removal create mode 100644 docs/templates.md diff --git a/changelog.d/10596.removal b/changelog.d/10596.removal new file mode 100644 index 0000000000..e69f632db4 --- /dev/null +++ b/changelog.d/10596.removal @@ -0,0 +1 @@ +The `template_dir` configuration settings in the `sso`, `account_validity` and `email` sections of the configuration file are now deprecated in favour of the global `templates.custom_template_directory` setting. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html) for more information. diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 3d320a1c43..56e0141c2b 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -21,6 +21,7 @@ - [Homeserver Sample Config File](usage/configuration/homeserver_sample_config.md) - [Logging Sample Config File](usage/configuration/logging_sample_config.md) - [Structured Logging](structured_logging.md) + - [Templates](templates.md) - [User Authentication](usage/configuration/user_authentication/README.md) - [Single-Sign On]() - [OpenID Connect](openid.md) diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index aeebcaf45f..3ec76d5abf 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -565,6 +565,19 @@ retention: # #next_link_domain_whitelist: ["matrix.org"] +# Templates to use when generating email or HTML page contents. +# +templates: + # Directory in which Synapse will try to find template files to use to generate + # email or HTML page contents. + # If not set, or a file is not found within the template directory, a default + # template from within the Synapse package will be used. + # + # See https://matrix-org.github.io/synapse/latest/templates.html for more + # information about using custom templates. + # + #custom_template_directory: /path/to/custom/templates/ + ## TLS ## @@ -1895,6 +1908,9 @@ cas_config: # Additional settings to use with single-sign on systems such as OpenID Connect, # SAML2 and CAS. # +# Server admins can configure custom templates for pages related to SSO. See +# https://matrix-org.github.io/synapse/latest/templates.html for more information. +# sso: # A list of client URLs which are whitelisted so that the user does not # have to confirm giving access to their account to the URL. Any client @@ -1927,169 +1943,6 @@ sso: # #update_profile_information: true - # Directory in which Synapse will try to find the template files below. - # If not set, or the files named below are not found within the template - # directory, default templates from within the Synapse package will be used. - # - # Synapse will look for the following templates in this directory: - # - # * HTML page to prompt the user to choose an Identity Provider during - # login: 'sso_login_idp_picker.html'. - # - # This is only used if multiple SSO Identity Providers are configured. - # - # When rendering, this template is given the following variables: - # * redirect_url: the URL that the user will be redirected to after - # login. - # - # * server_name: the homeserver's name. - # - # * providers: a list of available Identity Providers. Each element is - # an object with the following attributes: - # - # * idp_id: unique identifier for the IdP - # * idp_name: user-facing name for the IdP - # * idp_icon: if specified in the IdP config, an MXC URI for an icon - # for the IdP - # * idp_brand: if specified in the IdP config, a textual identifier - # for the brand of the IdP - # - # The rendered HTML page should contain a form which submits its results - # back as a GET request, with the following query parameters: - # - # * redirectUrl: the client redirect URI (ie, the `redirect_url` passed - # to the template) - # - # * idp: the 'idp_id' of the chosen IDP. - # - # * HTML page to prompt new users to enter a userid and confirm other - # details: 'sso_auth_account_details.html'. This is only shown if the - # SSO implementation (with any user_mapping_provider) does not return - # a localpart. - # - # When rendering, this template is given the following variables: - # - # * server_name: the homeserver's name. - # - # * idp: details of the SSO Identity Provider that the user logged in - # with: an object with the following attributes: - # - # * idp_id: unique identifier for the IdP - # * idp_name: user-facing name for the IdP - # * idp_icon: if specified in the IdP config, an MXC URI for an icon - # for the IdP - # * idp_brand: if specified in the IdP config, a textual identifier - # for the brand of the IdP - # - # * user_attributes: an object containing details about the user that - # we received from the IdP. May have the following attributes: - # - # * display_name: the user's display_name - # * emails: a list of email addresses - # - # The template should render a form which submits the following fields: - # - # * username: the localpart of the user's chosen user id - # - # * HTML page allowing the user to consent to the server's terms and - # conditions. This is only shown for new users, and only if - # `user_consent.require_at_registration` is set. - # - # When rendering, this template is given the following variables: - # - # * server_name: the homeserver's name. - # - # * user_id: the user's matrix proposed ID. - # - # * user_profile.display_name: the user's proposed display name, if any. - # - # * consent_version: the version of the terms that the user will be - # shown - # - # * terms_url: a link to the page showing the terms. - # - # The template should render a form which submits the following fields: - # - # * accepted_version: the version of the terms accepted by the user - # (ie, 'consent_version' from the input variables). - # - # * HTML page for a confirmation step before redirecting back to the client - # with the login token: 'sso_redirect_confirm.html'. - # - # When rendering, this template is given the following variables: - # - # * redirect_url: the URL the user is about to be redirected to. - # - # * display_url: the same as `redirect_url`, but with the query - # parameters stripped. The intention is to have a - # human-readable URL to show to users, not to use it as - # the final address to redirect to. - # - # * server_name: the homeserver's name. - # - # * new_user: a boolean indicating whether this is the user's first time - # logging in. - # - # * user_id: the user's matrix ID. - # - # * user_profile.avatar_url: an MXC URI for the user's avatar, if any. - # None if the user has not set an avatar. - # - # * user_profile.display_name: the user's display name. None if the user - # has not set a display name. - # - # * HTML page which notifies the user that they are authenticating to confirm - # an operation on their account during the user interactive authentication - # process: 'sso_auth_confirm.html'. - # - # When rendering, this template is given the following variables: - # * redirect_url: the URL the user is about to be redirected to. - # - # * description: the operation which the user is being asked to confirm - # - # * idp: details of the Identity Provider that we will use to confirm - # the user's identity: an object with the following attributes: - # - # * idp_id: unique identifier for the IdP - # * idp_name: user-facing name for the IdP - # * idp_icon: if specified in the IdP config, an MXC URI for an icon - # for the IdP - # * idp_brand: if specified in the IdP config, a textual identifier - # for the brand of the IdP - # - # * HTML page shown after a successful user interactive authentication session: - # 'sso_auth_success.html'. - # - # Note that this page must include the JavaScript which notifies of a successful authentication - # (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback). - # - # This template has no additional variables. - # - # * HTML page shown after a user-interactive authentication session which - # does not map correctly onto the expected user: 'sso_auth_bad_user.html'. - # - # When rendering, this template is given the following variables: - # * server_name: the homeserver's name. - # * user_id_to_verify: the MXID of the user that we are trying to - # validate. - # - # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database) - # attempts to login: 'sso_account_deactivated.html'. - # - # This template has no additional variables. - # - # * HTML page to display to users if something goes wrong during the - # OpenID Connect authentication process: 'sso_error.html'. - # - # When rendering, this template is given two variables: - # * error: the technical name of the error - # * error_description: a human-readable message for the error - # - # You can see the default templates at: - # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates - # - #template_dir: "res/templates" - # JSON web token integration. The following settings can be used to make # Synapse JSON web tokens for authentication, instead of its internal @@ -2220,6 +2073,9 @@ ui_auth: # Configuration for sending emails from Synapse. # +# Server admins can configure custom templates for email content. See +# https://matrix-org.github.io/synapse/latest/templates.html for more information. +# email: # The hostname of the outgoing SMTP server to use. Defaults to 'localhost'. # @@ -2296,49 +2152,6 @@ email: # #invite_client_location: https://app.element.io - # Directory in which Synapse will try to find the template files below. - # If not set, or the files named below are not found within the template - # directory, default templates from within the Synapse package will be used. - # - # Synapse will look for the following templates in this directory: - # - # * The contents of email notifications of missed events: 'notif_mail.html' and - # 'notif_mail.txt'. - # - # * The contents of account expiry notice emails: 'notice_expiry.html' and - # 'notice_expiry.txt'. - # - # * The contents of password reset emails sent by the homeserver: - # 'password_reset.html' and 'password_reset.txt' - # - # * An HTML page that a user will see when they follow the link in the password - # reset email. The user will be asked to confirm the action before their - # password is reset: 'password_reset_confirmation.html' - # - # * HTML pages for success and failure that a user will see when they confirm - # the password reset flow using the page above: 'password_reset_success.html' - # and 'password_reset_failure.html' - # - # * The contents of address verification emails sent during registration: - # 'registration.html' and 'registration.txt' - # - # * HTML pages for success and failure that a user will see when they follow - # the link in an address verification email sent during registration: - # 'registration_success.html' and 'registration_failure.html' - # - # * The contents of address verification emails sent when an address is added - # to a Matrix account: 'add_threepid.html' and 'add_threepid.txt' - # - # * HTML pages for success and failure that a user will see when they follow - # the link in an address verification email sent when an address is added - # to a Matrix account: 'add_threepid_success.html' and - # 'add_threepid_failure.html' - # - # You can see the default templates at: - # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates - # - #template_dir: "res/templates" - # Subjects to use when sending emails from Synapse. # # The placeholder '%(app)s' will be replaced with the value of the 'app_name' diff --git a/docs/templates.md b/docs/templates.md new file mode 100644 index 0000000000..a240f58b54 --- /dev/null +++ b/docs/templates.md @@ -0,0 +1,239 @@ +# Templates + +Synapse uses parametrised templates to generate the content of emails it sends and +webpages it shows to users. + +By default, Synapse will use the templates listed [here](https://github.com/matrix-org/synapse/tree/master/synapse/res/templates). +Server admins can configure an additional directory for Synapse to look for templates +in, allowing them to specify custom templates: + +```yaml +templates: + custom_templates_directory: /path/to/custom/templates/ +``` + +If this setting is not set, or the files named below are not found within the directory, +default templates from within the Synapse package will be used. + +Templates that are given variables when being rendered are rendered using [Jinja 2](https://jinja.palletsprojects.com/en/2.11.x/). +Templates rendered by Jinja 2 can also access two functions on top of the functions +already available as part of Jinja 2: + +```python +format_ts(value: int, format: str) -> str +``` + +Formats a timestamp in milliseconds. + +Example: `reason.last_sent_ts|format_ts("%c")` + +```python +mxc_to_http(value: str, width: int, height: int, resize_method: str = "crop") -> str +``` + +Turns a `mxc://` URL for media content into an HTTP(S) one using the homeserver's +`public_baseurl` configuration setting as the URL's base. + +Example: `message.sender_avatar_url|mxc_to_http(32,32)` + + +## Email templates + +Below are the templates Synapse will look for when generating the content of an email: + +* `notif_mail.html` and `notif_mail.txt`: The contents of email notifications of missed + events. + When rendering, this template is given the following variables: + * `user_display_name`: the display name for the user receiving the notification + * `unsubscribe_link`: the link users can click to unsubscribe from email notifications + * `summary_text`: a summary of the notification(s). The text used can be customised + by configuring the various settings in the `email.subjects` section of the + configuration file. + * `rooms`: a list of rooms containing events to include in the email. Each element is + an object with the following attributes: + * `title`: a human-readable name for the room + * `hash`: a hash of the ID of the room + * `invite`: a boolean, which is `True` if the room is an invite the user hasn't + accepted yet, `False` otherwise + * `notifs`: a list of events, or an empty list if `invite` is `True`. Each element + is an object with the following attributes: + * `link`: a `matrix.to` link to the event + * `ts`: the time in milliseconds at which the event was received + * `messages`: a list of messages containing one message before the event, the + message in the event, and one message after the event. Each element is an + object with the following attributes: + * `event_type`: the type of the event + * `is_historical`: a boolean, which is `False` if the message is the one + that triggered the notification, `True` otherwise + * `id`: the ID of the event + * `ts`: the time in milliseconds at which the event was sent + * `sender_name`: the display name for the event's sender + * `sender_avatar_url`: the avatar URL (as a `mxc://` URL) for the event's + sender + * `sender_hash`: a hash of the user ID of the sender + * `link`: a `matrix.to` link to the room + * `reason`: information on the event that triggered the email to be sent. It's an + object with the following attributes: + * `room_id`: the ID of the room the event was sent in + * `room_name`: a human-readable name for the room the event was sent in + * `now`: the current time in milliseconds + * `received_at`: the time in milliseconds at which the event was received + * `delay_before_mail_ms`: the amount of time in milliseconds Synapse always waits + before ever emailing about a notification (to give the user a chance to respond + to other push or notice the window) + * `last_sent_ts`: the time in milliseconds at which a notification was last sent + for an event in this room + * `throttle_ms`: the minimum amount of time in milliseconds between two + notifications can be sent for this room +* `password_reset.html` and `password_reset.txt`: The contents of password reset emails + sent by the homeserver. + When rendering, these templates are given a `link` variable which contains the link the + user must click in order to reset their password. +* `registration.html` and `registration.txt`: The contents of address verification emails + sent during registration. + When rendering, these templates are given a `link` variable which contains the link the + user must click in order to validate their email address. +* `add_threepid.html` and `add_threepid.txt`: The contents of address verification emails + sent when an address is added to a Matrix account. + When rendering, these templates are given a `link` variable which contains the link the + user must click in order to validate their email address. + + +## HTML page templates for registration and password reset + +Below are the templates Synapse will look for when generating pages related to +registration and password reset: + +* `password_reset_confirmation.html`: An HTML page that a user will see when they follow + the link in the password reset email. The user will be asked to confirm the action + before their password is reset. + When rendering, this template is given the following variables: + * `sid`: the session ID for the password reset + * `token`: the token for the password reset + * `client_secret`: the client secret for the password reset +* `password_reset_success.html` and `password_reset_failure.html`: HTML pages for success + and failure that a user will see when they confirm the password reset flow using the + page above. + When rendering, `password_reset_success.html` is given no variable, and + `password_reset_failure.html` is given a `failure_reason`, which contains the reason + for the password reset failure. +* `registration_success.html` and `registration_failure.html`: HTML pages for success and + failure that a user will see when they follow the link in an address verification email + sent during registration. + When rendering, `registration_success.html` is given no variable, and + `registration_failure.html` is given a `failure_reason`, which contains the reason + for the registration failure. +* `add_threepid_success.html` and `add_threepid_failure.html`: HTML pages for success and + failure that a user will see when they follow the link in an address verification email + sent when an address is added to a Matrix account. + When rendering, `add_threepid_success.html` is given no variable, and + `add_threepid_failure.html` is given a `failure_reason`, which contains the reason + for the registration failure. + + +## HTML page templates for Single Sign-On (SSO) + +Below are the templates Synapse will look for when generating pages related to SSO: + +* `sso_login_idp_picker.html`: HTML page to prompt the user to choose an + Identity Provider during login. + This is only used if multiple SSO Identity Providers are configured. + When rendering, this template is given the following variables: + * `redirect_url`: the URL that the user will be redirected to after + login. + * `server_name`: the homeserver's name. + * `providers`: a list of available Identity Providers. Each element is + an object with the following attributes: + * `idp_id`: unique identifier for the IdP + * `idp_name`: user-facing name for the IdP + * `idp_icon`: if specified in the IdP config, an MXC URI for an icon + for the IdP + * `idp_brand`: if specified in the IdP config, a textual identifier + for the brand of the IdP + The rendered HTML page should contain a form which submits its results + back as a GET request, with the following query parameters: + * `redirectUrl`: the client redirect URI (ie, the `redirect_url` passed + to the template) + * `idp`: the 'idp_id' of the chosen IDP. +* `sso_auth_account_details.html`: HTML page to prompt new users to enter a + userid and confirm other details. This is only shown if the + SSO implementation (with any `user_mapping_provider`) does not return + a localpart. + When rendering, this template is given the following variables: + * `server_name`: the homeserver's name. + * `idp`: details of the SSO Identity Provider that the user logged in + with: an object with the following attributes: + * `idp_id`: unique identifier for the IdP + * `idp_name`: user-facing name for the IdP + * `idp_icon`: if specified in the IdP config, an MXC URI for an icon + for the IdP + * `idp_brand`: if specified in the IdP config, a textual identifier + for the brand of the IdP + * `user_attributes`: an object containing details about the user that + we received from the IdP. May have the following attributes: + * display_name: the user's display_name + * emails: a list of email addresses + The template should render a form which submits the following fields: + * `username`: the localpart of the user's chosen user id +* `sso_new_user_consent.html`: HTML page allowing the user to consent to the + server's terms and conditions. This is only shown for new users, and only if + `user_consent.require_at_registration` is set. + When rendering, this template is given the following variables: + * `server_name`: the homeserver's name. + * `user_id`: the user's matrix proposed ID. + * `user_profile.display_name`: the user's proposed display name, if any. + * consent_version: the version of the terms that the user will be + shown + * `terms_url`: a link to the page showing the terms. + The template should render a form which submits the following fields: + * `accepted_version`: the version of the terms accepted by the user + (ie, 'consent_version' from the input variables). +* `sso_redirect_confirm.html`: HTML page for a confirmation step before redirecting back + to the client with the login token. + When rendering, this template is given the following variables: + * `redirect_url`: the URL the user is about to be redirected to. + * `display_url`: the same as `redirect_url`, but with the query + parameters stripped. The intention is to have a + human-readable URL to show to users, not to use it as + the final address to redirect to. + * `server_name`: the homeserver's name. + * `new_user`: a boolean indicating whether this is the user's first time + logging in. + * `user_id`: the user's matrix ID. + * `user_profile.avatar_url`: an MXC URI for the user's avatar, if any. + `None` if the user has not set an avatar. + * `user_profile.display_name`: the user's display name. `None` if the user + has not set a display name. +* `sso_auth_confirm.html`: HTML page which notifies the user that they are authenticating + to confirm an operation on their account during the user interactive authentication + process. + When rendering, this template is given the following variables: + * `redirect_url`: the URL the user is about to be redirected to. + * `description`: the operation which the user is being asked to confirm + * `idp`: details of the Identity Provider that we will use to confirm + the user's identity: an object with the following attributes: + * `idp_id`: unique identifier for the IdP + * `idp_name`: user-facing name for the IdP + * `idp_icon`: if specified in the IdP config, an MXC URI for an icon + for the IdP + * `idp_brand`: if specified in the IdP config, a textual identifier + for the brand of the IdP +* `sso_auth_success.html`: HTML page shown after a successful user interactive + authentication session. + Note that this page must include the JavaScript which notifies of a successful + authentication (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback). + This template has no additional variables. +* `sso_auth_bad_user.html`: HTML page shown after a user-interactive authentication + session which does not map correctly onto the expected user. + When rendering, this template is given the following variables: + * `server_name`: the homeserver's name. + * `user_id_to_verify`: the MXID of the user that we are trying to + validate. +* `sso_account_deactivated.html`: HTML page shown during single sign-on if a deactivated + user (according to Synapse's database) attempts to login. + This template has no additional variables. +* `sso_error.html`: HTML page to display to users if something goes wrong during the + OpenID Connect authentication process. + When rendering, this template is given two variables: + * `error`: the technical name of the error + * `error_description`: a human-readable message for the error diff --git a/docs/upgrade.md b/docs/upgrade.md index 8831c9d6cf..1c459d8e2b 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -112,6 +112,17 @@ environment variable. See [using a forward proxy with Synapse documentation](setup/forward_proxy.md) for details. +## Deprecation of `template_dir` + +The `template_dir` settings in the `sso`, `account_validity` and `email` sections of the +configuration file are now deprecated. Server admins should use the new +`templates.custom_template_directory` setting in the configuration file and use one single +custom template directory for all aforementioned features. Template file names remain +unchanged. See [the related documentation](https://matrix-org.github.io/synapse/latest/templates.html) +for more information and examples. + +We plan to remove support for these settings in October 2021. + # Upgrading to v1.39.0 diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py index 9acce5996e..52e63ab1f6 100644 --- a/synapse/config/account_validity.py +++ b/synapse/config/account_validity.py @@ -78,6 +78,11 @@ class AccountValidityConfig(Config): ) # Read and store template content + custom_template_directories = ( + self.root.server.custom_template_directory, + account_validity_template_dir, + ) + ( self.account_validity_account_renewed_template, self.account_validity_account_previously_renewed_template, @@ -88,5 +93,5 @@ class AccountValidityConfig(Config): "account_previously_renewed.html", invalid_token_template_filename, ], - (td for td in (account_validity_template_dir,) if td), + (td for td in custom_template_directories if td), ) diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index fc74b4a8b9..4477419196 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -258,7 +258,12 @@ class EmailConfig(Config): add_threepid_template_success_html, ], ( - td for td in (template_dir,) if td + td + for td in ( + self.root.server.custom_template_directory, + template_dir, + ) + if td ), # Filter out template_dir if not provided ) @@ -299,7 +304,14 @@ class EmailConfig(Config): self.email_notif_template_text, ) = self.read_templates( [notif_template_html, notif_template_text], - (td for td in (template_dir,) if td), + ( + td + for td in ( + self.root.server.custom_template_directory, + template_dir, + ) + if td + ), # Filter out template_dir if not provided ) self.email_notif_for_new_users = email_config.get( @@ -322,7 +334,14 @@ class EmailConfig(Config): self.account_validity_template_text, ) = self.read_templates( [expiry_template_html, expiry_template_text], - (td for td in (template_dir,) if td), + ( + td + for td in ( + self.root.server.custom_template_directory, + template_dir, + ) + if td + ), # Filter out template_dir if not provided ) subjects_config = email_config.get("subjects", {}) @@ -354,6 +373,9 @@ class EmailConfig(Config): """\ # Configuration for sending emails from Synapse. # + # Server admins can configure custom templates for email content. See + # https://matrix-org.github.io/synapse/latest/templates.html for more information. + # email: # The hostname of the outgoing SMTP server to use. Defaults to 'localhost'. # @@ -430,49 +452,6 @@ class EmailConfig(Config): # #invite_client_location: https://app.element.io - # Directory in which Synapse will try to find the template files below. - # If not set, or the files named below are not found within the template - # directory, default templates from within the Synapse package will be used. - # - # Synapse will look for the following templates in this directory: - # - # * The contents of email notifications of missed events: 'notif_mail.html' and - # 'notif_mail.txt'. - # - # * The contents of account expiry notice emails: 'notice_expiry.html' and - # 'notice_expiry.txt'. - # - # * The contents of password reset emails sent by the homeserver: - # 'password_reset.html' and 'password_reset.txt' - # - # * An HTML page that a user will see when they follow the link in the password - # reset email. The user will be asked to confirm the action before their - # password is reset: 'password_reset_confirmation.html' - # - # * HTML pages for success and failure that a user will see when they confirm - # the password reset flow using the page above: 'password_reset_success.html' - # and 'password_reset_failure.html' - # - # * The contents of address verification emails sent during registration: - # 'registration.html' and 'registration.txt' - # - # * HTML pages for success and failure that a user will see when they follow - # the link in an address verification email sent during registration: - # 'registration_success.html' and 'registration_failure.html' - # - # * The contents of address verification emails sent when an address is added - # to a Matrix account: 'add_threepid.html' and 'add_threepid.txt' - # - # * HTML pages for success and failure that a user will see when they follow - # the link in an address verification email sent when an address is added - # to a Matrix account: 'add_threepid_success.html' and - # 'add_threepid_failure.html' - # - # You can see the default templates at: - # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates - # - #template_dir: "res/templates" - # Subjects to use when sending emails from Synapse. # # The placeholder '%%(app)s' will be replaced with the value of the 'app_name' diff --git a/synapse/config/server.py b/synapse/config/server.py index 187b4301a0..8494795919 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -710,6 +710,18 @@ class ServerConfig(Config): # Turn the list into a set to improve lookup speed. self.next_link_domain_whitelist = set(next_link_domain_whitelist) + templates_config = config.get("templates") or {} + if not isinstance(templates_config, dict): + raise ConfigError("The 'templates' section must be a dictionary") + + self.custom_template_directory = templates_config.get( + "custom_template_directory" + ) + if self.custom_template_directory is not None and not isinstance( + self.custom_template_directory, str + ): + raise ConfigError("'custom_template_directory' must be a string") + def has_tls_listener(self) -> bool: return any(listener.tls for listener in self.listeners) @@ -1284,6 +1296,19 @@ class ServerConfig(Config): # all domains. # #next_link_domain_whitelist: ["matrix.org"] + + # Templates to use when generating email or HTML page contents. + # + templates: + # Directory in which Synapse will try to find template files to use to generate + # email or HTML page contents. + # If not set, or a file is not found within the template directory, a default + # template from within the Synapse package will be used. + # + # See https://matrix-org.github.io/synapse/latest/templates.html for more + # information about using custom templates. + # + #custom_template_directory: /path/to/custom/templates/ """ % locals() ) diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 4b590e0535..fe1177ab81 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -45,6 +45,11 @@ class SSOConfig(Config): self.sso_template_dir = sso_config.get("template_dir") # Read templates from disk + custom_template_directories = ( + self.root.server.custom_template_directory, + self.sso_template_dir, + ) + ( self.sso_login_idp_picker_template, self.sso_redirect_confirm_template, @@ -63,7 +68,7 @@ class SSOConfig(Config): "sso_auth_success.html", "sso_auth_bad_user.html", ], - (td for td in (self.sso_template_dir,) if td), + (td for td in custom_template_directories if td), ) # These templates have no placeholders, so render them here @@ -94,6 +99,9 @@ class SSOConfig(Config): # Additional settings to use with single-sign on systems such as OpenID Connect, # SAML2 and CAS. # + # Server admins can configure custom templates for pages related to SSO. See + # https://matrix-org.github.io/synapse/latest/templates.html for more information. + # sso: # A list of client URLs which are whitelisted so that the user does not # have to confirm giving access to their account to the URL. Any client @@ -125,167 +133,4 @@ class SSOConfig(Config): # information when first signing in. Defaults to false. # #update_profile_information: true - - # Directory in which Synapse will try to find the template files below. - # If not set, or the files named below are not found within the template - # directory, default templates from within the Synapse package will be used. - # - # Synapse will look for the following templates in this directory: - # - # * HTML page to prompt the user to choose an Identity Provider during - # login: 'sso_login_idp_picker.html'. - # - # This is only used if multiple SSO Identity Providers are configured. - # - # When rendering, this template is given the following variables: - # * redirect_url: the URL that the user will be redirected to after - # login. - # - # * server_name: the homeserver's name. - # - # * providers: a list of available Identity Providers. Each element is - # an object with the following attributes: - # - # * idp_id: unique identifier for the IdP - # * idp_name: user-facing name for the IdP - # * idp_icon: if specified in the IdP config, an MXC URI for an icon - # for the IdP - # * idp_brand: if specified in the IdP config, a textual identifier - # for the brand of the IdP - # - # The rendered HTML page should contain a form which submits its results - # back as a GET request, with the following query parameters: - # - # * redirectUrl: the client redirect URI (ie, the `redirect_url` passed - # to the template) - # - # * idp: the 'idp_id' of the chosen IDP. - # - # * HTML page to prompt new users to enter a userid and confirm other - # details: 'sso_auth_account_details.html'. This is only shown if the - # SSO implementation (with any user_mapping_provider) does not return - # a localpart. - # - # When rendering, this template is given the following variables: - # - # * server_name: the homeserver's name. - # - # * idp: details of the SSO Identity Provider that the user logged in - # with: an object with the following attributes: - # - # * idp_id: unique identifier for the IdP - # * idp_name: user-facing name for the IdP - # * idp_icon: if specified in the IdP config, an MXC URI for an icon - # for the IdP - # * idp_brand: if specified in the IdP config, a textual identifier - # for the brand of the IdP - # - # * user_attributes: an object containing details about the user that - # we received from the IdP. May have the following attributes: - # - # * display_name: the user's display_name - # * emails: a list of email addresses - # - # The template should render a form which submits the following fields: - # - # * username: the localpart of the user's chosen user id - # - # * HTML page allowing the user to consent to the server's terms and - # conditions. This is only shown for new users, and only if - # `user_consent.require_at_registration` is set. - # - # When rendering, this template is given the following variables: - # - # * server_name: the homeserver's name. - # - # * user_id: the user's matrix proposed ID. - # - # * user_profile.display_name: the user's proposed display name, if any. - # - # * consent_version: the version of the terms that the user will be - # shown - # - # * terms_url: a link to the page showing the terms. - # - # The template should render a form which submits the following fields: - # - # * accepted_version: the version of the terms accepted by the user - # (ie, 'consent_version' from the input variables). - # - # * HTML page for a confirmation step before redirecting back to the client - # with the login token: 'sso_redirect_confirm.html'. - # - # When rendering, this template is given the following variables: - # - # * redirect_url: the URL the user is about to be redirected to. - # - # * display_url: the same as `redirect_url`, but with the query - # parameters stripped. The intention is to have a - # human-readable URL to show to users, not to use it as - # the final address to redirect to. - # - # * server_name: the homeserver's name. - # - # * new_user: a boolean indicating whether this is the user's first time - # logging in. - # - # * user_id: the user's matrix ID. - # - # * user_profile.avatar_url: an MXC URI for the user's avatar, if any. - # None if the user has not set an avatar. - # - # * user_profile.display_name: the user's display name. None if the user - # has not set a display name. - # - # * HTML page which notifies the user that they are authenticating to confirm - # an operation on their account during the user interactive authentication - # process: 'sso_auth_confirm.html'. - # - # When rendering, this template is given the following variables: - # * redirect_url: the URL the user is about to be redirected to. - # - # * description: the operation which the user is being asked to confirm - # - # * idp: details of the Identity Provider that we will use to confirm - # the user's identity: an object with the following attributes: - # - # * idp_id: unique identifier for the IdP - # * idp_name: user-facing name for the IdP - # * idp_icon: if specified in the IdP config, an MXC URI for an icon - # for the IdP - # * idp_brand: if specified in the IdP config, a textual identifier - # for the brand of the IdP - # - # * HTML page shown after a successful user interactive authentication session: - # 'sso_auth_success.html'. - # - # Note that this page must include the JavaScript which notifies of a successful authentication - # (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback). - # - # This template has no additional variables. - # - # * HTML page shown after a user-interactive authentication session which - # does not map correctly onto the expected user: 'sso_auth_bad_user.html'. - # - # When rendering, this template is given the following variables: - # * server_name: the homeserver's name. - # * user_id_to_verify: the MXID of the user that we are trying to - # validate. - # - # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database) - # attempts to login: 'sso_account_deactivated.html'. - # - # This template has no additional variables. - # - # * HTML page to display to users if something goes wrong during the - # OpenID Connect authentication process: 'sso_error.html'. - # - # When rendering, this template is given two variables: - # * error: the technical name of the error - # * error_description: a human-readable message for the error - # - # You can see the default templates at: - # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates - # - #template_dir: "res/templates" """ diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 82725853bc..2f99d31c42 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -91,6 +91,7 @@ class ModuleApi: self._state = hs.get_state_handler() self._clock: Clock = hs.get_clock() self._send_email_handler = hs.get_send_email_handler() + self.custom_template_dir = hs.config.server.custom_template_directory try: app_name = self._hs.config.email_app_name @@ -679,7 +680,7 @@ class ModuleApi: """ return self._hs.config.read_templates( filenames, - (td for td in (custom_template_directory,) if td), + (td for td in (self.custom_template_dir, custom_template_directory) if td), ) diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py index 488b97b32e..fc62a09b7f 100644 --- a/synapse/rest/synapse/client/new_user_consent.py +++ b/synapse/rest/synapse/client/new_user_consent.py @@ -46,6 +46,8 @@ class NewUserConsentResource(DirectServeHtmlResource): self._consent_version = hs.config.consent.user_consent_version def template_search_dirs(): + if hs.config.server.custom_template_directory: + yield hs.config.server.custom_template_directory if hs.config.sso.sso_template_dir: yield hs.config.sso.sso_template_dir yield hs.config.sso.default_template_dir diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index ab24ec0a8e..c15b83c387 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -74,6 +74,8 @@ class AccountDetailsResource(DirectServeHtmlResource): self._sso_handler = hs.get_sso_handler() def template_search_dirs(): + if hs.config.server.custom_template_directory: + yield hs.config.server.custom_template_directory if hs.config.sso.sso_template_dir: yield hs.config.sso.sso_template_dir yield hs.config.sso.default_template_dir -- cgit 1.5.1 From c8132f4a31be2717976052424abeb1ed40e947c8 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 17 Aug 2021 13:48:59 +0100 Subject: Build debs for bookworm (#10612) --- changelog.d/10612.misc | 1 + scripts-dev/build_debian_packages | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10612.misc diff --git a/changelog.d/10612.misc b/changelog.d/10612.misc new file mode 100644 index 0000000000..c7a9457022 --- /dev/null +++ b/changelog.d/10612.misc @@ -0,0 +1 @@ +Build Debian packages for Debian 12 (Bookworm). diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages index 6153cb225f..e9f89e38ef 100755 --- a/scripts-dev/build_debian_packages +++ b/scripts-dev/build_debian_packages @@ -20,8 +20,9 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional, Sequence DISTS = ( - "debian:buster", + "debian:buster", # oldstable: EOL 2022-08 "debian:bullseye", + "debian:bookworm", "debian:sid", "ubuntu:bionic", # 18.04 LTS (our EOL forced by Py36 on 2021-12-23) "ubuntu:focal", # 20.04 LTS (our EOL forced by Py38 on 2024-10-14) -- cgit 1.5.1 From 84469bdac773ddb79cfc99f31bbac78d27450682 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 17 Aug 2021 14:02:50 +0100 Subject: Remove the unused public_room_list_stream (#10565) Co-authored-by: Patrick Cloke --- changelog.d/10565.misc | 1 + synapse/app/admin_cmd.py | 2 - synapse/app/generic_worker.py | 4 +- synapse/replication/slave/storage/room.py | 37 ----- synapse/replication/tcp/streams/__init__.py | 3 - synapse/replication/tcp/streams/_base.py | 25 ---- synapse/storage/databases/main/__init__.py | 4 +- synapse/storage/databases/main/room.py | 215 +++++----------------------- synapse/storage/schema/__init__.py | 7 +- 9 files changed, 48 insertions(+), 250 deletions(-) create mode 100644 changelog.d/10565.misc delete mode 100644 synapse/replication/slave/storage/room.py diff --git a/changelog.d/10565.misc b/changelog.d/10565.misc new file mode 100644 index 0000000000..06796b61ab --- /dev/null +++ b/changelog.d/10565.misc @@ -0,0 +1 @@ +Remove the unused public rooms replication stream. \ No newline at end of file diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 3234d9ebba..7396db93c6 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -38,7 +38,6 @@ from synapse.replication.slave.storage.groups import SlavedGroupServerStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.room import RoomStore from synapse.server import HomeServer from synapse.util.logcontext import LoggingContext from synapse.util.versionstring import get_version_string @@ -58,7 +57,6 @@ class AdminCmdSlavedStore( SlavedPushRuleStore, SlavedEventStore, SlavedClientIpStore, - RoomStore, BaseSlavedStore, ): pass diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index d7b425a7ab..845e6a8220 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -64,7 +64,6 @@ from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.pushers import SlavedPusherStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore -from synapse.replication.slave.storage.room import RoomStore from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.client import ( account_data, @@ -114,6 +113,7 @@ from synapse.storage.databases.main.monthly_active_users import ( MonthlyActiveUsersWorkerStore, ) from synapse.storage.databases.main.presence import PresenceStore +from synapse.storage.databases.main.room import RoomWorkerStore from synapse.storage.databases.main.search import SearchStore from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.transactions import TransactionWorkerStore @@ -237,7 +237,7 @@ class GenericWorkerSlavedStore( ClientIpWorkerStore, SlavedEventStore, SlavedKeyStore, - RoomStore, + RoomWorkerStore, DirectoryStore, SlavedApplicationServiceStore, SlavedRegistrationStore, diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py deleted file mode 100644 index 8cc6de3f46..0000000000 --- a/synapse/replication/slave/storage/room.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.replication.tcp.streams import PublicRoomsStream -from synapse.storage.database import DatabasePool -from synapse.storage.databases.main.room import RoomWorkerStore - -from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker - - -class RoomStore(RoomWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - self._public_room_id_gen = SlavedIdTracker( - db_conn, "public_room_list_stream", "stream_id" - ) - - def get_current_public_room_stream_id(self): - return self._public_room_id_gen.get_current_token() - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == PublicRoomsStream.NAME: - self._public_room_id_gen.advance(instance_name, token) - - return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 4c0023c68a..f41eabd85e 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -32,7 +32,6 @@ from synapse.replication.tcp.streams._base import ( GroupServerStream, PresenceFederationStream, PresenceStream, - PublicRoomsStream, PushersStream, PushRulesStream, ReceiptsStream, @@ -57,7 +56,6 @@ STREAMS_MAP = { PushRulesStream, PushersStream, CachesStream, - PublicRoomsStream, DeviceListsStream, ToDeviceStream, FederationStream, @@ -79,7 +77,6 @@ __all__ = [ "PushRulesStream", "PushersStream", "CachesStream", - "PublicRoomsStream", "DeviceListsStream", "ToDeviceStream", "TagAccountDataStream", diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 3716c41bea..9b905aba9d 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -447,31 +447,6 @@ class CachesStream(Stream): ) -class PublicRoomsStream(Stream): - """The public rooms list changed""" - - PublicRoomsStreamRow = namedtuple( - "PublicRoomsStreamRow", - ( - "room_id", # str - "visibility", # str - "appservice_id", # str, optional - "network_id", # str, optional - ), - ) - - NAME = "public_rooms" - ROW_TYPE = PublicRoomsStreamRow - - def __init__(self, hs): - store = hs.get_datastore() - super().__init__( - hs.get_instance_name(), - current_token_without_instance(store.get_current_public_room_stream_id), - store.get_all_new_public_rooms, - ) - - class DeviceListsStream(Stream): """Either a user has updated their devices or a remote server needs to be told about a device update. diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 8d9f07111d..01b918e12e 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -127,9 +127,6 @@ class DataStore( self._clock = hs.get_clock() self.database_engine = database.engine - self._public_room_id_gen = StreamIdGenerator( - db_conn, "public_room_list_stream", "stream_id" - ) self._device_list_id_gen = StreamIdGenerator( db_conn, "device_lists_stream", @@ -170,6 +167,7 @@ class DataStore( sequence_name="cache_invalidation_stream_seq", writers=[], ) + else: self._cache_id_gen = None diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 443e5f3315..c7a1c1e8d9 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -890,55 +890,6 @@ class RoomWorkerStore(SQLBaseStore): return total_media_quarantined - async def get_all_new_public_rooms( - self, instance_name: str, last_id: int, current_id: int, limit: int - ) -> Tuple[List[Tuple[int, tuple]], int, bool]: - """Get updates for public rooms replication stream. - - Args: - instance_name: The writer we want to fetch updates from. Unused - here since there is only ever one writer. - last_id: The token to fetch updates from. Exclusive. - current_id: The token to fetch updates up to. Inclusive. - limit: The requested limit for the number of rows to return. The - function may return more or fewer rows. - - Returns: - A tuple consisting of: the updates, a token to use to fetch - subsequent updates, and whether we returned fewer rows than exists - between the requested tokens due to the limit. - - The token returned can be used in a subsequent call to this - function to get further updatees. - - The updates are a list of 2-tuples of stream ID and the row data - """ - if last_id == current_id: - return [], current_id, False - - def get_all_new_public_rooms(txn): - sql = """ - SELECT stream_id, room_id, visibility, appservice_id, network_id - FROM public_room_list_stream - WHERE stream_id > ? AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? - """ - - txn.execute(sql, (last_id, current_id, limit)) - updates = [(row[0], row[1:]) for row in txn] - limited = False - upto_token = current_id - if len(updates) >= limit: - upto_token = updates[-1][0] - limited = True - - return updates, upto_token, limited - - return await self.db_pool.runInteraction( - "get_all_new_public_rooms", get_all_new_public_rooms - ) - async def get_rooms_for_retention_period_in_range( self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False ) -> Dict[str, dict]: @@ -1410,34 +1361,17 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): StoreError if the room could not be stored. """ try: - - def store_room_txn(txn, next_id): - self.db_pool.simple_insert_txn( - txn, - "rooms", - { - "room_id": room_id, - "creator": room_creator_user_id, - "is_public": is_public, - "room_version": room_version.identifier, - "has_auth_chain_index": True, - }, - ) - if is_public: - self.db_pool.simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - }, - ) - - async with self._public_room_id_gen.get_next() as next_id: - await self.db_pool.runInteraction( - "store_room_txn", store_room_txn, next_id - ) + await self.db_pool.simple_insert( + "rooms", + { + "room_id": room_id, + "creator": room_creator_user_id, + "is_public": is_public, + "room_version": room_version.identifier, + "has_auth_chain_index": True, + }, + desc="store_room", + ) except Exception as e: logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") @@ -1470,49 +1404,14 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): lock=False, ) - async def set_room_is_public(self, room_id, is_public): - def set_room_is_public_txn(txn, next_id): - self.db_pool.simple_update_one_txn( - txn, - table="rooms", - keyvalues={"room_id": room_id}, - updatevalues={"is_public": is_public}, - ) - - entries = self.db_pool.simple_select_list_txn( - txn, - table="public_room_list_stream", - keyvalues={ - "room_id": room_id, - "appservice_id": None, - "network_id": None, - }, - retcols=("stream_id", "visibility"), - ) - - entries.sort(key=lambda r: r["stream_id"]) - - add_to_stream = True - if entries: - add_to_stream = bool(entries[-1]["visibility"]) != is_public - - if add_to_stream: - self.db_pool.simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - "appservice_id": None, - "network_id": None, - }, - ) + async def set_room_is_public(self, room_id: str, is_public: bool) -> None: + await self.db_pool.simple_update_one( + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"is_public": is_public}, + desc="set_room_is_public", + ) - async with self._public_room_id_gen.get_next() as next_id: - await self.db_pool.runInteraction( - "set_room_is_public", set_room_is_public_txn, next_id - ) self.hs.get_notifier().on_new_replication_data() async def set_room_is_public_appservice( @@ -1533,68 +1432,33 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): list. """ - def set_room_is_public_appservice_txn(txn, next_id): - if is_public: - try: - self.db_pool.simple_insert_txn( - txn, - table="appservice_room_list", - values={ - "appservice_id": appservice_id, - "network_id": network_id, - "room_id": room_id, - }, - ) - except self.database_engine.module.IntegrityError: - # We've already inserted, nothing to do. - return - else: - self.db_pool.simple_delete_txn( - txn, - table="appservice_room_list", - keyvalues={ - "appservice_id": appservice_id, - "network_id": network_id, - "room_id": room_id, - }, - ) - - entries = self.db_pool.simple_select_list_txn( - txn, - table="public_room_list_stream", + if is_public: + await self.db_pool.simple_upsert( + table="appservice_room_list", keyvalues={ + "appservice_id": appservice_id, + "network_id": network_id, "room_id": room_id, + }, + values={}, + insertion_values={ "appservice_id": appservice_id, "network_id": network_id, + "room_id": room_id, }, - retcols=("stream_id", "visibility"), + desc="set_room_is_public_appservice_true", ) - - entries.sort(key=lambda r: r["stream_id"]) - - add_to_stream = True - if entries: - add_to_stream = bool(entries[-1]["visibility"]) != is_public - - if add_to_stream: - self.db_pool.simple_insert_txn( - txn, - table="public_room_list_stream", - values={ - "stream_id": next_id, - "room_id": room_id, - "visibility": is_public, - "appservice_id": appservice_id, - "network_id": network_id, - }, - ) - - async with self._public_room_id_gen.get_next() as next_id: - await self.db_pool.runInteraction( - "set_room_is_public_appservice", - set_room_is_public_appservice_txn, - next_id, + else: + await self.db_pool.simple_delete( + table="appservice_room_list", + keyvalues={ + "appservice_id": appservice_id, + "network_id": network_id, + "room_id": room_id, + }, + desc="set_room_is_public_appservice_false", ) + self.hs.get_notifier().on_new_replication_data() async def add_event_report( @@ -1787,9 +1651,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): "get_event_reports_paginate", _get_event_reports_paginate_txn ) - def get_current_public_room_stream_id(self): - return self._public_room_id_gen.get_current_token() - async def block_room(self, room_id: str, user_id: str) -> None: """Marks the room as blocked. Can be called multiple times. diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 7e0687e197..a5bc0ee8a5 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 62 +SCHEMA_VERSION = 63 """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -25,6 +25,11 @@ for more information on how this works. Changes in SCHEMA_VERSION = 61: - The `user_stats_historical` and `room_stats_historical` tables are not written and are not read (previously, they were written but not read). + +Changes in SCHEMA_VERSION = 63: + - The `public_room_list_stream` table is not written nor read to + (previously, it was written and read to, but not for any significant purpose). + https://github.com/matrix-org/synapse/pull/10565 """ -- cgit 1.5.1 From 1a9f531c79f8a043db6f151c5393f23aa5675800 Mon Sep 17 00:00:00 2001 From: Azrenbeth <77782548+Azrenbeth@users.noreply.github.com> Date: Tue, 17 Aug 2021 14:22:45 +0100 Subject: Port the PresenceRouter module interface to the new generic interface (#10524) Port the PresenceRouter module interface to the new generic interface introduced in v1.37.0 --- changelog.d/10524.feature | 1 + docs/modules.md | 46 +++++++++ docs/presence_router_module.md | 6 ++ docs/sample_config.yaml | 14 --- synapse/app/_base.py | 2 + synapse/config/server.py | 15 +-- synapse/events/presence_router.py | 192 ++++++++++++++++++++++++++++------- synapse/module_api/__init__.py | 10 ++ tests/events/test_presence_router.py | 109 +++++++++++++++++++- 9 files changed, 326 insertions(+), 69 deletions(-) create mode 100644 changelog.d/10524.feature diff --git a/changelog.d/10524.feature b/changelog.d/10524.feature new file mode 100644 index 0000000000..288c9bd74e --- /dev/null +++ b/changelog.d/10524.feature @@ -0,0 +1 @@ +Port the PresenceRouter module interface to the new generic interface. \ No newline at end of file diff --git a/docs/modules.md b/docs/modules.md index 9a430390a4..ae8d6f5b73 100644 --- a/docs/modules.md +++ b/docs/modules.md @@ -282,6 +282,52 @@ the request is a server admin. Modules can modify the `request_content` (by e.g. adding events to its `initial_state`), or deny the room's creation by raising a `module_api.errors.SynapseError`. +#### Presence router callbacks + +Presence router callbacks allow module developers to specify additional users (local or remote) +to receive certain presence updates from local users. Presence router callbacks can be +registered using the module API's `register_presence_router_callbacks` method. + +The available presence router callbacks are: + +```python +async def get_users_for_states( + self, + state_updates: Iterable["synapse.api.UserPresenceState"], +) -> Dict[str, Set["synapse.api.UserPresenceState"]]: +``` +**Requires** `get_interested_users` to also be registered + +Called when processing updates to the presence state of one or more users. This callback can +be used to instruct the server to forward that presence state to specific users. The module +must return a dictionary that maps from Matrix user IDs (which can be local or remote) to the +`UserPresenceState` changes that they should be forwarded. + +Synapse will then attempt to send the specified presence updates to each user when possible. + +```python +async def get_interested_users( + self, + user_id: str +) -> Union[Set[str], "synapse.module_api.PRESENCE_ALL_USERS"] +``` +**Requires** `get_users_for_states` to also be registered + +Called when determining which users someone should be able to see the presence state of. This +callback should return complementary results to `get_users_for_state` or the presence information +may not be properly forwarded. + +The callback is given the Matrix user ID for a local user that is requesting presence data and +should return the Matrix user IDs of the users whose presence state they are allowed to +query. The returned users can be local or remote. + +Alternatively the callback can return `synapse.module_api.PRESENCE_ALL_USERS` +to indicate that the user should receive updates from all known users. + +For example, if the user `@alice:example.org` is passed to this method, and the Set +`{"@bob:example.com", "@charlie:somewhere.org"}` is returned, this signifies that Alice +should receive presence updates sent by Bob and Charlie, regardless of whether these users +share a room. ### Porting an existing module that uses the old interface diff --git a/docs/presence_router_module.md b/docs/presence_router_module.md index 4a3e720240..face54fe2b 100644 --- a/docs/presence_router_module.md +++ b/docs/presence_router_module.md @@ -1,3 +1,9 @@ +

+This page of the Synapse documentation is now deprecated. For up to date +documentation on setting up or writing a presence router module, please see +this page. +

+ # Presence Router Module Synapse supports configuring a module that can specify additional users diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index aeebcaf45f..6030e85a08 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -108,20 +108,6 @@ presence: # #enabled: false - # Presence routers are third-party modules that can specify additional logic - # to where presence updates from users are routed. - # - presence_router: - # The custom module's class. Uncomment to use a custom presence router module. - # - #module: "my_custom_router.PresenceRouter" - - # Configuration options of the custom module. Refer to your module's - # documentation for available options. - # - #config: - # example_option: 'something' - # Whether to require authentication to retrieve profile data (avatars, # display names) of other users through the client API. Defaults to # 'false'. Note that profile data is also available via the federation diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 50a02f51f5..39e28aff9f 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -37,6 +37,7 @@ from synapse.app import check_bind_error from synapse.app.phone_stats_home import start_phone_stats_home from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory +from synapse.events.presence_router import load_legacy_presence_router from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.logging.context import PreserveLoggingContext @@ -370,6 +371,7 @@ async def start(hs: "HomeServer"): load_legacy_spam_checkers(hs) load_legacy_third_party_event_rules(hs) + load_legacy_presence_router(hs) # If we've configured an expiry time for caches, start the background job now. setup_expire_lru_cache_entries(hs) diff --git a/synapse/config/server.py b/synapse/config/server.py index 187b4301a0..871422ea28 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -248,6 +248,7 @@ class ServerConfig(Config): self.use_presence = config.get("use_presence", True) # Custom presence router module + # This is the legacy way of configuring it (the config should now be put in the modules section) self.presence_router_module_class = None self.presence_router_config = None presence_router_config = presence_config.get("presence_router") @@ -858,20 +859,6 @@ class ServerConfig(Config): # #enabled: false - # Presence routers are third-party modules that can specify additional logic - # to where presence updates from users are routed. - # - presence_router: - # The custom module's class. Uncomment to use a custom presence router module. - # - #module: "my_custom_router.PresenceRouter" - - # Configuration options of the custom module. Refer to your module's - # documentation for available options. - # - #config: - # example_option: 'something' - # Whether to require authentication to retrieve profile data (avatars, # display names) of other users through the client API. Defaults to # 'false'. Note that profile data is also available via the federation diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py index 6c37c8a7a4..eb4556cdc1 100644 --- a/synapse/events/presence_router.py +++ b/synapse/events/presence_router.py @@ -11,45 +11,115 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import TYPE_CHECKING, Dict, Iterable, Set, Union +import logging +from typing import ( + TYPE_CHECKING, + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + Set, + Union, +) from synapse.api.presence import UserPresenceState +from synapse.util.async_helpers import maybe_awaitable if TYPE_CHECKING: from synapse.server import HomeServer +GET_USERS_FOR_STATES_CALLBACK = Callable[ + [Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]] +] +GET_INTERESTED_USERS_CALLBACK = Callable[ + [str], Awaitable[Union[Set[str], "PresenceRouter.ALL_USERS"]] +] + +logger = logging.getLogger(__name__) + + +def load_legacy_presence_router(hs: "HomeServer"): + """Wrapper that loads a presence router module configured using the old + configuration, and registers the hooks they implement. + """ + + if hs.config.presence_router_module_class is None: + return + + module = hs.config.presence_router_module_class + config = hs.config.presence_router_config + api = hs.get_module_api() + + presence_router = module(config=config, module_api=api) + + # The known hooks. If a module implements a method which name appears in this set, + # we'll want to register it. + presence_router_methods = { + "get_users_for_states", + "get_interested_users", + } + + # All methods that the module provides should be async, but this wasn't enforced + # in the old module system, so we wrap them if needed + def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: + # f might be None if the callback isn't implemented by the module. In this + # case we don't want to register a callback at all so we return None. + if f is None: + return None + + def run(*args, **kwargs): + # mypy doesn't do well across function boundaries so we need to tell it + # f is definitely not None. + assert f is not None + + return maybe_awaitable(f(*args, **kwargs)) + + return run + + # Register the hooks through the module API. + hooks = { + hook: async_wrapper(getattr(presence_router, hook, None)) + for hook in presence_router_methods + } + + api.register_presence_router_callbacks(**hooks) + class PresenceRouter: """ A module that the homeserver will call upon to help route user presence updates to - additional destinations. If a custom presence router is configured, calls will be - passed to that instead. + additional destinations. """ ALL_USERS = "ALL" def __init__(self, hs: "HomeServer"): - self.custom_presence_router = None + # Initially there are no callbacks + self._get_users_for_states_callbacks: List[GET_USERS_FOR_STATES_CALLBACK] = [] + self._get_interested_users_callbacks: List[GET_INTERESTED_USERS_CALLBACK] = [] - # Check whether a custom presence router module has been configured - if hs.config.presence_router_module_class: - # Initialise the module - self.custom_presence_router = hs.config.presence_router_module_class( - config=hs.config.presence_router_config, module_api=hs.get_module_api() + def register_presence_router_callbacks( + self, + get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None, + get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None, + ): + # PresenceRouter modules are required to implement both of these methods + # or neither of them as they are assumed to act in a complementary manner + paired_methods = [get_users_for_states, get_interested_users] + if paired_methods.count(None) == 1: + raise RuntimeError( + "PresenceRouter modules must register neither or both of the paired callbacks: " + "[get_users_for_states, get_interested_users]" ) - # Ensure the module has implemented the required methods - required_methods = ["get_users_for_states", "get_interested_users"] - for method_name in required_methods: - if not hasattr(self.custom_presence_router, method_name): - raise Exception( - "PresenceRouter module '%s' must implement all required methods: %s" - % ( - hs.config.presence_router_module_class.__name__, - ", ".join(required_methods), - ) - ) + # Append the methods provided to the lists of callbacks + if get_users_for_states is not None: + self._get_users_for_states_callbacks.append(get_users_for_states) + + if get_interested_users is not None: + self._get_interested_users_callbacks.append(get_interested_users) async def get_users_for_states( self, @@ -66,14 +136,40 @@ class PresenceRouter: A dictionary of user_id -> set of UserPresenceState, indicating which presence updates each user should receive. """ - if self.custom_presence_router is not None: - # Ask the custom module - return await self.custom_presence_router.get_users_for_states( - state_updates=state_updates - ) - # Don't include any extra destinations for presence updates - return {} + # Bail out early if we don't have any callbacks to run. + if len(self._get_users_for_states_callbacks) == 0: + # Don't include any extra destinations for presence updates + return {} + + users_for_states = {} + # run all the callbacks for get_users_for_states and combine the results + for callback in self._get_users_for_states_callbacks: + try: + result = await callback(state_updates) + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + continue + + if not isinstance(result, Dict): + logger.warning( + "Wrong type returned by module API callback %s: %s, expected Dict", + callback, + result, + ) + continue + + for key, new_entries in result.items(): + if not isinstance(new_entries, Set): + logger.warning( + "Wrong type returned by module API callback %s: %s, expected Set", + callback, + new_entries, + ) + break + users_for_states.setdefault(key, set()).update(new_entries) + + return users_for_states async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]: """ @@ -92,12 +188,36 @@ class PresenceRouter: A set of user IDs to return presence updates for, or ALL_USERS to return all known updates. """ - if self.custom_presence_router is not None: - # Ask the custom module for interested users - return await self.custom_presence_router.get_interested_users( - user_id=user_id - ) - # A custom presence router is not defined. - # Don't report any additional interested users - return set() + # Bail out early if we don't have any callbacks to run. + if len(self._get_interested_users_callbacks) == 0: + # Don't report any additional interested users + return set() + + interested_users = set() + # run all the callbacks for get_interested_users and combine the results + for callback in self._get_interested_users_callbacks: + try: + result = await callback(user_id) + except Exception as e: + logger.warning("Failed to run module API callback %s: %s", callback, e) + continue + + # If one of the callbacks returns ALL_USERS then we can stop calling all + # of the other callbacks, since the set of interested_users is already as + # large as it can possibly be + if result == PresenceRouter.ALL_USERS: + return PresenceRouter.ALL_USERS + + if not isinstance(result, Set): + logger.warning( + "Wrong type returned by module API callback %s: %s, expected set", + callback, + result, + ) + continue + + # Add the new interested users to the set + interested_users.update(result) + + return interested_users diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 82725853bc..84bb7264a4 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -32,6 +32,7 @@ from twisted.internet import defer from twisted.web.resource import IResource from synapse.events import EventBase +from synapse.events.presence_router import PresenceRouter from synapse.http.client import SimpleHttpClient from synapse.http.server import ( DirectServeHtmlResource, @@ -57,6 +58,8 @@ This package defines the 'stable' API which can be used by extension modules whi are loaded into Synapse. """ +PRESENCE_ALL_USERS = PresenceRouter.ALL_USERS + __all__ = [ "errors", "make_deferred_yieldable", @@ -70,6 +73,7 @@ __all__ = [ "DirectServeHtmlResource", "DirectServeJsonResource", "ModuleApi", + "PRESENCE_ALL_USERS", ] logger = logging.getLogger(__name__) @@ -111,6 +115,7 @@ class ModuleApi: self._spam_checker = hs.get_spam_checker() self._account_validity_handler = hs.get_account_validity_handler() self._third_party_event_rules = hs.get_third_party_event_rules() + self._presence_router = hs.get_presence_router() ################################################################################# # The following methods should only be called during the module's initialisation. @@ -130,6 +135,11 @@ class ModuleApi: """Registers callbacks for third party event rules capabilities.""" return self._third_party_event_rules.register_third_party_rules_callbacks + @property + def register_presence_router_callbacks(self): + """Registers callbacks for presence router capabilities.""" + return self._presence_router.register_presence_router_callbacks + def register_web_resource(self, path: str, resource: IResource): """Registers a web resource to be served at the given path. diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index 6b87f571b8..3b3866bff8 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -17,7 +17,7 @@ from unittest.mock import Mock import attr from synapse.api.constants import EduTypes -from synapse.events.presence_router import PresenceRouter +from synapse.events.presence_router import PresenceRouter, load_legacy_presence_router from synapse.federation.units import Transaction from synapse.handlers.presence import UserPresenceState from synapse.module_api import ModuleApi @@ -34,7 +34,7 @@ class PresenceRouterTestConfig: users_who_should_receive_all_presence = attr.ib(type=List[str], default=[]) -class PresenceRouterTestModule: +class LegacyPresenceRouterTestModule: def __init__(self, config: PresenceRouterTestConfig, module_api: ModuleApi): self._config = config self._module_api = module_api @@ -77,6 +77,53 @@ class PresenceRouterTestModule: return config +class PresenceRouterTestModule: + def __init__(self, config: PresenceRouterTestConfig, api: ModuleApi): + self._config = config + self._module_api = api + api.register_presence_router_callbacks( + get_users_for_states=self.get_users_for_states, + get_interested_users=self.get_interested_users, + ) + + async def get_users_for_states( + self, state_updates: Iterable[UserPresenceState] + ) -> Dict[str, Set[UserPresenceState]]: + users_to_state = { + user_id: set(state_updates) + for user_id in self._config.users_who_should_receive_all_presence + } + return users_to_state + + async def get_interested_users( + self, user_id: str + ) -> Union[Set[str], PresenceRouter.ALL_USERS]: + if user_id in self._config.users_who_should_receive_all_presence: + return PresenceRouter.ALL_USERS + + return set() + + @staticmethod + def parse_config(config_dict: dict) -> PresenceRouterTestConfig: + """Parse a configuration dictionary from the homeserver config, do + some validation and return a typed PresenceRouterConfig. + + Args: + config_dict: The configuration dictionary. + + Returns: + A validated config object. + """ + # Initialise a typed config object + config = PresenceRouterTestConfig() + + config.users_who_should_receive_all_presence = config_dict.get( + "users_who_should_receive_all_presence" + ) + + return config + + class PresenceRouterTestCase(FederatingHomeserverTestCase): servlets = [ admin.register_servlets, @@ -86,9 +133,17 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): ] def make_homeserver(self, reactor, clock): - return self.setup_test_homeserver( + hs = self.setup_test_homeserver( federation_transport_client=Mock(spec=["send_transaction"]), ) + # Load the modules into the homeserver + module_api = hs.get_module_api() + for module, config in hs.config.modules.loaded_modules: + module(config=config, api=module_api) + + load_legacy_presence_router(hs) + + return hs def prepare(self, reactor, clock, homeserver): self.sync_handler = self.hs.get_sync_handler() @@ -98,7 +153,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): { "presence": { "presence_router": { - "module": __name__ + ".PresenceRouterTestModule", + "module": __name__ + ".LegacyPresenceRouterTestModule", "config": { "users_who_should_receive_all_presence": [ "@presence_gobbler:test", @@ -109,7 +164,28 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): "send_federation": True, } ) + def test_receiving_all_presence_legacy(self): + self.receiving_all_presence_test_body() + + @override_config( + { + "modules": [ + { + "module": __name__ + ".PresenceRouterTestModule", + "config": { + "users_who_should_receive_all_presence": [ + "@presence_gobbler:test", + ] + }, + }, + ], + "send_federation": True, + } + ) def test_receiving_all_presence(self): + self.receiving_all_presence_test_body() + + def receiving_all_presence_test_body(self): """Test that a user that does not share a room with another other can receive presence for them, due to presence routing. """ @@ -203,7 +279,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): { "presence": { "presence_router": { - "module": __name__ + ".PresenceRouterTestModule", + "module": __name__ + ".LegacyPresenceRouterTestModule", "config": { "users_who_should_receive_all_presence": [ "@presence_gobbler1:test", @@ -216,7 +292,30 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): "send_federation": True, } ) + def test_send_local_online_presence_to_with_module_legacy(self): + self.send_local_online_presence_to_with_module_test_body() + + @override_config( + { + "modules": [ + { + "module": __name__ + ".PresenceRouterTestModule", + "config": { + "users_who_should_receive_all_presence": [ + "@presence_gobbler1:test", + "@presence_gobbler2:test", + "@far_away_person:island", + ] + }, + }, + ], + "send_federation": True, + } + ) def test_send_local_online_presence_to_with_module(self): + self.send_local_online_presence_to_with_module_test_body() + + def send_local_online_presence_to_with_module_test_body(self): """Tests that send_local_presence_to_users sends local online presence to a set of specified local and remote users, with a custom PresenceRouter module enabled. """ -- cgit 1.5.1 From 703e3a9e853b7c2212045ec52eb6b2c6e370c6f9 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 17 Aug 2021 14:33:16 +0100 Subject: Allow /createRoom to be run on workers (#10564) Fixes https://github.com/matrix-org/synapse/issues/7867 --- changelog.d/10564.feature | 1 + docs/workers.md | 1 + synapse/rest/client/room.py | 2 +- synapse/storage/databases/main/room.py | 68 +++++++++++++++++----------------- 4 files changed, 37 insertions(+), 35 deletions(-) create mode 100644 changelog.d/10564.feature diff --git a/changelog.d/10564.feature b/changelog.d/10564.feature new file mode 100644 index 0000000000..4de32240b2 --- /dev/null +++ b/changelog.d/10564.feature @@ -0,0 +1 @@ +Add support for routing `/createRoom` to workers. diff --git a/docs/workers.md b/docs/workers.md index d8672324c3..1657dfc759 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -214,6 +214,7 @@ expressions: ^/_matrix/federation/v1/send/ # Client API requests + ^/_matrix/client/(api/v1|r0|unstable)/createRoom$ ^/_matrix/client/(api/v1|r0|unstable)/publicRooms$ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/joined_members$ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/context/.*$ diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index ed238b2141..c5c54564be 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -1141,10 +1141,10 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False): JoinedRoomsRestServlet(hs).register(http_server) RoomAliasListServlet(hs).register(http_server) SearchRestServlet(hs).register(http_server) + RoomCreateRestServlet(hs).register(http_server) # Some servlets only get registered for the main process. if not is_worker: - RoomCreateRestServlet(hs).register(http_server) RoomForgetRestServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index c7a1c1e8d9..f98b892598 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -73,6 +73,40 @@ class RoomWorkerStore(SQLBaseStore): self.config = hs.config + async def store_room( + self, + room_id: str, + room_creator_user_id: str, + is_public: bool, + room_version: RoomVersion, + ): + """Stores a room. + + Args: + room_id: The desired room ID, can be None. + room_creator_user_id: The user ID of the room creator. + is_public: True to indicate that this room should appear in + public room lists. + room_version: The version of the room + Raises: + StoreError if the room could not be stored. + """ + try: + await self.db_pool.simple_insert( + "rooms", + { + "room_id": room_id, + "creator": room_creator_user_id, + "is_public": is_public, + "room_version": room_version.identifier, + "has_auth_chain_index": True, + }, + desc="store_room", + ) + except Exception as e: + logger.error("store_room with room_id=%s failed: %s", room_id, e) + raise StoreError(500, "Problem creating room.") + async def get_room(self, room_id: str) -> dict: """Retrieve a room. @@ -1342,40 +1376,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): lock=False, ) - async def store_room( - self, - room_id: str, - room_creator_user_id: str, - is_public: bool, - room_version: RoomVersion, - ): - """Stores a room. - - Args: - room_id: The desired room ID, can be None. - room_creator_user_id: The user ID of the room creator. - is_public: True to indicate that this room should appear in - public room lists. - room_version: The version of the room - Raises: - StoreError if the room could not be stored. - """ - try: - await self.db_pool.simple_insert( - "rooms", - { - "room_id": room_id, - "creator": room_creator_user_id, - "is_public": is_public, - "room_version": room_version.identifier, - "has_auth_chain_index": True, - }, - desc="store_room", - ) - except Exception as e: - logger.error("store_room with room_id=%s failed: %s", room_id, e) - raise StoreError(500, "Problem creating room.") - async def maybe_store_room_on_outlier_membership( self, room_id: str, room_version: RoomVersion ): -- cgit 1.5.1 From 430241a1e9fd4fbb82d83958b61bbd66c9ba3505 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 17 Aug 2021 22:19:13 +0200 Subject: Remove deprecated Shutdown Room and Purge Room Admin API (#8830) --- changelog.d/8830.removal | 1 + docs/SUMMARY.md | 2 - docs/admin_api/purge_room.md | 21 ---- docs/admin_api/shutdown_room.md | 102 ------------------- docs/upgrade.md | 13 +++ synapse/rest/admin/__init__.py | 4 - synapse/rest/admin/purge_room_servlet.py | 58 ----------- synapse/rest/admin/rooms.py | 35 ------- tests/rest/admin/test_room.py | 162 ------------------------------- 9 files changed, 14 insertions(+), 384 deletions(-) create mode 100644 changelog.d/8830.removal delete mode 100644 docs/admin_api/purge_room.md delete mode 100644 docs/admin_api/shutdown_room.md delete mode 100644 synapse/rest/admin/purge_room_servlet.py diff --git a/changelog.d/8830.removal b/changelog.d/8830.removal new file mode 100644 index 0000000000..b3a93a9af2 --- /dev/null +++ b/changelog.d/8830.removal @@ -0,0 +1 @@ +Remove deprecated Shutdown Room and Purge Room Admin API. \ No newline at end of file diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 3d320a1c43..0714f151fd 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -51,12 +51,10 @@ - [Event Reports](admin_api/event_reports.md) - [Media](admin_api/media_admin_api.md) - [Purge History](admin_api/purge_history_api.md) - - [Purge Rooms](admin_api/purge_room.md) - [Register Users](admin_api/register_api.md) - [Manipulate Room Membership](admin_api/room_membership.md) - [Rooms](admin_api/rooms.md) - [Server Notices](admin_api/server_notices.md) - - [Shutdown Room](admin_api/shutdown_room.md) - [Statistics](admin_api/statistics.md) - [Users](admin_api/user_admin_api.md) - [Server Version](admin_api/version_api.md) diff --git a/docs/admin_api/purge_room.md b/docs/admin_api/purge_room.md deleted file mode 100644 index 54fea2db6d..0000000000 --- a/docs/admin_api/purge_room.md +++ /dev/null @@ -1,21 +0,0 @@ -Deprecated: Purge room API -========================== - -**The old Purge room API is deprecated and will be removed in a future release. -See the new [Delete Room API](rooms.md#delete-room-api) for more details.** - -This API will remove all trace of a room from your database. - -All local users must have left the room before it can be removed. - -The API is: - -``` -POST /_synapse/admin/v1/purge_room - -{ - "room_id": "!room:id" -} -``` - -You must authenticate using the access token of an admin user. diff --git a/docs/admin_api/shutdown_room.md b/docs/admin_api/shutdown_room.md deleted file mode 100644 index 856a629487..0000000000 --- a/docs/admin_api/shutdown_room.md +++ /dev/null @@ -1,102 +0,0 @@ -# Deprecated: Shutdown room API - -**The old Shutdown room API is deprecated and will be removed in a future release. -See the new [Delete Room API](rooms.md#delete-room-api) for more details.** - -Shuts down a room, preventing new joins and moves local users and room aliases automatically -to a new room. The new room will be created with the user specified by the -`new_room_user_id` parameter as room administrator and will contain a message -explaining what happened. Users invited to the new room will have power level --10 by default, and thus be unable to speak. The old room's power levels will be changed to -disallow any further invites or joins. - -The local server will only have the power to move local user and room aliases to -the new room. Users on other servers will be unaffected. - -## API - -You will need to authenticate with an access token for an admin user. - -### URL - -`POST /_synapse/admin/v1/shutdown_room/{room_id}` - -### URL Parameters - -* `room_id` - The ID of the room (e.g `!someroom:example.com`) - -### JSON Body Parameters - -* `new_room_user_id` - Required. A string representing the user ID of the user that will admin - the new room that all users in the old room will be moved to. -* `room_name` - Optional. A string representing the name of the room that new users will be - invited to. -* `message` - Optional. A string containing the first message that will be sent as - `new_room_user_id` in the new room. Ideally this will clearly convey why the - original room was shut down. - -If not specified, the default value of `room_name` is "Content Violation -Notification". The default value of `message` is "Sharing illegal content on -othis server is not permitted and rooms in violation will be blocked." - -### Response Parameters - -* `kicked_users` - An integer number representing the number of users that - were kicked. -* `failed_to_kick_users` - An integer number representing the number of users - that were not kicked. -* `local_aliases` - An array of strings representing the local aliases that were migrated from - the old room to the new. -* `new_room_id` - A string representing the room ID of the new room. - -## Example - -Request: - -``` -POST /_synapse/admin/v1/shutdown_room/!somebadroom%3Aexample.com - -{ - "new_room_user_id": "@someuser:example.com", - "room_name": "Content Violation Notification", - "message": "Bad Room has been shutdown due to content violations on this server. Please review our Terms of Service." -} -``` - -Response: - -``` -{ - "kicked_users": 5, - "failed_to_kick_users": 0, - "local_aliases": ["#badroom:example.com", "#evilsaloon:example.com], - "new_room_id": "!newroomid:example.com", -}, -``` - -## Undoing room shutdowns - -*Note*: This guide may be outdated by the time you read it. By nature of room shutdowns being performed at the database level, -the structure can and does change without notice. - -First, it's important to understand that a room shutdown is very destructive. Undoing a shutdown is not as simple as pretending it -never happened - work has to be done to move forward instead of resetting the past. In fact, in some cases it might not be possible -to recover at all: - -* If the room was invite-only, your users will need to be re-invited. -* If the room no longer has any members at all, it'll be impossible to rejoin. -* The first user to rejoin will have to do so via an alias on a different server. - -With all that being said, if you still want to try and recover the room: - -1. For safety reasons, shut down Synapse. -2. In the database, run `DELETE FROM blocked_rooms WHERE room_id = '!example:example.org';` - * For caution: it's recommended to run this in a transaction: `BEGIN; DELETE ...;`, verify you got 1 result, then `COMMIT;`. - * The room ID is the same one supplied to the shutdown room API, not the Content Violation room. -3. Restart Synapse. - -You will have to manually handle, if you so choose, the following: - -* Aliases that would have been redirected to the Content Violation room. -* Users that would have been booted from the room (and will have been force-joined to the Content Violation room). -* Removal of the Content Violation room if desired. diff --git a/docs/upgrade.md b/docs/upgrade.md index 8831c9d6cf..3ac2387e2a 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -85,6 +85,19 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.xx.0 + +## Removal of old Room Admin API + +The following admin APIs were deprecated in [Synapse 1.25](https://github.com/matrix-org/synapse/blob/v1.25.0/CHANGES.md#removal-warning) +(released on 2021-01-13) and have now been removed: + +- `POST /_synapse/admin/v1/purge_room` +- `POST /_synapse/admin/v1/shutdown_room/` + +Any scripts still using the above APIs should be converted to use the +[Delete Room API](https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#delete-room-api). + # Upgrading to v1.xx.0 diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 8a91068092..2acaea3003 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -36,7 +36,6 @@ from synapse.rest.admin.event_reports import ( ) from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo -from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet from synapse.rest.admin.rooms import ( DeleteRoomRestServlet, ForwardExtremitiesRestServlet, @@ -47,7 +46,6 @@ from synapse.rest.admin.rooms import ( RoomMembersRestServlet, RoomRestServlet, RoomStateRestServlet, - ShutdownRoomRestServlet, ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet @@ -221,7 +219,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomMembersRestServlet(hs).register(http_server) DeleteRoomRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) - PurgeRoomServlet(hs).register(http_server) SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) @@ -255,7 +252,6 @@ def register_servlets_for_client_rest_resource( PurgeHistoryRestServlet(hs).register(http_server) ResetPasswordRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server) - ShutdownRoomRestServlet(hs).register(http_server) UserRegisterServlet(hs).register(http_server) DeleteGroupAdminRestServlet(hs).register(http_server) AccountValidityRenewServlet(hs).register(http_server) diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py deleted file mode 100644 index 2365ff7a0f..0000000000 --- a/synapse/rest/admin/purge_room_servlet.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import TYPE_CHECKING, Tuple - -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_json_object_from_request, -) -from synapse.http.site import SynapseRequest -from synapse.rest.admin import assert_requester_is_admin -from synapse.rest.admin._base import admin_patterns -from synapse.types import JsonDict - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class PurgeRoomServlet(RestServlet): - """Servlet which will remove all trace of a room from the database - - POST /_synapse/admin/v1/purge_room - { - "room_id": "!room:id" - } - - returns: - - {} - """ - - PATTERNS = admin_patterns("/purge_room$") - - def __init__(self, hs: "HomeServer"): - self.hs = hs - self.auth = hs.get_auth() - self.pagination_handler = hs.get_pagination_handler() - - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - await assert_requester_is_admin(self.auth, request) - - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ("room_id",)) - - await self.pagination_handler.purge_room(body["room_id"]) - - return 200, {} diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 975c28b225..ad83d4b54c 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -46,41 +46,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class ShutdownRoomRestServlet(RestServlet): - """Shuts down a room by removing all local users from the room and blocking - all future invites and joins to the room. Any local aliases will be repointed - to a new room created by `new_room_user_id` and kicked users will be auto - joined to the new room. - """ - - PATTERNS = admin_patterns("/shutdown_room/(?P[^/]+)") - - def __init__(self, hs: "HomeServer"): - self.hs = hs - self.auth = hs.get_auth() - self.room_shutdown_handler = hs.get_room_shutdown_handler() - - async def on_POST( - self, request: SynapseRequest, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) - - content = parse_json_object_from_request(request) - assert_params_in_dict(content, ["new_room_user_id"]) - - ret = await self.room_shutdown_handler.shutdown_room( - room_id=room_id, - new_room_user_id=content["new_room_user_id"], - new_room_name=content.get("room_name"), - message=content.get("message"), - requester_user_id=requester.user.to_string(), - block=True, - ) - - return (200, ret) - - class DeleteRoomRestServlet(RestServlet): """Delete a room from server. diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index c9d4731017..40e032df7f 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -29,123 +29,6 @@ from tests import unittest """Tests admin REST events for /rooms paths.""" -class ShutdownRoomTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - events.register_servlets, - room.register_servlets, - room.register_deprecated_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.event_creation_handler = hs.get_event_creation_handler() - hs.config.user_consent_version = "1" - - consent_uri_builder = Mock() - consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" - self.event_creation_handler._consent_uri_builder = consent_uri_builder - - self.store = hs.get_datastore() - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - self.other_user = self.register_user("user", "pass") - self.other_user_token = self.login("user", "pass") - - # Mark the admin user as having consented - self.get_success(self.store.user_set_consent_version(self.admin_user, "1")) - - def test_shutdown_room_consent(self): - """Test that we can shutdown rooms with local users who have not - yet accepted the privacy policy. This used to fail when we tried to - force part the user from the old room. - """ - self.event_creation_handler._block_events_without_consent_error = None - - room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) - - # Assert one user in room - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertEqual([self.other_user], users_in_room) - - # Enable require consent to send events - self.event_creation_handler._block_events_without_consent_error = "Error" - - # Assert that the user is getting consent error - self.helper.send( - room_id, body="foo", tok=self.other_user_token, expect_code=403 - ) - - # Test that the admin can still send shutdown - url = "/_synapse/admin/v1/shutdown_room/" + room_id - channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Assert there is now no longer anyone in the room - users_in_room = self.get_success(self.store.get_users_in_room(room_id)) - self.assertEqual([], users_in_room) - - def test_shutdown_room_block_peek(self): - """Test that a world_readable room can no longer be peeked into after - it has been shut down. - """ - - self.event_creation_handler._block_events_without_consent_error = None - - room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token) - - # Enable world readable - url = "rooms/%s/state/m.room.history_visibility" % (room_id,) - channel = self.make_request( - "PUT", - url.encode("ascii"), - json.dumps({"history_visibility": "world_readable"}), - access_token=self.other_user_token, - ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Test that the admin can still send shutdown - url = "/_synapse/admin/v1/shutdown_room/" + room_id - channel = self.make_request( - "POST", - url.encode("ascii"), - json.dumps({"new_room_user_id": self.admin_user}), - access_token=self.admin_user_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Assert we can no longer peek into the room - self._assert_peek(room_id, expect_code=403) - - def _assert_peek(self, room_id, expect_code): - """Assert that the admin user can (or cannot) peek into the room.""" - - url = "rooms/%s/initialSync" % (room_id,) - channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) - - url = "events?timeout=0&room_id=" + room_id - channel = self.make_request( - "GET", url.encode("ascii"), access_token=self.admin_user_tok - ) - self.assertEqual( - expect_code, int(channel.result["code"]), msg=channel.result["body"] - ) - - @parameterized_class( ("method", "url_template"), [ @@ -557,51 +440,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): ) -class PurgeRoomTestCase(unittest.HomeserverTestCase): - """Test /purge_room admin API.""" - - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - - def test_purge_room(self): - room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) - - # All users have to have left the room. - self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) - - url = "/_synapse/admin/v1/purge_room" - channel = self.make_request( - "POST", - url.encode("ascii"), - {"room_id": room_id}, - access_token=self.admin_user_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Test that the following tables have been purged of all rows related to the room. - for table in PURGE_TABLES: - count = self.get_success( - self.store.db_pool.simple_select_one_onecol( - table=table, - keyvalues={"room_id": room_id}, - retcol="COUNT(*)", - desc="test_purge_room", - ) - ) - - self.assertEqual(count, 0, msg=f"Rows not purged in {table}") - - class RoomTestCase(unittest.HomeserverTestCase): """Test /room admin API.""" -- cgit 1.5.1 From 5581dd7bf7b1d1fb10d4852587d2712c8391c07c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Aug 2021 11:21:11 +0100 Subject: Allow modules to run looping call on all instances (#10638) By default the calls only ran on the worker configured to run background tasks. --- changelog.d/10638.feature | 1 + synapse/module_api/__init__.py | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10638.feature diff --git a/changelog.d/10638.feature b/changelog.d/10638.feature new file mode 100644 index 0000000000..c1de91f334 --- /dev/null +++ b/changelog.d/10638.feature @@ -0,0 +1 @@ +Add option to allow modules to run periodic tasks on all instances, rather than just the one configured to run background tasks. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 2f99d31c42..2d2ed229e2 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -604,10 +604,15 @@ class ModuleApi: msec: float, *args, desc: Optional[str] = None, + run_on_all_instances: bool = False, **kwargs, ): """Wraps a function as a background process and calls it repeatedly. + NOTE: Will only run on the instance that is configured to run + background processes (which is the main process by default), unless + `run_on_all_workers` is set. + Waits `msec` initially before calling `f` for the first time. Args: @@ -618,12 +623,14 @@ class ModuleApi: msec: How long to wait between calls in milliseconds. *args: Positional arguments to pass to function. desc: The background task's description. Default to the function's name. + run_on_all_instances: Whether to run this on all instances, rather + than just the instance configured to run background tasks. **kwargs: Key arguments to pass to function. """ if desc is None: desc = f.__name__ - if self._hs.config.run_background_tasks: + if self._hs.config.run_background_tasks or run_on_all_instances: self._clock.looping_call( run_as_background_process, msec, -- cgit 1.5.1 From 6a5f8fbcdaf54b6b2a0562ad5bd97f382713fb96 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 18 Aug 2021 07:27:32 -0400 Subject: Use auto-attribs for attrs classes for sync. (#10630) --- changelog.d/10630.misc | 1 + synapse/handlers/sync.py | 156 +++++++++++++++++++++++------------------------ 2 files changed, 79 insertions(+), 78 deletions(-) create mode 100644 changelog.d/10630.misc diff --git a/changelog.d/10630.misc b/changelog.d/10630.misc new file mode 100644 index 0000000000..7d01e00e48 --- /dev/null +++ b/changelog.d/10630.misc @@ -0,0 +1 @@ +Use auto-attribs for the attrs classes used in sync. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 590642f510..eba915819e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -86,20 +86,20 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100 SyncRequestKey = Tuple[Any, ...] -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class SyncConfig: - user = attr.ib(type=UserID) - filter_collection = attr.ib(type=FilterCollection) - is_guest = attr.ib(type=bool) - request_key = attr.ib(type=SyncRequestKey) - device_id = attr.ib(type=Optional[str]) + user: UserID + filter_collection: FilterCollection + is_guest: bool + request_key: SyncRequestKey + device_id: Optional[str] -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class TimelineBatch: - prev_batch = attr.ib(type=StreamToken) - events = attr.ib(type=List[EventBase]) - limited = attr.ib(type=bool) + prev_batch: StreamToken + events: List[EventBase] + limited: bool def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -113,16 +113,16 @@ class TimelineBatch: # if there are updates for it, which we check after the instance has been created. # This should not be a big deal because we update the notification counts afterwards as # well anyway. -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class JoinedSyncResult: - room_id = attr.ib(type=str) - timeline = attr.ib(type=TimelineBatch) - state = attr.ib(type=StateMap[EventBase]) - ephemeral = attr.ib(type=List[JsonDict]) - account_data = attr.ib(type=List[JsonDict]) - unread_notifications = attr.ib(type=JsonDict) - summary = attr.ib(type=Optional[JsonDict]) - unread_count = attr.ib(type=int) + room_id: str + timeline: TimelineBatch + state: StateMap[EventBase] + ephemeral: List[JsonDict] + account_data: List[JsonDict] + unread_notifications: JsonDict + summary: Optional[JsonDict] + unread_count: int def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -138,12 +138,12 @@ class JoinedSyncResult: ) -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class ArchivedSyncResult: - room_id = attr.ib(type=str) - timeline = attr.ib(type=TimelineBatch) - state = attr.ib(type=StateMap[EventBase]) - account_data = attr.ib(type=List[JsonDict]) + room_id: str + timeline: TimelineBatch + state: StateMap[EventBase] + account_data: List[JsonDict] def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -152,37 +152,37 @@ class ArchivedSyncResult: return bool(self.timeline or self.state or self.account_data) -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class InvitedSyncResult: - room_id = attr.ib(type=str) - invite = attr.ib(type=EventBase) + room_id: str + invite: EventBase def __bool__(self) -> bool: """Invited rooms should always be reported to the client""" return True -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class KnockedSyncResult: - room_id = attr.ib(type=str) - knock = attr.ib(type=EventBase) + room_id: str + knock: EventBase def __bool__(self) -> bool: """Knocked rooms should always be reported to the client""" return True -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class GroupsSyncResult: - join = attr.ib(type=JsonDict) - invite = attr.ib(type=JsonDict) - leave = attr.ib(type=JsonDict) + join: JsonDict + invite: JsonDict + leave: JsonDict def __bool__(self) -> bool: return bool(self.join or self.invite or self.leave) -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class DeviceLists: """ Attributes: @@ -190,27 +190,27 @@ class DeviceLists: left: List of user_ids whose devices we no longer track """ - changed = attr.ib(type=Collection[str]) - left = attr.ib(type=Collection[str]) + changed: Collection[str] + left: Collection[str] def __bool__(self) -> bool: return bool(self.changed or self.left) -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class _RoomChanges: """The set of room entries to include in the sync, plus the set of joined and left room IDs since last sync. """ - room_entries = attr.ib(type=List["RoomSyncResultBuilder"]) - invited = attr.ib(type=List[InvitedSyncResult]) - knocked = attr.ib(type=List[KnockedSyncResult]) - newly_joined_rooms = attr.ib(type=List[str]) - newly_left_rooms = attr.ib(type=List[str]) + room_entries: List["RoomSyncResultBuilder"] + invited: List[InvitedSyncResult] + knocked: List[KnockedSyncResult] + newly_joined_rooms: List[str] + newly_left_rooms: List[str] -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class SyncResult: """ Attributes: @@ -230,18 +230,18 @@ class SyncResult: groups: Group updates, if any """ - next_batch = attr.ib(type=StreamToken) - presence = attr.ib(type=List[JsonDict]) - account_data = attr.ib(type=List[JsonDict]) - joined = attr.ib(type=List[JoinedSyncResult]) - invited = attr.ib(type=List[InvitedSyncResult]) - knocked = attr.ib(type=List[KnockedSyncResult]) - archived = attr.ib(type=List[ArchivedSyncResult]) - to_device = attr.ib(type=List[JsonDict]) - device_lists = attr.ib(type=DeviceLists) - device_one_time_keys_count = attr.ib(type=JsonDict) - device_unused_fallback_key_types = attr.ib(type=List[str]) - groups = attr.ib(type=Optional[GroupsSyncResult]) + next_batch: StreamToken + presence: List[JsonDict] + account_data: List[JsonDict] + joined: List[JoinedSyncResult] + invited: List[InvitedSyncResult] + knocked: List[KnockedSyncResult] + archived: List[ArchivedSyncResult] + to_device: List[JsonDict] + device_lists: DeviceLists + device_one_time_keys_count: JsonDict + device_unused_fallback_key_types: List[str] + groups: Optional[GroupsSyncResult] def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -2160,7 +2160,7 @@ def _calculate_state( return {event_id_to_key[e]: e for e in state_ids} -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class SyncResultBuilder: """Used to help build up a new SyncResult for a user @@ -2182,23 +2182,23 @@ class SyncResultBuilder: to_device (list) """ - sync_config = attr.ib(type=SyncConfig) - full_state = attr.ib(type=bool) - since_token = attr.ib(type=Optional[StreamToken]) - now_token = attr.ib(type=StreamToken) - joined_room_ids = attr.ib(type=FrozenSet[str]) + sync_config: SyncConfig + full_state: bool + since_token: Optional[StreamToken] + now_token: StreamToken + joined_room_ids: FrozenSet[str] - presence = attr.ib(type=List[JsonDict], default=attr.Factory(list)) - account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list)) - joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list)) - invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list)) - knocked = attr.ib(type=List[KnockedSyncResult], default=attr.Factory(list)) - archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list)) - groups = attr.ib(type=Optional[GroupsSyncResult], default=None) - to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list)) + presence: List[JsonDict] = attr.Factory(list) + account_data: List[JsonDict] = attr.Factory(list) + joined: List[JoinedSyncResult] = attr.Factory(list) + invited: List[InvitedSyncResult] = attr.Factory(list) + knocked: List[KnockedSyncResult] = attr.Factory(list) + archived: List[ArchivedSyncResult] = attr.Factory(list) + groups: Optional[GroupsSyncResult] = None + to_device: List[JsonDict] = attr.Factory(list) -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class RoomSyncResultBuilder: """Stores information needed to create either a `JoinedSyncResult` or `ArchivedSyncResult`. @@ -2214,10 +2214,10 @@ class RoomSyncResultBuilder: upto_token: Latest point to return events from. """ - room_id = attr.ib(type=str) - rtype = attr.ib(type=str) - events = attr.ib(type=Optional[List[EventBase]]) - newly_joined = attr.ib(type=bool) - full_state = attr.ib(type=bool) - since_token = attr.ib(type=Optional[StreamToken]) - upto_token = attr.ib(type=StreamToken) + room_id: str + rtype: str + events: Optional[List[EventBase]] + newly_joined: bool + full_state: bool + since_token: Optional[StreamToken] + upto_token: StreamToken -- cgit 1.5.1 From 964f29cb6f4ff5892575ac170aa872822f2c1a0e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 18 Aug 2021 12:36:22 +0100 Subject: Refactor `on_receive_pdu` code (#10615) * drop room pdu linearizer sooner No point holding onto it while we recheck the db * move out `missing_prevs` calculation we're going to need `missing_prevs` whatever we do, so we may as well calculate it eagerly and just update it if it gets outdated. * Add another `if missing_prevs` condition this should be a no-op, since all the code inside the block already checks `if missing_prevs` * reorder if conditions This shouldn't change the logic at all. * Push down `min_depth` read No point reading it from the database unless we're going to use it. * Collect the sent_to_us_directly code together Move the remaining `sent_to_us_directly` code inside the `if sent_to_us_directly` block. * Properly separate the `not sent_to_us_directly` branch Since the only way this second block is now reachable is if we *didn't* go into the `sent_to_us_directly` branch, we can replace it with a simple `else`. * changelog --- changelog.d/10615.misc | 1 + synapse/handlers/federation.py | 271 +++++++++++++++++++++-------------------- 2 files changed, 138 insertions(+), 134 deletions(-) create mode 100644 changelog.d/10615.misc diff --git a/changelog.d/10615.misc b/changelog.d/10615.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10615.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index a257d87fed..529d025c39 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -285,169 +285,172 @@ class FederationHandler(BaseHandler): # - Fetching any missing prev events to fill in gaps in the graph # - Fetching state if we have a hole in the graph if not pdu.internal_metadata.is_outlier(): - # We only backfill backwards to the min depth. - min_depth = await self.get_min_depth_for_context(pdu.room_id) - - logger.debug("min_depth: %d", min_depth) - prevs = set(pdu.prev_event_ids()) seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen - if min_depth is not None and pdu.depth > min_depth: - missing_prevs = prevs - seen - if sent_to_us_directly and missing_prevs: - # If we're missing stuff, ensure we only fetch stuff one - # at a time. - logger.info( - "Acquiring room lock to fetch %d missing prev_events: %s", - len(missing_prevs), - shortstr(missing_prevs), - ) - with (await self._room_pdu_linearizer.queue(pdu.room_id)): + if missing_prevs: + if sent_to_us_directly: + # We only backfill backwards to the min depth. + min_depth = await self.get_min_depth_for_context(pdu.room_id) + logger.debug("min_depth: %d", min_depth) + + if min_depth is not None and pdu.depth > min_depth: + # If we're missing stuff, ensure we only fetch stuff one + # at a time. logger.info( - "Acquired room lock to fetch %d missing prev_events", + "Acquiring room lock to fetch %d missing prev_events: %s", len(missing_prevs), + shortstr(missing_prevs), ) - - try: - await self._get_missing_events_for_pdu( - origin, pdu, prevs, min_depth + with (await self._room_pdu_linearizer.queue(pdu.room_id)): + logger.info( + "Acquired room lock to fetch %d missing prev_events", + len(missing_prevs), ) - except Exception as e: - raise Exception( - "Error fetching missing prev_events for %s: %s" - % (event_id, e) - ) from e + + try: + await self._get_missing_events_for_pdu( + origin, pdu, prevs, min_depth + ) + except Exception as e: + raise Exception( + "Error fetching missing prev_events for %s: %s" + % (event_id, e) + ) from e # Update the set of things we've seen after trying to # fetch the missing stuff seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen + + if not missing_prevs: + logger.info("Found all missing prev_events") + + if missing_prevs: + # since this event was pushed to us, it is possible for it to + # become the only forward-extremity in the room, and we would then + # trust its state to be the state for the whole room. This is very + # bad. Further, if the event was pushed to us, there is no excuse + # for us not to have all the prev_events. (XXX: apart from + # min_depth?) + # + # We therefore reject any such events. + logger.warning( + "Rejecting: failed to fetch %d prev events: %s", + len(missing_prevs), + shortstr(missing_prevs), + ) + raise FederationError( + "ERROR", + 403, + ( + "Your server isn't divulging details about prev_events " + "referenced in this event." + ), + affected=pdu.event_id, + ) - if not prevs - seen: - logger.info( - "Found all missing prev_events", - ) - - missing_prevs = prevs - seen - if missing_prevs: - # We've still not been able to get all of the prev_events for this event. - # - # In this case, we need to fall back to asking another server in the - # federation for the state at this event. That's ok provided we then - # resolve the state against other bits of the DAG before using it (which - # will ensure that you can't just take over a room by sending an event, - # withholding its prev_events, and declaring yourself to be an admin in - # the subsequent state request). - # - # Now, if we're pulling this event as a missing prev_event, then clearly - # this event is not going to become the only forward-extremity and we are - # guaranteed to resolve its state against our existing forward - # extremities, so that should be fine. - # - # On the other hand, if this event was pushed to us, it is possible for - # it to become the only forward-extremity in the room, and we would then - # trust its state to be the state for the whole room. This is very bad. - # Further, if the event was pushed to us, there is no excuse for us not to - # have all the prev_events. We therefore reject any such events. - # - # XXX this really feels like it could/should be merged with the above, - # but there is an interaction with min_depth that I'm not really - # following. - - if sent_to_us_directly: - logger.warning( - "Rejecting: failed to fetch %d prev events: %s", - len(missing_prevs), + else: + # We don't have all of the prev_events for this event. + # + # In this case, we need to fall back to asking another server in the + # federation for the state at this event. That's ok provided we then + # resolve the state against other bits of the DAG before using it (which + # will ensure that you can't just take over a room by sending an event, + # withholding its prev_events, and declaring yourself to be an admin in + # the subsequent state request). + # + # Since we're pulling this event as a missing prev_event, then clearly + # this event is not going to become the only forward-extremity and we are + # guaranteed to resolve its state against our existing forward + # extremities, so that should be fine. + # + # XXX this really feels like it could/should be merged with the above, + # but there is an interaction with min_depth that I'm not really + # following. + logger.info( + "Event %s is missing prev_events %s: calculating state for a " + "backwards extremity", + event_id, shortstr(missing_prevs), ) - raise FederationError( - "ERROR", - 403, - ( - "Your server isn't divulging details about prev_events " - "referenced in this event." - ), - affected=pdu.event_id, - ) - logger.info( - "Event %s is missing prev_events %s: calculating state for a " - "backwards extremity", - event_id, - shortstr(missing_prevs), - ) + # Calculate the state after each of the previous events, and + # resolve them to find the correct state at the current event. + event_map = {event_id: pdu} + try: + # Get the state of the events we know about + ours = await self.state_store.get_state_groups_ids( + room_id, seen + ) - # Calculate the state after each of the previous events, and - # resolve them to find the correct state at the current event. - event_map = {event_id: pdu} - try: - # Get the state of the events we know about - ours = await self.state_store.get_state_groups_ids(room_id, seen) - - # state_maps is a list of mappings from (type, state_key) to event_id - state_maps: List[StateMap[str]] = list(ours.values()) - - # we don't need this any more, let's delete it. - del ours - - # Ask the remote server for the states we don't - # know about - for p in missing_prevs: - logger.info("Requesting state after missing prev_event %s", p) - - with nested_logging_context(p): - # note that if any of the missing prevs share missing state or - # auth events, the requests to fetch those events are deduped - # by the get_pdu_cache in federation_client. - remote_state = ( - await self._get_state_after_missing_prev_event( - origin, room_id, p - ) + # state_maps is a list of mappings from (type, state_key) to event_id + state_maps: List[StateMap[str]] = list(ours.values()) + + # we don't need this any more, let's delete it. + del ours + + # Ask the remote server for the states we don't + # know about + for p in missing_prevs: + logger.info( + "Requesting state after missing prev_event %s", p ) - remote_state_map = { - (x.type, x.state_key): x.event_id for x in remote_state - } - state_maps.append(remote_state_map) + with nested_logging_context(p): + # note that if any of the missing prevs share missing state or + # auth events, the requests to fetch those events are deduped + # by the get_pdu_cache in federation_client. + remote_state = ( + await self._get_state_after_missing_prev_event( + origin, room_id, p + ) + ) + + remote_state_map = { + (x.type, x.state_key): x.event_id + for x in remote_state + } + state_maps.append(remote_state_map) - for x in remote_state: - event_map[x.event_id] = x + for x in remote_state: + event_map[x.event_id] = x - room_version = await self.store.get_room_version_id(room_id) - state_map = ( - await self._state_resolution_handler.resolve_events_with_store( + room_version = await self.store.get_room_version_id(room_id) + state_map = await self._state_resolution_handler.resolve_events_with_store( room_id, room_version, state_maps, event_map, state_res_store=StateResolutionStore(self.store), ) - ) - # We need to give _process_received_pdu the actual state events - # rather than event ids, so generate that now. + # We need to give _process_received_pdu the actual state events + # rather than event ids, so generate that now. - # First though we need to fetch all the events that are in - # state_map, so we can build up the state below. - evs = await self.store.get_events( - list(state_map.values()), - get_prev_content=False, - redact_behaviour=EventRedactBehaviour.AS_IS, - ) - event_map.update(evs) + # First though we need to fetch all the events that are in + # state_map, so we can build up the state below. + evs = await self.store.get_events( + list(state_map.values()), + get_prev_content=False, + redact_behaviour=EventRedactBehaviour.AS_IS, + ) + event_map.update(evs) - state = [event_map[e] for e in state_map.values()] - except Exception: - logger.warning( - "Error attempting to resolve state at missing " "prev_events", - exc_info=True, - ) - raise FederationError( - "ERROR", - 403, - "We can't get valid state history.", - affected=event_id, - ) + state = [event_map[e] for e in state_map.values()] + except Exception: + logger.warning( + "Error attempting to resolve state at missing " + "prev_events", + exc_info=True, + ) + raise FederationError( + "ERROR", + 403, + "We can't get valid state history.", + affected=event_id, + ) # A second round of checks for all events. Check that the event passes auth # based on `auth_events`, this allows us to assert that the event would -- cgit 1.5.1 From eea28735958804cd2b0d54bd19e1e25e8570209d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 18 Aug 2021 12:38:37 +0100 Subject: fix broken link to upgrade notes (#10631) --- UPGRADE.rst | 2 +- changelog.d/10631.misc | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10631.misc diff --git a/UPGRADE.rst b/UPGRADE.rst index 17ecd935fd..6c7f9cb18e 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -1,7 +1,7 @@ Upgrading Synapse ================= -This document has moved to the `Synapse documentation website `_. +This document has moved to the `Synapse documentation website `_. Please update your links. The markdown source is available in `docs/upgrade.md `_. diff --git a/changelog.d/10631.misc b/changelog.d/10631.misc new file mode 100644 index 0000000000..d2a4624d53 --- /dev/null +++ b/changelog.d/10631.misc @@ -0,0 +1 @@ +Fix a broken link to the upgrade notes. -- cgit 1.5.1 From 6e613a10d072c32e72d6b97b2d178bb840769f3e Mon Sep 17 00:00:00 2001 From: Callum Brown Date: Wed, 18 Aug 2021 13:13:35 +0100 Subject: Display an error page during failure of fallback UIA. (#10561) --- changelog.d/10561.bugfix | 1 + docs/upgrade.md | 8 +++++++ synapse/handlers/auth.py | 23 +++++++++++-------- synapse/handlers/ui_auth/checkers.py | 10 +++++--- synapse/res/templates/recaptcha.html | 3 +++ synapse/res/templates/terms.html | 3 +++ synapse/rest/client/auth.py | 39 ++++++++++++++++++++------------ synapse/static/client/register/style.css | 6 ++++- 8 files changed, 65 insertions(+), 28 deletions(-) create mode 100644 changelog.d/10561.bugfix diff --git a/changelog.d/10561.bugfix b/changelog.d/10561.bugfix new file mode 100644 index 0000000000..2e4f53508c --- /dev/null +++ b/changelog.d/10561.bugfix @@ -0,0 +1 @@ +Display an error on User-Interactive Authentication fallback pages when authentication fails. Contributed by Callum Brown. diff --git a/docs/upgrade.md b/docs/upgrade.md index 3ac2387e2a..adbd16fda5 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -125,6 +125,14 @@ environment variable. See [using a forward proxy with Synapse documentation](setup/forward_proxy.md) for details. +## User-interactive authentication fallback templates can now display errors + +This may affect you if you make use of custom HTML templates for the +[reCAPTCHA](../synapse/res/templates/recaptcha.html) or +[terms](../synapse/res/templates/terms.html) fallback pages. + +The template is now provided an `error` variable if the authentication +process failed. See the default templates linked above for an example. # Upgrading to v1.39.0 diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 161b3c933c..98d3d2d97f 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -627,23 +627,28 @@ class AuthHandler(BaseHandler): async def add_oob_auth( self, stagetype: str, authdict: Dict[str, Any], clientip: str - ) -> bool: + ) -> None: """ Adds the result of out-of-band authentication into an existing auth session. Currently used for adding the result of fallback auth. + + Raises: + LoginError if the stagetype is unknown or the session is missing. + LoginError is raised by check_auth if authentication fails. """ if stagetype not in self.checkers: - raise LoginError(400, "", Codes.MISSING_PARAM) + raise LoginError( + 400, f"Unknown UIA stage type: {stagetype}", Codes.INVALID_PARAM + ) if "session" not in authdict: - raise LoginError(400, "", Codes.MISSING_PARAM) + raise LoginError(400, "Missing session ID", Codes.MISSING_PARAM) + # If authentication fails a LoginError is raised. Otherwise, store + # the successful result. result = await self.checkers[stagetype].check_auth(authdict, clientip) - if result: - await self.store.mark_ui_auth_stage_complete( - authdict["session"], stagetype, result - ) - return True - return False + await self.store.mark_ui_auth_stage_complete( + authdict["session"], stagetype, result + ) def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]: """ diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 5414ce77d8..270541cc76 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -49,7 +49,7 @@ class UserInteractiveAuthChecker: clientip: The IP address of the client. Raises: - SynapseError if authentication failed + LoginError if authentication failed. Returns: The result of authentication (to pass back to the client?) @@ -131,7 +131,9 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): ) if resp_body["success"]: return True - raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) + raise LoginError( + 401, "Captcha authentication failed", errcode=Codes.UNAUTHORIZED + ) class _BaseThreepidAuthChecker: @@ -191,7 +193,9 @@ class _BaseThreepidAuthChecker: raise AssertionError("Unrecognized threepid medium: %s" % (medium,)) if not threepid: - raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) + raise LoginError( + 401, "Unable to get validated threepid", errcode=Codes.UNAUTHORIZED + ) if threepid["medium"] != medium: raise LoginError( diff --git a/synapse/res/templates/recaptcha.html b/synapse/res/templates/recaptcha.html index 63944dc608..b3db06ef97 100644 --- a/synapse/res/templates/recaptcha.html +++ b/synapse/res/templates/recaptcha.html @@ -16,6 +16,9 @@ function captchaDone() {
+ {% if error is defined %} +

Error: {{ error }}

+ {% endif %}

Hello! We need to prevent computer programs and other automated things from creating accounts on this server. diff --git a/synapse/res/templates/terms.html b/synapse/res/templates/terms.html index dfef9897ee..369ff446d2 100644 --- a/synapse/res/templates/terms.html +++ b/synapse/res/templates/terms.html @@ -8,6 +8,9 @@

+ {% if error is defined %} +

Error: {{ error }}

+ {% endif %}

Please click the button below if you agree to the privacy policy of this homeserver. diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 6ea1b50a62..73284e48ec 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -16,7 +16,7 @@ import logging from typing import TYPE_CHECKING from synapse.api.constants import LoginType -from synapse.api.errors import SynapseError +from synapse.api.errors import LoginError, SynapseError from synapse.api.urls import CLIENT_API_PREFIX from synapse.http.server import respond_with_html from synapse.http.servlet import RestServlet, parse_string @@ -95,29 +95,32 @@ class AuthRestServlet(RestServlet): authdict = {"response": response, "session": session} - success = await self.auth_handler.add_oob_auth( - LoginType.RECAPTCHA, authdict, request.getClientIP() - ) - - if success: - html = self.success_template.render() - else: + try: + await self.auth_handler.add_oob_auth( + LoginType.RECAPTCHA, authdict, request.getClientIP() + ) + except LoginError as e: + # Authentication failed, let user try again html = self.recaptcha_template.render( session=session, myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), sitekey=self.hs.config.recaptcha_public_key, + error=e.msg, ) + else: + # No LoginError was raised, so authentication was successful + html = self.success_template.render() + elif stagetype == LoginType.TERMS: authdict = {"session": session} - success = await self.auth_handler.add_oob_auth( - LoginType.TERMS, authdict, request.getClientIP() - ) - - if success: - html = self.success_template.render() - else: + try: + await self.auth_handler.add_oob_auth( + LoginType.TERMS, authdict, request.getClientIP() + ) + except LoginError as e: + # Authentication failed, let user try again html = self.terms_template.render( session=session, terms_url="%s_matrix/consent?v=%s" @@ -127,10 +130,16 @@ class AuthRestServlet(RestServlet): ), myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), + error=e.msg, ) + else: + # No LoginError was raised, so authentication was successful + html = self.success_template.render() + elif stagetype == LoginType.SSO: # The SSO fallback workflow should not post here, raise SynapseError(404, "Fallback SSO auth does not support POST requests.") + else: raise SynapseError(404, "Unknown auth stage type") diff --git a/synapse/static/client/register/style.css b/synapse/static/client/register/style.css index 5a7b6eebf2..8a39b5d0f5 100644 --- a/synapse/static/client/register/style.css +++ b/synapse/static/client/register/style.css @@ -57,4 +57,8 @@ textarea, input { background-color: #f8f8f8; border: 1px #ccc solid; -} \ No newline at end of file +} + +.error { + color: red; +} -- cgit 1.5.1 From 3692f7fd33ec2a28991ab325a46df5e7eba1f056 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 18 Aug 2021 13:25:12 +0100 Subject: Mount /_synapse/admin/v1/users/{userId}/media admin API on media workers only (#10628) Co-authored-by: Patrick Cloke --- changelog.d/10628.feature | 1 + docs/upgrade.md | 6 ++ docs/workers.md | 4 +- synapse/rest/admin/__init__.py | 2 - synapse/rest/admin/media.py | 165 ++++++++++++++++++++++++++++++++++++++++- synapse/rest/admin/users.py | 160 --------------------------------------- 6 files changed, 173 insertions(+), 165 deletions(-) create mode 100644 changelog.d/10628.feature diff --git a/changelog.d/10628.feature b/changelog.d/10628.feature new file mode 100644 index 0000000000..708cb9b599 --- /dev/null +++ b/changelog.d/10628.feature @@ -0,0 +1 @@ +Admin API to delete several media for a specific user. Contributed by @dklimpel. \ No newline at end of file diff --git a/docs/upgrade.md b/docs/upgrade.md index 1c459d8e2b..99e32034c8 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -123,6 +123,12 @@ for more information and examples. We plan to remove support for these settings in October 2021. +## `/_synapse/admin/v1/users/{userId}/media` must be handled by media workers + +The [media repository worker documentation](https://matrix-org.github.io/synapse/latest/workers.html#synapseappmedia_repository) +has been updated to reflect that calls to `/_synapse/admin/v1/users/{userId}/media` +must now be handled by media repository workers. This is due to the new `DELETE` method +of this endpoint modifying the media store. # Upgrading to v1.39.0 diff --git a/docs/workers.md b/docs/workers.md index 1657dfc759..2e63f03452 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -426,10 +426,12 @@ Handles the media repository. It can handle all endpoints starting with: ^/_synapse/admin/v1/user/.*/media.*$ ^/_synapse/admin/v1/media/.*$ ^/_synapse/admin/v1/quarantine_media/.*$ + ^/_synapse/admin/v1/users/.*/media$ You should also set `enable_media_repo: False` in the shared configuration file to stop the main synapse running background jobs related to managing the -media repository. +media repository. Note that doing so will prevent the main process from being +able to handle the above endpoints. In the `media_repository` worker configuration file, configure the http listener to expose the `media` resource. For example: diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 8a91068092..d5862a4da4 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -61,7 +61,6 @@ from synapse.rest.admin.users import ( SearchUsersRestServlet, ShadowBanRestServlet, UserAdminServlet, - UserMediaRestServlet, UserMembershipRestServlet, UserRegisterServlet, UserRestServletV2, @@ -225,7 +224,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) - UserMediaRestServlet(hs).register(http_server) UserMembershipRestServlet(hs).register(http_server) UserTokenRestServlet(hs).register(http_server) UserRestServletV2(hs).register(http_server) diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 5f0555039d..8ce443049e 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -18,14 +18,15 @@ from typing import TYPE_CHECKING, Tuple from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet, parse_boolean, parse_integer +from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.admin._base import ( admin_patterns, assert_requester_is_admin, assert_user_is_admin, ) -from synapse.types import JsonDict +from synapse.storage.databases.main.media_repository import MediaSortOrder +from synapse.types import JsonDict, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -314,6 +315,165 @@ class DeleteMediaByDateSize(RestServlet): return 200, {"deleted_media": deleted_media, "total": total} +class UserMediaRestServlet(RestServlet): + """ + Gets information about all uploaded local media for a specific `user_id`. + With DELETE request you can delete all this media. + + Example: + http://localhost:8008/_synapse/admin/v1/users/@user:server/media + + Args: + The parameters `from` and `limit` are required for pagination. + By default, a `limit` of 100 is used. + Returns: + A list of media and an integer representing the total number of + media that exist given for this user + """ + + PATTERNS = admin_patterns("/users/(?P[^/]+)/media$") + + def __init__(self, hs: "HomeServer"): + self.is_mine = hs.is_mine + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.media_repository = hs.get_media_repository() + + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + # This will always be set by the time Twisted calls us. + assert request.args is not None + + await assert_requester_is_admin(self.auth, request) + + if not self.is_mine(UserID.from_string(user_id)): + raise SynapseError(400, "Can only look up local users") + + user = await self.store.get_user_by_id(user_id) + if user is None: + raise NotFoundError("Unknown user") + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + 400, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + # If neither `order_by` nor `dir` is set, set the default order + # to newest media is on top for backward compatibility. + if b"order_by" not in request.args and b"dir" not in request.args: + order_by = MediaSortOrder.CREATED_TS.value + direction = "b" + else: + order_by = parse_string( + request, + "order_by", + default=MediaSortOrder.CREATED_TS.value, + allowed_values=( + MediaSortOrder.MEDIA_ID.value, + MediaSortOrder.UPLOAD_NAME.value, + MediaSortOrder.CREATED_TS.value, + MediaSortOrder.LAST_ACCESS_TS.value, + MediaSortOrder.MEDIA_LENGTH.value, + MediaSortOrder.MEDIA_TYPE.value, + MediaSortOrder.QUARANTINED_BY.value, + MediaSortOrder.SAFE_FROM_QUARANTINE.value, + ), + ) + direction = parse_string( + request, "dir", default="f", allowed_values=("f", "b") + ) + + media, total = await self.store.get_local_media_by_user_paginate( + start, limit, user_id, order_by, direction + ) + + ret = {"media": media, "total": total} + if (start + limit) < total: + ret["next_token"] = start + len(media) + + return 200, ret + + async def on_DELETE( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + # This will always be set by the time Twisted calls us. + assert request.args is not None + + await assert_requester_is_admin(self.auth, request) + + if not self.is_mine(UserID.from_string(user_id)): + raise SynapseError(400, "Can only look up local users") + + user = await self.store.get_user_by_id(user_id) + if user is None: + raise NotFoundError("Unknown user") + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + 400, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + # If neither `order_by` nor `dir` is set, set the default order + # to newest media is on top for backward compatibility. + if b"order_by" not in request.args and b"dir" not in request.args: + order_by = MediaSortOrder.CREATED_TS.value + direction = "b" + else: + order_by = parse_string( + request, + "order_by", + default=MediaSortOrder.CREATED_TS.value, + allowed_values=( + MediaSortOrder.MEDIA_ID.value, + MediaSortOrder.UPLOAD_NAME.value, + MediaSortOrder.CREATED_TS.value, + MediaSortOrder.LAST_ACCESS_TS.value, + MediaSortOrder.MEDIA_LENGTH.value, + MediaSortOrder.MEDIA_TYPE.value, + MediaSortOrder.QUARANTINED_BY.value, + MediaSortOrder.SAFE_FROM_QUARANTINE.value, + ), + ) + direction = parse_string( + request, "dir", default="f", allowed_values=("f", "b") + ) + + media, _ = await self.store.get_local_media_by_user_paginate( + start, limit, user_id, order_by, direction + ) + + deleted_media, total = await self.media_repository.delete_local_media_ids( + ([row["media_id"] for row in media]) + ) + + return 200, {"deleted_media": deleted_media, "total": total} + + def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer) -> None: """ Media repo specific APIs. @@ -328,3 +488,4 @@ def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer) ListMediaInRoom(hs).register(http_server) DeleteMediaByID(hs).register(http_server) DeleteMediaByDateSize(hs).register(http_server) + UserMediaRestServlet(hs).register(http_server) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 93193b0864..3c8a0c6883 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -35,7 +35,6 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.rest.client._base import client_patterns -from synapse.storage.databases.main.media_repository import MediaSortOrder from synapse.storage.databases.main.stats import UserSortOrder from synapse.types import JsonDict, UserID @@ -851,165 +850,6 @@ class PushersRestServlet(RestServlet): return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)} -class UserMediaRestServlet(RestServlet): - """ - Gets information about all uploaded local media for a specific `user_id`. - With DELETE request you can delete all this media. - - Example: - http://localhost:8008/_synapse/admin/v1/users/@user:server/media - - Args: - The parameters `from` and `limit` are required for pagination. - By default, a `limit` of 100 is used. - Returns: - A list of media and an integer representing the total number of - media that exist given for this user - """ - - PATTERNS = admin_patterns("/users/(?P[^/]+)/media$") - - def __init__(self, hs: "HomeServer"): - self.is_mine = hs.is_mine - self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.media_repository = hs.get_media_repository() - - async def on_GET( - self, request: SynapseRequest, user_id: str - ) -> Tuple[int, JsonDict]: - # This will always be set by the time Twisted calls us. - assert request.args is not None - - await assert_requester_is_admin(self.auth, request) - - if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only look up local users") - - user = await self.store.get_user_by_id(user_id) - if user is None: - raise NotFoundError("Unknown user") - - start = parse_integer(request, "from", default=0) - limit = parse_integer(request, "limit", default=100) - - if start < 0: - raise SynapseError( - 400, - "Query parameter from must be a string representing a positive integer.", - errcode=Codes.INVALID_PARAM, - ) - - if limit < 0: - raise SynapseError( - 400, - "Query parameter limit must be a string representing a positive integer.", - errcode=Codes.INVALID_PARAM, - ) - - # If neither `order_by` nor `dir` is set, set the default order - # to newest media is on top for backward compatibility. - if b"order_by" not in request.args and b"dir" not in request.args: - order_by = MediaSortOrder.CREATED_TS.value - direction = "b" - else: - order_by = parse_string( - request, - "order_by", - default=MediaSortOrder.CREATED_TS.value, - allowed_values=( - MediaSortOrder.MEDIA_ID.value, - MediaSortOrder.UPLOAD_NAME.value, - MediaSortOrder.CREATED_TS.value, - MediaSortOrder.LAST_ACCESS_TS.value, - MediaSortOrder.MEDIA_LENGTH.value, - MediaSortOrder.MEDIA_TYPE.value, - MediaSortOrder.QUARANTINED_BY.value, - MediaSortOrder.SAFE_FROM_QUARANTINE.value, - ), - ) - direction = parse_string( - request, "dir", default="f", allowed_values=("f", "b") - ) - - media, total = await self.store.get_local_media_by_user_paginate( - start, limit, user_id, order_by, direction - ) - - ret = {"media": media, "total": total} - if (start + limit) < total: - ret["next_token"] = start + len(media) - - return 200, ret - - async def on_DELETE( - self, request: SynapseRequest, user_id: str - ) -> Tuple[int, JsonDict]: - # This will always be set by the time Twisted calls us. - assert request.args is not None - - await assert_requester_is_admin(self.auth, request) - - if not self.is_mine(UserID.from_string(user_id)): - raise SynapseError(400, "Can only look up local users") - - user = await self.store.get_user_by_id(user_id) - if user is None: - raise NotFoundError("Unknown user") - - start = parse_integer(request, "from", default=0) - limit = parse_integer(request, "limit", default=100) - - if start < 0: - raise SynapseError( - 400, - "Query parameter from must be a string representing a positive integer.", - errcode=Codes.INVALID_PARAM, - ) - - if limit < 0: - raise SynapseError( - 400, - "Query parameter limit must be a string representing a positive integer.", - errcode=Codes.INVALID_PARAM, - ) - - # If neither `order_by` nor `dir` is set, set the default order - # to newest media is on top for backward compatibility. - if b"order_by" not in request.args and b"dir" not in request.args: - order_by = MediaSortOrder.CREATED_TS.value - direction = "b" - else: - order_by = parse_string( - request, - "order_by", - default=MediaSortOrder.CREATED_TS.value, - allowed_values=( - MediaSortOrder.MEDIA_ID.value, - MediaSortOrder.UPLOAD_NAME.value, - MediaSortOrder.CREATED_TS.value, - MediaSortOrder.LAST_ACCESS_TS.value, - MediaSortOrder.MEDIA_LENGTH.value, - MediaSortOrder.MEDIA_TYPE.value, - MediaSortOrder.QUARANTINED_BY.value, - MediaSortOrder.SAFE_FROM_QUARANTINE.value, - ), - ) - direction = parse_string( - request, "dir", default="f", allowed_values=("f", "b") - ) - - media, _ = await self.store.get_local_media_by_user_paginate( - start, limit, user_id, order_by, direction - ) - - deleted_media, total = await self.media_repository.delete_local_media_ids( - ([row["media_id"] for row in media]) - ) - - return 200, {"deleted_media": deleted_media, "total": total} - - class UserTokenRestServlet(RestServlet): """An admin API for logging in as a user. -- cgit 1.5.1 From bec01c075829730cf467572e2fcf93e15372b0e9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 18 Aug 2021 09:22:07 -0400 Subject: Convert room member storage tuples to attrs. (#10629) Instead of using namedtuples. This helps with asserting type hints and code completion. --- changelog.d/10629.misc | 1 + synapse/handlers/initial_sync.py | 2 +- synapse/handlers/sync.py | 18 +++++----- synapse/storage/databases/main/roommember.py | 8 +++-- synapse/storage/databases/main/user_directory.py | 2 +- synapse/storage/roommember.py | 43 ++++++++++++++++-------- tests/replication/slave/storage/test_events.py | 9 +++-- 7 files changed, 54 insertions(+), 29 deletions(-) create mode 100644 changelog.d/10629.misc diff --git a/changelog.d/10629.misc b/changelog.d/10629.misc new file mode 100644 index 0000000000..cca1eb6c57 --- /dev/null +++ b/changelog.d/10629.misc @@ -0,0 +1 @@ +Convert room member storage tuples to `attrs` classes. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index e1c544a3c9..4e8f7f1d85 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -151,7 +151,7 @@ class InitialSyncHandler(BaseHandler): limit = 10 async def handle_room(event: RoomsForUser): - d = { + d: JsonDict = { "room_id": event.room_id, "membership": event.membership, "visibility": ( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index eba915819e..b7b299961f 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -701,7 +701,7 @@ class SyncHandler: name_id = state_ids.get((EventTypes.Name, "")) canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, "")) - summary = {} + summary: JsonDict = {} empty_ms = MemberSummary([], 0) # TODO: only send these when they change. @@ -2076,21 +2076,23 @@ class SyncHandler: # If the membership's stream ordering is after the given stream # ordering, we need to go and work out if the user was in the room # before. - for room_id, event_pos in joined_rooms: - if not event_pos.persisted_after(room_key): - joined_room_ids.add(room_id) + for joined_room in joined_rooms: + if not joined_room.event_pos.persisted_after(room_key): + joined_room_ids.add(joined_room.room_id) continue - logger.info("User joined room after current token: %s", room_id) + logger.info("User joined room after current token: %s", joined_room.room_id) extrems = ( await self.store.get_forward_extremities_for_room_at_stream_ordering( - room_id, event_pos.stream + joined_room.room_id, joined_room.event_pos.stream ) ) - users_in_room = await self.state.get_current_users_in_room(room_id, extrems) + users_in_room = await self.state.get_current_users_in_room( + joined_room.room_id, extrems + ) if user_id in users_in_room: - joined_room_ids.add(room_id) + joined_room_ids.add(joined_room.room_id) return frozenset(joined_room_ids) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index e8157ba3d4..c2f6b9d63d 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -307,7 +307,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) @cached() - async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser: + async def get_invited_rooms_for_local_user( + self, user_id: str + ) -> List[RoomsForUser]: """Get all the rooms the *local* user is invited to. Args: @@ -522,7 +524,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): _get_users_server_still_shares_room_with_txn, ) - async def get_rooms_for_user(self, user_id: str, on_invalidate=None): + async def get_rooms_for_user( + self, user_id: str, on_invalidate=None + ) -> FrozenSet[str]: """Returns a set of room_ids the user is currently joined to. If a remote user only returns rooms this server is currently diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 9d28d69ac7..65dde67ae9 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -365,7 +365,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return False async def update_profile_in_user_dir( - self, user_id: str, display_name: str, avatar_url: str + self, user_id: str, display_name: Optional[str], avatar_url: Optional[str] ) -> None: """ Update or add a user's profile in the user directory. diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index c34fbf21bc..0ff66debdf 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -14,25 +14,40 @@ # limitations under the License. import logging -from collections import namedtuple +from typing import List, Optional, Tuple + +import attr + +from synapse.types import PersistedEventPosition logger = logging.getLogger(__name__) -RoomsForUser = namedtuple( - "RoomsForUser", ("room_id", "sender", "membership", "event_id", "stream_ordering") -) +@attr.s(slots=True, frozen=True, weakref_slot=True, auto_attribs=True) +class RoomsForUser: + room_id: str + sender: str + membership: str + event_id: str + stream_ordering: int + + +@attr.s(slots=True, frozen=True, weakref_slot=True, auto_attribs=True) +class GetRoomsForUserWithStreamOrdering: + room_id: str + event_pos: PersistedEventPosition -GetRoomsForUserWithStreamOrdering = namedtuple( - "GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos") -) +@attr.s(slots=True, frozen=True, weakref_slot=True, auto_attribs=True) +class ProfileInfo: + avatar_url: Optional[str] + display_name: Optional[str] -# We store this using a namedtuple so that we save about 3x space over using a -# dict. -ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name")) -# "members" points to a truncated list of (user_id, event_id) tuples for users of -# a given membership type, suitable for use in calculating heroes for a room. -# "count" points to the total numberr of users of a given membership type. -MemberSummary = namedtuple("MemberSummary", ("members", "count")) +@attr.s(slots=True, frozen=True, weakref_slot=True, auto_attribs=True) +class MemberSummary: + # A truncated list of (user_id, event_id) tuples for users of a given + # membership type, suitable for use in calculating heroes for a room. + members: List[Tuple[str, str]] + # The total number of users of a given membership type. + count: int diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index db80a0bdbd..8d1b0606c4 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -20,7 +20,7 @@ from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.handlers.room import RoomEventSource from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.storage.roommember import RoomsForUser +from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser from synapse.types import PersistedEventPosition from tests.server import FakeTransport @@ -216,7 +216,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_rooms_for_user_with_stream_ordering", (USER_ID_2,), - {(ROOM_ID, expected_pos)}, + {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)}, ) def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self): @@ -305,7 +305,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): expected_pos = PersistedEventPosition( "master", j2.internal_metadata.stream_ordering ) - self.assertEqual(joined_rooms, {(ROOM_ID, expected_pos)}) + self.assertEqual( + joined_rooms, + {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)}, + ) event_id = 0 -- cgit 1.5.1 From 49cb7eae97bbe8916a5a3ec4f9d030f6304cd76c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Aug 2021 15:52:11 +0100 Subject: 1.41.0rc1 --- CHANGES.md | 79 +++++++++++++++++++++++++++++++++++++++++++++++ changelog.d/10119.misc | 1 - changelog.d/10129.bugfix | 1 - changelog.d/10394.feature | 1 - changelog.d/10435.feature | 1 - changelog.d/10443.doc | 1 - changelog.d/10475.feature | 1 - changelog.d/10498.feature | 1 - changelog.d/10504.misc | 1 - changelog.d/10507.misc | 1 - changelog.d/10513.feature | 1 - changelog.d/10518.feature | 1 - changelog.d/10527.misc | 1 - changelog.d/10529.misc | 1 - changelog.d/10530.misc | 1 - changelog.d/10532.bugfix | 1 - changelog.d/10537.misc | 1 - changelog.d/10538.feature | 1 - changelog.d/10539.misc | 1 - changelog.d/10541.bugfix | 1 - changelog.d/10542.misc | 1 - changelog.d/10546.feature | 1 - changelog.d/10549.feature | 1 - changelog.d/10550.bugfix | 1 - changelog.d/10551.doc | 1 - changelog.d/10552.misc | 1 - changelog.d/10558.feature | 1 - changelog.d/10560.feature | 1 - changelog.d/10563.misc | 1 - changelog.d/10564.feature | 1 - changelog.d/10565.misc | 1 - changelog.d/10569.feature | 1 - changelog.d/10570.feature | 1 - changelog.d/10572.misc | 1 - changelog.d/10573.misc | 1 - changelog.d/10574.feature | 1 - changelog.d/10575.feature | 1 - changelog.d/10576.misc | 1 - changelog.d/10578.feature | 1 - changelog.d/10579.feature | 1 - changelog.d/10580.bugfix | 1 - changelog.d/10583.feature | 1 - changelog.d/10587.misc | 1 - changelog.d/10588.removal | 1 - changelog.d/10590.misc | 1 - changelog.d/10591.misc | 1 - changelog.d/10592.bugfix | 1 - changelog.d/10596.removal | 1 - changelog.d/10598.feature | 1 - changelog.d/10599.doc | 1 - changelog.d/10600.misc | 1 - changelog.d/10602.feature | 1 - changelog.d/10606.bugfix | 1 - changelog.d/10611.bugfix | 1 - changelog.d/10612.misc | 1 - changelog.d/10620.misc | 1 - changelog.d/10623.bugfix | 1 - changelog.d/10628.feature | 1 - changelog.d/10631.misc | 1 - changelog.d/10638.feature | 1 - changelog.d/9581.feature | 1 - debian/changelog | 6 ++++ synapse/__init__.py | 2 +- 63 files changed, 86 insertions(+), 61 deletions(-) delete mode 100644 changelog.d/10119.misc delete mode 100644 changelog.d/10129.bugfix delete mode 100644 changelog.d/10394.feature delete mode 100644 changelog.d/10435.feature delete mode 100644 changelog.d/10443.doc delete mode 100644 changelog.d/10475.feature delete mode 100644 changelog.d/10498.feature delete mode 100644 changelog.d/10504.misc delete mode 100644 changelog.d/10507.misc delete mode 100644 changelog.d/10513.feature delete mode 100644 changelog.d/10518.feature delete mode 100644 changelog.d/10527.misc delete mode 100644 changelog.d/10529.misc delete mode 100644 changelog.d/10530.misc delete mode 100644 changelog.d/10532.bugfix delete mode 100644 changelog.d/10537.misc delete mode 100644 changelog.d/10538.feature delete mode 100644 changelog.d/10539.misc delete mode 100644 changelog.d/10541.bugfix delete mode 100644 changelog.d/10542.misc delete mode 100644 changelog.d/10546.feature delete mode 100644 changelog.d/10549.feature delete mode 100644 changelog.d/10550.bugfix delete mode 100644 changelog.d/10551.doc delete mode 100644 changelog.d/10552.misc delete mode 100644 changelog.d/10558.feature delete mode 100644 changelog.d/10560.feature delete mode 100644 changelog.d/10563.misc delete mode 100644 changelog.d/10564.feature delete mode 100644 changelog.d/10565.misc delete mode 100644 changelog.d/10569.feature delete mode 100644 changelog.d/10570.feature delete mode 100644 changelog.d/10572.misc delete mode 100644 changelog.d/10573.misc delete mode 100644 changelog.d/10574.feature delete mode 100644 changelog.d/10575.feature delete mode 100644 changelog.d/10576.misc delete mode 100644 changelog.d/10578.feature delete mode 100644 changelog.d/10579.feature delete mode 100644 changelog.d/10580.bugfix delete mode 100644 changelog.d/10583.feature delete mode 100644 changelog.d/10587.misc delete mode 100644 changelog.d/10588.removal delete mode 100644 changelog.d/10590.misc delete mode 100644 changelog.d/10591.misc delete mode 100644 changelog.d/10592.bugfix delete mode 100644 changelog.d/10596.removal delete mode 100644 changelog.d/10598.feature delete mode 100644 changelog.d/10599.doc delete mode 100644 changelog.d/10600.misc delete mode 100644 changelog.d/10602.feature delete mode 100644 changelog.d/10606.bugfix delete mode 100644 changelog.d/10611.bugfix delete mode 100644 changelog.d/10612.misc delete mode 100644 changelog.d/10620.misc delete mode 100644 changelog.d/10623.bugfix delete mode 100644 changelog.d/10628.feature delete mode 100644 changelog.d/10631.misc delete mode 100644 changelog.d/10638.feature delete mode 100644 changelog.d/9581.feature diff --git a/CHANGES.md b/CHANGES.md index 0e5e052951..b96ac701d5 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,82 @@ +Synapse 1.41.0rc1 (2021-08-18) +============================== + +Features +-------- + +- Add `get_userinfo_by_id` method to ModuleApi. ([\#9581](https://github.com/matrix-org/synapse/issues/9581)) +- Initial local support for [MSC3266](https://github.com/matrix-org/synapse/pull/10394), Room Summary over the unstable `/rooms/{roomIdOrAlias}/summary` API. ([\#10394](https://github.com/matrix-org/synapse/issues/10394)) +- Experimental support for [MSC3288](https://github.com/matrix-org/matrix-doc/pull/3288), sending `room_type` to the identity server for 3pid invites over the `/store-invite` API. ([\#10435](https://github.com/matrix-org/synapse/issues/10435)) +- Add support for sending federation requests through a proxy. Contributed by @Bubu and @dklimpel. ([\#10475](https://github.com/matrix-org/synapse/issues/10475)) +- Add support for "marker" events which makes historical events discoverable for servers that already have all of the scrollback history (part of MSC2716). ([\#10498](https://github.com/matrix-org/synapse/issues/10498)) +- Add a configuration setting for the time a `/sync` response is cached for. ([\#10513](https://github.com/matrix-org/synapse/issues/10513)) +- The default logging handler for new installations is now `PeriodicallyFlushingMemoryHandler`, a buffered logging handler which periodically flushes itself. ([\#10518](https://github.com/matrix-org/synapse/issues/10518)) +- Add support for new redaction rules for historical events specified in [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716). ([\#10538](https://github.com/matrix-org/synapse/issues/10538)) +- Add a setting to disable TLS when sending email. ([\#10546](https://github.com/matrix-org/synapse/issues/10546)) +- Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). ([\#10549](https://github.com/matrix-org/synapse/issues/10549), [\#10560](https://github.com/matrix-org/synapse/issues/10560), [\#10569](https://github.com/matrix-org/synapse/issues/10569), [\#10574](https://github.com/matrix-org/synapse/issues/10574), [\#10575](https://github.com/matrix-org/synapse/issues/10575), [\#10579](https://github.com/matrix-org/synapse/issues/10579), [\#10583](https://github.com/matrix-org/synapse/issues/10583)) +- Admin API to delete several media for a specific user. Contributed by @dklimpel. ([\#10558](https://github.com/matrix-org/synapse/issues/10558), [\#10628](https://github.com/matrix-org/synapse/issues/10628)) +- Add support for routing `/createRoom` to workers. ([\#10564](https://github.com/matrix-org/synapse/issues/10564)) +- Update the Synapse Grafana dashboard. ([\#10570](https://github.com/matrix-org/synapse/issues/10570)) +- Add an admin API (`GET /_synapse/admin/username_available`) to check if a username is available (regardless of registration settings). ([\#10578](https://github.com/matrix-org/synapse/issues/10578)) +- Allow editing a user's `external_ids` via the "Edit User" admin API. Contributed by @dklimpel. ([\#10598](https://github.com/matrix-org/synapse/issues/10598)) +- The Synapse manhole no longer needs coroutines to be wrapped in `defer.ensureDeferred`. ([\#10602](https://github.com/matrix-org/synapse/issues/10602)) +- Add option to allow modules to run periodic tasks on all instances, rather than just the one configured to run background tasks. ([\#10638](https://github.com/matrix-org/synapse/issues/10638)) + + +Bugfixes +-------- + +- Add some clarification to the sample config file. Contributed by @Kentokamoto. ([\#10129](https://github.com/matrix-org/synapse/issues/10129)) +- Fix a long-standing bug where protocols which are not implemented by any appservices were incorrectly returned via `GET /_matrix/client/r0/thirdparty/protocols`. ([\#10532](https://github.com/matrix-org/synapse/issues/10532)) +- Fix exceptions in logs when failing to get remote room list. ([\#10541](https://github.com/matrix-org/synapse/issues/10541)) +- Fix longstanding bug which caused the user "status" to be reset when the user went offline. Contributed by @dklimpel. ([\#10550](https://github.com/matrix-org/synapse/issues/10550)) +- Allow public rooms to be previewed in the spaces summary APIs from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). ([\#10580](https://github.com/matrix-org/synapse/issues/10580)) +- Fix a bug introduced in v1.37.1 where an error could occur in the asyncronous processing of PDUs when the queue was empty. ([\#10592](https://github.com/matrix-org/synapse/issues/10592)) +- Fix errors on /sync when read receipt data is a string. Only affects homeservers with the experimental flag for [MSC2285](https://github.com/matrix-org/matrix-doc/pull/2285) enabled. Contributed by @SimonBrandner. ([\#10606](https://github.com/matrix-org/synapse/issues/10606)) +- Additional validation for the spaces summary API to avoid errors like `ValueError: Stop argument for islice() must be None or an integer`. The missing validation has existed since v1.31.0. ([\#10611](https://github.com/matrix-org/synapse/issues/10611)) +- Revert behaviour introduced in v1.38.0 that strips `org.matrix.msc2732.device_unused_fallback_key_types` from `/sync` when its value is empty. This field should instead always be present according to [MSC2732](https://github.com/matrix-org/matrix-doc/blob/master/proposals/2732-olm-fallback-keys.md). ([\#10623](https://github.com/matrix-org/synapse/issues/10623)) + + +Improved Documentation +---------------------- + +- Add documentation for configuration a forward proxy. ([\#10443](https://github.com/matrix-org/synapse/issues/10443)) +- Updated the reverse proxy documentation to highlight the homserver configuration that is needed to make Synapse aware that is is intentionally reverse proxied. ([\#10551](https://github.com/matrix-org/synapse/issues/10551)) +- Update CONTRIBUTING.md to fix index links and the instructions for SyTest in docker. ([\#10599](https://github.com/matrix-org/synapse/issues/10599)) + + +Deprecations and Removals +------------------------- + +- No longer build `.dev` packages for Ubuntu 20.10 LTS Groovy Gorilla, which has now EOLed. ([\#10588](https://github.com/matrix-org/synapse/issues/10588)) +- The `template_dir` configuration settings in the `sso`, `account_validity` and `email` sections of the configuration file are now deprecated in favour of the global `templates.custom_template_directory` setting. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html) for more information. ([\#10596](https://github.com/matrix-org/synapse/issues/10596)) + + +Internal Changes +---------------- + +- Improve event caching mechanism to avoid having multiple copies of an event in memory at a time. ([\#10119](https://github.com/matrix-org/synapse/issues/10119)) +- Reduce errors in PostgreSQL logs due to concurrent serialization errors. ([\#10504](https://github.com/matrix-org/synapse/issues/10504)) +- Include room ID in ignored EDU log messages. Contributed by @ilmari. ([\#10507](https://github.com/matrix-org/synapse/issues/10507)) +- Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). ([\#10527](https://github.com/matrix-org/synapse/issues/10527), [\#10530](https://github.com/matrix-org/synapse/issues/10530)) +- Fix CI to not break when run against branches rather than pull requests. ([\#10529](https://github.com/matrix-org/synapse/issues/10529)) +- Mark all events stemming from the MSC2716 `/batch_send` endpoint as historical. ([\#10537](https://github.com/matrix-org/synapse/issues/10537)) +- Clean up some of the federation event authentication code for clarity. ([\#10539](https://github.com/matrix-org/synapse/issues/10539), [\#10591](https://github.com/matrix-org/synapse/issues/10591)) +- Convert `Transaction` and `Edu` objects to attrs. ([\#10542](https://github.com/matrix-org/synapse/issues/10542)) +- Update `/batch_send` endpoint to only return `state_events` created by the `state_events_from_before` passed in. ([\#10552](https://github.com/matrix-org/synapse/issues/10552)) +- Update contributing.md to warn against rebasing an open PR. ([\#10563](https://github.com/matrix-org/synapse/issues/10563)) +- Remove the unused public rooms replication stream. ([\#10565](https://github.com/matrix-org/synapse/issues/10565)) +- Clarify error message when failing to join a restricted room. ([\#10572](https://github.com/matrix-org/synapse/issues/10572)) +- Remove references to BuildKite in favour of GitHub Actions. ([\#10573](https://github.com/matrix-org/synapse/issues/10573)) +- Move `/batch_send` endpoint defined by [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) to the `/v2_alpha` directory. ([\#10576](https://github.com/matrix-org/synapse/issues/10576)) +- Allow multiple custom directories in `read_templates`. ([\#10587](https://github.com/matrix-org/synapse/issues/10587)) +- Re-organize the `synapse.federation.transport.server` module to create smaller files. ([\#10590](https://github.com/matrix-org/synapse/issues/10590)) +- Flatten the `synapse.rest.client` package by moving the contents of `v1` and `v2_alpha` into the parent. ([\#10600](https://github.com/matrix-org/synapse/issues/10600)) +- Build Debian packages for Debian 12 (Bookworm). ([\#10612](https://github.com/matrix-org/synapse/issues/10612)) +- Fix up a couple of links to the database schema documentation. ([\#10620](https://github.com/matrix-org/synapse/issues/10620)) +- Fix a broken link to the upgrade notes. ([\#10631](https://github.com/matrix-org/synapse/issues/10631)) + + Synapse 1.40.0 (2021-08-10) =========================== diff --git a/changelog.d/10119.misc b/changelog.d/10119.misc deleted file mode 100644 index f70dc6496f..0000000000 --- a/changelog.d/10119.misc +++ /dev/null @@ -1 +0,0 @@ -Improve event caching mechanism to avoid having multiple copies of an event in memory at a time. diff --git a/changelog.d/10129.bugfix b/changelog.d/10129.bugfix deleted file mode 100644 index 292676ec8d..0000000000 --- a/changelog.d/10129.bugfix +++ /dev/null @@ -1 +0,0 @@ -Add some clarification to the sample config file. Contributed by @Kentokamoto. diff --git a/changelog.d/10394.feature b/changelog.d/10394.feature deleted file mode 100644 index c8bbc5a740..0000000000 --- a/changelog.d/10394.feature +++ /dev/null @@ -1 +0,0 @@ -Initial local support for [MSC3266](https://github.com/matrix-org/synapse/pull/10394), Room Summary over the unstable `/rooms/{roomIdOrAlias}/summary` API. diff --git a/changelog.d/10435.feature b/changelog.d/10435.feature deleted file mode 100644 index f93ef5b415..0000000000 --- a/changelog.d/10435.feature +++ /dev/null @@ -1 +0,0 @@ -Experimental support for [MSC3288](https://github.com/matrix-org/matrix-doc/pull/3288), sending `room_type` to the identity server for 3pid invites over the `/store-invite` API. diff --git a/changelog.d/10443.doc b/changelog.d/10443.doc deleted file mode 100644 index 3588e5487f..0000000000 --- a/changelog.d/10443.doc +++ /dev/null @@ -1 +0,0 @@ -Add documentation for configuration a forward proxy. diff --git a/changelog.d/10475.feature b/changelog.d/10475.feature deleted file mode 100644 index 52eab11b03..0000000000 --- a/changelog.d/10475.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for sending federation requests through a proxy. Contributed by @Bubu and @dklimpel. \ No newline at end of file diff --git a/changelog.d/10498.feature b/changelog.d/10498.feature deleted file mode 100644 index 5df896572d..0000000000 --- a/changelog.d/10498.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for "marker" events which makes historical events discoverable for servers that already have all of the scrollback history (part of MSC2716). diff --git a/changelog.d/10504.misc b/changelog.d/10504.misc deleted file mode 100644 index 1479a5022d..0000000000 --- a/changelog.d/10504.misc +++ /dev/null @@ -1 +0,0 @@ -Reduce errors in PostgreSQL logs due to concurrent serialization errors. diff --git a/changelog.d/10507.misc b/changelog.d/10507.misc deleted file mode 100644 index 5dfd116e60..0000000000 --- a/changelog.d/10507.misc +++ /dev/null @@ -1 +0,0 @@ -Include room ID in ignored EDU log messages. Contributed by @ilmari. diff --git a/changelog.d/10513.feature b/changelog.d/10513.feature deleted file mode 100644 index 153b2df7b2..0000000000 --- a/changelog.d/10513.feature +++ /dev/null @@ -1 +0,0 @@ -Add a configuration setting for the time a `/sync` response is cached for. diff --git a/changelog.d/10518.feature b/changelog.d/10518.feature deleted file mode 100644 index 112e4d105c..0000000000 --- a/changelog.d/10518.feature +++ /dev/null @@ -1 +0,0 @@ -The default logging handler for new installations is now `PeriodicallyFlushingMemoryHandler`, a buffered logging handler which periodically flushes itself. diff --git a/changelog.d/10527.misc b/changelog.d/10527.misc deleted file mode 100644 index ffc4e4289c..0000000000 --- a/changelog.d/10527.misc +++ /dev/null @@ -1 +0,0 @@ -Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/changelog.d/10529.misc b/changelog.d/10529.misc deleted file mode 100644 index 4caf22523c..0000000000 --- a/changelog.d/10529.misc +++ /dev/null @@ -1 +0,0 @@ -Fix CI to not break when run against branches rather than pull requests. diff --git a/changelog.d/10530.misc b/changelog.d/10530.misc deleted file mode 100644 index ffc4e4289c..0000000000 --- a/changelog.d/10530.misc +++ /dev/null @@ -1 +0,0 @@ -Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/changelog.d/10532.bugfix b/changelog.d/10532.bugfix deleted file mode 100644 index d95e3d9b59..0000000000 --- a/changelog.d/10532.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a long-standing bug where protocols which are not implemented by any appservices were incorrectly returned via `GET /_matrix/client/r0/thirdparty/protocols`. diff --git a/changelog.d/10537.misc b/changelog.d/10537.misc deleted file mode 100644 index c9e045300c..0000000000 --- a/changelog.d/10537.misc +++ /dev/null @@ -1 +0,0 @@ -Mark all events stemming from the MSC2716 `/batch_send` endpoint as historical. diff --git a/changelog.d/10538.feature b/changelog.d/10538.feature deleted file mode 100644 index 120c8e8ca0..0000000000 --- a/changelog.d/10538.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for new redaction rules for historical events specified in [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716). diff --git a/changelog.d/10539.misc b/changelog.d/10539.misc deleted file mode 100644 index 9a765435db..0000000000 --- a/changelog.d/10539.misc +++ /dev/null @@ -1 +0,0 @@ -Clean up some of the federation event authentication code for clarity. diff --git a/changelog.d/10541.bugfix b/changelog.d/10541.bugfix deleted file mode 100644 index bb946e0920..0000000000 --- a/changelog.d/10541.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix exceptions in logs when failing to get remote room list. diff --git a/changelog.d/10542.misc b/changelog.d/10542.misc deleted file mode 100644 index 44b70b4730..0000000000 --- a/changelog.d/10542.misc +++ /dev/null @@ -1 +0,0 @@ -Convert `Transaction` and `Edu` objects to attrs. diff --git a/changelog.d/10546.feature b/changelog.d/10546.feature deleted file mode 100644 index 7709d010b3..0000000000 --- a/changelog.d/10546.feature +++ /dev/null @@ -1 +0,0 @@ -Add a setting to disable TLS when sending email. diff --git a/changelog.d/10549.feature b/changelog.d/10549.feature deleted file mode 100644 index ffc4e4289c..0000000000 --- a/changelog.d/10549.feature +++ /dev/null @@ -1 +0,0 @@ -Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/changelog.d/10550.bugfix b/changelog.d/10550.bugfix deleted file mode 100644 index 2e1b7c8bbb..0000000000 --- a/changelog.d/10550.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix longstanding bug which caused the user "status" to be reset when the user went offline. Contributed by @dklimpel. diff --git a/changelog.d/10551.doc b/changelog.d/10551.doc deleted file mode 100644 index 4a2b0785bf..0000000000 --- a/changelog.d/10551.doc +++ /dev/null @@ -1 +0,0 @@ -Updated the reverse proxy documentation to highlight the homserver configuration that is needed to make Synapse aware that is is intentionally reverse proxied. diff --git a/changelog.d/10552.misc b/changelog.d/10552.misc deleted file mode 100644 index fc5f6aea5f..0000000000 --- a/changelog.d/10552.misc +++ /dev/null @@ -1 +0,0 @@ -Update `/batch_send` endpoint to only return `state_events` created by the `state_events_from_before` passed in. diff --git a/changelog.d/10558.feature b/changelog.d/10558.feature deleted file mode 100644 index 1f461bc70a..0000000000 --- a/changelog.d/10558.feature +++ /dev/null @@ -1 +0,0 @@ -Admin API to delete several media for a specific user. Contributed by @dklimpel. diff --git a/changelog.d/10560.feature b/changelog.d/10560.feature deleted file mode 100644 index ffc4e4289c..0000000000 --- a/changelog.d/10560.feature +++ /dev/null @@ -1 +0,0 @@ -Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/changelog.d/10563.misc b/changelog.d/10563.misc deleted file mode 100644 index 8e4e90c8f4..0000000000 --- a/changelog.d/10563.misc +++ /dev/null @@ -1 +0,0 @@ -Update contributing.md to warn against rebasing an open PR. diff --git a/changelog.d/10564.feature b/changelog.d/10564.feature deleted file mode 100644 index 4de32240b2..0000000000 --- a/changelog.d/10564.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for routing `/createRoom` to workers. diff --git a/changelog.d/10565.misc b/changelog.d/10565.misc deleted file mode 100644 index 06796b61ab..0000000000 --- a/changelog.d/10565.misc +++ /dev/null @@ -1 +0,0 @@ -Remove the unused public rooms replication stream. \ No newline at end of file diff --git a/changelog.d/10569.feature b/changelog.d/10569.feature deleted file mode 100644 index ffc4e4289c..0000000000 --- a/changelog.d/10569.feature +++ /dev/null @@ -1 +0,0 @@ -Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/changelog.d/10570.feature b/changelog.d/10570.feature deleted file mode 100644 index bd432685b3..0000000000 --- a/changelog.d/10570.feature +++ /dev/null @@ -1 +0,0 @@ -Update the Synapse Grafana dashboard. diff --git a/changelog.d/10572.misc b/changelog.d/10572.misc deleted file mode 100644 index 008d7be444..0000000000 --- a/changelog.d/10572.misc +++ /dev/null @@ -1 +0,0 @@ -Clarify error message when failing to join a restricted room. diff --git a/changelog.d/10573.misc b/changelog.d/10573.misc deleted file mode 100644 index fc9b1a2f70..0000000000 --- a/changelog.d/10573.misc +++ /dev/null @@ -1 +0,0 @@ -Remove references to BuildKite in favour of GitHub Actions. \ No newline at end of file diff --git a/changelog.d/10574.feature b/changelog.d/10574.feature deleted file mode 100644 index ffc4e4289c..0000000000 --- a/changelog.d/10574.feature +++ /dev/null @@ -1 +0,0 @@ -Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/changelog.d/10575.feature b/changelog.d/10575.feature deleted file mode 100644 index ffc4e4289c..0000000000 --- a/changelog.d/10575.feature +++ /dev/null @@ -1 +0,0 @@ -Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/changelog.d/10576.misc b/changelog.d/10576.misc deleted file mode 100644 index f9f9c9a6fd..0000000000 --- a/changelog.d/10576.misc +++ /dev/null @@ -1 +0,0 @@ -Move `/batch_send` endpoint defined by [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) to the `/v2_alpha` directory. diff --git a/changelog.d/10578.feature b/changelog.d/10578.feature deleted file mode 100644 index 02397f0009..0000000000 --- a/changelog.d/10578.feature +++ /dev/null @@ -1 +0,0 @@ -Add an admin API (`GET /_synapse/admin/username_available`) to check if a username is available (regardless of registration settings). \ No newline at end of file diff --git a/changelog.d/10579.feature b/changelog.d/10579.feature deleted file mode 100644 index ffc4e4289c..0000000000 --- a/changelog.d/10579.feature +++ /dev/null @@ -1 +0,0 @@ -Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/changelog.d/10580.bugfix b/changelog.d/10580.bugfix deleted file mode 100644 index f8da7382b7..0000000000 --- a/changelog.d/10580.bugfix +++ /dev/null @@ -1 +0,0 @@ -Allow public rooms to be previewed in the spaces summary APIs from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/changelog.d/10583.feature b/changelog.d/10583.feature deleted file mode 100644 index ffc4e4289c..0000000000 --- a/changelog.d/10583.feature +++ /dev/null @@ -1 +0,0 @@ -Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/changelog.d/10587.misc b/changelog.d/10587.misc deleted file mode 100644 index 4c6167977c..0000000000 --- a/changelog.d/10587.misc +++ /dev/null @@ -1 +0,0 @@ -Allow multiple custom directories in `read_templates`. diff --git a/changelog.d/10588.removal b/changelog.d/10588.removal deleted file mode 100644 index 90c4b5cee2..0000000000 --- a/changelog.d/10588.removal +++ /dev/null @@ -1 +0,0 @@ -No longer build `.dev` packages for Ubuntu 20.10 LTS Groovy Gorilla, which has now EOLed. \ No newline at end of file diff --git a/changelog.d/10590.misc b/changelog.d/10590.misc deleted file mode 100644 index 62fec717da..0000000000 --- a/changelog.d/10590.misc +++ /dev/null @@ -1 +0,0 @@ -Re-organize the `synapse.federation.transport.server` module to create smaller files. diff --git a/changelog.d/10591.misc b/changelog.d/10591.misc deleted file mode 100644 index 9a765435db..0000000000 --- a/changelog.d/10591.misc +++ /dev/null @@ -1 +0,0 @@ -Clean up some of the federation event authentication code for clarity. diff --git a/changelog.d/10592.bugfix b/changelog.d/10592.bugfix deleted file mode 100644 index efcdab1136..0000000000 --- a/changelog.d/10592.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug introduced in v1.37.1 where an error could occur in the asyncronous processing of PDUs when the queue was empty. diff --git a/changelog.d/10596.removal b/changelog.d/10596.removal deleted file mode 100644 index e69f632db4..0000000000 --- a/changelog.d/10596.removal +++ /dev/null @@ -1 +0,0 @@ -The `template_dir` configuration settings in the `sso`, `account_validity` and `email` sections of the configuration file are now deprecated in favour of the global `templates.custom_template_directory` setting. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html) for more information. diff --git a/changelog.d/10598.feature b/changelog.d/10598.feature deleted file mode 100644 index 92c159118b..0000000000 --- a/changelog.d/10598.feature +++ /dev/null @@ -1 +0,0 @@ -Allow editing a user's `external_ids` via the "Edit User" admin API. Contributed by @dklimpel. \ No newline at end of file diff --git a/changelog.d/10599.doc b/changelog.d/10599.doc deleted file mode 100644 index 66e72078f0..0000000000 --- a/changelog.d/10599.doc +++ /dev/null @@ -1 +0,0 @@ -Update CONTRIBUTING.md to fix index links and the instructions for SyTest in docker. diff --git a/changelog.d/10600.misc b/changelog.d/10600.misc deleted file mode 100644 index 489dc20b11..0000000000 --- a/changelog.d/10600.misc +++ /dev/null @@ -1 +0,0 @@ -Flatten the `synapse.rest.client` package by moving the contents of `v1` and `v2_alpha` into the parent. diff --git a/changelog.d/10602.feature b/changelog.d/10602.feature deleted file mode 100644 index ab18291a20..0000000000 --- a/changelog.d/10602.feature +++ /dev/null @@ -1 +0,0 @@ -The Synapse manhole no longer needs coroutines to be wrapped in `defer.ensureDeferred`. diff --git a/changelog.d/10606.bugfix b/changelog.d/10606.bugfix deleted file mode 100644 index bab9fd2a61..0000000000 --- a/changelog.d/10606.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix errors on /sync when read receipt data is a string. Only affects homeservers with the experimental flag for [MSC2285](https://github.com/matrix-org/matrix-doc/pull/2285) enabled. Contributed by @SimonBrandner. diff --git a/changelog.d/10611.bugfix b/changelog.d/10611.bugfix deleted file mode 100644 index ecbe408b47..0000000000 --- a/changelog.d/10611.bugfix +++ /dev/null @@ -1 +0,0 @@ -Additional validation for the spaces summary API to avoid errors like `ValueError: Stop argument for islice() must be None or an integer`. The missing validation has existed since v1.31.0. diff --git a/changelog.d/10612.misc b/changelog.d/10612.misc deleted file mode 100644 index c7a9457022..0000000000 --- a/changelog.d/10612.misc +++ /dev/null @@ -1 +0,0 @@ -Build Debian packages for Debian 12 (Bookworm). diff --git a/changelog.d/10620.misc b/changelog.d/10620.misc deleted file mode 100644 index 8b29668a1f..0000000000 --- a/changelog.d/10620.misc +++ /dev/null @@ -1 +0,0 @@ -Fix up a couple of links to the database schema documentation. diff --git a/changelog.d/10623.bugfix b/changelog.d/10623.bugfix deleted file mode 100644 index 759fba3513..0000000000 --- a/changelog.d/10623.bugfix +++ /dev/null @@ -1 +0,0 @@ -Revert behaviour introduced in v1.38.0 that strips `org.matrix.msc2732.device_unused_fallback_key_types` from `/sync` when its value is empty. This field should instead always be present according to [MSC2732](https://github.com/matrix-org/matrix-doc/blob/master/proposals/2732-olm-fallback-keys.md). \ No newline at end of file diff --git a/changelog.d/10628.feature b/changelog.d/10628.feature deleted file mode 100644 index 708cb9b599..0000000000 --- a/changelog.d/10628.feature +++ /dev/null @@ -1 +0,0 @@ -Admin API to delete several media for a specific user. Contributed by @dklimpel. \ No newline at end of file diff --git a/changelog.d/10631.misc b/changelog.d/10631.misc deleted file mode 100644 index d2a4624d53..0000000000 --- a/changelog.d/10631.misc +++ /dev/null @@ -1 +0,0 @@ -Fix a broken link to the upgrade notes. diff --git a/changelog.d/10638.feature b/changelog.d/10638.feature deleted file mode 100644 index c1de91f334..0000000000 --- a/changelog.d/10638.feature +++ /dev/null @@ -1 +0,0 @@ -Add option to allow modules to run periodic tasks on all instances, rather than just the one configured to run background tasks. diff --git a/changelog.d/9581.feature b/changelog.d/9581.feature deleted file mode 100644 index fa1949cd4b..0000000000 --- a/changelog.d/9581.feature +++ /dev/null @@ -1 +0,0 @@ -Add `get_userinfo_by_id` method to ModuleApi. diff --git a/debian/changelog b/debian/changelog index e101423fe4..68f309b0b2 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.41.0~rc1) stable; urgency=medium + + * New synapse release 1.41.0~rc1. + + -- Synapse Packaging team Wed, 18 Aug 2021 15:52:00 +0100 + matrix-synapse-py3 (1.40.0) stable; urgency=medium * New synapse release 1.40.0. diff --git a/synapse/__init__.py b/synapse/__init__.py index 919293cd80..6ada20a77f 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.40.0" +__version__ = "1.41.0rc1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.5.1 From e328d8ffd93247f5b6be9a1fd1e5dea9b99149e7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Aug 2021 15:54:50 +0100 Subject: Update changelog --- CHANGES.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index b96ac701d5..01766af39c 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -8,7 +8,7 @@ Features - Initial local support for [MSC3266](https://github.com/matrix-org/synapse/pull/10394), Room Summary over the unstable `/rooms/{roomIdOrAlias}/summary` API. ([\#10394](https://github.com/matrix-org/synapse/issues/10394)) - Experimental support for [MSC3288](https://github.com/matrix-org/matrix-doc/pull/3288), sending `room_type` to the identity server for 3pid invites over the `/store-invite` API. ([\#10435](https://github.com/matrix-org/synapse/issues/10435)) - Add support for sending federation requests through a proxy. Contributed by @Bubu and @dklimpel. ([\#10475](https://github.com/matrix-org/synapse/issues/10475)) -- Add support for "marker" events which makes historical events discoverable for servers that already have all of the scrollback history (part of MSC2716). ([\#10498](https://github.com/matrix-org/synapse/issues/10498)) +- Add support for "marker" events which makes historical events discoverable for servers that already have all of the scrollback history (part of [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716)). ([\#10498](https://github.com/matrix-org/synapse/issues/10498)) - Add a configuration setting for the time a `/sync` response is cached for. ([\#10513](https://github.com/matrix-org/synapse/issues/10513)) - The default logging handler for new installations is now `PeriodicallyFlushingMemoryHandler`, a buffered logging handler which periodically flushes itself. ([\#10518](https://github.com/matrix-org/synapse/issues/10518)) - Add support for new redaction rules for historical events specified in [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716). ([\#10538](https://github.com/matrix-org/synapse/issues/10538)) @@ -31,7 +31,7 @@ Bugfixes - Fix exceptions in logs when failing to get remote room list. ([\#10541](https://github.com/matrix-org/synapse/issues/10541)) - Fix longstanding bug which caused the user "status" to be reset when the user went offline. Contributed by @dklimpel. ([\#10550](https://github.com/matrix-org/synapse/issues/10550)) - Allow public rooms to be previewed in the spaces summary APIs from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). ([\#10580](https://github.com/matrix-org/synapse/issues/10580)) -- Fix a bug introduced in v1.37.1 where an error could occur in the asyncronous processing of PDUs when the queue was empty. ([\#10592](https://github.com/matrix-org/synapse/issues/10592)) +- Fix a bug introduced in v1.37.1 where an error could occur in the asynchronous processing of PDUs when the queue was empty. ([\#10592](https://github.com/matrix-org/synapse/issues/10592)) - Fix errors on /sync when read receipt data is a string. Only affects homeservers with the experimental flag for [MSC2285](https://github.com/matrix-org/matrix-doc/pull/2285) enabled. Contributed by @SimonBrandner. ([\#10606](https://github.com/matrix-org/synapse/issues/10606)) - Additional validation for the spaces summary API to avoid errors like `ValueError: Stop argument for islice() must be None or an integer`. The missing validation has existed since v1.31.0. ([\#10611](https://github.com/matrix-org/synapse/issues/10611)) - Revert behaviour introduced in v1.38.0 that strips `org.matrix.msc2732.device_unused_fallback_key_types` from `/sync` when its value is empty. This field should instead always be present according to [MSC2732](https://github.com/matrix-org/matrix-doc/blob/master/proposals/2732-olm-fallback-keys.md). ([\#10623](https://github.com/matrix-org/synapse/issues/10623)) @@ -48,7 +48,7 @@ Improved Documentation Deprecations and Removals ------------------------- -- No longer build `.dev` packages for Ubuntu 20.10 LTS Groovy Gorilla, which has now EOLed. ([\#10588](https://github.com/matrix-org/synapse/issues/10588)) +- No longer build `.deb` packages for Ubuntu 20.10 LTS Groovy Gorilla, which has now EOLed. ([\#10588](https://github.com/matrix-org/synapse/issues/10588)) - The `template_dir` configuration settings in the `sso`, `account_validity` and `email` sections of the configuration file are now deprecated in favour of the global `templates.custom_template_directory` setting. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html) for more information. ([\#10596](https://github.com/matrix-org/synapse/issues/10596)) @@ -60,7 +60,7 @@ Internal Changes - Include room ID in ignored EDU log messages. Contributed by @ilmari. ([\#10507](https://github.com/matrix-org/synapse/issues/10507)) - Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). ([\#10527](https://github.com/matrix-org/synapse/issues/10527), [\#10530](https://github.com/matrix-org/synapse/issues/10530)) - Fix CI to not break when run against branches rather than pull requests. ([\#10529](https://github.com/matrix-org/synapse/issues/10529)) -- Mark all events stemming from the MSC2716 `/batch_send` endpoint as historical. ([\#10537](https://github.com/matrix-org/synapse/issues/10537)) +- Mark all events stemming from the [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` endpoint as historical. ([\#10537](https://github.com/matrix-org/synapse/issues/10537)) - Clean up some of the federation event authentication code for clarity. ([\#10539](https://github.com/matrix-org/synapse/issues/10539), [\#10591](https://github.com/matrix-org/synapse/issues/10591)) - Convert `Transaction` and `Edu` objects to attrs. ([\#10542](https://github.com/matrix-org/synapse/issues/10542)) - Update `/batch_send` endpoint to only return `state_events` created by the `state_events_from_before` passed in. ([\#10552](https://github.com/matrix-org/synapse/issues/10552)) -- cgit 1.5.1 From d9856d9150c05577282cb52f548de7737bb1e454 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 18 Aug 2021 11:00:37 -0400 Subject: Fix weakref_slot parameter for room member storage attrs. (#10642) Follow-up to #10629 which set it to true, not false. --- changelog.d/10642.misc | 1 + synapse/storage/roommember.py | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) create mode 100644 changelog.d/10642.misc diff --git a/changelog.d/10642.misc b/changelog.d/10642.misc new file mode 100644 index 0000000000..cca1eb6c57 --- /dev/null +++ b/changelog.d/10642.misc @@ -0,0 +1 @@ +Convert room member storage tuples to `attrs` classes. diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 0ff66debdf..9fad67ce48 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -23,7 +23,7 @@ from synapse.types import PersistedEventPosition logger = logging.getLogger(__name__) -@attr.s(slots=True, frozen=True, weakref_slot=True, auto_attribs=True) +@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True) class RoomsForUser: room_id: str sender: str @@ -32,19 +32,19 @@ class RoomsForUser: stream_ordering: int -@attr.s(slots=True, frozen=True, weakref_slot=True, auto_attribs=True) +@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True) class GetRoomsForUserWithStreamOrdering: room_id: str event_pos: PersistedEventPosition -@attr.s(slots=True, frozen=True, weakref_slot=True, auto_attribs=True) +@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True) class ProfileInfo: avatar_url: Optional[str] display_name: Optional[str] -@attr.s(slots=True, frozen=True, weakref_slot=True, auto_attribs=True) +@attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True) class MemberSummary: # A truncated list of (user_id, event_id) tuples for users of a given # membership type, suitable for use in calculating heroes for a room. -- cgit 1.5.1 From b9c35586a4fadae271b3fefb90a3108f74e9e3d5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Aug 2021 16:59:36 +0100 Subject: Update docs/upgrade.md with new version --- docs/upgrade.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/upgrade.md b/docs/upgrade.md index 99e32034c8..e5d386b02f 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -86,7 +86,7 @@ process, for example: ``` -# Upgrading to v1.xx.0 +# Upgrading to v1.41.0 ## Add support for routing outbound HTTP requests via a proxy for federation -- cgit 1.5.1 From 0c3565da4cdbe53646ae0bc737900526a1d3df67 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 18 Aug 2021 19:53:20 +0200 Subject: Additional type hints for the proxy agent and SRV resolver modules. (#10608) --- changelog.d/10608.misc | 1 + mypy.ini | 3 +++ synapse/http/additional_resource.py | 13 +++++++--- synapse/http/federation/srv_resolver.py | 45 ++++++++++++++++++--------------- synapse/http/proxyagent.py | 4 +-- 5 files changed, 41 insertions(+), 25 deletions(-) create mode 100644 changelog.d/10608.misc diff --git a/changelog.d/10608.misc b/changelog.d/10608.misc new file mode 100644 index 0000000000..875bdd2fd0 --- /dev/null +++ b/changelog.d/10608.misc @@ -0,0 +1 @@ +Improve type hints for the proxy agent and SRV resolver modules. Contributed by @dklimpel. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index e1b9405daa..107f4de76c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -28,10 +28,13 @@ files = synapse/federation, synapse/groups, synapse/handlers, + synapse/http/additional_resource.py, synapse/http/client.py, synapse/http/federation/matrix_federation_agent.py, + synapse/http/federation/srv_resolver.py, synapse/http/federation/well_known_resolver.py, synapse/http/matrixfederationclient.py, + synapse/http/proxyagent.py, synapse/http/servlet.py, synapse/http/server.py, synapse/http/site.py, diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py index 55ea97a07f..9a2684aca4 100644 --- a/synapse/http/additional_resource.py +++ b/synapse/http/additional_resource.py @@ -12,8 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + +from twisted.web.server import Request + from synapse.http.server import DirectServeJsonResource +if TYPE_CHECKING: + from synapse.server import HomeServer + class AdditionalResource(DirectServeJsonResource): """Resource wrapper for additional_resources @@ -25,7 +32,7 @@ class AdditionalResource(DirectServeJsonResource): and exception handling. """ - def __init__(self, hs, handler): + def __init__(self, hs: "HomeServer", handler): """Initialise AdditionalResource The ``handler`` should return a deferred which completes when it has @@ -33,14 +40,14 @@ class AdditionalResource(DirectServeJsonResource): ``request.write()``, and call ``request.finish()``. Args: - hs (synapse.server.HomeServer): homeserver + hs: homeserver handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred): function to be called to handle the request. """ super().__init__() self._handler = handler - def _async_render(self, request): + def _async_render(self, request: Request): # Cheekily pass the result straight through, so we don't need to worry # if its an awaitable or not. return self._handler(request) diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index b8ed4ec905..f68646fd0d 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -16,7 +16,7 @@ import logging import random import time -from typing import List +from typing import Callable, Dict, List import attr @@ -28,35 +28,35 @@ from synapse.logging.context import make_deferred_yieldable logger = logging.getLogger(__name__) -SERVER_CACHE = {} +SERVER_CACHE: Dict[bytes, List["Server"]] = {} -@attr.s(slots=True, frozen=True) +@attr.s(auto_attribs=True, slots=True, frozen=True) class Server: """ Our record of an individual server which can be tried to reach a destination. Attributes: - host (bytes): target hostname - port (int): - priority (int): - weight (int): - expires (int): when the cache should expire this record - in *seconds* since + host: target hostname + port: + priority: + weight: + expires: when the cache should expire this record - in *seconds* since the epoch """ - host = attr.ib() - port = attr.ib() - priority = attr.ib(default=0) - weight = attr.ib(default=0) - expires = attr.ib(default=0) + host: bytes + port: int + priority: int = 0 + weight: int = 0 + expires: int = 0 -def _sort_server_list(server_list): +def _sort_server_list(server_list: List[Server]) -> List[Server]: """Given a list of SRV records sort them into priority order and shuffle each priority with the given weight. """ - priority_map = {} + priority_map: Dict[int, List[Server]] = {} for server in server_list: priority_map.setdefault(server.priority, []).append(server) @@ -103,11 +103,16 @@ class SrvResolver: Args: dns_client (twisted.internet.interfaces.IResolver): twisted resolver impl - cache (dict): cache object - get_time (callable): clock implementation. Should return seconds since the epoch + cache: cache object + get_time: clock implementation. Should return seconds since the epoch """ - def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time): + def __init__( + self, + dns_client=client, + cache: Dict[bytes, List[Server]] = SERVER_CACHE, + get_time: Callable[[], float] = time.time, + ): self._dns_client = dns_client self._cache = cache self._get_time = get_time @@ -116,7 +121,7 @@ class SrvResolver: """Look up a SRV record Args: - service_name (bytes): record to look up + service_name: record to look up Returns: a list of the SRV records, or an empty list if none found @@ -158,7 +163,7 @@ class SrvResolver: and answers[0].payload and answers[0].payload.target == dns.Name(b".") ): - raise ConnectError("Service %s unavailable" % service_name) + raise ConnectError(f"Service {service_name!r} unavailable") servers = [] diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index a3f31452d0..6fd88bde20 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -173,7 +173,7 @@ class ProxyAgent(_AgentBase): raise ValueError(f"Invalid URI {uri!r}") parsed_uri = URI.fromBytes(uri) - pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port) + pool_key = f"{parsed_uri.scheme!r}{parsed_uri.host!r}{parsed_uri.port}" request_path = parsed_uri.originForm should_skip_proxy = False @@ -199,7 +199,7 @@ class ProxyAgent(_AgentBase): ) # Cache *all* connections under the same key, since we are only # connecting to a single destination, the proxy: - pool_key = ("http-proxy", self.http_proxy_endpoint) + pool_key = "http-proxy" endpoint = self.http_proxy_endpoint request_path = uri elif ( -- cgit 1.5.1 From 220f901229a506a82aedc51c5923768bf935ea4f Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 19 Aug 2021 11:25:05 +0200 Subject: Remove not needed database updates in modify user admin API (#10627) --- changelog.d/10627.misc | 1 + docs/admin_api/user_admin_api.md | 8 +++- synapse/rest/admin/users.py | 55 ++++++++++++++--------- synapse/storage/databases/main/registration.py | 25 ++++++++--- tests/rest/admin/test_user.py | 62 ++++++++++++++++++++++++-- 5 files changed, 118 insertions(+), 33 deletions(-) create mode 100644 changelog.d/10627.misc diff --git a/changelog.d/10627.misc b/changelog.d/10627.misc new file mode 100644 index 0000000000..e6d314976e --- /dev/null +++ b/changelog.d/10627.misc @@ -0,0 +1 @@ +Remove not needed database updates in modify user admin API. \ No newline at end of file diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 6a9335d6ec..60dc913915 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -21,11 +21,15 @@ It returns a JSON body like the following: "threepids": [ { "medium": "email", - "address": "" + "address": "", + "added_at": 1586458409743, + "validated_at": 1586458409743 }, { "medium": "email", - "address": "" + "address": "", + "added_at": 1586458409743, + "validated_at": 1586458409743 } ], "avatar_url": "", diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 3c8a0c6883..c1a1ba645e 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -228,13 +228,18 @@ class UserRestServletV2(RestServlet): if not isinstance(deactivate, bool): raise SynapseError(400, "'deactivated' parameter is not of type boolean") - # convert into List[Tuple[str, str]] + # convert List[Dict[str, str]] into Set[Tuple[str, str]] if external_ids is not None: - new_external_ids = [] - for external_id in external_ids: - new_external_ids.append( - (external_id["auth_provider"], external_id["external_id"]) - ) + new_external_ids = { + (external_id["auth_provider"], external_id["external_id"]) + for external_id in external_ids + } + + # convert List[Dict[str, str]] into Set[Tuple[str, str]] + if threepids is not None: + new_threepids = { + (threepid["medium"], threepid["address"]) for threepid in threepids + } if user: # modify user if "displayname" in body: @@ -243,29 +248,39 @@ class UserRestServletV2(RestServlet): ) if threepids is not None: - # remove old threepids from user - old_threepids = await self.store.user_get_threepids(user_id) - for threepid in old_threepids: + # get changed threepids (added and removed) + # convert List[Dict[str, Any]] into Set[Tuple[str, str]] + cur_threepids = { + (threepid["medium"], threepid["address"]) + for threepid in await self.store.user_get_threepids(user_id) + } + add_threepids = new_threepids - cur_threepids + del_threepids = cur_threepids - new_threepids + + # remove old threepids + for medium, address in del_threepids: try: await self.auth_handler.delete_threepid( - user_id, threepid["medium"], threepid["address"], None + user_id, medium, address, None ) except Exception: logger.exception("Failed to remove threepids") raise SynapseError(500, "Failed to remove threepids") - # add new threepids to user + # add new threepids current_time = self.hs.get_clock().time_msec() - for threepid in threepids: + for medium, address in add_threepids: await self.auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], current_time + user_id, medium, address, current_time ) if external_ids is not None: # get changed external_ids (added and removed) - cur_external_ids = await self.store.get_external_ids_by_user(user_id) - add_external_ids = set(new_external_ids) - set(cur_external_ids) - del_external_ids = set(cur_external_ids) - set(new_external_ids) + cur_external_ids = set( + await self.store.get_external_ids_by_user(user_id) + ) + add_external_ids = new_external_ids - cur_external_ids + del_external_ids = cur_external_ids - new_external_ids # remove old external_ids for auth_provider, external_id in del_external_ids: @@ -348,9 +363,9 @@ class UserRestServletV2(RestServlet): if threepids is not None: current_time = self.hs.get_clock().time_msec() - for threepid in threepids: + for medium, address in new_threepids: await self.auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], current_time + user_id, medium, address, current_time ) if ( self.hs.config.email_enable_notifs @@ -362,8 +377,8 @@ class UserRestServletV2(RestServlet): kind="email", app_id="m.email", app_display_name="Email Notifications", - device_display_name=threepid["address"], - pushkey=threepid["address"], + device_display_name=address, + pushkey=address, lang=None, # We don't know a user's language here data={}, ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index c67bea81c6..469dd53e0c 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -754,16 +754,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) return user_id - def get_user_id_by_threepid_txn(self, txn, medium, address): + def get_user_id_by_threepid_txn( + self, txn, medium: str, address: str + ) -> Optional[str]: """Returns user id from threepid Args: txn (cursor): - medium (str): threepid medium e.g. email - address (str): threepid address e.g. me@example.com + medium: threepid medium e.g. email + address: threepid address e.g. me@example.com Returns: - str|None: user id or None if no user id/threepid mapping exists + user id, or None if no user id/threepid mapping exists """ ret = self.db_pool.simple_select_one_txn( txn, @@ -776,14 +778,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return ret["user_id"] return None - async def user_add_threepid(self, user_id, medium, address, validated_at, added_at): + async def user_add_threepid( + self, + user_id: str, + medium: str, + address: str, + validated_at: int, + added_at: int, + ) -> None: await self.db_pool.simple_upsert( "user_threepids", {"medium": medium, "address": address}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, ) - async def user_get_threepids(self, user_id): + async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]: return await self.db_pool.simple_select_list( "user_threepids", {"user_id": user_id}, @@ -791,7 +800,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "user_get_threepids", ) - async def user_delete_threepid(self, user_id, medium, address) -> None: + async def user_delete_threepid( + self, user_id: str, medium: str, address: str + ) -> None: await self.db_pool.simple_delete( "user_threepids", keyvalues={"user_id": user_id, "medium": medium, "address": address}, diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index ef77275238..ee204c404b 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1431,12 +1431,14 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual(1, len(channel.json_body["threepids"])) self.assertEqual( "external_id1", channel.json_body["external_ids"][0]["external_id"] ) self.assertEqual( "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"] ) + self.assertEqual(1, len(channel.json_body["external_ids"])) self.assertFalse(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) self._check_fields(channel.json_body) @@ -1676,18 +1678,53 @@ class UserRestTestCase(unittest.HomeserverTestCase): Test setting threepid for an other user. """ - # Delete old and add new threepid to user + # Add two threepids to user channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content={"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]}, + content={ + "threepids": [ + {"medium": "email", "address": "bob1@bob.bob"}, + {"medium": "email", "address": "bob2@bob.bob"}, + ], + }, ) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["threepids"])) + # result does not always have the same sort order, therefore it becomes sorted + sorted_result = sorted( + channel.json_body["threepids"], key=lambda k: k["address"] + ) + self.assertEqual("email", sorted_result[0]["medium"]) + self.assertEqual("bob1@bob.bob", sorted_result[0]["address"]) + self.assertEqual("email", sorted_result[1]["medium"]) + self.assertEqual("bob2@bob.bob", sorted_result[1]["address"]) + self._check_fields(channel.json_body) + + # Set a new and remove a threepid + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={ + "threepids": [ + {"medium": "email", "address": "bob2@bob.bob"}, + {"medium": "email", "address": "bob3@bob.bob"}, + ], + }, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual("email", channel.json_body["threepids"][1]["medium"]) + self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"]) + self._check_fields(channel.json_body) # Get user channel = self.make_request( @@ -1698,8 +1735,24 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["threepids"])) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"]) + self.assertEqual("email", channel.json_body["threepids"][1]["medium"]) + self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"]) + self._check_fields(channel.json_body) + + # Remove threepids + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"threepids": []}, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(0, len(channel.json_body["threepids"])) + self._check_fields(channel.json_body) def test_set_external_id(self): """ @@ -1778,6 +1831,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(2, len(channel.json_body["external_ids"])) self.assertEqual( channel.json_body["external_ids"], [ -- cgit 1.5.1 From b5fef6054a3d5763a031440ad4c56fc97b12d2aa Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 19 Aug 2021 11:40:40 +0200 Subject: Support MSC3283: Expose `enable_set_displayname` in capabilities (#10452) --- changelog.d/10452.feature | 1 + synapse/config/experimental.py | 3 + synapse/rest/client/capabilities.py | 11 +++ tests/rest/client/v2_alpha/test_capabilities.py | 109 +++++++++++++++++++----- 4 files changed, 101 insertions(+), 23 deletions(-) create mode 100644 changelog.d/10452.feature diff --git a/changelog.d/10452.feature b/changelog.d/10452.feature new file mode 100644 index 0000000000..f332b383e3 --- /dev/null +++ b/changelog.d/10452.feature @@ -0,0 +1 @@ +Add support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283): Expose enable_set_displayname in capabilities. \ No newline at end of file diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index b918fb15b0..a85d8a5dc6 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -39,5 +39,8 @@ class ExperimentalConfig(Config): # MSC3244 (room version capabilities) self.msc3244_enabled: bool = experimental.get("msc3244_enabled", False) + # MSC3283 (set displayname, avatar_url and change 3pid capabilities) + self.msc3283_enabled: bool = experimental.get("msc3283_enabled", False) + # MSC3266 (room summary api) self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index 88e3aac797..093549512e 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -61,6 +61,17 @@ class CapabilitiesRestServlet(RestServlet): "org.matrix.msc3244.room_capabilities" ] = MSC3244_CAPABILITIES + if self.config.experimental.msc3283_enabled: + response["capabilities"]["org.matrix.msc3283.set_displayname"] = { + "enabled": self.config.enable_set_displayname + } + response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = { + "enabled": self.config.enable_set_avatar_url + } + response["capabilities"]["org.matrix.msc3283.3pid_changes"] = { + "enabled": self.config.enable_3pid_changes + } + return 200, response diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index ad83b3d2ff..ac31e5ceaf 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -30,19 +30,22 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.url = b"/_matrix/client/r0/capabilities" hs = self.setup_test_homeserver() - self.store = hs.get_datastore() self.config = hs.config self.auth_handler = hs.get_auth_handler() return hs + def prepare(self, reactor, clock, hs): + self.localpart = "user" + self.password = "pass" + self.user = self.register_user(self.localpart, self.password) + def test_check_auth_required(self): channel = self.make_request("GET", self.url) self.assertEqual(channel.code, 401) def test_get_room_version_capabilities(self): - self.register_user("user", "pass") - access_token = self.login("user", "pass") + access_token = self.login(self.localpart, self.password) channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] @@ -57,10 +60,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): ) def test_get_change_password_capabilities_password_login(self): - localpart = "user" - password = "pass" - user = self.register_user(localpart, password) - access_token = self.login(user, password) + access_token = self.login(self.localpart, self.password) channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] @@ -70,12 +70,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): @override_config({"password_config": {"localdb_enabled": False}}) def test_get_change_password_capabilities_localdb_disabled(self): - localpart = "user" - password = "pass" - user = self.register_user(localpart, password) access_token = self.get_success( self.auth_handler.get_access_token_for_user_id( - user, device_id=None, valid_until_ms=None + self.user, device_id=None, valid_until_ms=None ) ) @@ -87,12 +84,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): @override_config({"password_config": {"enabled": False}}) def test_get_change_password_capabilities_password_disabled(self): - localpart = "user" - password = "pass" - user = self.register_user(localpart, password) access_token = self.get_success( self.auth_handler.get_access_token_for_user_id( - user, device_id=None, valid_until_ms=None + self.user, device_id=None, valid_until_ms=None ) ) @@ -102,13 +96,85 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertFalse(capabilities["m.change_password"]["enabled"]) + def test_get_change_users_attributes_capabilities_when_msc3283_disabled(self): + """Test that per default msc3283 is disabled server returns `m.change_password`.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertTrue(capabilities["m.change_password"]["enabled"]) + self.assertNotIn("org.matrix.msc3283.set_displayname", capabilities) + self.assertNotIn("org.matrix.msc3283.set_avatar_url", capabilities) + self.assertNotIn("org.matrix.msc3283.3pid_changes", capabilities) + + @override_config({"experimental_features": {"msc3283_enabled": True}}) + def test_get_change_users_attributes_capabilities_when_msc3283_enabled(self): + """Test if msc3283 is enabled server returns capabilities.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertTrue(capabilities["m.change_password"]["enabled"]) + self.assertTrue(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) + self.assertTrue(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) + self.assertTrue(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) + + @override_config( + { + "enable_set_displayname": False, + "experimental_features": {"msc3283_enabled": True}, + } + ) + def test_get_set_displayname_capabilities_displayname_disabled(self): + """Test if set displayname is disabled that the server responds it.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) + + @override_config( + { + "enable_set_avatar_url": False, + "experimental_features": {"msc3283_enabled": True}, + } + ) + def test_get_set_avatar_url_capabilities_avatar_url_disabled(self): + """Test if set avatar_url is disabled that the server responds it.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) + + @override_config( + { + "enable_3pid_changes": False, + "experimental_features": {"msc3283_enabled": True}, + } + ) + def test_change_3pid_capabilities_3pid_disabled(self): + """Test if change 3pid is disabled that the server responds it.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) + def test_get_does_not_include_msc3244_fields_by_default(self): - localpart = "user" - password = "pass" - user = self.register_user(localpart, password) access_token = self.get_success( self.auth_handler.get_access_token_for_user_id( - user, device_id=None, valid_until_ms=None + self.user, device_id=None, valid_until_ms=None ) ) @@ -122,12 +188,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): @override_config({"experimental_features": {"msc3244_enabled": True}}) def test_get_does_include_msc3244_fields_when_enabled(self): - localpart = "user" - password = "pass" - user = self.register_user(localpart, password) access_token = self.get_success( self.auth_handler.get_access_token_for_user_id( - user, device_id=None, valid_until_ms=None + self.user, device_id=None, valid_until_ms=None ) ) -- cgit 1.5.1 From ce6819a7015847084208ff09758c5a13c1c4c429 Mon Sep 17 00:00:00 2001 From: John-Scott Atlakson <24574+jsma@users.noreply.github.com> Date: Thu, 19 Aug 2021 03:16:00 -0700 Subject: Fix typo in release notes (#10646) Ubuntu 20.10 was not an LTS release Signed-off-by: John-Scott Atlakson 24574+jsma@users.noreply.github.com --- CHANGES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 01766af39c..cad9423ebd 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -48,7 +48,7 @@ Improved Documentation Deprecations and Removals ------------------------- -- No longer build `.deb` packages for Ubuntu 20.10 LTS Groovy Gorilla, which has now EOLed. ([\#10588](https://github.com/matrix-org/synapse/issues/10588)) +- No longer build `.deb` packages for Ubuntu 20.10 Groovy Gorilla, which has now EOLed. ([\#10588](https://github.com/matrix-org/synapse/issues/10588)) - The `template_dir` configuration settings in the `sso`, `account_validity` and `email` sections of the configuration file are now deprecated in favour of the global `templates.custom_template_directory` setting. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html) for more information. ([\#10596](https://github.com/matrix-org/synapse/issues/10596)) -- cgit 1.5.1 From 000aa89be63c27092998eca03c97eaead21404cd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 19 Aug 2021 11:12:55 -0400 Subject: Do not include rooms with an unknown room version in a sync response. (#10644) A user will still see this room if it is in a local cache, but it will not reappear if clearing the cache and reloading. --- changelog.d/10644.bugfix | 1 + mypy.ini | 1 + synapse/handlers/sync.py | 7 +- synapse/storage/databases/main/roommember.py | 8 +- synapse/storage/roommember.py | 1 + tests/handlers/test_sync.py | 137 +++++++++++++++++++++++-- tests/replication/slave/storage/test_events.py | 1 + 7 files changed, 145 insertions(+), 11 deletions(-) create mode 100644 changelog.d/10644.bugfix diff --git a/changelog.d/10644.bugfix b/changelog.d/10644.bugfix new file mode 100644 index 0000000000..d88a81fd82 --- /dev/null +++ b/changelog.d/10644.bugfix @@ -0,0 +1 @@ +Rooms with unsupported room versions are no longer returned via `/sync`. diff --git a/mypy.ini b/mypy.ini index 107f4de76c..90ade37b3f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -90,6 +90,7 @@ files = tests/test_utils, tests/handlers/test_password_providers.py, tests/handlers/test_room_summary.py, + tests/handlers/test_sync.py, tests/rest/client/v1/test_login.py, tests/rest/client/v2_alpha/test_auth.py, tests/util/test_itertools.py, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b7b299961f..2203c45dcc 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1,5 +1,4 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2018, 2019 New Vector Ltd +# Copyright 2015-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,6 +30,7 @@ from prometheus_client import Counter from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.filtering import FilterCollection +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.logging.context import current_context from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span @@ -1843,6 +1843,9 @@ class SyncHandler: knocked = [] for event in room_list: + if event.room_version_id not in KNOWN_ROOM_VERSIONS: + continue + if event.membership == Membership.JOIN: room_entries.append( RoomSyncResultBuilder( diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index c2f6b9d63d..c58a4b8690 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -386,9 +386,10 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) sql = """ - SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering + SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering, r.room_version FROM local_current_membership AS c INNER JOIN events AS e USING (room_id, event_id) + INNER JOIN rooms AS r USING (room_id) WHERE user_id = ? AND %s @@ -397,7 +398,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) txn.execute(sql, (user_id, *args)) - results = [RoomsForUser(**r) for r in self.db_pool.cursor_to_dict(txn)] + results = [RoomsForUser(*r) for r in txn] return results @@ -447,7 +448,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): Returns: Returns the rooms the user is in currently, along with the stream - ordering of the most recent join for that user and room. + ordering of the most recent join for that user and room, along with + the room version of the room. """ return await self.db_pool.runInteraction( "get_rooms_for_user_with_stream_ordering", diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 9fad67ce48..2500381b7b 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -30,6 +30,7 @@ class RoomsForUser: membership: str event_id: str stream_ordering: int + room_version_id: str @attr.s(slots=True, frozen=True, weakref_slot=False, auto_attribs=True) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 84f05f6c58..339c039914 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -12,9 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + +from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, ResourceLimitError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION +from synapse.api.room_versions import RoomVersions from synapse.handlers.sync import SyncConfig +from synapse.rest import admin +from synapse.rest.client import knock, login, room +from synapse.server import HomeServer from synapse.types import UserID, create_requester import tests.unittest @@ -24,8 +31,14 @@ import tests.utils class SyncTestCase(tests.unittest.HomeserverTestCase): """Tests Sync Handler.""" - def prepare(self, reactor, clock, hs): - self.hs = hs + servlets = [ + admin.register_servlets, + knock.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs: HomeServer): self.sync_handler = self.hs.get_sync_handler() self.store = self.hs.get_datastore() @@ -68,12 +81,124 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): ) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + def test_unknown_room_version(self): + """ + A room with an unknown room version should not break sync (and should be excluded). + """ + inviter = self.register_user("creator", "pass", admin=True) + inviter_tok = self.login("@creator:test", "pass") + + user = self.register_user("user", "pass") + tok = self.login("user", "pass") + + # Do an initial sync on a different device. + requester = create_requester(user) + initial_result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, sync_config=generate_sync_config(user, device_id="dev") + ) + ) + + # Create a room as the user. + joined_room = self.helper.create_room_as(user, tok=tok) + + # Invite the user to the room as someone else. + invite_room = self.helper.create_room_as(inviter, tok=inviter_tok) + self.helper.invite(invite_room, targ=user, tok=inviter_tok) + + knock_room = self.helper.create_room_as( + inviter, room_version=RoomVersions.V7.identifier, tok=inviter_tok + ) + self.helper.send_state( + knock_room, + EventTypes.JoinRules, + {"join_rule": JoinRules.KNOCK}, + tok=inviter_tok, + ) + channel = self.make_request( + "POST", + "/_matrix/client/r0/knock/%s" % (knock_room,), + b"{}", + tok, + ) + self.assertEquals(200, channel.code, channel.result) + + # The rooms should appear in the sync response. + result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, sync_config=generate_sync_config(user) + ) + ) + self.assertIn(joined_room, [r.room_id for r in result.joined]) + self.assertIn(invite_room, [r.room_id for r in result.invited]) + self.assertIn(knock_room, [r.room_id for r in result.knocked]) + + # Test a incremental sync (by providing a since_token). + result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + sync_config=generate_sync_config(user, device_id="dev"), + since_token=initial_result.next_batch, + ) + ) + self.assertIn(joined_room, [r.room_id for r in result.joined]) + self.assertIn(invite_room, [r.room_id for r in result.invited]) + self.assertIn(knock_room, [r.room_id for r in result.knocked]) + + # Poke the database and update the room version to an unknown one. + for room_id in (joined_room, invite_room, knock_room): + self.get_success( + self.hs.get_datastores().main.db_pool.simple_update( + "rooms", + keyvalues={"room_id": room_id}, + updatevalues={"room_version": "unknown-room-version"}, + desc="updated-room-version", + ) + ) + + # Blow away caches (supported room versions can only change due to a restart). + self.get_success( + self.store.get_rooms_for_user_with_stream_ordering.invalidate_all() + ) + self.store._get_event_cache.clear() + + # The rooms should be excluded from the sync response. + # Get a new request key. + result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, sync_config=generate_sync_config(user) + ) + ) + self.assertNotIn(joined_room, [r.room_id for r in result.joined]) + self.assertNotIn(invite_room, [r.room_id for r in result.invited]) + self.assertNotIn(knock_room, [r.room_id for r in result.knocked]) + + # The rooms should also not be in an incremental sync. + result = self.get_success( + self.sync_handler.wait_for_sync_for_user( + requester, + sync_config=generate_sync_config(user, device_id="dev"), + since_token=initial_result.next_batch, + ) + ) + self.assertNotIn(joined_room, [r.room_id for r in result.joined]) + self.assertNotIn(invite_room, [r.room_id for r in result.invited]) + self.assertNotIn(knock_room, [r.room_id for r in result.knocked]) + + +_request_key = 0 + -def generate_sync_config(user_id: str) -> SyncConfig: +def generate_sync_config( + user_id: str, device_id: Optional[str] = "device_id" +) -> SyncConfig: + """Generate a sync config (with a unique request key).""" + global _request_key + _request_key += 1 return SyncConfig( - user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]), + user=UserID.from_string(user_id), filter_collection=DEFAULT_FILTER_COLLECTION, is_guest=False, - request_key="request_key", - device_id="device_id", + request_key=("request_key", _request_key), + device_id=device_id, ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 8d1b0606c4..b25a06b427 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -150,6 +150,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): "invite", event.event_id, event.internal_metadata.stream_ordering, + RoomVersions.V1.identifier, ) ], ) -- cgit 1.5.1 From 50af1efe4be8c93ee1fd642b60ab66d32317827b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 19 Aug 2021 17:31:40 +0100 Subject: Extract `_resolve_state_at_missing_prevs` (#10624) This is a follow-up to #10615: it takes the code that constructs the state at a backwards extremity, and extracts it to a separate method. --- changelog.d/10624.misc | 1 + synapse/handlers/federation.py | 229 ++++++++++++++++++++++------------------- 2 files changed, 125 insertions(+), 105 deletions(-) create mode 100644 changelog.d/10624.misc diff --git a/changelog.d/10624.misc b/changelog.d/10624.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10624.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 529d025c39..de86918b7b 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -285,12 +285,12 @@ class FederationHandler(BaseHandler): # - Fetching any missing prev events to fill in gaps in the graph # - Fetching state if we have a hole in the graph if not pdu.internal_metadata.is_outlier(): - prevs = set(pdu.prev_event_ids()) - seen = await self.store.have_events_in_timeline(prevs) - missing_prevs = prevs - seen + if sent_to_us_directly: + prevs = set(pdu.prev_event_ids()) + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen - if missing_prevs: - if sent_to_us_directly: + if missing_prevs: # We only backfill backwards to the min depth. min_depth = await self.get_min_depth_for_context(pdu.room_id) logger.debug("min_depth: %d", min_depth) @@ -351,106 +351,8 @@ class FederationHandler(BaseHandler): affected=pdu.event_id, ) - else: - # We don't have all of the prev_events for this event. - # - # In this case, we need to fall back to asking another server in the - # federation for the state at this event. That's ok provided we then - # resolve the state against other bits of the DAG before using it (which - # will ensure that you can't just take over a room by sending an event, - # withholding its prev_events, and declaring yourself to be an admin in - # the subsequent state request). - # - # Since we're pulling this event as a missing prev_event, then clearly - # this event is not going to become the only forward-extremity and we are - # guaranteed to resolve its state against our existing forward - # extremities, so that should be fine. - # - # XXX this really feels like it could/should be merged with the above, - # but there is an interaction with min_depth that I'm not really - # following. - logger.info( - "Event %s is missing prev_events %s: calculating state for a " - "backwards extremity", - event_id, - shortstr(missing_prevs), - ) - - # Calculate the state after each of the previous events, and - # resolve them to find the correct state at the current event. - event_map = {event_id: pdu} - try: - # Get the state of the events we know about - ours = await self.state_store.get_state_groups_ids( - room_id, seen - ) - - # state_maps is a list of mappings from (type, state_key) to event_id - state_maps: List[StateMap[str]] = list(ours.values()) - - # we don't need this any more, let's delete it. - del ours - - # Ask the remote server for the states we don't - # know about - for p in missing_prevs: - logger.info( - "Requesting state after missing prev_event %s", p - ) - - with nested_logging_context(p): - # note that if any of the missing prevs share missing state or - # auth events, the requests to fetch those events are deduped - # by the get_pdu_cache in federation_client. - remote_state = ( - await self._get_state_after_missing_prev_event( - origin, room_id, p - ) - ) - - remote_state_map = { - (x.type, x.state_key): x.event_id - for x in remote_state - } - state_maps.append(remote_state_map) - - for x in remote_state: - event_map[x.event_id] = x - - room_version = await self.store.get_room_version_id(room_id) - state_map = await self._state_resolution_handler.resolve_events_with_store( - room_id, - room_version, - state_maps, - event_map, - state_res_store=StateResolutionStore(self.store), - ) - - # We need to give _process_received_pdu the actual state events - # rather than event ids, so generate that now. - - # First though we need to fetch all the events that are in - # state_map, so we can build up the state below. - evs = await self.store.get_events( - list(state_map.values()), - get_prev_content=False, - redact_behaviour=EventRedactBehaviour.AS_IS, - ) - event_map.update(evs) - - state = [event_map[e] for e in state_map.values()] - except Exception: - logger.warning( - "Error attempting to resolve state at missing " - "prev_events", - exc_info=True, - ) - raise FederationError( - "ERROR", - 403, - "We can't get valid state history.", - affected=event_id, - ) + else: + state = await self._resolve_state_at_missing_prevs(origin, pdu) # A second round of checks for all events. Check that the event passes auth # based on `auth_events`, this allows us to assert that the event would @@ -1493,6 +1395,123 @@ class FederationHandler(BaseHandler): event_infos, ) + async def _resolve_state_at_missing_prevs( + self, dest: str, event: EventBase + ) -> Optional[Iterable[EventBase]]: + """Calculate the state at an event with missing prev_events. + + This is used when we have pulled a batch of events from a remote server, and + still don't have all the prev_events. + + If we already have all the prev_events for `event`, this method does nothing. + + Otherwise, the missing prevs become new backwards extremities, and we fall back + to asking the remote server for the state after each missing `prev_event`, + and resolving across them. + + That's ok provided we then resolve the state against other bits of the DAG + before using it - in other words, that the received event `event` is not going + to become the only forwards_extremity in the room (which will ensure that you + can't just take over a room by sending an event, withholding its prev_events, + and declaring yourself to be an admin in the subsequent state request). + + In other words: we should only call this method if `event` has been *pulled* + as part of a batch of missing prev events, or similar. + + Params: + dest: the remote server to ask for state at the missing prevs. Typically, + this will be the server we got `event` from. + event: an event to check for missing prevs. + + Returns: + if we already had all the prev events, `None`. Otherwise, returns a list of + the events in the state at `event`. + """ + room_id = event.room_id + event_id = event.event_id + + prevs = set(event.prev_event_ids()) + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen + + if not missing_prevs: + return None + + logger.info( + "Event %s is missing prev_events %s: calculating state for a " + "backwards extremity", + event_id, + shortstr(missing_prevs), + ) + # Calculate the state after each of the previous events, and + # resolve them to find the correct state at the current event. + event_map = {event_id: event} + try: + # Get the state of the events we know about + ours = await self.state_store.get_state_groups_ids(room_id, seen) + + # state_maps is a list of mappings from (type, state_key) to event_id + state_maps: List[StateMap[str]] = list(ours.values()) + + # we don't need this any more, let's delete it. + del ours + + # Ask the remote server for the states we don't + # know about + for p in missing_prevs: + logger.info("Requesting state after missing prev_event %s", p) + + with nested_logging_context(p): + # note that if any of the missing prevs share missing state or + # auth events, the requests to fetch those events are deduped + # by the get_pdu_cache in federation_client. + remote_state = await self._get_state_after_missing_prev_event( + dest, room_id, p + ) + + remote_state_map = { + (x.type, x.state_key): x.event_id for x in remote_state + } + state_maps.append(remote_state_map) + + for x in remote_state: + event_map[x.event_id] = x + + room_version = await self.store.get_room_version_id(room_id) + state_map = await self._state_resolution_handler.resolve_events_with_store( + room_id, + room_version, + state_maps, + event_map, + state_res_store=StateResolutionStore(self.store), + ) + + # We need to give _process_received_pdu the actual state events + # rather than event ids, so generate that now. + + # First though we need to fetch all the events that are in + # state_map, so we can build up the state below. + evs = await self.store.get_events( + list(state_map.values()), + get_prev_content=False, + redact_behaviour=EventRedactBehaviour.AS_IS, + ) + event_map.update(evs) + + state = [event_map[e] for e in state_map.values()] + except Exception: + logger.warning( + "Error attempting to resolve state at missing prev_events", + exc_info=True, + ) + raise FederationError( + "ERROR", + 403, + "We can't get valid state history.", + affected=event_id, + ) + return state + def _sanity_check_event(self, ev: EventBase) -> None: """ Do some early sanity checks of a received event -- cgit 1.5.1 From e81d62009ee091d69d56bca702ba0edd39becb1c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 19 Aug 2021 18:05:12 +0100 Subject: Split `on_receive_pdu` in half (#10640) Here we split on_receive_pdu into two functions (on_receive_pdu and process_pulled_event), rather than having both cases in the same method. There's a tiny bit of overlap, but not that much. --- changelog.d/10640.misc | 1 + synapse/federation/federation_server.py | 4 +- synapse/handlers/federation.py | 236 +++++++++++++++++++------------- tests/test_federation.py | 10 +- 4 files changed, 142 insertions(+), 109 deletions(-) create mode 100644 changelog.d/10640.misc diff --git a/changelog.d/10640.misc b/changelog.d/10640.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10640.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index afd8f8580a..e1b58d40c5 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1005,9 +1005,7 @@ class FederationServer(FederationBase): async with lock: logger.info("handling received PDU: %s", event) try: - await self.handler.on_receive_pdu( - origin, event, sent_to_us_directly=True - ) + await self.handler.on_receive_pdu(origin, event) except FederationError as e: # XXX: Ideally we'd inform the remote we failed to process # the event, but we can't return an error in the transaction diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index de86918b7b..246df43501 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -203,18 +203,13 @@ class FederationHandler(BaseHandler): self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages - async def on_receive_pdu( - self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False - ) -> None: - """Process a PDU received via a federation /send/ transaction, or - via backfill of missing prev_events + async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None: + """Process a PDU received via a federation /send/ transaction Args: origin: server which initiated the /send/ transaction. Will be used to fetch missing events or state. pdu: received PDU - sent_to_us_directly: True if this event was pushed to us; False if - we pulled it as the result of a missing prev_event. """ room_id = pdu.room_id @@ -276,8 +271,6 @@ class FederationHandler(BaseHandler): ) return None - state = None - # Check that the event passes auth based on the state at the event. This is # done for events that are to be added to the timeline (non-outliers). # @@ -285,83 +278,72 @@ class FederationHandler(BaseHandler): # - Fetching any missing prev events to fill in gaps in the graph # - Fetching state if we have a hole in the graph if not pdu.internal_metadata.is_outlier(): - if sent_to_us_directly: - prevs = set(pdu.prev_event_ids()) - seen = await self.store.have_events_in_timeline(prevs) - missing_prevs = prevs - seen - - if missing_prevs: - # We only backfill backwards to the min depth. - min_depth = await self.get_min_depth_for_context(pdu.room_id) - logger.debug("min_depth: %d", min_depth) - - if min_depth is not None and pdu.depth > min_depth: - # If we're missing stuff, ensure we only fetch stuff one - # at a time. + prevs = set(pdu.prev_event_ids()) + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen + + if missing_prevs: + # We only backfill backwards to the min depth. + min_depth = await self.get_min_depth_for_context(pdu.room_id) + logger.debug("min_depth: %d", min_depth) + + if min_depth is not None and pdu.depth > min_depth: + # If we're missing stuff, ensure we only fetch stuff one + # at a time. + logger.info( + "Acquiring room lock to fetch %d missing prev_events: %s", + len(missing_prevs), + shortstr(missing_prevs), + ) + with (await self._room_pdu_linearizer.queue(pdu.room_id)): logger.info( - "Acquiring room lock to fetch %d missing prev_events: %s", + "Acquired room lock to fetch %d missing prev_events", len(missing_prevs), - shortstr(missing_prevs), ) - with (await self._room_pdu_linearizer.queue(pdu.room_id)): - logger.info( - "Acquired room lock to fetch %d missing prev_events", - len(missing_prevs), + + try: + await self._get_missing_events_for_pdu( + origin, pdu, prevs, min_depth ) + except Exception as e: + raise Exception( + "Error fetching missing prev_events for %s: %s" + % (event_id, e) + ) from e - try: - await self._get_missing_events_for_pdu( - origin, pdu, prevs, min_depth - ) - except Exception as e: - raise Exception( - "Error fetching missing prev_events for %s: %s" - % (event_id, e) - ) from e - - # Update the set of things we've seen after trying to - # fetch the missing stuff - seen = await self.store.have_events_in_timeline(prevs) - missing_prevs = prevs - seen - - if not missing_prevs: - logger.info("Found all missing prev_events") - - if missing_prevs: - # since this event was pushed to us, it is possible for it to - # become the only forward-extremity in the room, and we would then - # trust its state to be the state for the whole room. This is very - # bad. Further, if the event was pushed to us, there is no excuse - # for us not to have all the prev_events. (XXX: apart from - # min_depth?) - # - # We therefore reject any such events. - logger.warning( - "Rejecting: failed to fetch %d prev events: %s", - len(missing_prevs), - shortstr(missing_prevs), - ) - raise FederationError( - "ERROR", - 403, - ( - "Your server isn't divulging details about prev_events " - "referenced in this event." - ), - affected=pdu.event_id, - ) + # Update the set of things we've seen after trying to + # fetch the missing stuff + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen - else: - state = await self._resolve_state_at_missing_prevs(origin, pdu) + if not missing_prevs: + logger.info("Found all missing prev_events") + + if missing_prevs: + # since this event was pushed to us, it is possible for it to + # become the only forward-extremity in the room, and we would then + # trust its state to be the state for the whole room. This is very + # bad. Further, if the event was pushed to us, there is no excuse + # for us not to have all the prev_events. (XXX: apart from + # min_depth?) + # + # We therefore reject any such events. + logger.warning( + "Rejecting: failed to fetch %d prev events: %s", + len(missing_prevs), + shortstr(missing_prevs), + ) + raise FederationError( + "ERROR", + 403, + ( + "Your server isn't divulging details about prev_events " + "referenced in this event." + ), + affected=pdu.event_id, + ) - # A second round of checks for all events. Check that the event passes auth - # based on `auth_events`, this allows us to assert that the event would - # have been allowed at some point. If an event passes this check its OK - # for it to be used as part of a returned `/state` request, as either - # a) we received the event as part of the original join and so trust it, or - # b) we'll do a state resolution with existing state before it becomes - # part of the "current state", which adds more protection. - await self._process_received_pdu(origin, pdu, state=state) + await self._process_received_pdu(origin, pdu, state=None) async def _get_missing_events_for_pdu( self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int @@ -461,24 +443,7 @@ class FederationHandler(BaseHandler): return logger.info("Got %d prev_events", len(missing_events)) - - # We want to sort these by depth so we process them and - # tell clients about them in order. - missing_events.sort(key=lambda x: x.depth) - - for ev in missing_events: - logger.info("Handling received prev_event %s", ev) - with nested_logging_context(ev.event_id): - try: - await self.on_receive_pdu(origin, ev, sent_to_us_directly=False) - except FederationError as e: - if e.code == 403: - logger.warning( - "Received prev_event %s failed history check.", - ev.event_id, - ) - else: - raise + await self._process_pulled_events(origin, missing_events) async def _get_state_for_room( self, @@ -1395,6 +1360,81 @@ class FederationHandler(BaseHandler): event_infos, ) + async def _process_pulled_events( + self, origin: str, events: Iterable[EventBase] + ) -> None: + """Process a batch of events we have pulled from a remote server + + Pulls in any events required to auth the events, persists the received events, + and notifies clients, if appropriate. + + Assumes the events have already had their signatures and hashes checked. + + Params: + origin: The server we received these events from + events: The received events. + """ + + # We want to sort these by depth so we process them and + # tell clients about them in order. + sorted_events = sorted(events, key=lambda x: x.depth) + + for ev in sorted_events: + with nested_logging_context(ev.event_id): + await self._process_pulled_event(origin, ev) + + async def _process_pulled_event(self, origin: str, event: EventBase) -> None: + """Process a single event that we have pulled from a remote server + + Pulls in any events required to auth the event, persists the received event, + and notifies clients, if appropriate. + + Assumes the event has already had its signatures and hashes checked. + + This is somewhat equivalent to on_receive_pdu, but applies somewhat different + logic in the case that we are missing prev_events (in particular, it just + requests the state at that point, rather than triggering a get_missing_events) - + so is appropriate when we have pulled the event from a remote server, rather + than having it pushed to us. + + Params: + origin: The server we received this event from + events: The received event + """ + logger.info("Processing pulled event %s", event) + + # these should not be outliers. + assert not event.internal_metadata.is_outlier() + + event_id = event.event_id + + existing = await self.store.get_event( + event_id, allow_none=True, allow_rejected=True + ) + if existing: + if not existing.internal_metadata.is_outlier(): + logger.info( + "Ignoring received event %s which we have already seen", + event_id, + ) + return + logger.info("De-outliering event %s", event_id) + + try: + self._sanity_check_event(event) + except SynapseError as err: + logger.warning("Event %s failed sanity check: %s", event_id, err) + return + + try: + state = await self._resolve_state_at_missing_prevs(origin, event) + await self._process_received_pdu(origin, event, state=state) + except FederationError as e: + if e.code == 403: + logger.warning("Pulled event %s failed history check.", event_id) + else: + raise + async def _resolve_state_at_missing_prevs( self, dest: str, event: EventBase ) -> Optional[Iterable[EventBase]]: @@ -1780,7 +1820,7 @@ class FederationHandler(BaseHandler): p, ) with nested_logging_context(p.event_id): - await self.on_receive_pdu(origin, p, sent_to_us_directly=True) + await self.on_receive_pdu(origin, p) except Exception as e: logger.warning( "Error handling queued PDU %s from %s: %s", p.event_id, origin, e diff --git a/tests/test_federation.py b/tests/test_federation.py index 3785799f46..348fcb72a7 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -85,11 +85,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Send the join, it should return None (which is not an error) self.assertEqual( - self.get_success( - self.handler.on_receive_pdu( - "test.serv", join_event, sent_to_us_directly=True - ) - ), + self.get_success(self.handler.on_receive_pdu("test.serv", join_event)), None, ) @@ -135,9 +131,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): with LoggingContext("test-context"): failure = self.get_failure( - self.handler.on_receive_pdu( - "test.serv", lying_event, sent_to_us_directly=True - ), + self.handler.on_receive_pdu("test.serv", lying_event), FederationError, ) -- cgit 1.5.1 From 5cda75fedef3dd02d3b456231be0a1b4bff2a31a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 20 Aug 2021 07:17:50 -0400 Subject: Set room version 8 as preferred for restricted rooms. (#10571) --- changelog.d/10571.feature | 1 + synapse/api/room_versions.py | 2 +- synapse/config/experimental.py | 2 +- tests/rest/client/v2_alpha/test_capabilities.py | 4 ++-- 4 files changed, 5 insertions(+), 4 deletions(-) create mode 100644 changelog.d/10571.feature diff --git a/changelog.d/10571.feature b/changelog.d/10571.feature new file mode 100644 index 0000000000..0da318cd5b --- /dev/null +++ b/changelog.d/10571.feature @@ -0,0 +1 @@ +Enable room capabilities ([MSC3244](https://github.com/matrix-org/matrix-doc/pull/3244)) by default and set room version 8 as the preferred room version for restricted rooms. diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index 11280c4462..8abcdfd4fd 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -293,7 +293,7 @@ MSC3244_CAPABILITIES = { ), RoomVersionCapability( "restricted", - None, + RoomVersions.V8, lambda room_version: room_version.msc3083_join_rules, ), ) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index b918fb15b0..907df9591a 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -37,7 +37,7 @@ class ExperimentalConfig(Config): self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False) # MSC3244 (room version capabilities) - self.msc3244_enabled: bool = experimental.get("msc3244_enabled", False) + self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True) # MSC3266 (room summary api) self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index ad83b3d2ff..13b3c5f499 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -102,7 +102,8 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertFalse(capabilities["m.change_password"]["enabled"]) - def test_get_does_not_include_msc3244_fields_by_default(self): + @override_config({"experimental_features": {"msc3244_enabled": False}}) + def test_get_does_not_include_msc3244_fields_when_disabled(self): localpart = "user" password = "pass" user = self.register_user(localpart, password) @@ -120,7 +121,6 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"] ) - @override_config({"experimental_features": {"msc3244_enabled": True}}) def test_get_does_include_msc3244_fields_when_enabled(self): localpart = "user" password = "pass" -- cgit 1.5.1 From ee3b2ac59a5645cb213b3c11c473613b80fe1e0f Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 20 Aug 2021 15:47:03 +0100 Subject: Validate device_keys for C-S /keys/query requests (#10593) * Validate device_keys for C-S /keys/query requests Closes #10354 A small, not particularly critical fix. I'm interested in seeing if we can find a more systematic approach though. #8445 is the place for any discussion. --- changelog.d/10593.bugfix | 1 + synapse/api/errors.py | 8 ++++ synapse/rest/client/keys.py | 16 ++++++- tests/rest/client/v2_alpha/test_keys.py | 77 +++++++++++++++++++++++++++++++++ 4 files changed, 101 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10593.bugfix create mode 100644 tests/rest/client/v2_alpha/test_keys.py diff --git a/changelog.d/10593.bugfix b/changelog.d/10593.bugfix new file mode 100644 index 0000000000..492e58a7a8 --- /dev/null +++ b/changelog.d/10593.bugfix @@ -0,0 +1 @@ +Reject Client-Server /keys/query requests which provide device_ids incorrectly. \ No newline at end of file diff --git a/synapse/api/errors.py b/synapse/api/errors.py index dc662bca83..9480f448d7 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -147,6 +147,14 @@ class SynapseError(CodeMessageException): return cs_error(self.msg, self.errcode) +class InvalidAPICallError(SynapseError): + """You called an existing API endpoint, but fed that endpoint + invalid or incomplete data.""" + + def __init__(self, msg: str): + super().__init__(HTTPStatus.BAD_REQUEST, msg, Codes.BAD_JSON) + + class ProxiedRequestError(SynapseError): """An error from a general matrix endpoint, eg. from a proxied Matrix API call. diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index d0d9d30d40..012491f597 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -15,8 +15,9 @@ # limitations under the License. import logging +from typing import Any -from synapse.api.errors import SynapseError +from synapse.api.errors import InvalidAPICallError, SynapseError from synapse.http.servlet import ( RestServlet, parse_integer, @@ -163,6 +164,19 @@ class KeyQueryServlet(RestServlet): device_id = requester.device_id timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) + + device_keys = body.get("device_keys") + if not isinstance(device_keys, dict): + raise InvalidAPICallError("'device_keys' must be a JSON object") + + def is_list_of_strings(values: Any) -> bool: + return isinstance(values, list) and all(isinstance(v, str) for v in values) + + if any(not is_list_of_strings(keys) for keys in device_keys.values()): + raise InvalidAPICallError( + "'device_keys' values must be a list of strings", + ) + result = await self.e2e_keys_handler.query_devices( body, timeout, user_id, device_id ) diff --git a/tests/rest/client/v2_alpha/test_keys.py b/tests/rest/client/v2_alpha/test_keys.py new file mode 100644 index 0000000000..80a4e728ff --- /dev/null +++ b/tests/rest/client/v2_alpha/test_keys.py @@ -0,0 +1,77 @@ +from http import HTTPStatus + +from synapse.api.errors import Codes +from synapse.rest import admin +from synapse.rest.client import keys, login + +from tests import unittest + + +class KeyQueryTestCase(unittest.HomeserverTestCase): + servlets = [ + keys.register_servlets, + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def test_rejects_device_id_ice_key_outside_of_list(self): + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + bob = self.register_user("bob", "uncle") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + bob: "device_id1", + }, + }, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + ) + + def test_rejects_device_key_given_as_map_to_bool(self): + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + bob = self.register_user("bob", "uncle") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + bob: { + "device_id1": True, + }, + }, + }, + alice_token, + ) + + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + ) + + def test_requires_device_key(self): + """`device_keys` is required. We should complain if it's missing.""" + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + {}, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + ) -- cgit 1.5.1 From 7862d704fdececf5a3a10466e2e7aee81fdf10f4 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 20 Aug 2021 16:31:02 +0100 Subject: Follow-up: format changelog, add licence (#10593) Merged before approval; these comments from @clokep on that PR. --- changelog.d/10593.bugfix | 2 +- tests/rest/client/v2_alpha/test_keys.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/changelog.d/10593.bugfix b/changelog.d/10593.bugfix index 492e58a7a8..af910bfa4d 100644 --- a/changelog.d/10593.bugfix +++ b/changelog.d/10593.bugfix @@ -1 +1 @@ -Reject Client-Server /keys/query requests which provide device_ids incorrectly. \ No newline at end of file +Reject Client-Server `/keys/query` requests which provide `device_ids` incorrectly. diff --git a/tests/rest/client/v2_alpha/test_keys.py b/tests/rest/client/v2_alpha/test_keys.py index 80a4e728ff..d7fa635eae 100644 --- a/tests/rest/client/v2_alpha/test_keys.py +++ b/tests/rest/client/v2_alpha/test_keys.py @@ -1,3 +1,17 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + from http import HTTPStatus from synapse.api.errors import Codes -- cgit 1.5.1 From f499dc38bcdf0cec83b873923af3d5310538759e Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 20 Aug 2021 17:43:26 +0200 Subject: Simplify tests for the device admin rest API. (#10664) By replacing duplicated code with parameterized tests and avoiding unnecessary dumping of JSON data. --- changelog.d/10664.misc | 1 + tests/rest/admin/test_device.py | 99 ++++++++--------------------------------- 2 files changed, 19 insertions(+), 81 deletions(-) create mode 100644 changelog.d/10664.misc diff --git a/changelog.d/10664.misc b/changelog.d/10664.misc new file mode 100644 index 0000000000..cebd5e9a96 --- /dev/null +++ b/changelog.d/10664.misc @@ -0,0 +1 @@ +Simplify tests for device admin rest API. \ No newline at end of file diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index c4afe5c3d9..a3679be205 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import urllib.parse +from parameterized import parameterized + import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client import login @@ -45,49 +46,23 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.other_user_device_id, ) - def test_no_auth(self): + @parameterized.expand(["GET", "PUT", "DELETE"]) + def test_no_auth(self, method: str): """ Try to get a device of an user without authentication. """ - channel = self.make_request("GET", self.url, b"{}") - - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - - channel = self.make_request("PUT", self.url, b"{}") - - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - - channel = self.make_request("DELETE", self.url, b"{}") + channel = self.make_request(method, self.url, b"{}") self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + @parameterized.expand(["GET", "PUT", "DELETE"]) + def test_requester_is_no_admin(self, method: str): """ If the user is not a server admin, an error is returned. """ channel = self.make_request( - "GET", - self.url, - access_token=self.other_user_token, - ) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - channel = self.make_request( - "PUT", - self.url, - access_token=self.other_user_token, - ) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - channel = self.make_request( - "DELETE", + method, self.url, access_token=self.other_user_token, ) @@ -95,7 +70,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self): + @parameterized.expand(["GET", "PUT", "DELETE"]) + def test_user_does_not_exist(self, method: str): """ Tests that a lookup for a user that does not exist returns a 404 """ @@ -105,7 +81,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ) channel = self.make_request( - "GET", + method, url, access_token=self.admin_user_tok, ) @@ -113,25 +89,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - channel = self.make_request( - "PUT", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(404, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - channel = self.make_request( - "DELETE", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(404, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - def test_user_is_not_local(self): + @parameterized.expand(["GET", "PUT", "DELETE"]) + def test_user_is_not_local(self, method: str): """ Tests that a lookup for a user that is not a local returns a 400 """ @@ -141,25 +100,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): ) channel = self.make_request( - "GET", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) - - channel = self.make_request( - "PUT", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only lookup local users", channel.json_body["error"]) - - channel = self.make_request( - "DELETE", + method, url, access_token=self.admin_user_tok, ) @@ -219,12 +160,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1) } - body = json.dumps(update) channel = self.make_request( "PUT", self.url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content=update, ) self.assertEqual(400, channel.code, msg=channel.json_body) @@ -275,12 +215,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase): Tests a normal successful update of display name """ # Set new display_name - body = json.dumps({"display_name": "new displayname"}) channel = self.make_request( "PUT", self.url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"display_name": "new displayname"}, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -529,12 +468,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): """ Tests that a remove of a device that does not exist returns 200. """ - body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]}) channel = self.make_request( "POST", self.url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"devices": ["unknown_device1", "unknown_device2"]}, ) # Delete unknown devices returns status 200 @@ -560,12 +498,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase): device_ids.append(str(d["device_id"])) # Delete devices - body = json.dumps({"devices": device_ids}) channel = self.make_request( "POST", self.url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"devices": device_ids}, ) self.assertEqual(200, channel.code, msg=channel.json_body) -- cgit 1.5.1 From ecd823d766fecdd7e1c7163073c097a0084122e2 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 20 Aug 2021 17:50:44 +0100 Subject: Flatten tests/rest/client/{v1,v2_alpha} too (#10667) --- changelog.d/10667.misc | 1 + mypy.ini | 4 +- tests/rest/client/test_account.py | 1000 +++++++++ tests/rest/client/test_auth.py | 717 +++++++ tests/rest/client/test_capabilities.py | 212 ++ tests/rest/client/test_directory.py | 154 ++ tests/rest/client/test_events.py | 156 ++ tests/rest/client/test_filter.py | 111 + tests/rest/client/test_keys.py | 91 + tests/rest/client/test_login.py | 1345 ++++++++++++ tests/rest/client/test_password_policy.py | 177 ++ tests/rest/client/test_presence.py | 76 + tests/rest/client/test_profile.py | 270 +++ tests/rest/client/test_push_rule_attrs.py | 414 ++++ tests/rest/client/test_register.py | 746 +++++++ tests/rest/client/test_relations.py | 724 +++++++ tests/rest/client/test_report_event.py | 82 + tests/rest/client/test_rooms.py | 2150 ++++++++++++++++++++ tests/rest/client/test_sendtodevice.py | 200 ++ tests/rest/client/test_shared_rooms.py | 142 ++ tests/rest/client/test_sync.py | 686 +++++++ tests/rest/client/test_typing.py | 121 ++ tests/rest/client/test_upgrade_room.py | 159 ++ tests/rest/client/utils.py | 654 ++++++ tests/rest/client/v1/__init__.py | 13 - tests/rest/client/v1/test_directory.py | 154 -- tests/rest/client/v1/test_events.py | 156 -- tests/rest/client/v1/test_login.py | 1345 ------------ tests/rest/client/v1/test_presence.py | 76 - tests/rest/client/v1/test_profile.py | 270 --- tests/rest/client/v1/test_push_rule_attrs.py | 414 ---- tests/rest/client/v1/test_rooms.py | 2150 -------------------- tests/rest/client/v1/test_typing.py | 121 -- tests/rest/client/v1/utils.py | 654 ------ tests/rest/client/v2_alpha/__init__.py | 0 tests/rest/client/v2_alpha/test_account.py | 1000 --------- tests/rest/client/v2_alpha/test_auth.py | 717 ------- tests/rest/client/v2_alpha/test_capabilities.py | 212 -- tests/rest/client/v2_alpha/test_filter.py | 111 - tests/rest/client/v2_alpha/test_keys.py | 91 - tests/rest/client/v2_alpha/test_password_policy.py | 177 -- tests/rest/client/v2_alpha/test_register.py | 746 ------- tests/rest/client/v2_alpha/test_relations.py | 724 ------- tests/rest/client/v2_alpha/test_report_event.py | 82 - tests/rest/client/v2_alpha/test_sendtodevice.py | 200 -- tests/rest/client/v2_alpha/test_shared_rooms.py | 142 -- tests/rest/client/v2_alpha/test_sync.py | 686 ------- tests/rest/client/v2_alpha/test_upgrade_room.py | 159 -- tests/unittest.py | 2 +- 49 files changed, 10391 insertions(+), 10403 deletions(-) create mode 100644 changelog.d/10667.misc create mode 100644 tests/rest/client/test_account.py create mode 100644 tests/rest/client/test_auth.py create mode 100644 tests/rest/client/test_capabilities.py create mode 100644 tests/rest/client/test_directory.py create mode 100644 tests/rest/client/test_events.py create mode 100644 tests/rest/client/test_filter.py create mode 100644 tests/rest/client/test_keys.py create mode 100644 tests/rest/client/test_login.py create mode 100644 tests/rest/client/test_password_policy.py create mode 100644 tests/rest/client/test_presence.py create mode 100644 tests/rest/client/test_profile.py create mode 100644 tests/rest/client/test_push_rule_attrs.py create mode 100644 tests/rest/client/test_register.py create mode 100644 tests/rest/client/test_relations.py create mode 100644 tests/rest/client/test_report_event.py create mode 100644 tests/rest/client/test_rooms.py create mode 100644 tests/rest/client/test_sendtodevice.py create mode 100644 tests/rest/client/test_shared_rooms.py create mode 100644 tests/rest/client/test_sync.py create mode 100644 tests/rest/client/test_typing.py create mode 100644 tests/rest/client/test_upgrade_room.py create mode 100644 tests/rest/client/utils.py delete mode 100644 tests/rest/client/v1/__init__.py delete mode 100644 tests/rest/client/v1/test_directory.py delete mode 100644 tests/rest/client/v1/test_events.py delete mode 100644 tests/rest/client/v1/test_login.py delete mode 100644 tests/rest/client/v1/test_presence.py delete mode 100644 tests/rest/client/v1/test_profile.py delete mode 100644 tests/rest/client/v1/test_push_rule_attrs.py delete mode 100644 tests/rest/client/v1/test_rooms.py delete mode 100644 tests/rest/client/v1/test_typing.py delete mode 100644 tests/rest/client/v1/utils.py delete mode 100644 tests/rest/client/v2_alpha/__init__.py delete mode 100644 tests/rest/client/v2_alpha/test_account.py delete mode 100644 tests/rest/client/v2_alpha/test_auth.py delete mode 100644 tests/rest/client/v2_alpha/test_capabilities.py delete mode 100644 tests/rest/client/v2_alpha/test_filter.py delete mode 100644 tests/rest/client/v2_alpha/test_keys.py delete mode 100644 tests/rest/client/v2_alpha/test_password_policy.py delete mode 100644 tests/rest/client/v2_alpha/test_register.py delete mode 100644 tests/rest/client/v2_alpha/test_relations.py delete mode 100644 tests/rest/client/v2_alpha/test_report_event.py delete mode 100644 tests/rest/client/v2_alpha/test_sendtodevice.py delete mode 100644 tests/rest/client/v2_alpha/test_shared_rooms.py delete mode 100644 tests/rest/client/v2_alpha/test_sync.py delete mode 100644 tests/rest/client/v2_alpha/test_upgrade_room.py diff --git a/changelog.d/10667.misc b/changelog.d/10667.misc new file mode 100644 index 0000000000..c92846ae26 --- /dev/null +++ b/changelog.d/10667.misc @@ -0,0 +1 @@ +Flatten the `tests.synapse.rests` package by moving the contents of `v1` and `v2_alpha` into the parent. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index 90ade37b3f..b17872211e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -91,8 +91,8 @@ files = tests/handlers/test_password_providers.py, tests/handlers/test_room_summary.py, tests/handlers/test_sync.py, - tests/rest/client/v1/test_login.py, - tests/rest/client/v2_alpha/test_auth.py, + tests/rest/client/test_login.py, + tests/rest/client/test_auth.py, tests/util/test_itertools.py, tests/util/test_stream_change_cache.py diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py new file mode 100644 index 0000000000..b946fca8b3 --- /dev/null +++ b/tests/rest/client/test_account.py @@ -0,0 +1,1000 @@ +# Copyright 2015-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import re +from email.parser import Parser +from typing import Optional + +import pkg_resources + +import synapse.rest.admin +from synapse.api.constants import LoginType, Membership +from synapse.api.errors import Codes, HttpResponseException +from synapse.appservice import ApplicationService +from synapse.rest.client import account, login, register, room +from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource + +from tests import unittest +from tests.server import FakeSite, make_request +from tests.unittest import override_config + + +class PasswordResetTestCase(unittest.HomeserverTestCase): + + servlets = [ + account.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + register.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Email config. + config["email"] = { + "enable_notifs": False, + "template_dir": os.path.abspath( + pkg_resources.resource_filename("synapse", "res/templates") + ), + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "https://example.com" + + hs = self.setup_test_homeserver(config=config) + + async def sendmail( + reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs + ): + self.email_attempts.append(msg) + + self.email_attempts = [] + hs.get_send_email_handler()._sendmail = sendmail + + return hs + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.submit_token_resource = PasswordResetSubmitTokenResource(hs) + + def test_basic_password_reset(self): + """Test basic password reset flow""" + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + client_secret = "foobar" + session_id = self._request_token(email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + self._reset_password(new_password, session_id, client_secret) + + # Assert we can log in with the new password + self.login("kermit", new_password) + + # Assert we can't log in with the old password + self.attempt_wrong_password_login("kermit", old_password) + + @override_config({"rc_3pid_validation": {"burst_count": 3}}) + def test_ratelimit_by_email(self): + """Test that we ratelimit /requestToken for the same email.""" + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test1@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + def reset(ip): + client_secret = "foobar" + session_id = self._request_token(email, client_secret, ip) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + self._reset_password(new_password, session_id, client_secret) + + self.email_attempts.clear() + + # We expect to be able to make three requests before getting rate + # limited. + # + # We change IPs to ensure that we're not being ratelimited due to the + # same IP + reset("127.0.0.1") + reset("127.0.0.2") + reset("127.0.0.3") + + with self.assertRaises(HttpResponseException) as cm: + reset("127.0.0.4") + + self.assertEqual(cm.exception.code, 429) + + def test_basic_password_reset_canonicalise_email(self): + """Test basic password reset flow + Request password reset with different spelling + """ + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email_profile = "test@example.com" + email_passwort_reset = "TEST@EXAMPLE.COM" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email_profile, + validated_at=0, + added_at=0, + ) + ) + + client_secret = "foobar" + session_id = self._request_token(email_passwort_reset, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + self._reset_password(new_password, session_id, client_secret) + + # Assert we can log in with the new password + self.login("kermit", new_password) + + # Assert we can't log in with the old password + self.attempt_wrong_password_login("kermit", old_password) + + def test_cant_reset_password_without_clicking_link(self): + """Test that we do actually need to click the link in the email""" + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + client_secret = "foobar" + session_id = self._request_token(email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + + # Attempt to reset password without clicking the link + self._reset_password(new_password, session_id, client_secret, expected_code=401) + + # Assert we can log in with the old password + self.login("kermit", old_password) + + # Assert we can't log in with the new password + self.attempt_wrong_password_login("kermit", new_password) + + def test_no_valid_token(self): + """Test that we do actually need to request a token and can't just + make a session up. + """ + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + client_secret = "foobar" + session_id = "weasle" + + # Attempt to reset password without even requesting an email + self._reset_password(new_password, session_id, client_secret, expected_code=401) + + # Assert we can log in with the old password + self.login("kermit", old_password) + + # Assert we can't log in with the new password + self.attempt_wrong_password_login("kermit", new_password) + + @unittest.override_config({"request_token_inhibit_3pid_errors": True}) + def test_password_reset_bad_email_inhibit_error(self): + """Test that triggering a password reset with an email address that isn't bound + to an account doesn't leak the lack of binding for that address if configured + that way. + """ + self.register_user("kermit", "monkey") + self.login("kermit", "monkey") + + email = "test@example.com" + + client_secret = "foobar" + session_id = self._request_token(email, client_secret) + + self.assertIsNotNone(session_id) + + def _request_token(self, email, client_secret, ip="127.0.0.1"): + channel = self.make_request( + "POST", + b"account/password/email/requestToken", + {"client_secret": client_secret, "email": email, "send_attempt": 1}, + client_ip=ip, + ) + + if channel.code != 200: + raise HttpResponseException( + channel.code, + channel.result["reason"], + channel.result["body"], + ) + + return channel.json_body["sid"] + + def _validate_token(self, link): + # Remove the host + path = link.replace("https://example.com", "") + + # Load the password reset confirmation page + channel = make_request( + self.reactor, + FakeSite(self.submit_token_resource), + "GET", + path, + shorthand=False, + ) + + self.assertEquals(200, channel.code, channel.result) + + # Now POST to the same endpoint, mimicking the same behaviour as clicking the + # password reset confirm button + + # Confirm the password reset + channel = make_request( + self.reactor, + FakeSite(self.submit_token_resource), + "POST", + path, + content=b"", + shorthand=False, + content_is_form=True, + ) + self.assertEquals(200, channel.code, channel.result) + + def _get_link_from_email(self): + assert self.email_attempts, "No emails have been sent" + + raw_msg = self.email_attempts[-1].decode("UTF-8") + mail = Parser().parsestr(raw_msg) + + text = None + for part in mail.walk(): + if part.get_content_type() == "text/plain": + text = part.get_payload(decode=True).decode("UTF-8") + break + + if not text: + self.fail("Could not find text portion of email to parse") + + match = re.search(r"https://example.com\S+", text) + assert match, "Could not find link in email" + + return match.group(0) + + def _reset_password( + self, new_password, session_id, client_secret, expected_code=200 + ): + channel = self.make_request( + "POST", + b"account/password", + { + "new_password": new_password, + "auth": { + "type": LoginType.EMAIL_IDENTITY, + "threepid_creds": { + "client_secret": client_secret, + "sid": session_id, + }, + }, + }, + ) + self.assertEquals(expected_code, channel.code, channel.result) + + +class DeactivateTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + account.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + return self.hs + + def test_deactivate_account(self): + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + self.deactivate(user_id, tok) + + store = self.hs.get_datastore() + + # Check that the user has been marked as deactivated. + self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) + + # Check that this access token has been invalidated. + channel = self.make_request("GET", "account/whoami", access_token=tok) + self.assertEqual(channel.code, 401) + + def test_pending_invites(self): + """Tests that deactivating a user rejects every pending invite for them.""" + store = self.hs.get_datastore() + + inviter_id = self.register_user("inviter", "test") + inviter_tok = self.login("inviter", "test") + + invitee_id = self.register_user("invitee", "test") + invitee_tok = self.login("invitee", "test") + + # Make @inviter:test invite @invitee:test in a new room. + room_id = self.helper.create_room_as(inviter_id, tok=inviter_tok) + self.helper.invite( + room=room_id, src=inviter_id, targ=invitee_id, tok=inviter_tok + ) + + # Make sure the invite is here. + pending_invites = self.get_success( + store.get_invited_rooms_for_local_user(invitee_id) + ) + self.assertEqual(len(pending_invites), 1, pending_invites) + self.assertEqual(pending_invites[0].room_id, room_id, pending_invites) + + # Deactivate @invitee:test. + self.deactivate(invitee_id, invitee_tok) + + # Check that the invite isn't there anymore. + pending_invites = self.get_success( + store.get_invited_rooms_for_local_user(invitee_id) + ) + self.assertEqual(len(pending_invites), 0, pending_invites) + + # Check that the membership of @invitee:test in the room is now "leave". + memberships = self.get_success( + store.get_rooms_for_local_user_where_membership_is( + invitee_id, [Membership.LEAVE] + ) + ) + self.assertEqual(len(memberships), 1, memberships) + self.assertEqual(memberships[0].room_id, room_id, memberships) + + def deactivate(self, user_id, tok): + request_data = json.dumps( + { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "test", + }, + "erase": False, + } + ) + channel = self.make_request( + "POST", "account/deactivate", request_data, access_token=tok + ) + self.assertEqual(channel.code, 200) + + +class WhoamiTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + account.register_servlets, + register.register_servlets, + ] + + def test_GET_whoami(self): + device_id = "wouldgohere" + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test", device_id=device_id) + + whoami = self.whoami(tok) + self.assertEqual(whoami, {"user_id": user_id, "device_id": device_id}) + + def test_GET_whoami_appservices(self): + user_id = "@as:test" + as_token = "i_am_an_app_service" + + appservice = ApplicationService( + as_token, + self.hs.config.server_name, + id="1234", + namespaces={"users": [{"regex": user_id, "exclusive": True}]}, + sender=user_id, + ) + self.hs.get_datastore().services_cache.append(appservice) + + whoami = self.whoami(as_token) + self.assertEqual(whoami, {"user_id": user_id}) + self.assertFalse(hasattr(whoami, "device_id")) + + def whoami(self, tok): + channel = self.make_request("GET", "account/whoami", {}, access_token=tok) + self.assertEqual(channel.code, 200) + return channel.json_body + + +class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + account.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Email config. + config["email"] = { + "enable_notifs": False, + "template_dir": os.path.abspath( + pkg_resources.resource_filename("synapse", "res/templates") + ), + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "https://example.com" + + self.hs = self.setup_test_homeserver(config=config) + + async def sendmail( + reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs + ): + self.email_attempts.append(msg) + + self.email_attempts = [] + self.hs.get_send_email_handler()._sendmail = sendmail + + return self.hs + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.user_id = self.register_user("kermit", "test") + self.user_id_tok = self.login("kermit", "test") + self.email = "test@example.com" + self.url_3pid = b"account/3pid" + + def test_add_valid_email(self): + self.get_success(self._add_email(self.email, self.email)) + + def test_add_valid_email_second_time(self): + self.get_success(self._add_email(self.email, self.email)) + self.get_success( + self._request_token_invalid_email( + self.email, + expected_errcode=Codes.THREEPID_IN_USE, + expected_error="Email is already in use", + ) + ) + + def test_add_valid_email_second_time_canonicalise(self): + self.get_success(self._add_email(self.email, self.email)) + self.get_success( + self._request_token_invalid_email( + "TEST@EXAMPLE.COM", + expected_errcode=Codes.THREEPID_IN_USE, + expected_error="Email is already in use", + ) + ) + + def test_add_email_no_at(self): + self.get_success( + self._request_token_invalid_email( + "address-without-at.bar", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", + ) + ) + + def test_add_email_two_at(self): + self.get_success( + self._request_token_invalid_email( + "foo@foo@test.bar", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", + ) + ) + + def test_add_email_bad_format(self): + self.get_success( + self._request_token_invalid_email( + "user@bad.example.net@good.example.com", + expected_errcode=Codes.UNKNOWN, + expected_error="Unable to parse email address", + ) + ) + + def test_add_email_domain_to_lower(self): + self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar")) + + def test_add_email_domain_with_umlaut(self): + self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")) + + def test_add_email_address_casefold(self): + self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com")) + + def test_address_trim(self): + self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar")) + + @override_config({"rc_3pid_validation": {"burst_count": 3}}) + def test_ratelimit_by_ip(self): + """Tests that adding emails is ratelimited by IP""" + + # We expect to be able to set three emails before getting ratelimited. + self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar")) + self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar")) + self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar")) + + with self.assertRaises(HttpResponseException) as cm: + self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar")) + + self.assertEqual(cm.exception.code, 429) + + def test_add_email_if_disabled(self): + """Test adding email to profile when doing so is disallowed""" + self.hs.config.enable_3pid_changes = False + + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + # Get user + channel = self.make_request( + "GET", + self.url_3pid, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email(self): + """Test deleting an email from profile""" + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + channel = self.make_request( + "POST", + b"account/3pid/delete", + {"medium": "email", "address": self.email}, + access_token=self.user_id_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + channel = self.make_request( + "GET", + self.url_3pid, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_delete_email_if_disabled(self): + """Test deleting an email from profile when disallowed""" + self.hs.config.enable_3pid_changes = False + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=self.user_id, + medium="email", + address=self.email, + validated_at=0, + added_at=0, + ) + ) + + channel = self.make_request( + "POST", + b"account/3pid/delete", + {"medium": "email", "address": self.email}, + access_token=self.user_id_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + # Get user + channel = self.make_request( + "GET", + self.url_3pid, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) + + def test_cant_add_email_without_clicking_link(self): + """Test that we do actually need to click the link in the email""" + client_secret = "foobar" + session_id = self._request_token(self.email, client_secret) + + self.assertEquals(len(self.email_attempts), 1) + + # Attempt to add email without clicking the link + channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) + + # Get user + channel = self.make_request( + "GET", + self.url_3pid, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + def test_no_valid_token(self): + """Test that we do actually need to request a token and can't just + make a session up. + """ + client_secret = "foobar" + session_id = "weasle" + + # Attempt to add email without even requesting an email + channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) + + # Get user + channel = self.make_request( + "GET", + self.url_3pid, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertFalse(channel.json_body["threepids"]) + + @override_config({"next_link_domain_whitelist": None}) + def test_next_link(self): + """Tests a valid next_link parameter value with no whitelist (good case)""" + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/a/good/site", + expect_code=200, + ) + + @override_config({"next_link_domain_whitelist": None}) + def test_next_link_exotic_protocol(self): + """Tests using a esoteric protocol as a next_link parameter value. + Someone may be hosting a client on IPFS etc. + """ + self._request_token( + "something@example.com", + "some_secret", + next_link="some-protocol://abcdefghijklmopqrstuvwxyz", + expect_code=200, + ) + + @override_config({"next_link_domain_whitelist": None}) + def test_next_link_file_uri(self): + """Tests next_link parameters cannot be file URI""" + # Attempt to use a next_link value that points to the local disk + self._request_token( + "something@example.com", + "some_secret", + next_link="file:///host/path", + expect_code=400, + ) + + @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) + def test_next_link_domain_whitelist(self): + """Tests next_link parameters must fit the whitelist if provided""" + + # Ensure not providing a next_link parameter still works + self._request_token( + "something@example.com", + "some_secret", + next_link=None, + expect_code=200, + ) + + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/some/good/page", + expect_code=200, + ) + + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.org/some/also/good/page", + expect_code=200, + ) + + self._request_token( + "something@example.com", + "some_secret", + next_link="https://bad.example.org/some/bad/page", + expect_code=400, + ) + + @override_config({"next_link_domain_whitelist": []}) + def test_empty_next_link_domain_whitelist(self): + """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially + disallowed + """ + self._request_token( + "something@example.com", + "some_secret", + next_link="https://example.com/a/page", + expect_code=400, + ) + + def _request_token( + self, + email: str, + client_secret: str, + next_link: Optional[str] = None, + expect_code: int = 200, + ) -> str: + """Request a validation token to add an email address to a user's account + + Args: + email: The email address to validate + client_secret: A secret string + next_link: A link to redirect the user to after validation + expect_code: Expected return code of the call + + Returns: + The ID of the new threepid validation session + """ + body = {"client_secret": client_secret, "email": email, "send_attempt": 1} + if next_link: + body["next_link"] = next_link + + channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + body, + ) + + if channel.code != expect_code: + raise HttpResponseException( + channel.code, + channel.result["reason"], + channel.result["body"], + ) + + return channel.json_body.get("sid") + + def _request_token_invalid_email( + self, + email, + expected_errcode, + expected_error, + client_secret="foobar", + ): + channel = self.make_request( + "POST", + b"account/3pid/email/requestToken", + {"client_secret": client_secret, "email": email, "send_attempt": 1}, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(expected_errcode, channel.json_body["errcode"]) + self.assertEqual(expected_error, channel.json_body["error"]) + + def _validate_token(self, link): + # Remove the host + path = link.replace("https://example.com", "") + + channel = self.make_request("GET", path, shorthand=False) + self.assertEquals(200, channel.code, channel.result) + + def _get_link_from_email(self): + assert self.email_attempts, "No emails have been sent" + + raw_msg = self.email_attempts[-1].decode("UTF-8") + mail = Parser().parsestr(raw_msg) + + text = None + for part in mail.walk(): + if part.get_content_type() == "text/plain": + text = part.get_payload(decode=True).decode("UTF-8") + break + + if not text: + self.fail("Could not find text portion of email to parse") + + match = re.search(r"https://example.com\S+", text) + assert match, "Could not find link in email" + + return match.group(0) + + def _add_email(self, request_email, expected_email): + """Test adding an email to profile""" + previous_email_attempts = len(self.email_attempts) + + client_secret = "foobar" + session_id = self._request_token(request_email, client_secret) + + self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1) + link = self._get_link_from_email() + + self._validate_token(link) + + channel = self.make_request( + "POST", + b"/_matrix/client/unstable/account/3pid/add", + { + "client_secret": client_secret, + "sid": session_id, + "auth": { + "type": "m.login.password", + "user": self.user_id, + "password": "test", + }, + }, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # Get user + channel = self.make_request( + "GET", + self.url_3pid, + access_token=self.user_id_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) + + threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} + self.assertIn(expected_email, threepids) diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py new file mode 100644 index 0000000000..e2fcbdc63a --- /dev/null +++ b/tests/rest/client/test_auth.py @@ -0,0 +1,717 @@ +# Copyright 2018 New Vector +# Copyright 2020-2021 The Matrix.org Foundation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +from twisted.internet.defer import succeed + +import synapse.rest.admin +from synapse.api.constants import LoginType +from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker +from synapse.rest.client import account, auth, devices, login, register +from synapse.rest.synapse.client import build_synapse_client_resource_tree +from synapse.types import JsonDict, UserID + +from tests import unittest +from tests.handlers.test_oidc import HAS_OIDC +from tests.rest.client.utils import TEST_OIDC_CONFIG +from tests.server import FakeChannel +from tests.unittest import override_config, skip_unless + + +class DummyRecaptchaChecker(UserInteractiveAuthChecker): + def __init__(self, hs): + super().__init__(hs) + self.recaptcha_attempts = [] + + def check_auth(self, authdict, clientip): + self.recaptcha_attempts.append((authdict, clientip)) + return succeed(True) + + +class FallbackAuthTests(unittest.HomeserverTestCase): + + servlets = [ + auth.register_servlets, + register.register_servlets, + ] + hijack_auth = False + + def make_homeserver(self, reactor, clock): + + config = self.default_config() + + config["enable_registration_captcha"] = True + config["recaptcha_public_key"] = "brokencake" + config["registrations_require_3pid"] = [] + + hs = self.setup_test_homeserver(config=config) + return hs + + def prepare(self, reactor, clock, hs): + self.recaptcha_checker = DummyRecaptchaChecker(hs) + auth_handler = hs.get_auth_handler() + auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker + + def register(self, expected_response: int, body: JsonDict) -> FakeChannel: + """Make a register request.""" + channel = self.make_request("POST", "register", body) + + self.assertEqual(channel.code, expected_response) + return channel + + def recaptcha( + self, + session: str, + expected_post_response: int, + post_session: Optional[str] = None, + ) -> None: + """Get and respond to a fallback recaptcha. Returns the second request.""" + if post_session is None: + post_session = session + + channel = self.make_request( + "GET", "auth/m.login.recaptcha/fallback/web?session=" + session + ) + self.assertEqual(channel.code, 200) + + channel = self.make_request( + "POST", + "auth/m.login.recaptcha/fallback/web?session=" + + post_session + + "&g-recaptcha-response=a", + ) + self.assertEqual(channel.code, expected_post_response) + + # The recaptcha handler is called with the response given + attempts = self.recaptcha_checker.recaptcha_attempts + self.assertEqual(len(attempts), 1) + self.assertEqual(attempts[0][0]["response"], "a") + + def test_fallback_captcha(self): + """Ensure that fallback auth via a captcha works.""" + # Returns a 401 as per the spec + channel = self.register( + 401, + {"username": "user", "type": "m.login.password", "password": "bar"}, + ) + + # Grab the session + session = channel.json_body["session"] + # Assert our configured public key is being given + self.assertEqual( + channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" + ) + + # Complete the recaptcha step. + self.recaptcha(session, 200) + + # also complete the dummy auth + self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) + + # Now we should have fulfilled a complete auth flow, including + # the recaptcha fallback step, we can then send a + # request to the register API with the session in the authdict. + channel = self.register(200, {"auth": {"session": session}}) + + # We're given a registered user. + self.assertEqual(channel.json_body["user_id"], "@user:test") + + def test_complete_operation_unknown_session(self): + """ + Attempting to mark an invalid session as complete should error. + """ + # Make the initial request to register. (Later on a different password + # will be used.) + # Returns a 401 as per the spec + channel = self.register( + 401, {"username": "user", "type": "m.login.password", "password": "bar"} + ) + + # Grab the session + session = channel.json_body["session"] + # Assert our configured public key is being given + self.assertEqual( + channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" + ) + + # Attempt to complete the recaptcha step with an unknown session. + # This results in an error. + self.recaptcha(session, 400, session + "unknown") + + +class UIAuthTests(unittest.HomeserverTestCase): + servlets = [ + auth.register_servlets, + devices.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + register.register_servlets, + ] + + def default_config(self): + config = super().default_config() + + # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns + # False, so synapse will see the requested uri as http://..., so using http in + # the public_baseurl stops Synapse trying to redirect to https. + config["public_baseurl"] = "http://synapse.test" + + if HAS_OIDC: + # we enable OIDC as a way of testing SSO flows + oidc_config = {} + oidc_config.update(TEST_OIDC_CONFIG) + oidc_config["allow_existing_users"] = True + config["oidc_config"] = oidc_config + + return config + + def create_resource_dict(self): + resource_dict = super().create_resource_dict() + resource_dict.update(build_synapse_client_resource_tree(self.hs)) + return resource_dict + + def prepare(self, reactor, clock, hs): + self.user_pass = "pass" + self.user = self.register_user("test", self.user_pass) + self.device_id = "dev1" + self.user_tok = self.login("test", self.user_pass, self.device_id) + + def delete_device( + self, + access_token: str, + device: str, + expected_response: int, + body: Union[bytes, JsonDict] = b"", + ) -> FakeChannel: + """Delete an individual device.""" + channel = self.make_request( + "DELETE", + "devices/" + device, + body, + access_token=access_token, + ) + + # Ensure the response is sane. + self.assertEqual(channel.code, expected_response) + + return channel + + def delete_devices(self, expected_response: int, body: JsonDict) -> FakeChannel: + """Delete 1 or more devices.""" + # Note that this uses the delete_devices endpoint so that we can modify + # the payload half-way through some tests. + channel = self.make_request( + "POST", + "delete_devices", + body, + access_token=self.user_tok, + ) + + # Ensure the response is sane. + self.assertEqual(channel.code, expected_response) + + return channel + + def test_ui_auth(self): + """ + Test user interactive authentication outside of registration. + """ + # Attempt to delete this device. + # Returns a 401 as per the spec + channel = self.delete_device(self.user_tok, self.device_id, 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow. + self.delete_device( + self.user_tok, + self.device_id, + 200, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + def test_grandfathered_identifier(self): + """Check behaviour without "identifier" dict + + Synapse used to require clients to submit a "user" field for m.login.password + UIA - check that still works. + """ + + channel = self.delete_device(self.user_tok, self.device_id, 401) + session = channel.json_body["session"] + + # Make another request providing the UI auth flow. + self.delete_device( + self.user_tok, + self.device_id, + 200, + { + "auth": { + "type": "m.login.password", + "user": self.user, + "password": self.user_pass, + "session": session, + }, + }, + ) + + def test_can_change_body(self): + """ + The client dict can be modified during the user interactive authentication session. + + Note that it is not spec compliant to modify the client dict during a + user interactive authentication session, but many clients currently do. + + When Synapse is updated to be spec compliant, the call to re-use the + session ID should be rejected. + """ + # Create a second login. + self.login("test", self.user_pass, "dev2") + + # Attempt to delete the first device. + # Returns a 401 as per the spec + channel = self.delete_devices(401, {"devices": [self.device_id]}) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow, but try to delete the + # second device. + self.delete_devices( + 200, + { + "devices": ["dev2"], + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + def test_cannot_change_uri(self): + """ + The initial requested URI cannot be modified during the user interactive authentication session. + """ + # Create a second login. + self.login("test", self.user_pass, "dev2") + + # Attempt to delete the first device. + # Returns a 401 as per the spec + channel = self.delete_device(self.user_tok, self.device_id, 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow, but try to delete the + # second device. This results in an error. + # + # This makes use of the fact that the device ID is embedded into the URL. + self.delete_device( + self.user_tok, + "dev2", + 403, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + @unittest.override_config({"ui_auth": {"session_timeout": "5s"}}) + def test_can_reuse_session(self): + """ + The session can be reused if configured. + + Compare to test_cannot_change_uri. + """ + # Create a second and third login. + self.login("test", self.user_pass, "dev2") + self.login("test", self.user_pass, "dev3") + + # Attempt to delete a device. This works since the user just logged in. + self.delete_device(self.user_tok, "dev2", 200) + + # Move the clock forward past the validation timeout. + self.reactor.advance(6) + + # Deleting another devices throws the user into UI auth. + channel = self.delete_device(self.user_tok, "dev3", 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow. + self.delete_device( + self.user_tok, + "dev3", + 200, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + # Make another request, but try to delete the first device. This works + # due to re-using the previous session. + # + # Note that *no auth* information is provided, not even a session iD! + self.delete_device(self.user_tok, self.device_id, 200) + + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config({"oidc_config": TEST_OIDC_CONFIG}) + def test_ui_auth_via_sso(self): + """Test a successful UI Auth flow via SSO + + This includes: + * hitting the UIA SSO redirect endpoint + * checking it serves a confirmation page which links to the OIDC provider + * calling back to the synapse oidc callback + * checking that the original operation succeeds + """ + + # log the user in + remote_user_id = UserID.from_string(self.user).localpart + login_resp = self.helper.login_via_oidc(remote_user_id) + self.assertEqual(login_resp["user_id"], self.user) + + # initiate a UI Auth process by attempting to delete the device + channel = self.delete_device(self.user_tok, self.device_id, 401) + + # check that SSO is offered + flows = channel.json_body["flows"] + self.assertIn({"stages": ["m.login.sso"]}, flows) + + # run the UIA-via-SSO flow + session_id = channel.json_body["session"] + channel = self.helper.auth_via_oidc( + {"sub": remote_user_id}, ui_auth_session_id=session_id + ) + + # that should serve a confirmation page + self.assertEqual(channel.code, 200, channel.result) + + # and now the delete request should succeed. + self.delete_device( + self.user_tok, + self.device_id, + 200, + body={"auth": {"session": session_id}}, + ) + + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config({"oidc_config": TEST_OIDC_CONFIG}) + def test_does_not_offer_password_for_sso_user(self): + login_resp = self.helper.login_via_oidc("username") + user_tok = login_resp["access_token"] + device_id = login_resp["device_id"] + + # now call the device deletion API: we should get the option to auth with SSO + # and not password. + channel = self.delete_device(user_tok, device_id, 401) + + flows = channel.json_body["flows"] + self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) + + def test_does_not_offer_sso_for_password_user(self): + channel = self.delete_device(self.user_tok, self.device_id, 401) + + flows = channel.json_body["flows"] + self.assertEqual(flows, [{"stages": ["m.login.password"]}]) + + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config({"oidc_config": TEST_OIDC_CONFIG}) + def test_offers_both_flows_for_upgraded_user(self): + """A user that had a password and then logged in with SSO should get both flows""" + login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + self.assertEqual(login_resp["user_id"], self.user) + + channel = self.delete_device(self.user_tok, self.device_id, 401) + + flows = channel.json_body["flows"] + # we have no particular expectations of ordering here + self.assertIn({"stages": ["m.login.password"]}, flows) + self.assertIn({"stages": ["m.login.sso"]}, flows) + self.assertEqual(len(flows), 2) + + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config({"oidc_config": TEST_OIDC_CONFIG}) + def test_ui_auth_fails_for_incorrect_sso_user(self): + """If the user tries to authenticate with the wrong SSO user, they get an error""" + # log the user in + login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + self.assertEqual(login_resp["user_id"], self.user) + + # start a UI Auth flow by attempting to delete a device + channel = self.delete_device(self.user_tok, self.device_id, 401) + + flows = channel.json_body["flows"] + self.assertIn({"stages": ["m.login.sso"]}, flows) + session_id = channel.json_body["session"] + + # do the OIDC auth, but auth as the wrong user + channel = self.helper.auth_via_oidc( + {"sub": "wrong_user"}, ui_auth_session_id=session_id + ) + + # that should return a failure message + self.assertSubstring("We were unable to validate", channel.text_body) + + # ... and the delete op should now fail with a 403 + self.delete_device( + self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}} + ) + + +class RefreshAuthTests(unittest.HomeserverTestCase): + servlets = [ + auth.register_servlets, + account.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + register.register_servlets, + ] + hijack_auth = False + + def prepare(self, reactor, clock, hs): + self.user_pass = "pass" + self.user = self.register_user("test", self.user_pass) + + def test_login_issue_refresh_token(self): + """ + A login response should include a refresh_token only if asked. + """ + # Test login + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + + login_without_refresh = self.make_request( + "POST", "/_matrix/client/r0/login", body + ) + self.assertEqual(login_without_refresh.code, 200, login_without_refresh.result) + self.assertNotIn("refresh_token", login_without_refresh.json_body) + + login_with_refresh = self.make_request( + "POST", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, + ) + self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result) + self.assertIn("refresh_token", login_with_refresh.json_body) + self.assertIn("expires_in_ms", login_with_refresh.json_body) + + def test_register_issue_refresh_token(self): + """ + A register response should include a refresh_token only if asked. + """ + register_without_refresh = self.make_request( + "POST", + "/_matrix/client/r0/register", + { + "username": "test2", + "password": self.user_pass, + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual( + register_without_refresh.code, 200, register_without_refresh.result + ) + self.assertNotIn("refresh_token", register_without_refresh.json_body) + + register_with_refresh = self.make_request( + "POST", + "/_matrix/client/r0/register?org.matrix.msc2918.refresh_token=true", + { + "username": "test3", + "password": self.user_pass, + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result) + self.assertIn("refresh_token", register_with_refresh.json_body) + self.assertIn("expires_in_ms", register_with_refresh.json_body) + + def test_token_refresh(self): + """ + A refresh token can be used to issue a new access token. + """ + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + login_response = self.make_request( + "POST", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, + ) + self.assertEqual(login_response.code, 200, login_response.result) + + refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertIn("access_token", refresh_response.json_body) + self.assertIn("refresh_token", refresh_response.json_body) + self.assertIn("expires_in_ms", refresh_response.json_body) + + # The access and refresh tokens should be different from the original ones after refresh + self.assertNotEqual( + login_response.json_body["access_token"], + refresh_response.json_body["access_token"], + ) + self.assertNotEqual( + login_response.json_body["refresh_token"], + refresh_response.json_body["refresh_token"], + ) + + @override_config({"access_token_lifetime": "1m"}) + def test_refresh_token_expiration(self): + """ + The access token should have some time as specified in the config. + """ + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + login_response = self.make_request( + "POST", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, + ) + self.assertEqual(login_response.code, 200, login_response.result) + self.assertApproximates( + login_response.json_body["expires_in_ms"], 60 * 1000, 100 + ) + + refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertApproximates( + refresh_response.json_body["expires_in_ms"], 60 * 1000, 100 + ) + + def test_refresh_token_invalidation(self): + """Refresh tokens are invalidated after first use of the next token. + + A refresh token is considered invalid if: + - it was already used at least once + - and either + - the next access token was used + - the next refresh token was used + + The chain of tokens goes like this: + + login -|-> first_refresh -> third_refresh (fails) + |-> second_refresh -> fifth_refresh + |-> fourth_refresh (fails) + """ + + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + login_response = self.make_request( + "POST", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, + ) + self.assertEqual(login_response.code, 200, login_response.result) + + # This first refresh should work properly + first_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual( + first_refresh_response.code, 200, first_refresh_response.result + ) + + # This one as well, since the token in the first one was never used + second_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual( + second_refresh_response.code, 200, second_refresh_response.result + ) + + # This one should not, since the token from the first refresh is not valid anymore + third_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": first_refresh_response.json_body["refresh_token"]}, + ) + self.assertEqual( + third_refresh_response.code, 401, third_refresh_response.result + ) + + # The associated access token should also be invalid + whoami_response = self.make_request( + "GET", + "/_matrix/client/r0/account/whoami", + access_token=first_refresh_response.json_body["access_token"], + ) + self.assertEqual(whoami_response.code, 401, whoami_response.result) + + # But all other tokens should work (they will expire after some time) + for access_token in [ + second_refresh_response.json_body["access_token"], + login_response.json_body["access_token"], + ]: + whoami_response = self.make_request( + "GET", "/_matrix/client/r0/account/whoami", access_token=access_token + ) + self.assertEqual(whoami_response.code, 200, whoami_response.result) + + # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail + fourth_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual( + fourth_refresh_response.code, 403, fourth_refresh_response.result + ) + + # But refreshing from the last valid refresh token still works + fifth_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": second_refresh_response.json_body["refresh_token"]}, + ) + self.assertEqual( + fifth_refresh_response.code, 200, fifth_refresh_response.result + ) diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py new file mode 100644 index 0000000000..ac31e5ceaf --- /dev/null +++ b/tests/rest/client/test_capabilities.py @@ -0,0 +1,212 @@ +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import synapse.rest.admin +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.rest.client import capabilities, login + +from tests import unittest +from tests.unittest import override_config + + +class CapabilitiesTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + capabilities.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.url = b"/_matrix/client/r0/capabilities" + hs = self.setup_test_homeserver() + self.config = hs.config + self.auth_handler = hs.get_auth_handler() + return hs + + def prepare(self, reactor, clock, hs): + self.localpart = "user" + self.password = "pass" + self.user = self.register_user(self.localpart, self.password) + + def test_check_auth_required(self): + channel = self.make_request("GET", self.url) + + self.assertEqual(channel.code, 401) + + def test_get_room_version_capabilities(self): + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + for room_version in capabilities["m.room_versions"]["available"].keys(): + self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version) + + self.assertEqual( + self.config.default_room_version.identifier, + capabilities["m.room_versions"]["default"], + ) + + def test_get_change_password_capabilities_password_login(self): + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertTrue(capabilities["m.change_password"]["enabled"]) + + @override_config({"password_config": {"localdb_enabled": False}}) + def test_get_change_password_capabilities_localdb_disabled(self): + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + self.user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["m.change_password"]["enabled"]) + + @override_config({"password_config": {"enabled": False}}) + def test_get_change_password_capabilities_password_disabled(self): + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + self.user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["m.change_password"]["enabled"]) + + def test_get_change_users_attributes_capabilities_when_msc3283_disabled(self): + """Test that per default msc3283 is disabled server returns `m.change_password`.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertTrue(capabilities["m.change_password"]["enabled"]) + self.assertNotIn("org.matrix.msc3283.set_displayname", capabilities) + self.assertNotIn("org.matrix.msc3283.set_avatar_url", capabilities) + self.assertNotIn("org.matrix.msc3283.3pid_changes", capabilities) + + @override_config({"experimental_features": {"msc3283_enabled": True}}) + def test_get_change_users_attributes_capabilities_when_msc3283_enabled(self): + """Test if msc3283 is enabled server returns capabilities.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertTrue(capabilities["m.change_password"]["enabled"]) + self.assertTrue(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) + self.assertTrue(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) + self.assertTrue(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) + + @override_config( + { + "enable_set_displayname": False, + "experimental_features": {"msc3283_enabled": True}, + } + ) + def test_get_set_displayname_capabilities_displayname_disabled(self): + """Test if set displayname is disabled that the server responds it.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) + + @override_config( + { + "enable_set_avatar_url": False, + "experimental_features": {"msc3283_enabled": True}, + } + ) + def test_get_set_avatar_url_capabilities_avatar_url_disabled(self): + """Test if set avatar_url is disabled that the server responds it.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) + + @override_config( + { + "enable_3pid_changes": False, + "experimental_features": {"msc3283_enabled": True}, + } + ) + def test_change_3pid_capabilities_3pid_disabled(self): + """Test if change 3pid is disabled that the server responds it.""" + access_token = self.login(self.localpart, self.password) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) + + def test_get_does_not_include_msc3244_fields_by_default(self): + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + self.user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertNotIn( + "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"] + ) + + @override_config({"experimental_features": {"msc3244_enabled": True}}) + def test_get_does_include_msc3244_fields_when_enabled(self): + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + self.user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + for details in capabilities["m.room_versions"][ + "org.matrix.msc3244.room_capabilities" + ].values(): + if details["preferred"] is not None: + self.assertTrue( + details["preferred"] in KNOWN_ROOM_VERSIONS, + str(details["preferred"]), + ) + + self.assertGreater(len(details["support"]), 0) + for room_version in details["support"]: + self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, str(room_version)) diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py new file mode 100644 index 0000000000..d2181ea907 --- /dev/null +++ b/tests/rest/client/test_directory.py @@ -0,0 +1,154 @@ +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from synapse.rest import admin +from synapse.rest.client import directory, login, room +from synapse.types import RoomAlias +from synapse.util.stringutils import random_string + +from tests import unittest +from tests.unittest import override_config + + +class DirectoryTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["require_membership_for_aliases"] = True + + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + self.user = self.register_user("user", "test") + self.user_tok = self.login("user", "test") + + def test_state_event_not_in_room(self): + self.ensure_user_left_room() + self.set_alias_via_state_event(403) + + def test_directory_endpoint_not_in_room(self): + self.ensure_user_left_room() + self.set_alias_via_directory(403) + + def test_state_event_in_room_too_long(self): + self.ensure_user_joined_room() + self.set_alias_via_state_event(400, alias_length=256) + + def test_directory_in_room_too_long(self): + self.ensure_user_joined_room() + self.set_alias_via_directory(400, alias_length=256) + + @override_config({"default_room_version": 5}) + def test_state_event_user_in_v5_room(self): + """Test that a regular user can add alias events before room v6""" + self.ensure_user_joined_room() + self.set_alias_via_state_event(200) + + @override_config({"default_room_version": 6}) + def test_state_event_v6_room(self): + """Test that a regular user can *not* add alias events from room v6""" + self.ensure_user_joined_room() + self.set_alias_via_state_event(403) + + def test_directory_in_room(self): + self.ensure_user_joined_room() + self.set_alias_via_directory(200) + + def test_room_creation_too_long(self): + url = "/_matrix/client/r0/createRoom" + + # We use deliberately a localpart under the length threshold so + # that we can make sure that the check is done on the whole alias. + data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} + request_data = json.dumps(data) + channel = self.make_request( + "POST", url, request_data, access_token=self.user_tok + ) + self.assertEqual(channel.code, 400, channel.result) + + def test_room_creation(self): + url = "/_matrix/client/r0/createRoom" + + # Check with an alias of allowed length. There should already be + # a test that ensures it works in test_register.py, but let's be + # as cautious as possible here. + data = {"room_alias_name": random_string(5)} + request_data = json.dumps(data) + channel = self.make_request( + "POST", url, request_data, access_token=self.user_tok + ) + self.assertEqual(channel.code, 200, channel.result) + + def set_alias_via_state_event(self, expected_code, alias_length=5): + url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( + self.room_id, + self.hs.hostname, + ) + + data = {"aliases": [self.random_alias(alias_length)]} + request_data = json.dumps(data) + + channel = self.make_request( + "PUT", url, request_data, access_token=self.user_tok + ) + self.assertEqual(channel.code, expected_code, channel.result) + + def set_alias_via_directory(self, expected_code, alias_length=5): + url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length) + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + channel = self.make_request( + "PUT", url, request_data, access_token=self.user_tok + ) + self.assertEqual(channel.code, expected_code, channel.result) + + def random_alias(self, length): + return RoomAlias(random_string(length), self.hs.hostname).to_string() + + def ensure_user_left_room(self): + self.ensure_membership("leave") + + def ensure_user_joined_room(self): + self.ensure_membership("join") + + def ensure_membership(self, membership): + try: + if membership == "leave": + self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok) + if membership == "join": + self.helper.join(room=self.room_id, user=self.user, tok=self.user_tok) + except AssertionError: + # We don't care whether the leave request didn't return a 200 (e.g. + # if the user isn't already in the room), because we only want to + # make sure the user isn't in the room. + pass diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py new file mode 100644 index 0000000000..a90294003e --- /dev/null +++ b/tests/rest/client/test_events.py @@ -0,0 +1,156 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Tests REST events for /events paths.""" + +from unittest.mock import Mock + +import synapse.rest.admin +from synapse.rest.client import events, login, room + +from tests import unittest + + +class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): + """Tests event streaming (GET /events).""" + + servlets = [ + events.register_servlets, + room.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + + config = self.default_config() + config["enable_registration_captcha"] = False + config["enable_registration"] = True + config["auto_join_rooms"] = [] + + hs = self.setup_test_homeserver(config=config) + + hs.get_federation_handler = Mock() + + return hs + + def prepare(self, reactor, clock, hs): + + # register an account + self.user_id = self.register_user("sid1", "pass") + self.token = self.login(self.user_id, "pass") + + # register a 2nd account + self.other_user = self.register_user("other2", "pass") + self.other_token = self.login(self.other_user, "pass") + + def test_stream_basic_permissions(self): + # invalid token, expect 401 + # note: this is in violation of the original v1 spec, which expected + # 403. However, since the v1 spec no longer exists and the v1 + # implementation is now part of the r0 implementation, the newer + # behaviour is used instead to be consistent with the r0 spec. + # see issue #2602 + channel = self.make_request( + "GET", "/events?access_token=%s" % ("invalid" + self.token,) + ) + self.assertEquals(channel.code, 401, msg=channel.result) + + # valid token, expect content + channel = self.make_request( + "GET", "/events?access_token=%s&timeout=0" % (self.token,) + ) + self.assertEquals(channel.code, 200, msg=channel.result) + self.assertTrue("chunk" in channel.json_body) + self.assertTrue("start" in channel.json_body) + self.assertTrue("end" in channel.json_body) + + def test_stream_room_permissions(self): + room_id = self.helper.create_room_as(self.other_user, tok=self.other_token) + self.helper.send(room_id, tok=self.other_token) + + # invited to room (expect no content for room) + self.helper.invite( + room_id, src=self.other_user, targ=self.user_id, tok=self.other_token + ) + + # valid token, expect content + channel = self.make_request( + "GET", "/events?access_token=%s&timeout=0" % (self.token,) + ) + self.assertEquals(channel.code, 200, msg=channel.result) + + # We may get a presence event for ourselves down + self.assertEquals( + 0, + len( + [ + c + for c in channel.json_body["chunk"] + if not ( + c.get("type") == "m.presence" + and c["content"].get("user_id") == self.user_id + ) + ] + ), + ) + + # joined room (expect all content for room) + self.helper.join(room=room_id, user=self.user_id, tok=self.token) + + # left to room (expect no content for room) + + def TODO_test_stream_items(self): + # new user, no content + + # join room, expect 1 item (join) + + # send message, expect 2 items (join,send) + + # set topic, expect 3 items (join,send,topic) + + # someone else join room, expect 4 (join,send,topic,join) + + # someone else send message, expect 5 (join,send.topic,join,send) + + # someone else set topic, expect 6 (join,send,topic,join,send,topic) + pass + + +class GetEventsTestCase(unittest.HomeserverTestCase): + servlets = [ + events.register_servlets, + room.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def prepare(self, hs, reactor, clock): + + # register an account + self.user_id = self.register_user("sid1", "pass") + self.token = self.login(self.user_id, "pass") + + self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + def test_get_event_via_events(self): + resp = self.helper.send(self.room_id, tok=self.token) + event_id = resp["event_id"] + + channel = self.make_request( + "GET", + "/events/" + event_id, + access_token=self.token, + ) + self.assertEquals(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py new file mode 100644 index 0000000000..475c6bed3d --- /dev/null +++ b/tests/rest/client/test_filter.py @@ -0,0 +1,111 @@ +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import Codes +from synapse.rest.client import filter + +from tests import unittest + +PATH_PREFIX = "/_matrix/client/v2_alpha" + + +class FilterTestCase(unittest.HomeserverTestCase): + + user_id = "@apple:test" + hijack_auth = True + EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} + EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}' + servlets = [filter.register_servlets] + + def prepare(self, reactor, clock, hs): + self.filtering = hs.get_filtering() + self.store = hs.get_datastore() + + def test_add_filter(self): + channel = self.make_request( + "POST", + "/_matrix/client/r0/user/%s/filter" % (self.user_id), + self.EXAMPLE_FILTER_JSON, + ) + + self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.json_body, {"filter_id": "0"}) + filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) + self.pump() + self.assertEquals(filter.result, self.EXAMPLE_FILTER) + + def test_add_filter_for_other_user(self): + channel = self.make_request( + "POST", + "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), + self.EXAMPLE_FILTER_JSON, + ) + + self.assertEqual(channel.result["code"], b"403") + self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) + + def test_add_filter_non_local_user(self): + _is_mine = self.hs.is_mine + self.hs.is_mine = lambda target_user: False + channel = self.make_request( + "POST", + "/_matrix/client/r0/user/%s/filter" % (self.user_id), + self.EXAMPLE_FILTER_JSON, + ) + + self.hs.is_mine = _is_mine + self.assertEqual(channel.result["code"], b"403") + self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) + + def test_get_filter(self): + filter_id = defer.ensureDeferred( + self.filtering.add_user_filter( + user_localpart="apple", user_filter=self.EXAMPLE_FILTER + ) + ) + self.reactor.advance(1) + filter_id = filter_id.result + channel = self.make_request( + "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) + ) + + self.assertEqual(channel.result["code"], b"200") + self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) + + def test_get_filter_non_existant(self): + channel = self.make_request( + "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) + ) + + self.assertEqual(channel.result["code"], b"404") + self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) + + # Currently invalid params do not have an appropriate errcode + # in errors.py + def test_get_filter_invalid_id(self): + channel = self.make_request( + "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) + ) + + self.assertEqual(channel.result["code"], b"400") + + # No ID also returns an invalid_id error + def test_get_filter_no_id(self): + channel = self.make_request( + "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) + ) + + self.assertEqual(channel.result["code"], b"400") diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py new file mode 100644 index 0000000000..d7fa635eae --- /dev/null +++ b/tests/rest/client/test_keys.py @@ -0,0 +1,91 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from http import HTTPStatus + +from synapse.api.errors import Codes +from synapse.rest import admin +from synapse.rest.client import keys, login + +from tests import unittest + + +class KeyQueryTestCase(unittest.HomeserverTestCase): + servlets = [ + keys.register_servlets, + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def test_rejects_device_id_ice_key_outside_of_list(self): + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + bob = self.register_user("bob", "uncle") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + bob: "device_id1", + }, + }, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + ) + + def test_rejects_device_key_given_as_map_to_bool(self): + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + bob = self.register_user("bob", "uncle") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + bob: { + "device_id1": True, + }, + }, + }, + alice_token, + ) + + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + ) + + def test_requires_device_key(self): + """`device_keys` is required. We should complain if it's missing.""" + self.register_user("alice", "wonderland") + alice_token = self.login("alice", "wonderland") + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + {}, + alice_token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.BAD_JSON, + channel.result, + ) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py new file mode 100644 index 0000000000..5b2243fe52 --- /dev/null +++ b/tests/rest/client/test_login.py @@ -0,0 +1,1345 @@ +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import urllib.parse +from typing import Any, Dict, List, Optional, Union +from unittest.mock import Mock +from urllib.parse import urlencode + +import pymacaroons + +from twisted.web.resource import Resource + +import synapse.rest.admin +from synapse.appservice import ApplicationService +from synapse.rest.client import devices, login, logout, register +from synapse.rest.client.account import WhoamiRestServlet +from synapse.rest.synapse.client import build_synapse_client_resource_tree +from synapse.types import create_requester + +from tests import unittest +from tests.handlers.test_oidc import HAS_OIDC +from tests.handlers.test_saml import has_saml2 +from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.test_utils.html_parsers import TestHtmlParser +from tests.unittest import HomeserverTestCase, override_config, skip_unless + +try: + import jwt + + HAS_JWT = True +except ImportError: + HAS_JWT = False + + +# synapse server name: used to populate public_baseurl in some tests +SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse" + +# public_baseurl for some tests. It uses an http:// scheme because +# FakeChannel.isSecure() returns False, so synapse will see the requested uri as +# http://..., so using http in the public_baseurl stops Synapse trying to redirect to +# https://.... +BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,) + +# CAS server used in some tests +CAS_SERVER = "https://fake.test" + +# just enough to tell pysaml2 where to redirect to +SAML_SERVER = "https://test.saml.server/idp/sso" +TEST_SAML_METADATA = """ + + + + + +""" % { + "SAML_SERVER": SAML_SERVER, +} + +LOGIN_URL = b"/_matrix/client/r0/login" +TEST_URL = b"/_matrix/client/r0/account/whoami" + +# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is + +TEST_CLIENT_REDIRECT_URL = 'https://x?&q"+%3D%2B"="fö%26=o"' + +# the query params in TEST_CLIENT_REDIRECT_URL +EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("", ""), ('q" =+"', '"fö&=o"')] + +# (possibly experimental) login flows we expect to appear in the list after the normal +# ones +ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] + + +class LoginRestServletTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + logout.register_servlets, + devices.register_servlets, + lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), + ] + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + self.hs.config.enable_registration = True + self.hs.config.registrations_require_3pid = [] + self.hs.config.auto_join_rooms = [] + self.hs.config.enable_registration_captcha = False + + return self.hs + + @override_config( + { + "rc_login": { + "address": {"per_second": 0.17, "burst_count": 5}, + # Prevent the account login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "account": {"per_second": 10000, "burst_count": 10000}, + } + } + ) + def test_POST_ratelimiting_per_address(self): + # Create different users so we're sure not to be bothered by the per-user + # ratelimiter. + for i in range(0, 6): + self.register_user("kermit" + str(i), "monkey") + + for i in range(0, 6): + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, + "password": "monkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower + # than 1min. + self.assertTrue(retry_after_ms < 6000) + + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) + + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, + "password": "monkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + @override_config( + { + "rc_login": { + "account": {"per_second": 0.17, "burst_count": 5}, + # Prevent the address login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "address": {"per_second": 10000, "burst_count": 10000}, + } + } + ) + def test_POST_ratelimiting_per_account(self): + self.register_user("kermit", "monkey") + + for i in range(0, 6): + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "monkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower + # than 1min. + self.assertTrue(retry_after_ms < 6000) + + self.reactor.advance(retry_after_ms / 1000.0) + + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "monkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + @override_config( + { + "rc_login": { + # Prevent the address login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "address": {"per_second": 10000, "burst_count": 10000}, + "failed_attempts": {"per_second": 0.17, "burst_count": 5}, + } + } + ) + def test_POST_ratelimiting_per_account_failed_attempts(self): + self.register_user("kermit", "monkey") + + for i in range(0, 6): + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "notamonkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"403", channel.result) + + # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower + # than 1min. + self.assertTrue(retry_after_ms < 6000) + + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) + + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "notamonkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + self.assertEquals(channel.result["code"], b"403", channel.result) + + @override_config({"session_lifetime": "24h"}) + def test_soft_logout(self): + self.register_user("kermit", "monkey") + + # we shouldn't be able to make requests without an access token + channel = self.make_request(b"GET", TEST_URL) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN") + + # log in as normal + params = { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "monkey", + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + self.assertEquals(channel.code, 200, channel.result) + access_token = channel.json_body["access_token"] + device_id = channel.json_body["device_id"] + + # we should now be able to make requests with the access token + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 200, channel.result) + + # time passes + self.reactor.advance(24 * 3600) + + # ... and we should be soft-logouted + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # + # test behaviour after deleting the expired device + # + + # we now log in as a different device + access_token_2 = self.login("kermit", "monkey") + + # more requests with the expired token should still return a soft-logout + self.reactor.advance(3600) + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # ... but if we delete that device, it will be a proper logout + self._delete_device(access_token_2, "kermit", "monkey", device_id) + + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], False) + + def _delete_device(self, access_token, user_id, password, device_id): + """Perform the UI-Auth to delete a device""" + channel = self.make_request( + b"DELETE", "devices/" + device_id, access_token=access_token + ) + self.assertEquals(channel.code, 401, channel.result) + # check it's a UI-Auth fail + self.assertEqual( + set(channel.json_body.keys()), + {"flows", "params", "session"}, + channel.result, + ) + + auth = { + "type": "m.login.password", + # https://github.com/matrix-org/synapse/issues/5665 + # "identifier": {"type": "m.id.user", "user": user_id}, + "user": user_id, + "password": password, + "session": channel.json_body["session"], + } + + channel = self.make_request( + b"DELETE", + "devices/" + device_id, + access_token=access_token, + content={"auth": auth}, + ) + self.assertEquals(channel.code, 200, channel.result) + + @override_config({"session_lifetime": "24h"}) + def test_session_can_hard_logout_after_being_soft_logged_out(self): + self.register_user("kermit", "monkey") + + # log in as normal + access_token = self.login("kermit", "monkey") + + # we should now be able to make requests with the access token + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 200, channel.result) + + # time passes + self.reactor.advance(24 * 3600) + + # ... and we should be soft-logouted + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # Now try to hard logout this session + channel = self.make_request(b"POST", "/logout", access_token=access_token) + self.assertEquals(channel.result["code"], b"200", channel.result) + + @override_config({"session_lifetime": "24h"}) + def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self): + self.register_user("kermit", "monkey") + + # log in as normal + access_token = self.login("kermit", "monkey") + + # we should now be able to make requests with the access token + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 200, channel.result) + + # time passes + self.reactor.advance(24 * 3600) + + # ... and we should be soft-logouted + channel = self.make_request(b"GET", TEST_URL, access_token=access_token) + self.assertEquals(channel.code, 401, channel.result) + self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") + self.assertEquals(channel.json_body["soft_logout"], True) + + # Now try to hard log out all of the user's sessions + channel = self.make_request(b"POST", "/logout/all", access_token=access_token) + self.assertEquals(channel.result["code"], b"200", channel.result) + + +@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") +class MultiSSOTestCase(unittest.HomeserverTestCase): + """Tests for homeservers with multiple SSO providers enabled""" + + servlets = [ + login.register_servlets, + ] + + def default_config(self) -> Dict[str, Any]: + config = super().default_config() + + config["public_baseurl"] = BASE_URL + + config["cas_config"] = { + "enabled": True, + "server_url": CAS_SERVER, + "service_url": "https://matrix.goodserver.com:8448", + } + + config["saml2_config"] = { + "sp_config": { + "metadata": {"inline": [TEST_SAML_METADATA]}, + # use the XMLSecurity backend to avoid relying on xmlsec1 + "crypto_backend": "XMLSecurity", + }, + } + + # default OIDC provider + config["oidc_config"] = TEST_OIDC_CONFIG + + # additional OIDC providers + config["oidc_providers"] = [ + { + "idp_id": "idp1", + "idp_name": "IDP1", + "discover": False, + "issuer": "https://issuer1", + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["profile"], + "authorization_endpoint": "https://issuer1/auth", + "token_endpoint": "https://issuer1/token", + "userinfo_endpoint": "https://issuer1/userinfo", + "user_mapping_provider": { + "config": {"localpart_template": "{{ user.sub }}"} + }, + } + ] + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d.update(build_synapse_client_resource_tree(self.hs)) + return d + + def test_get_login_flows(self): + """GET /login should return password and SSO flows""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + expected_flow_types = [ + "m.login.cas", + "m.login.sso", + "m.login.token", + "m.login.password", + ] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS] + + self.assertCountEqual( + [f["type"] for f in channel.json_body["flows"]], expected_flow_types + ) + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_get_msc2858_login_flows(self): + """The SSO flow should include IdP info if MSC2858 is enabled""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + # stick the flows results in a dict by type + flow_results: Dict[str, Any] = {} + for f in channel.json_body["flows"]: + flow_type = f["type"] + self.assertNotIn( + flow_type, flow_results, "duplicate flow type %s" % (flow_type,) + ) + flow_results[flow_type] = f + + self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned") + sso_flow = flow_results.pop("m.login.sso") + # we should have a set of IdPs + self.assertCountEqual( + sso_flow["org.matrix.msc2858.identity_providers"], + [ + {"id": "cas", "name": "CAS"}, + {"id": "saml", "name": "SAML"}, + {"id": "oidc-idp1", "name": "IDP1"}, + {"id": "oidc", "name": "OIDC"}, + ], + ) + + # the rest of the flows are simple + expected_flows = [ + {"type": "m.login.cas"}, + {"type": "m.login.token"}, + {"type": "m.login.password"}, + ] + ADDITIONAL_LOGIN_FLOWS + + self.assertCountEqual(flow_results.values(), expected_flows) + + def test_multi_sso_redirect(self): + """/login/sso/redirect should redirect to an identity picker""" + # first hit the redirect url, which should redirect to our idp picker + channel = self._make_sso_redirect_request(False, None) + self.assertEqual(channel.code, 302, channel.result) + uri = channel.headers.getRawHeaders("Location")[0] + + # hitting that picker should give us some HTML + channel = self.make_request("GET", uri) + self.assertEqual(channel.code, 200, channel.result) + + # parse the form to check it has fields assumed elsewhere in this class + html = channel.result["body"].decode("utf-8") + p = TestHtmlParser() + p.feed(html) + p.close() + + # there should be a link for each href + returned_idps: List[str] = [] + for link in p.links: + path, query = link.split("?", 1) + self.assertEqual(path, "pick_idp") + params = urllib.parse.parse_qs(query) + self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL]) + returned_idps.append(params["idp"][0]) + + self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"]) + + def test_multi_sso_redirect_to_cas(self): + """If CAS is chosen, should redirect to the CAS server""" + + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=cas", + shorthand=False, + ) + self.assertEqual(channel.code, 302, channel.result) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + cas_uri = location_headers[0] + cas_uri_path, cas_uri_query = cas_uri.split("?", 1) + + # it should redirect us to the login page of the cas server + self.assertEqual(cas_uri_path, CAS_SERVER + "/login") + + # check that the redirectUrl is correctly encoded in the service param - ie, the + # place that CAS will redirect to + cas_uri_params = urllib.parse.parse_qs(cas_uri_query) + service_uri = cas_uri_params["service"][0] + _, service_uri_query = service_uri.split("?", 1) + service_uri_params = urllib.parse.parse_qs(service_uri_query) + self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) + + def test_multi_sso_redirect_to_saml(self): + """If SAML is chosen, should redirect to the SAML server""" + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=saml", + ) + self.assertEqual(channel.code, 302, channel.result) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + saml_uri = location_headers[0] + saml_uri_path, saml_uri_query = saml_uri.split("?", 1) + + # it should redirect us to the login page of the SAML server + self.assertEqual(saml_uri_path, SAML_SERVER) + + # the RelayState is used to carry the client redirect url + saml_uri_params = urllib.parse.parse_qs(saml_uri_query) + relay_state_param = saml_uri_params["RelayState"][0] + self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) + + def test_login_via_oidc(self): + """If OIDC is chosen, should redirect to the OIDC auth endpoint""" + + # pick the default OIDC provider + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=oidc", + ) + self.assertEqual(channel.code, 302, channel.result) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + oidc_uri = location_headers[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + + # ... and should have set a cookie including the redirect url + cookie_headers = channel.headers.getRawHeaders("Set-Cookie") + assert cookie_headers + cookies: Dict[str, str] = {} + for h in cookie_headers: + key, value = h.split(";")[0].split("=", maxsplit=1) + cookies[key] = value + + oidc_session_cookie = cookies["oidc_session"] + macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) + self.assertEqual( + self._get_value_from_macaroon(macaroon, "client_redirect_url"), + TEST_CLIENT_REDIRECT_URL, + ) + + channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) + + # that should serve a confirmation page + self.assertEqual(channel.code, 200, channel.result) + content_type_headers = channel.headers.getRawHeaders("Content-Type") + assert content_type_headers + self.assertTrue(content_type_headers[-1].startswith("text/html")) + p = TestHtmlParser() + p.feed(channel.text_body) + p.close() + + # ... which should contain our redirect link + self.assertEqual(len(p.links), 1) + path, query = p.links[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" + ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + login_token = params[2][1] + chan = self.make_request( + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], "@user1:test") + + def test_multi_sso_redirect_to_unknown(self): + """An unknown IdP should cause a 400""" + channel = self.make_request( + "GET", + "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", + ) + self.assertEqual(channel.code, 400, channel.result) + + def test_client_idp_redirect_to_unknown(self): + """If the client tries to pick an unknown IdP, return a 404""" + channel = self._make_sso_redirect_request(False, "xxx") + self.assertEqual(channel.code, 404, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") + + def test_client_idp_redirect_to_oidc(self): + """If the client pick a known IdP, redirect to it""" + channel = self._make_sso_redirect_request(False, "oidc") + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_msc2858_redirect_to_oidc(self): + """Test the unstable API""" + channel = self._make_sso_redirect_request(True, "oidc") + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + + def test_client_idp_redirect_msc2858_disabled(self): + """If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400""" + channel = self._make_sso_redirect_request(True, "oidc") + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + + def _make_sso_redirect_request( + self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None + ): + """Send a request to /_matrix/client/r0/login/sso/redirect + + ... or the unstable equivalent + + ... possibly specifying an IDP provider + """ + endpoint = ( + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect" + if unstable_endpoint + else "/_matrix/client/r0/login/sso/redirect" + ) + if idp_prov is not None: + endpoint += "/" + idp_prov + endpoint += "?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + return self.make_request( + "GET", + endpoint, + custom_headers=[("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME)], + ) + + @staticmethod + def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: + prefix = key + " = " + for caveat in macaroon.caveats: + if caveat.caveat_id.startswith(prefix): + return caveat.caveat_id[len(prefix) :] + raise ValueError("No %s caveat in macaroon" % (key,)) + + +class CASTestCase(unittest.HomeserverTestCase): + + servlets = [ + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.base_url = "https://matrix.goodserver.com/" + self.redirect_path = "_synapse/client/login/sso/redirect/confirm" + + config = self.default_config() + config["public_baseurl"] = ( + config.get("public_baseurl") or "https://matrix.goodserver.com:8448" + ) + config["cas_config"] = { + "enabled": True, + "server_url": CAS_SERVER, + } + + cas_user_id = "username" + self.user_id = "@%s:test" % cas_user_id + + async def get_raw(uri, args): + """Return an example response payload from a call to the `/proxyValidate` + endpoint of a CAS server, copied from + https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20 + + This needs to be returned by an async function (as opposed to set as the + mock's return value) because the corresponding Synapse code awaits on it. + """ + return ( + """ + + + %s + PGTIOU-84678-8a9d... + + https://proxy2/pgtUrl + https://proxy1/pgtUrl + + + + """ + % cas_user_id + ).encode("utf-8") + + mocked_http_client = Mock(spec=["get_raw"]) + mocked_http_client.get_raw.side_effect = get_raw + + self.hs = self.setup_test_homeserver( + config=config, + proxied_http_client=mocked_http_client, + ) + + return self.hs + + def prepare(self, reactor, clock, hs): + self.deactivate_account_handler = hs.get_deactivate_account_handler() + + def test_cas_redirect_confirm(self): + """Tests that the SSO login flow serves a confirmation page before redirecting a + user to the redirect URL. + """ + base_url = "/_matrix/client/r0/login/cas/ticket?redirectUrl" + redirect_url = "https://dodgy-site.com/" + + url_parts = list(urllib.parse.urlparse(base_url)) + query = dict(urllib.parse.parse_qsl(url_parts[4])) + query.update({"redirectUrl": redirect_url}) + query.update({"ticket": "ticket"}) + url_parts[4] = urllib.parse.urlencode(query) + cas_ticket_url = urllib.parse.urlunparse(url_parts) + + # Get Synapse to call the fake CAS and serve the template. + channel = self.make_request("GET", cas_ticket_url) + + # Test that the response is HTML. + self.assertEqual(channel.code, 200, channel.result) + content_type_header_value = "" + for header in channel.result.get("headers", []): + if header[0] == b"Content-Type": + content_type_header_value = header[1].decode("utf8") + + self.assertTrue(content_type_header_value.startswith("text/html")) + + # Test that the body isn't empty. + self.assertTrue(len(channel.result["body"]) > 0) + + # And that it contains our redirect link + self.assertIn(redirect_url, channel.result["body"].decode("UTF-8")) + + @override_config( + { + "sso": { + "client_whitelist": [ + "https://legit-site.com/", + "https://other-site.com/", + ] + } + } + ) + def test_cas_redirect_whitelisted(self): + """Tests that the SSO login flow serves a redirect to a whitelisted url""" + self._test_redirect("https://legit-site.com/") + + @override_config({"public_baseurl": "https://example.com"}) + def test_cas_redirect_login_fallback(self): + self._test_redirect("https://example.com/_matrix/static/client/login") + + def _test_redirect(self, redirect_url): + """Tests that the SSO login flow serves a redirect for the given redirect URL.""" + cas_ticket_url = ( + "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" + % (urllib.parse.quote(redirect_url)) + ) + + # Get Synapse to call the fake CAS and serve the template. + channel = self.make_request("GET", cas_ticket_url) + + self.assertEqual(channel.code, 302) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) + + @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) + def test_deactivated_user(self): + """Logging in as a deactivated account should error.""" + redirect_url = "https://legit-site.com/" + + # First login (to create the user). + self._test_redirect(redirect_url) + + # Deactivate the account. + self.get_success( + self.deactivate_account_handler.deactivate_account( + self.user_id, False, create_requester(self.user_id) + ) + ) + + # Request the CAS ticket. + cas_ticket_url = ( + "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" + % (urllib.parse.quote(redirect_url)) + ) + + # Get Synapse to call the fake CAS and serve the template. + channel = self.make_request("GET", cas_ticket_url) + + # Because the user is deactivated they are served an error template. + self.assertEqual(channel.code, 403) + self.assertIn(b"SSO account deactivated", channel.result["body"]) + + +@skip_unless(HAS_JWT, "requires jwt") +class JWTTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + jwt_secret = "secret" + jwt_algorithm = "HS256" + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + self.hs.config.jwt_enabled = True + self.hs.config.jwt_secret = self.jwt_secret + self.hs.config.jwt_algorithm = self.jwt_algorithm + return self.hs + + def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: + # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. + result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm) + if isinstance(result, bytes): + return result.decode("ascii") + return result + + def jwt_login(self, *args): + params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} + channel = self.make_request(b"POST", LOGIN_URL, params) + return channel + + def test_login_jwt_valid_registered(self): + self.register_user("kermit", "monkey") + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + def test_login_jwt_valid_unregistered(self): + channel = self.jwt_login({"sub": "frog"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@frog:test") + + def test_login_jwt_invalid_signature(self): + channel = self.jwt_login({"sub": "frog"}, "notsecret") + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + "JWT validation failed: Signature verification failed", + ) + + def test_login_jwt_expired(self): + channel = self.jwt_login({"sub": "frog", "exp": 864000}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Signature has expired" + ) + + def test_login_jwt_not_before(self): + now = int(time.time()) + channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + "JWT validation failed: The token is not yet valid (nbf)", + ) + + def test_login_no_sub(self): + channel = self.jwt_login({"username": "root"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual(channel.json_body["error"], "Invalid JWT") + + @override_config( + { + "jwt_config": { + "jwt_enabled": True, + "secret": jwt_secret, + "algorithm": jwt_algorithm, + "issuer": "test-issuer", + } + } + ) + def test_login_iss(self): + """Test validating the issuer claim.""" + # A valid issuer. + channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + # An invalid issuer. + channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Invalid issuer" + ) + + # Not providing an issuer. + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + 'JWT validation failed: Token is missing the "iss" claim', + ) + + def test_login_iss_no_config(self): + """Test providing an issuer claim without requiring it in the configuration.""" + channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + @override_config( + { + "jwt_config": { + "jwt_enabled": True, + "secret": jwt_secret, + "algorithm": jwt_algorithm, + "audiences": ["test-audience"], + } + } + ) + def test_login_aud(self): + """Test validating the audience claim.""" + # A valid audience. + channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + # An invalid audience. + channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Invalid audience" + ) + + # Not providing an audience. + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + 'JWT validation failed: Token is missing the "aud" claim', + ) + + def test_login_aud_no_config(self): + """Test providing an audience without requiring it in the configuration.""" + channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], "JWT validation failed: Invalid audience" + ) + + def test_login_no_token(self): + params = {"type": "org.matrix.login.jwt"} + channel = self.make_request(b"POST", LOGIN_URL, params) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") + + +# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use +# RSS256, with a public key configured in synapse as "jwt_secret", and tokens +# signed by the private key. +@skip_unless(HAS_JWT, "requires jwt") +class JWTPubKeyTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + ] + + # This key's pubkey is used as the jwt_secret setting of synapse. Valid + # tokens are signed by this and validated using the pubkey. It is generated + # with `openssl genrsa 512` (not a secure way to generate real keys, but + # good enough for tests!) + jwt_privatekey = "\n".join( + [ + "-----BEGIN RSA PRIVATE KEY-----", + "MIIBPAIBAAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7TKO1vSEWdq7u9x8SMFiB", + "492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQJAUv7OOSOtiU+wzJq82rnk", + "yR4NHqt7XX8BvkZPM7/+EjBRanmZNSp5kYZzKVaZ/gTOM9+9MwlmhidrUOweKfB/", + "kQIhAPZwHazbjo7dYlJs7wPQz1vd+aHSEH+3uQKIysebkmm3AiEA1nc6mDdmgiUq", + "TpIN8A4MBKmfZMWTLq6z05y/qjKyxb0CIQDYJxCwTEenIaEa4PdoJl+qmXFasVDN", + "ZU0+XtNV7yul0wIhAMI9IhiStIjS2EppBa6RSlk+t1oxh2gUWlIh+YVQfZGRAiEA", + "tqBR7qLZGJ5CVKxWmNhJZGt1QHoUtOch8t9C4IdOZ2g=", + "-----END RSA PRIVATE KEY-----", + ] + ) + + # Generated with `openssl rsa -in foo.key -pubout`, with the the above + # private key placed in foo.key (jwt_privatekey). + jwt_pubkey = "\n".join( + [ + "-----BEGIN PUBLIC KEY-----", + "MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7", + "TKO1vSEWdq7u9x8SMFiB492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQ==", + "-----END PUBLIC KEY-----", + ] + ) + + # This key is used to sign tokens that shouldn't be accepted by synapse. + # Generated just like jwt_privatekey. + bad_privatekey = "\n".join( + [ + "-----BEGIN RSA PRIVATE KEY-----", + "MIIBOgIBAAJBAL//SQrKpKbjCCnv/FlasJCv+t3k/MPsZfniJe4DVFhsktF2lwQv", + "gLjmQD3jBUTz+/FndLSBvr3F4OHtGL9O/osCAwEAAQJAJqH0jZJW7Smzo9ShP02L", + "R6HRZcLExZuUrWI+5ZSP7TaZ1uwJzGFspDrunqaVoPobndw/8VsP8HFyKtceC7vY", + "uQIhAPdYInDDSJ8rFKGiy3Ajv5KWISBicjevWHF9dbotmNO9AiEAxrdRJVU+EI9I", + "eB4qRZpY6n4pnwyP0p8f/A3NBaQPG+cCIFlj08aW/PbxNdqYoBdeBA0xDrXKfmbb", + "iwYxBkwL0JCtAiBYmsi94sJn09u2Y4zpuCbJeDPKzWkbuwQh+W1fhIWQJQIhAKR0", + "KydN6cRLvphNQ9c/vBTdlzWxzcSxREpguC7F1J1m", + "-----END RSA PRIVATE KEY-----", + ] + ) + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + self.hs.config.jwt_enabled = True + self.hs.config.jwt_secret = self.jwt_pubkey + self.hs.config.jwt_algorithm = "RS256" + return self.hs + + def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: + # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. + result: Union[bytes, str] = jwt.encode(payload, secret, "RS256") + if isinstance(result, bytes): + return result.decode("ascii") + return result + + def jwt_login(self, *args): + params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} + channel = self.make_request(b"POST", LOGIN_URL, params) + return channel + + def test_login_jwt_valid(self): + channel = self.jwt_login({"sub": "kermit"}) + self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.json_body["user_id"], "@kermit:test") + + def test_login_jwt_invalid_signature(self): + channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) + self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + self.assertEqual( + channel.json_body["error"], + "JWT validation failed: Signature verification failed", + ) + + +AS_USER = "as_user_alice" + + +class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + register.register_servlets, + ] + + def register_as_user(self, username): + self.make_request( + b"POST", + "/_matrix/client/r0/register?access_token=%s" % (self.service.token,), + {"username": username}, + ) + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + + self.service = ApplicationService( + id="unique_identifier", + token="some_token", + hostname="example.com", + sender="@asbot:example.com", + namespaces={ + ApplicationService.NS_USERS: [ + {"regex": r"@as_user.*", "exclusive": False} + ], + ApplicationService.NS_ROOMS: [], + ApplicationService.NS_ALIASES: [], + }, + ) + self.another_service = ApplicationService( + id="another__identifier", + token="another_token", + hostname="example.com", + sender="@as2bot:example.com", + namespaces={ + ApplicationService.NS_USERS: [ + {"regex": r"@as2_user.*", "exclusive": False} + ], + ApplicationService.NS_ROOMS: [], + ApplicationService.NS_ALIASES: [], + }, + ) + + self.hs.get_datastore().services_cache.append(self.service) + self.hs.get_datastore().services_cache.append(self.another_service) + return self.hs + + def test_login_appservice_user(self): + """Test that an appservice user can use /login""" + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": AS_USER}, + } + channel = self.make_request( + b"POST", LOGIN_URL, params, access_token=self.service.token + ) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_login_appservice_user_bot(self): + """Test that the appservice bot can use /login""" + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": self.service.sender}, + } + channel = self.make_request( + b"POST", LOGIN_URL, params, access_token=self.service.token + ) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_login_appservice_wrong_user(self): + """Test that non-as users cannot login with the as token""" + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": "fibble_wibble"}, + } + channel = self.make_request( + b"POST", LOGIN_URL, params, access_token=self.service.token + ) + + self.assertEquals(channel.result["code"], b"403", channel.result) + + def test_login_appservice_wrong_as(self): + """Test that as users cannot login with wrong as token""" + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": AS_USER}, + } + channel = self.make_request( + b"POST", LOGIN_URL, params, access_token=self.another_service.token + ) + + self.assertEquals(channel.result["code"], b"403", channel.result) + + def test_login_appservice_no_token(self): + """Test that users must provide a token when using the appservice + login method + """ + self.register_as_user(AS_USER) + + params = { + "type": login.LoginRestServlet.APPSERVICE_TYPE, + "identifier": {"type": "m.id.user", "user": AS_USER}, + } + channel = self.make_request(b"POST", LOGIN_URL, params) + + self.assertEquals(channel.result["code"], b"401", channel.result) + + +@skip_unless(HAS_OIDC, "requires OIDC") +class UsernamePickerTestCase(HomeserverTestCase): + """Tests for the username picker flow of SSO login""" + + servlets = [login.register_servlets] + + def default_config(self): + config = super().default_config() + config["public_baseurl"] = BASE_URL + + config["oidc_config"] = {} + config["oidc_config"].update(TEST_OIDC_CONFIG) + config["oidc_config"]["user_mapping_provider"] = { + "config": {"display_name_template": "{{ user.displayname }}"} + } + + # whitelist this client URI so we redirect straight to it rather than + # serving a confirmation page + config["sso"] = {"client_whitelist": ["https://x"]} + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d.update(build_synapse_client_resource_tree(self.hs)) + return d + + def test_username_picker(self): + """Test the happy path of a username picker flow.""" + + # do the start of the login flow + channel = self.helper.auth_via_oidc( + {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL + ) + + # that should redirect to the username picker + self.assertEqual(channel.code, 302, channel.result) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + picker_url = location_headers[0] + self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details") + + # ... with a username_mapping_session cookie + cookies: Dict[str, str] = {} + channel.extract_cookies(cookies) + self.assertIn("username_mapping_session", cookies) + session_id = cookies["username_mapping_session"] + + # introspect the sso handler a bit to check that the username mapping session + # looks ok. + username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions + self.assertIn( + session_id, + username_mapping_sessions, + "session id not found in map", + ) + session = username_mapping_sessions[session_id] + self.assertEqual(session.remote_user_id, "tester") + self.assertEqual(session.display_name, "Jonny") + self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL) + + # the expiry time should be about 15 minutes away + expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) + self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) + + # Now, submit a username to the username picker, which should serve a redirect + # to the completion page + content = urlencode({b"username": b"bobby"}).encode("utf8") + chan = self.make_request( + "POST", + path=picker_url, + content=content, + content_is_form=True, + custom_headers=[ + ("Cookie", "username_mapping_session=" + session_id), + # old versions of twisted don't do form-parsing without a valid + # content-length header. + ("Content-Length", str(len(content))), + ], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + assert location_headers + + # send a request to the completion page, which should 302 to the client redirectUrl + chan = self.make_request( + "GET", + path=location_headers[0], + custom_headers=[("Cookie", "username_mapping_session=" + session_id)], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + assert location_headers + + # ensure that the returned location matches the requested redirect URL + path, query = location_headers[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" + ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") + + # fish the login token out of the returned redirect uri + login_token = params[2][1] + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + chan = self.make_request( + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], "@bobby:test") diff --git a/tests/rest/client/test_password_policy.py b/tests/rest/client/test_password_policy.py new file mode 100644 index 0000000000..3cf5871899 --- /dev/null +++ b/tests/rest/client/test_password_policy.py @@ -0,0 +1,177 @@ +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from synapse.api.constants import LoginType +from synapse.api.errors import Codes +from synapse.rest import admin +from synapse.rest.client import account, login, password_policy, register + +from tests import unittest + + +class PasswordPolicyTestCase(unittest.HomeserverTestCase): + """Tests the password policy feature and its compliance with MSC2000. + + When validating a password, Synapse does the necessary checks in this order: + + 1. Password is long enough + 2. Password contains digit(s) + 3. Password contains symbol(s) + 4. Password contains uppercase letter(s) + 5. Password contains lowercase letter(s) + + For each test below that checks whether a password triggers the right error code, + that test provides a password good enough to pass the previous tests, but not the + one it is currently testing (nor any test that comes afterward). + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + password_policy.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.register_url = "/_matrix/client/r0/register" + self.policy = { + "enabled": True, + "minimum_length": 10, + "require_digit": True, + "require_symbol": True, + "require_lowercase": True, + "require_uppercase": True, + } + + config = self.default_config() + config["password_config"] = { + "policy": self.policy, + } + + hs = self.setup_test_homeserver(config=config) + return hs + + def test_get_policy(self): + """Tests if the /password_policy endpoint returns the configured policy.""" + + channel = self.make_request("GET", "/_matrix/client/r0/password_policy") + + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual( + channel.json_body, + { + "m.minimum_length": 10, + "m.require_digit": True, + "m.require_symbol": True, + "m.require_lowercase": True, + "m.require_uppercase": True, + }, + channel.result, + ) + + def test_password_too_short(self): + request_data = json.dumps({"username": "kermit", "password": "shorty"}) + channel = self.make_request("POST", self.register_url, request_data) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.PASSWORD_TOO_SHORT, + channel.result, + ) + + def test_password_no_digit(self): + request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) + channel = self.make_request("POST", self.register_url, request_data) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.PASSWORD_NO_DIGIT, + channel.result, + ) + + def test_password_no_symbol(self): + request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) + channel = self.make_request("POST", self.register_url, request_data) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.PASSWORD_NO_SYMBOL, + channel.result, + ) + + def test_password_no_uppercase(self): + request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) + channel = self.make_request("POST", self.register_url, request_data) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.PASSWORD_NO_UPPERCASE, + channel.result, + ) + + def test_password_no_lowercase(self): + request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) + channel = self.make_request("POST", self.register_url, request_data) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.PASSWORD_NO_LOWERCASE, + channel.result, + ) + + def test_password_compliant(self): + request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) + channel = self.make_request("POST", self.register_url, request_data) + + # Getting a 401 here means the password has passed validation and the server has + # responded with a list of registration flows. + self.assertEqual(channel.code, 401, channel.result) + + def test_password_change(self): + """This doesn't test every possible use case, only that hitting /account/password + triggers the password validation code. + """ + compliant_password = "C0mpl!antpassword" + not_compliant_password = "notcompliantpassword" + + user_id = self.register_user("kermit", compliant_password) + tok = self.login("kermit", compliant_password) + + request_data = json.dumps( + { + "new_password": not_compliant_password, + "auth": { + "password": compliant_password, + "type": LoginType.PASSWORD, + "user": user_id, + }, + } + ) + channel = self.make_request( + "POST", + "/_matrix/client/r0/account/password", + request_data, + access_token=tok, + ) + + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py new file mode 100644 index 0000000000..1d152352d1 --- /dev/null +++ b/tests/rest/client/test_presence.py @@ -0,0 +1,76 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock + +from twisted.internet import defer + +from synapse.handlers.presence import PresenceHandler +from synapse.rest.client import presence +from synapse.types import UserID + +from tests import unittest + + +class PresenceTestCase(unittest.HomeserverTestCase): + """Tests presence REST API.""" + + user_id = "@sid:red" + + user = UserID.from_string(user_id) + servlets = [presence.register_servlets] + + def make_homeserver(self, reactor, clock): + + presence_handler = Mock(spec=PresenceHandler) + presence_handler.set_state.return_value = defer.succeed(None) + + hs = self.setup_test_homeserver( + "red", + federation_http_client=None, + federation_client=Mock(), + presence_handler=presence_handler, + ) + + return hs + + def test_put_presence(self): + """ + PUT to the status endpoint with use_presence enabled will call + set_state on the presence handler. + """ + self.hs.config.use_presence = True + + body = {"presence": "here", "status_msg": "beep boop"} + channel = self.make_request( + "PUT", "/presence/%s/status" % (self.user_id,), body + ) + + self.assertEqual(channel.code, 200) + self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) + + @unittest.override_config({"use_presence": False}) + def test_put_presence_disabled(self): + """ + PUT to the status endpoint with use_presence disabled will NOT call + set_state on the presence handler. + """ + + body = {"presence": "here", "status_msg": "beep boop"} + channel = self.make_request( + "PUT", "/presence/%s/status" % (self.user_id,), body + ) + + self.assertEqual(channel.code, 200) + self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py new file mode 100644 index 0000000000..2860579c2e --- /dev/null +++ b/tests/rest/client/test_profile.py @@ -0,0 +1,270 @@ +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests REST events for /profile paths.""" +from synapse.rest import admin +from synapse.rest.client import login, profile, room + +from tests import unittest + + +class ProfileTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + self.hs = self.setup_test_homeserver() + return self.hs + + def prepare(self, reactor, clock, hs): + self.owner = self.register_user("owner", "pass") + self.owner_tok = self.login("owner", "pass") + self.other = self.register_user("other", "pass", displayname="Bob") + + def test_get_displayname(self): + res = self._get_displayname() + self.assertEqual(res, "owner") + + def test_set_displayname(self): + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.owner,), + content={"displayname": "test"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + res = self._get_displayname() + self.assertEqual(res, "test") + + def test_set_displayname_noauth(self): + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.owner,), + content={"displayname": "test"}, + ) + self.assertEqual(channel.code, 401, channel.result) + + def test_set_displayname_too_long(self): + """Attempts to set a stupid displayname should get a 400""" + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.owner,), + content={"displayname": "test" * 100}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + res = self._get_displayname() + self.assertEqual(res, "owner") + + def test_get_displayname_other(self): + res = self._get_displayname(self.other) + self.assertEquals(res, "Bob") + + def test_set_displayname_other(self): + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.other,), + content={"displayname": "test"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + def test_get_avatar_url(self): + res = self._get_avatar_url() + self.assertIsNone(res) + + def test_set_avatar_url(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + res = self._get_avatar_url() + self.assertEqual(res, "http://my.server/pic.gif") + + def test_set_avatar_url_noauth(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif"}, + ) + self.assertEqual(channel.code, 401, channel.result) + + def test_set_avatar_url_too_long(self): + """Attempts to set a stupid avatar_url should get a 400""" + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.owner,), + content={"avatar_url": "http://my.server/pic.gif" * 100}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + res = self._get_avatar_url() + self.assertIsNone(res) + + def test_get_avatar_url_other(self): + res = self._get_avatar_url(self.other) + self.assertIsNone(res) + + def test_set_avatar_url_other(self): + channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.other,), + content={"avatar_url": "http://my.server/pic.gif"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + + def _get_displayname(self, name=None): + channel = self.make_request( + "GET", "/profile/%s/displayname" % (name or self.owner,) + ) + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body["displayname"] + + def _get_avatar_url(self, name=None): + channel = self.make_request( + "GET", "/profile/%s/avatar_url" % (name or self.owner,) + ) + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body.get("avatar_url") + + +class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + + config = self.default_config() + config["require_auth_for_profile_requests"] = True + config["limit_profile_requests_to_users_who_share_rooms"] = True + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def prepare(self, reactor, clock, hs): + # User owning the requested profile. + self.owner = self.register_user("owner", "pass") + self.owner_tok = self.login("owner", "pass") + self.profile_url = "/profile/%s" % (self.owner) + + # User requesting the profile. + self.requester = self.register_user("requester", "pass") + self.requester_tok = self.login("requester", "pass") + + self.room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok) + + def test_no_auth(self): + self.try_fetch_profile(401) + + def test_not_in_shared_room(self): + self.ensure_requester_left_room() + + self.try_fetch_profile(403, access_token=self.requester_tok) + + def test_in_shared_room(self): + self.ensure_requester_left_room() + + self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok) + + self.try_fetch_profile(200, self.requester_tok) + + def try_fetch_profile(self, expected_code, access_token=None): + self.request_profile(expected_code, access_token=access_token) + + self.request_profile( + expected_code, url_suffix="/displayname", access_token=access_token + ) + + self.request_profile( + expected_code, url_suffix="/avatar_url", access_token=access_token + ) + + def request_profile(self, expected_code, url_suffix="", access_token=None): + channel = self.make_request( + "GET", self.profile_url + url_suffix, access_token=access_token + ) + self.assertEqual(channel.code, expected_code, channel.result) + + def ensure_requester_left_room(self): + try: + self.helper.leave( + room=self.room_id, user=self.requester, tok=self.requester_tok + ) + except AssertionError: + # We don't care whether the leave request didn't return a 200 (e.g. + # if the user isn't already in the room), because we only want to + # make sure the user isn't in the room. + pass + + +class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["require_auth_for_profile_requests"] = True + config["limit_profile_requests_to_users_who_share_rooms"] = True + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def prepare(self, reactor, clock, hs): + # User requesting the profile. + self.requester = self.register_user("requester", "pass") + self.requester_tok = self.login("requester", "pass") + + def test_can_lookup_own_profile(self): + """Tests that a user can lookup their own profile without having to be in a room + if 'require_auth_for_profile_requests' is set to true in the server's config. + """ + channel = self.make_request( + "GET", "/profile/" + self.requester, access_token=self.requester_tok + ) + self.assertEqual(channel.code, 200, channel.result) + + channel = self.make_request( + "GET", + "/profile/" + self.requester + "/displayname", + access_token=self.requester_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + channel = self.make_request( + "GET", + "/profile/" + self.requester + "/avatar_url", + access_token=self.requester_tok, + ) + self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/rest/client/test_push_rule_attrs.py b/tests/rest/client/test_push_rule_attrs.py new file mode 100644 index 0000000000..d0ce91ccd9 --- /dev/null +++ b/tests/rest/client/test_push_rule_attrs.py @@ -0,0 +1,414 @@ +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import synapse +from synapse.api.errors import Codes +from synapse.rest.client import login, push_rule, room + +from tests.unittest import HomeserverTestCase + + +class PushRuleAttributesTestCase(HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + push_rule.register_servlets, + ] + hijack_auth = False + + def test_enabled_on_creation(self): + """ + Tests the GET and PUT of push rules' `enabled` endpoints. + Tests that a rule is enabled upon creation, even though a rule with that + ruleId existed previously and was disabled. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # GET enabled for that new rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], True) + + def test_enabled_on_recreation(self): + """ + Tests the GET and PUT of push rules' `enabled` endpoints. + Tests that a rule is enabled upon creation, even if a rule with that + ruleId existed previously and was disabled. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # disable the rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/enabled", + {"enabled": False}, + access_token=token, + ) + self.assertEqual(channel.code, 200) + + # check rule disabled + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], False) + + # DELETE the rule + channel = self.make_request( + "DELETE", "/pushrules/global/override/best.friend", access_token=token + ) + self.assertEqual(channel.code, 200) + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # GET enabled for that new rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], True) + + def test_enabled_disable(self): + """ + Tests the GET and PUT of push rules' `enabled` endpoints. + Tests that a rule is disabled and enabled when we ask for it. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # disable the rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/enabled", + {"enabled": False}, + access_token=token, + ) + self.assertEqual(channel.code, 200) + + # check rule disabled + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], False) + + # re-enable the rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/enabled", + {"enabled": True}, + access_token=token, + ) + self.assertEqual(channel.code, 200) + + # check rule enabled + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["enabled"], True) + + def test_enabled_404_when_get_non_existent(self): + """ + Tests that `enabled` gives 404 when the rule doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # check 404 for never-heard-of rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # GET enabled for that new rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 200) + + # DELETE the rule + channel = self.make_request( + "DELETE", "/pushrules/global/override/best.friend", access_token=token + ) + self.assertEqual(channel.code, 200) + + # check 404 for deleted rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_enabled_404_when_get_non_existent_server_rule(self): + """ + Tests that `enabled` gives 404 when the server-default rule doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # check 404 for never-heard-of rule + channel = self.make_request( + "GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_enabled_404_when_put_non_existent_rule(self): + """ + Tests that `enabled` gives 404 when we put to a rule that doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # enable & check 404 for never-heard-of rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/enabled", + {"enabled": True}, + access_token=token, + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_enabled_404_when_put_non_existent_server_rule(self): + """ + Tests that `enabled` gives 404 when we put to a server-default rule that doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # enable & check 404 for never-heard-of rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/.m.muahahah/enabled", + {"enabled": True}, + access_token=token, + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_actions_get(self): + """ + Tests that `actions` gives you what you expect on a fresh rule. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # GET actions for that new rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/actions", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}] + ) + + def test_actions_put(self): + """ + Tests that PUT on actions updates the value you'd get from GET. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # change the rule actions + channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/actions", + {"actions": ["dont_notify"]}, + access_token=token, + ) + self.assertEqual(channel.code, 200) + + # GET actions for that new rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/actions", access_token=token + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["actions"], ["dont_notify"]) + + def test_actions_404_when_get_non_existent(self): + """ + Tests that `actions` gives 404 when the rule doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + body = { + "conditions": [ + {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} + ], + "actions": ["notify", {"set_tweak": "highlight"}], + } + + # check 404 for never-heard-of rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + # PUT a new rule + channel = self.make_request( + "PUT", "/pushrules/global/override/best.friend", body, access_token=token + ) + self.assertEqual(channel.code, 200) + + # DELETE the rule + channel = self.make_request( + "DELETE", "/pushrules/global/override/best.friend", access_token=token + ) + self.assertEqual(channel.code, 200) + + # check 404 for deleted rule + channel = self.make_request( + "GET", "/pushrules/global/override/best.friend/enabled", access_token=token + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_actions_404_when_get_non_existent_server_rule(self): + """ + Tests that `actions` gives 404 when the server-default rule doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # check 404 for never-heard-of rule + channel = self.make_request( + "GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_actions_404_when_put_non_existent_rule(self): + """ + Tests that `actions` gives 404 when putting to a rule that doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # enable & check 404 for never-heard-of rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/best.friend/actions", + {"actions": ["dont_notify"]}, + access_token=token, + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_actions_404_when_put_non_existent_server_rule(self): + """ + Tests that `actions` gives 404 when putting to a server-default rule that doesn't exist. + """ + self.register_user("user", "pass") + token = self.login("user", "pass") + + # enable & check 404 for never-heard-of rule + channel = self.make_request( + "PUT", + "/pushrules/global/override/.m.muahahah/actions", + {"actions": ["dont_notify"]}, + access_token=token, + ) + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py new file mode 100644 index 0000000000..fecda037a5 --- /dev/null +++ b/tests/rest/client/test_register.py @@ -0,0 +1,746 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import json +import os + +import pkg_resources + +import synapse.rest.admin +from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType +from synapse.api.errors import Codes +from synapse.appservice import ApplicationService +from synapse.rest.client import account, account_validity, login, logout, register, sync + +from tests import unittest +from tests.unittest import override_config + + +class RegisterRestServletTestCase(unittest.HomeserverTestCase): + + servlets = [ + login.register_servlets, + register.register_servlets, + synapse.rest.admin.register_servlets, + ] + url = b"/_matrix/client/r0/register" + + def default_config(self): + config = super().default_config() + config["allow_guest_access"] = True + return config + + def test_POST_appservice_registration_valid(self): + user_id = "@as_user_kermit:test" + as_token = "i_am_an_app_service" + + appservice = ApplicationService( + as_token, + self.hs.config.server_name, + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", + ) + + self.hs.get_datastore().services_cache.append(appservice) + request_data = json.dumps( + {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE} + ) + + channel = self.make_request( + b"POST", self.url + b"?access_token=i_am_an_app_service", request_data + ) + + self.assertEquals(channel.result["code"], b"200", channel.result) + det_data = {"user_id": user_id, "home_server": self.hs.hostname} + self.assertDictContainsSubset(det_data, channel.json_body) + + def test_POST_appservice_registration_no_type(self): + as_token = "i_am_an_app_service" + + appservice = ApplicationService( + as_token, + self.hs.config.server_name, + id="1234", + namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", + ) + + self.hs.get_datastore().services_cache.append(appservice) + request_data = json.dumps({"username": "as_user_kermit"}) + + channel = self.make_request( + b"POST", self.url + b"?access_token=i_am_an_app_service", request_data + ) + + self.assertEquals(channel.result["code"], b"400", channel.result) + + def test_POST_appservice_registration_invalid(self): + self.appservice = None # no application service exists + request_data = json.dumps( + {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} + ) + channel = self.make_request( + b"POST", self.url + b"?access_token=i_am_an_app_service", request_data + ) + + self.assertEquals(channel.result["code"], b"401", channel.result) + + def test_POST_bad_password(self): + request_data = json.dumps({"username": "kermit", "password": 666}) + channel = self.make_request(b"POST", self.url, request_data) + + self.assertEquals(channel.result["code"], b"400", channel.result) + self.assertEquals(channel.json_body["error"], "Invalid password") + + def test_POST_bad_username(self): + request_data = json.dumps({"username": 777, "password": "monkey"}) + channel = self.make_request(b"POST", self.url, request_data) + + self.assertEquals(channel.result["code"], b"400", channel.result) + self.assertEquals(channel.json_body["error"], "Invalid username") + + def test_POST_user_valid(self): + user_id = "@kermit:test" + device_id = "frogfone" + params = { + "username": "kermit", + "password": "monkey", + "device_id": device_id, + "auth": {"type": LoginType.DUMMY}, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", self.url, request_data) + + det_data = { + "user_id": user_id, + "home_server": self.hs.hostname, + "device_id": device_id, + } + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertDictContainsSubset(det_data, channel.json_body) + + @override_config({"enable_registration": False}) + def test_POST_disabled_registration(self): + request_data = json.dumps({"username": "kermit", "password": "monkey"}) + self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) + + channel = self.make_request(b"POST", self.url, request_data) + + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals(channel.json_body["error"], "Registration has been disabled") + self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN") + + def test_POST_guest_registration(self): + self.hs.config.macaroon_secret_key = "test" + self.hs.config.allow_guest_access = True + + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + + det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertDictContainsSubset(det_data, channel.json_body) + + def test_POST_disabled_guest_registration(self): + self.hs.config.allow_guest_access = False + + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals(channel.json_body["error"], "Guest access is disabled") + + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) + def test_POST_ratelimiting_guest(self): + for i in range(0, 6): + url = self.url + b"?kind=guest" + channel = self.make_request(b"POST", url, b"{}") + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) + + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + + self.assertEquals(channel.result["code"], b"200", channel.result) + + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) + def test_POST_ratelimiting(self): + for i in range(0, 6): + params = { + "username": "kermit" + str(i), + "password": "monkey", + "device_id": "frogfone", + "auth": {"type": LoginType.DUMMY}, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", self.url, request_data) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) + + channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") + + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_advertised_flows(self): + channel = self.make_request(b"POST", self.url, b"{}") + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + # with the stock config, we only expect the dummy flow + self.assertCountEqual([["m.login.dummy"]], (f["stages"] for f in flows)) + + @unittest.override_config( + { + "public_baseurl": "https://test_server", + "enable_registration_captcha": True, + "user_consent": { + "version": "1", + "template_dir": "/", + "require_at_registration": True, + }, + "account_threepid_delegates": { + "email": "https://id_server", + "msisdn": "https://id_server", + }, + } + ) + def test_advertised_flows_captcha_and_terms_and_3pids(self): + channel = self.make_request(b"POST", self.url, b"{}") + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + self.assertCountEqual( + [ + ["m.login.recaptcha", "m.login.terms", "m.login.dummy"], + ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"], + ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"], + [ + "m.login.recaptcha", + "m.login.terms", + "m.login.msisdn", + "m.login.email.identity", + ], + ], + (f["stages"] for f in flows), + ) + + @unittest.override_config( + { + "public_baseurl": "https://test_server", + "registrations_require_3pid": ["email"], + "disable_msisdn_registration": True, + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_advertised_flows_no_msisdn_email_required(self): + channel = self.make_request(b"POST", self.url, b"{}") + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + + # with the stock config, we expect all four combinations of 3pid + self.assertCountEqual( + [["m.login.email.identity"]], (f["stages"] for f in flows) + ) + + @unittest.override_config( + { + "request_token_inhibit_3pid_errors": True, + "public_baseurl": "https://test_server", + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_request_token_existing_email_inhibit_error(self): + """Test that requesting a token via this endpoint doesn't leak existing + associations if configured that way. + """ + user_id = self.register_user("kermit", "monkey") + self.login("kermit", "monkey") + + email = "test@example.com" + + # Add a threepid + self.get_success( + self.hs.get_datastore().user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": email, "send_attempt": 1}, + ) + self.assertEquals(200, channel.code, channel.result) + + self.assertIsNotNone(channel.json_body.get("sid")) + + @unittest.override_config( + { + "public_baseurl": "https://test_server", + "email": { + "smtp_host": "mail_server", + "smtp_port": 2525, + "notif_from": "sender@host", + }, + } + ) + def test_reject_invalid_email(self): + """Check that bad emails are rejected""" + + # Test for email with multiple @ + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1}, + ) + self.assertEquals(400, channel.code, channel.result) + # Check error to ensure that we're not erroring due to a bug in the test. + self.assertEquals( + channel.json_body, + {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, + ) + + # Test for email with no @ + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": "email", "send_attempt": 1}, + ) + self.assertEquals(400, channel.code, channel.result) + self.assertEquals( + channel.json_body, + {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, + ) + + # Test for super long email + email = "a@" + "a" * 1000 + channel = self.make_request( + "POST", + b"register/email/requestToken", + {"client_secret": "foobar", "email": email, "send_attempt": 1}, + ) + self.assertEquals(400, channel.code, channel.result) + self.assertEquals( + channel.json_body, + {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, + ) + + +class AccountValidityTestCase(unittest.HomeserverTestCase): + + servlets = [ + register.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + sync.register_servlets, + logout.register_servlets, + account_validity.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + # Test for account expiring after a week. + config["enable_registration"] = True + config["account_validity"] = { + "enabled": True, + "period": 604800000, # Time in ms for 1 week + } + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def test_validity_period(self): + self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + # The specific endpoint doesn't matter, all we need is an authenticated + # endpoint. + channel = self.make_request(b"GET", "/sync", access_token=tok) + + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) + + channel = self.make_request(b"GET", "/sync", access_token=tok) + + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals( + channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result + ) + + def test_manual_renewal(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) + + # If we register the admin user at the beginning of the test, it will + # expire at the same time as the normal user and the renewal request + # will be denied. + self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + url = "/_synapse/admin/v1/account_validity/validity" + params = {"user_id": user_id} + request_data = json.dumps(params) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # The specific endpoint doesn't matter, all we need is an authenticated + # endpoint. + channel = self.make_request(b"GET", "/sync", access_token=tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_manual_expire(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + url = "/_synapse/admin/v1/account_validity/validity" + params = { + "user_id": user_id, + "expiration_ts": 0, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # The specific endpoint doesn't matter, all we need is an authenticated + # endpoint. + channel = self.make_request(b"GET", "/sync", access_token=tok) + self.assertEquals(channel.result["code"], b"403", channel.result) + self.assertEquals( + channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result + ) + + def test_logging_out_expired_user(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + self.register_user("admin", "adminpassword", admin=True) + admin_tok = self.login("admin", "adminpassword") + + url = "/_synapse/admin/v1/account_validity/validity" + params = { + "user_id": user_id, + "expiration_ts": 0, + "enable_renewal_emails": False, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Try to log the user out + channel = self.make_request(b"POST", "/logout", access_token=tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Log the user in again (allowed for expired accounts) + tok = self.login("kermit", "monkey") + + # Try to log out all of the user's sessions + channel = self.make_request(b"POST", "/logout/all", access_token=tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + +class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): + + servlets = [ + register.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + sync.register_servlets, + account_validity.register_servlets, + account.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + # Test for account expiring after a week and renewal emails being sent 2 + # days before expiry. + config["enable_registration"] = True + config["account_validity"] = { + "enabled": True, + "period": 604800000, # Time in ms for 1 week + "renew_at": 172800000, # Time in ms for 2 days + "renew_by_email_enabled": True, + "renew_email_subject": "Renew your account", + "account_renewed_html_path": "account_renewed.html", + "invalid_token_html_path": "invalid_token.html", + } + + # Email config. + + config["email"] = { + "enable_notifs": True, + "template_dir": os.path.abspath( + pkg_resources.resource_filename("synapse", "res/templates") + ), + "expiry_template_html": "notice_expiry.html", + "expiry_template_text": "notice_expiry.txt", + "notif_template_html": "notif_mail.html", + "notif_template_text": "notif_mail.txt", + "smtp_host": "127.0.0.1", + "smtp_port": 20, + "require_transport_security": False, + "smtp_user": None, + "smtp_pass": None, + "notif_from": "test@example.com", + } + config["public_baseurl"] = "aaa" + + self.hs = self.setup_test_homeserver(config=config) + + async def sendmail(*args, **kwargs): + self.email_attempts.append((args, kwargs)) + + self.email_attempts = [] + self.hs.get_send_email_handler()._sendmail = sendmail + + self.store = self.hs.get_datastore() + + return self.hs + + def test_renewal_email(self): + self.email_attempts = [] + + (user_id, tok) = self.create_user() + + # Move 5 days forward. This should trigger a renewal email to be sent. + self.reactor.advance(datetime.timedelta(days=5).total_seconds()) + self.assertEqual(len(self.email_attempts), 1) + + # Retrieving the URL from the email is too much pain for now, so we + # retrieve the token from the DB. + renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) + url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token + channel = self.make_request(b"GET", url) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Check that we're getting HTML back. + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) + + # Check that the HTML we're getting is the one we expect on a successful renewal. + expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id)) + expected_html = self.hs.config.account_validity.account_validity_account_renewed_template.render( + expiration_ts=expiration_ts + ) + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + + # Move 1 day forward. Try to renew with the same token again. + url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token + channel = self.make_request(b"GET", url) + self.assertEquals(channel.result["code"], b"200", channel.result) + + # Check that we're getting HTML back. + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) + + # Check that the HTML we're getting is the one we expect when reusing a + # token. The account expiration date should not have changed. + expected_html = self.hs.config.account_validity.account_validity_account_previously_renewed_template.render( + expiration_ts=expiration_ts + ) + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + + # Move 3 days forward. If the renewal failed, every authed request with + # our access token should be denied from now, otherwise they should + # succeed. + self.reactor.advance(datetime.timedelta(days=3).total_seconds()) + channel = self.make_request(b"GET", "/sync", access_token=tok) + self.assertEquals(channel.result["code"], b"200", channel.result) + + def test_renewal_invalid_token(self): + # Hit the renewal endpoint with an invalid token and check that it behaves as + # expected, i.e. that it responds with 404 Not Found and the correct HTML. + url = "/_matrix/client/unstable/account_validity/renew?token=123" + channel = self.make_request(b"GET", url) + self.assertEquals(channel.result["code"], b"404", channel.result) + + # Check that we're getting HTML back. + content_type = channel.headers.getRawHeaders(b"Content-Type") + self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) + + # Check that the HTML we're getting is the one we expect when using an + # invalid/unknown token. + expected_html = ( + self.hs.config.account_validity.account_validity_invalid_token_template.render() + ) + self.assertEqual( + channel.result["body"], expected_html.encode("utf8"), channel.result + ) + + def test_manual_email_send(self): + self.email_attempts = [] + + (user_id, tok) = self.create_user() + channel = self.make_request( + b"POST", + "/_matrix/client/unstable/account_validity/send_mail", + access_token=tok, + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.assertEqual(len(self.email_attempts), 1) + + def test_deactivated_user(self): + self.email_attempts = [] + + (user_id, tok) = self.create_user() + + request_data = json.dumps( + { + "auth": { + "type": "m.login.password", + "user": user_id, + "password": "monkey", + }, + "erase": False, + } + ) + channel = self.make_request( + "POST", "account/deactivate", request_data, access_token=tok + ) + self.assertEqual(channel.code, 200) + + self.reactor.advance(datetime.timedelta(days=8).total_seconds()) + + self.assertEqual(len(self.email_attempts), 0) + + def create_user(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + # We need to manually add an email address otherwise the handler will do + # nothing. + now = self.hs.get_clock().time_msec() + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address="kermit@example.com", + validated_at=now, + added_at=now, + ) + ) + return user_id, tok + + def test_manual_email_send_expired_account(self): + user_id = self.register_user("kermit", "monkey") + tok = self.login("kermit", "monkey") + + # We need to manually add an email address otherwise the handler will do + # nothing. + now = self.hs.get_clock().time_msec() + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address="kermit@example.com", + validated_at=now, + added_at=now, + ) + ) + + # Make the account expire. + self.reactor.advance(datetime.timedelta(days=8).total_seconds()) + + # Ignore all emails sent by the automatic background task and only focus on the + # ones sent manually. + self.email_attempts = [] + + # Test that we're still able to manually trigger a mail to be sent. + channel = self.make_request( + b"POST", + "/_matrix/client/unstable/account_validity/send_mail", + access_token=tok, + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.assertEqual(len(self.email_attempts), 1) + + +class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): + + servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] + + def make_homeserver(self, reactor, clock): + self.validity_period = 10 + self.max_delta = self.validity_period * 10.0 / 100.0 + + config = self.default_config() + + config["enable_registration"] = True + config["account_validity"] = {"enabled": False} + + self.hs = self.setup_test_homeserver(config=config) + + # We need to set these directly, instead of in the homeserver config dict above. + # This is due to account validity-related config options not being read by + # Synapse when account_validity.enabled is False. + self.hs.get_datastore()._account_validity_period = self.validity_period + self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta + + self.store = self.hs.get_datastore() + + return self.hs + + def test_background_job(self): + """ + Tests the same thing as test_background_job, except that it sets the + startup_job_max_delta parameter and checks that the expiration date is within the + allowed range. + """ + user_id = self.register_user("kermit_delta", "user") + + self.hs.config.account_validity.startup_job_max_delta = self.max_delta + + now_ms = self.hs.get_clock().time_msec() + self.get_success(self.store._set_expiration_date_when_missing()) + + res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) + + self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) + self.assertLessEqual(res, now_ms + self.validity_period) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py new file mode 100644 index 0000000000..02b5e9a8d0 --- /dev/null +++ b/tests/rest/client/test_relations.py @@ -0,0 +1,724 @@ +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import json +import urllib +from typing import Optional + +from synapse.api.constants import EventTypes, RelationTypes +from synapse.rest import admin +from synapse.rest.client import login, register, relations, room + +from tests import unittest + + +class RelationsTestCase(unittest.HomeserverTestCase): + servlets = [ + relations.register_servlets, + room.register_servlets, + login.register_servlets, + register.register_servlets, + admin.register_servlets_for_client_rest_resource, + ] + hijack_auth = False + + def make_homeserver(self, reactor, clock): + # We need to enable msc1849 support for aggregations + config = self.default_config() + config["experimental_msc1849_support_enabled"] = True + + # We enable frozen dicts as relations/edits change event contents, so we + # want to test that we don't modify the events in the caches. + config["use_frozen_dicts"] = True + + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor, clock, hs): + self.user_id, self.user_token = self._create_user("alice") + self.user2_id, self.user2_token = self._create_user("bob") + + self.room = self.helper.create_room_as(self.user_id, tok=self.user_token) + self.helper.join(self.room, user=self.user2_id, tok=self.user2_token) + res = self.helper.send(self.room, body="Hi!", tok=self.user_token) + self.parent_id = res["event_id"] + + def test_send_relation(self): + """Tests that sending a relation using the new /send_relation works + creates the right shape of event. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") + self.assertEquals(200, channel.code, channel.json_body) + + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, event_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assert_dict( + { + "type": "m.reaction", + "sender": self.user_id, + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "key": "👍", + "rel_type": RelationTypes.ANNOTATION, + } + }, + }, + channel.json_body, + ) + + def test_deny_membership(self): + """Test that we deny relations on membership events""" + channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) + self.assertEquals(400, channel.code, channel.json_body) + + def test_deny_double_react(self): + """Test that we deny relations on membership events""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(400, channel.code, channel.json_body) + + def test_basic_paginate_relations(self): + """Tests that calling pagination API correctly the latest relations.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + annotation_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # We expect to get back a single pagination result, which is the full + # relation event we sent above. + self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"}, + channel.json_body["chunk"][0], + ) + + # We also expect to get the original event (the id of which is self.parent_id) + self.assertEquals( + channel.json_body["original_event"]["event_id"], self.parent_id + ) + + # Make sure next_batch has something in it that looks like it could be a + # valid token. + self.assertIsInstance( + channel.json_body.get("next_batch"), str, channel.json_body + ) + + def test_repeated_paginate_relations(self): + """Test that if we paginate using a limit and tokens then we get the + expected events. + """ + + expected_event_ids = [] + for _ in range(10): + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") + self.assertEquals(200, channel.code, channel.json_body) + expected_event_ids.append(channel.json_body["event_id"]) + + prev_token = None + found_event_ids = [] + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s" + % (self.room, self.parent_id, from_token), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEquals(found_event_ids, expected_event_ids) + + def test_aggregation_pagination_groups(self): + """Test that we can paginate annotation groups correctly.""" + + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 + + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) + + idx = 0 + sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} + for key in itertools.chain.from_iterable( + itertools.repeat(key, num) for key, num in sent_groups.items() + ): + channel = self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key=key, + access_token=access_tokens[idx], + ) + self.assertEquals(200, channel.code, channel.json_body) + + idx += 1 + idx %= len(access_tokens) + + prev_token = None + found_groups = {} + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s" + % (self.room, self.parent_id, from_token), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + for groups in channel.json_body["chunk"]: + # We only expect reactions + self.assertEqual(groups["type"], "m.reaction", channel.json_body) + + # We should only see each key once + self.assertNotIn(groups["key"], found_groups, channel.json_body) + + found_groups[groups["key"]] = groups["count"] + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + self.assertEquals(sent_groups, found_groups) + + def test_aggregation_pagination_within_group(self): + """Test that we can paginate within an annotation group.""" + + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 + + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) + + idx = 0 + expected_event_ids = [] + for _ in range(10): + channel = self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key="👍", + access_token=access_tokens[idx], + ) + self.assertEquals(200, channel.code, channel.json_body) + expected_event_ids.append(channel.json_body["event_id"]) + + idx += 1 + + # Also send a different type of reaction so that we test we don't see it + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + self.assertEquals(200, channel.code, channel.json_body) + + prev_token = None + found_event_ids = [] + encoded_key = urllib.parse.quote_plus("👍".encode()) + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s" + "/aggregations/%s/%s/m.reaction/%s?limit=1%s" + % ( + self.room, + self.parent_id, + RelationTypes.ANNOTATION, + encoded_key, + from_token, + ), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEquals(found_event_ids, expected_event_ids) + + def test_aggregation(self): + """Test that annotations get correctly aggregated.""" + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals( + channel.json_body, + { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + ) + + def test_aggregation_redactions(self): + """Test that annotations get correctly aggregated after a redaction.""" + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + to_redact_event_id = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Now lets redact one of the 'a' reactions + channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id), + access_token=self.user_token, + content={}, + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s" + % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals( + channel.json_body, + {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, + ) + + def test_aggregation_must_be_annotation(self): + """Test that aggregations must be annotations.""" + + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1" + % (self.room, self.parent_id, RelationTypes.REPLACE), + access_token=self.user_token, + ) + self.assertEquals(400, channel.code, channel.json_body) + + def test_aggregation_get_event(self): + """Test that annotations and references get correctly bundled when + getting the parent event. + """ + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + self.assertEquals(200, channel.code, channel.json_body) + + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + reply_1 = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + self.assertEquals(200, channel.code, channel.json_body) + reply_2 = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals( + channel.json_body["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + RelationTypes.REFERENCE: { + "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] + }, + }, + ) + + def test_edit(self): + """Test that a simple edit works.""" + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + ) + self.assertEquals(200, channel.code, channel.json_body) + + edit_event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals(channel.json_body["content"], new_body) + + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + + def test_multi_edit(self): + """Test that multiple edits, including attempts by people who + shouldn't be allowed, are correctly handled. + """ + + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "First edit"}, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + ) + self.assertEquals(200, channel.code, channel.json_body) + + edit_event_id = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message.WRONG_TYPE", + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, self.parent_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEquals(channel.json_body["content"], new_body) + + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + + def test_edit_reply(self): + """Test that editing a reply works.""" + + # Create a reply to edit. + channel = self._send_relation( + RelationTypes.REFERENCE, + "m.room.message", + content={"msgtype": "m.text", "body": "A reply!"}, + ) + self.assertEquals(200, channel.code, channel.json_body) + reply = channel.json_body["event_id"] + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + parent_id=reply, + ) + self.assertEquals(200, channel.code, channel.json_body) + + edit_event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, reply), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # We expect to see the new body in the dict, as well as the reference + # metadata sill intact. + self.assertDictContainsSubset(new_body, channel.json_body["content"]) + self.assertDictContainsSubset( + { + "m.relates_to": { + "event_id": self.parent_id, + "key": None, + "rel_type": "m.reference", + } + }, + channel.json_body["content"], + ) + + # We expect that the edit relation appears in the unsigned relations + # section. + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + + def test_relations_redaction_redacts_edits(self): + """Test that edits of an event are redacted when the original event + is redacted. + """ + # Send a new event + res = self.helper.send(self.room, body="Heyo!", tok=self.user_token) + original_event_id = res["event_id"] + + # Add a relation + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + parent_id=original_event_id, + content={ + "msgtype": "m.text", + "body": "Wibble", + "m.new_content": {"msgtype": "m.text", "body": "First edit"}, + }, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Check the relation is returned + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" + % (self.room, original_event_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertIn("chunk", channel.json_body) + self.assertEquals(len(channel.json_body["chunk"]), 1) + + # Redact the original event + channel = self.make_request( + "PUT", + "/rooms/%s/redact/%s/%s" + % (self.room, original_event_id, "test_relations_redaction_redacts_edits"), + access_token=self.user_token, + content="{}", + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Try to check for remaining m.replace relations + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" + % (self.room, original_event_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Check that no relations are returned + self.assertIn("chunk", channel.json_body) + self.assertEquals(channel.json_body["chunk"], []) + + def test_aggregations_redaction_prevents_access_to_aggregations(self): + """Test that annotations of an event are redacted when the original event + is redacted. + """ + # Send a new event + res = self.helper.send(self.room, body="Hello!", tok=self.user_token) + original_event_id = res["event_id"] + + # Add a relation + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", key="👍", parent_id=original_event_id + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Redact the original + channel = self.make_request( + "PUT", + "/rooms/%s/redact/%s/%s" + % ( + self.room, + original_event_id, + "test_aggregations_redaction_prevents_access_to_aggregations", + ), + access_token=self.user_token, + content="{}", + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Check that aggregations returns zero + channel = self.make_request( + "GET", + "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction" + % (self.room, original_event_id), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertIn("chunk", channel.json_body) + self.assertEquals(channel.json_body["chunk"], []) + + def _send_relation( + self, + relation_type, + event_type, + key=None, + content: Optional[dict] = None, + access_token=None, + parent_id=None, + ): + """Helper function to send a relation pointing at `self.parent_id` + + Args: + relation_type (str): One of `RelationTypes` + event_type (str): The type of the event to create + parent_id (str): The event_id this relation relates to. If None, then self.parent_id + key (str|None): The aggregation key used for m.annotation relation + type. + content(dict|None): The content of the created event. + access_token (str|None): The access token used to send the relation, + defaults to `self.user_token` + + Returns: + FakeChannel + """ + if not access_token: + access_token = self.user_token + + query = "" + if key: + query = "?key=" + urllib.parse.quote_plus(key.encode("utf-8")) + + original_id = parent_id if parent_id else self.parent_id + + channel = self.make_request( + "POST", + "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" + % (self.room, original_id, relation_type, event_type, query), + json.dumps(content or {}).encode("utf-8"), + access_token=access_token, + ) + return channel + + def _create_user(self, localpart): + user_id = self.register_user(localpart, "abc123") + access_token = self.login(localpart, "abc123") + + return user_id, access_token diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py new file mode 100644 index 0000000000..ee6b0b9ebf --- /dev/null +++ b/tests/rest/client/test_report_event.py @@ -0,0 +1,82 @@ +# Copyright 2021 Callum Brown +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import synapse.rest.admin +from synapse.rest.client import login, report_event, room + +from tests import unittest + + +class ReportEventTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + report_event.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok, is_public=True + ) + self.helper.join(self.room_id, user=self.admin_user, tok=self.admin_user_tok) + resp = self.helper.send(self.room_id, tok=self.admin_user_tok) + self.event_id = resp["event_id"] + self.report_path = f"rooms/{self.room_id}/report/{self.event_id}" + + def test_reason_str_and_score_int(self): + data = {"reason": "this makes me sad", "score": -100} + self._assert_status(200, data) + + def test_no_reason(self): + data = {"score": 0} + self._assert_status(200, data) + + def test_no_score(self): + data = {"reason": "this makes me sad"} + self._assert_status(200, data) + + def test_no_reason_and_no_score(self): + data = {} + self._assert_status(200, data) + + def test_reason_int_and_score_str(self): + data = {"reason": 10, "score": "string"} + self._assert_status(400, data) + + def test_reason_zero_and_score_blank(self): + data = {"reason": 0, "score": ""} + self._assert_status(400, data) + + def test_reason_and_score_null(self): + data = {"reason": None, "score": None} + self._assert_status(400, data) + + def _assert_status(self, response_status, data): + channel = self.make_request( + "POST", + self.report_path, + json.dumps(data), + access_token=self.other_user_tok, + ) + self.assertEqual( + response_status, int(channel.result["code"]), msg=channel.result["body"] + ) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py new file mode 100644 index 0000000000..0c9cbb9aff --- /dev/null +++ b/tests/rest/client/test_rooms.py @@ -0,0 +1,2150 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests REST events for /rooms paths.""" + +import json +from typing import Iterable +from unittest.mock import Mock, call +from urllib import parse as urlparse + +from twisted.internet import defer + +import synapse.rest.admin +from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.errors import HttpResponseException +from synapse.handlers.pagination import PurgeStatus +from synapse.rest import admin +from synapse.rest.client import account, directory, login, profile, room +from synapse.types import JsonDict, RoomAlias, UserID, create_requester +from synapse.util.stringutils import random_string + +from tests import unittest +from tests.test_utils import make_awaitable + +PATH_PREFIX = b"/_matrix/client/api/v1" + + +class RoomBase(unittest.HomeserverTestCase): + rmcreator_id = None + + servlets = [room.register_servlets, room.register_deprecated_servlets] + + def make_homeserver(self, reactor, clock): + + self.hs = self.setup_test_homeserver( + "red", + federation_http_client=None, + federation_client=Mock(), + ) + + self.hs.get_federation_handler = Mock() + self.hs.get_federation_handler.return_value.maybe_backfill = Mock( + return_value=make_awaitable(None) + ) + + async def _insert_client_ip(*args, **kwargs): + return None + + self.hs.get_datastore().insert_client_ip = _insert_client_ip + + return self.hs + + +class RoomPermissionsTestCase(RoomBase): + """Tests room permissions.""" + + user_id = "@sid1:red" + rmcreator_id = "@notme:red" + + def prepare(self, reactor, clock, hs): + + self.helper.auth_user_id = self.rmcreator_id + # create some rooms under the name rmcreator_id + self.uncreated_rmid = "!aa:test" + self.created_rmid = self.helper.create_room_as( + self.rmcreator_id, is_public=False + ) + self.created_public_rmid = self.helper.create_room_as( + self.rmcreator_id, is_public=True + ) + + # send a message in one of the rooms + self.created_rmid_msg_path = ( + "rooms/%s/send/m.room.message/a1" % (self.created_rmid) + ).encode("ascii") + channel = self.make_request( + "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' + ) + self.assertEquals(200, channel.code, channel.result) + + # set topic for public room + channel = self.make_request( + "PUT", + ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), + b'{"topic":"Public Room Topic"}', + ) + self.assertEquals(200, channel.code, channel.result) + + # auth as user_id now + self.helper.auth_user_id = self.user_id + + def test_can_do_action(self): + msg_content = b'{"msgtype":"m.text","body":"hello"}' + + seq = iter(range(100)) + + def send_msg_path(): + return "/rooms/%s/send/m.room.message/mid%s" % ( + self.created_rmid, + str(next(seq)), + ) + + # send message in uncreated room, expect 403 + channel = self.make_request( + "PUT", + "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), + msg_content, + ) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # send message in created room not joined (no state), expect 403 + channel = self.make_request("PUT", send_msg_path(), msg_content) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # send message in created room and invited, expect 403 + self.helper.invite( + room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id + ) + channel = self.make_request("PUT", send_msg_path(), msg_content) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # send message in created room and joined, expect 200 + self.helper.join(room=self.created_rmid, user=self.user_id) + channel = self.make_request("PUT", send_msg_path(), msg_content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + # send message in created room and left, expect 403 + self.helper.leave(room=self.created_rmid, user=self.user_id) + channel = self.make_request("PUT", send_msg_path(), msg_content) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + def test_topic_perms(self): + topic_content = b'{"topic":"My Topic Name"}' + topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid + + # set/get topic in uncreated room, expect 403 + channel = self.make_request( + "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content + ) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + channel = self.make_request( + "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid + ) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # set/get topic in created PRIVATE room not joined, expect 403 + channel = self.make_request("PUT", topic_path, topic_content) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + channel = self.make_request("GET", topic_path) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # set topic in created PRIVATE room and invited, expect 403 + self.helper.invite( + room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id + ) + channel = self.make_request("PUT", topic_path, topic_content) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # get topic in created PRIVATE room and invited, expect 403 + channel = self.make_request("GET", topic_path) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # set/get topic in created PRIVATE room and joined, expect 200 + self.helper.join(room=self.created_rmid, user=self.user_id) + + # Only room ops can set topic by default + self.helper.auth_user_id = self.rmcreator_id + channel = self.make_request("PUT", topic_path, topic_content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.helper.auth_user_id = self.user_id + + channel = self.make_request("GET", topic_path) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) + + # set/get topic in created PRIVATE room and left, expect 403 + self.helper.leave(room=self.created_rmid, user=self.user_id) + channel = self.make_request("PUT", topic_path, topic_content) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + channel = self.make_request("GET", topic_path) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + # get topic in PUBLIC room, not joined, expect 403 + channel = self.make_request( + "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid + ) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + # set topic in PUBLIC room, not joined, expect 403 + channel = self.make_request( + "PUT", + "/rooms/%s/state/m.room.topic" % self.created_public_rmid, + topic_content, + ) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + def _test_get_membership( + self, room=None, members: Iterable = frozenset(), expect_code=None + ): + for member in members: + path = "/rooms/%s/state/m.room.member/%s" % (room, member) + channel = self.make_request("GET", path) + self.assertEquals(expect_code, channel.code) + + def test_membership_basic_room_perms(self): + # === room does not exist === + room = self.uncreated_rmid + # get membership of self, get membership of other, uncreated room + # expect all 403s + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 + ) + + # trying to invite people to this room should 403 + self.helper.invite( + room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403 + ) + + # set [invite/join/left] of self, set [invite/join/left] of other, + # expect all 404s because room doesn't exist on any server + for usr in [self.user_id, self.rmcreator_id]: + self.helper.join(room=room, user=usr, expect_code=404) + self.helper.leave(room=room, user=usr, expect_code=404) + + def test_membership_private_room_perms(self): + room = self.created_rmid + # get membership of self, get membership of other, private room + invite + # expect all 403s + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 + ) + + # get membership of self, get membership of other, private room + joined + # expect all 200s + self.helper.join(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) + + # get membership of self, get membership of other, private room + left + # expect all 200s + self.helper.leave(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) + + def test_membership_public_room_perms(self): + room = self.created_public_rmid + # get membership of self, get membership of other, public room + invite + # expect 403 + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 + ) + + # get membership of self, get membership of other, public room + joined + # expect all 200s + self.helper.join(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) + + # get membership of self, get membership of other, public room + left + # expect 200. + self.helper.leave(room=room, user=self.user_id) + self._test_get_membership( + members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 + ) + + def test_invited_permissions(self): + room = self.created_rmid + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + + # set [invite/join/left] of other user, expect 403s + self.helper.invite( + room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403 + ) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=self.rmcreator_id, + membership=Membership.JOIN, + expect_code=403, + ) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=self.rmcreator_id, + membership=Membership.LEAVE, + expect_code=403, + ) + + def test_joined_permissions(self): + room = self.created_rmid + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self.helper.join(room=room, user=self.user_id) + + # set invited of self, expect 403 + self.helper.invite( + room=room, src=self.user_id, targ=self.user_id, expect_code=403 + ) + + # set joined of self, expect 200 (NOOP) + self.helper.join(room=room, user=self.user_id) + + other = "@burgundy:red" + # set invited of other, expect 200 + self.helper.invite(room=room, src=self.user_id, targ=other, expect_code=200) + + # set joined of other, expect 403 + self.helper.change_membership( + room=room, + src=self.user_id, + targ=other, + membership=Membership.JOIN, + expect_code=403, + ) + + # set left of other, expect 403 + self.helper.change_membership( + room=room, + src=self.user_id, + targ=other, + membership=Membership.LEAVE, + expect_code=403, + ) + + # set left of self, expect 200 + self.helper.leave(room=room, user=self.user_id) + + def test_leave_permissions(self): + room = self.created_rmid + self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) + self.helper.join(room=room, user=self.user_id) + self.helper.leave(room=room, user=self.user_id) + + # set [invite/join/left] of self, set [invite/join/left] of other, + # expect all 403s + for usr in [self.user_id, self.rmcreator_id]: + self.helper.change_membership( + room=room, + src=self.user_id, + targ=usr, + membership=Membership.INVITE, + expect_code=403, + ) + + self.helper.change_membership( + room=room, + src=self.user_id, + targ=usr, + membership=Membership.JOIN, + expect_code=403, + ) + + # It is always valid to LEAVE if you've already left (currently.) + self.helper.change_membership( + room=room, + src=self.user_id, + targ=self.rmcreator_id, + membership=Membership.LEAVE, + expect_code=403, + ) + + +class RoomsMemberListTestCase(RoomBase): + """Tests /rooms/$room_id/members/list REST events.""" + + user_id = "@sid1:red" + + def test_get_member_list(self): + room_id = self.helper.create_room_as(self.user_id) + channel = self.make_request("GET", "/rooms/%s/members" % room_id) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + def test_get_member_list_no_room(self): + channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + def test_get_member_list_no_permission(self): + room_id = self.helper.create_room_as("@some_other_guy:red") + channel = self.make_request("GET", "/rooms/%s/members" % room_id) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + def test_get_member_list_mixed_memberships(self): + room_creator = "@some_other_guy:red" + room_id = self.helper.create_room_as(room_creator) + room_path = "/rooms/%s/members" % room_id + self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) + # can't see list if you're just invited. + channel = self.make_request("GET", room_path) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + self.helper.join(room=room_id, user=self.user_id) + # can see list now joined + channel = self.make_request("GET", room_path) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + self.helper.leave(room=room_id, user=self.user_id) + # can see old list once left + channel = self.make_request("GET", room_path) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + +class RoomsCreateTestCase(RoomBase): + """Tests /rooms and /rooms/$room_id REST events.""" + + user_id = "@sid1:red" + + def test_post_room_no_keys(self): + # POST with no config keys, expect new room id + channel = self.make_request("POST", "/createRoom", "{}") + + self.assertEquals(200, channel.code, channel.result) + self.assertTrue("room_id" in channel.json_body) + + def test_post_room_visibility_key(self): + # POST with visibility config key, expect new room id + channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') + self.assertEquals(200, channel.code) + self.assertTrue("room_id" in channel.json_body) + + def test_post_room_custom_key(self): + # POST with custom config keys, expect new room id + channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') + self.assertEquals(200, channel.code) + self.assertTrue("room_id" in channel.json_body) + + def test_post_room_known_and_unknown_keys(self): + # POST with custom + known config keys, expect new room id + channel = self.make_request( + "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' + ) + self.assertEquals(200, channel.code) + self.assertTrue("room_id" in channel.json_body) + + def test_post_room_invalid_content(self): + # POST with invalid content / paths, expect 400 + channel = self.make_request("POST", "/createRoom", b'{"visibili') + self.assertEquals(400, channel.code) + + channel = self.make_request("POST", "/createRoom", b'["hello"]') + self.assertEquals(400, channel.code) + + def test_post_room_invitees_invalid_mxid(self): + # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 + # Note the trailing space in the MXID here! + channel = self.make_request( + "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' + ) + self.assertEquals(400, channel.code) + + @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}}) + def test_post_room_invitees_ratelimit(self): + """Test that invites sent when creating a room are ratelimited by a RateLimiter, + which ratelimits them correctly, including by not limiting when the requester is + exempt from ratelimiting. + """ + + # Build the request's content. We use local MXIDs because invites over federation + # are more difficult to mock. + content = json.dumps( + { + "invite": [ + "@alice1:red", + "@alice2:red", + "@alice3:red", + "@alice4:red", + ] + } + ).encode("utf8") + + # Test that the invites are correctly ratelimited. + channel = self.make_request("POST", "/createRoom", content) + self.assertEqual(400, channel.code) + self.assertEqual( + "Cannot invite so many users at once", + channel.json_body["error"], + ) + + # Add the current user to the ratelimit overrides, allowing them no ratelimiting. + self.get_success( + self.hs.get_datastore().set_ratelimit_for_user(self.user_id, 0, 0) + ) + + # Test that the invites aren't ratelimited anymore. + channel = self.make_request("POST", "/createRoom", content) + self.assertEqual(200, channel.code) + + +class RoomTopicTestCase(RoomBase): + """Tests /rooms/$room_id/topic REST events.""" + + user_id = "@sid1:red" + + def prepare(self, reactor, clock, hs): + # create the room + self.room_id = self.helper.create_room_as(self.user_id) + self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,) + + def test_invalid_puts(self): + # missing keys or invalid json + channel = self.make_request("PUT", self.path, "{}") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", self.path, '{"_name":"bo"}') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", self.path, '{"nao') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request( + "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' + ) + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", self.path, "text only") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", self.path, "") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + # valid key, wrong type + content = '{"topic":["Topic name"]}' + channel = self.make_request("PUT", self.path, content) + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + def test_rooms_topic(self): + # nothing should be there + channel = self.make_request("GET", self.path) + self.assertEquals(404, channel.code, msg=channel.result["body"]) + + # valid put + content = '{"topic":"Topic name"}' + channel = self.make_request("PUT", self.path, content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + # valid get + channel = self.make_request("GET", self.path) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assert_dict(json.loads(content), channel.json_body) + + def test_rooms_topic_with_extra_keys(self): + # valid put with extra keys + content = '{"topic":"Seasons","subtopic":"Summer"}' + channel = self.make_request("PUT", self.path, content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + # valid get + channel = self.make_request("GET", self.path) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assert_dict(json.loads(content), channel.json_body) + + +class RoomMemberStateTestCase(RoomBase): + """Tests /rooms/$room_id/members/$user_id/state REST events.""" + + user_id = "@sid1:red" + + def prepare(self, reactor, clock, hs): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_invalid_puts(self): + path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) + # missing keys or invalid json + channel = self.make_request("PUT", path, "{}") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, '{"_name":"bo"}') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, '{"nao') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, "text only") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, "") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + # valid keys, wrong types + content = '{"membership":["%s","%s","%s"]}' % ( + Membership.INVITE, + Membership.JOIN, + Membership.LEAVE, + ) + channel = self.make_request("PUT", path, content.encode("ascii")) + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + def test_rooms_members_self(self): + path = "/rooms/%s/state/m.room.member/%s" % ( + urlparse.quote(self.room_id), + self.user_id, + ) + + # valid join message (NOOP since we made the room) + content = '{"membership":"%s"}' % Membership.JOIN + channel = self.make_request("PUT", path, content.encode("ascii")) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + channel = self.make_request("GET", path, None) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + expected_response = {"membership": Membership.JOIN} + self.assertEquals(expected_response, channel.json_body) + + def test_rooms_members_other(self): + self.other_id = "@zzsid1:red" + path = "/rooms/%s/state/m.room.member/%s" % ( + urlparse.quote(self.room_id), + self.other_id, + ) + + # valid invite message + content = '{"membership":"%s"}' % Membership.INVITE + channel = self.make_request("PUT", path, content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + channel = self.make_request("GET", path, None) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEquals(json.loads(content), channel.json_body) + + def test_rooms_members_other_custom_keys(self): + self.other_id = "@zzsid1:red" + path = "/rooms/%s/state/m.room.member/%s" % ( + urlparse.quote(self.room_id), + self.other_id, + ) + + # valid invite message with custom key + content = '{"membership":"%s","invite_text":"%s"}' % ( + Membership.INVITE, + "Join us!", + ) + channel = self.make_request("PUT", path, content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + channel = self.make_request("GET", path, None) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + self.assertEquals(json.loads(content), channel.json_body) + + +class RoomInviteRatelimitTestCase(RoomBase): + user_id = "@sid1:red" + + servlets = [ + admin.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + @unittest.override_config( + {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invites_by_rooms_ratelimit(self): + """Tests that invites in a room are actually rate-limited.""" + room_id = self.helper.create_room_as(self.user_id) + + for i in range(3): + self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,)) + + self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429) + + @unittest.override_config( + {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invites_by_users_ratelimit(self): + """Tests that invites to a specific user are actually rate-limited.""" + + for _ in range(3): + room_id = self.helper.create_room_as(self.user_id) + self.helper.invite(room_id, self.user_id, "@other-users:red") + + room_id = self.helper.create_room_as(self.user_id) + self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429) + + +class RoomJoinRatelimitTestCase(RoomBase): + user_id = "@sid1:red" + + servlets = [ + admin.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_join_local_ratelimit(self): + """Tests that local joins are actually rate-limited.""" + for _ in range(3): + self.helper.create_room_as(self.user_id) + + self.helper.create_room_as(self.user_id, expect_code=429) + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_join_local_ratelimit_profile_change(self): + """Tests that sending a profile update into all of the user's joined rooms isn't + rate-limited by the rate-limiter on joins.""" + + # Create and join as many rooms as the rate-limiting config allows in a second. + room_ids = [ + self.helper.create_room_as(self.user_id), + self.helper.create_room_as(self.user_id), + self.helper.create_room_as(self.user_id), + ] + # Let some time for the rate-limiter to forget about our multi-join. + self.reactor.advance(2) + # Add one to make sure we're joined to more rooms than the config allows us to + # join in a second. + room_ids.append(self.helper.create_room_as(self.user_id)) + + # Create a profile for the user, since it hasn't been done on registration. + store = self.hs.get_datastore() + self.get_success( + store.create_profile(UserID.from_string(self.user_id).localpart) + ) + + # Update the display name for the user. + path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id + channel = self.make_request("PUT", path, {"displayname": "John Doe"}) + self.assertEquals(channel.code, 200, channel.json_body) + + # Check that all the rooms have been sent a profile update into. + for room_id in room_ids: + path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % ( + room_id, + self.user_id, + ) + + channel = self.make_request("GET", path) + self.assertEquals(channel.code, 200) + + self.assertIn("displayname", channel.json_body) + self.assertEquals(channel.json_body["displayname"], "John Doe") + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_join_local_ratelimit_idempotent(self): + """Tests that the room join endpoints remain idempotent despite rate-limiting + on room joins.""" + room_id = self.helper.create_room_as(self.user_id) + + # Let's test both paths to be sure. + paths_to_test = [ + "/_matrix/client/r0/rooms/%s/join", + "/_matrix/client/r0/join/%s", + ] + + for path in paths_to_test: + # Make sure we send more requests than the rate-limiting config would allow + # if all of these requests ended up joining the user to a room. + for _ in range(4): + channel = self.make_request("POST", path % room_id, {}) + self.assertEquals(channel.code, 200) + + @unittest.override_config( + { + "rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}, + "auto_join_rooms": ["#room:red", "#room2:red", "#room3:red", "#room4:red"], + "autocreate_auto_join_rooms": True, + }, + ) + def test_autojoin_rooms(self): + user_id = self.register_user("testuser", "password") + + # Check that the new user successfully joined the four rooms + rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id)) + self.assertEqual(len(rooms), 4) + + +class RoomMessagesTestCase(RoomBase): + """Tests /rooms/$room_id/messages/$user_id/$msg_id REST events.""" + + user_id = "@sid1:red" + + def prepare(self, reactor, clock, hs): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_invalid_puts(self): + path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) + # missing keys or invalid json + channel = self.make_request("PUT", path, b"{}") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, b'{"_name":"bo"}') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, b'{"nao') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, b"text only") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + channel = self.make_request("PUT", path, b"") + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + def test_rooms_messages_sent(self): + path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) + + content = b'{"body":"test","msgtype":{"type":"a"}}' + channel = self.make_request("PUT", path, content) + self.assertEquals(400, channel.code, msg=channel.result["body"]) + + # custom message types + content = b'{"body":"test","msgtype":"test.custom.text"}' + channel = self.make_request("PUT", path, content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + # m.text message type + path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) + content = b'{"body":"test2","msgtype":"m.text"}' + channel = self.make_request("PUT", path, content) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + +class RoomInitialSyncTestCase(RoomBase): + """Tests /rooms/$room_id/initialSync.""" + + user_id = "@sid1:red" + + def prepare(self, reactor, clock, hs): + # create the room + self.room_id = self.helper.create_room_as(self.user_id) + + def test_initial_sync(self): + channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) + self.assertEquals(200, channel.code) + + self.assertEquals(self.room_id, channel.json_body["room_id"]) + self.assertEquals("join", channel.json_body["membership"]) + + # Room state is easier to assert on if we unpack it into a dict + state = {} + for event in channel.json_body["state"]: + if "state_key" not in event: + continue + t = event["type"] + if t not in state: + state[t] = [] + state[t].append(event) + + self.assertTrue("m.room.create" in state) + + self.assertTrue("messages" in channel.json_body) + self.assertTrue("chunk" in channel.json_body["messages"]) + self.assertTrue("end" in channel.json_body["messages"]) + + self.assertTrue("presence" in channel.json_body) + + presence_by_user = { + e["content"]["user_id"]: e for e in channel.json_body["presence"] + } + self.assertTrue(self.user_id in presence_by_user) + self.assertEquals("m.presence", presence_by_user[self.user_id]["type"]) + + +class RoomMessageListTestCase(RoomBase): + """Tests /rooms/$room_id/messages REST events.""" + + user_id = "@sid1:red" + + def prepare(self, reactor, clock, hs): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_topo_token_is_accepted(self): + token = "t1-0_0_0_0_0_0_0_0_0" + channel = self.make_request( + "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) + ) + self.assertEquals(200, channel.code) + self.assertTrue("start" in channel.json_body) + self.assertEquals(token, channel.json_body["start"]) + self.assertTrue("chunk" in channel.json_body) + self.assertTrue("end" in channel.json_body) + + def test_stream_token_is_accepted_for_fwd_pagianation(self): + token = "s0_0_0_0_0_0_0_0_0" + channel = self.make_request( + "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) + ) + self.assertEquals(200, channel.code) + self.assertTrue("start" in channel.json_body) + self.assertEquals(token, channel.json_body["start"]) + self.assertTrue("chunk" in channel.json_body) + self.assertTrue("end" in channel.json_body) + + def test_room_messages_purge(self): + store = self.hs.get_datastore() + pagination_handler = self.hs.get_pagination_handler() + + # Send a first message in the room, which will be removed by the purge. + first_event_id = self.helper.send(self.room_id, "message 1")["event_id"] + first_token = self.get_success( + store.get_topological_token_for_event(first_event_id) + ) + first_token_str = self.get_success(first_token.to_string(store)) + + # Send a second message in the room, which won't be removed, and which we'll + # use as the marker to purge events before. + second_event_id = self.helper.send(self.room_id, "message 2")["event_id"] + second_token = self.get_success( + store.get_topological_token_for_event(second_event_id) + ) + second_token_str = self.get_success(second_token.to_string(store)) + + # Send a third event in the room to ensure we don't fall under any edge case + # due to our marker being the latest forward extremity in the room. + self.helper.send(self.room_id, "message 3") + + # Check that we get the first and second message when querying /messages. + channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % ( + self.room_id, + second_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) + + # Purge every event before the second event. + purge_id = random_string(16) + pagination_handler._purges_by_id[purge_id] = PurgeStatus() + self.get_success( + pagination_handler._purge_history( + purge_id=purge_id, + room_id=self.room_id, + token=second_token_str, + delete_local_events=True, + ) + ) + + # Check that we only get the second message through /message now that the first + # has been purged. + channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % ( + self.room_id, + second_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 1, [event["content"] for event in chunk]) + + # Check that we get no event, but also no error, when querying /messages with + # the token that was pointing at the first event, because we don't have it + # anymore. + channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % ( + self.room_id, + first_token_str, + json.dumps({"types": [EventTypes.Message]}), + ), + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) + + +class RoomSearchTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + user_id = True + hijack_auth = False + + def prepare(self, reactor, clock, hs): + + # Register the user who does the searching + self.user_id = self.register_user("user", "pass") + self.access_token = self.login("user", "pass") + + # Register the user who sends the message + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + # Create a room + self.room = self.helper.create_room_as(self.user_id, tok=self.access_token) + + # Invite the other person + self.helper.invite( + room=self.room, + src=self.user_id, + tok=self.access_token, + targ=self.other_user_id, + ) + + # The other user joins + self.helper.join( + room=self.room, user=self.other_user_id, tok=self.other_access_token + ) + + def test_finds_message(self): + """ + The search functionality will search for content in messages if asked to + do so. + """ + # The other user sends some messages + self.helper.send(self.room, body="Hi!", tok=self.other_access_token) + self.helper.send(self.room, body="There!", tok=self.other_access_token) + + channel = self.make_request( + "POST", + "/search?access_token=%s" % (self.access_token,), + { + "search_categories": { + "room_events": {"keys": ["content.body"], "search_term": "Hi"} + } + }, + ) + + # Check we get the results we expect -- one search result, of the sent + # messages + self.assertEqual(channel.code, 200) + results = channel.json_body["search_categories"]["room_events"] + self.assertEqual(results["count"], 1) + self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") + + # No context was requested, so we should get none. + self.assertEqual(results["results"][0]["context"], {}) + + def test_include_context(self): + """ + When event_context includes include_profile, profile information will be + included in the search response. + """ + # The other user sends some messages + self.helper.send(self.room, body="Hi!", tok=self.other_access_token) + self.helper.send(self.room, body="There!", tok=self.other_access_token) + + channel = self.make_request( + "POST", + "/search?access_token=%s" % (self.access_token,), + { + "search_categories": { + "room_events": { + "keys": ["content.body"], + "search_term": "Hi", + "event_context": {"include_profile": True}, + } + } + }, + ) + + # Check we get the results we expect -- one search result, of the sent + # messages + self.assertEqual(channel.code, 200) + results = channel.json_body["search_categories"]["room_events"] + self.assertEqual(results["count"], 1) + self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") + + # We should get context info, like the two users, and the display names. + context = results["results"][0]["context"] + self.assertEqual(len(context["profile_info"].keys()), 2) + self.assertEqual( + context["profile_info"][self.other_user_id]["displayname"], "otheruser" + ) + + +class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + + self.url = b"/_matrix/client/r0/publicRooms" + + config = self.default_config() + config["allow_public_rooms_without_auth"] = False + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def test_restricted_no_auth(self): + channel = self.make_request("GET", self.url) + self.assertEqual(channel.code, 401, channel.result) + + def test_restricted_auth(self): + self.register_user("user", "pass") + tok = self.login("user", "pass") + + channel = self.make_request("GET", self.url, access_token=tok) + self.assertEqual(channel.code, 200, channel.result) + + +class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): + """Test that we correctly fallback to local filtering if a remote server + doesn't support search. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver(federation_client=Mock()) + + def prepare(self, reactor, clock, hs): + self.register_user("user", "pass") + self.token = self.login("user", "pass") + + self.federation_client = hs.get_federation_client() + + def test_simple(self): + "Simple test for searching rooms over federation" + self.federation_client.get_public_rooms.side_effect = ( + lambda *a, **k: defer.succeed({}) + ) + + search_filter = {"generic_search_term": "foobar"} + + channel = self.make_request( + "POST", + b"/_matrix/client/r0/publicRooms?server=testserv", + content={"filter": search_filter}, + access_token=self.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.federation_client.get_public_rooms.assert_called_once_with( + "testserv", + limit=100, + since_token=None, + search_filter=search_filter, + include_all_networks=False, + third_party_instance_id=None, + ) + + def test_fallback(self): + "Test that searching public rooms over federation falls back if it gets a 404" + + # The `get_public_rooms` should be called again if the first call fails + # with a 404, when using search filters. + self.federation_client.get_public_rooms.side_effect = ( + HttpResponseException(404, "Not Found", b""), + defer.succeed({}), + ) + + search_filter = {"generic_search_term": "foobar"} + + channel = self.make_request( + "POST", + b"/_matrix/client/r0/publicRooms?server=testserv", + content={"filter": search_filter}, + access_token=self.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.federation_client.get_public_rooms.assert_has_calls( + [ + call( + "testserv", + limit=100, + since_token=None, + search_filter=search_filter, + include_all_networks=False, + third_party_instance_id=None, + ), + call( + "testserv", + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ), + ] + ) + + +class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + profile.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["allow_per_room_profiles"] = False + self.hs = self.setup_test_homeserver(config=config) + + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("test", "test") + self.tok = self.login("test", "test") + + # Set a profile for the test user + self.displayname = "test user" + data = {"displayname": self.displayname} + request_data = json.dumps(data) + channel = self.make_request( + "PUT", + "/_matrix/client/r0/profile/%s/displayname" % (self.user_id,), + request_data, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + def test_per_room_profile_forbidden(self): + data = {"membership": "join", "displayname": "other test user"} + request_data = json.dumps(data) + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" + % (self.room_id, self.user_id), + request_data, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + res_displayname = channel.json_body["content"]["displayname"] + self.assertEqual(res_displayname, self.displayname, channel.result) + + +class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): + """Tests that clients can add a "reason" field to membership events and + that they get correctly added to the generated events and propagated. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) + + def test_join_reason(self): + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/join", + content={"reason": reason}, + access_token=self.second_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_leave_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/leave", + content={"reason": reason}, + access_token=self.second_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_kick_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/kick", + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.second_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_ban_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/ban", + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_unban_reason(self): + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/unban", + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_invite_reason(self): + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/invite", + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_reject_invite_reason(self): + self.helper.invite( + self.room_id, + src=self.creator, + targ=self.second_user_id, + tok=self.creator_tok, + ) + + reason = "hello" + channel = self.make_request( + "POST", + f"/_matrix/client/r0/rooms/{self.room_id}/leave", + content={"reason": reason}, + access_token=self.second_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def _check_for_reason(self, reason): + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format( + self.room_id, self.second_user_id + ), + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + event_content = channel.json_body + + self.assertEqual(event_content.get("reason"), reason, channel.result) + + +class LabelsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + profile.register_servlets, + ] + + # Filter that should only catch messages with the label "#fun". + FILTER_LABELS = { + "types": [EventTypes.Message], + "org.matrix.labels": ["#fun"], + } + # Filter that should only catch messages without the label "#fun". + FILTER_NOT_LABELS = { + "types": [EventTypes.Message], + "org.matrix.not_labels": ["#fun"], + } + # Filter that should only catch messages with the label "#work" but without the label + # "#notfun". + FILTER_LABELS_NOT_LABELS = { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("test", "test") + self.tok = self.login("test", "test") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + def test_context_filter_labels(self): + """Test that we can filter by a label on a /context request.""" + event_id = self._send_labelled_messages_in_room() + + channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 1, [event["content"] for event in events_before] + ) + self.assertEqual( + events_before[0]["content"]["body"], "with right label", events_before[0] + ) + + events_after = channel.json_body["events_before"] + + self.assertEqual( + len(events_after), 1, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with right label", events_after[0] + ) + + def test_context_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /context request.""" + event_id = self._send_labelled_messages_in_room() + + channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 1, [event["content"] for event in events_before] + ) + self.assertEqual( + events_before[0]["content"]["body"], "without label", events_before[0] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual( + len(events_after), 2, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with wrong label", events_after[0] + ) + self.assertEqual( + events_after[1]["content"]["body"], "with two wrong labels", events_after[1] + ) + + def test_context_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /context request. + """ + event_id = self._send_labelled_messages_in_room() + + channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 0, [event["content"] for event in events_before] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual( + len(events_after), 1, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with wrong label", events_after[0] + ) + + def test_messages_filter_labels(self): + """Test that we can filter by a label on a /messages request.""" + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)), + ) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 2, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) + + def test_messages_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /messages request.""" + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)), + ) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 4, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "without label", events[0]) + self.assertEqual(events[1]["content"]["body"], "without label", events[1]) + self.assertEqual(events[2]["content"]["body"], "with wrong label", events[2]) + self.assertEqual( + events[3]["content"]["body"], "with two wrong labels", events[3] + ) + + def test_messages_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /messages request. + """ + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % ( + self.room_id, + self.tok, + token, + json.dumps(self.FILTER_LABELS_NOT_LABELS), + ), + ) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 1, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) + + def test_search_filter_labels(self): + """Test that we can filter by a label on a /search request.""" + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), + 2, + [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "with right label", + results[0]["result"]["content"]["body"], + ) + self.assertEqual( + results[1]["result"]["content"]["body"], + "with right label", + results[1]["result"]["content"]["body"], + ) + + def test_search_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /search request.""" + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_NOT_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), + 4, + [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "without label", + results[0]["result"]["content"]["body"], + ) + self.assertEqual( + results[1]["result"]["content"]["body"], + "without label", + results[1]["result"]["content"]["body"], + ) + self.assertEqual( + results[2]["result"]["content"]["body"], + "with wrong label", + results[2]["result"]["content"]["body"], + ) + self.assertEqual( + results[3]["result"]["content"]["body"], + "with two wrong labels", + results[3]["result"]["content"]["body"], + ) + + def test_search_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /search request. + """ + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS_NOT_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), + 1, + [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "with wrong label", + results[0]["result"]["content"]["body"], + ) + + def _send_labelled_messages_in_room(self): + """Sends several messages to a room with different labels (or without any) to test + filtering by label. + Returns: + The ID of the event to use if we're testing filtering on /context. + """ + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=self.tok, + ) + + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=self.tok, + ) + # Return this event's ID when we test filtering in /context requests. + event_id = res["event_id"] + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with wrong label", + EventContentFields.LABELS: ["#work"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with two wrong labels", + EventContentFields.LABELS: ["#work", "#notfun"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=self.tok, + ) + + return event_id + + +class ContextTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + account.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.tok = self.login("user", "password") + self.room_id = self.helper.create_room_as( + self.user_id, tok=self.tok, is_public=False + ) + + self.other_user_id = self.register_user("user2", "password") + self.other_tok = self.login("user2", "password") + + self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok) + self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok) + + def test_erased_sender(self): + """Test that an erasure request results in the requester's events being hidden + from any new member of the room. + """ + + # Send a bunch of events in the room. + + self.helper.send(self.room_id, "message 1", tok=self.tok) + self.helper.send(self.room_id, "message 2", tok=self.tok) + event_id = self.helper.send(self.room_id, "message 3", tok=self.tok)["event_id"] + self.helper.send(self.room_id, "message 4", tok=self.tok) + self.helper.send(self.room_id, "message 5", tok=self.tok) + + # Check that we can still see the messages before the erasure request. + + channel = self.make_request( + "GET", + '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' + % (self.room_id, event_id), + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual(len(events_before), 2, events_before) + self.assertEqual( + events_before[0].get("content", {}).get("body"), + "message 2", + events_before[0], + ) + self.assertEqual( + events_before[1].get("content", {}).get("body"), + "message 1", + events_before[1], + ) + + self.assertEqual( + channel.json_body["event"].get("content", {}).get("body"), + "message 3", + channel.json_body["event"], + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual(len(events_after), 2, events_after) + self.assertEqual( + events_after[0].get("content", {}).get("body"), + "message 4", + events_after[0], + ) + self.assertEqual( + events_after[1].get("content", {}).get("body"), + "message 5", + events_after[1], + ) + + # Deactivate the first account and erase the user's data. + + deactivate_account_handler = self.hs.get_deactivate_account_handler() + self.get_success( + deactivate_account_handler.deactivate_account( + self.user_id, True, create_requester(self.user_id) + ) + ) + + # Invite another user in the room. This is needed because messages will be + # pruned only if the user wasn't a member of the room when the messages were + # sent. + + invited_user_id = self.register_user("user3", "password") + invited_tok = self.login("user3", "password") + + self.helper.invite( + self.room_id, self.other_user_id, invited_user_id, tok=self.other_tok + ) + self.helper.join(self.room_id, invited_user_id, tok=invited_tok) + + # Check that a user that joined the room after the erasure request can't see + # the messages anymore. + + channel = self.make_request( + "GET", + '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' + % (self.room_id, event_id), + access_token=invited_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual(len(events_before), 2, events_before) + self.assertDictEqual(events_before[0].get("content"), {}, events_before[0]) + self.assertDictEqual(events_before[1].get("content"), {}, events_before[1]) + + self.assertDictEqual( + channel.json_body["event"].get("content"), {}, channel.json_body["event"] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual(len(events_after), 2, events_after) + self.assertDictEqual(events_after[0].get("content"), {}, events_after[0]) + self.assertEqual(events_after[1].get("content"), {}, events_after[1]) + + +class RoomAliasListTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + def test_no_aliases(self): + res = self._get_aliases(self.room_owner_tok) + self.assertEqual(res["aliases"], []) + + def test_not_in_room(self): + self.register_user("user", "test") + user_tok = self.login("user", "test") + res = self._get_aliases(user_tok, expected_code=403) + self.assertEqual(res["errcode"], "M_FORBIDDEN") + + def test_admin_user(self): + alias1 = self._random_alias() + self._set_alias_via_directory(alias1) + + self.register_user("user", "test", admin=True) + user_tok = self.login("user", "test") + + res = self._get_aliases(user_tok) + self.assertEqual(res["aliases"], [alias1]) + + def test_with_aliases(self): + alias1 = self._random_alias() + alias2 = self._random_alias() + + self._set_alias_via_directory(alias1) + self._set_alias_via_directory(alias2) + + res = self._get_aliases(self.room_owner_tok) + self.assertEqual(set(res["aliases"]), {alias1, alias2}) + + def test_peekable_room(self): + alias1 = self._random_alias() + self._set_alias_via_directory(alias1) + + self.helper.send_state( + self.room_id, + EventTypes.RoomHistoryVisibility, + body={"history_visibility": "world_readable"}, + tok=self.room_owner_tok, + ) + + self.register_user("user", "test") + user_tok = self.login("user", "test") + + res = self._get_aliases(user_tok) + self.assertEqual(res["aliases"], [alias1]) + + def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/aliases" % (self.room_id,), + access_token=access_token, + ) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + if expected_code == 200: + self.assertIsInstance(res["aliases"], list) + return res + + def _random_alias(self) -> str: + return RoomAlias(random_string(5), self.hs.hostname).to_string() + + def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + url = "/_matrix/client/r0/directory/room/" + alias + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + channel = self.make_request( + "PUT", url, request_data, access_token=self.room_owner_tok + ) + self.assertEqual(channel.code, expected_code, channel.result) + + +class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.room_owner = self.register_user("room_owner", "test") + self.room_owner_tok = self.login("room_owner", "test") + + self.room_id = self.helper.create_room_as( + self.room_owner, tok=self.room_owner_tok + ) + + self.alias = "#alias:test" + self._set_alias_via_directory(self.alias) + + def _set_alias_via_directory(self, alias: str, expected_code: int = 200): + url = "/_matrix/client/r0/directory/room/" + alias + data = {"room_id": self.room_id} + request_data = json.dumps(data) + + channel = self.make_request( + "PUT", url, request_data, access_token=self.room_owner_tok + ) + self.assertEqual(channel.code, expected_code, channel.result) + + def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + channel = self.make_request( + "GET", + "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), + access_token=self.room_owner_tok, + ) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + return res + + def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict: + """Calls the endpoint under test. returns the json response object.""" + channel = self.make_request( + "PUT", + "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), + json.dumps(content), + access_token=self.room_owner_tok, + ) + self.assertEqual(channel.code, expected_code, channel.result) + res = channel.json_body + self.assertIsInstance(res, dict) + return res + + def test_canonical_alias(self): + """Test a basic alias message.""" + # There is no canonical alias to start with. + self._get_canonical_alias(expected_code=404) + + # Create an alias. + self._set_canonical_alias({"alias": self.alias}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias}) + + # Now remove the alias. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_alt_aliases(self): + """Test a canonical alias message with alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alt_aliases": [self.alias]}) + + # Now remove the alt_aliases. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_alias_alt_aliases(self): + """Test a canonical alias message with an alias and alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) + + # Now remove the alias and alt_aliases. + self._set_canonical_alias({}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {}) + + def test_partial_modify(self): + """Test removing only the alt_aliases.""" + # Create an alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) + + # Now remove the alt_aliases. + self._set_canonical_alias({"alias": self.alias}) + + # There is an alias event, but it is empty. + res = self._get_canonical_alias() + self.assertEqual(res, {"alias": self.alias}) + + def test_add_alias(self): + """Test removing only the alt_aliases.""" + # Create an additional alias. + second_alias = "#second:test" + self._set_alias_via_directory(second_alias) + + # Add the canonical alias. + self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) + + # Then add the second alias. + self._set_canonical_alias( + {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} + ) + + # Canonical alias now exists! + res = self._get_canonical_alias() + self.assertEqual( + res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} + ) + + def test_bad_data(self): + """Invalid data for alt_aliases should cause errors.""" + self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400) + self._set_canonical_alias({"alt_aliases": None}, expected_code=400) + self._set_canonical_alias({"alt_aliases": 0}, expected_code=400) + self._set_canonical_alias({"alt_aliases": 1}, expected_code=400) + self._set_canonical_alias({"alt_aliases": False}, expected_code=400) + self._set_canonical_alias({"alt_aliases": True}, expected_code=400) + self._set_canonical_alias({"alt_aliases": {}}, expected_code=400) + + def test_bad_alias(self): + """An alias which does not point to the room raises a SynapseError.""" + self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) + self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py new file mode 100644 index 0000000000..6db7062a8e --- /dev/null +++ b/tests/rest/client/test_sendtodevice.py @@ -0,0 +1,200 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest import admin +from synapse.rest.client import login, sendtodevice, sync + +from tests.unittest import HomeserverTestCase, override_config + + +class SendToDeviceTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + sendtodevice.register_servlets, + sync.register_servlets, + ] + + def test_user_to_user(self): + """A to-device message from one user to another should get delivered""" + + user1 = self.register_user("u1", "pass") + user1_tok = self.login("u1", "pass", "d1") + + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + # send the message + test_msg = {"foo": "bar"} + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.test/1234", + content={"messages": {user2: {"d2": test_msg}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + # check it appears + channel = self.make_request("GET", "/sync", access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + expected_result = { + "events": [ + { + "sender": user1, + "type": "m.test", + "content": test_msg, + } + ] + } + self.assertEqual(channel.json_body["to_device"], expected_result) + + # it should re-appear if we do another sync + channel = self.make_request("GET", "/sync", access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["to_device"], expected_result) + + # it should *not* appear if we do an incremental sync + sync_token = channel.json_body["next_batch"] + channel = self.make_request( + "GET", f"/sync?since={sync_token}", access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), []) + + @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) + def test_local_room_key_request(self): + """m.room_key_request has special-casing; test from local user""" + user1 = self.register_user("u1", "pass") + user1_tok = self.login("u1", "pass", "d1") + + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + # send three messages + for i in range(3): + chan = self.make_request( + "PUT", + f"/_matrix/client/r0/sendToDevice/m.room_key_request/{i}", + content={"messages": {user2: {"d2": {"idx": i}}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + # now sync: we should get two of the three + channel = self.make_request("GET", "/sync", access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 2) + for i in range(2): + self.assertEqual( + msgs[i], + {"sender": user1, "type": "m.room_key_request", "content": {"idx": i}}, + ) + sync_token = channel.json_body["next_batch"] + + # ... time passes + self.reactor.advance(1) + + # and we can send more messages + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.room_key_request/3", + content={"messages": {user2: {"d2": {"idx": 3}}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + # ... which should arrive + channel = self.make_request( + "GET", f"/sync?since={sync_token}", access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 1) + self.assertEqual( + msgs[0], + {"sender": user1, "type": "m.room_key_request", "content": {"idx": 3}}, + ) + + @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) + def test_remote_room_key_request(self): + """m.room_key_request has special-casing; test from remote user""" + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + federation_registry = self.hs.get_federation_registry() + + # send three messages + for i in range(3): + self.get_success( + federation_registry.on_edu( + "m.direct_to_device", + "remote_server", + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "messages": {user2: {"d2": {"idx": i}}}, + "message_id": f"{i}", + }, + ) + ) + + # now sync: we should get two of the three + channel = self.make_request("GET", "/sync", access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 2) + for i in range(2): + self.assertEqual( + msgs[i], + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "content": {"idx": i}, + }, + ) + sync_token = channel.json_body["next_batch"] + + # ... time passes + self.reactor.advance(1) + + # and we can send more messages + self.get_success( + federation_registry.on_edu( + "m.direct_to_device", + "remote_server", + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "messages": {user2: {"d2": {"idx": 3}}}, + "message_id": "3", + }, + ) + ) + + # ... which should arrive + channel = self.make_request( + "GET", f"/sync?since={sync_token}", access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + msgs = channel.json_body["to_device"]["events"] + self.assertEqual(len(msgs), 1) + self.assertEqual( + msgs[0], + { + "sender": "@user:remote_server", + "type": "m.room_key_request", + "content": {"idx": 3}, + }, + ) diff --git a/tests/rest/client/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py new file mode 100644 index 0000000000..283eccd53f --- /dev/null +++ b/tests/rest/client/test_shared_rooms.py @@ -0,0 +1,142 @@ +# Copyright 2020 Half-Shot +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import synapse.rest.admin +from synapse.rest.client import login, room, shared_rooms + +from tests import unittest +from tests.server import FakeChannel + + +class UserSharedRoomsTest(unittest.HomeserverTestCase): + """ + Tests the UserSharedRoomsServlet. + """ + + servlets = [ + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + shared_rooms.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["update_user_directory"] = True + return self.setup_test_homeserver(config=config) + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_user_directory_handler() + + def _get_shared_rooms(self, token, other_user) -> FakeChannel: + return self.make_request( + "GET", + "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" + % other_user, + access_token=token, + ) + + def test_shared_room_list_public(self): + """ + A room should show up in the shared list of rooms between two users + if it is public. + """ + self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True) + + def test_shared_room_list_private(self): + """ + A room should show up in the shared list of rooms between two users + if it is private. + """ + self._check_shared_rooms_with( + room_one_is_public=False, room_two_is_public=False + ) + + def test_shared_room_list_mixed(self): + """ + The shared room list between two users should contain both public and private + rooms. + """ + self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False) + + def _check_shared_rooms_with( + self, room_one_is_public: bool, room_two_is_public: bool + ): + """Checks that shared public or private rooms between two users appear in + their shared room lists + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + # Create a room. user1 invites user2, who joins + room_id_one = self.helper.create_room_as( + u1, is_public=room_one_is_public, tok=u1_token + ) + self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token) + self.helper.join(room_id_one, user=u2, tok=u2_token) + + # Check shared rooms from user1's perspective. + # We should see the one room in common + channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 1) + self.assertEquals(channel.json_body["joined"][0], room_id_one) + + # Create another room and invite user2 to it + room_id_two = self.helper.create_room_as( + u1, is_public=room_two_is_public, tok=u1_token + ) + self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token) + self.helper.join(room_id_two, user=u2, tok=u2_token) + + # Check shared rooms again. We should now see both rooms. + channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 2) + for room_id_id in channel.json_body["joined"]: + self.assertIn(room_id_id, [room_id_one, room_id_two]) + + def test_shared_room_list_after_leave(self): + """ + A room should no longer be considered shared if the other + user has left it. + """ + u1 = self.register_user("user1", "pass") + u1_token = self.login(u1, "pass") + u2 = self.register_user("user2", "pass") + u2_token = self.login(u2, "pass") + + room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) + self.helper.invite(room, src=u1, targ=u2, tok=u1_token) + self.helper.join(room, user=u2, tok=u2_token) + + # Assert user directory is not empty + channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 1) + self.assertEquals(channel.json_body["joined"][0], room) + + self.helper.leave(room, user=u1, tok=u1_token) + + # Check user1's view of shared rooms with user2 + channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 0) + + # Check user2's view of shared rooms with user1 + channel = self._get_shared_rooms(u2_token, u1) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 0) diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py new file mode 100644 index 0000000000..95be369d4b --- /dev/null +++ b/tests/rest/client/test_sync.py @@ -0,0 +1,686 @@ +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json + +import synapse.rest.admin +from synapse.api.constants import ( + EventContentFields, + EventTypes, + ReadReceiptEventFields, + RelationTypes, +) +from synapse.rest.client import knock, login, read_marker, receipts, room, sync + +from tests import unittest +from tests.federation.transport.test_knocking import ( + KnockingStrippedStateEventHelperMixin, +) +from tests.server import TimedOutException +from tests.unittest import override_config + + +class FilterTestCase(unittest.HomeserverTestCase): + + user_id = "@apple:test" + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def test_sync_argless(self): + channel = self.make_request("GET", "/sync") + + self.assertEqual(channel.code, 200) + self.assertIn("next_batch", channel.json_body) + + +class SyncFilterTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def test_sync_filter_labels(self): + """Test that we can filter by a label.""" + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.labels": ["#fun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 2, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) + + def test_sync_filter_not_labels(self): + """Test that we can filter by the absence of a label.""" + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.not_labels": ["#fun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 3, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "without label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) + self.assertEqual( + events[2]["content"]["body"], "with two wrong labels", events[2] + ) + + def test_sync_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label.""" + sync_filter = json.dumps( + { + "room": { + "timeline": { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + } + } + ) + + events = self._test_sync_filter_labels(sync_filter) + + self.assertEqual(len(events), 1, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) + + def _test_sync_filter_labels(self, sync_filter): + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + room_id = self.helper.create_room_as(user_id, tok=tok) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with wrong label", + EventContentFields.LABELS: ["#work"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with two wrong labels", + EventContentFields.LABELS: ["#work", "#notfun"], + }, + tok=tok, + ) + + self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=tok, + ) + + channel = self.make_request( + "GET", "/sync?filter=%s" % sync_filter, access_token=tok + ) + self.assertEqual(channel.code, 200, channel.result) + + return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"] + + +class SyncTypingTests(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + user_id = True + hijack_auth = False + + def test_sync_backwards_typing(self): + """ + If the typing serial goes backwards and the typing handler is then reset + (such as when the master restarts and sets the typing serial to 0), we + do not incorrectly return typing information that had a serial greater + than the now-reset serial. + """ + typing_url = "/rooms/%s/typing/%s?access_token=%s" + sync_url = "/sync?timeout=3000000&access_token=%s&since=%s" + + # Register the user who gets notified + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Register the user who sends the message + other_user_id = self.register_user("otheruser", "pass") + other_access_token = self.login("otheruser", "pass") + + # Create a room + room = self.helper.create_room_as(user_id, tok=access_token) + + # Invite the other person + self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id) + + # The other user joins + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # The other user sends some messages + self.helper.send(room, body="Hi!", tok=other_access_token) + self.helper.send(room, body="There!", tok=other_access_token) + + # Start typing. + channel = self.make_request( + "PUT", + typing_url % (room, other_user_id, other_access_token), + b'{"typing": true, "timeout": 30000}', + ) + self.assertEquals(200, channel.code) + + channel = self.make_request("GET", "/sync?access_token=%s" % (access_token,)) + self.assertEquals(200, channel.code) + next_batch = channel.json_body["next_batch"] + + # Stop typing. + channel = self.make_request( + "PUT", + typing_url % (room, other_user_id, other_access_token), + b'{"typing": false}', + ) + self.assertEquals(200, channel.code) + + # Start typing. + channel = self.make_request( + "PUT", + typing_url % (room, other_user_id, other_access_token), + b'{"typing": true, "timeout": 30000}', + ) + self.assertEquals(200, channel.code) + + # Should return immediately + channel = self.make_request("GET", sync_url % (access_token, next_batch)) + self.assertEquals(200, channel.code) + next_batch = channel.json_body["next_batch"] + + # Reset typing serial back to 0, as if the master had. + typing = self.hs.get_typing_handler() + typing._latest_room_serial = 0 + + # Since it checks the state token, we need some state to update to + # invalidate the stream token. + self.helper.send(room, body="There!", tok=other_access_token) + + channel = self.make_request("GET", sync_url % (access_token, next_batch)) + self.assertEquals(200, channel.code) + next_batch = channel.json_body["next_batch"] + + # This should time out! But it does not, because our stream token is + # ahead, and therefore it's saying the typing (that we've actually + # already seen) is new, since it's got a token above our new, now-reset + # stream token. + channel = self.make_request("GET", sync_url % (access_token, next_batch)) + self.assertEquals(200, channel.code) + next_batch = channel.json_body["next_batch"] + + # Clear the typing information, so that it doesn't think everything is + # in the future. + typing._reset() + + # Now it SHOULD fail as it never completes! + with self.assertRaises(TimedOutException): + self.make_request("GET", sync_url % (access_token, next_batch)) + + +class SyncKnockTestCase( + unittest.HomeserverTestCase, KnockingStrippedStateEventHelperMixin +): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + sync.register_servlets, + knock.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.url = "/sync?since=%s" + self.next_batch = "s0" + + # Register the first user (used to create the room to knock on). + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # Create the room we'll knock on. + self.room_id = self.helper.create_room_as( + self.user_id, + is_public=False, + room_version="7", + tok=self.tok, + ) + + # Register the second user (used to knock on the room). + self.knocker = self.register_user("knocker", "monkey") + self.knocker_tok = self.login("knocker", "monkey") + + # Perform an initial sync for the knocking user. + channel = self.make_request( + "GET", + self.url % self.next_batch, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Store the next batch for the next request. + self.next_batch = channel.json_body["next_batch"] + + # Set up some room state to test with. + self.expected_room_state = self.send_example_state_events_to_room( + hs, self.room_id, self.user_id + ) + + @override_config({"experimental_features": {"msc2403_enabled": True}}) + def test_knock_room_state(self): + """Tests that /sync returns state from a room after knocking on it.""" + # Knock on a room + channel = self.make_request( + "POST", + "/_matrix/client/r0/knock/%s" % (self.room_id,), + b"{}", + self.knocker_tok, + ) + self.assertEquals(200, channel.code, channel.result) + + # We expect to see the knock event in the stripped room state later + self.expected_room_state[EventTypes.Member] = { + "content": {"membership": "knock", "displayname": "knocker"}, + "state_key": "@knocker:test", + } + + # Check that /sync includes stripped state from the room + channel = self.make_request( + "GET", + self.url % self.next_batch, + access_token=self.knocker_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Extract the stripped room state events from /sync + knock_entry = channel.json_body["rooms"]["knock"] + room_state_events = knock_entry[self.room_id]["knock_state"]["events"] + + # Validate that the knock membership event came last + self.assertEqual(room_state_events[-1]["type"], EventTypes.Member) + + # Validate the stripped room state events + self.check_knock_room_state_against_room_state( + room_state_events, self.expected_room_state + ) + + +class ReadReceiptsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + receipts.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.url = "/sync?since=%s" + self.next_batch = "s0" + + # Register the first user + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # Create the room + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # Register the second user + self.user2 = self.register_user("kermit2", "monkey") + self.tok2 = self.login("kermit2", "monkey") + + # Join the second user + self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) + + @override_config({"experimental_features": {"msc2285_enabled": True}}) + def test_hidden_read_receipts(self): + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a read receipt to tell the server the first user's message was read + body = json.dumps({ReadReceiptEventFields.MSC2285_HIDDEN: True}).encode("utf8") + channel = self.make_request( + "POST", + "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), + body, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + + # Test that the first user can't see the other user's hidden read receipt + self.assertEqual(self._get_read_receipt(), None) + + def test_read_receipt_with_empty_body(self): + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a read receipt for this message with an empty body + channel = self.make_request( + "POST", + "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + + def _get_read_receipt(self): + """Syncs and returns the read receipt.""" + + # Checks if event is a read receipt + def is_read_receipt(event): + return event["type"] == "m.receipt" + + # Sync + channel = self.make_request( + "GET", + self.url % self.next_batch, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200) + + # Store the next batch for the next request. + self.next_batch = channel.json_body["next_batch"] + + # Return the read receipt + ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][ + "ephemeral" + ]["events"] + return next(filter(is_read_receipt, ephemeral_events), None) + + +class UnreadMessagesTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + read_marker.register_servlets, + room.register_servlets, + sync.register_servlets, + receipts.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.url = "/sync?since=%s" + self.next_batch = "s0" + + # Register the first user (used to check the unread counts). + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # Create the room we'll check unread counts for. + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # Register the second user (used to send events to the room). + self.user2 = self.register_user("kermit2", "monkey") + self.tok2 = self.login("kermit2", "monkey") + + # Change the power levels of the room so that the second user can send state + # events. + self.helper.send_state( + self.room_id, + EventTypes.PowerLevels, + { + "users": {self.user_id: 100, self.user2: 100}, + "users_default": 0, + "events": { + "m.room.name": 50, + "m.room.power_levels": 100, + "m.room.history_visibility": 100, + "m.room.canonical_alias": 50, + "m.room.avatar": 50, + "m.room.tombstone": 100, + "m.room.server_acl": 100, + "m.room.encryption": 100, + }, + "events_default": 0, + "state_default": 50, + "ban": 50, + "kick": 50, + "redact": 50, + "invite": 0, + }, + tok=self.tok, + ) + + def test_unread_counts(self): + """Tests that /sync returns the right value for the unread count (MSC2654).""" + + # Check that our own messages don't increase the unread count. + self.helper.send(self.room_id, "hello", tok=self.tok) + self._check_unread_count(0) + + # Join the new user and check that this doesn't increase the unread count. + self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) + self._check_unread_count(0) + + # Check that the new user sending a message increases our unread count. + res = self.helper.send(self.room_id, "hello", tok=self.tok2) + self._check_unread_count(1) + + # Send a read receipt to tell the server we've read the latest event. + body = json.dumps({"m.read": res["event_id"]}).encode("utf8") + channel = self.make_request( + "POST", + "/rooms/%s/read_markers" % self.room_id, + body, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the unread counter is back to 0. + self._check_unread_count(0) + + # Check that hidden read receipts don't break unread counts + res = self.helper.send(self.room_id, "hello", tok=self.tok2) + self._check_unread_count(1) + + # Send a read receipt to tell the server we've read the latest event. + body = json.dumps({ReadReceiptEventFields.MSC2285_HIDDEN: True}).encode("utf8") + channel = self.make_request( + "POST", + "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), + body, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the unread counter is back to 0. + self._check_unread_count(0) + + # Check that room name changes increase the unread counter. + self.helper.send_state( + self.room_id, + "m.room.name", + {"name": "my super room"}, + tok=self.tok2, + ) + self._check_unread_count(1) + + # Check that room topic changes increase the unread counter. + self.helper.send_state( + self.room_id, + "m.room.topic", + {"topic": "welcome!!!"}, + tok=self.tok2, + ) + self._check_unread_count(2) + + # Check that encrypted messages increase the unread counter. + self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2) + self._check_unread_count(3) + + # Check that custom events with a body increase the unread counter. + self.helper.send_event( + self.room_id, + "org.matrix.custom_type", + {"body": "hello"}, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that edits don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "body": "hello", + "msgtype": "m.text", + "m.relates_to": {"rel_type": RelationTypes.REPLACE}, + }, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that notices don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"body": "hello", "msgtype": "m.notice"}, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that tombstone events changes increase the unread counter. + self.helper.send_state( + self.room_id, + EventTypes.Tombstone, + {"replacement_room": "!someroom:test"}, + tok=self.tok2, + ) + self._check_unread_count(5) + + def _check_unread_count(self, expected_count: int): + """Syncs and compares the unread count with the expected value.""" + + channel = self.make_request( + "GET", + self.url % self.next_batch, + access_token=self.tok, + ) + + self.assertEqual(channel.code, 200, channel.json_body) + + room_entry = channel.json_body["rooms"]["join"][self.room_id] + self.assertEqual( + room_entry["org.matrix.msc2654.unread_count"], + expected_count, + room_entry, + ) + + # Store the next batch for the next request. + self.next_batch = channel.json_body["next_batch"] + + +class SyncCacheTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + def test_noop_sync_does_not_tightloop(self): + """If the sync times out, we shouldn't cache the result + + Essentially a regression test for #8518. + """ + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # we should immediately get an initial sync response + channel = self.make_request("GET", "/sync", access_token=self.tok) + self.assertEqual(channel.code, 200, channel.json_body) + + # now, make an incremental sync request, with a timeout + next_batch = channel.json_body["next_batch"] + channel = self.make_request( + "GET", + f"/sync?since={next_batch}&timeout=10000", + access_token=self.tok, + await_result=False, + ) + # that should block for 10 seconds + with self.assertRaises(TimedOutException): + channel.await_result(timeout_ms=9900) + channel.await_result(timeout_ms=200) + self.assertEqual(channel.code, 200, channel.json_body) + + # we expect the next_batch in the result to be the same as before + self.assertEqual(channel.json_body["next_batch"], next_batch) + + # another incremental sync should also block. + channel = self.make_request( + "GET", + f"/sync?since={next_batch}&timeout=10000", + access_token=self.tok, + await_result=False, + ) + # that should block for 10 seconds + with self.assertRaises(TimedOutException): + channel.await_result(timeout_ms=9900) + channel.await_result(timeout_ms=200) + self.assertEqual(channel.code, 200, channel.json_body) diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py new file mode 100644 index 0000000000..b54b004733 --- /dev/null +++ b/tests/rest/client/test_typing.py @@ -0,0 +1,121 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests REST events for /rooms paths.""" + +from unittest.mock import Mock + +from synapse.rest.client import room +from synapse.types import UserID + +from tests import unittest + +PATH_PREFIX = "/_matrix/client/api/v1" + + +class RoomTypingTestCase(unittest.HomeserverTestCase): + """Tests /rooms/$room_id/typing/$user_id REST API.""" + + user_id = "@sid:red" + + user = UserID.from_string(user_id) + servlets = [room.register_servlets] + + def make_homeserver(self, reactor, clock): + + hs = self.setup_test_homeserver( + "red", + federation_http_client=None, + federation_client=Mock(), + ) + + self.event_source = hs.get_event_sources().sources["typing"] + + hs.get_federation_handler = Mock() + + async def get_user_by_access_token(token=None, allow_guest=False): + return { + "user": UserID.from_string(self.auth_user_id), + "token_id": 1, + "is_guest": False, + } + + hs.get_auth().get_user_by_access_token = get_user_by_access_token + + async def _insert_client_ip(*args, **kwargs): + return None + + hs.get_datastore().insert_client_ip = _insert_client_ip + + return hs + + def prepare(self, reactor, clock, hs): + self.room_id = self.helper.create_room_as(self.user_id) + # Need another user to make notifications actually work + self.helper.join(self.room_id, user="@jim:red") + + def test_set_typing(self): + channel = self.make_request( + "PUT", + "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + b'{"typing": true, "timeout": 30000}', + ) + self.assertEquals(200, channel.code) + + self.assertEquals(self.event_source.get_current_key(), 1) + events = self.get_success( + self.event_source.get_new_events(from_key=0, room_ids=[self.room_id]) + ) + self.assertEquals( + events[0], + [ + { + "type": "m.typing", + "room_id": self.room_id, + "content": {"user_ids": [self.user_id]}, + } + ], + ) + + def test_set_not_typing(self): + channel = self.make_request( + "PUT", + "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + b'{"typing": false}', + ) + self.assertEquals(200, channel.code) + + def test_typing_timeout(self): + channel = self.make_request( + "PUT", + "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + b'{"typing": true, "timeout": 30000}', + ) + self.assertEquals(200, channel.code) + + self.assertEquals(self.event_source.get_current_key(), 1) + + self.reactor.advance(36) + + self.assertEquals(self.event_source.get_current_key(), 2) + + channel = self.make_request( + "PUT", + "/rooms/%s/typing/%s" % (self.room_id, self.user_id), + b'{"typing": true, "timeout": 30000}', + ) + self.assertEquals(200, channel.code) + + self.assertEquals(self.event_source.get_current_key(), 3) diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py new file mode 100644 index 0000000000..72f976d8e2 --- /dev/null +++ b/tests/rest/client/test_upgrade_room.py @@ -0,0 +1,159 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from synapse.config.server import DEFAULT_ROOM_VERSION +from synapse.rest import admin +from synapse.rest.client import login, room, room_upgrade_rest_servlet + +from tests import unittest +from tests.server import FakeChannel + + +class UpgradeRoomTest(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + room_upgrade_rest_servlet.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.handler = hs.get_user_directory_handler() + + self.creator = self.register_user("creator", "pass") + self.creator_token = self.login(self.creator, "pass") + + self.other = self.register_user("user", "pass") + self.other_token = self.login(self.other, "pass") + + self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_token) + self.helper.join(self.room_id, self.other, tok=self.other_token) + + def _upgrade_room(self, token: Optional[str] = None) -> FakeChannel: + # We never want a cached response. + self.reactor.advance(5 * 60 + 1) + + return self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/upgrade" % self.room_id, + # This will upgrade a room to the same version, but that's fine. + content={"new_version": DEFAULT_ROOM_VERSION}, + access_token=token or self.creator_token, + ) + + def test_upgrade(self): + """ + Upgrading a room should work fine. + """ + channel = self._upgrade_room() + self.assertEquals(200, channel.code, channel.result) + self.assertIn("replacement_room", channel.json_body) + + def test_not_in_room(self): + """ + Upgrading a room should work fine. + """ + # THe user isn't in the room. + roomless = self.register_user("roomless", "pass") + roomless_token = self.login(roomless, "pass") + + channel = self._upgrade_room(roomless_token) + self.assertEquals(403, channel.code, channel.result) + + def test_power_levels(self): + """ + Another user can upgrade the room if their power level is increased. + """ + # The other user doesn't have the proper power level. + channel = self._upgrade_room(self.other_token) + self.assertEquals(403, channel.code, channel.result) + + # Increase the power levels so that this user can upgrade. + power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.creator_token, + ) + power_levels["users"][self.other] = 100 + self.helper.send_state( + self.room_id, + "m.room.power_levels", + body=power_levels, + tok=self.creator_token, + ) + + # The upgrade should succeed! + channel = self._upgrade_room(self.other_token) + self.assertEquals(200, channel.code, channel.result) + + def test_power_levels_user_default(self): + """ + Another user can upgrade the room if the default power level for users is increased. + """ + # The other user doesn't have the proper power level. + channel = self._upgrade_room(self.other_token) + self.assertEquals(403, channel.code, channel.result) + + # Increase the power levels so that this user can upgrade. + power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.creator_token, + ) + power_levels["users_default"] = 100 + self.helper.send_state( + self.room_id, + "m.room.power_levels", + body=power_levels, + tok=self.creator_token, + ) + + # The upgrade should succeed! + channel = self._upgrade_room(self.other_token) + self.assertEquals(200, channel.code, channel.result) + + def test_power_levels_tombstone(self): + """ + Another user can upgrade the room if they can send the tombstone event. + """ + # The other user doesn't have the proper power level. + channel = self._upgrade_room(self.other_token) + self.assertEquals(403, channel.code, channel.result) + + # Increase the power levels so that this user can upgrade. + power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.creator_token, + ) + power_levels["events"]["m.room.tombstone"] = 0 + self.helper.send_state( + self.room_id, + "m.room.power_levels", + body=power_levels, + tok=self.creator_token, + ) + + # The upgrade should succeed! + channel = self._upgrade_room(self.other_token) + self.assertEquals(200, channel.code, channel.result) + + power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.creator_token, + ) + self.assertNotIn(self.other, power_levels["users"]) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py new file mode 100644 index 0000000000..954ad1a1fd --- /dev/null +++ b/tests/rest/client/utils.py @@ -0,0 +1,654 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +import time +import urllib.parse +from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union +from unittest.mock import patch + +import attr + +from twisted.web.resource import Resource +from twisted.web.server import Site + +from synapse.api.constants import Membership +from synapse.types import JsonDict + +from tests.server import FakeChannel, FakeSite, make_request +from tests.test_utils import FakeResponse +from tests.test_utils.html_parsers import TestHtmlParser + + +@attr.s +class RestHelper: + """Contains extra helper functions to quickly and clearly perform a given + REST action, which isn't the focus of the test. + """ + + hs = attr.ib() + site = attr.ib(type=Site) + auth_user_id = attr.ib() + + def create_room_as( + self, + room_creator: Optional[str] = None, + is_public: bool = True, + room_version: Optional[str] = None, + tok: Optional[str] = None, + expect_code: int = 200, + extra_content: Optional[Dict] = None, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + ) -> str: + """ + Create a room. + + Args: + room_creator: The user ID to create the room with. + is_public: If True, the `visibility` parameter will be set to the + default (public). Otherwise, the `visibility` parameter will be set + to "private". + room_version: The room version to create the room as. Defaults to Synapse's + default room version. + tok: The access token to use in the request. + expect_code: The expected HTTP response code. + + Returns: + The ID of the newly created room. + """ + temp_id = self.auth_user_id + self.auth_user_id = room_creator + path = "/_matrix/client/r0/createRoom" + content = extra_content or {} + if not is_public: + content["visibility"] = "private" + if room_version: + content["room_version"] = room_version + if tok: + path = path + "?access_token=%s" % tok + + channel = make_request( + self.hs.get_reactor(), + self.site, + "POST", + path, + json.dumps(content).encode("utf8"), + custom_headers=custom_headers, + ) + + assert channel.result["code"] == b"%d" % expect_code, channel.result + self.auth_user_id = temp_id + + if expect_code == 200: + return channel.json_body["room_id"] + + def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): + self.change_membership( + room=room, + src=src, + targ=targ, + tok=tok, + membership=Membership.INVITE, + expect_code=expect_code, + ) + + def join(self, room=None, user=None, expect_code=200, tok=None): + self.change_membership( + room=room, + src=user, + targ=user, + tok=tok, + membership=Membership.JOIN, + expect_code=expect_code, + ) + + def leave(self, room=None, user=None, expect_code=200, tok=None): + self.change_membership( + room=room, + src=user, + targ=user, + tok=tok, + membership=Membership.LEAVE, + expect_code=expect_code, + ) + + def change_membership( + self, + room: str, + src: str, + targ: str, + membership: str, + extra_data: Optional[dict] = None, + tok: Optional[str] = None, + expect_code: int = 200, + ) -> None: + """ + Send a membership state event into a room. + + Args: + room: The ID of the room to send to + src: The mxid of the event sender + targ: The mxid of the event's target. The state key + membership: The type of membership event + extra_data: Extra information to include in the content of the event + tok: The user access token to use + expect_code: The expected HTTP response code + """ + temp_id = self.auth_user_id + self.auth_user_id = src + + path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ) + if tok: + path = path + "?access_token=%s" % tok + + data = {"membership": membership} + data.update(extra_data or {}) + + channel = make_request( + self.hs.get_reactor(), + self.site, + "PUT", + path, + json.dumps(data).encode("utf8"), + ) + + assert ( + int(channel.result["code"]) == expect_code + ), "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], + ) + + self.auth_user_id = temp_id + + def send( + self, + room_id, + body=None, + txn_id=None, + tok=None, + expect_code=200, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + ): + if body is None: + body = "body_text_here" + + content = {"msgtype": "m.text", "body": body} + + return self.send_event( + room_id, + "m.room.message", + content, + txn_id, + tok, + expect_code, + custom_headers=custom_headers, + ) + + def send_event( + self, + room_id, + type, + content: Optional[dict] = None, + txn_id=None, + tok=None, + expect_code=200, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + ): + if txn_id is None: + txn_id = "m%s" % (str(time.time())) + + path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id) + if tok: + path = path + "?access_token=%s" % tok + + channel = make_request( + self.hs.get_reactor(), + self.site, + "PUT", + path, + json.dumps(content or {}).encode("utf8"), + custom_headers=custom_headers, + ) + + assert ( + int(channel.result["code"]) == expect_code + ), "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], + ) + + return channel.json_body + + def _read_write_state( + self, + room_id: str, + event_type: str, + body: Optional[Dict[str, Any]], + tok: str, + expect_code: int = 200, + state_key: str = "", + method: str = "GET", + ) -> Dict: + """Read or write some state from a given room + + Args: + room_id: + event_type: The type of state event + body: Body that is sent when making the request. The content of the state event. + If None, the request to the server will have an empty body + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + method: "GET" or "PUT" for reading or writing state, respectively + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ + path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % ( + room_id, + event_type, + state_key, + ) + if tok: + path = path + "?access_token=%s" % tok + + # Set request body if provided + content = b"" + if body is not None: + content = json.dumps(body).encode("utf8") + + channel = make_request(self.hs.get_reactor(), self.site, method, path, content) + + assert ( + int(channel.result["code"]) == expect_code + ), "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], + ) + + return channel.json_body + + def get_state( + self, + room_id: str, + event_type: str, + tok: str, + expect_code: int = 200, + state_key: str = "", + ): + """Gets some state from a room + + Args: + room_id: + event_type: The type of state event + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ + return self._read_write_state( + room_id, event_type, None, tok, expect_code, state_key, method="GET" + ) + + def send_state( + self, + room_id: str, + event_type: str, + body: Dict[str, Any], + tok: str, + expect_code: int = 200, + state_key: str = "", + ): + """Set some state in a room + + Args: + room_id: + event_type: The type of state event + body: Body that is sent when making the request. The content of the state event. + tok: The access token to use + expect_code: The HTTP code to expect in the response + state_key: + + Returns: + The response body from the server + + Raises: + AssertionError: if expect_code doesn't match the HTTP code we received + """ + return self._read_write_state( + room_id, event_type, body, tok, expect_code, state_key, method="PUT" + ) + + def upload_media( + self, + resource: Resource, + image_data: bytes, + tok: str, + filename: str = "test.png", + expect_code: int = 200, + ) -> dict: + """Upload a piece of test media to the media repo + Args: + resource: The resource that will handle the upload request + image_data: The image data to upload + tok: The user token to use during the upload + filename: The filename of the media to be uploaded + expect_code: The return code to expect from attempting to upload the media + """ + image_length = len(image_data) + path = "/_matrix/media/r0/upload?filename=%s" % (filename,) + channel = make_request( + self.hs.get_reactor(), + FakeSite(resource), + "POST", + path, + content=image_data, + access_token=tok, + custom_headers=[(b"Content-Length", str(image_length))], + ) + + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( + expect_code, + int(channel.result["code"]), + channel.result["body"], + ) + + return channel.json_body + + def login_via_oidc(self, remote_user_id: str) -> JsonDict: + """Log in (as a new user) via OIDC + + Returns the result of the final token login. + + Requires that "oidc_config" in the homeserver config be set appropriately + (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a + "public_base_url". + + Also requires the login servlet and the OIDC callback resource to be mounted at + the normal places. + """ + client_redirect_url = "https://x" + channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) + + # expect a confirmation page + assert channel.code == 200, channel.result + + # fish the matrix login token out of the body of the confirmation page + m = re.search( + 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), + channel.text_body, + ) + assert m, channel.text_body + login_token = m.group(1) + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token and device id. + channel = make_request( + self.hs.get_reactor(), + self.site, + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + assert channel.code == 200 + return channel.json_body + + def auth_via_oidc( + self, + user_info_dict: JsonDict, + client_redirect_url: Optional[str] = None, + ui_auth_session_id: Optional[str] = None, + ) -> FakeChannel: + """Perform an OIDC authentication flow via a mock OIDC provider. + + This can be used for either login or user-interactive auth. + + Starts by making a request to the relevant synapse redirect endpoint, which is + expected to serve a 302 to the OIDC provider. We then make a request to the + OIDC callback endpoint, intercepting the HTTP requests that will get sent back + to the OIDC provider. + + Requires that "oidc_config" in the homeserver config be set appropriately + (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a + "public_base_url". + + Also requires the login servlet and the OIDC callback resource to be mounted at + the normal places. + + Args: + user_info_dict: the remote userinfo that the OIDC provider should present. + Typically this should be '{"sub": ""}'. + client_redirect_url: for a login flow, the client redirect URL to pass to + the login redirect endpoint + ui_auth_session_id: if set, we will perform a UI Auth flow. The session id + of the UI auth. + + Returns: + A FakeChannel containing the result of calling the OIDC callback endpoint. + Note that the response code may be a 200, 302 or 400 depending on how things + went. + """ + + cookies = {} + + # if we're doing a ui auth, hit the ui auth redirect endpoint + if ui_auth_session_id: + # can't set the client redirect url for UI Auth + assert client_redirect_url is None + oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) + else: + # otherwise, hit the login redirect endpoint + oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) + + # we now have a URI for the OIDC IdP, but we skip that and go straight + # back to synapse's OIDC callback resource. However, we do need the "state" + # param that synapse passes to the IdP via query params, as well as the cookie + # that synapse passes to the client. + + oauth_uri_path, _ = oauth_uri.split("?", 1) + assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( + "unexpected SSO URI " + oauth_uri_path + ) + return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) + + def complete_oidc_auth( + self, + oauth_uri: str, + cookies: Mapping[str, str], + user_info_dict: JsonDict, + ) -> FakeChannel: + """Mock out an OIDC authentication flow + + Assumes that an OIDC auth has been initiated by one of initiate_sso_login or + initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to + Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get + sent back to the OIDC provider. + + Requires the OIDC callback resource to be mounted at the normal place. + + Args: + oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, + from initiate_sso_login or initiate_sso_ui_auth). + cookies: the cookies set by synapse's redirect endpoint, which will be + sent back to the callback endpoint. + user_info_dict: the remote userinfo that the OIDC provider should present. + Typically this should be '{"sub": ""}'. + + Returns: + A FakeChannel containing the result of calling the OIDC callback endpoint. + """ + _, oauth_uri_qs = oauth_uri.split("?", 1) + params = urllib.parse.parse_qs(oauth_uri_qs) + callback_uri = "%s?%s" % ( + urllib.parse.urlparse(params["redirect_uri"][0]).path, + urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}), + ) + + # before we hit the callback uri, stub out some methods in the http client so + # that we don't have to handle full HTTPS requests. + # (expected url, json response) pairs, in the order we expect them. + expected_requests = [ + # first we get a hit to the token endpoint, which we tell to return + # a dummy OIDC access token + (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), + # and then one to the user_info endpoint, which returns our remote user id. + (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), + ] + + async def mock_req(method: str, uri: str, data=None, headers=None): + (expected_uri, resp_obj) = expected_requests.pop(0) + assert uri == expected_uri + resp = FakeResponse( + code=200, + phrase=b"OK", + body=json.dumps(resp_obj).encode("utf-8"), + ) + return resp + + with patch.object(self.hs.get_proxied_http_client(), "request", mock_req): + # now hit the callback URI with the right params and a made-up code + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + callback_uri, + custom_headers=[ + ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items() + ], + ) + return channel + + def initiate_sso_login( + self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] + ) -> str: + """Make a request to the login-via-sso redirect endpoint, and return the target + + Assumes that exactly one SSO provider has been configured. Requires the login + servlet to be mounted. + + Args: + client_redirect_url: the client redirect URL to pass to the login redirect + endpoint + cookies: any cookies returned will be added to this dict + + Returns: + the URI that the client gets redirected to (ie, the SSO server) + """ + params = {} + if client_redirect_url: + params["redirectUrl"] = client_redirect_url + + # hit the redirect url (which should redirect back to the redirect url. This + # is the easiest way of figuring out what the Host header ought to be set to + # to keep Synapse happy. + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), + ) + assert channel.code == 302 + + # hit the redirect url again with the right Host header, which should now issue + # a cookie and redirect to the SSO provider. + location = channel.headers.getRawHeaders("Location")[0] + parts = urllib.parse.urlsplit(location) + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + urllib.parse.urlunsplit(("", "") + parts[2:]), + custom_headers=[ + ("Host", parts[1]), + ], + ) + + assert channel.code == 302 + channel.extract_cookies(cookies) + return channel.headers.getRawHeaders("Location")[0] + + def initiate_sso_ui_auth( + self, ui_auth_session_id: str, cookies: MutableMapping[str, str] + ) -> str: + """Make a request to the ui-auth-via-sso endpoint, and return the target + + Assumes that exactly one SSO provider has been configured. Requires the + AuthRestServlet to be mounted. + + Args: + ui_auth_session_id: the session id of the UI auth + cookies: any cookies returned will be added to this dict + + Returns: + the URI that the client gets linked to (ie, the SSO server) + """ + sso_redirect_endpoint = ( + "/_matrix/client/r0/auth/m.login.sso/fallback/web?" + + urllib.parse.urlencode({"session": ui_auth_session_id}) + ) + # hit the redirect url (which will issue a cookie and state) + channel = make_request( + self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint + ) + # that should serve a confirmation page + assert channel.code == 200, channel.text_body + channel.extract_cookies(cookies) + + # parse the confirmation page to fish out the link. + p = TestHtmlParser() + p.feed(channel.text_body) + p.close() + assert len(p.links) == 1, "not exactly one link in confirmation page" + oauth_uri = p.links[0] + return oauth_uri + + +# an 'oidc_config' suitable for login_via_oidc. +TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" +TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" +TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" +TEST_OIDC_CONFIG = { + "enabled": True, + "discover": False, + "issuer": "https://issuer.test", + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["profile"], + "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, + "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, + "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, + "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, +} diff --git a/tests/rest/client/v1/__init__.py b/tests/rest/client/v1/__init__.py deleted file mode 100644 index 5e83dba2ed..0000000000 --- a/tests/rest/client/v1/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py deleted file mode 100644 index d2181ea907..0000000000 --- a/tests/rest/client/v1/test_directory.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -from synapse.rest import admin -from synapse.rest.client import directory, login, room -from synapse.types import RoomAlias -from synapse.util.stringutils import random_string - -from tests import unittest -from tests.unittest import override_config - - -class DirectoryTestCase(unittest.HomeserverTestCase): - - servlets = [ - admin.register_servlets_for_client_rest_resource, - directory.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - config["require_membership_for_aliases"] = True - - self.hs = self.setup_test_homeserver(config=config) - - return self.hs - - def prepare(self, reactor, clock, homeserver): - self.room_owner = self.register_user("room_owner", "test") - self.room_owner_tok = self.login("room_owner", "test") - - self.room_id = self.helper.create_room_as( - self.room_owner, tok=self.room_owner_tok - ) - - self.user = self.register_user("user", "test") - self.user_tok = self.login("user", "test") - - def test_state_event_not_in_room(self): - self.ensure_user_left_room() - self.set_alias_via_state_event(403) - - def test_directory_endpoint_not_in_room(self): - self.ensure_user_left_room() - self.set_alias_via_directory(403) - - def test_state_event_in_room_too_long(self): - self.ensure_user_joined_room() - self.set_alias_via_state_event(400, alias_length=256) - - def test_directory_in_room_too_long(self): - self.ensure_user_joined_room() - self.set_alias_via_directory(400, alias_length=256) - - @override_config({"default_room_version": 5}) - def test_state_event_user_in_v5_room(self): - """Test that a regular user can add alias events before room v6""" - self.ensure_user_joined_room() - self.set_alias_via_state_event(200) - - @override_config({"default_room_version": 6}) - def test_state_event_v6_room(self): - """Test that a regular user can *not* add alias events from room v6""" - self.ensure_user_joined_room() - self.set_alias_via_state_event(403) - - def test_directory_in_room(self): - self.ensure_user_joined_room() - self.set_alias_via_directory(200) - - def test_room_creation_too_long(self): - url = "/_matrix/client/r0/createRoom" - - # We use deliberately a localpart under the length threshold so - # that we can make sure that the check is done on the whole alias. - data = {"room_alias_name": random_string(256 - len(self.hs.hostname))} - request_data = json.dumps(data) - channel = self.make_request( - "POST", url, request_data, access_token=self.user_tok - ) - self.assertEqual(channel.code, 400, channel.result) - - def test_room_creation(self): - url = "/_matrix/client/r0/createRoom" - - # Check with an alias of allowed length. There should already be - # a test that ensures it works in test_register.py, but let's be - # as cautious as possible here. - data = {"room_alias_name": random_string(5)} - request_data = json.dumps(data) - channel = self.make_request( - "POST", url, request_data, access_token=self.user_tok - ) - self.assertEqual(channel.code, 200, channel.result) - - def set_alias_via_state_event(self, expected_code, alias_length=5): - url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % ( - self.room_id, - self.hs.hostname, - ) - - data = {"aliases": [self.random_alias(alias_length)]} - request_data = json.dumps(data) - - channel = self.make_request( - "PUT", url, request_data, access_token=self.user_tok - ) - self.assertEqual(channel.code, expected_code, channel.result) - - def set_alias_via_directory(self, expected_code, alias_length=5): - url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length) - data = {"room_id": self.room_id} - request_data = json.dumps(data) - - channel = self.make_request( - "PUT", url, request_data, access_token=self.user_tok - ) - self.assertEqual(channel.code, expected_code, channel.result) - - def random_alias(self, length): - return RoomAlias(random_string(length), self.hs.hostname).to_string() - - def ensure_user_left_room(self): - self.ensure_membership("leave") - - def ensure_user_joined_room(self): - self.ensure_membership("join") - - def ensure_membership(self, membership): - try: - if membership == "leave": - self.helper.leave(room=self.room_id, user=self.user, tok=self.user_tok) - if membership == "join": - self.helper.join(room=self.room_id, user=self.user, tok=self.user_tok) - except AssertionError: - # We don't care whether the leave request didn't return a 200 (e.g. - # if the user isn't already in the room), because we only want to - # make sure the user isn't in the room. - pass diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py deleted file mode 100644 index a90294003e..0000000000 --- a/tests/rest/client/v1/test_events.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" Tests REST events for /events paths.""" - -from unittest.mock import Mock - -import synapse.rest.admin -from synapse.rest.client import events, login, room - -from tests import unittest - - -class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): - """Tests event streaming (GET /events).""" - - servlets = [ - events.register_servlets, - room.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - - config = self.default_config() - config["enable_registration_captcha"] = False - config["enable_registration"] = True - config["auto_join_rooms"] = [] - - hs = self.setup_test_homeserver(config=config) - - hs.get_federation_handler = Mock() - - return hs - - def prepare(self, reactor, clock, hs): - - # register an account - self.user_id = self.register_user("sid1", "pass") - self.token = self.login(self.user_id, "pass") - - # register a 2nd account - self.other_user = self.register_user("other2", "pass") - self.other_token = self.login(self.other_user, "pass") - - def test_stream_basic_permissions(self): - # invalid token, expect 401 - # note: this is in violation of the original v1 spec, which expected - # 403. However, since the v1 spec no longer exists and the v1 - # implementation is now part of the r0 implementation, the newer - # behaviour is used instead to be consistent with the r0 spec. - # see issue #2602 - channel = self.make_request( - "GET", "/events?access_token=%s" % ("invalid" + self.token,) - ) - self.assertEquals(channel.code, 401, msg=channel.result) - - # valid token, expect content - channel = self.make_request( - "GET", "/events?access_token=%s&timeout=0" % (self.token,) - ) - self.assertEquals(channel.code, 200, msg=channel.result) - self.assertTrue("chunk" in channel.json_body) - self.assertTrue("start" in channel.json_body) - self.assertTrue("end" in channel.json_body) - - def test_stream_room_permissions(self): - room_id = self.helper.create_room_as(self.other_user, tok=self.other_token) - self.helper.send(room_id, tok=self.other_token) - - # invited to room (expect no content for room) - self.helper.invite( - room_id, src=self.other_user, targ=self.user_id, tok=self.other_token - ) - - # valid token, expect content - channel = self.make_request( - "GET", "/events?access_token=%s&timeout=0" % (self.token,) - ) - self.assertEquals(channel.code, 200, msg=channel.result) - - # We may get a presence event for ourselves down - self.assertEquals( - 0, - len( - [ - c - for c in channel.json_body["chunk"] - if not ( - c.get("type") == "m.presence" - and c["content"].get("user_id") == self.user_id - ) - ] - ), - ) - - # joined room (expect all content for room) - self.helper.join(room=room_id, user=self.user_id, tok=self.token) - - # left to room (expect no content for room) - - def TODO_test_stream_items(self): - # new user, no content - - # join room, expect 1 item (join) - - # send message, expect 2 items (join,send) - - # set topic, expect 3 items (join,send,topic) - - # someone else join room, expect 4 (join,send,topic,join) - - # someone else send message, expect 5 (join,send.topic,join,send) - - # someone else set topic, expect 6 (join,send,topic,join,send,topic) - pass - - -class GetEventsTestCase(unittest.HomeserverTestCase): - servlets = [ - events.register_servlets, - room.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - ] - - def prepare(self, hs, reactor, clock): - - # register an account - self.user_id = self.register_user("sid1", "pass") - self.token = self.login(self.user_id, "pass") - - self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) - - def test_get_event_via_events(self): - resp = self.helper.send(self.room_id, tok=self.token) - event_id = resp["event_id"] - - channel = self.make_request( - "GET", - "/events/" + event_id, - access_token=self.token, - ) - self.assertEquals(channel.code, 200, msg=channel.result) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py deleted file mode 100644 index eba3552b19..0000000000 --- a/tests/rest/client/v1/test_login.py +++ /dev/null @@ -1,1345 +0,0 @@ -# Copyright 2019-2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -import urllib.parse -from typing import Any, Dict, List, Optional, Union -from unittest.mock import Mock -from urllib.parse import urlencode - -import pymacaroons - -from twisted.web.resource import Resource - -import synapse.rest.admin -from synapse.appservice import ApplicationService -from synapse.rest.client import devices, login, logout, register -from synapse.rest.client.account import WhoamiRestServlet -from synapse.rest.synapse.client import build_synapse_client_resource_tree -from synapse.types import create_requester - -from tests import unittest -from tests.handlers.test_oidc import HAS_OIDC -from tests.handlers.test_saml import has_saml2 -from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG -from tests.test_utils.html_parsers import TestHtmlParser -from tests.unittest import HomeserverTestCase, override_config, skip_unless - -try: - import jwt - - HAS_JWT = True -except ImportError: - HAS_JWT = False - - -# synapse server name: used to populate public_baseurl in some tests -SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse" - -# public_baseurl for some tests. It uses an http:// scheme because -# FakeChannel.isSecure() returns False, so synapse will see the requested uri as -# http://..., so using http in the public_baseurl stops Synapse trying to redirect to -# https://.... -BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,) - -# CAS server used in some tests -CAS_SERVER = "https://fake.test" - -# just enough to tell pysaml2 where to redirect to -SAML_SERVER = "https://test.saml.server/idp/sso" -TEST_SAML_METADATA = """ - - - - - -""" % { - "SAML_SERVER": SAML_SERVER, -} - -LOGIN_URL = b"/_matrix/client/r0/login" -TEST_URL = b"/_matrix/client/r0/account/whoami" - -# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is + -TEST_CLIENT_REDIRECT_URL = 'https://x?&q"+%3D%2B"="fö%26=o"' - -# the query params in TEST_CLIENT_REDIRECT_URL -EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("", ""), ('q" =+"', '"fö&=o"')] - -# (possibly experimental) login flows we expect to appear in the list after the normal -# ones -ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] - - -class LoginRestServletTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - logout.register_servlets, - devices.register_servlets, - lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), - ] - - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - self.hs.config.enable_registration = True - self.hs.config.registrations_require_3pid = [] - self.hs.config.auto_join_rooms = [] - self.hs.config.enable_registration_captcha = False - - return self.hs - - @override_config( - { - "rc_login": { - "address": {"per_second": 0.17, "burst_count": 5}, - # Prevent the account login ratelimiter from raising first - # - # This is normally covered by the default test homeserver config - # which sets these values to 10000, but as we're overriding the entire - # rc_login dict here, we need to set this manually as well - "account": {"per_second": 10000, "burst_count": 10000}, - } - } - ) - def test_POST_ratelimiting_per_address(self): - # Create different users so we're sure not to be bothered by the per-user - # ratelimiter. - for i in range(0, 6): - self.register_user("kermit" + str(i), "monkey") - - for i in range(0, 6): - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, - "password": "monkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) - retry_after_ms = int(channel.json_body["retry_after_ms"]) - else: - self.assertEquals(channel.result["code"], b"200", channel.result) - - # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower - # than 1min. - self.assertTrue(retry_after_ms < 6000) - - self.reactor.advance(retry_after_ms / 1000.0 + 1.0) - - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, - "password": "monkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - self.assertEquals(channel.result["code"], b"200", channel.result) - - @override_config( - { - "rc_login": { - "account": {"per_second": 0.17, "burst_count": 5}, - # Prevent the address login ratelimiter from raising first - # - # This is normally covered by the default test homeserver config - # which sets these values to 10000, but as we're overriding the entire - # rc_login dict here, we need to set this manually as well - "address": {"per_second": 10000, "burst_count": 10000}, - } - } - ) - def test_POST_ratelimiting_per_account(self): - self.register_user("kermit", "monkey") - - for i in range(0, 6): - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit"}, - "password": "monkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) - retry_after_ms = int(channel.json_body["retry_after_ms"]) - else: - self.assertEquals(channel.result["code"], b"200", channel.result) - - # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower - # than 1min. - self.assertTrue(retry_after_ms < 6000) - - self.reactor.advance(retry_after_ms / 1000.0) - - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit"}, - "password": "monkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - self.assertEquals(channel.result["code"], b"200", channel.result) - - @override_config( - { - "rc_login": { - # Prevent the address login ratelimiter from raising first - # - # This is normally covered by the default test homeserver config - # which sets these values to 10000, but as we're overriding the entire - # rc_login dict here, we need to set this manually as well - "address": {"per_second": 10000, "burst_count": 10000}, - "failed_attempts": {"per_second": 0.17, "burst_count": 5}, - } - } - ) - def test_POST_ratelimiting_per_account_failed_attempts(self): - self.register_user("kermit", "monkey") - - for i in range(0, 6): - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit"}, - "password": "notamonkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) - retry_after_ms = int(channel.json_body["retry_after_ms"]) - else: - self.assertEquals(channel.result["code"], b"403", channel.result) - - # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower - # than 1min. - self.assertTrue(retry_after_ms < 6000) - - self.reactor.advance(retry_after_ms / 1000.0 + 1.0) - - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit"}, - "password": "notamonkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - self.assertEquals(channel.result["code"], b"403", channel.result) - - @override_config({"session_lifetime": "24h"}) - def test_soft_logout(self): - self.register_user("kermit", "monkey") - - # we shouldn't be able to make requests without an access token - channel = self.make_request(b"GET", TEST_URL) - self.assertEquals(channel.result["code"], b"401", channel.result) - self.assertEquals(channel.json_body["errcode"], "M_MISSING_TOKEN") - - # log in as normal - params = { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": "kermit"}, - "password": "monkey", - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - self.assertEquals(channel.code, 200, channel.result) - access_token = channel.json_body["access_token"] - device_id = channel.json_body["device_id"] - - # we should now be able to make requests with the access token - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) - - # time passes - self.reactor.advance(24 * 3600) - - # ... and we should be soft-logouted - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) - - # - # test behaviour after deleting the expired device - # - - # we now log in as a different device - access_token_2 = self.login("kermit", "monkey") - - # more requests with the expired token should still return a soft-logout - self.reactor.advance(3600) - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) - - # ... but if we delete that device, it will be a proper logout - self._delete_device(access_token_2, "kermit", "monkey", device_id) - - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], False) - - def _delete_device(self, access_token, user_id, password, device_id): - """Perform the UI-Auth to delete a device""" - channel = self.make_request( - b"DELETE", "devices/" + device_id, access_token=access_token - ) - self.assertEquals(channel.code, 401, channel.result) - # check it's a UI-Auth fail - self.assertEqual( - set(channel.json_body.keys()), - {"flows", "params", "session"}, - channel.result, - ) - - auth = { - "type": "m.login.password", - # https://github.com/matrix-org/synapse/issues/5665 - # "identifier": {"type": "m.id.user", "user": user_id}, - "user": user_id, - "password": password, - "session": channel.json_body["session"], - } - - channel = self.make_request( - b"DELETE", - "devices/" + device_id, - access_token=access_token, - content={"auth": auth}, - ) - self.assertEquals(channel.code, 200, channel.result) - - @override_config({"session_lifetime": "24h"}) - def test_session_can_hard_logout_after_being_soft_logged_out(self): - self.register_user("kermit", "monkey") - - # log in as normal - access_token = self.login("kermit", "monkey") - - # we should now be able to make requests with the access token - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) - - # time passes - self.reactor.advance(24 * 3600) - - # ... and we should be soft-logouted - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) - - # Now try to hard logout this session - channel = self.make_request(b"POST", "/logout", access_token=access_token) - self.assertEquals(channel.result["code"], b"200", channel.result) - - @override_config({"session_lifetime": "24h"}) - def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(self): - self.register_user("kermit", "monkey") - - # log in as normal - access_token = self.login("kermit", "monkey") - - # we should now be able to make requests with the access token - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 200, channel.result) - - # time passes - self.reactor.advance(24 * 3600) - - # ... and we should be soft-logouted - channel = self.make_request(b"GET", TEST_URL, access_token=access_token) - self.assertEquals(channel.code, 401, channel.result) - self.assertEquals(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") - self.assertEquals(channel.json_body["soft_logout"], True) - - # Now try to hard log out all of the user's sessions - channel = self.make_request(b"POST", "/logout/all", access_token=access_token) - self.assertEquals(channel.result["code"], b"200", channel.result) - - -@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") -class MultiSSOTestCase(unittest.HomeserverTestCase): - """Tests for homeservers with multiple SSO providers enabled""" - - servlets = [ - login.register_servlets, - ] - - def default_config(self) -> Dict[str, Any]: - config = super().default_config() - - config["public_baseurl"] = BASE_URL - - config["cas_config"] = { - "enabled": True, - "server_url": CAS_SERVER, - "service_url": "https://matrix.goodserver.com:8448", - } - - config["saml2_config"] = { - "sp_config": { - "metadata": {"inline": [TEST_SAML_METADATA]}, - # use the XMLSecurity backend to avoid relying on xmlsec1 - "crypto_backend": "XMLSecurity", - }, - } - - # default OIDC provider - config["oidc_config"] = TEST_OIDC_CONFIG - - # additional OIDC providers - config["oidc_providers"] = [ - { - "idp_id": "idp1", - "idp_name": "IDP1", - "discover": False, - "issuer": "https://issuer1", - "client_id": "test-client-id", - "client_secret": "test-client-secret", - "scopes": ["profile"], - "authorization_endpoint": "https://issuer1/auth", - "token_endpoint": "https://issuer1/token", - "userinfo_endpoint": "https://issuer1/userinfo", - "user_mapping_provider": { - "config": {"localpart_template": "{{ user.sub }}"} - }, - } - ] - return config - - def create_resource_dict(self) -> Dict[str, Resource]: - d = super().create_resource_dict() - d.update(build_synapse_client_resource_tree(self.hs)) - return d - - def test_get_login_flows(self): - """GET /login should return password and SSO flows""" - channel = self.make_request("GET", "/_matrix/client/r0/login") - self.assertEqual(channel.code, 200, channel.result) - - expected_flow_types = [ - "m.login.cas", - "m.login.sso", - "m.login.token", - "m.login.password", - ] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS] - - self.assertCountEqual( - [f["type"] for f in channel.json_body["flows"]], expected_flow_types - ) - - @override_config({"experimental_features": {"msc2858_enabled": True}}) - def test_get_msc2858_login_flows(self): - """The SSO flow should include IdP info if MSC2858 is enabled""" - channel = self.make_request("GET", "/_matrix/client/r0/login") - self.assertEqual(channel.code, 200, channel.result) - - # stick the flows results in a dict by type - flow_results: Dict[str, Any] = {} - for f in channel.json_body["flows"]: - flow_type = f["type"] - self.assertNotIn( - flow_type, flow_results, "duplicate flow type %s" % (flow_type,) - ) - flow_results[flow_type] = f - - self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned") - sso_flow = flow_results.pop("m.login.sso") - # we should have a set of IdPs - self.assertCountEqual( - sso_flow["org.matrix.msc2858.identity_providers"], - [ - {"id": "cas", "name": "CAS"}, - {"id": "saml", "name": "SAML"}, - {"id": "oidc-idp1", "name": "IDP1"}, - {"id": "oidc", "name": "OIDC"}, - ], - ) - - # the rest of the flows are simple - expected_flows = [ - {"type": "m.login.cas"}, - {"type": "m.login.token"}, - {"type": "m.login.password"}, - ] + ADDITIONAL_LOGIN_FLOWS - - self.assertCountEqual(flow_results.values(), expected_flows) - - def test_multi_sso_redirect(self): - """/login/sso/redirect should redirect to an identity picker""" - # first hit the redirect url, which should redirect to our idp picker - channel = self._make_sso_redirect_request(False, None) - self.assertEqual(channel.code, 302, channel.result) - uri = channel.headers.getRawHeaders("Location")[0] - - # hitting that picker should give us some HTML - channel = self.make_request("GET", uri) - self.assertEqual(channel.code, 200, channel.result) - - # parse the form to check it has fields assumed elsewhere in this class - html = channel.result["body"].decode("utf-8") - p = TestHtmlParser() - p.feed(html) - p.close() - - # there should be a link for each href - returned_idps: List[str] = [] - for link in p.links: - path, query = link.split("?", 1) - self.assertEqual(path, "pick_idp") - params = urllib.parse.parse_qs(query) - self.assertEqual(params["redirectUrl"], [TEST_CLIENT_REDIRECT_URL]) - returned_idps.append(params["idp"][0]) - - self.assertCountEqual(returned_idps, ["cas", "oidc", "oidc-idp1", "saml"]) - - def test_multi_sso_redirect_to_cas(self): - """If CAS is chosen, should redirect to the CAS server""" - - channel = self.make_request( - "GET", - "/_synapse/client/pick_idp?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - + "&idp=cas", - shorthand=False, - ) - self.assertEqual(channel.code, 302, channel.result) - location_headers = channel.headers.getRawHeaders("Location") - assert location_headers - cas_uri = location_headers[0] - cas_uri_path, cas_uri_query = cas_uri.split("?", 1) - - # it should redirect us to the login page of the cas server - self.assertEqual(cas_uri_path, CAS_SERVER + "/login") - - # check that the redirectUrl is correctly encoded in the service param - ie, the - # place that CAS will redirect to - cas_uri_params = urllib.parse.parse_qs(cas_uri_query) - service_uri = cas_uri_params["service"][0] - _, service_uri_query = service_uri.split("?", 1) - service_uri_params = urllib.parse.parse_qs(service_uri_query) - self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) - - def test_multi_sso_redirect_to_saml(self): - """If SAML is chosen, should redirect to the SAML server""" - channel = self.make_request( - "GET", - "/_synapse/client/pick_idp?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - + "&idp=saml", - ) - self.assertEqual(channel.code, 302, channel.result) - location_headers = channel.headers.getRawHeaders("Location") - assert location_headers - saml_uri = location_headers[0] - saml_uri_path, saml_uri_query = saml_uri.split("?", 1) - - # it should redirect us to the login page of the SAML server - self.assertEqual(saml_uri_path, SAML_SERVER) - - # the RelayState is used to carry the client redirect url - saml_uri_params = urllib.parse.parse_qs(saml_uri_query) - relay_state_param = saml_uri_params["RelayState"][0] - self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) - - def test_login_via_oidc(self): - """If OIDC is chosen, should redirect to the OIDC auth endpoint""" - - # pick the default OIDC provider - channel = self.make_request( - "GET", - "/_synapse/client/pick_idp?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - + "&idp=oidc", - ) - self.assertEqual(channel.code, 302, channel.result) - location_headers = channel.headers.getRawHeaders("Location") - assert location_headers - oidc_uri = location_headers[0] - oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) - - # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) - - # ... and should have set a cookie including the redirect url - cookie_headers = channel.headers.getRawHeaders("Set-Cookie") - assert cookie_headers - cookies: Dict[str, str] = {} - for h in cookie_headers: - key, value = h.split(";")[0].split("=", maxsplit=1) - cookies[key] = value - - oidc_session_cookie = cookies["oidc_session"] - macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) - self.assertEqual( - self._get_value_from_macaroon(macaroon, "client_redirect_url"), - TEST_CLIENT_REDIRECT_URL, - ) - - channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) - - # that should serve a confirmation page - self.assertEqual(channel.code, 200, channel.result) - content_type_headers = channel.headers.getRawHeaders("Content-Type") - assert content_type_headers - self.assertTrue(content_type_headers[-1].startswith("text/html")) - p = TestHtmlParser() - p.feed(channel.text_body) - p.close() - - # ... which should contain our redirect link - self.assertEqual(len(p.links), 1) - path, query = p.links[0].split("?", 1) - self.assertEqual(path, "https://x") - - # it will have url-encoded the params properly, so we'll have to parse them - params = urllib.parse.parse_qsl( - query, keep_blank_values=True, strict_parsing=True, errors="strict" - ) - self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) - self.assertEqual(params[2][0], "loginToken") - - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token, mxid, and device id. - login_token = params[2][1] - chan = self.make_request( - "POST", - "/login", - content={"type": "m.login.token", "token": login_token}, - ) - self.assertEqual(chan.code, 200, chan.result) - self.assertEqual(chan.json_body["user_id"], "@user1:test") - - def test_multi_sso_redirect_to_unknown(self): - """An unknown IdP should cause a 400""" - channel = self.make_request( - "GET", - "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", - ) - self.assertEqual(channel.code, 400, channel.result) - - def test_client_idp_redirect_to_unknown(self): - """If the client tries to pick an unknown IdP, return a 404""" - channel = self._make_sso_redirect_request(False, "xxx") - self.assertEqual(channel.code, 404, channel.result) - self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") - - def test_client_idp_redirect_to_oidc(self): - """If the client pick a known IdP, redirect to it""" - channel = self._make_sso_redirect_request(False, "oidc") - self.assertEqual(channel.code, 302, channel.result) - oidc_uri = channel.headers.getRawHeaders("Location")[0] - oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) - - # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) - - @override_config({"experimental_features": {"msc2858_enabled": True}}) - def test_client_msc2858_redirect_to_oidc(self): - """Test the unstable API""" - channel = self._make_sso_redirect_request(True, "oidc") - self.assertEqual(channel.code, 302, channel.result) - oidc_uri = channel.headers.getRawHeaders("Location")[0] - oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) - - # it should redirect us to the auth page of the OIDC server - self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) - - def test_client_idp_redirect_msc2858_disabled(self): - """If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400""" - channel = self._make_sso_redirect_request(True, "oidc") - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") - - def _make_sso_redirect_request( - self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None - ): - """Send a request to /_matrix/client/r0/login/sso/redirect - - ... or the unstable equivalent - - ... possibly specifying an IDP provider - """ - endpoint = ( - "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect" - if unstable_endpoint - else "/_matrix/client/r0/login/sso/redirect" - ) - if idp_prov is not None: - endpoint += "/" + idp_prov - endpoint += "?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - - return self.make_request( - "GET", - endpoint, - custom_headers=[("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME)], - ) - - @staticmethod - def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: - prefix = key + " = " - for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(prefix): - return caveat.caveat_id[len(prefix) :] - raise ValueError("No %s caveat in macaroon" % (key,)) - - -class CASTestCase(unittest.HomeserverTestCase): - - servlets = [ - login.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - self.base_url = "https://matrix.goodserver.com/" - self.redirect_path = "_synapse/client/login/sso/redirect/confirm" - - config = self.default_config() - config["public_baseurl"] = ( - config.get("public_baseurl") or "https://matrix.goodserver.com:8448" - ) - config["cas_config"] = { - "enabled": True, - "server_url": CAS_SERVER, - } - - cas_user_id = "username" - self.user_id = "@%s:test" % cas_user_id - - async def get_raw(uri, args): - """Return an example response payload from a call to the `/proxyValidate` - endpoint of a CAS server, copied from - https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20 - - This needs to be returned by an async function (as opposed to set as the - mock's return value) because the corresponding Synapse code awaits on it. - """ - return ( - """ - - - %s - PGTIOU-84678-8a9d... - - https://proxy2/pgtUrl - https://proxy1/pgtUrl - - - - """ - % cas_user_id - ).encode("utf-8") - - mocked_http_client = Mock(spec=["get_raw"]) - mocked_http_client.get_raw.side_effect = get_raw - - self.hs = self.setup_test_homeserver( - config=config, - proxied_http_client=mocked_http_client, - ) - - return self.hs - - def prepare(self, reactor, clock, hs): - self.deactivate_account_handler = hs.get_deactivate_account_handler() - - def test_cas_redirect_confirm(self): - """Tests that the SSO login flow serves a confirmation page before redirecting a - user to the redirect URL. - """ - base_url = "/_matrix/client/r0/login/cas/ticket?redirectUrl" - redirect_url = "https://dodgy-site.com/" - - url_parts = list(urllib.parse.urlparse(base_url)) - query = dict(urllib.parse.parse_qsl(url_parts[4])) - query.update({"redirectUrl": redirect_url}) - query.update({"ticket": "ticket"}) - url_parts[4] = urllib.parse.urlencode(query) - cas_ticket_url = urllib.parse.urlunparse(url_parts) - - # Get Synapse to call the fake CAS and serve the template. - channel = self.make_request("GET", cas_ticket_url) - - # Test that the response is HTML. - self.assertEqual(channel.code, 200, channel.result) - content_type_header_value = "" - for header in channel.result.get("headers", []): - if header[0] == b"Content-Type": - content_type_header_value = header[1].decode("utf8") - - self.assertTrue(content_type_header_value.startswith("text/html")) - - # Test that the body isn't empty. - self.assertTrue(len(channel.result["body"]) > 0) - - # And that it contains our redirect link - self.assertIn(redirect_url, channel.result["body"].decode("UTF-8")) - - @override_config( - { - "sso": { - "client_whitelist": [ - "https://legit-site.com/", - "https://other-site.com/", - ] - } - } - ) - def test_cas_redirect_whitelisted(self): - """Tests that the SSO login flow serves a redirect to a whitelisted url""" - self._test_redirect("https://legit-site.com/") - - @override_config({"public_baseurl": "https://example.com"}) - def test_cas_redirect_login_fallback(self): - self._test_redirect("https://example.com/_matrix/static/client/login") - - def _test_redirect(self, redirect_url): - """Tests that the SSO login flow serves a redirect for the given redirect URL.""" - cas_ticket_url = ( - "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" - % (urllib.parse.quote(redirect_url)) - ) - - # Get Synapse to call the fake CAS and serve the template. - channel = self.make_request("GET", cas_ticket_url) - - self.assertEqual(channel.code, 302) - location_headers = channel.headers.getRawHeaders("Location") - assert location_headers - self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) - - @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) - def test_deactivated_user(self): - """Logging in as a deactivated account should error.""" - redirect_url = "https://legit-site.com/" - - # First login (to create the user). - self._test_redirect(redirect_url) - - # Deactivate the account. - self.get_success( - self.deactivate_account_handler.deactivate_account( - self.user_id, False, create_requester(self.user_id) - ) - ) - - # Request the CAS ticket. - cas_ticket_url = ( - "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket" - % (urllib.parse.quote(redirect_url)) - ) - - # Get Synapse to call the fake CAS and serve the template. - channel = self.make_request("GET", cas_ticket_url) - - # Because the user is deactivated they are served an error template. - self.assertEqual(channel.code, 403) - self.assertIn(b"SSO account deactivated", channel.result["body"]) - - -@skip_unless(HAS_JWT, "requires jwt") -class JWTTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - ] - - jwt_secret = "secret" - jwt_algorithm = "HS256" - - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - self.hs.config.jwt_enabled = True - self.hs.config.jwt_secret = self.jwt_secret - self.hs.config.jwt_algorithm = self.jwt_algorithm - return self.hs - - def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: - # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result: Union[str, bytes] = jwt.encode(payload, secret, self.jwt_algorithm) - if isinstance(result, bytes): - return result.decode("ascii") - return result - - def jwt_login(self, *args): - params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} - channel = self.make_request(b"POST", LOGIN_URL, params) - return channel - - def test_login_jwt_valid_registered(self): - self.register_user("kermit", "monkey") - channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) - self.assertEqual(channel.json_body["user_id"], "@kermit:test") - - def test_login_jwt_valid_unregistered(self): - channel = self.jwt_login({"sub": "frog"}) - self.assertEqual(channel.result["code"], b"200", channel.result) - self.assertEqual(channel.json_body["user_id"], "@frog:test") - - def test_login_jwt_invalid_signature(self): - channel = self.jwt_login({"sub": "frog"}, "notsecret") - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], - "JWT validation failed: Signature verification failed", - ) - - def test_login_jwt_expired(self): - channel = self.jwt_login({"sub": "frog", "exp": 864000}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], "JWT validation failed: Signature has expired" - ) - - def test_login_jwt_not_before(self): - now = int(time.time()) - channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], - "JWT validation failed: The token is not yet valid (nbf)", - ) - - def test_login_no_sub(self): - channel = self.jwt_login({"username": "root"}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual(channel.json_body["error"], "Invalid JWT") - - @override_config( - { - "jwt_config": { - "jwt_enabled": True, - "secret": jwt_secret, - "algorithm": jwt_algorithm, - "issuer": "test-issuer", - } - } - ) - def test_login_iss(self): - """Test validating the issuer claim.""" - # A valid issuer. - channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) - self.assertEqual(channel.result["code"], b"200", channel.result) - self.assertEqual(channel.json_body["user_id"], "@kermit:test") - - # An invalid issuer. - channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid issuer" - ) - - # Not providing an issuer. - channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], - 'JWT validation failed: Token is missing the "iss" claim', - ) - - def test_login_iss_no_config(self): - """Test providing an issuer claim without requiring it in the configuration.""" - channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) - self.assertEqual(channel.result["code"], b"200", channel.result) - self.assertEqual(channel.json_body["user_id"], "@kermit:test") - - @override_config( - { - "jwt_config": { - "jwt_enabled": True, - "secret": jwt_secret, - "algorithm": jwt_algorithm, - "audiences": ["test-audience"], - } - } - ) - def test_login_aud(self): - """Test validating the audience claim.""" - # A valid audience. - channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) - self.assertEqual(channel.result["code"], b"200", channel.result) - self.assertEqual(channel.json_body["user_id"], "@kermit:test") - - # An invalid audience. - channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid audience" - ) - - # Not providing an audience. - channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], - 'JWT validation failed: Token is missing the "aud" claim', - ) - - def test_login_aud_no_config(self): - """Test providing an audience without requiring it in the configuration.""" - channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], "JWT validation failed: Invalid audience" - ) - - def test_login_no_token(self): - params = {"type": "org.matrix.login.jwt"} - channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") - - -# The JWTPubKeyTestCase is a complement to JWTTestCase where we instead use -# RSS256, with a public key configured in synapse as "jwt_secret", and tokens -# signed by the private key. -@skip_unless(HAS_JWT, "requires jwt") -class JWTPubKeyTestCase(unittest.HomeserverTestCase): - servlets = [ - login.register_servlets, - ] - - # This key's pubkey is used as the jwt_secret setting of synapse. Valid - # tokens are signed by this and validated using the pubkey. It is generated - # with `openssl genrsa 512` (not a secure way to generate real keys, but - # good enough for tests!) - jwt_privatekey = "\n".join( - [ - "-----BEGIN RSA PRIVATE KEY-----", - "MIIBPAIBAAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7TKO1vSEWdq7u9x8SMFiB", - "492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQJAUv7OOSOtiU+wzJq82rnk", - "yR4NHqt7XX8BvkZPM7/+EjBRanmZNSp5kYZzKVaZ/gTOM9+9MwlmhidrUOweKfB/", - "kQIhAPZwHazbjo7dYlJs7wPQz1vd+aHSEH+3uQKIysebkmm3AiEA1nc6mDdmgiUq", - "TpIN8A4MBKmfZMWTLq6z05y/qjKyxb0CIQDYJxCwTEenIaEa4PdoJl+qmXFasVDN", - "ZU0+XtNV7yul0wIhAMI9IhiStIjS2EppBa6RSlk+t1oxh2gUWlIh+YVQfZGRAiEA", - "tqBR7qLZGJ5CVKxWmNhJZGt1QHoUtOch8t9C4IdOZ2g=", - "-----END RSA PRIVATE KEY-----", - ] - ) - - # Generated with `openssl rsa -in foo.key -pubout`, with the the above - # private key placed in foo.key (jwt_privatekey). - jwt_pubkey = "\n".join( - [ - "-----BEGIN PUBLIC KEY-----", - "MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAM50f1Q5gsdmzifLstzLHb5NhfajiOt7", - "TKO1vSEWdq7u9x8SMFiB492RM9W/XFoh8WUfL9uL6Now6tPRDsWv3xsCAwEAAQ==", - "-----END PUBLIC KEY-----", - ] - ) - - # This key is used to sign tokens that shouldn't be accepted by synapse. - # Generated just like jwt_privatekey. - bad_privatekey = "\n".join( - [ - "-----BEGIN RSA PRIVATE KEY-----", - "MIIBOgIBAAJBAL//SQrKpKbjCCnv/FlasJCv+t3k/MPsZfniJe4DVFhsktF2lwQv", - "gLjmQD3jBUTz+/FndLSBvr3F4OHtGL9O/osCAwEAAQJAJqH0jZJW7Smzo9ShP02L", - "R6HRZcLExZuUrWI+5ZSP7TaZ1uwJzGFspDrunqaVoPobndw/8VsP8HFyKtceC7vY", - "uQIhAPdYInDDSJ8rFKGiy3Ajv5KWISBicjevWHF9dbotmNO9AiEAxrdRJVU+EI9I", - "eB4qRZpY6n4pnwyP0p8f/A3NBaQPG+cCIFlj08aW/PbxNdqYoBdeBA0xDrXKfmbb", - "iwYxBkwL0JCtAiBYmsi94sJn09u2Y4zpuCbJeDPKzWkbuwQh+W1fhIWQJQIhAKR0", - "KydN6cRLvphNQ9c/vBTdlzWxzcSxREpguC7F1J1m", - "-----END RSA PRIVATE KEY-----", - ] - ) - - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - self.hs.config.jwt_enabled = True - self.hs.config.jwt_secret = self.jwt_pubkey - self.hs.config.jwt_algorithm = "RS256" - return self.hs - - def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: - # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str. - result: Union[bytes, str] = jwt.encode(payload, secret, "RS256") - if isinstance(result, bytes): - return result.decode("ascii") - return result - - def jwt_login(self, *args): - params = {"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)} - channel = self.make_request(b"POST", LOGIN_URL, params) - return channel - - def test_login_jwt_valid(self): - channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) - self.assertEqual(channel.json_body["user_id"], "@kermit:test") - - def test_login_jwt_invalid_signature(self): - channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) - self.assertEqual(channel.result["code"], b"403", channel.result) - self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - self.assertEqual( - channel.json_body["error"], - "JWT validation failed: Signature verification failed", - ) - - -AS_USER = "as_user_alice" - - -class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): - servlets = [ - login.register_servlets, - register.register_servlets, - ] - - def register_as_user(self, username): - self.make_request( - b"POST", - "/_matrix/client/r0/register?access_token=%s" % (self.service.token,), - {"username": username}, - ) - - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - - self.service = ApplicationService( - id="unique_identifier", - token="some_token", - hostname="example.com", - sender="@asbot:example.com", - namespaces={ - ApplicationService.NS_USERS: [ - {"regex": r"@as_user.*", "exclusive": False} - ], - ApplicationService.NS_ROOMS: [], - ApplicationService.NS_ALIASES: [], - }, - ) - self.another_service = ApplicationService( - id="another__identifier", - token="another_token", - hostname="example.com", - sender="@as2bot:example.com", - namespaces={ - ApplicationService.NS_USERS: [ - {"regex": r"@as2_user.*", "exclusive": False} - ], - ApplicationService.NS_ROOMS: [], - ApplicationService.NS_ALIASES: [], - }, - ) - - self.hs.get_datastore().services_cache.append(self.service) - self.hs.get_datastore().services_cache.append(self.another_service) - return self.hs - - def test_login_appservice_user(self): - """Test that an appservice user can use /login""" - self.register_as_user(AS_USER) - - params = { - "type": login.LoginRestServlet.APPSERVICE_TYPE, - "identifier": {"type": "m.id.user", "user": AS_USER}, - } - channel = self.make_request( - b"POST", LOGIN_URL, params, access_token=self.service.token - ) - - self.assertEquals(channel.result["code"], b"200", channel.result) - - def test_login_appservice_user_bot(self): - """Test that the appservice bot can use /login""" - self.register_as_user(AS_USER) - - params = { - "type": login.LoginRestServlet.APPSERVICE_TYPE, - "identifier": {"type": "m.id.user", "user": self.service.sender}, - } - channel = self.make_request( - b"POST", LOGIN_URL, params, access_token=self.service.token - ) - - self.assertEquals(channel.result["code"], b"200", channel.result) - - def test_login_appservice_wrong_user(self): - """Test that non-as users cannot login with the as token""" - self.register_as_user(AS_USER) - - params = { - "type": login.LoginRestServlet.APPSERVICE_TYPE, - "identifier": {"type": "m.id.user", "user": "fibble_wibble"}, - } - channel = self.make_request( - b"POST", LOGIN_URL, params, access_token=self.service.token - ) - - self.assertEquals(channel.result["code"], b"403", channel.result) - - def test_login_appservice_wrong_as(self): - """Test that as users cannot login with wrong as token""" - self.register_as_user(AS_USER) - - params = { - "type": login.LoginRestServlet.APPSERVICE_TYPE, - "identifier": {"type": "m.id.user", "user": AS_USER}, - } - channel = self.make_request( - b"POST", LOGIN_URL, params, access_token=self.another_service.token - ) - - self.assertEquals(channel.result["code"], b"403", channel.result) - - def test_login_appservice_no_token(self): - """Test that users must provide a token when using the appservice - login method - """ - self.register_as_user(AS_USER) - - params = { - "type": login.LoginRestServlet.APPSERVICE_TYPE, - "identifier": {"type": "m.id.user", "user": AS_USER}, - } - channel = self.make_request(b"POST", LOGIN_URL, params) - - self.assertEquals(channel.result["code"], b"401", channel.result) - - -@skip_unless(HAS_OIDC, "requires OIDC") -class UsernamePickerTestCase(HomeserverTestCase): - """Tests for the username picker flow of SSO login""" - - servlets = [login.register_servlets] - - def default_config(self): - config = super().default_config() - config["public_baseurl"] = BASE_URL - - config["oidc_config"] = {} - config["oidc_config"].update(TEST_OIDC_CONFIG) - config["oidc_config"]["user_mapping_provider"] = { - "config": {"display_name_template": "{{ user.displayname }}"} - } - - # whitelist this client URI so we redirect straight to it rather than - # serving a confirmation page - config["sso"] = {"client_whitelist": ["https://x"]} - return config - - def create_resource_dict(self) -> Dict[str, Resource]: - d = super().create_resource_dict() - d.update(build_synapse_client_resource_tree(self.hs)) - return d - - def test_username_picker(self): - """Test the happy path of a username picker flow.""" - - # do the start of the login flow - channel = self.helper.auth_via_oidc( - {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL - ) - - # that should redirect to the username picker - self.assertEqual(channel.code, 302, channel.result) - location_headers = channel.headers.getRawHeaders("Location") - assert location_headers - picker_url = location_headers[0] - self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details") - - # ... with a username_mapping_session cookie - cookies: Dict[str, str] = {} - channel.extract_cookies(cookies) - self.assertIn("username_mapping_session", cookies) - session_id = cookies["username_mapping_session"] - - # introspect the sso handler a bit to check that the username mapping session - # looks ok. - username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions - self.assertIn( - session_id, - username_mapping_sessions, - "session id not found in map", - ) - session = username_mapping_sessions[session_id] - self.assertEqual(session.remote_user_id, "tester") - self.assertEqual(session.display_name, "Jonny") - self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL) - - # the expiry time should be about 15 minutes away - expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) - self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) - - # Now, submit a username to the username picker, which should serve a redirect - # to the completion page - content = urlencode({b"username": b"bobby"}).encode("utf8") - chan = self.make_request( - "POST", - path=picker_url, - content=content, - content_is_form=True, - custom_headers=[ - ("Cookie", "username_mapping_session=" + session_id), - # old versions of twisted don't do form-parsing without a valid - # content-length header. - ("Content-Length", str(len(content))), - ], - ) - self.assertEqual(chan.code, 302, chan.result) - location_headers = chan.headers.getRawHeaders("Location") - assert location_headers - - # send a request to the completion page, which should 302 to the client redirectUrl - chan = self.make_request( - "GET", - path=location_headers[0], - custom_headers=[("Cookie", "username_mapping_session=" + session_id)], - ) - self.assertEqual(chan.code, 302, chan.result) - location_headers = chan.headers.getRawHeaders("Location") - assert location_headers - - # ensure that the returned location matches the requested redirect URL - path, query = location_headers[0].split("?", 1) - self.assertEqual(path, "https://x") - - # it will have url-encoded the params properly, so we'll have to parse them - params = urllib.parse.parse_qsl( - query, keep_blank_values=True, strict_parsing=True, errors="strict" - ) - self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) - self.assertEqual(params[2][0], "loginToken") - - # fish the login token out of the returned redirect uri - login_token = params[2][1] - - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token, mxid, and device id. - chan = self.make_request( - "POST", - "/login", - content={"type": "m.login.token", "token": login_token}, - ) - self.assertEqual(chan.code, 200, chan.result) - self.assertEqual(chan.json_body["user_id"], "@bobby:test") diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py deleted file mode 100644 index 1d152352d1..0000000000 --- a/tests/rest/client/v1/test_presence.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from unittest.mock import Mock - -from twisted.internet import defer - -from synapse.handlers.presence import PresenceHandler -from synapse.rest.client import presence -from synapse.types import UserID - -from tests import unittest - - -class PresenceTestCase(unittest.HomeserverTestCase): - """Tests presence REST API.""" - - user_id = "@sid:red" - - user = UserID.from_string(user_id) - servlets = [presence.register_servlets] - - def make_homeserver(self, reactor, clock): - - presence_handler = Mock(spec=PresenceHandler) - presence_handler.set_state.return_value = defer.succeed(None) - - hs = self.setup_test_homeserver( - "red", - federation_http_client=None, - federation_client=Mock(), - presence_handler=presence_handler, - ) - - return hs - - def test_put_presence(self): - """ - PUT to the status endpoint with use_presence enabled will call - set_state on the presence handler. - """ - self.hs.config.use_presence = True - - body = {"presence": "here", "status_msg": "beep boop"} - channel = self.make_request( - "PUT", "/presence/%s/status" % (self.user_id,), body - ) - - self.assertEqual(channel.code, 200) - self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1) - - @unittest.override_config({"use_presence": False}) - def test_put_presence_disabled(self): - """ - PUT to the status endpoint with use_presence disabled will NOT call - set_state on the presence handler. - """ - - body = {"presence": "here", "status_msg": "beep boop"} - channel = self.make_request( - "PUT", "/presence/%s/status" % (self.user_id,), body - ) - - self.assertEqual(channel.code, 200) - self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0) diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py deleted file mode 100644 index 2860579c2e..0000000000 --- a/tests/rest/client/v1/test_profile.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests REST events for /profile paths.""" -from synapse.rest import admin -from synapse.rest.client import login, profile, room - -from tests import unittest - - -class ProfileTestCase(unittest.HomeserverTestCase): - - servlets = [ - admin.register_servlets_for_client_rest_resource, - login.register_servlets, - profile.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - return self.hs - - def prepare(self, reactor, clock, hs): - self.owner = self.register_user("owner", "pass") - self.owner_tok = self.login("owner", "pass") - self.other = self.register_user("other", "pass", displayname="Bob") - - def test_get_displayname(self): - res = self._get_displayname() - self.assertEqual(res, "owner") - - def test_set_displayname(self): - channel = self.make_request( - "PUT", - "/profile/%s/displayname" % (self.owner,), - content={"displayname": "test"}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - res = self._get_displayname() - self.assertEqual(res, "test") - - def test_set_displayname_noauth(self): - channel = self.make_request( - "PUT", - "/profile/%s/displayname" % (self.owner,), - content={"displayname": "test"}, - ) - self.assertEqual(channel.code, 401, channel.result) - - def test_set_displayname_too_long(self): - """Attempts to set a stupid displayname should get a 400""" - channel = self.make_request( - "PUT", - "/profile/%s/displayname" % (self.owner,), - content={"displayname": "test" * 100}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 400, channel.result) - - res = self._get_displayname() - self.assertEqual(res, "owner") - - def test_get_displayname_other(self): - res = self._get_displayname(self.other) - self.assertEquals(res, "Bob") - - def test_set_displayname_other(self): - channel = self.make_request( - "PUT", - "/profile/%s/displayname" % (self.other,), - content={"displayname": "test"}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 400, channel.result) - - def test_get_avatar_url(self): - res = self._get_avatar_url() - self.assertIsNone(res) - - def test_set_avatar_url(self): - channel = self.make_request( - "PUT", - "/profile/%s/avatar_url" % (self.owner,), - content={"avatar_url": "http://my.server/pic.gif"}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - res = self._get_avatar_url() - self.assertEqual(res, "http://my.server/pic.gif") - - def test_set_avatar_url_noauth(self): - channel = self.make_request( - "PUT", - "/profile/%s/avatar_url" % (self.owner,), - content={"avatar_url": "http://my.server/pic.gif"}, - ) - self.assertEqual(channel.code, 401, channel.result) - - def test_set_avatar_url_too_long(self): - """Attempts to set a stupid avatar_url should get a 400""" - channel = self.make_request( - "PUT", - "/profile/%s/avatar_url" % (self.owner,), - content={"avatar_url": "http://my.server/pic.gif" * 100}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 400, channel.result) - - res = self._get_avatar_url() - self.assertIsNone(res) - - def test_get_avatar_url_other(self): - res = self._get_avatar_url(self.other) - self.assertIsNone(res) - - def test_set_avatar_url_other(self): - channel = self.make_request( - "PUT", - "/profile/%s/avatar_url" % (self.other,), - content={"avatar_url": "http://my.server/pic.gif"}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 400, channel.result) - - def _get_displayname(self, name=None): - channel = self.make_request( - "GET", "/profile/%s/displayname" % (name or self.owner,) - ) - self.assertEqual(channel.code, 200, channel.result) - return channel.json_body["displayname"] - - def _get_avatar_url(self, name=None): - channel = self.make_request( - "GET", "/profile/%s/avatar_url" % (name or self.owner,) - ) - self.assertEqual(channel.code, 200, channel.result) - return channel.json_body.get("avatar_url") - - -class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): - - servlets = [ - admin.register_servlets_for_client_rest_resource, - login.register_servlets, - profile.register_servlets, - room.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - - config = self.default_config() - config["require_auth_for_profile_requests"] = True - config["limit_profile_requests_to_users_who_share_rooms"] = True - self.hs = self.setup_test_homeserver(config=config) - - return self.hs - - def prepare(self, reactor, clock, hs): - # User owning the requested profile. - self.owner = self.register_user("owner", "pass") - self.owner_tok = self.login("owner", "pass") - self.profile_url = "/profile/%s" % (self.owner) - - # User requesting the profile. - self.requester = self.register_user("requester", "pass") - self.requester_tok = self.login("requester", "pass") - - self.room_id = self.helper.create_room_as(self.owner, tok=self.owner_tok) - - def test_no_auth(self): - self.try_fetch_profile(401) - - def test_not_in_shared_room(self): - self.ensure_requester_left_room() - - self.try_fetch_profile(403, access_token=self.requester_tok) - - def test_in_shared_room(self): - self.ensure_requester_left_room() - - self.helper.join(room=self.room_id, user=self.requester, tok=self.requester_tok) - - self.try_fetch_profile(200, self.requester_tok) - - def try_fetch_profile(self, expected_code, access_token=None): - self.request_profile(expected_code, access_token=access_token) - - self.request_profile( - expected_code, url_suffix="/displayname", access_token=access_token - ) - - self.request_profile( - expected_code, url_suffix="/avatar_url", access_token=access_token - ) - - def request_profile(self, expected_code, url_suffix="", access_token=None): - channel = self.make_request( - "GET", self.profile_url + url_suffix, access_token=access_token - ) - self.assertEqual(channel.code, expected_code, channel.result) - - def ensure_requester_left_room(self): - try: - self.helper.leave( - room=self.room_id, user=self.requester, tok=self.requester_tok - ) - except AssertionError: - # We don't care whether the leave request didn't return a 200 (e.g. - # if the user isn't already in the room), because we only want to - # make sure the user isn't in the room. - pass - - -class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): - - servlets = [ - admin.register_servlets_for_client_rest_resource, - login.register_servlets, - profile.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - config["require_auth_for_profile_requests"] = True - config["limit_profile_requests_to_users_who_share_rooms"] = True - self.hs = self.setup_test_homeserver(config=config) - - return self.hs - - def prepare(self, reactor, clock, hs): - # User requesting the profile. - self.requester = self.register_user("requester", "pass") - self.requester_tok = self.login("requester", "pass") - - def test_can_lookup_own_profile(self): - """Tests that a user can lookup their own profile without having to be in a room - if 'require_auth_for_profile_requests' is set to true in the server's config. - """ - channel = self.make_request( - "GET", "/profile/" + self.requester, access_token=self.requester_tok - ) - self.assertEqual(channel.code, 200, channel.result) - - channel = self.make_request( - "GET", - "/profile/" + self.requester + "/displayname", - access_token=self.requester_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - channel = self.make_request( - "GET", - "/profile/" + self.requester + "/avatar_url", - access_token=self.requester_tok, - ) - self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/v1/test_push_rule_attrs.py deleted file mode 100644 index d0ce91ccd9..0000000000 --- a/tests/rest/client/v1/test_push_rule_attrs.py +++ /dev/null @@ -1,414 +0,0 @@ -# Copyright 2020 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import synapse -from synapse.api.errors import Codes -from synapse.rest.client import login, push_rule, room - -from tests.unittest import HomeserverTestCase - - -class PushRuleAttributesTestCase(HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - push_rule.register_servlets, - ] - hijack_auth = False - - def test_enabled_on_creation(self): - """ - Tests the GET and PUT of push rules' `enabled` endpoints. - Tests that a rule is enabled upon creation, even though a rule with that - ruleId existed previously and was disabled. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # GET enabled for that new rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["enabled"], True) - - def test_enabled_on_recreation(self): - """ - Tests the GET and PUT of push rules' `enabled` endpoints. - Tests that a rule is enabled upon creation, even if a rule with that - ruleId existed previously and was disabled. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # disable the rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/best.friend/enabled", - {"enabled": False}, - access_token=token, - ) - self.assertEqual(channel.code, 200) - - # check rule disabled - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["enabled"], False) - - # DELETE the rule - channel = self.make_request( - "DELETE", "/pushrules/global/override/best.friend", access_token=token - ) - self.assertEqual(channel.code, 200) - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # GET enabled for that new rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["enabled"], True) - - def test_enabled_disable(self): - """ - Tests the GET and PUT of push rules' `enabled` endpoints. - Tests that a rule is disabled and enabled when we ask for it. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # disable the rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/best.friend/enabled", - {"enabled": False}, - access_token=token, - ) - self.assertEqual(channel.code, 200) - - # check rule disabled - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["enabled"], False) - - # re-enable the rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/best.friend/enabled", - {"enabled": True}, - access_token=token, - ) - self.assertEqual(channel.code, 200) - - # check rule enabled - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["enabled"], True) - - def test_enabled_404_when_get_non_existent(self): - """ - Tests that `enabled` gives 404 when the rule doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # check 404 for never-heard-of rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # GET enabled for that new rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 200) - - # DELETE the rule - channel = self.make_request( - "DELETE", "/pushrules/global/override/best.friend", access_token=token - ) - self.assertEqual(channel.code, 200) - - # check 404 for deleted rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_enabled_404_when_get_non_existent_server_rule(self): - """ - Tests that `enabled` gives 404 when the server-default rule doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - # check 404 for never-heard-of rule - channel = self.make_request( - "GET", "/pushrules/global/override/.m.muahahaha/enabled", access_token=token - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_enabled_404_when_put_non_existent_rule(self): - """ - Tests that `enabled` gives 404 when we put to a rule that doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - # enable & check 404 for never-heard-of rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/best.friend/enabled", - {"enabled": True}, - access_token=token, - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_enabled_404_when_put_non_existent_server_rule(self): - """ - Tests that `enabled` gives 404 when we put to a server-default rule that doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - # enable & check 404 for never-heard-of rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/.m.muahahah/enabled", - {"enabled": True}, - access_token=token, - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_actions_get(self): - """ - Tests that `actions` gives you what you expect on a fresh rule. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # GET actions for that new rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/actions", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual( - channel.json_body["actions"], ["notify", {"set_tweak": "highlight"}] - ) - - def test_actions_put(self): - """ - Tests that PUT on actions updates the value you'd get from GET. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # change the rule actions - channel = self.make_request( - "PUT", - "/pushrules/global/override/best.friend/actions", - {"actions": ["dont_notify"]}, - access_token=token, - ) - self.assertEqual(channel.code, 200) - - # GET actions for that new rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/actions", access_token=token - ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["actions"], ["dont_notify"]) - - def test_actions_404_when_get_non_existent(self): - """ - Tests that `actions` gives 404 when the rule doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - body = { - "conditions": [ - {"kind": "event_match", "key": "sender", "pattern": "@user2:hs"} - ], - "actions": ["notify", {"set_tweak": "highlight"}], - } - - # check 404 for never-heard-of rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - # PUT a new rule - channel = self.make_request( - "PUT", "/pushrules/global/override/best.friend", body, access_token=token - ) - self.assertEqual(channel.code, 200) - - # DELETE the rule - channel = self.make_request( - "DELETE", "/pushrules/global/override/best.friend", access_token=token - ) - self.assertEqual(channel.code, 200) - - # check 404 for deleted rule - channel = self.make_request( - "GET", "/pushrules/global/override/best.friend/enabled", access_token=token - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_actions_404_when_get_non_existent_server_rule(self): - """ - Tests that `actions` gives 404 when the server-default rule doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - # check 404 for never-heard-of rule - channel = self.make_request( - "GET", "/pushrules/global/override/.m.muahahaha/actions", access_token=token - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_actions_404_when_put_non_existent_rule(self): - """ - Tests that `actions` gives 404 when putting to a rule that doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - # enable & check 404 for never-heard-of rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/best.friend/actions", - {"actions": ["dont_notify"]}, - access_token=token, - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) - - def test_actions_404_when_put_non_existent_server_rule(self): - """ - Tests that `actions` gives 404 when putting to a server-default rule that doesn't exist. - """ - self.register_user("user", "pass") - token = self.login("user", "pass") - - # enable & check 404 for never-heard-of rule - channel = self.make_request( - "PUT", - "/pushrules/global/override/.m.muahahah/actions", - {"actions": ["dont_notify"]}, - access_token=token, - ) - self.assertEqual(channel.code, 404) - self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py deleted file mode 100644 index 0c9cbb9aff..0000000000 --- a/tests/rest/client/v1/test_rooms.py +++ /dev/null @@ -1,2150 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017 Vector Creations Ltd -# Copyright 2018-2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests REST events for /rooms paths.""" - -import json -from typing import Iterable -from unittest.mock import Mock, call -from urllib import parse as urlparse - -from twisted.internet import defer - -import synapse.rest.admin -from synapse.api.constants import EventContentFields, EventTypes, Membership -from synapse.api.errors import HttpResponseException -from synapse.handlers.pagination import PurgeStatus -from synapse.rest import admin -from synapse.rest.client import account, directory, login, profile, room -from synapse.types import JsonDict, RoomAlias, UserID, create_requester -from synapse.util.stringutils import random_string - -from tests import unittest -from tests.test_utils import make_awaitable - -PATH_PREFIX = b"/_matrix/client/api/v1" - - -class RoomBase(unittest.HomeserverTestCase): - rmcreator_id = None - - servlets = [room.register_servlets, room.register_deprecated_servlets] - - def make_homeserver(self, reactor, clock): - - self.hs = self.setup_test_homeserver( - "red", - federation_http_client=None, - federation_client=Mock(), - ) - - self.hs.get_federation_handler = Mock() - self.hs.get_federation_handler.return_value.maybe_backfill = Mock( - return_value=make_awaitable(None) - ) - - async def _insert_client_ip(*args, **kwargs): - return None - - self.hs.get_datastore().insert_client_ip = _insert_client_ip - - return self.hs - - -class RoomPermissionsTestCase(RoomBase): - """Tests room permissions.""" - - user_id = "@sid1:red" - rmcreator_id = "@notme:red" - - def prepare(self, reactor, clock, hs): - - self.helper.auth_user_id = self.rmcreator_id - # create some rooms under the name rmcreator_id - self.uncreated_rmid = "!aa:test" - self.created_rmid = self.helper.create_room_as( - self.rmcreator_id, is_public=False - ) - self.created_public_rmid = self.helper.create_room_as( - self.rmcreator_id, is_public=True - ) - - # send a message in one of the rooms - self.created_rmid_msg_path = ( - "rooms/%s/send/m.room.message/a1" % (self.created_rmid) - ).encode("ascii") - channel = self.make_request( - "PUT", self.created_rmid_msg_path, b'{"msgtype":"m.text","body":"test msg"}' - ) - self.assertEquals(200, channel.code, channel.result) - - # set topic for public room - channel = self.make_request( - "PUT", - ("rooms/%s/state/m.room.topic" % self.created_public_rmid).encode("ascii"), - b'{"topic":"Public Room Topic"}', - ) - self.assertEquals(200, channel.code, channel.result) - - # auth as user_id now - self.helper.auth_user_id = self.user_id - - def test_can_do_action(self): - msg_content = b'{"msgtype":"m.text","body":"hello"}' - - seq = iter(range(100)) - - def send_msg_path(): - return "/rooms/%s/send/m.room.message/mid%s" % ( - self.created_rmid, - str(next(seq)), - ) - - # send message in uncreated room, expect 403 - channel = self.make_request( - "PUT", - "/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,), - msg_content, - ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # send message in created room not joined (no state), expect 403 - channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # send message in created room and invited, expect 403 - self.helper.invite( - room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id - ) - channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # send message in created room and joined, expect 200 - self.helper.join(room=self.created_rmid, user=self.user_id) - channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - # send message in created room and left, expect 403 - self.helper.leave(room=self.created_rmid, user=self.user_id) - channel = self.make_request("PUT", send_msg_path(), msg_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - def test_topic_perms(self): - topic_content = b'{"topic":"My Topic Name"}' - topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid - - # set/get topic in uncreated room, expect 403 - channel = self.make_request( - "PUT", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid, topic_content - ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - channel = self.make_request( - "GET", "/rooms/%s/state/m.room.topic" % self.uncreated_rmid - ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # set/get topic in created PRIVATE room not joined, expect 403 - channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - channel = self.make_request("GET", topic_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # set topic in created PRIVATE room and invited, expect 403 - self.helper.invite( - room=self.created_rmid, src=self.rmcreator_id, targ=self.user_id - ) - channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # get topic in created PRIVATE room and invited, expect 403 - channel = self.make_request("GET", topic_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # set/get topic in created PRIVATE room and joined, expect 200 - self.helper.join(room=self.created_rmid, user=self.user_id) - - # Only room ops can set topic by default - self.helper.auth_user_id = self.rmcreator_id - channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.helper.auth_user_id = self.user_id - - channel = self.make_request("GET", topic_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assert_dict(json.loads(topic_content.decode("utf8")), channel.json_body) - - # set/get topic in created PRIVATE room and left, expect 403 - self.helper.leave(room=self.created_rmid, user=self.user_id) - channel = self.make_request("PUT", topic_path, topic_content) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - channel = self.make_request("GET", topic_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - # get topic in PUBLIC room, not joined, expect 403 - channel = self.make_request( - "GET", "/rooms/%s/state/m.room.topic" % self.created_public_rmid - ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - # set topic in PUBLIC room, not joined, expect 403 - channel = self.make_request( - "PUT", - "/rooms/%s/state/m.room.topic" % self.created_public_rmid, - topic_content, - ) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - def _test_get_membership( - self, room=None, members: Iterable = frozenset(), expect_code=None - ): - for member in members: - path = "/rooms/%s/state/m.room.member/%s" % (room, member) - channel = self.make_request("GET", path) - self.assertEquals(expect_code, channel.code) - - def test_membership_basic_room_perms(self): - # === room does not exist === - room = self.uncreated_rmid - # get membership of self, get membership of other, uncreated room - # expect all 403s - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 - ) - - # trying to invite people to this room should 403 - self.helper.invite( - room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403 - ) - - # set [invite/join/left] of self, set [invite/join/left] of other, - # expect all 404s because room doesn't exist on any server - for usr in [self.user_id, self.rmcreator_id]: - self.helper.join(room=room, user=usr, expect_code=404) - self.helper.leave(room=room, user=usr, expect_code=404) - - def test_membership_private_room_perms(self): - room = self.created_rmid - # get membership of self, get membership of other, private room + invite - # expect all 403s - self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 - ) - - # get membership of self, get membership of other, private room + joined - # expect all 200s - self.helper.join(room=room, user=self.user_id) - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 - ) - - # get membership of self, get membership of other, private room + left - # expect all 200s - self.helper.leave(room=room, user=self.user_id) - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 - ) - - def test_membership_public_room_perms(self): - room = self.created_public_rmid - # get membership of self, get membership of other, public room + invite - # expect 403 - self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=403 - ) - - # get membership of self, get membership of other, public room + joined - # expect all 200s - self.helper.join(room=room, user=self.user_id) - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 - ) - - # get membership of self, get membership of other, public room + left - # expect 200. - self.helper.leave(room=room, user=self.user_id) - self._test_get_membership( - members=[self.user_id, self.rmcreator_id], room=room, expect_code=200 - ) - - def test_invited_permissions(self): - room = self.created_rmid - self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - - # set [invite/join/left] of other user, expect 403s - self.helper.invite( - room=room, src=self.user_id, targ=self.rmcreator_id, expect_code=403 - ) - self.helper.change_membership( - room=room, - src=self.user_id, - targ=self.rmcreator_id, - membership=Membership.JOIN, - expect_code=403, - ) - self.helper.change_membership( - room=room, - src=self.user_id, - targ=self.rmcreator_id, - membership=Membership.LEAVE, - expect_code=403, - ) - - def test_joined_permissions(self): - room = self.created_rmid - self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - self.helper.join(room=room, user=self.user_id) - - # set invited of self, expect 403 - self.helper.invite( - room=room, src=self.user_id, targ=self.user_id, expect_code=403 - ) - - # set joined of self, expect 200 (NOOP) - self.helper.join(room=room, user=self.user_id) - - other = "@burgundy:red" - # set invited of other, expect 200 - self.helper.invite(room=room, src=self.user_id, targ=other, expect_code=200) - - # set joined of other, expect 403 - self.helper.change_membership( - room=room, - src=self.user_id, - targ=other, - membership=Membership.JOIN, - expect_code=403, - ) - - # set left of other, expect 403 - self.helper.change_membership( - room=room, - src=self.user_id, - targ=other, - membership=Membership.LEAVE, - expect_code=403, - ) - - # set left of self, expect 200 - self.helper.leave(room=room, user=self.user_id) - - def test_leave_permissions(self): - room = self.created_rmid - self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id) - self.helper.join(room=room, user=self.user_id) - self.helper.leave(room=room, user=self.user_id) - - # set [invite/join/left] of self, set [invite/join/left] of other, - # expect all 403s - for usr in [self.user_id, self.rmcreator_id]: - self.helper.change_membership( - room=room, - src=self.user_id, - targ=usr, - membership=Membership.INVITE, - expect_code=403, - ) - - self.helper.change_membership( - room=room, - src=self.user_id, - targ=usr, - membership=Membership.JOIN, - expect_code=403, - ) - - # It is always valid to LEAVE if you've already left (currently.) - self.helper.change_membership( - room=room, - src=self.user_id, - targ=self.rmcreator_id, - membership=Membership.LEAVE, - expect_code=403, - ) - - -class RoomsMemberListTestCase(RoomBase): - """Tests /rooms/$room_id/members/list REST events.""" - - user_id = "@sid1:red" - - def test_get_member_list(self): - room_id = self.helper.create_room_as(self.user_id) - channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - def test_get_member_list_no_room(self): - channel = self.make_request("GET", "/rooms/roomdoesnotexist/members") - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - def test_get_member_list_no_permission(self): - room_id = self.helper.create_room_as("@some_other_guy:red") - channel = self.make_request("GET", "/rooms/%s/members" % room_id) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - def test_get_member_list_mixed_memberships(self): - room_creator = "@some_other_guy:red" - room_id = self.helper.create_room_as(room_creator) - room_path = "/rooms/%s/members" % room_id - self.helper.invite(room=room_id, src=room_creator, targ=self.user_id) - # can't see list if you're just invited. - channel = self.make_request("GET", room_path) - self.assertEquals(403, channel.code, msg=channel.result["body"]) - - self.helper.join(room=room_id, user=self.user_id) - # can see list now joined - channel = self.make_request("GET", room_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - self.helper.leave(room=room_id, user=self.user_id) - # can see old list once left - channel = self.make_request("GET", room_path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - -class RoomsCreateTestCase(RoomBase): - """Tests /rooms and /rooms/$room_id REST events.""" - - user_id = "@sid1:red" - - def test_post_room_no_keys(self): - # POST with no config keys, expect new room id - channel = self.make_request("POST", "/createRoom", "{}") - - self.assertEquals(200, channel.code, channel.result) - self.assertTrue("room_id" in channel.json_body) - - def test_post_room_visibility_key(self): - # POST with visibility config key, expect new room id - channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}') - self.assertEquals(200, channel.code) - self.assertTrue("room_id" in channel.json_body) - - def test_post_room_custom_key(self): - # POST with custom config keys, expect new room id - channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}') - self.assertEquals(200, channel.code) - self.assertTrue("room_id" in channel.json_body) - - def test_post_room_known_and_unknown_keys(self): - # POST with custom + known config keys, expect new room id - channel = self.make_request( - "POST", "/createRoom", b'{"visibility":"private","custom":"things"}' - ) - self.assertEquals(200, channel.code) - self.assertTrue("room_id" in channel.json_body) - - def test_post_room_invalid_content(self): - # POST with invalid content / paths, expect 400 - channel = self.make_request("POST", "/createRoom", b'{"visibili') - self.assertEquals(400, channel.code) - - channel = self.make_request("POST", "/createRoom", b'["hello"]') - self.assertEquals(400, channel.code) - - def test_post_room_invitees_invalid_mxid(self): - # POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088 - # Note the trailing space in the MXID here! - channel = self.make_request( - "POST", "/createRoom", b'{"invite":["@alice:example.com "]}' - ) - self.assertEquals(400, channel.code) - - @unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}}) - def test_post_room_invitees_ratelimit(self): - """Test that invites sent when creating a room are ratelimited by a RateLimiter, - which ratelimits them correctly, including by not limiting when the requester is - exempt from ratelimiting. - """ - - # Build the request's content. We use local MXIDs because invites over federation - # are more difficult to mock. - content = json.dumps( - { - "invite": [ - "@alice1:red", - "@alice2:red", - "@alice3:red", - "@alice4:red", - ] - } - ).encode("utf8") - - # Test that the invites are correctly ratelimited. - channel = self.make_request("POST", "/createRoom", content) - self.assertEqual(400, channel.code) - self.assertEqual( - "Cannot invite so many users at once", - channel.json_body["error"], - ) - - # Add the current user to the ratelimit overrides, allowing them no ratelimiting. - self.get_success( - self.hs.get_datastore().set_ratelimit_for_user(self.user_id, 0, 0) - ) - - # Test that the invites aren't ratelimited anymore. - channel = self.make_request("POST", "/createRoom", content) - self.assertEqual(200, channel.code) - - -class RoomTopicTestCase(RoomBase): - """Tests /rooms/$room_id/topic REST events.""" - - user_id = "@sid1:red" - - def prepare(self, reactor, clock, hs): - # create the room - self.room_id = self.helper.create_room_as(self.user_id) - self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,) - - def test_invalid_puts(self): - # missing keys or invalid json - channel = self.make_request("PUT", self.path, "{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", self.path, '{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", self.path, '{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request( - "PUT", self.path, '[{"_name":"bo"},{"_name":"jill"}]' - ) - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", self.path, "text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", self.path, "") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - # valid key, wrong type - content = '{"topic":["Topic name"]}' - channel = self.make_request("PUT", self.path, content) - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - def test_rooms_topic(self): - # nothing should be there - channel = self.make_request("GET", self.path) - self.assertEquals(404, channel.code, msg=channel.result["body"]) - - # valid put - content = '{"topic":"Topic name"}' - channel = self.make_request("PUT", self.path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - # valid get - channel = self.make_request("GET", self.path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assert_dict(json.loads(content), channel.json_body) - - def test_rooms_topic_with_extra_keys(self): - # valid put with extra keys - content = '{"topic":"Seasons","subtopic":"Summer"}' - channel = self.make_request("PUT", self.path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - # valid get - channel = self.make_request("GET", self.path) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assert_dict(json.loads(content), channel.json_body) - - -class RoomMemberStateTestCase(RoomBase): - """Tests /rooms/$room_id/members/$user_id/state REST events.""" - - user_id = "@sid1:red" - - def prepare(self, reactor, clock, hs): - self.room_id = self.helper.create_room_as(self.user_id) - - def test_invalid_puts(self): - path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id) - # missing keys or invalid json - channel = self.make_request("PUT", path, "{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, '{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, '{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, "text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, "") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - # valid keys, wrong types - content = '{"membership":["%s","%s","%s"]}' % ( - Membership.INVITE, - Membership.JOIN, - Membership.LEAVE, - ) - channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - def test_rooms_members_self(self): - path = "/rooms/%s/state/m.room.member/%s" % ( - urlparse.quote(self.room_id), - self.user_id, - ) - - # valid join message (NOOP since we made the room) - content = '{"membership":"%s"}' % Membership.JOIN - channel = self.make_request("PUT", path, content.encode("ascii")) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - expected_response = {"membership": Membership.JOIN} - self.assertEquals(expected_response, channel.json_body) - - def test_rooms_members_other(self): - self.other_id = "@zzsid1:red" - path = "/rooms/%s/state/m.room.member/%s" % ( - urlparse.quote(self.room_id), - self.other_id, - ) - - # valid invite message - content = '{"membership":"%s"}' % Membership.INVITE - channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assertEquals(json.loads(content), channel.json_body) - - def test_rooms_members_other_custom_keys(self): - self.other_id = "@zzsid1:red" - path = "/rooms/%s/state/m.room.member/%s" % ( - urlparse.quote(self.room_id), - self.other_id, - ) - - # valid invite message with custom key - content = '{"membership":"%s","invite_text":"%s"}' % ( - Membership.INVITE, - "Join us!", - ) - channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - channel = self.make_request("GET", path, None) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - self.assertEquals(json.loads(content), channel.json_body) - - -class RoomInviteRatelimitTestCase(RoomBase): - user_id = "@sid1:red" - - servlets = [ - admin.register_servlets, - profile.register_servlets, - room.register_servlets, - ] - - @unittest.override_config( - {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}} - ) - def test_invites_by_rooms_ratelimit(self): - """Tests that invites in a room are actually rate-limited.""" - room_id = self.helper.create_room_as(self.user_id) - - for i in range(3): - self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,)) - - self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429) - - @unittest.override_config( - {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} - ) - def test_invites_by_users_ratelimit(self): - """Tests that invites to a specific user are actually rate-limited.""" - - for _ in range(3): - room_id = self.helper.create_room_as(self.user_id) - self.helper.invite(room_id, self.user_id, "@other-users:red") - - room_id = self.helper.create_room_as(self.user_id) - self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429) - - -class RoomJoinRatelimitTestCase(RoomBase): - user_id = "@sid1:red" - - servlets = [ - admin.register_servlets, - profile.register_servlets, - room.register_servlets, - ] - - @unittest.override_config( - {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} - ) - def test_join_local_ratelimit(self): - """Tests that local joins are actually rate-limited.""" - for _ in range(3): - self.helper.create_room_as(self.user_id) - - self.helper.create_room_as(self.user_id, expect_code=429) - - @unittest.override_config( - {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} - ) - def test_join_local_ratelimit_profile_change(self): - """Tests that sending a profile update into all of the user's joined rooms isn't - rate-limited by the rate-limiter on joins.""" - - # Create and join as many rooms as the rate-limiting config allows in a second. - room_ids = [ - self.helper.create_room_as(self.user_id), - self.helper.create_room_as(self.user_id), - self.helper.create_room_as(self.user_id), - ] - # Let some time for the rate-limiter to forget about our multi-join. - self.reactor.advance(2) - # Add one to make sure we're joined to more rooms than the config allows us to - # join in a second. - room_ids.append(self.helper.create_room_as(self.user_id)) - - # Create a profile for the user, since it hasn't been done on registration. - store = self.hs.get_datastore() - self.get_success( - store.create_profile(UserID.from_string(self.user_id).localpart) - ) - - # Update the display name for the user. - path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id - channel = self.make_request("PUT", path, {"displayname": "John Doe"}) - self.assertEquals(channel.code, 200, channel.json_body) - - # Check that all the rooms have been sent a profile update into. - for room_id in room_ids: - path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % ( - room_id, - self.user_id, - ) - - channel = self.make_request("GET", path) - self.assertEquals(channel.code, 200) - - self.assertIn("displayname", channel.json_body) - self.assertEquals(channel.json_body["displayname"], "John Doe") - - @unittest.override_config( - {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} - ) - def test_join_local_ratelimit_idempotent(self): - """Tests that the room join endpoints remain idempotent despite rate-limiting - on room joins.""" - room_id = self.helper.create_room_as(self.user_id) - - # Let's test both paths to be sure. - paths_to_test = [ - "/_matrix/client/r0/rooms/%s/join", - "/_matrix/client/r0/join/%s", - ] - - for path in paths_to_test: - # Make sure we send more requests than the rate-limiting config would allow - # if all of these requests ended up joining the user to a room. - for _ in range(4): - channel = self.make_request("POST", path % room_id, {}) - self.assertEquals(channel.code, 200) - - @unittest.override_config( - { - "rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}, - "auto_join_rooms": ["#room:red", "#room2:red", "#room3:red", "#room4:red"], - "autocreate_auto_join_rooms": True, - }, - ) - def test_autojoin_rooms(self): - user_id = self.register_user("testuser", "password") - - # Check that the new user successfully joined the four rooms - rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id)) - self.assertEqual(len(rooms), 4) - - -class RoomMessagesTestCase(RoomBase): - """Tests /rooms/$room_id/messages/$user_id/$msg_id REST events.""" - - user_id = "@sid1:red" - - def prepare(self, reactor, clock, hs): - self.room_id = self.helper.create_room_as(self.user_id) - - def test_invalid_puts(self): - path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) - # missing keys or invalid json - channel = self.make_request("PUT", path, b"{}") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, b'{"_name":"bo"}') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, b'{"nao') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, b'[{"_name":"bo"},{"_name":"jill"}]') - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, b"text only") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - channel = self.make_request("PUT", path, b"") - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - def test_rooms_messages_sent(self): - path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id)) - - content = b'{"body":"test","msgtype":{"type":"a"}}' - channel = self.make_request("PUT", path, content) - self.assertEquals(400, channel.code, msg=channel.result["body"]) - - # custom message types - content = b'{"body":"test","msgtype":"test.custom.text"}' - channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - # m.text message type - path = "/rooms/%s/send/m.room.message/mid2" % (urlparse.quote(self.room_id)) - content = b'{"body":"test2","msgtype":"m.text"}' - channel = self.make_request("PUT", path, content) - self.assertEquals(200, channel.code, msg=channel.result["body"]) - - -class RoomInitialSyncTestCase(RoomBase): - """Tests /rooms/$room_id/initialSync.""" - - user_id = "@sid1:red" - - def prepare(self, reactor, clock, hs): - # create the room - self.room_id = self.helper.create_room_as(self.user_id) - - def test_initial_sync(self): - channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id) - self.assertEquals(200, channel.code) - - self.assertEquals(self.room_id, channel.json_body["room_id"]) - self.assertEquals("join", channel.json_body["membership"]) - - # Room state is easier to assert on if we unpack it into a dict - state = {} - for event in channel.json_body["state"]: - if "state_key" not in event: - continue - t = event["type"] - if t not in state: - state[t] = [] - state[t].append(event) - - self.assertTrue("m.room.create" in state) - - self.assertTrue("messages" in channel.json_body) - self.assertTrue("chunk" in channel.json_body["messages"]) - self.assertTrue("end" in channel.json_body["messages"]) - - self.assertTrue("presence" in channel.json_body) - - presence_by_user = { - e["content"]["user_id"]: e for e in channel.json_body["presence"] - } - self.assertTrue(self.user_id in presence_by_user) - self.assertEquals("m.presence", presence_by_user[self.user_id]["type"]) - - -class RoomMessageListTestCase(RoomBase): - """Tests /rooms/$room_id/messages REST events.""" - - user_id = "@sid1:red" - - def prepare(self, reactor, clock, hs): - self.room_id = self.helper.create_room_as(self.user_id) - - def test_topo_token_is_accepted(self): - token = "t1-0_0_0_0_0_0_0_0_0" - channel = self.make_request( - "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) - ) - self.assertEquals(200, channel.code) - self.assertTrue("start" in channel.json_body) - self.assertEquals(token, channel.json_body["start"]) - self.assertTrue("chunk" in channel.json_body) - self.assertTrue("end" in channel.json_body) - - def test_stream_token_is_accepted_for_fwd_pagianation(self): - token = "s0_0_0_0_0_0_0_0_0" - channel = self.make_request( - "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) - ) - self.assertEquals(200, channel.code) - self.assertTrue("start" in channel.json_body) - self.assertEquals(token, channel.json_body["start"]) - self.assertTrue("chunk" in channel.json_body) - self.assertTrue("end" in channel.json_body) - - def test_room_messages_purge(self): - store = self.hs.get_datastore() - pagination_handler = self.hs.get_pagination_handler() - - # Send a first message in the room, which will be removed by the purge. - first_event_id = self.helper.send(self.room_id, "message 1")["event_id"] - first_token = self.get_success( - store.get_topological_token_for_event(first_event_id) - ) - first_token_str = self.get_success(first_token.to_string(store)) - - # Send a second message in the room, which won't be removed, and which we'll - # use as the marker to purge events before. - second_event_id = self.helper.send(self.room_id, "message 2")["event_id"] - second_token = self.get_success( - store.get_topological_token_for_event(second_event_id) - ) - second_token_str = self.get_success(second_token.to_string(store)) - - # Send a third event in the room to ensure we don't fall under any edge case - # due to our marker being the latest forward extremity in the room. - self.helper.send(self.room_id, "message 3") - - # Check that we get the first and second message when querying /messages. - channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" - % ( - self.room_id, - second_token_str, - json.dumps({"types": [EventTypes.Message]}), - ), - ) - self.assertEqual(channel.code, 200, channel.json_body) - - chunk = channel.json_body["chunk"] - self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) - - # Purge every event before the second event. - purge_id = random_string(16) - pagination_handler._purges_by_id[purge_id] = PurgeStatus() - self.get_success( - pagination_handler._purge_history( - purge_id=purge_id, - room_id=self.room_id, - token=second_token_str, - delete_local_events=True, - ) - ) - - # Check that we only get the second message through /message now that the first - # has been purged. - channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" - % ( - self.room_id, - second_token_str, - json.dumps({"types": [EventTypes.Message]}), - ), - ) - self.assertEqual(channel.code, 200, channel.json_body) - - chunk = channel.json_body["chunk"] - self.assertEqual(len(chunk), 1, [event["content"] for event in chunk]) - - # Check that we get no event, but also no error, when querying /messages with - # the token that was pointing at the first event, because we don't have it - # anymore. - channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" - % ( - self.room_id, - first_token_str, - json.dumps({"types": [EventTypes.Message]}), - ), - ) - self.assertEqual(channel.code, 200, channel.json_body) - - chunk = channel.json_body["chunk"] - self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) - - -class RoomSearchTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - user_id = True - hijack_auth = False - - def prepare(self, reactor, clock, hs): - - # Register the user who does the searching - self.user_id = self.register_user("user", "pass") - self.access_token = self.login("user", "pass") - - # Register the user who sends the message - self.other_user_id = self.register_user("otheruser", "pass") - self.other_access_token = self.login("otheruser", "pass") - - # Create a room - self.room = self.helper.create_room_as(self.user_id, tok=self.access_token) - - # Invite the other person - self.helper.invite( - room=self.room, - src=self.user_id, - tok=self.access_token, - targ=self.other_user_id, - ) - - # The other user joins - self.helper.join( - room=self.room, user=self.other_user_id, tok=self.other_access_token - ) - - def test_finds_message(self): - """ - The search functionality will search for content in messages if asked to - do so. - """ - # The other user sends some messages - self.helper.send(self.room, body="Hi!", tok=self.other_access_token) - self.helper.send(self.room, body="There!", tok=self.other_access_token) - - channel = self.make_request( - "POST", - "/search?access_token=%s" % (self.access_token,), - { - "search_categories": { - "room_events": {"keys": ["content.body"], "search_term": "Hi"} - } - }, - ) - - # Check we get the results we expect -- one search result, of the sent - # messages - self.assertEqual(channel.code, 200) - results = channel.json_body["search_categories"]["room_events"] - self.assertEqual(results["count"], 1) - self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") - - # No context was requested, so we should get none. - self.assertEqual(results["results"][0]["context"], {}) - - def test_include_context(self): - """ - When event_context includes include_profile, profile information will be - included in the search response. - """ - # The other user sends some messages - self.helper.send(self.room, body="Hi!", tok=self.other_access_token) - self.helper.send(self.room, body="There!", tok=self.other_access_token) - - channel = self.make_request( - "POST", - "/search?access_token=%s" % (self.access_token,), - { - "search_categories": { - "room_events": { - "keys": ["content.body"], - "search_term": "Hi", - "event_context": {"include_profile": True}, - } - } - }, - ) - - # Check we get the results we expect -- one search result, of the sent - # messages - self.assertEqual(channel.code, 200) - results = channel.json_body["search_categories"]["room_events"] - self.assertEqual(results["count"], 1) - self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!") - - # We should get context info, like the two users, and the display names. - context = results["results"][0]["context"] - self.assertEqual(len(context["profile_info"].keys()), 2) - self.assertEqual( - context["profile_info"][self.other_user_id]["displayname"], "otheruser" - ) - - -class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - - self.url = b"/_matrix/client/r0/publicRooms" - - config = self.default_config() - config["allow_public_rooms_without_auth"] = False - self.hs = self.setup_test_homeserver(config=config) - - return self.hs - - def test_restricted_no_auth(self): - channel = self.make_request("GET", self.url) - self.assertEqual(channel.code, 401, channel.result) - - def test_restricted_auth(self): - self.register_user("user", "pass") - tok = self.login("user", "pass") - - channel = self.make_request("GET", self.url, access_token=tok) - self.assertEqual(channel.code, 200, channel.result) - - -class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): - """Test that we correctly fallback to local filtering if a remote server - doesn't support search. - """ - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - return self.setup_test_homeserver(federation_client=Mock()) - - def prepare(self, reactor, clock, hs): - self.register_user("user", "pass") - self.token = self.login("user", "pass") - - self.federation_client = hs.get_federation_client() - - def test_simple(self): - "Simple test for searching rooms over federation" - self.federation_client.get_public_rooms.side_effect = ( - lambda *a, **k: defer.succeed({}) - ) - - search_filter = {"generic_search_term": "foobar"} - - channel = self.make_request( - "POST", - b"/_matrix/client/r0/publicRooms?server=testserv", - content={"filter": search_filter}, - access_token=self.token, - ) - self.assertEqual(channel.code, 200, channel.result) - - self.federation_client.get_public_rooms.assert_called_once_with( - "testserv", - limit=100, - since_token=None, - search_filter=search_filter, - include_all_networks=False, - third_party_instance_id=None, - ) - - def test_fallback(self): - "Test that searching public rooms over federation falls back if it gets a 404" - - # The `get_public_rooms` should be called again if the first call fails - # with a 404, when using search filters. - self.federation_client.get_public_rooms.side_effect = ( - HttpResponseException(404, "Not Found", b""), - defer.succeed({}), - ) - - search_filter = {"generic_search_term": "foobar"} - - channel = self.make_request( - "POST", - b"/_matrix/client/r0/publicRooms?server=testserv", - content={"filter": search_filter}, - access_token=self.token, - ) - self.assertEqual(channel.code, 200, channel.result) - - self.federation_client.get_public_rooms.assert_has_calls( - [ - call( - "testserv", - limit=100, - since_token=None, - search_filter=search_filter, - include_all_networks=False, - third_party_instance_id=None, - ), - call( - "testserv", - limit=None, - since_token=None, - search_filter=None, - include_all_networks=False, - third_party_instance_id=None, - ), - ] - ) - - -class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - profile.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - config["allow_per_room_profiles"] = False - self.hs = self.setup_test_homeserver(config=config) - - return self.hs - - def prepare(self, reactor, clock, homeserver): - self.user_id = self.register_user("test", "test") - self.tok = self.login("test", "test") - - # Set a profile for the test user - self.displayname = "test user" - data = {"displayname": self.displayname} - request_data = json.dumps(data) - channel = self.make_request( - "PUT", - "/_matrix/client/r0/profile/%s/displayname" % (self.user_id,), - request_data, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - def test_per_room_profile_forbidden(self): - data = {"membership": "join", "displayname": "other test user"} - request_data = json.dumps(data) - channel = self.make_request( - "PUT", - "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" - % (self.room_id, self.user_id), - request_data, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - event_id = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id), - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - res_displayname = channel.json_body["content"]["displayname"] - self.assertEqual(res_displayname, self.displayname, channel.result) - - -class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): - """Tests that clients can add a "reason" field to membership events and - that they get correctly added to the generated events and propagated. - """ - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - def prepare(self, reactor, clock, homeserver): - self.creator = self.register_user("creator", "test") - self.creator_tok = self.login("creator", "test") - - self.second_user_id = self.register_user("second", "test") - self.second_tok = self.login("second", "test") - - self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) - - def test_join_reason(self): - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/join", - content={"reason": reason}, - access_token=self.second_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def test_leave_reason(self): - self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) - - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/leave", - content={"reason": reason}, - access_token=self.second_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def test_kick_reason(self): - self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) - - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/kick", - content={"reason": reason, "user_id": self.second_user_id}, - access_token=self.second_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def test_ban_reason(self): - self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) - - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/ban", - content={"reason": reason, "user_id": self.second_user_id}, - access_token=self.creator_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def test_unban_reason(self): - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/unban", - content={"reason": reason, "user_id": self.second_user_id}, - access_token=self.creator_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def test_invite_reason(self): - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/invite", - content={"reason": reason, "user_id": self.second_user_id}, - access_token=self.creator_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def test_reject_invite_reason(self): - self.helper.invite( - self.room_id, - src=self.creator, - targ=self.second_user_id, - tok=self.creator_tok, - ) - - reason = "hello" - channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room_id}/leave", - content={"reason": reason}, - access_token=self.second_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - self._check_for_reason(reason) - - def _check_for_reason(self, reason): - channel = self.make_request( - "GET", - "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format( - self.room_id, self.second_user_id - ), - access_token=self.creator_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - event_content = channel.json_body - - self.assertEqual(event_content.get("reason"), reason, channel.result) - - -class LabelsTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - profile.register_servlets, - ] - - # Filter that should only catch messages with the label "#fun". - FILTER_LABELS = { - "types": [EventTypes.Message], - "org.matrix.labels": ["#fun"], - } - # Filter that should only catch messages without the label "#fun". - FILTER_NOT_LABELS = { - "types": [EventTypes.Message], - "org.matrix.not_labels": ["#fun"], - } - # Filter that should only catch messages with the label "#work" but without the label - # "#notfun". - FILTER_LABELS_NOT_LABELS = { - "types": [EventTypes.Message], - "org.matrix.labels": ["#work"], - "org.matrix.not_labels": ["#notfun"], - } - - def prepare(self, reactor, clock, homeserver): - self.user_id = self.register_user("test", "test") - self.tok = self.login("test", "test") - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - def test_context_filter_labels(self): - """Test that we can filter by a label on a /context request.""" - event_id = self._send_labelled_messages_in_room() - - channel = self.make_request( - "GET", - "/rooms/%s/context/%s?filter=%s" - % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - events_before = channel.json_body["events_before"] - - self.assertEqual( - len(events_before), 1, [event["content"] for event in events_before] - ) - self.assertEqual( - events_before[0]["content"]["body"], "with right label", events_before[0] - ) - - events_after = channel.json_body["events_before"] - - self.assertEqual( - len(events_after), 1, [event["content"] for event in events_after] - ) - self.assertEqual( - events_after[0]["content"]["body"], "with right label", events_after[0] - ) - - def test_context_filter_not_labels(self): - """Test that we can filter by the absence of a label on a /context request.""" - event_id = self._send_labelled_messages_in_room() - - channel = self.make_request( - "GET", - "/rooms/%s/context/%s?filter=%s" - % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - events_before = channel.json_body["events_before"] - - self.assertEqual( - len(events_before), 1, [event["content"] for event in events_before] - ) - self.assertEqual( - events_before[0]["content"]["body"], "without label", events_before[0] - ) - - events_after = channel.json_body["events_after"] - - self.assertEqual( - len(events_after), 2, [event["content"] for event in events_after] - ) - self.assertEqual( - events_after[0]["content"]["body"], "with wrong label", events_after[0] - ) - self.assertEqual( - events_after[1]["content"]["body"], "with two wrong labels", events_after[1] - ) - - def test_context_filter_labels_not_labels(self): - """Test that we can filter by both a label and the absence of another label on a - /context request. - """ - event_id = self._send_labelled_messages_in_room() - - channel = self.make_request( - "GET", - "/rooms/%s/context/%s?filter=%s" - % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - events_before = channel.json_body["events_before"] - - self.assertEqual( - len(events_before), 0, [event["content"] for event in events_before] - ) - - events_after = channel.json_body["events_after"] - - self.assertEqual( - len(events_after), 1, [event["content"] for event in events_after] - ) - self.assertEqual( - events_after[0]["content"]["body"], "with wrong label", events_after[0] - ) - - def test_messages_filter_labels(self): - """Test that we can filter by a label on a /messages request.""" - self._send_labelled_messages_in_room() - - token = "s0_0_0_0_0_0_0_0_0" - channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" - % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)), - ) - - events = channel.json_body["chunk"] - - self.assertEqual(len(events), 2, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) - - def test_messages_filter_not_labels(self): - """Test that we can filter by the absence of a label on a /messages request.""" - self._send_labelled_messages_in_room() - - token = "s0_0_0_0_0_0_0_0_0" - channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" - % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)), - ) - - events = channel.json_body["chunk"] - - self.assertEqual(len(events), 4, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "without label", events[0]) - self.assertEqual(events[1]["content"]["body"], "without label", events[1]) - self.assertEqual(events[2]["content"]["body"], "with wrong label", events[2]) - self.assertEqual( - events[3]["content"]["body"], "with two wrong labels", events[3] - ) - - def test_messages_filter_labels_not_labels(self): - """Test that we can filter by both a label and the absence of another label on a - /messages request. - """ - self._send_labelled_messages_in_room() - - token = "s0_0_0_0_0_0_0_0_0" - channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" - % ( - self.room_id, - self.tok, - token, - json.dumps(self.FILTER_LABELS_NOT_LABELS), - ), - ) - - events = channel.json_body["chunk"] - - self.assertEqual(len(events), 1, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) - - def test_search_filter_labels(self): - """Test that we can filter by a label on a /search request.""" - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS, - } - } - } - ) - - self._send_labelled_messages_in_room() - - channel = self.make_request( - "POST", "/search?access_token=%s" % self.tok, request_data - ) - - results = channel.json_body["search_categories"]["room_events"]["results"] - - self.assertEqual( - len(results), - 2, - [result["result"]["content"] for result in results], - ) - self.assertEqual( - results[0]["result"]["content"]["body"], - "with right label", - results[0]["result"]["content"]["body"], - ) - self.assertEqual( - results[1]["result"]["content"]["body"], - "with right label", - results[1]["result"]["content"]["body"], - ) - - def test_search_filter_not_labels(self): - """Test that we can filter by the absence of a label on a /search request.""" - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_NOT_LABELS, - } - } - } - ) - - self._send_labelled_messages_in_room() - - channel = self.make_request( - "POST", "/search?access_token=%s" % self.tok, request_data - ) - - results = channel.json_body["search_categories"]["room_events"]["results"] - - self.assertEqual( - len(results), - 4, - [result["result"]["content"] for result in results], - ) - self.assertEqual( - results[0]["result"]["content"]["body"], - "without label", - results[0]["result"]["content"]["body"], - ) - self.assertEqual( - results[1]["result"]["content"]["body"], - "without label", - results[1]["result"]["content"]["body"], - ) - self.assertEqual( - results[2]["result"]["content"]["body"], - "with wrong label", - results[2]["result"]["content"]["body"], - ) - self.assertEqual( - results[3]["result"]["content"]["body"], - "with two wrong labels", - results[3]["result"]["content"]["body"], - ) - - def test_search_filter_labels_not_labels(self): - """Test that we can filter by both a label and the absence of another label on a - /search request. - """ - request_data = json.dumps( - { - "search_categories": { - "room_events": { - "search_term": "label", - "filter": self.FILTER_LABELS_NOT_LABELS, - } - } - } - ) - - self._send_labelled_messages_in_room() - - channel = self.make_request( - "POST", "/search?access_token=%s" % self.tok, request_data - ) - - results = channel.json_body["search_categories"]["room_events"]["results"] - - self.assertEqual( - len(results), - 1, - [result["result"]["content"] for result in results], - ) - self.assertEqual( - results[0]["result"]["content"]["body"], - "with wrong label", - results[0]["result"]["content"]["body"], - ) - - def _send_labelled_messages_in_room(self): - """Sends several messages to a room with different labels (or without any) to test - filtering by label. - Returns: - The ID of the event to use if we're testing filtering on /context. - """ - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - tok=self.tok, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "without label"}, - tok=self.tok, - ) - - res = self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "without label"}, - tok=self.tok, - ) - # Return this event's ID when we test filtering in /context requests. - event_id = res["event_id"] - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with wrong label", - EventContentFields.LABELS: ["#work"], - }, - tok=self.tok, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with two wrong labels", - EventContentFields.LABELS: ["#work", "#notfun"], - }, - tok=self.tok, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - tok=self.tok, - ) - - return event_id - - -class ContextTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - account.register_servlets, - ] - - def prepare(self, reactor, clock, homeserver): - self.user_id = self.register_user("user", "password") - self.tok = self.login("user", "password") - self.room_id = self.helper.create_room_as( - self.user_id, tok=self.tok, is_public=False - ) - - self.other_user_id = self.register_user("user2", "password") - self.other_tok = self.login("user2", "password") - - self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok) - self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok) - - def test_erased_sender(self): - """Test that an erasure request results in the requester's events being hidden - from any new member of the room. - """ - - # Send a bunch of events in the room. - - self.helper.send(self.room_id, "message 1", tok=self.tok) - self.helper.send(self.room_id, "message 2", tok=self.tok) - event_id = self.helper.send(self.room_id, "message 3", tok=self.tok)["event_id"] - self.helper.send(self.room_id, "message 4", tok=self.tok) - self.helper.send(self.room_id, "message 5", tok=self.tok) - - # Check that we can still see the messages before the erasure request. - - channel = self.make_request( - "GET", - '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' - % (self.room_id, event_id), - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - events_before = channel.json_body["events_before"] - - self.assertEqual(len(events_before), 2, events_before) - self.assertEqual( - events_before[0].get("content", {}).get("body"), - "message 2", - events_before[0], - ) - self.assertEqual( - events_before[1].get("content", {}).get("body"), - "message 1", - events_before[1], - ) - - self.assertEqual( - channel.json_body["event"].get("content", {}).get("body"), - "message 3", - channel.json_body["event"], - ) - - events_after = channel.json_body["events_after"] - - self.assertEqual(len(events_after), 2, events_after) - self.assertEqual( - events_after[0].get("content", {}).get("body"), - "message 4", - events_after[0], - ) - self.assertEqual( - events_after[1].get("content", {}).get("body"), - "message 5", - events_after[1], - ) - - # Deactivate the first account and erase the user's data. - - deactivate_account_handler = self.hs.get_deactivate_account_handler() - self.get_success( - deactivate_account_handler.deactivate_account( - self.user_id, True, create_requester(self.user_id) - ) - ) - - # Invite another user in the room. This is needed because messages will be - # pruned only if the user wasn't a member of the room when the messages were - # sent. - - invited_user_id = self.register_user("user3", "password") - invited_tok = self.login("user3", "password") - - self.helper.invite( - self.room_id, self.other_user_id, invited_user_id, tok=self.other_tok - ) - self.helper.join(self.room_id, invited_user_id, tok=invited_tok) - - # Check that a user that joined the room after the erasure request can't see - # the messages anymore. - - channel = self.make_request( - "GET", - '/rooms/%s/context/%s?filter={"types":["m.room.message"]}' - % (self.room_id, event_id), - access_token=invited_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - - events_before = channel.json_body["events_before"] - - self.assertEqual(len(events_before), 2, events_before) - self.assertDictEqual(events_before[0].get("content"), {}, events_before[0]) - self.assertDictEqual(events_before[1].get("content"), {}, events_before[1]) - - self.assertDictEqual( - channel.json_body["event"].get("content"), {}, channel.json_body["event"] - ) - - events_after = channel.json_body["events_after"] - - self.assertEqual(len(events_after), 2, events_after) - self.assertDictEqual(events_after[0].get("content"), {}, events_after[0]) - self.assertEqual(events_after[1].get("content"), {}, events_after[1]) - - -class RoomAliasListTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - directory.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def prepare(self, reactor, clock, homeserver): - self.room_owner = self.register_user("room_owner", "test") - self.room_owner_tok = self.login("room_owner", "test") - - self.room_id = self.helper.create_room_as( - self.room_owner, tok=self.room_owner_tok - ) - - def test_no_aliases(self): - res = self._get_aliases(self.room_owner_tok) - self.assertEqual(res["aliases"], []) - - def test_not_in_room(self): - self.register_user("user", "test") - user_tok = self.login("user", "test") - res = self._get_aliases(user_tok, expected_code=403) - self.assertEqual(res["errcode"], "M_FORBIDDEN") - - def test_admin_user(self): - alias1 = self._random_alias() - self._set_alias_via_directory(alias1) - - self.register_user("user", "test", admin=True) - user_tok = self.login("user", "test") - - res = self._get_aliases(user_tok) - self.assertEqual(res["aliases"], [alias1]) - - def test_with_aliases(self): - alias1 = self._random_alias() - alias2 = self._random_alias() - - self._set_alias_via_directory(alias1) - self._set_alias_via_directory(alias2) - - res = self._get_aliases(self.room_owner_tok) - self.assertEqual(set(res["aliases"]), {alias1, alias2}) - - def test_peekable_room(self): - alias1 = self._random_alias() - self._set_alias_via_directory(alias1) - - self.helper.send_state( - self.room_id, - EventTypes.RoomHistoryVisibility, - body={"history_visibility": "world_readable"}, - tok=self.room_owner_tok, - ) - - self.register_user("user", "test") - user_tok = self.login("user", "test") - - res = self._get_aliases(user_tok) - self.assertEqual(res["aliases"], [alias1]) - - def _get_aliases(self, access_token: str, expected_code: int = 200) -> JsonDict: - """Calls the endpoint under test. returns the json response object.""" - channel = self.make_request( - "GET", - "/_matrix/client/r0/rooms/%s/aliases" % (self.room_id,), - access_token=access_token, - ) - self.assertEqual(channel.code, expected_code, channel.result) - res = channel.json_body - self.assertIsInstance(res, dict) - if expected_code == 200: - self.assertIsInstance(res["aliases"], list) - return res - - def _random_alias(self) -> str: - return RoomAlias(random_string(5), self.hs.hostname).to_string() - - def _set_alias_via_directory(self, alias: str, expected_code: int = 200): - url = "/_matrix/client/r0/directory/room/" + alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) - - channel = self.make_request( - "PUT", url, request_data, access_token=self.room_owner_tok - ) - self.assertEqual(channel.code, expected_code, channel.result) - - -class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - directory.register_servlets, - login.register_servlets, - room.register_servlets, - ] - - def prepare(self, reactor, clock, homeserver): - self.room_owner = self.register_user("room_owner", "test") - self.room_owner_tok = self.login("room_owner", "test") - - self.room_id = self.helper.create_room_as( - self.room_owner, tok=self.room_owner_tok - ) - - self.alias = "#alias:test" - self._set_alias_via_directory(self.alias) - - def _set_alias_via_directory(self, alias: str, expected_code: int = 200): - url = "/_matrix/client/r0/directory/room/" + alias - data = {"room_id": self.room_id} - request_data = json.dumps(data) - - channel = self.make_request( - "PUT", url, request_data, access_token=self.room_owner_tok - ) - self.assertEqual(channel.code, expected_code, channel.result) - - def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict: - """Calls the endpoint under test. returns the json response object.""" - channel = self.make_request( - "GET", - "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), - access_token=self.room_owner_tok, - ) - self.assertEqual(channel.code, expected_code, channel.result) - res = channel.json_body - self.assertIsInstance(res, dict) - return res - - def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict: - """Calls the endpoint under test. returns the json response object.""" - channel = self.make_request( - "PUT", - "rooms/%s/state/m.room.canonical_alias" % (self.room_id,), - json.dumps(content), - access_token=self.room_owner_tok, - ) - self.assertEqual(channel.code, expected_code, channel.result) - res = channel.json_body - self.assertIsInstance(res, dict) - return res - - def test_canonical_alias(self): - """Test a basic alias message.""" - # There is no canonical alias to start with. - self._get_canonical_alias(expected_code=404) - - # Create an alias. - self._set_canonical_alias({"alias": self.alias}) - - # Canonical alias now exists! - res = self._get_canonical_alias() - self.assertEqual(res, {"alias": self.alias}) - - # Now remove the alias. - self._set_canonical_alias({}) - - # There is an alias event, but it is empty. - res = self._get_canonical_alias() - self.assertEqual(res, {}) - - def test_alt_aliases(self): - """Test a canonical alias message with alt_aliases.""" - # Create an alias. - self._set_canonical_alias({"alt_aliases": [self.alias]}) - - # Canonical alias now exists! - res = self._get_canonical_alias() - self.assertEqual(res, {"alt_aliases": [self.alias]}) - - # Now remove the alt_aliases. - self._set_canonical_alias({}) - - # There is an alias event, but it is empty. - res = self._get_canonical_alias() - self.assertEqual(res, {}) - - def test_alias_alt_aliases(self): - """Test a canonical alias message with an alias and alt_aliases.""" - # Create an alias. - self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) - - # Canonical alias now exists! - res = self._get_canonical_alias() - self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) - - # Now remove the alias and alt_aliases. - self._set_canonical_alias({}) - - # There is an alias event, but it is empty. - res = self._get_canonical_alias() - self.assertEqual(res, {}) - - def test_partial_modify(self): - """Test removing only the alt_aliases.""" - # Create an alias. - self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) - - # Canonical alias now exists! - res = self._get_canonical_alias() - self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]}) - - # Now remove the alt_aliases. - self._set_canonical_alias({"alias": self.alias}) - - # There is an alias event, but it is empty. - res = self._get_canonical_alias() - self.assertEqual(res, {"alias": self.alias}) - - def test_add_alias(self): - """Test removing only the alt_aliases.""" - # Create an additional alias. - second_alias = "#second:test" - self._set_alias_via_directory(second_alias) - - # Add the canonical alias. - self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]}) - - # Then add the second alias. - self._set_canonical_alias( - {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} - ) - - # Canonical alias now exists! - res = self._get_canonical_alias() - self.assertEqual( - res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]} - ) - - def test_bad_data(self): - """Invalid data for alt_aliases should cause errors.""" - self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400) - self._set_canonical_alias({"alt_aliases": None}, expected_code=400) - self._set_canonical_alias({"alt_aliases": 0}, expected_code=400) - self._set_canonical_alias({"alt_aliases": 1}, expected_code=400) - self._set_canonical_alias({"alt_aliases": False}, expected_code=400) - self._set_canonical_alias({"alt_aliases": True}, expected_code=400) - self._set_canonical_alias({"alt_aliases": {}}, expected_code=400) - - def test_bad_alias(self): - """An alias which does not point to the room raises a SynapseError.""" - self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) - self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py deleted file mode 100644 index b54b004733..0000000000 --- a/tests/rest/client/v1/test_typing.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests REST events for /rooms paths.""" - -from unittest.mock import Mock - -from synapse.rest.client import room -from synapse.types import UserID - -from tests import unittest - -PATH_PREFIX = "/_matrix/client/api/v1" - - -class RoomTypingTestCase(unittest.HomeserverTestCase): - """Tests /rooms/$room_id/typing/$user_id REST API.""" - - user_id = "@sid:red" - - user = UserID.from_string(user_id) - servlets = [room.register_servlets] - - def make_homeserver(self, reactor, clock): - - hs = self.setup_test_homeserver( - "red", - federation_http_client=None, - federation_client=Mock(), - ) - - self.event_source = hs.get_event_sources().sources["typing"] - - hs.get_federation_handler = Mock() - - async def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.auth_user_id), - "token_id": 1, - "is_guest": False, - } - - hs.get_auth().get_user_by_access_token = get_user_by_access_token - - async def _insert_client_ip(*args, **kwargs): - return None - - hs.get_datastore().insert_client_ip = _insert_client_ip - - return hs - - def prepare(self, reactor, clock, hs): - self.room_id = self.helper.create_room_as(self.user_id) - # Need another user to make notifications actually work - self.helper.join(self.room_id, user="@jim:red") - - def test_set_typing(self): - channel = self.make_request( - "PUT", - "/rooms/%s/typing/%s" % (self.room_id, self.user_id), - b'{"typing": true, "timeout": 30000}', - ) - self.assertEquals(200, channel.code) - - self.assertEquals(self.event_source.get_current_key(), 1) - events = self.get_success( - self.event_source.get_new_events(from_key=0, room_ids=[self.room_id]) - ) - self.assertEquals( - events[0], - [ - { - "type": "m.typing", - "room_id": self.room_id, - "content": {"user_ids": [self.user_id]}, - } - ], - ) - - def test_set_not_typing(self): - channel = self.make_request( - "PUT", - "/rooms/%s/typing/%s" % (self.room_id, self.user_id), - b'{"typing": false}', - ) - self.assertEquals(200, channel.code) - - def test_typing_timeout(self): - channel = self.make_request( - "PUT", - "/rooms/%s/typing/%s" % (self.room_id, self.user_id), - b'{"typing": true, "timeout": 30000}', - ) - self.assertEquals(200, channel.code) - - self.assertEquals(self.event_source.get_current_key(), 1) - - self.reactor.advance(36) - - self.assertEquals(self.event_source.get_current_key(), 2) - - channel = self.make_request( - "PUT", - "/rooms/%s/typing/%s" % (self.room_id, self.user_id), - b'{"typing": true, "timeout": 30000}', - ) - self.assertEquals(200, channel.code) - - self.assertEquals(self.event_source.get_current_key(), 3) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py deleted file mode 100644 index 954ad1a1fd..0000000000 --- a/tests/rest/client/v1/utils.py +++ /dev/null @@ -1,654 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017 Vector Creations Ltd -# Copyright 2018-2019 New Vector Ltd -# Copyright 2019-2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import re -import time -import urllib.parse -from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union -from unittest.mock import patch - -import attr - -from twisted.web.resource import Resource -from twisted.web.server import Site - -from synapse.api.constants import Membership -from synapse.types import JsonDict - -from tests.server import FakeChannel, FakeSite, make_request -from tests.test_utils import FakeResponse -from tests.test_utils.html_parsers import TestHtmlParser - - -@attr.s -class RestHelper: - """Contains extra helper functions to quickly and clearly perform a given - REST action, which isn't the focus of the test. - """ - - hs = attr.ib() - site = attr.ib(type=Site) - auth_user_id = attr.ib() - - def create_room_as( - self, - room_creator: Optional[str] = None, - is_public: bool = True, - room_version: Optional[str] = None, - tok: Optional[str] = None, - expect_code: int = 200, - extra_content: Optional[Dict] = None, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, - ) -> str: - """ - Create a room. - - Args: - room_creator: The user ID to create the room with. - is_public: If True, the `visibility` parameter will be set to the - default (public). Otherwise, the `visibility` parameter will be set - to "private". - room_version: The room version to create the room as. Defaults to Synapse's - default room version. - tok: The access token to use in the request. - expect_code: The expected HTTP response code. - - Returns: - The ID of the newly created room. - """ - temp_id = self.auth_user_id - self.auth_user_id = room_creator - path = "/_matrix/client/r0/createRoom" - content = extra_content or {} - if not is_public: - content["visibility"] = "private" - if room_version: - content["room_version"] = room_version - if tok: - path = path + "?access_token=%s" % tok - - channel = make_request( - self.hs.get_reactor(), - self.site, - "POST", - path, - json.dumps(content).encode("utf8"), - custom_headers=custom_headers, - ) - - assert channel.result["code"] == b"%d" % expect_code, channel.result - self.auth_user_id = temp_id - - if expect_code == 200: - return channel.json_body["room_id"] - - def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): - self.change_membership( - room=room, - src=src, - targ=targ, - tok=tok, - membership=Membership.INVITE, - expect_code=expect_code, - ) - - def join(self, room=None, user=None, expect_code=200, tok=None): - self.change_membership( - room=room, - src=user, - targ=user, - tok=tok, - membership=Membership.JOIN, - expect_code=expect_code, - ) - - def leave(self, room=None, user=None, expect_code=200, tok=None): - self.change_membership( - room=room, - src=user, - targ=user, - tok=tok, - membership=Membership.LEAVE, - expect_code=expect_code, - ) - - def change_membership( - self, - room: str, - src: str, - targ: str, - membership: str, - extra_data: Optional[dict] = None, - tok: Optional[str] = None, - expect_code: int = 200, - ) -> None: - """ - Send a membership state event into a room. - - Args: - room: The ID of the room to send to - src: The mxid of the event sender - targ: The mxid of the event's target. The state key - membership: The type of membership event - extra_data: Extra information to include in the content of the event - tok: The user access token to use - expect_code: The expected HTTP response code - """ - temp_id = self.auth_user_id - self.auth_user_id = src - - path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ) - if tok: - path = path + "?access_token=%s" % tok - - data = {"membership": membership} - data.update(extra_data or {}) - - channel = make_request( - self.hs.get_reactor(), - self.site, - "PUT", - path, - json.dumps(data).encode("utf8"), - ) - - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( - expect_code, - int(channel.result["code"]), - channel.result["body"], - ) - - self.auth_user_id = temp_id - - def send( - self, - room_id, - body=None, - txn_id=None, - tok=None, - expect_code=200, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, - ): - if body is None: - body = "body_text_here" - - content = {"msgtype": "m.text", "body": body} - - return self.send_event( - room_id, - "m.room.message", - content, - txn_id, - tok, - expect_code, - custom_headers=custom_headers, - ) - - def send_event( - self, - room_id, - type, - content: Optional[dict] = None, - txn_id=None, - tok=None, - expect_code=200, - custom_headers: Optional[ - Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] - ] = None, - ): - if txn_id is None: - txn_id = "m%s" % (str(time.time())) - - path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id) - if tok: - path = path + "?access_token=%s" % tok - - channel = make_request( - self.hs.get_reactor(), - self.site, - "PUT", - path, - json.dumps(content or {}).encode("utf8"), - custom_headers=custom_headers, - ) - - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( - expect_code, - int(channel.result["code"]), - channel.result["body"], - ) - - return channel.json_body - - def _read_write_state( - self, - room_id: str, - event_type: str, - body: Optional[Dict[str, Any]], - tok: str, - expect_code: int = 200, - state_key: str = "", - method: str = "GET", - ) -> Dict: - """Read or write some state from a given room - - Args: - room_id: - event_type: The type of state event - body: Body that is sent when making the request. The content of the state event. - If None, the request to the server will have an empty body - tok: The access token to use - expect_code: The HTTP code to expect in the response - state_key: - method: "GET" or "PUT" for reading or writing state, respectively - - Returns: - The response body from the server - - Raises: - AssertionError: if expect_code doesn't match the HTTP code we received - """ - path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % ( - room_id, - event_type, - state_key, - ) - if tok: - path = path + "?access_token=%s" % tok - - # Set request body if provided - content = b"" - if body is not None: - content = json.dumps(body).encode("utf8") - - channel = make_request(self.hs.get_reactor(), self.site, method, path, content) - - assert ( - int(channel.result["code"]) == expect_code - ), "Expected: %d, got: %d, resp: %r" % ( - expect_code, - int(channel.result["code"]), - channel.result["body"], - ) - - return channel.json_body - - def get_state( - self, - room_id: str, - event_type: str, - tok: str, - expect_code: int = 200, - state_key: str = "", - ): - """Gets some state from a room - - Args: - room_id: - event_type: The type of state event - tok: The access token to use - expect_code: The HTTP code to expect in the response - state_key: - - Returns: - The response body from the server - - Raises: - AssertionError: if expect_code doesn't match the HTTP code we received - """ - return self._read_write_state( - room_id, event_type, None, tok, expect_code, state_key, method="GET" - ) - - def send_state( - self, - room_id: str, - event_type: str, - body: Dict[str, Any], - tok: str, - expect_code: int = 200, - state_key: str = "", - ): - """Set some state in a room - - Args: - room_id: - event_type: The type of state event - body: Body that is sent when making the request. The content of the state event. - tok: The access token to use - expect_code: The HTTP code to expect in the response - state_key: - - Returns: - The response body from the server - - Raises: - AssertionError: if expect_code doesn't match the HTTP code we received - """ - return self._read_write_state( - room_id, event_type, body, tok, expect_code, state_key, method="PUT" - ) - - def upload_media( - self, - resource: Resource, - image_data: bytes, - tok: str, - filename: str = "test.png", - expect_code: int = 200, - ) -> dict: - """Upload a piece of test media to the media repo - Args: - resource: The resource that will handle the upload request - image_data: The image data to upload - tok: The user token to use during the upload - filename: The filename of the media to be uploaded - expect_code: The return code to expect from attempting to upload the media - """ - image_length = len(image_data) - path = "/_matrix/media/r0/upload?filename=%s" % (filename,) - channel = make_request( - self.hs.get_reactor(), - FakeSite(resource), - "POST", - path, - content=image_data, - access_token=tok, - custom_headers=[(b"Content-Length", str(image_length))], - ) - - assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( - expect_code, - int(channel.result["code"]), - channel.result["body"], - ) - - return channel.json_body - - def login_via_oidc(self, remote_user_id: str) -> JsonDict: - """Log in (as a new user) via OIDC - - Returns the result of the final token login. - - Requires that "oidc_config" in the homeserver config be set appropriately - (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a - "public_base_url". - - Also requires the login servlet and the OIDC callback resource to be mounted at - the normal places. - """ - client_redirect_url = "https://x" - channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url) - - # expect a confirmation page - assert channel.code == 200, channel.result - - # fish the matrix login token out of the body of the confirmation page - m = re.search( - 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), - channel.text_body, - ) - assert m, channel.text_body - login_token = m.group(1) - - # finally, submit the matrix login token to the login API, which gives us our - # matrix access token and device id. - channel = make_request( - self.hs.get_reactor(), - self.site, - "POST", - "/login", - content={"type": "m.login.token", "token": login_token}, - ) - assert channel.code == 200 - return channel.json_body - - def auth_via_oidc( - self, - user_info_dict: JsonDict, - client_redirect_url: Optional[str] = None, - ui_auth_session_id: Optional[str] = None, - ) -> FakeChannel: - """Perform an OIDC authentication flow via a mock OIDC provider. - - This can be used for either login or user-interactive auth. - - Starts by making a request to the relevant synapse redirect endpoint, which is - expected to serve a 302 to the OIDC provider. We then make a request to the - OIDC callback endpoint, intercepting the HTTP requests that will get sent back - to the OIDC provider. - - Requires that "oidc_config" in the homeserver config be set appropriately - (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a - "public_base_url". - - Also requires the login servlet and the OIDC callback resource to be mounted at - the normal places. - - Args: - user_info_dict: the remote userinfo that the OIDC provider should present. - Typically this should be '{"sub": ""}'. - client_redirect_url: for a login flow, the client redirect URL to pass to - the login redirect endpoint - ui_auth_session_id: if set, we will perform a UI Auth flow. The session id - of the UI auth. - - Returns: - A FakeChannel containing the result of calling the OIDC callback endpoint. - Note that the response code may be a 200, 302 or 400 depending on how things - went. - """ - - cookies = {} - - # if we're doing a ui auth, hit the ui auth redirect endpoint - if ui_auth_session_id: - # can't set the client redirect url for UI Auth - assert client_redirect_url is None - oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies) - else: - # otherwise, hit the login redirect endpoint - oauth_uri = self.initiate_sso_login(client_redirect_url, cookies) - - # we now have a URI for the OIDC IdP, but we skip that and go straight - # back to synapse's OIDC callback resource. However, we do need the "state" - # param that synapse passes to the IdP via query params, as well as the cookie - # that synapse passes to the client. - - oauth_uri_path, _ = oauth_uri.split("?", 1) - assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( - "unexpected SSO URI " + oauth_uri_path - ) - return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) - - def complete_oidc_auth( - self, - oauth_uri: str, - cookies: Mapping[str, str], - user_info_dict: JsonDict, - ) -> FakeChannel: - """Mock out an OIDC authentication flow - - Assumes that an OIDC auth has been initiated by one of initiate_sso_login or - initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to - Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get - sent back to the OIDC provider. - - Requires the OIDC callback resource to be mounted at the normal place. - - Args: - oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, - from initiate_sso_login or initiate_sso_ui_auth). - cookies: the cookies set by synapse's redirect endpoint, which will be - sent back to the callback endpoint. - user_info_dict: the remote userinfo that the OIDC provider should present. - Typically this should be '{"sub": ""}'. - - Returns: - A FakeChannel containing the result of calling the OIDC callback endpoint. - """ - _, oauth_uri_qs = oauth_uri.split("?", 1) - params = urllib.parse.parse_qs(oauth_uri_qs) - callback_uri = "%s?%s" % ( - urllib.parse.urlparse(params["redirect_uri"][0]).path, - urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}), - ) - - # before we hit the callback uri, stub out some methods in the http client so - # that we don't have to handle full HTTPS requests. - # (expected url, json response) pairs, in the order we expect them. - expected_requests = [ - # first we get a hit to the token endpoint, which we tell to return - # a dummy OIDC access token - (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), - # and then one to the user_info endpoint, which returns our remote user id. - (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), - ] - - async def mock_req(method: str, uri: str, data=None, headers=None): - (expected_uri, resp_obj) = expected_requests.pop(0) - assert uri == expected_uri - resp = FakeResponse( - code=200, - phrase=b"OK", - body=json.dumps(resp_obj).encode("utf-8"), - ) - return resp - - with patch.object(self.hs.get_proxied_http_client(), "request", mock_req): - # now hit the callback URI with the right params and a made-up code - channel = make_request( - self.hs.get_reactor(), - self.site, - "GET", - callback_uri, - custom_headers=[ - ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items() - ], - ) - return channel - - def initiate_sso_login( - self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str] - ) -> str: - """Make a request to the login-via-sso redirect endpoint, and return the target - - Assumes that exactly one SSO provider has been configured. Requires the login - servlet to be mounted. - - Args: - client_redirect_url: the client redirect URL to pass to the login redirect - endpoint - cookies: any cookies returned will be added to this dict - - Returns: - the URI that the client gets redirected to (ie, the SSO server) - """ - params = {} - if client_redirect_url: - params["redirectUrl"] = client_redirect_url - - # hit the redirect url (which should redirect back to the redirect url. This - # is the easiest way of figuring out what the Host header ought to be set to - # to keep Synapse happy. - channel = make_request( - self.hs.get_reactor(), - self.site, - "GET", - "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), - ) - assert channel.code == 302 - - # hit the redirect url again with the right Host header, which should now issue - # a cookie and redirect to the SSO provider. - location = channel.headers.getRawHeaders("Location")[0] - parts = urllib.parse.urlsplit(location) - channel = make_request( - self.hs.get_reactor(), - self.site, - "GET", - urllib.parse.urlunsplit(("", "") + parts[2:]), - custom_headers=[ - ("Host", parts[1]), - ], - ) - - assert channel.code == 302 - channel.extract_cookies(cookies) - return channel.headers.getRawHeaders("Location")[0] - - def initiate_sso_ui_auth( - self, ui_auth_session_id: str, cookies: MutableMapping[str, str] - ) -> str: - """Make a request to the ui-auth-via-sso endpoint, and return the target - - Assumes that exactly one SSO provider has been configured. Requires the - AuthRestServlet to be mounted. - - Args: - ui_auth_session_id: the session id of the UI auth - cookies: any cookies returned will be added to this dict - - Returns: - the URI that the client gets linked to (ie, the SSO server) - """ - sso_redirect_endpoint = ( - "/_matrix/client/r0/auth/m.login.sso/fallback/web?" - + urllib.parse.urlencode({"session": ui_auth_session_id}) - ) - # hit the redirect url (which will issue a cookie and state) - channel = make_request( - self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint - ) - # that should serve a confirmation page - assert channel.code == 200, channel.text_body - channel.extract_cookies(cookies) - - # parse the confirmation page to fish out the link. - p = TestHtmlParser() - p.feed(channel.text_body) - p.close() - assert len(p.links) == 1, "not exactly one link in confirmation page" - oauth_uri = p.links[0] - return oauth_uri - - -# an 'oidc_config' suitable for login_via_oidc. -TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" -TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" -TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" -TEST_OIDC_CONFIG = { - "enabled": True, - "discover": False, - "issuer": "https://issuer.test", - "client_id": "test-client-id", - "client_secret": "test-client-secret", - "scopes": ["profile"], - "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, - "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, - "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, - "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, -} diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py deleted file mode 100644 index b946fca8b3..0000000000 --- a/tests/rest/client/v2_alpha/test_account.py +++ /dev/null @@ -1,1000 +0,0 @@ -# Copyright 2015-2016 OpenMarket Ltd -# Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json -import os -import re -from email.parser import Parser -from typing import Optional - -import pkg_resources - -import synapse.rest.admin -from synapse.api.constants import LoginType, Membership -from synapse.api.errors import Codes, HttpResponseException -from synapse.appservice import ApplicationService -from synapse.rest.client import account, login, register, room -from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource - -from tests import unittest -from tests.server import FakeSite, make_request -from tests.unittest import override_config - - -class PasswordResetTestCase(unittest.HomeserverTestCase): - - servlets = [ - account.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - register.register_servlets, - login.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - - # Email config. - config["email"] = { - "enable_notifs": False, - "template_dir": os.path.abspath( - pkg_resources.resource_filename("synapse", "res/templates") - ), - "smtp_host": "127.0.0.1", - "smtp_port": 20, - "require_transport_security": False, - "smtp_user": None, - "smtp_pass": None, - "notif_from": "test@example.com", - } - config["public_baseurl"] = "https://example.com" - - hs = self.setup_test_homeserver(config=config) - - async def sendmail( - reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs - ): - self.email_attempts.append(msg) - - self.email_attempts = [] - hs.get_send_email_handler()._sendmail = sendmail - - return hs - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.submit_token_resource = PasswordResetSubmitTokenResource(hs) - - def test_basic_password_reset(self): - """Test basic password reset flow""" - old_password = "monkey" - new_password = "kangeroo" - - user_id = self.register_user("kermit", old_password) - self.login("kermit", old_password) - - email = "test@example.com" - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address=email, - validated_at=0, - added_at=0, - ) - ) - - client_secret = "foobar" - session_id = self._request_token(email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() - - self._validate_token(link) - - self._reset_password(new_password, session_id, client_secret) - - # Assert we can log in with the new password - self.login("kermit", new_password) - - # Assert we can't log in with the old password - self.attempt_wrong_password_login("kermit", old_password) - - @override_config({"rc_3pid_validation": {"burst_count": 3}}) - def test_ratelimit_by_email(self): - """Test that we ratelimit /requestToken for the same email.""" - old_password = "monkey" - new_password = "kangeroo" - - user_id = self.register_user("kermit", old_password) - self.login("kermit", old_password) - - email = "test1@example.com" - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address=email, - validated_at=0, - added_at=0, - ) - ) - - def reset(ip): - client_secret = "foobar" - session_id = self._request_token(email, client_secret, ip) - - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() - - self._validate_token(link) - - self._reset_password(new_password, session_id, client_secret) - - self.email_attempts.clear() - - # We expect to be able to make three requests before getting rate - # limited. - # - # We change IPs to ensure that we're not being ratelimited due to the - # same IP - reset("127.0.0.1") - reset("127.0.0.2") - reset("127.0.0.3") - - with self.assertRaises(HttpResponseException) as cm: - reset("127.0.0.4") - - self.assertEqual(cm.exception.code, 429) - - def test_basic_password_reset_canonicalise_email(self): - """Test basic password reset flow - Request password reset with different spelling - """ - old_password = "monkey" - new_password = "kangeroo" - - user_id = self.register_user("kermit", old_password) - self.login("kermit", old_password) - - email_profile = "test@example.com" - email_passwort_reset = "TEST@EXAMPLE.COM" - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address=email_profile, - validated_at=0, - added_at=0, - ) - ) - - client_secret = "foobar" - session_id = self._request_token(email_passwort_reset, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() - - self._validate_token(link) - - self._reset_password(new_password, session_id, client_secret) - - # Assert we can log in with the new password - self.login("kermit", new_password) - - # Assert we can't log in with the old password - self.attempt_wrong_password_login("kermit", old_password) - - def test_cant_reset_password_without_clicking_link(self): - """Test that we do actually need to click the link in the email""" - old_password = "monkey" - new_password = "kangeroo" - - user_id = self.register_user("kermit", old_password) - self.login("kermit", old_password) - - email = "test@example.com" - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address=email, - validated_at=0, - added_at=0, - ) - ) - - client_secret = "foobar" - session_id = self._request_token(email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - - # Attempt to reset password without clicking the link - self._reset_password(new_password, session_id, client_secret, expected_code=401) - - # Assert we can log in with the old password - self.login("kermit", old_password) - - # Assert we can't log in with the new password - self.attempt_wrong_password_login("kermit", new_password) - - def test_no_valid_token(self): - """Test that we do actually need to request a token and can't just - make a session up. - """ - old_password = "monkey" - new_password = "kangeroo" - - user_id = self.register_user("kermit", old_password) - self.login("kermit", old_password) - - email = "test@example.com" - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address=email, - validated_at=0, - added_at=0, - ) - ) - - client_secret = "foobar" - session_id = "weasle" - - # Attempt to reset password without even requesting an email - self._reset_password(new_password, session_id, client_secret, expected_code=401) - - # Assert we can log in with the old password - self.login("kermit", old_password) - - # Assert we can't log in with the new password - self.attempt_wrong_password_login("kermit", new_password) - - @unittest.override_config({"request_token_inhibit_3pid_errors": True}) - def test_password_reset_bad_email_inhibit_error(self): - """Test that triggering a password reset with an email address that isn't bound - to an account doesn't leak the lack of binding for that address if configured - that way. - """ - self.register_user("kermit", "monkey") - self.login("kermit", "monkey") - - email = "test@example.com" - - client_secret = "foobar" - session_id = self._request_token(email, client_secret) - - self.assertIsNotNone(session_id) - - def _request_token(self, email, client_secret, ip="127.0.0.1"): - channel = self.make_request( - "POST", - b"account/password/email/requestToken", - {"client_secret": client_secret, "email": email, "send_attempt": 1}, - client_ip=ip, - ) - - if channel.code != 200: - raise HttpResponseException( - channel.code, - channel.result["reason"], - channel.result["body"], - ) - - return channel.json_body["sid"] - - def _validate_token(self, link): - # Remove the host - path = link.replace("https://example.com", "") - - # Load the password reset confirmation page - channel = make_request( - self.reactor, - FakeSite(self.submit_token_resource), - "GET", - path, - shorthand=False, - ) - - self.assertEquals(200, channel.code, channel.result) - - # Now POST to the same endpoint, mimicking the same behaviour as clicking the - # password reset confirm button - - # Confirm the password reset - channel = make_request( - self.reactor, - FakeSite(self.submit_token_resource), - "POST", - path, - content=b"", - shorthand=False, - content_is_form=True, - ) - self.assertEquals(200, channel.code, channel.result) - - def _get_link_from_email(self): - assert self.email_attempts, "No emails have been sent" - - raw_msg = self.email_attempts[-1].decode("UTF-8") - mail = Parser().parsestr(raw_msg) - - text = None - for part in mail.walk(): - if part.get_content_type() == "text/plain": - text = part.get_payload(decode=True).decode("UTF-8") - break - - if not text: - self.fail("Could not find text portion of email to parse") - - match = re.search(r"https://example.com\S+", text) - assert match, "Could not find link in email" - - return match.group(0) - - def _reset_password( - self, new_password, session_id, client_secret, expected_code=200 - ): - channel = self.make_request( - "POST", - b"account/password", - { - "new_password": new_password, - "auth": { - "type": LoginType.EMAIL_IDENTITY, - "threepid_creds": { - "client_secret": client_secret, - "sid": session_id, - }, - }, - }, - ) - self.assertEquals(expected_code, channel.code, channel.result) - - -class DeactivateTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - account.register_servlets, - room.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() - return self.hs - - def test_deactivate_account(self): - user_id = self.register_user("kermit", "test") - tok = self.login("kermit", "test") - - self.deactivate(user_id, tok) - - store = self.hs.get_datastore() - - # Check that the user has been marked as deactivated. - self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) - - # Check that this access token has been invalidated. - channel = self.make_request("GET", "account/whoami", access_token=tok) - self.assertEqual(channel.code, 401) - - def test_pending_invites(self): - """Tests that deactivating a user rejects every pending invite for them.""" - store = self.hs.get_datastore() - - inviter_id = self.register_user("inviter", "test") - inviter_tok = self.login("inviter", "test") - - invitee_id = self.register_user("invitee", "test") - invitee_tok = self.login("invitee", "test") - - # Make @inviter:test invite @invitee:test in a new room. - room_id = self.helper.create_room_as(inviter_id, tok=inviter_tok) - self.helper.invite( - room=room_id, src=inviter_id, targ=invitee_id, tok=inviter_tok - ) - - # Make sure the invite is here. - pending_invites = self.get_success( - store.get_invited_rooms_for_local_user(invitee_id) - ) - self.assertEqual(len(pending_invites), 1, pending_invites) - self.assertEqual(pending_invites[0].room_id, room_id, pending_invites) - - # Deactivate @invitee:test. - self.deactivate(invitee_id, invitee_tok) - - # Check that the invite isn't there anymore. - pending_invites = self.get_success( - store.get_invited_rooms_for_local_user(invitee_id) - ) - self.assertEqual(len(pending_invites), 0, pending_invites) - - # Check that the membership of @invitee:test in the room is now "leave". - memberships = self.get_success( - store.get_rooms_for_local_user_where_membership_is( - invitee_id, [Membership.LEAVE] - ) - ) - self.assertEqual(len(memberships), 1, memberships) - self.assertEqual(memberships[0].room_id, room_id, memberships) - - def deactivate(self, user_id, tok): - request_data = json.dumps( - { - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "test", - }, - "erase": False, - } - ) - channel = self.make_request( - "POST", "account/deactivate", request_data, access_token=tok - ) - self.assertEqual(channel.code, 200) - - -class WhoamiTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - account.register_servlets, - register.register_servlets, - ] - - def test_GET_whoami(self): - device_id = "wouldgohere" - user_id = self.register_user("kermit", "test") - tok = self.login("kermit", "test", device_id=device_id) - - whoami = self.whoami(tok) - self.assertEqual(whoami, {"user_id": user_id, "device_id": device_id}) - - def test_GET_whoami_appservices(self): - user_id = "@as:test" - as_token = "i_am_an_app_service" - - appservice = ApplicationService( - as_token, - self.hs.config.server_name, - id="1234", - namespaces={"users": [{"regex": user_id, "exclusive": True}]}, - sender=user_id, - ) - self.hs.get_datastore().services_cache.append(appservice) - - whoami = self.whoami(as_token) - self.assertEqual(whoami, {"user_id": user_id}) - self.assertFalse(hasattr(whoami, "device_id")) - - def whoami(self, tok): - channel = self.make_request("GET", "account/whoami", {}, access_token=tok) - self.assertEqual(channel.code, 200) - return channel.json_body - - -class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): - - servlets = [ - account.register_servlets, - login.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - - # Email config. - config["email"] = { - "enable_notifs": False, - "template_dir": os.path.abspath( - pkg_resources.resource_filename("synapse", "res/templates") - ), - "smtp_host": "127.0.0.1", - "smtp_port": 20, - "require_transport_security": False, - "smtp_user": None, - "smtp_pass": None, - "notif_from": "test@example.com", - } - config["public_baseurl"] = "https://example.com" - - self.hs = self.setup_test_homeserver(config=config) - - async def sendmail( - reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs - ): - self.email_attempts.append(msg) - - self.email_attempts = [] - self.hs.get_send_email_handler()._sendmail = sendmail - - return self.hs - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - - self.user_id = self.register_user("kermit", "test") - self.user_id_tok = self.login("kermit", "test") - self.email = "test@example.com" - self.url_3pid = b"account/3pid" - - def test_add_valid_email(self): - self.get_success(self._add_email(self.email, self.email)) - - def test_add_valid_email_second_time(self): - self.get_success(self._add_email(self.email, self.email)) - self.get_success( - self._request_token_invalid_email( - self.email, - expected_errcode=Codes.THREEPID_IN_USE, - expected_error="Email is already in use", - ) - ) - - def test_add_valid_email_second_time_canonicalise(self): - self.get_success(self._add_email(self.email, self.email)) - self.get_success( - self._request_token_invalid_email( - "TEST@EXAMPLE.COM", - expected_errcode=Codes.THREEPID_IN_USE, - expected_error="Email is already in use", - ) - ) - - def test_add_email_no_at(self): - self.get_success( - self._request_token_invalid_email( - "address-without-at.bar", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) - ) - - def test_add_email_two_at(self): - self.get_success( - self._request_token_invalid_email( - "foo@foo@test.bar", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) - ) - - def test_add_email_bad_format(self): - self.get_success( - self._request_token_invalid_email( - "user@bad.example.net@good.example.com", - expected_errcode=Codes.UNKNOWN, - expected_error="Unable to parse email address", - ) - ) - - def test_add_email_domain_to_lower(self): - self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar")) - - def test_add_email_domain_with_umlaut(self): - self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")) - - def test_add_email_address_casefold(self): - self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com")) - - def test_address_trim(self): - self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar")) - - @override_config({"rc_3pid_validation": {"burst_count": 3}}) - def test_ratelimit_by_ip(self): - """Tests that adding emails is ratelimited by IP""" - - # We expect to be able to set three emails before getting ratelimited. - self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar")) - self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar")) - self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar")) - - with self.assertRaises(HttpResponseException) as cm: - self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar")) - - self.assertEqual(cm.exception.code, 429) - - def test_add_email_if_disabled(self): - """Test adding email to profile when doing so is disallowed""" - self.hs.config.enable_3pid_changes = False - - client_secret = "foobar" - session_id = self._request_token(self.email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - link = self._get_link_from_email() - - self._validate_token(link) - - channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - # Get user - channel = self.make_request( - "GET", - self.url_3pid, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def test_delete_email(self): - """Test deleting an email from profile""" - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=self.user_id, - medium="email", - address=self.email, - validated_at=0, - added_at=0, - ) - ) - - channel = self.make_request( - "POST", - b"account/3pid/delete", - {"medium": "email", "address": self.email}, - access_token=self.user_id_tok, - ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Get user - channel = self.make_request( - "GET", - self.url_3pid, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def test_delete_email_if_disabled(self): - """Test deleting an email from profile when disallowed""" - self.hs.config.enable_3pid_changes = False - - # Add a threepid - self.get_success( - self.store.user_add_threepid( - user_id=self.user_id, - medium="email", - address=self.email, - validated_at=0, - added_at=0, - ) - ) - - channel = self.make_request( - "POST", - b"account/3pid/delete", - {"medium": "email", "address": self.email}, - access_token=self.user_id_tok, - ) - - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - # Get user - channel = self.make_request( - "GET", - self.url_3pid, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) - - def test_cant_add_email_without_clicking_link(self): - """Test that we do actually need to click the link in the email""" - client_secret = "foobar" - session_id = self._request_token(self.email, client_secret) - - self.assertEquals(len(self.email_attempts), 1) - - # Attempt to add email without clicking the link - channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) - - # Get user - channel = self.make_request( - "GET", - self.url_3pid, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - def test_no_valid_token(self): - """Test that we do actually need to request a token and can't just - make a session up. - """ - client_secret = "foobar" - session_id = "weasle" - - # Attempt to add email without even requesting an email - channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) - - # Get user - channel = self.make_request( - "GET", - self.url_3pid, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertFalse(channel.json_body["threepids"]) - - @override_config({"next_link_domain_whitelist": None}) - def test_next_link(self): - """Tests a valid next_link parameter value with no whitelist (good case)""" - self._request_token( - "something@example.com", - "some_secret", - next_link="https://example.com/a/good/site", - expect_code=200, - ) - - @override_config({"next_link_domain_whitelist": None}) - def test_next_link_exotic_protocol(self): - """Tests using a esoteric protocol as a next_link parameter value. - Someone may be hosting a client on IPFS etc. - """ - self._request_token( - "something@example.com", - "some_secret", - next_link="some-protocol://abcdefghijklmopqrstuvwxyz", - expect_code=200, - ) - - @override_config({"next_link_domain_whitelist": None}) - def test_next_link_file_uri(self): - """Tests next_link parameters cannot be file URI""" - # Attempt to use a next_link value that points to the local disk - self._request_token( - "something@example.com", - "some_secret", - next_link="file:///host/path", - expect_code=400, - ) - - @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) - def test_next_link_domain_whitelist(self): - """Tests next_link parameters must fit the whitelist if provided""" - - # Ensure not providing a next_link parameter still works - self._request_token( - "something@example.com", - "some_secret", - next_link=None, - expect_code=200, - ) - - self._request_token( - "something@example.com", - "some_secret", - next_link="https://example.com/some/good/page", - expect_code=200, - ) - - self._request_token( - "something@example.com", - "some_secret", - next_link="https://example.org/some/also/good/page", - expect_code=200, - ) - - self._request_token( - "something@example.com", - "some_secret", - next_link="https://bad.example.org/some/bad/page", - expect_code=400, - ) - - @override_config({"next_link_domain_whitelist": []}) - def test_empty_next_link_domain_whitelist(self): - """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially - disallowed - """ - self._request_token( - "something@example.com", - "some_secret", - next_link="https://example.com/a/page", - expect_code=400, - ) - - def _request_token( - self, - email: str, - client_secret: str, - next_link: Optional[str] = None, - expect_code: int = 200, - ) -> str: - """Request a validation token to add an email address to a user's account - - Args: - email: The email address to validate - client_secret: A secret string - next_link: A link to redirect the user to after validation - expect_code: Expected return code of the call - - Returns: - The ID of the new threepid validation session - """ - body = {"client_secret": client_secret, "email": email, "send_attempt": 1} - if next_link: - body["next_link"] = next_link - - channel = self.make_request( - "POST", - b"account/3pid/email/requestToken", - body, - ) - - if channel.code != expect_code: - raise HttpResponseException( - channel.code, - channel.result["reason"], - channel.result["body"], - ) - - return channel.json_body.get("sid") - - def _request_token_invalid_email( - self, - email, - expected_errcode, - expected_error, - client_secret="foobar", - ): - channel = self.make_request( - "POST", - b"account/3pid/email/requestToken", - {"client_secret": client_secret, "email": email, "send_attempt": 1}, - ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(expected_errcode, channel.json_body["errcode"]) - self.assertEqual(expected_error, channel.json_body["error"]) - - def _validate_token(self, link): - # Remove the host - path = link.replace("https://example.com", "") - - channel = self.make_request("GET", path, shorthand=False) - self.assertEquals(200, channel.code, channel.result) - - def _get_link_from_email(self): - assert self.email_attempts, "No emails have been sent" - - raw_msg = self.email_attempts[-1].decode("UTF-8") - mail = Parser().parsestr(raw_msg) - - text = None - for part in mail.walk(): - if part.get_content_type() == "text/plain": - text = part.get_payload(decode=True).decode("UTF-8") - break - - if not text: - self.fail("Could not find text portion of email to parse") - - match = re.search(r"https://example.com\S+", text) - assert match, "Could not find link in email" - - return match.group(0) - - def _add_email(self, request_email, expected_email): - """Test adding an email to profile""" - previous_email_attempts = len(self.email_attempts) - - client_secret = "foobar" - session_id = self._request_token(request_email, client_secret) - - self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1) - link = self._get_link_from_email() - - self._validate_token(link) - - channel = self.make_request( - "POST", - b"/_matrix/client/unstable/account/3pid/add", - { - "client_secret": client_secret, - "sid": session_id, - "auth": { - "type": "m.login.password", - "user": self.user_id, - "password": "test", - }, - }, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - - # Get user - channel = self.make_request( - "GET", - self.url_3pid, - access_token=self.user_id_tok, - ) - - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - - threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} - self.assertIn(expected_email, threepids) diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py deleted file mode 100644 index cf5cfb910c..0000000000 --- a/tests/rest/client/v2_alpha/test_auth.py +++ /dev/null @@ -1,717 +0,0 @@ -# Copyright 2018 New Vector -# Copyright 2020-2021 The Matrix.org Foundation C.I.C -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional, Union - -from twisted.internet.defer import succeed - -import synapse.rest.admin -from synapse.api.constants import LoginType -from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker -from synapse.rest.client import account, auth, devices, login, register -from synapse.rest.synapse.client import build_synapse_client_resource_tree -from synapse.types import JsonDict, UserID - -from tests import unittest -from tests.handlers.test_oidc import HAS_OIDC -from tests.rest.client.v1.utils import TEST_OIDC_CONFIG -from tests.server import FakeChannel -from tests.unittest import override_config, skip_unless - - -class DummyRecaptchaChecker(UserInteractiveAuthChecker): - def __init__(self, hs): - super().__init__(hs) - self.recaptcha_attempts = [] - - def check_auth(self, authdict, clientip): - self.recaptcha_attempts.append((authdict, clientip)) - return succeed(True) - - -class FallbackAuthTests(unittest.HomeserverTestCase): - - servlets = [ - auth.register_servlets, - register.register_servlets, - ] - hijack_auth = False - - def make_homeserver(self, reactor, clock): - - config = self.default_config() - - config["enable_registration_captcha"] = True - config["recaptcha_public_key"] = "brokencake" - config["registrations_require_3pid"] = [] - - hs = self.setup_test_homeserver(config=config) - return hs - - def prepare(self, reactor, clock, hs): - self.recaptcha_checker = DummyRecaptchaChecker(hs) - auth_handler = hs.get_auth_handler() - auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker - - def register(self, expected_response: int, body: JsonDict) -> FakeChannel: - """Make a register request.""" - channel = self.make_request("POST", "register", body) - - self.assertEqual(channel.code, expected_response) - return channel - - def recaptcha( - self, - session: str, - expected_post_response: int, - post_session: Optional[str] = None, - ) -> None: - """Get and respond to a fallback recaptcha. Returns the second request.""" - if post_session is None: - post_session = session - - channel = self.make_request( - "GET", "auth/m.login.recaptcha/fallback/web?session=" + session - ) - self.assertEqual(channel.code, 200) - - channel = self.make_request( - "POST", - "auth/m.login.recaptcha/fallback/web?session=" - + post_session - + "&g-recaptcha-response=a", - ) - self.assertEqual(channel.code, expected_post_response) - - # The recaptcha handler is called with the response given - attempts = self.recaptcha_checker.recaptcha_attempts - self.assertEqual(len(attempts), 1) - self.assertEqual(attempts[0][0]["response"], "a") - - def test_fallback_captcha(self): - """Ensure that fallback auth via a captcha works.""" - # Returns a 401 as per the spec - channel = self.register( - 401, - {"username": "user", "type": "m.login.password", "password": "bar"}, - ) - - # Grab the session - session = channel.json_body["session"] - # Assert our configured public key is being given - self.assertEqual( - channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" - ) - - # Complete the recaptcha step. - self.recaptcha(session, 200) - - # also complete the dummy auth - self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}}) - - # Now we should have fulfilled a complete auth flow, including - # the recaptcha fallback step, we can then send a - # request to the register API with the session in the authdict. - channel = self.register(200, {"auth": {"session": session}}) - - # We're given a registered user. - self.assertEqual(channel.json_body["user_id"], "@user:test") - - def test_complete_operation_unknown_session(self): - """ - Attempting to mark an invalid session as complete should error. - """ - # Make the initial request to register. (Later on a different password - # will be used.) - # Returns a 401 as per the spec - channel = self.register( - 401, {"username": "user", "type": "m.login.password", "password": "bar"} - ) - - # Grab the session - session = channel.json_body["session"] - # Assert our configured public key is being given - self.assertEqual( - channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" - ) - - # Attempt to complete the recaptcha step with an unknown session. - # This results in an error. - self.recaptcha(session, 400, session + "unknown") - - -class UIAuthTests(unittest.HomeserverTestCase): - servlets = [ - auth.register_servlets, - devices.register_servlets, - login.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - register.register_servlets, - ] - - def default_config(self): - config = super().default_config() - - # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns - # False, so synapse will see the requested uri as http://..., so using http in - # the public_baseurl stops Synapse trying to redirect to https. - config["public_baseurl"] = "http://synapse.test" - - if HAS_OIDC: - # we enable OIDC as a way of testing SSO flows - oidc_config = {} - oidc_config.update(TEST_OIDC_CONFIG) - oidc_config["allow_existing_users"] = True - config["oidc_config"] = oidc_config - - return config - - def create_resource_dict(self): - resource_dict = super().create_resource_dict() - resource_dict.update(build_synapse_client_resource_tree(self.hs)) - return resource_dict - - def prepare(self, reactor, clock, hs): - self.user_pass = "pass" - self.user = self.register_user("test", self.user_pass) - self.device_id = "dev1" - self.user_tok = self.login("test", self.user_pass, self.device_id) - - def delete_device( - self, - access_token: str, - device: str, - expected_response: int, - body: Union[bytes, JsonDict] = b"", - ) -> FakeChannel: - """Delete an individual device.""" - channel = self.make_request( - "DELETE", - "devices/" + device, - body, - access_token=access_token, - ) - - # Ensure the response is sane. - self.assertEqual(channel.code, expected_response) - - return channel - - def delete_devices(self, expected_response: int, body: JsonDict) -> FakeChannel: - """Delete 1 or more devices.""" - # Note that this uses the delete_devices endpoint so that we can modify - # the payload half-way through some tests. - channel = self.make_request( - "POST", - "delete_devices", - body, - access_token=self.user_tok, - ) - - # Ensure the response is sane. - self.assertEqual(channel.code, expected_response) - - return channel - - def test_ui_auth(self): - """ - Test user interactive authentication outside of registration. - """ - # Attempt to delete this device. - # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, self.device_id, 401) - - # Grab the session - session = channel.json_body["session"] - # Ensure that flows are what is expected. - self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) - - # Make another request providing the UI auth flow. - self.delete_device( - self.user_tok, - self.device_id, - 200, - { - "auth": { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": self.user}, - "password": self.user_pass, - "session": session, - }, - }, - ) - - def test_grandfathered_identifier(self): - """Check behaviour without "identifier" dict - - Synapse used to require clients to submit a "user" field for m.login.password - UIA - check that still works. - """ - - channel = self.delete_device(self.user_tok, self.device_id, 401) - session = channel.json_body["session"] - - # Make another request providing the UI auth flow. - self.delete_device( - self.user_tok, - self.device_id, - 200, - { - "auth": { - "type": "m.login.password", - "user": self.user, - "password": self.user_pass, - "session": session, - }, - }, - ) - - def test_can_change_body(self): - """ - The client dict can be modified during the user interactive authentication session. - - Note that it is not spec compliant to modify the client dict during a - user interactive authentication session, but many clients currently do. - - When Synapse is updated to be spec compliant, the call to re-use the - session ID should be rejected. - """ - # Create a second login. - self.login("test", self.user_pass, "dev2") - - # Attempt to delete the first device. - # Returns a 401 as per the spec - channel = self.delete_devices(401, {"devices": [self.device_id]}) - - # Grab the session - session = channel.json_body["session"] - # Ensure that flows are what is expected. - self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) - - # Make another request providing the UI auth flow, but try to delete the - # second device. - self.delete_devices( - 200, - { - "devices": ["dev2"], - "auth": { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": self.user}, - "password": self.user_pass, - "session": session, - }, - }, - ) - - def test_cannot_change_uri(self): - """ - The initial requested URI cannot be modified during the user interactive authentication session. - """ - # Create a second login. - self.login("test", self.user_pass, "dev2") - - # Attempt to delete the first device. - # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, self.device_id, 401) - - # Grab the session - session = channel.json_body["session"] - # Ensure that flows are what is expected. - self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) - - # Make another request providing the UI auth flow, but try to delete the - # second device. This results in an error. - # - # This makes use of the fact that the device ID is embedded into the URL. - self.delete_device( - self.user_tok, - "dev2", - 403, - { - "auth": { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": self.user}, - "password": self.user_pass, - "session": session, - }, - }, - ) - - @unittest.override_config({"ui_auth": {"session_timeout": "5s"}}) - def test_can_reuse_session(self): - """ - The session can be reused if configured. - - Compare to test_cannot_change_uri. - """ - # Create a second and third login. - self.login("test", self.user_pass, "dev2") - self.login("test", self.user_pass, "dev3") - - # Attempt to delete a device. This works since the user just logged in. - self.delete_device(self.user_tok, "dev2", 200) - - # Move the clock forward past the validation timeout. - self.reactor.advance(6) - - # Deleting another devices throws the user into UI auth. - channel = self.delete_device(self.user_tok, "dev3", 401) - - # Grab the session - session = channel.json_body["session"] - # Ensure that flows are what is expected. - self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) - - # Make another request providing the UI auth flow. - self.delete_device( - self.user_tok, - "dev3", - 200, - { - "auth": { - "type": "m.login.password", - "identifier": {"type": "m.id.user", "user": self.user}, - "password": self.user_pass, - "session": session, - }, - }, - ) - - # Make another request, but try to delete the first device. This works - # due to re-using the previous session. - # - # Note that *no auth* information is provided, not even a session iD! - self.delete_device(self.user_tok, self.device_id, 200) - - @skip_unless(HAS_OIDC, "requires OIDC") - @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_ui_auth_via_sso(self): - """Test a successful UI Auth flow via SSO - - This includes: - * hitting the UIA SSO redirect endpoint - * checking it serves a confirmation page which links to the OIDC provider - * calling back to the synapse oidc callback - * checking that the original operation succeeds - """ - - # log the user in - remote_user_id = UserID.from_string(self.user).localpart - login_resp = self.helper.login_via_oidc(remote_user_id) - self.assertEqual(login_resp["user_id"], self.user) - - # initiate a UI Auth process by attempting to delete the device - channel = self.delete_device(self.user_tok, self.device_id, 401) - - # check that SSO is offered - flows = channel.json_body["flows"] - self.assertIn({"stages": ["m.login.sso"]}, flows) - - # run the UIA-via-SSO flow - session_id = channel.json_body["session"] - channel = self.helper.auth_via_oidc( - {"sub": remote_user_id}, ui_auth_session_id=session_id - ) - - # that should serve a confirmation page - self.assertEqual(channel.code, 200, channel.result) - - # and now the delete request should succeed. - self.delete_device( - self.user_tok, - self.device_id, - 200, - body={"auth": {"session": session_id}}, - ) - - @skip_unless(HAS_OIDC, "requires OIDC") - @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_does_not_offer_password_for_sso_user(self): - login_resp = self.helper.login_via_oidc("username") - user_tok = login_resp["access_token"] - device_id = login_resp["device_id"] - - # now call the device deletion API: we should get the option to auth with SSO - # and not password. - channel = self.delete_device(user_tok, device_id, 401) - - flows = channel.json_body["flows"] - self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) - - def test_does_not_offer_sso_for_password_user(self): - channel = self.delete_device(self.user_tok, self.device_id, 401) - - flows = channel.json_body["flows"] - self.assertEqual(flows, [{"stages": ["m.login.password"]}]) - - @skip_unless(HAS_OIDC, "requires OIDC") - @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_offers_both_flows_for_upgraded_user(self): - """A user that had a password and then logged in with SSO should get both flows""" - login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) - self.assertEqual(login_resp["user_id"], self.user) - - channel = self.delete_device(self.user_tok, self.device_id, 401) - - flows = channel.json_body["flows"] - # we have no particular expectations of ordering here - self.assertIn({"stages": ["m.login.password"]}, flows) - self.assertIn({"stages": ["m.login.sso"]}, flows) - self.assertEqual(len(flows), 2) - - @skip_unless(HAS_OIDC, "requires OIDC") - @override_config({"oidc_config": TEST_OIDC_CONFIG}) - def test_ui_auth_fails_for_incorrect_sso_user(self): - """If the user tries to authenticate with the wrong SSO user, they get an error""" - # log the user in - login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) - self.assertEqual(login_resp["user_id"], self.user) - - # start a UI Auth flow by attempting to delete a device - channel = self.delete_device(self.user_tok, self.device_id, 401) - - flows = channel.json_body["flows"] - self.assertIn({"stages": ["m.login.sso"]}, flows) - session_id = channel.json_body["session"] - - # do the OIDC auth, but auth as the wrong user - channel = self.helper.auth_via_oidc( - {"sub": "wrong_user"}, ui_auth_session_id=session_id - ) - - # that should return a failure message - self.assertSubstring("We were unable to validate", channel.text_body) - - # ... and the delete op should now fail with a 403 - self.delete_device( - self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}} - ) - - -class RefreshAuthTests(unittest.HomeserverTestCase): - servlets = [ - auth.register_servlets, - account.register_servlets, - login.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - register.register_servlets, - ] - hijack_auth = False - - def prepare(self, reactor, clock, hs): - self.user_pass = "pass" - self.user = self.register_user("test", self.user_pass) - - def test_login_issue_refresh_token(self): - """ - A login response should include a refresh_token only if asked. - """ - # Test login - body = {"type": "m.login.password", "user": "test", "password": self.user_pass} - - login_without_refresh = self.make_request( - "POST", "/_matrix/client/r0/login", body - ) - self.assertEqual(login_without_refresh.code, 200, login_without_refresh.result) - self.assertNotIn("refresh_token", login_without_refresh.json_body) - - login_with_refresh = self.make_request( - "POST", - "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", - body, - ) - self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result) - self.assertIn("refresh_token", login_with_refresh.json_body) - self.assertIn("expires_in_ms", login_with_refresh.json_body) - - def test_register_issue_refresh_token(self): - """ - A register response should include a refresh_token only if asked. - """ - register_without_refresh = self.make_request( - "POST", - "/_matrix/client/r0/register", - { - "username": "test2", - "password": self.user_pass, - "auth": {"type": LoginType.DUMMY}, - }, - ) - self.assertEqual( - register_without_refresh.code, 200, register_without_refresh.result - ) - self.assertNotIn("refresh_token", register_without_refresh.json_body) - - register_with_refresh = self.make_request( - "POST", - "/_matrix/client/r0/register?org.matrix.msc2918.refresh_token=true", - { - "username": "test3", - "password": self.user_pass, - "auth": {"type": LoginType.DUMMY}, - }, - ) - self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result) - self.assertIn("refresh_token", register_with_refresh.json_body) - self.assertIn("expires_in_ms", register_with_refresh.json_body) - - def test_token_refresh(self): - """ - A refresh token can be used to issue a new access token. - """ - body = {"type": "m.login.password", "user": "test", "password": self.user_pass} - login_response = self.make_request( - "POST", - "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", - body, - ) - self.assertEqual(login_response.code, 200, login_response.result) - - refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": login_response.json_body["refresh_token"]}, - ) - self.assertEqual(refresh_response.code, 200, refresh_response.result) - self.assertIn("access_token", refresh_response.json_body) - self.assertIn("refresh_token", refresh_response.json_body) - self.assertIn("expires_in_ms", refresh_response.json_body) - - # The access and refresh tokens should be different from the original ones after refresh - self.assertNotEqual( - login_response.json_body["access_token"], - refresh_response.json_body["access_token"], - ) - self.assertNotEqual( - login_response.json_body["refresh_token"], - refresh_response.json_body["refresh_token"], - ) - - @override_config({"access_token_lifetime": "1m"}) - def test_refresh_token_expiration(self): - """ - The access token should have some time as specified in the config. - """ - body = {"type": "m.login.password", "user": "test", "password": self.user_pass} - login_response = self.make_request( - "POST", - "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", - body, - ) - self.assertEqual(login_response.code, 200, login_response.result) - self.assertApproximates( - login_response.json_body["expires_in_ms"], 60 * 1000, 100 - ) - - refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": login_response.json_body["refresh_token"]}, - ) - self.assertEqual(refresh_response.code, 200, refresh_response.result) - self.assertApproximates( - refresh_response.json_body["expires_in_ms"], 60 * 1000, 100 - ) - - def test_refresh_token_invalidation(self): - """Refresh tokens are invalidated after first use of the next token. - - A refresh token is considered invalid if: - - it was already used at least once - - and either - - the next access token was used - - the next refresh token was used - - The chain of tokens goes like this: - - login -|-> first_refresh -> third_refresh (fails) - |-> second_refresh -> fifth_refresh - |-> fourth_refresh (fails) - """ - - body = {"type": "m.login.password", "user": "test", "password": self.user_pass} - login_response = self.make_request( - "POST", - "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", - body, - ) - self.assertEqual(login_response.code, 200, login_response.result) - - # This first refresh should work properly - first_refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": login_response.json_body["refresh_token"]}, - ) - self.assertEqual( - first_refresh_response.code, 200, first_refresh_response.result - ) - - # This one as well, since the token in the first one was never used - second_refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": login_response.json_body["refresh_token"]}, - ) - self.assertEqual( - second_refresh_response.code, 200, second_refresh_response.result - ) - - # This one should not, since the token from the first refresh is not valid anymore - third_refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": first_refresh_response.json_body["refresh_token"]}, - ) - self.assertEqual( - third_refresh_response.code, 401, third_refresh_response.result - ) - - # The associated access token should also be invalid - whoami_response = self.make_request( - "GET", - "/_matrix/client/r0/account/whoami", - access_token=first_refresh_response.json_body["access_token"], - ) - self.assertEqual(whoami_response.code, 401, whoami_response.result) - - # But all other tokens should work (they will expire after some time) - for access_token in [ - second_refresh_response.json_body["access_token"], - login_response.json_body["access_token"], - ]: - whoami_response = self.make_request( - "GET", "/_matrix/client/r0/account/whoami", access_token=access_token - ) - self.assertEqual(whoami_response.code, 200, whoami_response.result) - - # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail - fourth_refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": login_response.json_body["refresh_token"]}, - ) - self.assertEqual( - fourth_refresh_response.code, 403, fourth_refresh_response.result - ) - - # But refreshing from the last valid refresh token still works - fifth_refresh_response = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", - {"refresh_token": second_refresh_response.json_body["refresh_token"]}, - ) - self.assertEqual( - fifth_refresh_response.code, 200, fifth_refresh_response.result - ) diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py deleted file mode 100644 index ac31e5ceaf..0000000000 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import synapse.rest.admin -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.rest.client import capabilities, login - -from tests import unittest -from tests.unittest import override_config - - -class CapabilitiesTestCase(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - capabilities.register_servlets, - login.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - self.url = b"/_matrix/client/r0/capabilities" - hs = self.setup_test_homeserver() - self.config = hs.config - self.auth_handler = hs.get_auth_handler() - return hs - - def prepare(self, reactor, clock, hs): - self.localpart = "user" - self.password = "pass" - self.user = self.register_user(self.localpart, self.password) - - def test_check_auth_required(self): - channel = self.make_request("GET", self.url) - - self.assertEqual(channel.code, 401) - - def test_get_room_version_capabilities(self): - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - for room_version in capabilities["m.room_versions"]["available"].keys(): - self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version) - - self.assertEqual( - self.config.default_room_version.identifier, - capabilities["m.room_versions"]["default"], - ) - - def test_get_change_password_capabilities_password_login(self): - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertTrue(capabilities["m.change_password"]["enabled"]) - - @override_config({"password_config": {"localdb_enabled": False}}) - def test_get_change_password_capabilities_localdb_disabled(self): - access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( - self.user, device_id=None, valid_until_ms=None - ) - ) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["m.change_password"]["enabled"]) - - @override_config({"password_config": {"enabled": False}}) - def test_get_change_password_capabilities_password_disabled(self): - access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( - self.user, device_id=None, valid_until_ms=None - ) - ) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["m.change_password"]["enabled"]) - - def test_get_change_users_attributes_capabilities_when_msc3283_disabled(self): - """Test that per default msc3283 is disabled server returns `m.change_password`.""" - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertTrue(capabilities["m.change_password"]["enabled"]) - self.assertNotIn("org.matrix.msc3283.set_displayname", capabilities) - self.assertNotIn("org.matrix.msc3283.set_avatar_url", capabilities) - self.assertNotIn("org.matrix.msc3283.3pid_changes", capabilities) - - @override_config({"experimental_features": {"msc3283_enabled": True}}) - def test_get_change_users_attributes_capabilities_when_msc3283_enabled(self): - """Test if msc3283 is enabled server returns capabilities.""" - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertTrue(capabilities["m.change_password"]["enabled"]) - self.assertTrue(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) - self.assertTrue(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) - self.assertTrue(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) - - @override_config( - { - "enable_set_displayname": False, - "experimental_features": {"msc3283_enabled": True}, - } - ) - def test_get_set_displayname_capabilities_displayname_disabled(self): - """Test if set displayname is disabled that the server responds it.""" - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) - - @override_config( - { - "enable_set_avatar_url": False, - "experimental_features": {"msc3283_enabled": True}, - } - ) - def test_get_set_avatar_url_capabilities_avatar_url_disabled(self): - """Test if set avatar_url is disabled that the server responds it.""" - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) - - @override_config( - { - "enable_3pid_changes": False, - "experimental_features": {"msc3283_enabled": True}, - } - ) - def test_change_3pid_capabilities_3pid_disabled(self): - """Test if change 3pid is disabled that the server responds it.""" - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) - - def test_get_does_not_include_msc3244_fields_by_default(self): - access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( - self.user, device_id=None, valid_until_ms=None - ) - ) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - self.assertNotIn( - "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"] - ) - - @override_config({"experimental_features": {"msc3244_enabled": True}}) - def test_get_does_include_msc3244_fields_when_enabled(self): - access_token = self.get_success( - self.auth_handler.get_access_token_for_user_id( - self.user, device_id=None, valid_until_ms=None - ) - ) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] - - self.assertEqual(channel.code, 200) - for details in capabilities["m.room_versions"][ - "org.matrix.msc3244.room_capabilities" - ].values(): - if details["preferred"] is not None: - self.assertTrue( - details["preferred"] in KNOWN_ROOM_VERSIONS, - str(details["preferred"]), - ) - - self.assertGreater(len(details["support"]), 0) - for room_version in details["support"]: - self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, str(room_version)) diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py deleted file mode 100644 index 475c6bed3d..0000000000 --- a/tests/rest/client/v2_alpha/test_filter.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from twisted.internet import defer - -from synapse.api.errors import Codes -from synapse.rest.client import filter - -from tests import unittest - -PATH_PREFIX = "/_matrix/client/v2_alpha" - - -class FilterTestCase(unittest.HomeserverTestCase): - - user_id = "@apple:test" - hijack_auth = True - EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} - EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}' - servlets = [filter.register_servlets] - - def prepare(self, reactor, clock, hs): - self.filtering = hs.get_filtering() - self.store = hs.get_datastore() - - def test_add_filter(self): - channel = self.make_request( - "POST", - "/_matrix/client/r0/user/%s/filter" % (self.user_id), - self.EXAMPLE_FILTER_JSON, - ) - - self.assertEqual(channel.result["code"], b"200") - self.assertEqual(channel.json_body, {"filter_id": "0"}) - filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) - self.pump() - self.assertEquals(filter.result, self.EXAMPLE_FILTER) - - def test_add_filter_for_other_user(self): - channel = self.make_request( - "POST", - "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), - self.EXAMPLE_FILTER_JSON, - ) - - self.assertEqual(channel.result["code"], b"403") - self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) - - def test_add_filter_non_local_user(self): - _is_mine = self.hs.is_mine - self.hs.is_mine = lambda target_user: False - channel = self.make_request( - "POST", - "/_matrix/client/r0/user/%s/filter" % (self.user_id), - self.EXAMPLE_FILTER_JSON, - ) - - self.hs.is_mine = _is_mine - self.assertEqual(channel.result["code"], b"403") - self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) - - def test_get_filter(self): - filter_id = defer.ensureDeferred( - self.filtering.add_user_filter( - user_localpart="apple", user_filter=self.EXAMPLE_FILTER - ) - ) - self.reactor.advance(1) - filter_id = filter_id.result - channel = self.make_request( - "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id) - ) - - self.assertEqual(channel.result["code"], b"200") - self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) - - def test_get_filter_non_existant(self): - channel = self.make_request( - "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id) - ) - - self.assertEqual(channel.result["code"], b"404") - self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) - - # Currently invalid params do not have an appropriate errcode - # in errors.py - def test_get_filter_invalid_id(self): - channel = self.make_request( - "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id) - ) - - self.assertEqual(channel.result["code"], b"400") - - # No ID also returns an invalid_id error - def test_get_filter_no_id(self): - channel = self.make_request( - "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id) - ) - - self.assertEqual(channel.result["code"], b"400") diff --git a/tests/rest/client/v2_alpha/test_keys.py b/tests/rest/client/v2_alpha/test_keys.py deleted file mode 100644 index d7fa635eae..0000000000 --- a/tests/rest/client/v2_alpha/test_keys.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License - -from http import HTTPStatus - -from synapse.api.errors import Codes -from synapse.rest import admin -from synapse.rest.client import keys, login - -from tests import unittest - - -class KeyQueryTestCase(unittest.HomeserverTestCase): - servlets = [ - keys.register_servlets, - admin.register_servlets_for_client_rest_resource, - login.register_servlets, - ] - - def test_rejects_device_id_ice_key_outside_of_list(self): - self.register_user("alice", "wonderland") - alice_token = self.login("alice", "wonderland") - bob = self.register_user("bob", "uncle") - channel = self.make_request( - "POST", - "/_matrix/client/r0/keys/query", - { - "device_keys": { - bob: "device_id1", - }, - }, - alice_token, - ) - self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.BAD_JSON, - channel.result, - ) - - def test_rejects_device_key_given_as_map_to_bool(self): - self.register_user("alice", "wonderland") - alice_token = self.login("alice", "wonderland") - bob = self.register_user("bob", "uncle") - channel = self.make_request( - "POST", - "/_matrix/client/r0/keys/query", - { - "device_keys": { - bob: { - "device_id1": True, - }, - }, - }, - alice_token, - ) - - self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.BAD_JSON, - channel.result, - ) - - def test_requires_device_key(self): - """`device_keys` is required. We should complain if it's missing.""" - self.register_user("alice", "wonderland") - alice_token = self.login("alice", "wonderland") - channel = self.make_request( - "POST", - "/_matrix/client/r0/keys/query", - {}, - alice_token, - ) - self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.BAD_JSON, - channel.result, - ) diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/v2_alpha/test_password_policy.py deleted file mode 100644 index 3cf5871899..0000000000 --- a/tests/rest/client/v2_alpha/test_password_policy.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -from synapse.api.constants import LoginType -from synapse.api.errors import Codes -from synapse.rest import admin -from synapse.rest.client import account, login, password_policy, register - -from tests import unittest - - -class PasswordPolicyTestCase(unittest.HomeserverTestCase): - """Tests the password policy feature and its compliance with MSC2000. - - When validating a password, Synapse does the necessary checks in this order: - - 1. Password is long enough - 2. Password contains digit(s) - 3. Password contains symbol(s) - 4. Password contains uppercase letter(s) - 5. Password contains lowercase letter(s) - - For each test below that checks whether a password triggers the right error code, - that test provides a password good enough to pass the previous tests, but not the - one it is currently testing (nor any test that comes afterward). - """ - - servlets = [ - admin.register_servlets_for_client_rest_resource, - login.register_servlets, - register.register_servlets, - password_policy.register_servlets, - account.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - self.register_url = "/_matrix/client/r0/register" - self.policy = { - "enabled": True, - "minimum_length": 10, - "require_digit": True, - "require_symbol": True, - "require_lowercase": True, - "require_uppercase": True, - } - - config = self.default_config() - config["password_config"] = { - "policy": self.policy, - } - - hs = self.setup_test_homeserver(config=config) - return hs - - def test_get_policy(self): - """Tests if the /password_policy endpoint returns the configured policy.""" - - channel = self.make_request("GET", "/_matrix/client/r0/password_policy") - - self.assertEqual(channel.code, 200, channel.result) - self.assertEqual( - channel.json_body, - { - "m.minimum_length": 10, - "m.require_digit": True, - "m.require_symbol": True, - "m.require_lowercase": True, - "m.require_uppercase": True, - }, - channel.result, - ) - - def test_password_too_short(self): - request_data = json.dumps({"username": "kermit", "password": "shorty"}) - channel = self.make_request("POST", self.register_url, request_data) - - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.PASSWORD_TOO_SHORT, - channel.result, - ) - - def test_password_no_digit(self): - request_data = json.dumps({"username": "kermit", "password": "longerpassword"}) - channel = self.make_request("POST", self.register_url, request_data) - - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.PASSWORD_NO_DIGIT, - channel.result, - ) - - def test_password_no_symbol(self): - request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"}) - channel = self.make_request("POST", self.register_url, request_data) - - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.PASSWORD_NO_SYMBOL, - channel.result, - ) - - def test_password_no_uppercase(self): - request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"}) - channel = self.make_request("POST", self.register_url, request_data) - - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.PASSWORD_NO_UPPERCASE, - channel.result, - ) - - def test_password_no_lowercase(self): - request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"}) - channel = self.make_request("POST", self.register_url, request_data) - - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual( - channel.json_body["errcode"], - Codes.PASSWORD_NO_LOWERCASE, - channel.result, - ) - - def test_password_compliant(self): - request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"}) - channel = self.make_request("POST", self.register_url, request_data) - - # Getting a 401 here means the password has passed validation and the server has - # responded with a list of registration flows. - self.assertEqual(channel.code, 401, channel.result) - - def test_password_change(self): - """This doesn't test every possible use case, only that hitting /account/password - triggers the password validation code. - """ - compliant_password = "C0mpl!antpassword" - not_compliant_password = "notcompliantpassword" - - user_id = self.register_user("kermit", compliant_password) - tok = self.login("kermit", compliant_password) - - request_data = json.dumps( - { - "new_password": not_compliant_password, - "auth": { - "password": compliant_password, - "type": LoginType.PASSWORD, - "user": user_id, - }, - } - ) - channel = self.make_request( - "POST", - "/_matrix/client/r0/account/password", - request_data, - access_token=tok, - ) - - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual(channel.json_body["errcode"], Codes.PASSWORD_NO_DIGIT) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py deleted file mode 100644 index fecda037a5..0000000000 --- a/tests/rest/client/v2_alpha/test_register.py +++ /dev/null @@ -1,746 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import datetime -import json -import os - -import pkg_resources - -import synapse.rest.admin -from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType -from synapse.api.errors import Codes -from synapse.appservice import ApplicationService -from synapse.rest.client import account, account_validity, login, logout, register, sync - -from tests import unittest -from tests.unittest import override_config - - -class RegisterRestServletTestCase(unittest.HomeserverTestCase): - - servlets = [ - login.register_servlets, - register.register_servlets, - synapse.rest.admin.register_servlets, - ] - url = b"/_matrix/client/r0/register" - - def default_config(self): - config = super().default_config() - config["allow_guest_access"] = True - return config - - def test_POST_appservice_registration_valid(self): - user_id = "@as_user_kermit:test" - as_token = "i_am_an_app_service" - - appservice = ApplicationService( - as_token, - self.hs.config.server_name, - id="1234", - namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, - sender="@as:test", - ) - - self.hs.get_datastore().services_cache.append(appservice) - request_data = json.dumps( - {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE} - ) - - channel = self.make_request( - b"POST", self.url + b"?access_token=i_am_an_app_service", request_data - ) - - self.assertEquals(channel.result["code"], b"200", channel.result) - det_data = {"user_id": user_id, "home_server": self.hs.hostname} - self.assertDictContainsSubset(det_data, channel.json_body) - - def test_POST_appservice_registration_no_type(self): - as_token = "i_am_an_app_service" - - appservice = ApplicationService( - as_token, - self.hs.config.server_name, - id="1234", - namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, - sender="@as:test", - ) - - self.hs.get_datastore().services_cache.append(appservice) - request_data = json.dumps({"username": "as_user_kermit"}) - - channel = self.make_request( - b"POST", self.url + b"?access_token=i_am_an_app_service", request_data - ) - - self.assertEquals(channel.result["code"], b"400", channel.result) - - def test_POST_appservice_registration_invalid(self): - self.appservice = None # no application service exists - request_data = json.dumps( - {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE} - ) - channel = self.make_request( - b"POST", self.url + b"?access_token=i_am_an_app_service", request_data - ) - - self.assertEquals(channel.result["code"], b"401", channel.result) - - def test_POST_bad_password(self): - request_data = json.dumps({"username": "kermit", "password": 666}) - channel = self.make_request(b"POST", self.url, request_data) - - self.assertEquals(channel.result["code"], b"400", channel.result) - self.assertEquals(channel.json_body["error"], "Invalid password") - - def test_POST_bad_username(self): - request_data = json.dumps({"username": 777, "password": "monkey"}) - channel = self.make_request(b"POST", self.url, request_data) - - self.assertEquals(channel.result["code"], b"400", channel.result) - self.assertEquals(channel.json_body["error"], "Invalid username") - - def test_POST_user_valid(self): - user_id = "@kermit:test" - device_id = "frogfone" - params = { - "username": "kermit", - "password": "monkey", - "device_id": device_id, - "auth": {"type": LoginType.DUMMY}, - } - request_data = json.dumps(params) - channel = self.make_request(b"POST", self.url, request_data) - - det_data = { - "user_id": user_id, - "home_server": self.hs.hostname, - "device_id": device_id, - } - self.assertEquals(channel.result["code"], b"200", channel.result) - self.assertDictContainsSubset(det_data, channel.json_body) - - @override_config({"enable_registration": False}) - def test_POST_disabled_registration(self): - request_data = json.dumps({"username": "kermit", "password": "monkey"}) - self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) - - channel = self.make_request(b"POST", self.url, request_data) - - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals(channel.json_body["error"], "Registration has been disabled") - self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN") - - def test_POST_guest_registration(self): - self.hs.config.macaroon_secret_key = "test" - self.hs.config.allow_guest_access = True - - channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - - det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} - self.assertEquals(channel.result["code"], b"200", channel.result) - self.assertDictContainsSubset(det_data, channel.json_body) - - def test_POST_disabled_guest_registration(self): - self.hs.config.allow_guest_access = False - - channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals(channel.json_body["error"], "Guest access is disabled") - - @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) - def test_POST_ratelimiting_guest(self): - for i in range(0, 6): - url = self.url + b"?kind=guest" - channel = self.make_request(b"POST", url, b"{}") - - if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) - retry_after_ms = int(channel.json_body["retry_after_ms"]) - else: - self.assertEquals(channel.result["code"], b"200", channel.result) - - self.reactor.advance(retry_after_ms / 1000.0 + 1.0) - - channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - - self.assertEquals(channel.result["code"], b"200", channel.result) - - @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) - def test_POST_ratelimiting(self): - for i in range(0, 6): - params = { - "username": "kermit" + str(i), - "password": "monkey", - "device_id": "frogfone", - "auth": {"type": LoginType.DUMMY}, - } - request_data = json.dumps(params) - channel = self.make_request(b"POST", self.url, request_data) - - if i == 5: - self.assertEquals(channel.result["code"], b"429", channel.result) - retry_after_ms = int(channel.json_body["retry_after_ms"]) - else: - self.assertEquals(channel.result["code"], b"200", channel.result) - - self.reactor.advance(retry_after_ms / 1000.0 + 1.0) - - channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - - self.assertEquals(channel.result["code"], b"200", channel.result) - - def test_advertised_flows(self): - channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) - flows = channel.json_body["flows"] - - # with the stock config, we only expect the dummy flow - self.assertCountEqual([["m.login.dummy"]], (f["stages"] for f in flows)) - - @unittest.override_config( - { - "public_baseurl": "https://test_server", - "enable_registration_captcha": True, - "user_consent": { - "version": "1", - "template_dir": "/", - "require_at_registration": True, - }, - "account_threepid_delegates": { - "email": "https://id_server", - "msisdn": "https://id_server", - }, - } - ) - def test_advertised_flows_captcha_and_terms_and_3pids(self): - channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) - flows = channel.json_body["flows"] - - self.assertCountEqual( - [ - ["m.login.recaptcha", "m.login.terms", "m.login.dummy"], - ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"], - ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"], - [ - "m.login.recaptcha", - "m.login.terms", - "m.login.msisdn", - "m.login.email.identity", - ], - ], - (f["stages"] for f in flows), - ) - - @unittest.override_config( - { - "public_baseurl": "https://test_server", - "registrations_require_3pid": ["email"], - "disable_msisdn_registration": True, - "email": { - "smtp_host": "mail_server", - "smtp_port": 2525, - "notif_from": "sender@host", - }, - } - ) - def test_advertised_flows_no_msisdn_email_required(self): - channel = self.make_request(b"POST", self.url, b"{}") - self.assertEquals(channel.result["code"], b"401", channel.result) - flows = channel.json_body["flows"] - - # with the stock config, we expect all four combinations of 3pid - self.assertCountEqual( - [["m.login.email.identity"]], (f["stages"] for f in flows) - ) - - @unittest.override_config( - { - "request_token_inhibit_3pid_errors": True, - "public_baseurl": "https://test_server", - "email": { - "smtp_host": "mail_server", - "smtp_port": 2525, - "notif_from": "sender@host", - }, - } - ) - def test_request_token_existing_email_inhibit_error(self): - """Test that requesting a token via this endpoint doesn't leak existing - associations if configured that way. - """ - user_id = self.register_user("kermit", "monkey") - self.login("kermit", "monkey") - - email = "test@example.com" - - # Add a threepid - self.get_success( - self.hs.get_datastore().user_add_threepid( - user_id=user_id, - medium="email", - address=email, - validated_at=0, - added_at=0, - ) - ) - - channel = self.make_request( - "POST", - b"register/email/requestToken", - {"client_secret": "foobar", "email": email, "send_attempt": 1}, - ) - self.assertEquals(200, channel.code, channel.result) - - self.assertIsNotNone(channel.json_body.get("sid")) - - @unittest.override_config( - { - "public_baseurl": "https://test_server", - "email": { - "smtp_host": "mail_server", - "smtp_port": 2525, - "notif_from": "sender@host", - }, - } - ) - def test_reject_invalid_email(self): - """Check that bad emails are rejected""" - - # Test for email with multiple @ - channel = self.make_request( - "POST", - b"register/email/requestToken", - {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1}, - ) - self.assertEquals(400, channel.code, channel.result) - # Check error to ensure that we're not erroring due to a bug in the test. - self.assertEquals( - channel.json_body, - {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, - ) - - # Test for email with no @ - channel = self.make_request( - "POST", - b"register/email/requestToken", - {"client_secret": "foobar", "email": "email", "send_attempt": 1}, - ) - self.assertEquals(400, channel.code, channel.result) - self.assertEquals( - channel.json_body, - {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, - ) - - # Test for super long email - email = "a@" + "a" * 1000 - channel = self.make_request( - "POST", - b"register/email/requestToken", - {"client_secret": "foobar", "email": email, "send_attempt": 1}, - ) - self.assertEquals(400, channel.code, channel.result) - self.assertEquals( - channel.json_body, - {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, - ) - - -class AccountValidityTestCase(unittest.HomeserverTestCase): - - servlets = [ - register.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - sync.register_servlets, - logout.register_servlets, - account_validity.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - # Test for account expiring after a week. - config["enable_registration"] = True - config["account_validity"] = { - "enabled": True, - "period": 604800000, # Time in ms for 1 week - } - self.hs = self.setup_test_homeserver(config=config) - - return self.hs - - def test_validity_period(self): - self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - - # The specific endpoint doesn't matter, all we need is an authenticated - # endpoint. - channel = self.make_request(b"GET", "/sync", access_token=tok) - - self.assertEquals(channel.result["code"], b"200", channel.result) - - self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) - - channel = self.make_request(b"GET", "/sync", access_token=tok) - - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals( - channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result - ) - - def test_manual_renewal(self): - user_id = self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - - self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) - - # If we register the admin user at the beginning of the test, it will - # expire at the same time as the normal user and the renewal request - # will be denied. - self.register_user("admin", "adminpassword", admin=True) - admin_tok = self.login("admin", "adminpassword") - - url = "/_synapse/admin/v1/account_validity/validity" - params = {"user_id": user_id} - request_data = json.dumps(params) - channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - # The specific endpoint doesn't matter, all we need is an authenticated - # endpoint. - channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - def test_manual_expire(self): - user_id = self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - - self.register_user("admin", "adminpassword", admin=True) - admin_tok = self.login("admin", "adminpassword") - - url = "/_synapse/admin/v1/account_validity/validity" - params = { - "user_id": user_id, - "expiration_ts": 0, - "enable_renewal_emails": False, - } - request_data = json.dumps(params) - channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - # The specific endpoint doesn't matter, all we need is an authenticated - # endpoint. - channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"403", channel.result) - self.assertEquals( - channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result - ) - - def test_logging_out_expired_user(self): - user_id = self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - - self.register_user("admin", "adminpassword", admin=True) - admin_tok = self.login("admin", "adminpassword") - - url = "/_synapse/admin/v1/account_validity/validity" - params = { - "user_id": user_id, - "expiration_ts": 0, - "enable_renewal_emails": False, - } - request_data = json.dumps(params) - channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - # Try to log the user out - channel = self.make_request(b"POST", "/logout", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - # Log the user in again (allowed for expired accounts) - tok = self.login("kermit", "monkey") - - # Try to log out all of the user's sessions - channel = self.make_request(b"POST", "/logout/all", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - -class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): - - servlets = [ - register.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - login.register_servlets, - sync.register_servlets, - account_validity.register_servlets, - account.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - - # Test for account expiring after a week and renewal emails being sent 2 - # days before expiry. - config["enable_registration"] = True - config["account_validity"] = { - "enabled": True, - "period": 604800000, # Time in ms for 1 week - "renew_at": 172800000, # Time in ms for 2 days - "renew_by_email_enabled": True, - "renew_email_subject": "Renew your account", - "account_renewed_html_path": "account_renewed.html", - "invalid_token_html_path": "invalid_token.html", - } - - # Email config. - - config["email"] = { - "enable_notifs": True, - "template_dir": os.path.abspath( - pkg_resources.resource_filename("synapse", "res/templates") - ), - "expiry_template_html": "notice_expiry.html", - "expiry_template_text": "notice_expiry.txt", - "notif_template_html": "notif_mail.html", - "notif_template_text": "notif_mail.txt", - "smtp_host": "127.0.0.1", - "smtp_port": 20, - "require_transport_security": False, - "smtp_user": None, - "smtp_pass": None, - "notif_from": "test@example.com", - } - config["public_baseurl"] = "aaa" - - self.hs = self.setup_test_homeserver(config=config) - - async def sendmail(*args, **kwargs): - self.email_attempts.append((args, kwargs)) - - self.email_attempts = [] - self.hs.get_send_email_handler()._sendmail = sendmail - - self.store = self.hs.get_datastore() - - return self.hs - - def test_renewal_email(self): - self.email_attempts = [] - - (user_id, tok) = self.create_user() - - # Move 5 days forward. This should trigger a renewal email to be sent. - self.reactor.advance(datetime.timedelta(days=5).total_seconds()) - self.assertEqual(len(self.email_attempts), 1) - - # Retrieving the URL from the email is too much pain for now, so we - # retrieve the token from the DB. - renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) - url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token - channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"200", channel.result) - - # Check that we're getting HTML back. - content_type = channel.headers.getRawHeaders(b"Content-Type") - self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) - - # Check that the HTML we're getting is the one we expect on a successful renewal. - expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id)) - expected_html = self.hs.config.account_validity.account_validity_account_renewed_template.render( - expiration_ts=expiration_ts - ) - self.assertEqual( - channel.result["body"], expected_html.encode("utf8"), channel.result - ) - - # Move 1 day forward. Try to renew with the same token again. - url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token - channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"200", channel.result) - - # Check that we're getting HTML back. - content_type = channel.headers.getRawHeaders(b"Content-Type") - self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) - - # Check that the HTML we're getting is the one we expect when reusing a - # token. The account expiration date should not have changed. - expected_html = self.hs.config.account_validity.account_validity_account_previously_renewed_template.render( - expiration_ts=expiration_ts - ) - self.assertEqual( - channel.result["body"], expected_html.encode("utf8"), channel.result - ) - - # Move 3 days forward. If the renewal failed, every authed request with - # our access token should be denied from now, otherwise they should - # succeed. - self.reactor.advance(datetime.timedelta(days=3).total_seconds()) - channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEquals(channel.result["code"], b"200", channel.result) - - def test_renewal_invalid_token(self): - # Hit the renewal endpoint with an invalid token and check that it behaves as - # expected, i.e. that it responds with 404 Not Found and the correct HTML. - url = "/_matrix/client/unstable/account_validity/renew?token=123" - channel = self.make_request(b"GET", url) - self.assertEquals(channel.result["code"], b"404", channel.result) - - # Check that we're getting HTML back. - content_type = channel.headers.getRawHeaders(b"Content-Type") - self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result) - - # Check that the HTML we're getting is the one we expect when using an - # invalid/unknown token. - expected_html = ( - self.hs.config.account_validity.account_validity_invalid_token_template.render() - ) - self.assertEqual( - channel.result["body"], expected_html.encode("utf8"), channel.result - ) - - def test_manual_email_send(self): - self.email_attempts = [] - - (user_id, tok) = self.create_user() - channel = self.make_request( - b"POST", - "/_matrix/client/unstable/account_validity/send_mail", - access_token=tok, - ) - self.assertEquals(channel.result["code"], b"200", channel.result) - - self.assertEqual(len(self.email_attempts), 1) - - def test_deactivated_user(self): - self.email_attempts = [] - - (user_id, tok) = self.create_user() - - request_data = json.dumps( - { - "auth": { - "type": "m.login.password", - "user": user_id, - "password": "monkey", - }, - "erase": False, - } - ) - channel = self.make_request( - "POST", "account/deactivate", request_data, access_token=tok - ) - self.assertEqual(channel.code, 200) - - self.reactor.advance(datetime.timedelta(days=8).total_seconds()) - - self.assertEqual(len(self.email_attempts), 0) - - def create_user(self): - user_id = self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - # We need to manually add an email address otherwise the handler will do - # nothing. - now = self.hs.get_clock().time_msec() - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address="kermit@example.com", - validated_at=now, - added_at=now, - ) - ) - return user_id, tok - - def test_manual_email_send_expired_account(self): - user_id = self.register_user("kermit", "monkey") - tok = self.login("kermit", "monkey") - - # We need to manually add an email address otherwise the handler will do - # nothing. - now = self.hs.get_clock().time_msec() - self.get_success( - self.store.user_add_threepid( - user_id=user_id, - medium="email", - address="kermit@example.com", - validated_at=now, - added_at=now, - ) - ) - - # Make the account expire. - self.reactor.advance(datetime.timedelta(days=8).total_seconds()) - - # Ignore all emails sent by the automatic background task and only focus on the - # ones sent manually. - self.email_attempts = [] - - # Test that we're still able to manually trigger a mail to be sent. - channel = self.make_request( - b"POST", - "/_matrix/client/unstable/account_validity/send_mail", - access_token=tok, - ) - self.assertEquals(channel.result["code"], b"200", channel.result) - - self.assertEqual(len(self.email_attempts), 1) - - -class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): - - servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource] - - def make_homeserver(self, reactor, clock): - self.validity_period = 10 - self.max_delta = self.validity_period * 10.0 / 100.0 - - config = self.default_config() - - config["enable_registration"] = True - config["account_validity"] = {"enabled": False} - - self.hs = self.setup_test_homeserver(config=config) - - # We need to set these directly, instead of in the homeserver config dict above. - # This is due to account validity-related config options not being read by - # Synapse when account_validity.enabled is False. - self.hs.get_datastore()._account_validity_period = self.validity_period - self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta - - self.store = self.hs.get_datastore() - - return self.hs - - def test_background_job(self): - """ - Tests the same thing as test_background_job, except that it sets the - startup_job_max_delta parameter and checks that the expiration date is within the - allowed range. - """ - user_id = self.register_user("kermit_delta", "user") - - self.hs.config.account_validity.startup_job_max_delta = self.max_delta - - now_ms = self.hs.get_clock().time_msec() - self.get_success(self.store._set_expiration_date_when_missing()) - - res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) - - self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) - self.assertLessEqual(res, now_ms + self.validity_period) diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py deleted file mode 100644 index 02b5e9a8d0..0000000000 --- a/tests/rest/client/v2_alpha/test_relations.py +++ /dev/null @@ -1,724 +0,0 @@ -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import itertools -import json -import urllib -from typing import Optional - -from synapse.api.constants import EventTypes, RelationTypes -from synapse.rest import admin -from synapse.rest.client import login, register, relations, room - -from tests import unittest - - -class RelationsTestCase(unittest.HomeserverTestCase): - servlets = [ - relations.register_servlets, - room.register_servlets, - login.register_servlets, - register.register_servlets, - admin.register_servlets_for_client_rest_resource, - ] - hijack_auth = False - - def make_homeserver(self, reactor, clock): - # We need to enable msc1849 support for aggregations - config = self.default_config() - config["experimental_msc1849_support_enabled"] = True - - # We enable frozen dicts as relations/edits change event contents, so we - # want to test that we don't modify the events in the caches. - config["use_frozen_dicts"] = True - - return self.setup_test_homeserver(config=config) - - def prepare(self, reactor, clock, hs): - self.user_id, self.user_token = self._create_user("alice") - self.user2_id, self.user2_token = self._create_user("bob") - - self.room = self.helper.create_room_as(self.user_id, tok=self.user_token) - self.helper.join(self.room, user=self.user2_id, tok=self.user2_token) - res = self.helper.send(self.room, body="Hi!", tok=self.user_token) - self.parent_id = res["event_id"] - - def test_send_relation(self): - """Tests that sending a relation using the new /send_relation works - creates the right shape of event. - """ - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") - self.assertEquals(200, channel.code, channel.json_body) - - event_id = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, event_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assert_dict( - { - "type": "m.reaction", - "sender": self.user_id, - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "key": "👍", - "rel_type": RelationTypes.ANNOTATION, - } - }, - }, - channel.json_body, - ) - - def test_deny_membership(self): - """Test that we deny relations on membership events""" - channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) - self.assertEquals(400, channel.code, channel.json_body) - - def test_deny_double_react(self): - """Test that we deny relations on membership events""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(400, channel.code, channel.json_body) - - def test_basic_paginate_relations(self): - """Tests that calling pagination API correctly the latest relations.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") - self.assertEquals(200, channel.code, channel.json_body) - annotation_id = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" - % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - # We expect to get back a single pagination result, which is the full - # relation event we sent above. - self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) - self.assert_dict( - {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"}, - channel.json_body["chunk"][0], - ) - - # We also expect to get the original event (the id of which is self.parent_id) - self.assertEquals( - channel.json_body["original_event"]["event_id"], self.parent_id - ) - - # Make sure next_batch has something in it that looks like it could be a - # valid token. - self.assertIsInstance( - channel.json_body.get("next_batch"), str, channel.json_body - ) - - def test_repeated_paginate_relations(self): - """Test that if we paginate using a limit and tokens then we get the - expected events. - """ - - expected_event_ids = [] - for _ in range(10): - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction") - self.assertEquals(200, channel.code, channel.json_body) - expected_event_ids.append(channel.json_body["event_id"]) - - prev_token = None - found_event_ids = [] - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s" - % (self.room, self.parent_id, from_token), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - next_batch = channel.json_body.get("next_batch") - - self.assertNotEquals(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) - - def test_aggregation_pagination_groups(self): - """Test that we can paginate annotation groups correctly.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - sent_groups = {"👍": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} - for key in itertools.chain.from_iterable( - itertools.repeat(key, num) for key, num in sent_groups.items() - ): - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key=key, - access_token=access_tokens[idx], - ) - self.assertEquals(200, channel.code, channel.json_body) - - idx += 1 - idx %= len(access_tokens) - - prev_token = None - found_groups = {} - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s" - % (self.room, self.parent_id, from_token), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - for groups in channel.json_body["chunk"]: - # We only expect reactions - self.assertEqual(groups["type"], "m.reaction", channel.json_body) - - # We should only see each key once - self.assertNotIn(groups["key"], found_groups, channel.json_body) - - found_groups[groups["key"]] = groups["count"] - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEquals(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - self.assertEquals(sent_groups, found_groups) - - def test_aggregation_pagination_within_group(self): - """Test that we can paginate within an annotation group.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - expected_event_ids = [] - for _ in range(10): - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key="👍", - access_token=access_tokens[idx], - ) - self.assertEquals(200, channel.code, channel.json_body) - expected_event_ids.append(channel.json_body["event_id"]) - - idx += 1 - - # Also send a different type of reaction so that we test we don't see it - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEquals(200, channel.code, channel.json_body) - - prev_token = None - found_event_ids = [] - encoded_key = urllib.parse.quote_plus("👍".encode()) - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s" - "/aggregations/%s/%s/m.reaction/%s?limit=1%s" - % ( - self.room, - self.parent_id, - RelationTypes.ANNOTATION, - encoded_key, - from_token, - ), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEquals(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEquals(found_event_ids, expected_event_ids) - - def test_aggregation(self): - """Test that annotations get correctly aggregated.""" - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEquals(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s" - % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEquals( - channel.json_body, - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - ) - - def test_aggregation_redactions(self): - """Test that annotations get correctly aggregated after a redaction.""" - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) - to_redact_event_id = channel.json_body["event_id"] - - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self.assertEquals(200, channel.code, channel.json_body) - - # Now lets redact one of the 'a' reactions - channel = self.make_request( - "POST", - "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id), - access_token=self.user_token, - content={}, - ) - self.assertEquals(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s" - % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEquals( - channel.json_body, - {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]}, - ) - - def test_aggregation_must_be_annotation(self): - """Test that aggregations must be annotations.""" - - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1" - % (self.room, self.parent_id, RelationTypes.REPLACE), - access_token=self.user_token, - ) - self.assertEquals(400, channel.code, channel.json_body) - - def test_aggregation_get_event(self): - """Test that annotations and references get correctly bundled when - getting the parent event. - """ - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEquals(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) - reply_1 = channel.json_body["event_id"] - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEquals(200, channel.code, channel.json_body) - reply_2 = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEquals( - channel.json_body["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - RelationTypes.REFERENCE: { - "chunk": [{"event_id": reply_1}, {"event_id": reply_2}] - }, - }, - ) - - def test_edit(self): - """Test that a simple edit works.""" - - new_body = {"msgtype": "m.text", "body": "I've been edited!"} - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, - ) - self.assertEquals(200, channel.code, channel.json_body) - - edit_event_id = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEquals(channel.json_body["content"], new_body) - - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) - - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) - - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict - ) - - def test_multi_edit(self): - """Test that multiple edits, including attempts by people who - shouldn't be allowed, are correctly handled. - """ - - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={ - "msgtype": "m.text", - "body": "Wibble", - "m.new_content": {"msgtype": "m.text", "body": "First edit"}, - }, - ) - self.assertEquals(200, channel.code, channel.json_body) - - new_body = {"msgtype": "m.text", "body": "I've been edited!"} - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, - ) - self.assertEquals(200, channel.code, channel.json_body) - - edit_event_id = channel.json_body["event_id"] - - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message.WRONG_TYPE", - content={ - "msgtype": "m.text", - "body": "Wibble", - "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, - }, - ) - self.assertEquals(200, channel.code, channel.json_body) - - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertEquals(channel.json_body["content"], new_body) - - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) - - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) - - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict - ) - - def test_edit_reply(self): - """Test that editing a reply works.""" - - # Create a reply to edit. - channel = self._send_relation( - RelationTypes.REFERENCE, - "m.room.message", - content={"msgtype": "m.text", "body": "A reply!"}, - ) - self.assertEquals(200, channel.code, channel.json_body) - reply = channel.json_body["event_id"] - - new_body = {"msgtype": "m.text", "body": "I've been edited!"} - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, - parent_id=reply, - ) - self.assertEquals(200, channel.code, channel.json_body) - - edit_event_id = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - "/rooms/%s/event/%s" % (self.room, reply), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - # We expect to see the new body in the dict, as well as the reference - # metadata sill intact. - self.assertDictContainsSubset(new_body, channel.json_body["content"]) - self.assertDictContainsSubset( - { - "m.relates_to": { - "event_id": self.parent_id, - "key": None, - "rel_type": "m.reference", - } - }, - channel.json_body["content"], - ) - - # We expect that the edit relation appears in the unsigned relations - # section. - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) - - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) - - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict - ) - - def test_relations_redaction_redacts_edits(self): - """Test that edits of an event are redacted when the original event - is redacted. - """ - # Send a new event - res = self.helper.send(self.room, body="Heyo!", tok=self.user_token) - original_event_id = res["event_id"] - - # Add a relation - channel = self._send_relation( - RelationTypes.REPLACE, - "m.room.message", - parent_id=original_event_id, - content={ - "msgtype": "m.text", - "body": "Wibble", - "m.new_content": {"msgtype": "m.text", "body": "First edit"}, - }, - ) - self.assertEquals(200, channel.code, channel.json_body) - - # Check the relation is returned - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" - % (self.room, original_event_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertIn("chunk", channel.json_body) - self.assertEquals(len(channel.json_body["chunk"]), 1) - - # Redact the original event - channel = self.make_request( - "PUT", - "/rooms/%s/redact/%s/%s" - % (self.room, original_event_id, "test_relations_redaction_redacts_edits"), - access_token=self.user_token, - content="{}", - ) - self.assertEquals(200, channel.code, channel.json_body) - - # Try to check for remaining m.replace relations - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message" - % (self.room, original_event_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - # Check that no relations are returned - self.assertIn("chunk", channel.json_body) - self.assertEquals(channel.json_body["chunk"], []) - - def test_aggregations_redaction_prevents_access_to_aggregations(self): - """Test that annotations of an event are redacted when the original event - is redacted. - """ - # Send a new event - res = self.helper.send(self.room, body="Hello!", tok=self.user_token) - original_event_id = res["event_id"] - - # Add a relation - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", key="👍", parent_id=original_event_id - ) - self.assertEquals(200, channel.code, channel.json_body) - - # Redact the original - channel = self.make_request( - "PUT", - "/rooms/%s/redact/%s/%s" - % ( - self.room, - original_event_id, - "test_aggregations_redaction_prevents_access_to_aggregations", - ), - access_token=self.user_token, - content="{}", - ) - self.assertEquals(200, channel.code, channel.json_body) - - # Check that aggregations returns zero - channel = self.make_request( - "GET", - "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction" - % (self.room, original_event_id), - access_token=self.user_token, - ) - self.assertEquals(200, channel.code, channel.json_body) - - self.assertIn("chunk", channel.json_body) - self.assertEquals(channel.json_body["chunk"], []) - - def _send_relation( - self, - relation_type, - event_type, - key=None, - content: Optional[dict] = None, - access_token=None, - parent_id=None, - ): - """Helper function to send a relation pointing at `self.parent_id` - - Args: - relation_type (str): One of `RelationTypes` - event_type (str): The type of the event to create - parent_id (str): The event_id this relation relates to. If None, then self.parent_id - key (str|None): The aggregation key used for m.annotation relation - type. - content(dict|None): The content of the created event. - access_token (str|None): The access token used to send the relation, - defaults to `self.user_token` - - Returns: - FakeChannel - """ - if not access_token: - access_token = self.user_token - - query = "" - if key: - query = "?key=" + urllib.parse.quote_plus(key.encode("utf-8")) - - original_id = parent_id if parent_id else self.parent_id - - channel = self.make_request( - "POST", - "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" - % (self.room, original_id, relation_type, event_type, query), - json.dumps(content or {}).encode("utf-8"), - access_token=access_token, - ) - return channel - - def _create_user(self, localpart): - user_id = self.register_user(localpart, "abc123") - access_token = self.login(localpart, "abc123") - - return user_id, access_token diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/v2_alpha/test_report_event.py deleted file mode 100644 index ee6b0b9ebf..0000000000 --- a/tests/rest/client/v2_alpha/test_report_event.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2021 Callum Brown -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -import synapse.rest.admin -from synapse.rest.client import login, report_event, room - -from tests import unittest - - -class ReportEventTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - report_event.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") - self.other_user = self.register_user("user", "pass") - self.other_user_tok = self.login("user", "pass") - - self.room_id = self.helper.create_room_as( - self.other_user, tok=self.other_user_tok, is_public=True - ) - self.helper.join(self.room_id, user=self.admin_user, tok=self.admin_user_tok) - resp = self.helper.send(self.room_id, tok=self.admin_user_tok) - self.event_id = resp["event_id"] - self.report_path = f"rooms/{self.room_id}/report/{self.event_id}" - - def test_reason_str_and_score_int(self): - data = {"reason": "this makes me sad", "score": -100} - self._assert_status(200, data) - - def test_no_reason(self): - data = {"score": 0} - self._assert_status(200, data) - - def test_no_score(self): - data = {"reason": "this makes me sad"} - self._assert_status(200, data) - - def test_no_reason_and_no_score(self): - data = {} - self._assert_status(200, data) - - def test_reason_int_and_score_str(self): - data = {"reason": 10, "score": "string"} - self._assert_status(400, data) - - def test_reason_zero_and_score_blank(self): - data = {"reason": 0, "score": ""} - self._assert_status(400, data) - - def test_reason_and_score_null(self): - data = {"reason": None, "score": None} - self._assert_status(400, data) - - def _assert_status(self, response_status, data): - channel = self.make_request( - "POST", - self.report_path, - json.dumps(data), - access_token=self.other_user_tok, - ) - self.assertEqual( - response_status, int(channel.result["code"]), msg=channel.result["body"] - ) diff --git a/tests/rest/client/v2_alpha/test_sendtodevice.py b/tests/rest/client/v2_alpha/test_sendtodevice.py deleted file mode 100644 index 6db7062a8e..0000000000 --- a/tests/rest/client/v2_alpha/test_sendtodevice.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.rest import admin -from synapse.rest.client import login, sendtodevice, sync - -from tests.unittest import HomeserverTestCase, override_config - - -class SendToDeviceTestCase(HomeserverTestCase): - servlets = [ - admin.register_servlets, - login.register_servlets, - sendtodevice.register_servlets, - sync.register_servlets, - ] - - def test_user_to_user(self): - """A to-device message from one user to another should get delivered""" - - user1 = self.register_user("u1", "pass") - user1_tok = self.login("u1", "pass", "d1") - - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") - - # send the message - test_msg = {"foo": "bar"} - chan = self.make_request( - "PUT", - "/_matrix/client/r0/sendToDevice/m.test/1234", - content={"messages": {user2: {"d2": test_msg}}}, - access_token=user1_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - - # check it appears - channel = self.make_request("GET", "/sync", access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - expected_result = { - "events": [ - { - "sender": user1, - "type": "m.test", - "content": test_msg, - } - ] - } - self.assertEqual(channel.json_body["to_device"], expected_result) - - # it should re-appear if we do another sync - channel = self.make_request("GET", "/sync", access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - self.assertEqual(channel.json_body["to_device"], expected_result) - - # it should *not* appear if we do an incremental sync - sync_token = channel.json_body["next_batch"] - channel = self.make_request( - "GET", f"/sync?since={sync_token}", access_token=user2_tok - ) - self.assertEqual(channel.code, 200, channel.result) - self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), []) - - @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) - def test_local_room_key_request(self): - """m.room_key_request has special-casing; test from local user""" - user1 = self.register_user("u1", "pass") - user1_tok = self.login("u1", "pass", "d1") - - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") - - # send three messages - for i in range(3): - chan = self.make_request( - "PUT", - f"/_matrix/client/r0/sendToDevice/m.room_key_request/{i}", - content={"messages": {user2: {"d2": {"idx": i}}}}, - access_token=user1_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - - # now sync: we should get two of the three - channel = self.make_request("GET", "/sync", access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 2) - for i in range(2): - self.assertEqual( - msgs[i], - {"sender": user1, "type": "m.room_key_request", "content": {"idx": i}}, - ) - sync_token = channel.json_body["next_batch"] - - # ... time passes - self.reactor.advance(1) - - # and we can send more messages - chan = self.make_request( - "PUT", - "/_matrix/client/r0/sendToDevice/m.room_key_request/3", - content={"messages": {user2: {"d2": {"idx": 3}}}}, - access_token=user1_tok, - ) - self.assertEqual(chan.code, 200, chan.result) - - # ... which should arrive - channel = self.make_request( - "GET", f"/sync?since={sync_token}", access_token=user2_tok - ) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 1) - self.assertEqual( - msgs[0], - {"sender": user1, "type": "m.room_key_request", "content": {"idx": 3}}, - ) - - @override_config({"rc_key_requests": {"per_second": 10, "burst_count": 2}}) - def test_remote_room_key_request(self): - """m.room_key_request has special-casing; test from remote user""" - user2 = self.register_user("u2", "pass") - user2_tok = self.login("u2", "pass", "d2") - - federation_registry = self.hs.get_federation_registry() - - # send three messages - for i in range(3): - self.get_success( - federation_registry.on_edu( - "m.direct_to_device", - "remote_server", - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "messages": {user2: {"d2": {"idx": i}}}, - "message_id": f"{i}", - }, - ) - ) - - # now sync: we should get two of the three - channel = self.make_request("GET", "/sync", access_token=user2_tok) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 2) - for i in range(2): - self.assertEqual( - msgs[i], - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "content": {"idx": i}, - }, - ) - sync_token = channel.json_body["next_batch"] - - # ... time passes - self.reactor.advance(1) - - # and we can send more messages - self.get_success( - federation_registry.on_edu( - "m.direct_to_device", - "remote_server", - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "messages": {user2: {"d2": {"idx": 3}}}, - "message_id": "3", - }, - ) - ) - - # ... which should arrive - channel = self.make_request( - "GET", f"/sync?since={sync_token}", access_token=user2_tok - ) - self.assertEqual(channel.code, 200, channel.result) - msgs = channel.json_body["to_device"]["events"] - self.assertEqual(len(msgs), 1) - self.assertEqual( - msgs[0], - { - "sender": "@user:remote_server", - "type": "m.room_key_request", - "content": {"idx": 3}, - }, - ) diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py deleted file mode 100644 index 283eccd53f..0000000000 --- a/tests/rest/client/v2_alpha/test_shared_rooms.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright 2020 Half-Shot -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import synapse.rest.admin -from synapse.rest.client import login, room, shared_rooms - -from tests import unittest -from tests.server import FakeChannel - - -class UserSharedRoomsTest(unittest.HomeserverTestCase): - """ - Tests the UserSharedRoomsServlet. - """ - - servlets = [ - login.register_servlets, - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - shared_rooms.register_servlets, - ] - - def make_homeserver(self, reactor, clock): - config = self.default_config() - config["update_user_directory"] = True - return self.setup_test_homeserver(config=config) - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.handler = hs.get_user_directory_handler() - - def _get_shared_rooms(self, token, other_user) -> FakeChannel: - return self.make_request( - "GET", - "/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/%s" - % other_user, - access_token=token, - ) - - def test_shared_room_list_public(self): - """ - A room should show up in the shared list of rooms between two users - if it is public. - """ - self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True) - - def test_shared_room_list_private(self): - """ - A room should show up in the shared list of rooms between two users - if it is private. - """ - self._check_shared_rooms_with( - room_one_is_public=False, room_two_is_public=False - ) - - def test_shared_room_list_mixed(self): - """ - The shared room list between two users should contain both public and private - rooms. - """ - self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False) - - def _check_shared_rooms_with( - self, room_one_is_public: bool, room_two_is_public: bool - ): - """Checks that shared public or private rooms between two users appear in - their shared room lists - """ - u1 = self.register_user("user1", "pass") - u1_token = self.login(u1, "pass") - u2 = self.register_user("user2", "pass") - u2_token = self.login(u2, "pass") - - # Create a room. user1 invites user2, who joins - room_id_one = self.helper.create_room_as( - u1, is_public=room_one_is_public, tok=u1_token - ) - self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token) - self.helper.join(room_id_one, user=u2, tok=u2_token) - - # Check shared rooms from user1's perspective. - # We should see the one room in common - channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 1) - self.assertEquals(channel.json_body["joined"][0], room_id_one) - - # Create another room and invite user2 to it - room_id_two = self.helper.create_room_as( - u1, is_public=room_two_is_public, tok=u1_token - ) - self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token) - self.helper.join(room_id_two, user=u2, tok=u2_token) - - # Check shared rooms again. We should now see both rooms. - channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 2) - for room_id_id in channel.json_body["joined"]: - self.assertIn(room_id_id, [room_id_one, room_id_two]) - - def test_shared_room_list_after_leave(self): - """ - A room should no longer be considered shared if the other - user has left it. - """ - u1 = self.register_user("user1", "pass") - u1_token = self.login(u1, "pass") - u2 = self.register_user("user2", "pass") - u2_token = self.login(u2, "pass") - - room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) - self.helper.invite(room, src=u1, targ=u2, tok=u1_token) - self.helper.join(room, user=u2, tok=u2_token) - - # Assert user directory is not empty - channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 1) - self.assertEquals(channel.json_body["joined"][0], room) - - self.helper.leave(room, user=u1, tok=u1_token) - - # Check user1's view of shared rooms with user2 - channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 0) - - # Check user2's view of shared rooms with user1 - channel = self._get_shared_rooms(u2_token, u1) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 0) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py deleted file mode 100644 index 95be369d4b..0000000000 --- a/tests/rest/client/v2_alpha/test_sync.py +++ /dev/null @@ -1,686 +0,0 @@ -# Copyright 2018-2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import json - -import synapse.rest.admin -from synapse.api.constants import ( - EventContentFields, - EventTypes, - ReadReceiptEventFields, - RelationTypes, -) -from synapse.rest.client import knock, login, read_marker, receipts, room, sync - -from tests import unittest -from tests.federation.transport.test_knocking import ( - KnockingStrippedStateEventHelperMixin, -) -from tests.server import TimedOutException -from tests.unittest import override_config - - -class FilterTestCase(unittest.HomeserverTestCase): - - user_id = "@apple:test" - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - sync.register_servlets, - ] - - def test_sync_argless(self): - channel = self.make_request("GET", "/sync") - - self.assertEqual(channel.code, 200) - self.assertIn("next_batch", channel.json_body) - - -class SyncFilterTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - sync.register_servlets, - ] - - def test_sync_filter_labels(self): - """Test that we can filter by a label.""" - sync_filter = json.dumps( - { - "room": { - "timeline": { - "types": [EventTypes.Message], - "org.matrix.labels": ["#fun"], - } - } - } - ) - - events = self._test_sync_filter_labels(sync_filter) - - self.assertEqual(len(events), 2, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) - - def test_sync_filter_not_labels(self): - """Test that we can filter by the absence of a label.""" - sync_filter = json.dumps( - { - "room": { - "timeline": { - "types": [EventTypes.Message], - "org.matrix.not_labels": ["#fun"], - } - } - } - ) - - events = self._test_sync_filter_labels(sync_filter) - - self.assertEqual(len(events), 3, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "without label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) - self.assertEqual( - events[2]["content"]["body"], "with two wrong labels", events[2] - ) - - def test_sync_filter_labels_not_labels(self): - """Test that we can filter by both a label and the absence of another label.""" - sync_filter = json.dumps( - { - "room": { - "timeline": { - "types": [EventTypes.Message], - "org.matrix.labels": ["#work"], - "org.matrix.not_labels": ["#notfun"], - } - } - } - ) - - events = self._test_sync_filter_labels(sync_filter) - - self.assertEqual(len(events), 1, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) - - def _test_sync_filter_labels(self, sync_filter): - user_id = self.register_user("kermit", "test") - tok = self.login("kermit", "test") - - room_id = self.helper.create_room_as(user_id, tok=tok) - - self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - tok=tok, - ) - - self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "without label"}, - tok=tok, - ) - - self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with wrong label", - EventContentFields.LABELS: ["#work"], - }, - tok=tok, - ) - - self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with two wrong labels", - EventContentFields.LABELS: ["#work", "#notfun"], - }, - tok=tok, - ) - - self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - tok=tok, - ) - - channel = self.make_request( - "GET", "/sync?filter=%s" % sync_filter, access_token=tok - ) - self.assertEqual(channel.code, 200, channel.result) - - return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"] - - -class SyncTypingTests(unittest.HomeserverTestCase): - - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - sync.register_servlets, - ] - user_id = True - hijack_auth = False - - def test_sync_backwards_typing(self): - """ - If the typing serial goes backwards and the typing handler is then reset - (such as when the master restarts and sets the typing serial to 0), we - do not incorrectly return typing information that had a serial greater - than the now-reset serial. - """ - typing_url = "/rooms/%s/typing/%s?access_token=%s" - sync_url = "/sync?timeout=3000000&access_token=%s&since=%s" - - # Register the user who gets notified - user_id = self.register_user("user", "pass") - access_token = self.login("user", "pass") - - # Register the user who sends the message - other_user_id = self.register_user("otheruser", "pass") - other_access_token = self.login("otheruser", "pass") - - # Create a room - room = self.helper.create_room_as(user_id, tok=access_token) - - # Invite the other person - self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id) - - # The other user joins - self.helper.join(room=room, user=other_user_id, tok=other_access_token) - - # The other user sends some messages - self.helper.send(room, body="Hi!", tok=other_access_token) - self.helper.send(room, body="There!", tok=other_access_token) - - # Start typing. - channel = self.make_request( - "PUT", - typing_url % (room, other_user_id, other_access_token), - b'{"typing": true, "timeout": 30000}', - ) - self.assertEquals(200, channel.code) - - channel = self.make_request("GET", "/sync?access_token=%s" % (access_token,)) - self.assertEquals(200, channel.code) - next_batch = channel.json_body["next_batch"] - - # Stop typing. - channel = self.make_request( - "PUT", - typing_url % (room, other_user_id, other_access_token), - b'{"typing": false}', - ) - self.assertEquals(200, channel.code) - - # Start typing. - channel = self.make_request( - "PUT", - typing_url % (room, other_user_id, other_access_token), - b'{"typing": true, "timeout": 30000}', - ) - self.assertEquals(200, channel.code) - - # Should return immediately - channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) - next_batch = channel.json_body["next_batch"] - - # Reset typing serial back to 0, as if the master had. - typing = self.hs.get_typing_handler() - typing._latest_room_serial = 0 - - # Since it checks the state token, we need some state to update to - # invalidate the stream token. - self.helper.send(room, body="There!", tok=other_access_token) - - channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) - next_batch = channel.json_body["next_batch"] - - # This should time out! But it does not, because our stream token is - # ahead, and therefore it's saying the typing (that we've actually - # already seen) is new, since it's got a token above our new, now-reset - # stream token. - channel = self.make_request("GET", sync_url % (access_token, next_batch)) - self.assertEquals(200, channel.code) - next_batch = channel.json_body["next_batch"] - - # Clear the typing information, so that it doesn't think everything is - # in the future. - typing._reset() - - # Now it SHOULD fail as it never completes! - with self.assertRaises(TimedOutException): - self.make_request("GET", sync_url % (access_token, next_batch)) - - -class SyncKnockTestCase( - unittest.HomeserverTestCase, KnockingStrippedStateEventHelperMixin -): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - room.register_servlets, - sync.register_servlets, - knock.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.url = "/sync?since=%s" - self.next_batch = "s0" - - # Register the first user (used to create the room to knock on). - self.user_id = self.register_user("kermit", "monkey") - self.tok = self.login("kermit", "monkey") - - # Create the room we'll knock on. - self.room_id = self.helper.create_room_as( - self.user_id, - is_public=False, - room_version="7", - tok=self.tok, - ) - - # Register the second user (used to knock on the room). - self.knocker = self.register_user("knocker", "monkey") - self.knocker_tok = self.login("knocker", "monkey") - - # Perform an initial sync for the knocking user. - channel = self.make_request( - "GET", - self.url % self.next_batch, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Store the next batch for the next request. - self.next_batch = channel.json_body["next_batch"] - - # Set up some room state to test with. - self.expected_room_state = self.send_example_state_events_to_room( - hs, self.room_id, self.user_id - ) - - @override_config({"experimental_features": {"msc2403_enabled": True}}) - def test_knock_room_state(self): - """Tests that /sync returns state from a room after knocking on it.""" - # Knock on a room - channel = self.make_request( - "POST", - "/_matrix/client/r0/knock/%s" % (self.room_id,), - b"{}", - self.knocker_tok, - ) - self.assertEquals(200, channel.code, channel.result) - - # We expect to see the knock event in the stripped room state later - self.expected_room_state[EventTypes.Member] = { - "content": {"membership": "knock", "displayname": "knocker"}, - "state_key": "@knocker:test", - } - - # Check that /sync includes stripped state from the room - channel = self.make_request( - "GET", - self.url % self.next_batch, - access_token=self.knocker_tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Extract the stripped room state events from /sync - knock_entry = channel.json_body["rooms"]["knock"] - room_state_events = knock_entry[self.room_id]["knock_state"]["events"] - - # Validate that the knock membership event came last - self.assertEqual(room_state_events[-1]["type"], EventTypes.Member) - - # Validate the stripped room state events - self.check_knock_room_state_against_room_state( - room_state_events, self.expected_room_state - ) - - -class ReadReceiptsTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - receipts.register_servlets, - room.register_servlets, - sync.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.url = "/sync?since=%s" - self.next_batch = "s0" - - # Register the first user - self.user_id = self.register_user("kermit", "monkey") - self.tok = self.login("kermit", "monkey") - - # Create the room - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - # Register the second user - self.user2 = self.register_user("kermit2", "monkey") - self.tok2 = self.login("kermit2", "monkey") - - # Join the second user - self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - - @override_config({"experimental_features": {"msc2285_enabled": True}}) - def test_hidden_read_receipts(self): - # Send a message as the first user - res = self.helper.send(self.room_id, body="hello", tok=self.tok) - - # Send a read receipt to tell the server the first user's message was read - body = json.dumps({ReadReceiptEventFields.MSC2285_HIDDEN: True}).encode("utf8") - channel = self.make_request( - "POST", - "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), - body, - access_token=self.tok2, - ) - self.assertEqual(channel.code, 200) - - # Test that the first user can't see the other user's hidden read receipt - self.assertEqual(self._get_read_receipt(), None) - - def test_read_receipt_with_empty_body(self): - # Send a message as the first user - res = self.helper.send(self.room_id, body="hello", tok=self.tok) - - # Send a read receipt for this message with an empty body - channel = self.make_request( - "POST", - "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), - access_token=self.tok2, - ) - self.assertEqual(channel.code, 200) - - def _get_read_receipt(self): - """Syncs and returns the read receipt.""" - - # Checks if event is a read receipt - def is_read_receipt(event): - return event["type"] == "m.receipt" - - # Sync - channel = self.make_request( - "GET", - self.url % self.next_batch, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200) - - # Store the next batch for the next request. - self.next_batch = channel.json_body["next_batch"] - - # Return the read receipt - ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][ - "ephemeral" - ]["events"] - return next(filter(is_read_receipt, ephemeral_events), None) - - -class UnreadMessagesTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - read_marker.register_servlets, - room.register_servlets, - sync.register_servlets, - receipts.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.url = "/sync?since=%s" - self.next_batch = "s0" - - # Register the first user (used to check the unread counts). - self.user_id = self.register_user("kermit", "monkey") - self.tok = self.login("kermit", "monkey") - - # Create the room we'll check unread counts for. - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - # Register the second user (used to send events to the room). - self.user2 = self.register_user("kermit2", "monkey") - self.tok2 = self.login("kermit2", "monkey") - - # Change the power levels of the room so that the second user can send state - # events. - self.helper.send_state( - self.room_id, - EventTypes.PowerLevels, - { - "users": {self.user_id: 100, self.user2: 100}, - "users_default": 0, - "events": { - "m.room.name": 50, - "m.room.power_levels": 100, - "m.room.history_visibility": 100, - "m.room.canonical_alias": 50, - "m.room.avatar": 50, - "m.room.tombstone": 100, - "m.room.server_acl": 100, - "m.room.encryption": 100, - }, - "events_default": 0, - "state_default": 50, - "ban": 50, - "kick": 50, - "redact": 50, - "invite": 0, - }, - tok=self.tok, - ) - - def test_unread_counts(self): - """Tests that /sync returns the right value for the unread count (MSC2654).""" - - # Check that our own messages don't increase the unread count. - self.helper.send(self.room_id, "hello", tok=self.tok) - self._check_unread_count(0) - - # Join the new user and check that this doesn't increase the unread count. - self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - self._check_unread_count(0) - - # Check that the new user sending a message increases our unread count. - res = self.helper.send(self.room_id, "hello", tok=self.tok2) - self._check_unread_count(1) - - # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({"m.read": res["event_id"]}).encode("utf8") - channel = self.make_request( - "POST", - "/rooms/%s/read_markers" % self.room_id, - body, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check that the unread counter is back to 0. - self._check_unread_count(0) - - # Check that hidden read receipts don't break unread counts - res = self.helper.send(self.room_id, "hello", tok=self.tok2) - self._check_unread_count(1) - - # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({ReadReceiptEventFields.MSC2285_HIDDEN: True}).encode("utf8") - channel = self.make_request( - "POST", - "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), - body, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200, channel.json_body) - - # Check that the unread counter is back to 0. - self._check_unread_count(0) - - # Check that room name changes increase the unread counter. - self.helper.send_state( - self.room_id, - "m.room.name", - {"name": "my super room"}, - tok=self.tok2, - ) - self._check_unread_count(1) - - # Check that room topic changes increase the unread counter. - self.helper.send_state( - self.room_id, - "m.room.topic", - {"topic": "welcome!!!"}, - tok=self.tok2, - ) - self._check_unread_count(2) - - # Check that encrypted messages increase the unread counter. - self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2) - self._check_unread_count(3) - - # Check that custom events with a body increase the unread counter. - self.helper.send_event( - self.room_id, - "org.matrix.custom_type", - {"body": "hello"}, - tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that edits don't increase the unread counter. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "body": "hello", - "msgtype": "m.text", - "m.relates_to": {"rel_type": RelationTypes.REPLACE}, - }, - tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that notices don't increase the unread counter. - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"body": "hello", "msgtype": "m.notice"}, - tok=self.tok2, - ) - self._check_unread_count(4) - - # Check that tombstone events changes increase the unread counter. - self.helper.send_state( - self.room_id, - EventTypes.Tombstone, - {"replacement_room": "!someroom:test"}, - tok=self.tok2, - ) - self._check_unread_count(5) - - def _check_unread_count(self, expected_count: int): - """Syncs and compares the unread count with the expected value.""" - - channel = self.make_request( - "GET", - self.url % self.next_batch, - access_token=self.tok, - ) - - self.assertEqual(channel.code, 200, channel.json_body) - - room_entry = channel.json_body["rooms"]["join"][self.room_id] - self.assertEqual( - room_entry["org.matrix.msc2654.unread_count"], - expected_count, - room_entry, - ) - - # Store the next batch for the next request. - self.next_batch = channel.json_body["next_batch"] - - -class SyncCacheTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - sync.register_servlets, - ] - - def test_noop_sync_does_not_tightloop(self): - """If the sync times out, we shouldn't cache the result - - Essentially a regression test for #8518. - """ - self.user_id = self.register_user("kermit", "monkey") - self.tok = self.login("kermit", "monkey") - - # we should immediately get an initial sync response - channel = self.make_request("GET", "/sync", access_token=self.tok) - self.assertEqual(channel.code, 200, channel.json_body) - - # now, make an incremental sync request, with a timeout - next_batch = channel.json_body["next_batch"] - channel = self.make_request( - "GET", - f"/sync?since={next_batch}&timeout=10000", - access_token=self.tok, - await_result=False, - ) - # that should block for 10 seconds - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=9900) - channel.await_result(timeout_ms=200) - self.assertEqual(channel.code, 200, channel.json_body) - - # we expect the next_batch in the result to be the same as before - self.assertEqual(channel.json_body["next_batch"], next_batch) - - # another incremental sync should also block. - channel = self.make_request( - "GET", - f"/sync?since={next_batch}&timeout=10000", - access_token=self.tok, - await_result=False, - ) - # that should block for 10 seconds - with self.assertRaises(TimedOutException): - channel.await_result(timeout_ms=9900) - channel.await_result(timeout_ms=200) - self.assertEqual(channel.code, 200, channel.json_body) diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/v2_alpha/test_upgrade_room.py deleted file mode 100644 index 72f976d8e2..0000000000 --- a/tests/rest/client/v2_alpha/test_upgrade_room.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -from synapse.config.server import DEFAULT_ROOM_VERSION -from synapse.rest import admin -from synapse.rest.client import login, room, room_upgrade_rest_servlet - -from tests import unittest -from tests.server import FakeChannel - - -class UpgradeRoomTest(unittest.HomeserverTestCase): - servlets = [ - admin.register_servlets, - login.register_servlets, - room.register_servlets, - room_upgrade_rest_servlet.register_servlets, - ] - - def prepare(self, reactor, clock, hs): - self.store = hs.get_datastore() - self.handler = hs.get_user_directory_handler() - - self.creator = self.register_user("creator", "pass") - self.creator_token = self.login(self.creator, "pass") - - self.other = self.register_user("user", "pass") - self.other_token = self.login(self.other, "pass") - - self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_token) - self.helper.join(self.room_id, self.other, tok=self.other_token) - - def _upgrade_room(self, token: Optional[str] = None) -> FakeChannel: - # We never want a cached response. - self.reactor.advance(5 * 60 + 1) - - return self.make_request( - "POST", - "/_matrix/client/r0/rooms/%s/upgrade" % self.room_id, - # This will upgrade a room to the same version, but that's fine. - content={"new_version": DEFAULT_ROOM_VERSION}, - access_token=token or self.creator_token, - ) - - def test_upgrade(self): - """ - Upgrading a room should work fine. - """ - channel = self._upgrade_room() - self.assertEquals(200, channel.code, channel.result) - self.assertIn("replacement_room", channel.json_body) - - def test_not_in_room(self): - """ - Upgrading a room should work fine. - """ - # THe user isn't in the room. - roomless = self.register_user("roomless", "pass") - roomless_token = self.login(roomless, "pass") - - channel = self._upgrade_room(roomless_token) - self.assertEquals(403, channel.code, channel.result) - - def test_power_levels(self): - """ - Another user can upgrade the room if their power level is increased. - """ - # The other user doesn't have the proper power level. - channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) - - # Increase the power levels so that this user can upgrade. - power_levels = self.helper.get_state( - self.room_id, - "m.room.power_levels", - tok=self.creator_token, - ) - power_levels["users"][self.other] = 100 - self.helper.send_state( - self.room_id, - "m.room.power_levels", - body=power_levels, - tok=self.creator_token, - ) - - # The upgrade should succeed! - channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) - - def test_power_levels_user_default(self): - """ - Another user can upgrade the room if the default power level for users is increased. - """ - # The other user doesn't have the proper power level. - channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) - - # Increase the power levels so that this user can upgrade. - power_levels = self.helper.get_state( - self.room_id, - "m.room.power_levels", - tok=self.creator_token, - ) - power_levels["users_default"] = 100 - self.helper.send_state( - self.room_id, - "m.room.power_levels", - body=power_levels, - tok=self.creator_token, - ) - - # The upgrade should succeed! - channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) - - def test_power_levels_tombstone(self): - """ - Another user can upgrade the room if they can send the tombstone event. - """ - # The other user doesn't have the proper power level. - channel = self._upgrade_room(self.other_token) - self.assertEquals(403, channel.code, channel.result) - - # Increase the power levels so that this user can upgrade. - power_levels = self.helper.get_state( - self.room_id, - "m.room.power_levels", - tok=self.creator_token, - ) - power_levels["events"]["m.room.tombstone"] = 0 - self.helper.send_state( - self.room_id, - "m.room.power_levels", - body=power_levels, - tok=self.creator_token, - ) - - # The upgrade should succeed! - channel = self._upgrade_room(self.other_token) - self.assertEquals(200, channel.code, channel.result) - - power_levels = self.helper.get_state( - self.room_id, - "m.room.power_levels", - tok=self.creator_token, - ) - self.assertNotIn(self.other, power_levels["users"]) diff --git a/tests/unittest.py b/tests/unittest.py index 3eec9c4d5b..f2c90cc47b 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -252,7 +252,7 @@ class HomeserverTestCase(TestCase): reactor=self.reactor, ) - from tests.rest.client.v1.utils import RestHelper + from tests.rest.client.utils import RestHelper self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None)) -- cgit 1.5.1 From 947dbbdfd1e0029da66f956d277b7c089928e1e7 Mon Sep 17 00:00:00 2001 From: Callum Brown Date: Sat, 21 Aug 2021 22:14:43 +0100 Subject: Implement MSC3231: Token authenticated registration (#10142) Signed-off-by: Callum Brown This is part of my GSoC project implementing [MSC3231](https://github.com/matrix-org/matrix-doc/pull/3231). --- changelog.d/10142.feature | 1 + docs/SUMMARY.md | 1 + docs/sample_config.yaml | 15 + .../admin_api/registration_tokens.md | 295 +++++++++ docs/workers.md | 1 + synapse/api/constants.py | 1 + synapse/app/generic_worker.py | 6 +- synapse/config/ratelimiting.py | 11 + synapse/config/registration.py | 15 + synapse/handlers/ui_auth/__init__.py | 5 + synapse/handlers/ui_auth/checkers.py | 65 ++ synapse/res/templates/registration_token.html | 23 + synapse/rest/admin/__init__.py | 8 + synapse/rest/admin/registration_tokens.py | 321 ++++++++++ synapse/rest/client/auth.py | 24 + synapse/rest/client/register.py | 72 +++ synapse/storage/databases/main/registration.py | 316 +++++++++ synapse/storage/databases/main/ui_auth.py | 43 ++ .../main/delta/63/01create_registration_tokens.sql | 23 + tests/rest/admin/test_registration_tokens.py | 710 +++++++++++++++++++++ tests/rest/client/test_register.py | 434 +++++++++++++ 21 files changed, 2389 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10142.feature create mode 100644 docs/usage/administration/admin_api/registration_tokens.md create mode 100644 synapse/res/templates/registration_token.html create mode 100644 synapse/rest/admin/registration_tokens.py create mode 100644 synapse/storage/schema/main/delta/63/01create_registration_tokens.sql create mode 100644 tests/rest/admin/test_registration_tokens.py diff --git a/changelog.d/10142.feature b/changelog.d/10142.feature new file mode 100644 index 0000000000..5353f6269d --- /dev/null +++ b/changelog.d/10142.feature @@ -0,0 +1 @@ +Add support for [MSC3231 - Token authenticated registration](https://github.com/matrix-org/matrix-doc/pull/3231). Users can be required to submit a token during registration to authenticate themselves. Contributed by Callum Brown. diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 634bc833ab..4fcd2b7852 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -53,6 +53,7 @@ - [Media](admin_api/media_admin_api.md) - [Purge History](admin_api/purge_history_api.md) - [Register Users](admin_api/register_api.md) + - [Registration Tokens](usage/administration/admin_api/registration_tokens.md) - [Manipulate Room Membership](admin_api/room_membership.md) - [Rooms](admin_api/rooms.md) - [Server Notices](admin_api/server_notices.md) diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 2b0c453242..935841dbfa 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -793,6 +793,8 @@ log_config: "CONFDIR/SERVERNAME.log.config" # is using # - one for registration that ratelimits registration requests based on the # client's IP address. +# - one for checking the validity of registration tokens that ratelimits +# requests based on the client's IP address. # - one for login that ratelimits login requests based on the client's IP # address. # - one for login that ratelimits login requests based on the account the @@ -821,6 +823,10 @@ log_config: "CONFDIR/SERVERNAME.log.config" # per_second: 0.17 # burst_count: 3 # +#rc_registration_token_validity: +# per_second: 0.1 +# burst_count: 5 +# #rc_login: # address: # per_second: 0.17 @@ -1169,6 +1175,15 @@ url_preview_accept_language: # #enable_3pid_lookup: true +# Require users to submit a token during registration. +# Tokens can be managed using the admin API: +# https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html +# Note that `enable_registration` must be set to `true`. +# Disabling this option will not delete any tokens previously generated. +# Defaults to false. Uncomment the following to require tokens: +# +#registration_requires_token: true + # If set, allows registration of standard or admin accounts by anyone who # has the shared secret, even if registration is otherwise disabled. # diff --git a/docs/usage/administration/admin_api/registration_tokens.md b/docs/usage/administration/admin_api/registration_tokens.md new file mode 100644 index 0000000000..828c0277d6 --- /dev/null +++ b/docs/usage/administration/admin_api/registration_tokens.md @@ -0,0 +1,295 @@ +# Registration Tokens + +This API allows you to manage tokens which can be used to authenticate +registration requests, as proposed in [MSC3231](https://github.com/govynnus/matrix-doc/blob/token-registration/proposals/3231-token-authenticated-registration.md). +To use it, you will need to enable the `registration_requires_token` config +option, and authenticate by providing an `access_token` for a server admin: +see [Admin API](../../usage/administration/admin_api). +Note that this API is still experimental; not all clients may support it yet. + + +## Registration token objects + +Most endpoints make use of JSON objects that contain details about tokens. +These objects have the following fields: +- `token`: The token which can be used to authenticate registration. +- `uses_allowed`: The number of times the token can be used to complete a + registration before it becomes invalid. +- `pending`: The number of pending uses the token has. When someone uses + the token to authenticate themselves, the pending counter is incremented + so that the token is not used more than the permitted number of times. + When the person completes registration the pending counter is decremented, + and the completed counter is incremented. +- `completed`: The number of times the token has been used to successfully + complete a registration. +- `expiry_time`: The latest time the token is valid. Given as the number of + milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch). + To convert this into a human-readable form you can remove the milliseconds + and use the `date` command. For example, `date -d '@1625394937'`. + + +## List all tokens + +Lists all tokens and details about them. If the request is successful, the top +level JSON object will have a `registration_tokens` key which is an array of +registration token objects. + +``` +GET /_synapse/admin/v1/registration_tokens +``` + +Optional query parameters: +- `valid`: `true` or `false`. If `true`, only valid tokens are returned. + If `false`, only tokens that have expired or have had all uses exhausted are + returned. If omitted, all tokens are returned regardless of validity. + +Example: + +``` +GET /_synapse/admin/v1/registration_tokens +``` +``` +200 OK + +{ + "registration_tokens": [ + { + "token": "abcd", + "uses_allowed": 3, + "pending": 0, + "completed": 1, + "expiry_time": null + }, + { + "token": "pqrs", + "uses_allowed": 2, + "pending": 1, + "completed": 1, + "expiry_time": null + }, + { + "token": "wxyz", + "uses_allowed": null, + "pending": 0, + "completed": 9, + "expiry_time": 1625394937000 // 2021-07-04 10:35:37 UTC + } + ] +} +``` + +Example using the `valid` query parameter: + +``` +GET /_synapse/admin/v1/registration_tokens?valid=false +``` +``` +200 OK + +{ + "registration_tokens": [ + { + "token": "pqrs", + "uses_allowed": 2, + "pending": 1, + "completed": 1, + "expiry_time": null + }, + { + "token": "wxyz", + "uses_allowed": null, + "pending": 0, + "completed": 9, + "expiry_time": 1625394937000 // 2021-07-04 10:35:37 UTC + } + ] +} +``` + + +## Get one token + +Get details about a single token. If the request is successful, the response +body will be a registration token object. + +``` +GET /_synapse/admin/v1/registration_tokens/ +``` + +Path parameters: +- `token`: The registration token to return details of. + +Example: + +``` +GET /_synapse/admin/v1/registration_tokens/abcd +``` +``` +200 OK + +{ + "token": "abcd", + "uses_allowed": 3, + "pending": 0, + "completed": 1, + "expiry_time": null +} +``` + + +## Create token + +Create a new registration token. If the request is successful, the newly created +token will be returned as a registration token object in the response body. + +``` +POST /_synapse/admin/v1/registration_tokens/new +``` + +The request body must be a JSON object and can contain the following fields: +- `token`: The registration token. A string of no more than 64 characters that + consists only of characters matched by the regex `[A-Za-z0-9-_]`. + Default: randomly generated. +- `uses_allowed`: The integer number of times the token can be used to complete + a registration before it becomes invalid. + Default: `null` (unlimited uses). +- `expiry_time`: The latest time the token is valid. Given as the number of + milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch). + You could use, for example, `date '+%s000' -d 'tomorrow'`. + Default: `null` (token does not expire). +- `length`: The length of the token randomly generated if `token` is not + specified. Must be between 1 and 64 inclusive. Default: `16`. + +If a field is omitted the default is used. + +Example using defaults: + +``` +POST /_synapse/admin/v1/registration_tokens/new + +{} +``` +``` +200 OK + +{ + "token": "0M-9jbkf2t_Tgiw1", + "uses_allowed": null, + "pending": 0, + "completed": 0, + "expiry_time": null +} +``` + +Example specifying some fields: + +``` +POST /_synapse/admin/v1/registration_tokens/new + +{ + "token": "defg", + "uses_allowed": 1 +} +``` +``` +200 OK + +{ + "token": "defg", + "uses_allowed": 1, + "pending": 0, + "completed": 0, + "expiry_time": null +} +``` + + +## Update token + +Update the number of allowed uses or expiry time of a token. If the request is +successful, the updated token will be returned as a registration token object +in the response body. + +``` +PUT /_synapse/admin/v1/registration_tokens/ +``` + +Path parameters: +- `token`: The registration token to update. + +The request body must be a JSON object and can contain the following fields: +- `uses_allowed`: The integer number of times the token can be used to complete + a registration before it becomes invalid. By setting `uses_allowed` to `0` + the token can be easily made invalid without deleting it. + If `null` the token will have an unlimited number of uses. +- `expiry_time`: The latest time the token is valid. Given as the number of + milliseconds since 1970-01-01 00:00:00 UTC (the start of the Unix epoch). + If `null` the token will not expire. + +If a field is omitted its value is not modified. + +Example: + +``` +PUT /_synapse/admin/v1/registration_tokens/defg + +{ + "expiry_time": 4781243146000 // 2121-07-06 11:05:46 UTC +} +``` +``` +200 OK + +{ + "token": "defg", + "uses_allowed": 1, + "pending": 0, + "completed": 0, + "expiry_time": 4781243146000 +} +``` + + +## Delete token + +Delete a registration token. If the request is successful, the response body +will be an empty JSON object. + +``` +DELETE /_synapse/admin/v1/registration_tokens/ +``` + +Path parameters: +- `token`: The registration token to delete. + +Example: + +``` +DELETE /_synapse/admin/v1/registration_tokens/wxyz +``` +``` +200 OK + +{} +``` + + +## Errors + +If a request fails a "standard error response" will be returned as defined in +the [Matrix Client-Server API specification](https://matrix.org/docs/spec/client_server/r0.6.1#api-standards). + +For example, if the token specified in a path parameter does not exist a +`404 Not Found` error will be returned. + +``` +GET /_synapse/admin/v1/registration_tokens/1234 +``` +``` +404 Not Found + +{ + "errcode": "M_NOT_FOUND", + "error": "No such registration token: 1234" +} +``` diff --git a/docs/workers.md b/docs/workers.md index 2e63f03452..3121241894 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -236,6 +236,7 @@ expressions: # Registration/login requests ^/_matrix/client/(api/v1|r0|unstable)/login$ ^/_matrix/client/(r0|unstable)/register$ + ^/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity$ # Event sending requests ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact diff --git a/synapse/api/constants.py b/synapse/api/constants.py index e0e24fddac..829061c870 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -79,6 +79,7 @@ class LoginType: TERMS = "m.login.terms" SSO = "m.login.sso" DUMMY = "m.login.dummy" + REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token" # This is used in the `type` parameter for /register when called by diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 845e6a8220..fd2626dbe1 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -95,7 +95,10 @@ from synapse.rest.client.profile import ( ProfileRestServlet, ) from synapse.rest.client.push_rule import PushRuleRestServlet -from synapse.rest.client.register import RegisterRestServlet +from synapse.rest.client.register import ( + RegisterRestServlet, + RegistrationTokenValidityRestServlet, +) from synapse.rest.client.sendtodevice import SendToDeviceRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.rest.client.voip import VoipRestServlet @@ -279,6 +282,7 @@ class GenericWorkerServer(HomeServer): resource = JsonResource(self, canonical_json=False) RegisterRestServlet(self).register(resource) + RegistrationTokenValidityRestServlet(self).register(resource) login.register_servlets(self, resource) ThreepidRestServlet(self).register(resource) DevicesRestServlet(self).register(resource) diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 7a8d5851c4..f856327bd8 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -79,6 +79,11 @@ class RatelimitConfig(Config): self.rc_registration = RateLimitConfig(config.get("rc_registration", {})) + self.rc_registration_token_validity = RateLimitConfig( + config.get("rc_registration_token_validity", {}), + defaults={"per_second": 0.1, "burst_count": 5}, + ) + rc_login_config = config.get("rc_login", {}) self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {})) self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {})) @@ -143,6 +148,8 @@ class RatelimitConfig(Config): # is using # - one for registration that ratelimits registration requests based on the # client's IP address. + # - one for checking the validity of registration tokens that ratelimits + # requests based on the client's IP address. # - one for login that ratelimits login requests based on the client's IP # address. # - one for login that ratelimits login requests based on the account the @@ -171,6 +178,10 @@ class RatelimitConfig(Config): # per_second: 0.17 # burst_count: 3 # + #rc_registration_token_validity: + # per_second: 0.1 + # burst_count: 5 + # #rc_login: # address: # per_second: 0.17 diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 0ad919b139..7cffdacfa5 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -33,6 +33,9 @@ class RegistrationConfig(Config): self.registrations_require_3pid = config.get("registrations_require_3pid", []) self.allowed_local_3pids = config.get("allowed_local_3pids", []) self.enable_3pid_lookup = config.get("enable_3pid_lookup", True) + self.registration_requires_token = config.get( + "registration_requires_token", False + ) self.registration_shared_secret = config.get("registration_shared_secret") self.bcrypt_rounds = config.get("bcrypt_rounds", 12) @@ -140,6 +143,9 @@ class RegistrationConfig(Config): "mechanism by removing the `access_token_lifetime` option." ) + # The fallback template used for authenticating using a registration token + self.registration_token_template = self.read_template("registration_token.html") + # The success template used during fallback auth. self.fallback_success_template = self.read_template("auth_success.html") @@ -199,6 +205,15 @@ class RegistrationConfig(Config): # #enable_3pid_lookup: true + # Require users to submit a token during registration. + # Tokens can be managed using the admin API: + # https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/registration_tokens.html + # Note that `enable_registration` must be set to `true`. + # Disabling this option will not delete any tokens previously generated. + # Defaults to false. Uncomment the following to require tokens: + # + #registration_requires_token: true + # If set, allows registration of standard or admin accounts by anyone who # has the shared secret, even if registration is otherwise disabled. # diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py index 4c3b669fae..13b0c61d2e 100644 --- a/synapse/handlers/ui_auth/__init__.py +++ b/synapse/handlers/ui_auth/__init__.py @@ -34,3 +34,8 @@ class UIAuthSessionDataConstants: # used by validate_user_via_ui_auth to store the mxid of the user we are validating # for. REQUEST_USER_ID = "request_user_id" + + # used during registration to store the registration token used (if required) so that: + # - we can prevent a token being used twice by one session + # - we can 'use up' the token after registration has successfully completed + REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token" diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 270541cc76..d3828dec6b 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -241,11 +241,76 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): return await self._check_threepid("msisdn", authdict) +class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): + AUTH_TYPE = LoginType.REGISTRATION_TOKEN + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + self.hs = hs + self._enabled = bool(hs.config.registration_requires_token) + self.store = hs.get_datastore() + + def is_enabled(self) -> bool: + return self._enabled + + async def check_auth(self, authdict: dict, clientip: str) -> Any: + if "token" not in authdict: + raise LoginError(400, "Missing registration token", Codes.MISSING_PARAM) + if not isinstance(authdict["token"], str): + raise LoginError( + 400, "Registration token must be a string", Codes.INVALID_PARAM + ) + if "session" not in authdict: + raise LoginError(400, "Missing UIA session", Codes.MISSING_PARAM) + + # Get these here to avoid cyclic dependencies + from synapse.handlers.ui_auth import UIAuthSessionDataConstants + + auth_handler = self.hs.get_auth_handler() + + session = authdict["session"] + token = authdict["token"] + + # If the LoginType.REGISTRATION_TOKEN stage has already been completed, + # return early to avoid incrementing `pending` again. + stored_token = await auth_handler.get_session_data( + session, UIAuthSessionDataConstants.REGISTRATION_TOKEN + ) + if stored_token: + if token != stored_token: + raise LoginError( + 400, "Registration token has changed", Codes.INVALID_PARAM + ) + else: + return token + + if await self.store.registration_token_is_valid(token): + # Increment pending counter, so that if token has limited uses it + # can't be used up by someone else in the meantime. + await self.store.set_registration_token_pending(token) + # Store the token in the UIA session, so that once registration + # is complete `completed` can be incremented. + await auth_handler.set_session_data( + session, + UIAuthSessionDataConstants.REGISTRATION_TOKEN, + token, + ) + # The token will be stored as the result of the authentication stage + # in ui_auth_sessions_credentials. This allows the pending counter + # for tokens to be decremented when expired sessions are deleted. + return token + else: + raise LoginError( + 401, "Invalid registration token", errcode=Codes.UNAUTHORIZED + ) + + INTERACTIVE_AUTH_CHECKERS = [ DummyAuthChecker, TermsAuthChecker, RecaptchaAuthChecker, EmailIdentityAuthChecker, MsisdnAuthChecker, + RegistrationTokenAuthChecker, ] """A list of UserInteractiveAuthChecker classes""" diff --git a/synapse/res/templates/registration_token.html b/synapse/res/templates/registration_token.html new file mode 100644 index 0000000000..4577ce1702 --- /dev/null +++ b/synapse/res/templates/registration_token.html @@ -0,0 +1,23 @@ + + +Authentication + + + + + +

+ {% if error is defined %} +

Error: {{ error }}

+ {% endif %} +

+ Please enter a registration token. +

+ + + +
+ + + diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 7f3051aef1..6e1c8736e1 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -36,6 +36,11 @@ from synapse.rest.admin.event_reports import ( ) from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo +from synapse.rest.admin.registration_tokens import ( + ListRegistrationTokensRestServlet, + NewRegistrationTokenRestServlet, + RegistrationTokenRestServlet, +) from synapse.rest.admin.rooms import ( DeleteRoomRestServlet, ForwardExtremitiesRestServlet, @@ -238,6 +243,9 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomEventContextServlet(hs).register(http_server) RateLimitRestServlet(hs).register(http_server) UsernameAvailableRestServlet(hs).register(http_server) + ListRegistrationTokensRestServlet(hs).register(http_server) + NewRegistrationTokenRestServlet(hs).register(http_server) + RegistrationTokenRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource( diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py new file mode 100644 index 0000000000..5a1c929d85 --- /dev/null +++ b/synapse/rest/admin/registration_tokens.py @@ -0,0 +1,321 @@ +# Copyright 2021 Callum Brown +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import string +from typing import TYPE_CHECKING, Tuple + +from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_boolean, + parse_json_object_from_request, +) +from synapse.http.site import SynapseRequest +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class ListRegistrationTokensRestServlet(RestServlet): + """List registration tokens. + + To list all tokens: + + GET /_synapse/admin/v1/registration_tokens + + 200 OK + + { + "registration_tokens": [ + { + "token": "abcd", + "uses_allowed": 3, + "pending": 0, + "completed": 1, + "expiry_time": null + }, + { + "token": "wxyz", + "uses_allowed": null, + "pending": 0, + "completed": 9, + "expiry_time": 1625394937000 + } + ] + } + + The optional query parameter `valid` can be used to filter the response. + If it is `true`, only valid tokens are returned. If it is `false`, only + tokens that have expired or have had all uses exhausted are returned. + If it is omitted, all tokens are returned regardless of validity. + """ + + PATTERNS = admin_patterns("/registration_tokens$") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + valid = parse_boolean(request, "valid") + token_list = await self.store.get_registration_tokens(valid) + return 200, {"registration_tokens": token_list} + + +class NewRegistrationTokenRestServlet(RestServlet): + """Create a new registration token. + + For example, to create a token specifying some fields: + + POST /_synapse/admin/v1/registration_tokens/new + + { + "token": "defg", + "uses_allowed": 1 + } + + 200 OK + + { + "token": "defg", + "uses_allowed": 1, + "pending": 0, + "completed": 0, + "expiry_time": null + } + + Defaults are used for any fields not specified. + """ + + PATTERNS = admin_patterns("/registration_tokens/new$") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + # A string of all the characters allowed to be in a registration_token + self.allowed_chars = string.ascii_letters + string.digits + "-_" + self.allowed_chars_set = set(self.allowed_chars) + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + body = parse_json_object_from_request(request) + + if "token" in body: + token = body["token"] + if not isinstance(token, str): + raise SynapseError(400, "token must be a string", Codes.INVALID_PARAM) + if not (0 < len(token) <= 64): + raise SynapseError( + 400, + "token must not be empty and must not be longer than 64 characters", + Codes.INVALID_PARAM, + ) + if not set(token).issubset(self.allowed_chars_set): + raise SynapseError( + 400, + "token must consist only of characters matched by the regex [A-Za-z0-9-_]", + Codes.INVALID_PARAM, + ) + + else: + # Get length of token to generate (default is 16) + length = body.get("length", 16) + if not isinstance(length, int): + raise SynapseError( + 400, "length must be an integer", Codes.INVALID_PARAM + ) + if not (0 < length <= 64): + raise SynapseError( + 400, + "length must be greater than zero and not greater than 64", + Codes.INVALID_PARAM, + ) + + # Generate token + token = await self.store.generate_registration_token( + length, self.allowed_chars + ) + + uses_allowed = body.get("uses_allowed", None) + if not ( + uses_allowed is None + or (isinstance(uses_allowed, int) and uses_allowed >= 0) + ): + raise SynapseError( + 400, + "uses_allowed must be a non-negative integer or null", + Codes.INVALID_PARAM, + ) + + expiry_time = body.get("expiry_time", None) + if not isinstance(expiry_time, (int, type(None))): + raise SynapseError( + 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM + ) + if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): + raise SynapseError( + 400, "expiry_time must not be in the past", Codes.INVALID_PARAM + ) + + created = await self.store.create_registration_token( + token, uses_allowed, expiry_time + ) + if not created: + raise SynapseError( + 400, f"Token already exists: {token}", Codes.INVALID_PARAM + ) + + resp = { + "token": token, + "uses_allowed": uses_allowed, + "pending": 0, + "completed": 0, + "expiry_time": expiry_time, + } + return 200, resp + + +class RegistrationTokenRestServlet(RestServlet): + """Retrieve, update, or delete the given token. + + For example, + + to retrieve a token: + + GET /_synapse/admin/v1/registration_tokens/abcd + + 200 OK + + { + "token": "abcd", + "uses_allowed": 3, + "pending": 0, + "completed": 1, + "expiry_time": null + } + + + to update a token: + + PUT /_synapse/admin/v1/registration_tokens/defg + + { + "uses_allowed": 5, + "expiry_time": 4781243146000 + } + + 200 OK + + { + "token": "defg", + "uses_allowed": 5, + "pending": 0, + "completed": 0, + "expiry_time": 4781243146000 + } + + + to delete a token: + + DELETE /_synapse/admin/v1/registration_tokens/wxyz + + 200 OK + + {} + """ + + PATTERNS = admin_patterns("/registration_tokens/(?P[^/]*)$") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.clock = hs.get_clock() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]: + """Retrieve a registration token.""" + await assert_requester_is_admin(self.auth, request) + token_info = await self.store.get_one_registration_token(token) + + # If no result return a 404 + if token_info is None: + raise NotFoundError(f"No such registration token: {token}") + + return 200, token_info + + async def on_PUT(self, request: SynapseRequest, token: str) -> Tuple[int, JsonDict]: + """Update a registration token.""" + await assert_requester_is_admin(self.auth, request) + body = parse_json_object_from_request(request) + new_attributes = {} + + # Only add uses_allowed to new_attributes if it is present and valid + if "uses_allowed" in body: + uses_allowed = body["uses_allowed"] + if not ( + uses_allowed is None + or (isinstance(uses_allowed, int) and uses_allowed >= 0) + ): + raise SynapseError( + 400, + "uses_allowed must be a non-negative integer or null", + Codes.INVALID_PARAM, + ) + new_attributes["uses_allowed"] = uses_allowed + + if "expiry_time" in body: + expiry_time = body["expiry_time"] + if not isinstance(expiry_time, (int, type(None))): + raise SynapseError( + 400, "expiry_time must be an integer or null", Codes.INVALID_PARAM + ) + if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec(): + raise SynapseError( + 400, "expiry_time must not be in the past", Codes.INVALID_PARAM + ) + new_attributes["expiry_time"] = expiry_time + + if len(new_attributes) == 0: + # Nothing to update, get token info to return + token_info = await self.store.get_one_registration_token(token) + else: + token_info = await self.store.update_registration_token( + token, new_attributes + ) + + # If no result return a 404 + if token_info is None: + raise NotFoundError(f"No such registration token: {token}") + + return 200, token_info + + async def on_DELETE( + self, request: SynapseRequest, token: str + ) -> Tuple[int, JsonDict]: + """Delete a registration token.""" + await assert_requester_is_admin(self.auth, request) + + if await self.store.delete_registration_token(token): + return 200, {} + + raise NotFoundError(f"No such registration token: {token}") diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 73284e48ec..91800c0278 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -46,6 +46,7 @@ class AuthRestServlet(RestServlet): self.registration_handler = hs.get_registration_handler() self.recaptcha_template = hs.config.recaptcha_template self.terms_template = hs.config.terms_template + self.registration_token_template = hs.config.registration_token_template self.success_template = hs.config.fallback_success_template async def on_GET(self, request, stagetype): @@ -74,6 +75,12 @@ class AuthRestServlet(RestServlet): # re-authenticate with their SSO provider. html = await self.auth_handler.start_sso_ui_auth(request, session) + elif stagetype == LoginType.REGISTRATION_TOKEN: + html = self.registration_token_template.render( + session=session, + myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web", + ) + else: raise SynapseError(404, "Unknown auth stage type") @@ -140,6 +147,23 @@ class AuthRestServlet(RestServlet): # The SSO fallback workflow should not post here, raise SynapseError(404, "Fallback SSO auth does not support POST requests.") + elif stagetype == LoginType.REGISTRATION_TOKEN: + token = parse_string(request, "token", required=True) + authdict = {"session": session, "token": token} + + try: + await self.auth_handler.add_oob_auth( + LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP() + ) + except LoginError as e: + html = self.registration_token_template.render( + session=session, + myurl=f"{CLIENT_API_PREFIX}/r0/auth/{LoginType.REGISTRATION_TOKEN}/fallback/web", + error=e.msg, + ) + else: + html = self.success_template.render() + else: raise SynapseError(404, "Unknown auth stage type") diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 58b8e8f261..2781a0ea96 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -28,6 +28,7 @@ from synapse.api.errors import ( ThreepidValidationError, UnrecognizedRequestError, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.config import ConfigError from synapse.config.captcha import CaptchaConfig from synapse.config.consent import ConsentConfig @@ -379,6 +380,55 @@ class UsernameAvailabilityRestServlet(RestServlet): return 200, {"available": True} +class RegistrationTokenValidityRestServlet(RestServlet): + """Check the validity of a registration token. + + Example: + + GET /_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity?token=abcd + + 200 OK + + { + "valid": true + } + """ + + PATTERNS = client_patterns( + f"/org.matrix.msc3231/register/{LoginType.REGISTRATION_TOKEN}/validity", + releases=(), + unstable=True, + ) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super().__init__() + self.hs = hs + self.store = hs.get_datastore() + self.ratelimiter = Ratelimiter( + store=self.store, + clock=hs.get_clock(), + rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second, + burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count, + ) + + async def on_GET(self, request): + await self.ratelimiter.ratelimit(None, (request.getClientIP(),)) + + if not self.hs.config.enable_registration: + raise SynapseError( + 403, "Registration has been disabled", errcode=Codes.FORBIDDEN + ) + + token = parse_string(request, "token", required=True) + valid = await self.store.registration_token_is_valid(token) + + return 200, {"valid": valid} + + class RegisterRestServlet(RestServlet): PATTERNS = client_patterns("/register$") @@ -686,6 +736,22 @@ class RegisterRestServlet(RestServlet): ) if registered: + # Check if a token was used to authenticate registration + registration_token = await self.auth_handler.get_session_data( + session_id, + UIAuthSessionDataConstants.REGISTRATION_TOKEN, + ) + if registration_token: + # Increment the `completed` counter for the token + await self.store.use_registration_token(registration_token) + # Indicate that the token has been successfully used so that + # pending is not decremented again when expiring old UIA sessions. + await self.store.mark_ui_auth_stage_complete( + session_id, + LoginType.REGISTRATION_TOKEN, + True, + ) + await self.registration_handler.post_registration_actions( user_id=registered_user_id, auth_result=auth_result, @@ -868,6 +934,11 @@ def _calculate_registration_flows( for flow in flows: flow.insert(0, LoginType.RECAPTCHA) + # Prepend registration token to all flows if we're requiring a token + if config.registration_requires_token: + for flow in flows: + flow.insert(0, LoginType.REGISTRATION_TOKEN) + return flows @@ -876,4 +947,5 @@ def register_servlets(hs, http_server): MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) UsernameAvailabilityRestServlet(hs).register(http_server) RegistrationSubmitTokenServlet(hs).register(http_server) + RegistrationTokenValidityRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 469dd53e0c..a6517962f6 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1168,6 +1168,322 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="update_access_token_last_validated", ) + async def registration_token_is_valid(self, token: str) -> bool: + """Checks if a token can be used to authenticate a registration. + + Args: + token: The registration token to be checked + Returns: + True if the token is valid, False otherwise. + """ + res = await self.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["uses_allowed", "pending", "completed", "expiry_time"], + allow_none=True, + ) + + # Check if the token exists + if res is None: + return False + + # Check if the token has expired + now = self._clock.time_msec() + if res["expiry_time"] and res["expiry_time"] < now: + return False + + # Check if the token has been used up + if ( + res["uses_allowed"] + and res["pending"] + res["completed"] >= res["uses_allowed"] + ): + return False + + # Otherwise, the token is valid + return True + + async def set_registration_token_pending(self, token: str) -> None: + """Increment the pending registrations counter for a token. + + Args: + token: The registration token pending use + """ + + def _set_registration_token_pending_txn(txn): + pending = self.db_pool.simple_select_one_onecol_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcol="pending", + ) + self.db_pool.simple_update_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + updatevalues={"pending": pending + 1}, + ) + + return await self.db_pool.runInteraction( + "set_registration_token_pending", _set_registration_token_pending_txn + ) + + async def use_registration_token(self, token: str) -> None: + """Complete a use of the given registration token. + + The `pending` counter will be decremented, and the `completed` + counter will be incremented. + + Args: + token: The registration token to be 'used' + """ + + def _use_registration_token_txn(txn): + # Normally, res is Optional[Dict[str, Any]]. + # Override type because the return type is only optional if + # allow_none is True, and we don't want mypy throwing errors + # about None not being indexable. + res: Dict[str, Any] = self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=["pending", "completed"], + ) # type: ignore + + # Decrement pending and increment completed + self.db_pool.simple_update_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + updatevalues={ + "completed": res["completed"] + 1, + "pending": res["pending"] - 1, + }, + ) + + return await self.db_pool.runInteraction( + "use_registration_token", _use_registration_token_txn + ) + + async def get_registration_tokens( + self, valid: Optional[bool] = None + ) -> List[Dict[str, Any]]: + """List all registration tokens. Used by the admin API. + + Args: + valid: If True, only valid tokens are returned. + If False, only invalid tokens are returned. + Default is None: return all tokens regardless of validity. + + Returns: + A list of dicts, each containing details of a token. + """ + + def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]): + if valid is None: + # Return all tokens regardless of validity + txn.execute("SELECT * FROM registration_tokens") + + elif valid: + # Select valid tokens only + sql = ( + "SELECT * FROM registration_tokens WHERE " + "(uses_allowed > pending + completed OR uses_allowed IS NULL) " + "AND (expiry_time > ? OR expiry_time IS NULL)" + ) + txn.execute(sql, [now]) + + else: + # Select invalid tokens only + sql = ( + "SELECT * FROM registration_tokens WHERE " + "uses_allowed <= pending + completed OR expiry_time <= ?" + ) + txn.execute(sql, [now]) + + return self.db_pool.cursor_to_dict(txn) + + return await self.db_pool.runInteraction( + "select_registration_tokens", + select_registration_tokens_txn, + self._clock.time_msec(), + valid, + ) + + async def get_one_registration_token(self, token: str) -> Optional[Dict[str, Any]]: + """Get info about the given registration token. Used by the admin API. + + Args: + token: The token to retrieve information about. + + Returns: + A dict, or None if token doesn't exist. + """ + return await self.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["token", "uses_allowed", "pending", "completed", "expiry_time"], + allow_none=True, + desc="get_one_registration_token", + ) + + async def generate_registration_token( + self, length: int, chars: str + ) -> Optional[str]: + """Generate a random registration token. Used by the admin API. + + Args: + length: The length of the token to generate. + chars: A string of the characters allowed in the generated token. + + Returns: + The generated token. + + Raises: + SynapseError if a unique registration token could still not be + generated after a few tries. + """ + # Make a few attempts at generating a unique token of the required + # length before failing. + for _i in range(3): + # Generate token + token = "".join(random.choices(chars, k=length)) + + # Check if the token already exists + existing_token = await self.db_pool.simple_select_one_onecol( + "registration_tokens", + keyvalues={"token": token}, + retcol="token", + allow_none=True, + desc="check_if_registration_token_exists", + ) + + if existing_token is None: + # The generated token doesn't exist yet, return it + return token + + raise SynapseError( + 500, + "Unable to generate a unique registration token. Try again with a greater length", + Codes.UNKNOWN, + ) + + async def create_registration_token( + self, token: str, uses_allowed: Optional[int], expiry_time: Optional[int] + ) -> bool: + """Create a new registration token. Used by the admin API. + + Args: + token: The token to create. + uses_allowed: The number of times the token can be used to complete + a registration before it becomes invalid. A value of None indicates + unlimited uses. + expiry_time: The latest time the token is valid. Given as the + number of milliseconds since 1970-01-01 00:00:00 UTC. A value of + None indicates that the token does not expire. + + Returns: + Whether the row was inserted or not. + """ + + def _create_registration_token_txn(txn): + row = self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=["token"], + allow_none=True, + ) + + if row is not None: + # Token already exists + return False + + self.db_pool.simple_insert_txn( + txn, + "registration_tokens", + values={ + "token": token, + "uses_allowed": uses_allowed, + "pending": 0, + "completed": 0, + "expiry_time": expiry_time, + }, + ) + + return True + + return await self.db_pool.runInteraction( + "create_registration_token", _create_registration_token_txn + ) + + async def update_registration_token( + self, token: str, updatevalues: Dict[str, Optional[int]] + ) -> Optional[Dict[str, Any]]: + """Update a registration token. Used by the admin API. + + Args: + token: The token to update. + updatevalues: A dict with the fields to update. E.g.: + `{"uses_allowed": 3}` to update just uses_allowed, or + `{"uses_allowed": 3, "expiry_time": None}` to update both. + This is passed straight to simple_update_one. + + Returns: + A dict with all info about the token, or None if token doesn't exist. + """ + + def _update_registration_token_txn(txn): + try: + self.db_pool.simple_update_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + updatevalues=updatevalues, + ) + except StoreError: + # Update failed because token does not exist + return None + + # Get all info about the token so it can be sent in the response + return self.db_pool.simple_select_one_txn( + txn, + "registration_tokens", + keyvalues={"token": token}, + retcols=[ + "token", + "uses_allowed", + "pending", + "completed", + "expiry_time", + ], + allow_none=True, + ) + + return await self.db_pool.runInteraction( + "update_registration_token", _update_registration_token_txn + ) + + async def delete_registration_token(self, token: str) -> bool: + """Delete a registration token. Used by the admin API. + + Args: + token: The token to delete. + + Returns: + Whether the token was successfully deleted or not. + """ + try: + await self.db_pool.simple_delete_one( + "registration_tokens", + keyvalues={"token": token}, + desc="delete_registration_token", + ) + except StoreError: + # Deletion failed because token does not exist + return False + + return True + @cached() async def mark_access_token_as_used(self, token_id: int) -> None: """ diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 38bfdf5dad..4d6bbc94c7 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import attr +from synapse.api.constants import LoginType from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import LoggingTransaction @@ -329,6 +330,48 @@ class UIAuthWorkerStore(SQLBaseStore): keyvalues={}, ) + # If a registration token was used, decrement the pending counter + # before deleting the session. + rows = self.db_pool.simple_select_many_txn( + txn, + table="ui_auth_sessions_credentials", + column="session_id", + iterable=session_ids, + keyvalues={"stage_type": LoginType.REGISTRATION_TOKEN}, + retcols=["result"], + ) + + # Get the tokens used and how much pending needs to be decremented by. + token_counts: Dict[str, int] = {} + for r in rows: + # If registration was successfully completed, the result of the + # registration token stage for that session will be True. + # If a token was used to authenticate, but registration was + # never completed, the result will be the token used. + token = db_to_json(r["result"]) + if isinstance(token, str): + token_counts[token] = token_counts.get(token, 0) + 1 + + # Update the `pending` counters. + if len(token_counts) > 0: + token_rows = self.db_pool.simple_select_many_txn( + txn, + table="registration_tokens", + column="token", + iterable=list(token_counts.keys()), + keyvalues={}, + retcols=["token", "pending"], + ) + for token_row in token_rows: + token = token_row["token"] + new_pending = token_row["pending"] - token_counts[token] + self.db_pool.simple_update_one_txn( + txn, + table="registration_tokens", + keyvalues={"token": token}, + updatevalues={"pending": new_pending}, + ) + # Delete the corresponding completed credentials. self.db_pool.simple_delete_many_txn( txn, diff --git a/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql new file mode 100644 index 0000000000..ee6cf958f4 --- /dev/null +++ b/synapse/storage/schema/main/delta/63/01create_registration_tokens.sql @@ -0,0 +1,23 @@ +/* Copyright 2021 Callum Brown + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS registration_tokens( + token TEXT NOT NULL, -- The token that can be used for authentication. + uses_allowed INT, -- The total number of times this token can be used. NULL if no limit. + pending INT NOT NULL, -- The number of in progress registrations using this token. + completed INT NOT NULL, -- The number of times this token has been used to complete a registration. + expiry_time BIGINT, -- The latest time this token will be valid (epoch time in milliseconds). NULL if token doesn't expire. + UNIQUE (token) +); diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py new file mode 100644 index 0000000000..4927321e5a --- /dev/null +++ b/tests/rest/admin/test_registration_tokens.py @@ -0,0 +1,710 @@ +# Copyright 2021 Callum Brown +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import string + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client import login + +from tests import unittest + + +class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.url = "/_synapse/admin/v1/registration_tokens" + + def _new_token(self, **kwargs): + """Helper function to create a token.""" + token = kwargs.get( + "token", + "".join(random.choices(string.ascii_letters, k=8)), + ) + self.get_success( + self.store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": kwargs.get("uses_allowed", None), + "pending": kwargs.get("pending", 0), + "completed": kwargs.get("completed", 0), + "expiry_time": kwargs.get("expiry_time", None), + }, + ) + ) + return token + + # CREATION + + def test_create_no_auth(self): + """Try to create a token without authentication.""" + channel = self.make_request("POST", self.url + "/new", {}) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_create_requester_not_admin(self): + """Try to create a token while not an admin.""" + channel = self.make_request( + "POST", + self.url + "/new", + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_create_using_defaults(self): + """Create a token using all the defaults.""" + channel = self.make_request( + "POST", + self.url + "/new", + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["token"]), 16) + self.assertIsNone(channel.json_body["uses_allowed"]) + self.assertIsNone(channel.json_body["expiry_time"]) + self.assertEqual(channel.json_body["pending"], 0) + self.assertEqual(channel.json_body["completed"], 0) + + def test_create_specifying_fields(self): + """Create a token specifying the value of all fields.""" + data = { + "token": "abcd", + "uses_allowed": 1, + "expiry_time": self.clock.time_msec() + 1000000, + } + + channel = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["token"], "abcd") + self.assertEqual(channel.json_body["uses_allowed"], 1) + self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"]) + self.assertEqual(channel.json_body["pending"], 0) + self.assertEqual(channel.json_body["completed"], 0) + + def test_create_with_null_value(self): + """Create a token specifying unlimited uses and no expiry.""" + data = { + "uses_allowed": None, + "expiry_time": None, + } + + channel = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["token"]), 16) + self.assertIsNone(channel.json_body["uses_allowed"]) + self.assertIsNone(channel.json_body["expiry_time"]) + self.assertEqual(channel.json_body["pending"], 0) + self.assertEqual(channel.json_body["completed"], 0) + + def test_create_token_too_long(self): + """Check token longer than 64 chars is invalid.""" + data = {"token": "a" * 65} + + channel = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_token_invalid_chars(self): + """Check you can't create token with invalid characters.""" + data = { + "token": "abc/def", + } + + channel = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_token_already_exists(self): + """Check you can't create token that already exists.""" + data = { + "token": "abcd", + } + + channel1 = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"]) + + channel2 = self.make_request( + "POST", + self.url + "/new", + data, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"]) + self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_unable_to_generate_token(self): + """Check right error is raised when server can't generate unique token.""" + # Create all possible single character tokens + tokens = [] + for c in string.ascii_letters + string.digits + "-_": + tokens.append( + { + "token": c, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + } + ) + self.get_success( + self.store.db_pool.simple_insert_many( + "registration_tokens", + tokens, + "create_all_registration_tokens", + ) + ) + + # Check creating a single character token fails with a 500 status code + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 1}, + access_token=self.admin_user_tok, + ) + self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"]) + + def test_create_uses_allowed(self): + """Check you can only create a token with good values for uses_allowed.""" + # Should work with 0 (token is invalid from the start) + channel = self.make_request( + "POST", + self.url + "/new", + {"uses_allowed": 0}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["uses_allowed"], 0) + + # Should fail with negative integer + channel = self.make_request( + "POST", + self.url + "/new", + {"uses_allowed": -5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with float + channel = self.make_request( + "POST", + self.url + "/new", + {"uses_allowed": 1.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_expiry_time(self): + """Check you can't create a token with an invalid expiry_time.""" + # Should fail with a time in the past + channel = self.make_request( + "POST", + self.url + "/new", + {"expiry_time": self.clock.time_msec() - 10000}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with float + channel = self.make_request( + "POST", + self.url + "/new", + {"expiry_time": self.clock.time_msec() + 1000000.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_create_length(self): + """Check you can only generate a token with a valid length.""" + # Should work with 64 + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 64}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["token"]), 64) + + # Should fail with 0 + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 0}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with a negative integer + channel = self.make_request( + "POST", + self.url + "/new", + {"length": -5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with a float + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 8.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with 65 + channel = self.make_request( + "POST", + self.url + "/new", + {"length": 65}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # UPDATING + + def test_update_no_auth(self): + """Try to update a token without authentication.""" + channel = self.make_request( + "PUT", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_update_requester_not_admin(self): + """Try to update a token while not an admin.""" + channel = self.make_request( + "PUT", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_update_non_existent(self): + """Try to update a token that doesn't exist.""" + channel = self.make_request( + "PUT", + self.url + "/1234", + {"uses_allowed": 1}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_update_uses_allowed(self): + """Test updating just uses_allowed.""" + # Create new token using default values + token = self._new_token() + + # Should succeed with 1 + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": 1}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["uses_allowed"], 1) + self.assertIsNone(channel.json_body["expiry_time"]) + + # Should succeed with 0 (makes token invalid) + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": 0}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["uses_allowed"], 0) + self.assertIsNone(channel.json_body["expiry_time"]) + + # Should succeed with null + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": None}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertIsNone(channel.json_body["uses_allowed"]) + self.assertIsNone(channel.json_body["expiry_time"]) + + # Should fail with a float + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": 1.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail with a negative integer + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"uses_allowed": -5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_update_expiry_time(self): + """Test updating just expiry_time.""" + # Create new token using default values + token = self._new_token() + new_expiry_time = self.clock.time_msec() + 1000000 + + # Should succeed with a time in the future + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"expiry_time": new_expiry_time}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) + self.assertIsNone(channel.json_body["uses_allowed"]) + + # Should succeed with null + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"expiry_time": None}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertIsNone(channel.json_body["expiry_time"]) + self.assertIsNone(channel.json_body["uses_allowed"]) + + # Should fail with a time in the past + past_time = self.clock.time_msec() - 10000 + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"expiry_time": past_time}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # Should fail a float + channel = self.make_request( + "PUT", + self.url + "/" + token, + {"expiry_time": new_expiry_time + 0.5}, + access_token=self.admin_user_tok, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + def test_update_both(self): + """Test updating both uses_allowed and expiry_time.""" + # Create new token using default values + token = self._new_token() + new_expiry_time = self.clock.time_msec() + 1000000 + + data = { + "uses_allowed": 1, + "expiry_time": new_expiry_time, + } + + channel = self.make_request( + "PUT", + self.url + "/" + token, + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["uses_allowed"], 1) + self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) + + def test_update_invalid_type(self): + """Test using invalid types doesn't work.""" + # Create new token using default values + token = self._new_token() + + data = { + "uses_allowed": False, + "expiry_time": "1626430124000", + } + + channel = self.make_request( + "PUT", + self.url + "/" + token, + data, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) + + # DELETING + + def test_delete_no_auth(self): + """Try to delete a token without authentication.""" + channel = self.make_request( + "DELETE", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_delete_requester_not_admin(self): + """Try to delete a token while not an admin.""" + channel = self.make_request( + "DELETE", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_delete_non_existent(self): + """Try to delete a token that doesn't exist.""" + channel = self.make_request( + "DELETE", + self.url + "/1234", + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_delete(self): + """Test deleting a token.""" + # Create new token using default values + token = self._new_token() + + channel = self.make_request( + "DELETE", + self.url + "/" + token, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + # GETTING ONE + + def test_get_no_auth(self): + """Try to get a token without authentication.""" + channel = self.make_request( + "GET", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + ) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_get_requester_not_admin(self): + """Try to get a token while not an admin.""" + channel = self.make_request( + "GET", + self.url + "/1234", # Token doesn't exist but that doesn't matter + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_get_non_existent(self): + """Try to get a token that doesn't exist.""" + channel = self.make_request( + "GET", + self.url + "/1234", + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_get(self): + """Test getting a token.""" + # Create new token using default values + token = self._new_token() + + channel = self.make_request( + "GET", + self.url + "/" + token, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["token"], token) + self.assertIsNone(channel.json_body["uses_allowed"]) + self.assertIsNone(channel.json_body["expiry_time"]) + self.assertEqual(channel.json_body["pending"], 0) + self.assertEqual(channel.json_body["completed"], 0) + + # LISTING + + def test_list_no_auth(self): + """Try to list tokens without authentication.""" + channel = self.make_request("GET", self.url, {}) + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_list_requester_not_admin(self): + """Try to list tokens while not an admin.""" + channel = self.make_request( + "GET", + self.url, + {}, + access_token=self.other_user_tok, + ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_list_all(self): + """Test listing all tokens.""" + # Create new token using default values + token = self._new_token() + + channel = self.make_request( + "GET", + self.url, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["registration_tokens"]), 1) + token_info = channel.json_body["registration_tokens"][0] + self.assertEqual(token_info["token"], token) + self.assertIsNone(token_info["uses_allowed"]) + self.assertIsNone(token_info["expiry_time"]) + self.assertEqual(token_info["pending"], 0) + self.assertEqual(token_info["completed"], 0) + + def test_list_invalid_query_parameter(self): + """Test with `valid` query parameter not `true` or `false`.""" + channel = self.make_request( + "GET", + self.url + "?valid=x", + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + + def _test_list_query_parameter(self, valid: str): + """Helper used to test both valid=true and valid=false.""" + # Create 2 valid and 2 invalid tokens. + now = self.hs.get_clock().time_msec() + # Create always valid token + valid1 = self._new_token() + # Create token that hasn't been used up + valid2 = self._new_token(uses_allowed=1) + # Create token that has expired + invalid1 = self._new_token(expiry_time=now - 10000) + # Create token that has been used up but hasn't expired + invalid2 = self._new_token( + uses_allowed=2, + pending=1, + completed=1, + expiry_time=now + 1000000, + ) + + if valid == "true": + tokens = [valid1, valid2] + else: + tokens = [invalid1, invalid2] + + channel = self.make_request( + "GET", + self.url + "?valid=" + valid, + {}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(len(channel.json_body["registration_tokens"]), 2) + token_info_1 = channel.json_body["registration_tokens"][0] + token_info_2 = channel.json_body["registration_tokens"][1] + self.assertIn(token_info_1["token"], tokens) + self.assertIn(token_info_2["token"], tokens) + + def test_list_valid(self): + """Test listing just valid tokens.""" + self._test_list_query_parameter(valid="true") + + def test_list_invalid(self): + """Test listing just invalid tokens.""" + self._test_list_query_parameter(valid="false") diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index fecda037a5..9f3ab2c985 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -24,6 +24,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService from synapse.rest.client import account, account_validity, login, logout, register, sync +from synapse.storage._base import db_to_json from tests import unittest from tests.unittest import override_config @@ -204,6 +205,371 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config({"registration_requires_token": True}) + def test_POST_registration_requires_token(self): + username = "kermit" + device_id = "frogfone" + token = "abcd" + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + params = { + "username": username, + "password": "monkey", + "device_id": device_id, + } + + # Request without auth to get flows and session + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + flows = channel.json_body["flows"] + # Synapse adds a dummy stage to differentiate flows where otherwise one + # flow would be a subset of another flow. + self.assertCountEqual( + [[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]], + (f["stages"] for f in flows), + ) + session = channel.json_body["session"] + + # Do the registration token stage and check it has completed + params["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", self.url, request_data) + self.assertEquals(channel.result["code"], b"401", channel.result) + completed = channel.json_body["completed"] + self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) + + # Do the m.login.dummy stage and check registration was successful + params["auth"] = { + "type": LoginType.DUMMY, + "session": session, + } + request_data = json.dumps(params) + channel = self.make_request(b"POST", self.url, request_data) + det_data = { + "user_id": f"@{username}:{self.hs.hostname}", + "home_server": self.hs.hostname, + "device_id": device_id, + } + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertDictContainsSubset(det_data, channel.json_body) + + # Check the `completed` counter has been incremented and pending is 0 + res = self.get_success( + store.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["pending", "completed"], + ) + ) + self.assertEquals(res["completed"], 1) + self.assertEquals(res["pending"], 0) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_invalid(self): + params = { + "username": "kermit", + "password": "monkey", + } + # Request without auth to get session + channel = self.make_request(b"POST", self.url, json.dumps(params)) + session = channel.json_body["session"] + + # Test with token param missing (invalid) + params["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "session": session, + } + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM) + self.assertEquals(channel.json_body["completed"], []) + + # Test with non-string (invalid) + params["auth"]["token"] = 1234 + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM) + self.assertEquals(channel.json_body["completed"], []) + + # Test with unknown token (invalid) + params["auth"]["token"] = "1234" + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEquals(channel.json_body["completed"], []) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_limit_uses(self): + token = "abcd" + store = self.hs.get_datastore() + # Create token that can be used once + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": 1, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + params1 = {"username": "bert", "password": "monkey"} + params2 = {"username": "ernie", "password": "monkey"} + # Do 2 requests without auth to get two session IDs + channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) + session1 = channel1.json_body["session"] + channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) + session2 = channel2.json_body["session"] + + # Use token with session1 and check `pending` is 1 + params1["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session1, + } + self.make_request(b"POST", self.url, json.dumps(params1)) + # Repeat request to make sure pending isn't increased again + self.make_request(b"POST", self.url, json.dumps(params1)) + pending = self.get_success( + store.db_pool.simple_select_one_onecol( + "registration_tokens", + keyvalues={"token": token}, + retcol="pending", + ) + ) + self.assertEquals(pending, 1) + + # Check auth fails when using token with session2 + params2["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session2, + } + channel = self.make_request(b"POST", self.url, json.dumps(params2)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEquals(channel.json_body["completed"], []) + + # Complete registration with session1 + params1["auth"]["type"] = LoginType.DUMMY + self.make_request(b"POST", self.url, json.dumps(params1)) + # Check pending=0 and completed=1 + res = self.get_success( + store.db_pool.simple_select_one( + "registration_tokens", + keyvalues={"token": token}, + retcols=["pending", "completed"], + ) + ) + self.assertEquals(res["pending"], 0) + self.assertEquals(res["completed"], 1) + + # Check auth still fails when using token with session2 + channel = self.make_request(b"POST", self.url, json.dumps(params2)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEquals(channel.json_body["completed"], []) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_expiry(self): + token = "abcd" + now = self.hs.get_clock().time_msec() + store = self.hs.get_datastore() + # Create token that expired yesterday + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": now - 24 * 60 * 60 * 1000, + }, + ) + ) + params = {"username": "kermit", "password": "monkey"} + # Request without auth to get session + channel = self.make_request(b"POST", self.url, json.dumps(params)) + session = channel.json_body["session"] + + # Check authentication fails with expired token + params["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session, + } + channel = self.make_request(b"POST", self.url, json.dumps(params)) + self.assertEquals(channel.result["code"], b"401", channel.result) + self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED) + self.assertEquals(channel.json_body["completed"], []) + + # Update token so it expires tomorrow + self.get_success( + store.db_pool.simple_update_one( + "registration_tokens", + keyvalues={"token": token}, + updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000}, + ) + ) + + # Check authentication succeeds + channel = self.make_request(b"POST", self.url, json.dumps(params)) + completed = channel.json_body["completed"] + self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_session_expiry(self): + """Test `pending` is decremented when an uncompleted session expires.""" + token = "abcd" + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + + # Do 2 requests without auth to get two session IDs + params1 = {"username": "bert", "password": "monkey"} + params2 = {"username": "ernie", "password": "monkey"} + channel1 = self.make_request(b"POST", self.url, json.dumps(params1)) + session1 = channel1.json_body["session"] + channel2 = self.make_request(b"POST", self.url, json.dumps(params2)) + session2 = channel2.json_body["session"] + + # Use token with both sessions + params1["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session1, + } + self.make_request(b"POST", self.url, json.dumps(params1)) + + params2["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session2, + } + self.make_request(b"POST", self.url, json.dumps(params2)) + + # Complete registration with session1 + params1["auth"]["type"] = LoginType.DUMMY + self.make_request(b"POST", self.url, json.dumps(params1)) + + # Check `result` of registration token stage for session1 is `True` + result1 = self.get_success( + store.db_pool.simple_select_one_onecol( + "ui_auth_sessions_credentials", + keyvalues={ + "session_id": session1, + "stage_type": LoginType.REGISTRATION_TOKEN, + }, + retcol="result", + ) + ) + self.assertTrue(db_to_json(result1)) + + # Check `result` for session2 is the token used + result2 = self.get_success( + store.db_pool.simple_select_one_onecol( + "ui_auth_sessions_credentials", + keyvalues={ + "session_id": session2, + "stage_type": LoginType.REGISTRATION_TOKEN, + }, + retcol="result", + ) + ) + self.assertEquals(db_to_json(result2), token) + + # Delete both sessions (mimics expiry) + self.get_success( + store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec()) + ) + + # Check pending is now 0 + pending = self.get_success( + store.db_pool.simple_select_one_onecol( + "registration_tokens", + keyvalues={"token": token}, + retcol="pending", + ) + ) + self.assertEquals(pending, 0) + + @override_config({"registration_requires_token": True}) + def test_POST_registration_token_session_expiry_deleted_token(self): + """Test session expiry doesn't break when the token is deleted. + + 1. Start but don't complete UIA with a registration token + 2. Delete the token from the database + 3. Expire the session + """ + token = "abcd" + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + + # Do request without auth to get a session ID + params = {"username": "kermit", "password": "monkey"} + channel = self.make_request(b"POST", self.url, json.dumps(params)) + session = channel.json_body["session"] + + # Use token + params["auth"] = { + "type": LoginType.REGISTRATION_TOKEN, + "token": token, + "session": session, + } + self.make_request(b"POST", self.url, json.dumps(params)) + + # Delete token + self.get_success( + store.db_pool.simple_delete_one( + "registration_tokens", + keyvalues={"token": token}, + ) + ) + + # Delete session (mimics expiry) + self.get_success( + store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec()) + ) + def test_advertised_flows(self): channel = self.make_request(b"POST", self.url, b"{}") self.assertEquals(channel.result["code"], b"401", channel.result) @@ -744,3 +1110,71 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta) self.assertLessEqual(res, now_ms + self.validity_period) + + +class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): + servlets = [register.register_servlets] + url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity" + + def default_config(self): + config = super().default_config() + config["registration_requires_token"] = True + return config + + def test_GET_token_valid(self): + token = "abcd" + store = self.hs.get_datastore() + self.get_success( + store.db_pool.simple_insert( + "registration_tokens", + { + "token": token, + "uses_allowed": None, + "pending": 0, + "completed": 0, + "expiry_time": None, + }, + ) + ) + + channel = self.make_request( + b"GET", + f"{self.url}?token={token}", + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEquals(channel.json_body["valid"], True) + + def test_GET_token_invalid(self): + token = "1234" + channel = self.make_request( + b"GET", + f"{self.url}?token={token}", + ) + self.assertEquals(channel.result["code"], b"200", channel.result) + self.assertEquals(channel.json_body["valid"], False) + + @override_config( + {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}} + ) + def test_GET_ratelimiting(self): + token = "1234" + + for i in range(0, 6): + channel = self.make_request( + b"GET", + f"{self.url}?token={token}", + ) + + if i == 5: + self.assertEquals(channel.result["code"], b"429", channel.result) + retry_after_ms = int(channel.json_body["retry_after_ms"]) + else: + self.assertEquals(channel.result["code"], b"200", channel.result) + + self.reactor.advance(retry_after_ms / 1000.0 + 1.0) + + channel = self.make_request( + b"GET", + f"{self.url}?token={token}", + ) + self.assertEquals(channel.result["code"], b"200", channel.result) -- cgit 1.5.1 From 4db65f911ab338ea2a9465f3ee8bee3832bd0346 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 23 Aug 2021 11:12:45 +0100 Subject: Run a nightly CI build against Twisted trunk. (#10651) This creates a GHA workflow which runs at 8am every day, and runs mypy, trial and sytest against Twisted's current trunk. If any of the jobs fail, it opens an issue. --- .ci/patch_for_twisted_trunk.sh | 8 +++ .ci/twisted_trunk_build_failed_issue_template.md | 4 ++ .github/workflows/twisted_trunk.yml | 89 ++++++++++++++++++++++++ changelog.d/10651.misc | 1 + 4 files changed, 102 insertions(+) create mode 100755 .ci/patch_for_twisted_trunk.sh create mode 100644 .ci/twisted_trunk_build_failed_issue_template.md create mode 100644 .github/workflows/twisted_trunk.yml create mode 100644 changelog.d/10651.misc diff --git a/.ci/patch_for_twisted_trunk.sh b/.ci/patch_for_twisted_trunk.sh new file mode 100755 index 0000000000..f524581986 --- /dev/null +++ b/.ci/patch_for_twisted_trunk.sh @@ -0,0 +1,8 @@ +#!/bin/sh + +# replaces the dependency on Twisted in `python_dependencies` with trunk. + +set -e +cd "$(dirname "$0")"/.. + +sed -i -e 's#"Twisted.*"#"Twisted @ git+https://github.com/twisted/twisted"#' synapse/python_dependencies.py diff --git a/.ci/twisted_trunk_build_failed_issue_template.md b/.ci/twisted_trunk_build_failed_issue_template.md new file mode 100644 index 0000000000..2ead1dc394 --- /dev/null +++ b/.ci/twisted_trunk_build_failed_issue_template.md @@ -0,0 +1,4 @@ +--- +title: CI run against Twisted trunk is failing +--- +See https://github.com/{{env.GITHUB_REPOSITORY}}/actions/runs/{{env.GITHUB_RUN_ID}} diff --git a/.github/workflows/twisted_trunk.yml b/.github/workflows/twisted_trunk.yml new file mode 100644 index 0000000000..0bf7790545 --- /dev/null +++ b/.github/workflows/twisted_trunk.yml @@ -0,0 +1,89 @@ +name: Twisted Trunk + +on: + schedule: + - cron: 0 8 * * * + + workflow_dispatch: + +jobs: + mypy: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + - run: .ci/patch_for_twisted_trunk.sh + - run: pip install tox + - run: tox -e mypy + + trial: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - run: sudo apt-get -qq install xmlsec1 + - uses: actions/setup-python@v2 + with: + python-version: 3.6 + - run: .ci/patch_for_twisted_trunk.sh + - run: pip install tox + - run: tox -e py + env: + TRIAL_FLAGS: "--jobs=2" + + - name: Dump logs + # Note: Dumps to workflow logs instead of using actions/upload-artifact + # This keeps logs colocated with failing jobs + # It also ignores find's exit code; this is a best effort affair + run: >- + find _trial_temp -name '*.log' + -exec echo "::group::{}" \; + -exec cat {} \; + -exec echo "::endgroup::" \; + || true + + sytest: + runs-on: ubuntu-latest + container: + image: matrixdotorg/sytest-synapse:buster + volumes: + - ${{ github.workspace }}:/src + + steps: + - uses: actions/checkout@v2 + - name: Patch dependencies + run: .ci/patch_for_twisted_trunk.sh + working-directory: /src + - name: Run SyTest + run: /bootstrap.sh synapse + working-directory: /src + - name: Summarise results.tap + if: ${{ always() }} + run: /sytest/scripts/tap_to_gha.pl /logs/results.tap + - name: Upload SyTest logs + uses: actions/upload-artifact@v2 + if: ${{ always() }} + with: + name: Sytest Logs - ${{ job.status }} - (${{ join(matrix.*, ', ') }}) + path: | + /logs/results.tap + /logs/**/*.log* + + # open an issue if the build fails, so we know about it. + open-issue: + if: failure() + needs: + - mypy + - trial + - sytest + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - uses: JasonEtco/create-an-issue@v2 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + filename: .ci/twisted_trunk_build_failed_issue_template.md diff --git a/changelog.d/10651.misc b/changelog.d/10651.misc new file mode 100644 index 0000000000..7104c121e0 --- /dev/null +++ b/changelog.d/10651.misc @@ -0,0 +1 @@ +Run a nightly CI build against Twisted trunk. -- cgit 1.5.1 From 31dac7ffeeb02f68d1dbe068fd241239e02208dc Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 23 Aug 2021 08:00:25 -0400 Subject: Do not include stack traces for known exceptions when trying multiple federation destinations. (#10662) --- changelog.d/10662.misc | 1 + synapse/federation/federation_client.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10662.misc diff --git a/changelog.d/10662.misc b/changelog.d/10662.misc new file mode 100644 index 0000000000..593f9ceaad --- /dev/null +++ b/changelog.d/10662.misc @@ -0,0 +1 @@ +Do not print out stack traces for network errors when fetching data over federation. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 29979414e3..44d9e8a5c7 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -43,6 +43,7 @@ from synapse.api.errors import ( Codes, FederationDeniedError, HttpResponseException, + RequestSendFailed, SynapseError, UnsupportedRoomVersionError, ) @@ -558,7 +559,11 @@ class FederationClient(FederationBase): try: return await callback(destination) - except InvalidResponseError as e: + except ( + RequestSendFailed, + InvalidResponseError, + NotRetryingDestination, + ) as e: logger.warning("Failed to %s via %s: %s", description, destination, e) except UnsupportedRoomVersionError: raise -- cgit 1.5.1 From 2af6d31b78109a989e27128ac655990c35b29d62 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 23 Aug 2021 08:14:17 -0400 Subject: Addtional type hints for the REST servlets. (#10665) --- changelog.d/10665.misc | 1 + synapse/rest/client/account_validity.py | 39 ++++++------ synapse/rest/client/capabilities.py | 3 +- synapse/rest/client/directory.py | 78 +++++++++++++++--------- synapse/rest/client/pusher.py | 22 ++++--- synapse/rest/client/read_marker.py | 15 ++++- synapse/rest/client/room_upgrade_rest_servlet.py | 18 ++++-- synapse/rest/client/shared_rooms.py | 16 +++-- synapse/rest/client/tags.py | 25 ++++++-- synapse/rest/client/thirdparty.py | 36 +++++++---- synapse/rest/client/tokenrefresh.py | 14 ++++- synapse/rest/client/user_directory.py | 17 +++--- synapse/rest/client/versions.py | 14 ++++- synapse/rest/client/voip.py | 13 +++- 14 files changed, 204 insertions(+), 107 deletions(-) create mode 100644 changelog.d/10665.misc diff --git a/changelog.d/10665.misc b/changelog.d/10665.misc new file mode 100644 index 0000000000..39a37b90b1 --- /dev/null +++ b/changelog.d/10665.misc @@ -0,0 +1 @@ +Add missing type hints to REST servlets. diff --git a/synapse/rest/client/account_validity.py b/synapse/rest/client/account_validity.py index 3ebe401861..6c24b96c54 100644 --- a/synapse/rest/client/account_validity.py +++ b/synapse/rest/client/account_validity.py @@ -13,24 +13,27 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple -from synapse.api.errors import SynapseError -from synapse.http.server import respond_with_html -from synapse.http.servlet import RestServlet +from twisted.web.server import Request + +from synapse.http.server import HttpServer, respond_with_html +from synapse.http.servlet import RestServlet, parse_string +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class AccountValidityRenewServlet(RestServlet): PATTERNS = client_patterns("/account_validity/renew$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs @@ -46,18 +49,14 @@ class AccountValidityRenewServlet(RestServlet): hs.config.account_validity.account_validity_invalid_token_template ) - async def on_GET(self, request): - if b"token" not in request.args: - raise SynapseError(400, "Missing renewal token") - renewal_token = request.args[b"token"][0] + async def on_GET(self, request: Request) -> None: + renewal_token = parse_string(request, "token", required=True) ( token_valid, token_stale, expiration_ts, - ) = await self.account_activity_handler.renew_account( - renewal_token.decode("utf8") - ) + ) = await self.account_activity_handler.renew_account(renewal_token) if token_valid: status_code = 200 @@ -77,11 +76,7 @@ class AccountValidityRenewServlet(RestServlet): class AccountValiditySendMailServlet(RestServlet): PATTERNS = client_patterns("/account_validity/send_mail$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs @@ -91,7 +86,7 @@ class AccountValiditySendMailServlet(RestServlet): hs.config.account_validity.account_validity_renew_by_email_enabled ) - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() await self.account_activity_handler.send_renewal_email_to_user(user_id) @@ -99,6 +94,6 @@ class AccountValiditySendMailServlet(RestServlet): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: AccountValidityRenewServlet(hs).register(http_server) AccountValiditySendMailServlet(hs).register(http_server) diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index 093549512e..65b3b5ce2c 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet from synapse.http.site import SynapseRequest from synapse.types import JsonDict @@ -75,5 +76,5 @@ class CapabilitiesRestServlet(RestServlet): return 200, response -def register_servlets(hs: "HomeServer", http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: CapabilitiesRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py index ffa075c8e5..ee247e3d1e 100644 --- a/synapse/rest/client/directory.py +++ b/synapse/rest/client/directory.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. - import logging +from typing import TYPE_CHECKING, Tuple + +from twisted.web.server import Request from synapse.api.errors import ( AuthError, @@ -22,14 +24,19 @@ from synapse.api.errors import ( NotFoundError, SynapseError, ) +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.types import RoomAlias +from synapse.types import JsonDict, RoomAlias + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ClientDirectoryServer(hs).register(http_server) ClientDirectoryListServer(hs).register(http_server) ClientAppserviceDirectoryListServer(hs).register(http_server) @@ -38,21 +45,23 @@ def register_servlets(hs, http_server): class ClientDirectoryServer(RestServlet): PATTERNS = client_patterns("/directory/room/(?P[^/]*)$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.store = hs.get_datastore() self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() - async def on_GET(self, request, room_alias): - room_alias = RoomAlias.from_string(room_alias) + async def on_GET(self, request: Request, room_alias: str) -> Tuple[int, JsonDict]: + room_alias_obj = RoomAlias.from_string(room_alias) - res = await self.directory_handler.get_association(room_alias) + res = await self.directory_handler.get_association(room_alias_obj) return 200, res - async def on_PUT(self, request, room_alias): - room_alias = RoomAlias.from_string(room_alias) + async def on_PUT( + self, request: SynapseRequest, room_alias: str + ) -> Tuple[int, JsonDict]: + room_alias_obj = RoomAlias.from_string(room_alias) content = parse_json_object_from_request(request) if "room_id" not in content: @@ -61,7 +70,7 @@ class ClientDirectoryServer(RestServlet): ) logger.debug("Got content: %s", content) - logger.debug("Got room name: %s", room_alias.to_string()) + logger.debug("Got room name: %s", room_alias_obj.to_string()) room_id = content["room_id"] servers = content["servers"] if "servers" in content else None @@ -78,22 +87,25 @@ class ClientDirectoryServer(RestServlet): requester = await self.auth.get_user_by_req(request) await self.directory_handler.create_association( - requester, room_alias, room_id, servers + requester, room_alias_obj, room_id, servers ) return 200, {} - async def on_DELETE(self, request, room_alias): + async def on_DELETE( + self, request: SynapseRequest, room_alias: str + ) -> Tuple[int, JsonDict]: + room_alias_obj = RoomAlias.from_string(room_alias) + try: service = self.auth.get_appservice_by_req(request) - room_alias = RoomAlias.from_string(room_alias) await self.directory_handler.delete_appservice_association( - service, room_alias + service, room_alias_obj ) logger.info( "Application service at %s deleted alias %s", service.url, - room_alias.to_string(), + room_alias_obj.to_string(), ) return 200, {} except InvalidClientCredentialsError: @@ -103,12 +115,10 @@ class ClientDirectoryServer(RestServlet): requester = await self.auth.get_user_by_req(request) user = requester.user - room_alias = RoomAlias.from_string(room_alias) - - await self.directory_handler.delete_association(requester, room_alias) + await self.directory_handler.delete_association(requester, room_alias_obj) logger.info( - "User %s deleted alias %s", user.to_string(), room_alias.to_string() + "User %s deleted alias %s", user.to_string(), room_alias_obj.to_string() ) return 200, {} @@ -117,20 +127,22 @@ class ClientDirectoryServer(RestServlet): class ClientDirectoryListServer(RestServlet): PATTERNS = client_patterns("/directory/list/room/(?P[^/]*)$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.store = hs.get_datastore() self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() - async def on_GET(self, request, room_id): + async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]: room = await self.store.get_room(room_id) if room is None: raise NotFoundError("Unknown room") return 200, {"visibility": "public" if room["is_public"] else "private"} - async def on_PUT(self, request, room_id): + async def on_PUT( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) @@ -142,7 +154,9 @@ class ClientDirectoryListServer(RestServlet): return 200, {} - async def on_DELETE(self, request, room_id): + async def on_DELETE( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await self.directory_handler.edit_published_room_list( @@ -157,21 +171,27 @@ class ClientAppserviceDirectoryListServer(RestServlet): "/directory/list/appservice/(?P[^/]*)/(?P[^/]*)$", v1=True ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.store = hs.get_datastore() self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() - def on_PUT(self, request, network_id, room_id): + async def on_PUT( + self, request: SynapseRequest, network_id: str, room_id: str + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) visibility = content.get("visibility", "public") - return self._edit(request, network_id, room_id, visibility) + return await self._edit(request, network_id, room_id, visibility) - def on_DELETE(self, request, network_id, room_id): - return self._edit(request, network_id, room_id, "private") + async def on_DELETE( + self, request: SynapseRequest, network_id: str, room_id: str + ) -> Tuple[int, JsonDict]: + return await self._edit(request, network_id, room_id, "private") - async def _edit(self, request, network_id, room_id, visibility): + async def _edit( + self, request: SynapseRequest, network_id: str, room_id: str, visibility: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if not requester.app_service: raise AuthError( diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py index 84619c5e41..98604a9388 100644 --- a/synapse/rest/client/pusher.py +++ b/synapse/rest/client/pusher.py @@ -13,17 +13,23 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, StoreError, SynapseError -from synapse.http.server import respond_with_html_bytes +from synapse.http.server import HttpServer, respond_with_html_bytes from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.push import PusherConfigException from synapse.rest.client._base import client_patterns +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -31,12 +37,12 @@ logger = logging.getLogger(__name__) class PushersRestServlet(RestServlet): PATTERNS = client_patterns("/pushers$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user = requester.user @@ -50,14 +56,14 @@ class PushersRestServlet(RestServlet): class PushersSetRestServlet(RestServlet): PATTERNS = client_patterns("/pushers/set$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.notifier = hs.get_notifier() self.pusher_pool = self.hs.get_pusherpool() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user = requester.user @@ -132,14 +138,14 @@ class PushersRemoveRestServlet(RestServlet): PATTERNS = client_patterns("/pushers/remove$", v1=True) SUCCESS_HTML = b"You have been unsubscribed" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.notifier = hs.get_notifier() self.auth = hs.get_auth() self.pusher_pool = self.hs.get_pusherpool() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> None: requester = await self.auth.get_user_by_req(request, rights="delete_pusher") user = requester.user @@ -165,7 +171,7 @@ class PushersRemoveRestServlet(RestServlet): return None -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: PushersRestServlet(hs).register(http_server) PushersSetRestServlet(hs).register(http_server) PushersRemoveRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 027f8b81fa..43c04fac6f 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -13,27 +13,36 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.constants import ReadReceiptEventFields from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class ReadMarkerRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/read_markers$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.receipts_handler = hs.get_receipts_handler() self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() - async def on_POST(self, request, room_id): + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) await self.presence_handler.bump_presence_active_time(requester.user) @@ -70,5 +79,5 @@ class ReadMarkerRestServlet(RestServlet): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReadMarkerRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room_upgrade_rest_servlet.py b/synapse/rest/client/room_upgrade_rest_servlet.py index 6d1b083acb..6a7792e18b 100644 --- a/synapse/rest/client/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/room_upgrade_rest_servlet.py @@ -13,18 +13,25 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, ShadowBanError, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, ) +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from synapse.util import stringutils from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -41,9 +48,6 @@ class RoomUpgradeRestServlet(RestServlet): } Creates a new room and shuts down the old one. Returns the ID of the new room. - - Args: - hs (synapse.server.HomeServer): """ PATTERNS = client_patterns( @@ -51,13 +55,15 @@ class RoomUpgradeRestServlet(RestServlet): "/rooms/(?P[^/]*)/upgrade$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self._hs = hs self._room_creation_handler = hs.get_room_creation_handler() self._auth = hs.get_auth() - async def on_POST(self, request, room_id): + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self._auth.get_user_by_req(request) content = parse_json_object_from_request(request) @@ -84,5 +90,5 @@ class RoomUpgradeRestServlet(RestServlet): return 200, ret -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomUpgradeRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py index d2e7f04b40..1d90493eb0 100644 --- a/synapse/rest/client/shared_rooms.py +++ b/synapse/rest/client/shared_rooms.py @@ -12,13 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet -from synapse.types import UserID +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict, UserID from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -32,13 +38,15 @@ class UserSharedRoomsServlet(RestServlet): releases=(), # This is an unstable feature ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.user_directory_active = hs.config.update_user_directory - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: if not self.user_directory_active: raise SynapseError( @@ -63,5 +71,5 @@ class UserSharedRoomsServlet(RestServlet): return 200, {"joined": list(rooms)} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserSharedRoomsServlet(hs).register(http_server) diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py index c14f83be18..c88cb9367c 100644 --- a/synapse/rest/client/tags.py +++ b/synapse/rest/client/tags.py @@ -13,12 +13,19 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import AuthError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -29,12 +36,14 @@ class TagListServlet(RestServlet): PATTERNS = client_patterns("/user/(?P[^/]*)/rooms/(?P[^/]*)/tags") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, user_id, room_id): + async def on_GET( + self, request: SynapseRequest, user_id: str, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get tags for other users.") @@ -54,12 +63,14 @@ class TagServlet(RestServlet): "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags/(?P[^/]*)" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.handler = hs.get_account_data_handler() - async def on_PUT(self, request, user_id, room_id, tag): + async def on_PUT( + self, request: SynapseRequest, user_id: str, room_id: str, tag: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") @@ -70,7 +81,9 @@ class TagServlet(RestServlet): return 200, {} - async def on_DELETE(self, request, user_id, room_id, tag): + async def on_DELETE( + self, request: SynapseRequest, user_id: str, room_id: str, tag: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") @@ -80,6 +93,6 @@ class TagServlet(RestServlet): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: TagListServlet(hs).register(http_server) TagServlet(hs).register(http_server) diff --git a/synapse/rest/client/thirdparty.py b/synapse/rest/client/thirdparty.py index b5c67c9bb6..b895c73acf 100644 --- a/synapse/rest/client/thirdparty.py +++ b/synapse/rest/client/thirdparty.py @@ -12,27 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. - import logging +from typing import TYPE_CHECKING, Dict, List, Tuple from synapse.api.constants import ThirdPartyEntityKind +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class ThirdPartyProtocolsServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/protocols") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) protocols = await self.appservice_handler.get_3pe_protocols() @@ -42,13 +48,15 @@ class ThirdPartyProtocolsServlet(RestServlet): class ThirdPartyProtocolServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/protocol/(?P[^/]+)$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - async def on_GET(self, request, protocol): + async def on_GET( + self, request: SynapseRequest, protocol: str + ) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) protocols = await self.appservice_handler.get_3pe_protocols( @@ -63,16 +71,18 @@ class ThirdPartyProtocolServlet(RestServlet): class ThirdPartyUserServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/user(/(?P[^/]+))?$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - async def on_GET(self, request, protocol): + async def on_GET( + self, request: SynapseRequest, protocol: str + ) -> Tuple[int, List[JsonDict]]: await self.auth.get_user_by_req(request, allow_guest=True) - fields = request.args + fields: Dict[bytes, List[bytes]] = request.args # type: ignore[assignment] fields.pop(b"access_token", None) results = await self.appservice_handler.query_3pe( @@ -85,16 +95,18 @@ class ThirdPartyUserServlet(RestServlet): class ThirdPartyLocationServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/location(/(?P[^/]+))?$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - async def on_GET(self, request, protocol): + async def on_GET( + self, request: SynapseRequest, protocol: str + ) -> Tuple[int, List[JsonDict]]: await self.auth.get_user_by_req(request, allow_guest=True) - fields = request.args + fields: Dict[bytes, List[bytes]] = request.args # type: ignore[assignment] fields.pop(b"access_token", None) results = await self.appservice_handler.query_3pe( @@ -104,7 +116,7 @@ class ThirdPartyLocationServlet(RestServlet): return 200, results -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ThirdPartyProtocolsServlet(hs).register(http_server) ThirdPartyProtocolServlet(hs).register(http_server) ThirdPartyUserServlet(hs).register(http_server) diff --git a/synapse/rest/client/tokenrefresh.py b/synapse/rest/client/tokenrefresh.py index b2f858545c..c8c3b25bd3 100644 --- a/synapse/rest/client/tokenrefresh.py +++ b/synapse/rest/client/tokenrefresh.py @@ -12,11 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + +from twisted.web.server import Request + from synapse.api.errors import AuthError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + class TokenRefreshRestServlet(RestServlet): """ @@ -26,12 +34,12 @@ class TokenRefreshRestServlet(RestServlet): PATTERNS = client_patterns("/tokenrefresh") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() - async def on_POST(self, request): + async def on_POST(self, request: Request) -> None: raise AuthError(403, "tokenrefresh is no longer supported.") -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: TokenRefreshRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py index 7e8912f0b9..8852811114 100644 --- a/synapse/rest/client/user_directory.py +++ b/synapse/rest/client/user_directory.py @@ -13,29 +13,32 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class UserDirectorySearchRestServlet(RestServlet): PATTERNS = client_patterns("/user_directory/search$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.user_directory_handler = hs.get_user_directory_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """Searches for users in directory Returns: @@ -75,5 +78,5 @@ class UserDirectorySearchRestServlet(RestServlet): return 200, results -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: UserDirectorySearchRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index fa2e4e9cba..a1a815cf82 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -17,9 +17,17 @@ import logging import re +from typing import TYPE_CHECKING, Tuple + +from twisted.web.server import Request from synapse.api.constants import RoomCreationPreset +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -27,7 +35,7 @@ logger = logging.getLogger(__name__) class VersionsRestServlet(RestServlet): PATTERNS = [re.compile("^/_matrix/client/versions$")] - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.config = hs.config @@ -45,7 +53,7 @@ class VersionsRestServlet(RestServlet): in self.config.encryption_enabled_by_default_for_room_presets ) - def on_GET(self, request): + def on_GET(self, request: Request) -> Tuple[int, JsonDict]: return ( 200, { @@ -89,5 +97,5 @@ class VersionsRestServlet(RestServlet): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: VersionsRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/voip.py b/synapse/rest/client/voip.py index f53020520d..9d46ed3af3 100644 --- a/synapse/rest/client/voip.py +++ b/synapse/rest/client/voip.py @@ -15,20 +15,27 @@ import base64 import hashlib import hmac +from typing import TYPE_CHECKING, Tuple +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer class VoipRestServlet(RestServlet): PATTERNS = client_patterns("/voip/turnServer$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req( request, self.hs.config.turn_allow_guests ) @@ -69,5 +76,5 @@ class VoipRestServlet(RestServlet): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: VoipRestServlet(hs).register(http_server) -- cgit 1.5.1 From bd7d398b05aaa18d5b0629153ababeea7539256c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 23 Aug 2021 08:14:42 -0400 Subject: Additional type hints for the sync REST servlet. (#10666) --- changelog.d/10666.misc | 1 + synapse/handlers/sync.py | 21 +++---- synapse/rest/client/sync.py | 132 +++++++++++++++++++++++++++----------------- 3 files changed, 93 insertions(+), 61 deletions(-) create mode 100644 changelog.d/10666.misc diff --git a/changelog.d/10666.misc b/changelog.d/10666.misc new file mode 100644 index 0000000000..39a37b90b1 --- /dev/null +++ b/changelog.d/10666.misc @@ -0,0 +1 @@ +Add missing type hints to REST servlets. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 2203c45dcc..86c3c7f0df 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -30,6 +30,7 @@ from prometheus_client import Counter from synapse.api.constants import AccountDataTypes, EventTypes, Membership from synapse.api.filtering import FilterCollection +from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.logging.context import current_context @@ -231,7 +232,7 @@ class SyncResult: """ next_batch: StreamToken - presence: List[JsonDict] + presence: List[UserPresenceState] account_data: List[JsonDict] joined: List[JoinedSyncResult] invited: List[InvitedSyncResult] @@ -2177,14 +2178,14 @@ class SyncResultBuilder: joined_room_ids: List of rooms the user is joined to # The following mirror the fields in a sync response - presence (list) - account_data (list) - joined (list[JoinedSyncResult]) - invited (list[InvitedSyncResult]) - knocked (list[KnockedSyncResult]) - archived (list[ArchivedSyncResult]) - groups (GroupsSyncResult|None) - to_device (list) + presence + account_data + joined + invited + knocked + archived + groups + to_device """ sync_config: SyncConfig @@ -2193,7 +2194,7 @@ class SyncResultBuilder: now_token: StreamToken joined_room_ids: FrozenSet[str] - presence: List[JsonDict] = attr.Factory(list) + presence: List[UserPresenceState] = attr.Factory(list) account_data: List[JsonDict] = attr.Factory(list) joined: List[JoinedSyncResult] = attr.Factory(list) invited: List[InvitedSyncResult] = attr.Factory(list) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index e18f4d01b3..65c37be3e9 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -14,17 +14,26 @@ import itertools import logging from collections import defaultdict -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from synapse.api.constants import Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection +from synapse.api.presence import UserPresenceState from synapse.events.utils import ( format_event_for_client_v2_without_room_id, format_event_raw, ) from synapse.handlers.presence import format_user_presence_state -from synapse.handlers.sync import KnockedSyncResult, SyncConfig +from synapse.handlers.sync import ( + ArchivedSyncResult, + InvitedSyncResult, + JoinedSyncResult, + KnockedSyncResult, + SyncConfig, + SyncResult, +) +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.types import JsonDict, StreamToken @@ -192,6 +201,8 @@ class SyncRestServlet(RestServlet): return 200, {} time_now = self.clock.time_msec() + # We know that the the requester has an access token since appservices + # cannot use sync. response_content = await self.encode_response( time_now, sync_result, requester.access_token_id, filter_collection ) @@ -199,7 +210,13 @@ class SyncRestServlet(RestServlet): logger.debug("Event formatting complete") return 200, response_content - async def encode_response(self, time_now, sync_result, access_token_id, filter): + async def encode_response( + self, + time_now: int, + sync_result: SyncResult, + access_token_id: Optional[int], + filter: FilterCollection, + ) -> JsonDict: logger.debug("Formatting events in sync response") if filter.event_format == "client": event_formatter = format_event_for_client_v2_without_room_id @@ -234,7 +251,7 @@ class SyncRestServlet(RestServlet): logger.debug("building sync response dict") - response: dict = defaultdict(dict) + response: JsonDict = defaultdict(dict) response["next_batch"] = await sync_result.next_batch.to_string(self.store) if sync_result.account_data: @@ -274,6 +291,8 @@ class SyncRestServlet(RestServlet): if archived: response["rooms"][Membership.LEAVE] = archived + # By the time we get here groups is no longer optional. + assert sync_result.groups is not None if sync_result.groups.join: response["groups"][Membership.JOIN] = sync_result.groups.join if sync_result.groups.invite: @@ -284,7 +303,7 @@ class SyncRestServlet(RestServlet): return response @staticmethod - def encode_presence(events, time_now): + def encode_presence(events: List[UserPresenceState], time_now: int) -> JsonDict: return { "events": [ { @@ -299,25 +318,27 @@ class SyncRestServlet(RestServlet): } async def encode_joined( - self, rooms, time_now, token_id, event_fields, event_formatter - ): + self, + rooms: List[JoinedSyncResult], + time_now: int, + token_id: Optional[int], + event_fields: List[str], + event_formatter: Callable[[JsonDict], JsonDict], + ) -> JsonDict: """ Encode the joined rooms in a sync result Args: - rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync - results for rooms this user is joined to - time_now(int): current time - used as a baseline for age - calculations - token_id(int): ID of the user's auth token - used for namespacing + rooms: list of sync results for rooms this user is joined to + time_now: current time - used as a baseline for age calculations + token_id: ID of the user's auth token - used for namespacing of transaction IDs - event_fields(list): List of event fields to include. If empty, + event_fields: List of event fields to include. If empty, all fields will be returned. - event_formatter (func[dict]): function to convert from federation format + event_formatter: function to convert from federation format to client format Returns: - dict[str, dict[str, object]]: the joined rooms list, in our - response format + The joined rooms list, in our response format """ joined = {} for room in rooms: @@ -332,23 +353,26 @@ class SyncRestServlet(RestServlet): return joined - async def encode_invited(self, rooms, time_now, token_id, event_formatter): + async def encode_invited( + self, + rooms: List[InvitedSyncResult], + time_now: int, + token_id: Optional[int], + event_formatter: Callable[[JsonDict], JsonDict], + ) -> JsonDict: """ Encode the invited rooms in a sync result Args: - rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of - sync results for rooms this user is invited to - time_now(int): current time - used as a baseline for age - calculations - token_id(int): ID of the user's auth token - used for namespacing + rooms: list of sync results for rooms this user is invited to + time_now: current time - used as a baseline for age calculations + token_id: ID of the user's auth token - used for namespacing of transaction IDs - event_formatter (func[dict]): function to convert from federation format + event_formatter: function to convert from federation format to client format Returns: - dict[str, dict[str, object]]: the invited rooms list, in our - response format + The invited rooms list, in our response format """ invited = {} for room in rooms: @@ -371,7 +395,7 @@ class SyncRestServlet(RestServlet): self, rooms: List[KnockedSyncResult], time_now: int, - token_id: int, + token_id: Optional[int], event_formatter: Callable[[Dict], Dict], ) -> Dict[str, Dict[str, Any]]: """ @@ -422,25 +446,26 @@ class SyncRestServlet(RestServlet): return knocked async def encode_archived( - self, rooms, time_now, token_id, event_fields, event_formatter - ): + self, + rooms: List[ArchivedSyncResult], + time_now: int, + token_id: Optional[int], + event_fields: List[str], + event_formatter: Callable[[JsonDict], JsonDict], + ) -> JsonDict: """ Encode the archived rooms in a sync result Args: - rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of - sync results for rooms this user is joined to - time_now(int): current time - used as a baseline for age - calculations - token_id(int): ID of the user's auth token - used for namespacing + rooms: list of sync results for rooms this user is joined to + time_now: current time - used as a baseline for age calculations + token_id: ID of the user's auth token - used for namespacing of transaction IDs - event_fields(list): List of event fields to include. If empty, + event_fields: List of event fields to include. If empty, all fields will be returned. - event_formatter (func[dict]): function to convert from federation format - to client format + event_formatter: function to convert from federation format to client format Returns: - dict[str, dict[str, object]]: The invited rooms list, in our - response format + The archived rooms list, in our response format """ joined = {} for room in rooms: @@ -456,23 +481,27 @@ class SyncRestServlet(RestServlet): return joined async def encode_room( - self, room, time_now, token_id, joined, only_fields, event_formatter - ): + self, + room: Union[JoinedSyncResult, ArchivedSyncResult], + time_now: int, + token_id: Optional[int], + joined: bool, + only_fields: Optional[List[str]], + event_formatter: Callable[[JsonDict], JsonDict], + ) -> JsonDict: """ Args: - room (JoinedSyncResult|ArchivedSyncResult): sync result for a - single room - time_now (int): current time - used as a baseline for age - calculations - token_id (int): ID of the user's auth token - used for namespacing + room: sync result for a single room + time_now: current time - used as a baseline for age calculations + token_id: ID of the user's auth token - used for namespacing of transaction IDs - joined (bool): True if the user is joined to this room - will mean + joined: True if the user is joined to this room - will mean we handle ephemeral events - only_fields(list): Optional. The list of event fields to include. - event_formatter (func[dict]): function to convert from federation format + only_fields: Optional. The list of event fields to include. + event_formatter: function to convert from federation format to client format Returns: - dict[str, object]: the room, encoded in our response format + The room, encoded in our response format """ def serialize(events): @@ -508,7 +537,7 @@ class SyncRestServlet(RestServlet): account_data = room.account_data - result = { + result: JsonDict = { "timeline": { "events": serialized_timeline, "prev_batch": await room.timeline.prev_batch.to_string(self.store), @@ -519,6 +548,7 @@ class SyncRestServlet(RestServlet): } if joined: + assert isinstance(room, JoinedSyncResult) ephemeral_events = room.ephemeral result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications @@ -528,5 +558,5 @@ class SyncRestServlet(RestServlet): return result -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: SyncRestServlet(hs).register(http_server) -- cgit 1.5.1 From 2efc838f050f0608f6648c0235eaade813d66f08 Mon Sep 17 00:00:00 2001 From: Dan Callahan Date: Mon, 23 Aug 2021 14:06:49 +0100 Subject: Avoid duplicate issues from Twisted trunk failures (#10672) Setting `update_existing: true` in the `create-an-issue` GitHub Action will avoid opening duplicate issues if an open issue already exists with an identical title. If no open issues match the title, then a new issue will be created. This helps avoid spamming our issue tracker should there be a failure when testing against Twisted's trunk. This PR also pins the SHA of the `create-an-issue` action to mitigate the risk of a malicious actor gaining access to JasonEtco's account. See GitHub's page on security hardening third party actions for more: https://docs.github.com/en/actions/learn-github-actions/security-hardening-for-github-actions#using-third-party-actions Signed-off-by: Dan Callahan --- .github/workflows/twisted_trunk.yml | 3 ++- changelog.d/10672.misc | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10672.misc diff --git a/.github/workflows/twisted_trunk.yml b/.github/workflows/twisted_trunk.yml index 0bf7790545..b5c729888f 100644 --- a/.github/workflows/twisted_trunk.yml +++ b/.github/workflows/twisted_trunk.yml @@ -82,8 +82,9 @@ jobs: steps: - uses: actions/checkout@v2 - - uses: JasonEtco/create-an-issue@v2 + - uses: JasonEtco/create-an-issue@5d9504915f79f9cc6d791934b8ef34f2353dd74d # v2.5.0, 2020-12-06 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: + update_existing: true filename: .ci/twisted_trunk_build_failed_issue_template.md diff --git a/changelog.d/10672.misc b/changelog.d/10672.misc new file mode 100644 index 0000000000..7104c121e0 --- /dev/null +++ b/changelog.d/10672.misc @@ -0,0 +1 @@ +Run a nightly CI build against Twisted trunk. -- cgit 1.5.1 From 3e83f97154e12fc50ccf2d8b4a01007cff012c47 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 23 Aug 2021 14:58:31 +0100 Subject: Fix the titles in the OIDC documentation (#10639) * Fix the titles in the OIDC documentation Having them as links broke the table-of-contents rendering in mdbook. Plus there's no reason for only some of the provider titles to be links. * Changelog * Add link to google idp docs --- changelog.d/10639.doc | 1 + docs/openid.md | 16 ++++++++++------ 2 files changed, 11 insertions(+), 6 deletions(-) create mode 100644 changelog.d/10639.doc diff --git a/changelog.d/10639.doc b/changelog.d/10639.doc new file mode 100644 index 0000000000..acbac4aad8 --- /dev/null +++ b/changelog.d/10639.doc @@ -0,0 +1 @@ +Fix some of the titles not rendering in the OIDC documentation. diff --git a/docs/openid.md b/docs/openid.md index f685fd551a..f121bc8a6e 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -79,7 +79,7 @@ oidc_providers: display_name_template: "{{ user.name }}" ``` -### [Dex][dex-idp] +### Dex [Dex][dex-idp] is a simple, open-source, certified OpenID Connect Provider. Although it is designed to help building a full-blown provider with an @@ -117,7 +117,7 @@ oidc_providers: localpart_template: "{{ user.name }}" display_name_template: "{{ user.name|capitalize }}" ``` -### [Keycloak][keycloak-idp] +### Keycloak [Keycloak][keycloak-idp] is an opensource IdP maintained by Red Hat. @@ -166,7 +166,9 @@ oidc_providers: localpart_template: "{{ user.preferred_username }}" display_name_template: "{{ user.name }}" ``` -### [Auth0][auth0] +### Auth0 + +[Auth0][auth0] is a hosted SaaS IdP solution. 1. Create a regular web application for Synapse 2. Set the Allowed Callback URLs to `[synapse public baseurl]/_synapse/client/oidc/callback` @@ -209,7 +211,7 @@ oidc_providers: ### GitHub -GitHub is a bit special as it is not an OpenID Connect compliant provider, but +[GitHub][github-idp] is a bit special as it is not an OpenID Connect compliant provider, but just a regular OAuth2 provider. The [`/user` API endpoint](https://developer.github.com/v3/users/#get-the-authenticated-user) @@ -242,11 +244,13 @@ oidc_providers: display_name_template: "{{ user.name }}" ``` -### [Google][google-idp] +### Google + +[Google][google-idp] is an OpenID certified authentication and authorisation provider. 1. Set up a project in the Google API Console (see https://developers.google.com/identity/protocols/oauth2/openid-connect#appsetup). -2. add an "OAuth Client ID" for a Web Application under "Credentials". +2. Add an "OAuth Client ID" for a Web Application under "Credentials". 3. Copy the Client ID and Client Secret, and add the following to your synapse config: ```yaml oidc_providers: -- cgit 1.5.1 From 0c1d6f65d7c65efd8491adf4efc2620148e2841a Mon Sep 17 00:00:00 2001 From: Azrenbeth <77782548+Azrenbeth@users.noreply.github.com> Date: Mon, 23 Aug 2021 16:25:33 +0100 Subject: Enforce the max length for per-room display names / avatar URLs. (#10654) To match the maximum lengths allowed for profile data. --- changelog.d/10654.bugfix | 1 + synapse/handlers/room_member.py | 17 ++++++++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10654.bugfix diff --git a/changelog.d/10654.bugfix b/changelog.d/10654.bugfix new file mode 100644 index 0000000000..b0bd78453f --- /dev/null +++ b/changelog.d/10654.bugfix @@ -0,0 +1 @@ +Enforce the maximum length for per-room display names and avatar URLs. \ No newline at end of file diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index ba13196218..401b84aad1 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -36,6 +36,7 @@ from synapse.api.ratelimiting import Ratelimiter from synapse.event_auth import get_named_level, get_power_level_event from synapse.events import EventBase from synapse.events.snapshot import EventContext +from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.types import ( JsonDict, Requester, @@ -79,7 +80,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.account_data_handler = hs.get_account_data_handler() self.event_auth_handler = hs.get_event_auth_handler() - self.member_linearizer = Linearizer(name="member") + self.member_linearizer: Linearizer = Linearizer(name="member") self.clock = hs.get_clock() self.spam_checker = hs.get_spam_checker() @@ -556,6 +557,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): content.pop("displayname", None) content.pop("avatar_url", None) + if len(content.get("displayname") or "") > MAX_DISPLAYNAME_LEN: + raise SynapseError( + 400, + f"Displayname is too long (max {MAX_DISPLAYNAME_LEN})", + errcode=Codes.BAD_JSON, + ) + + if len(content.get("avatar_url") or "") > MAX_AVATAR_URL_LEN: + raise SynapseError( + 400, + f"Avatar URL is too long (max {MAX_AVATAR_URL_LEN})", + errcode=Codes.BAD_JSON, + ) + effective_membership_state = action if action in ["kick", "unban"]: effective_membership_state = "leave" -- cgit 1.5.1 From 86415f162d5dca9f38054a76e00130d947f2e8e2 Mon Sep 17 00:00:00 2001 From: Hugo DELVAL Date: Mon, 23 Aug 2021 19:12:36 +0200 Subject: doc: add django-oauth-toolkit to oidc doc (#10192) Signed-off-by: Hugo Delval --- changelog.d/10192.doc | 1 + docs/openid.md | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) create mode 100644 changelog.d/10192.doc diff --git a/changelog.d/10192.doc b/changelog.d/10192.doc new file mode 100644 index 0000000000..3dd00537e8 --- /dev/null +++ b/changelog.d/10192.doc @@ -0,0 +1 @@ +Add documentation on how to connect Django with synapse using oidc and django-oauth-toolkit. Contributed by @HugoDelval. diff --git a/docs/openid.md b/docs/openid.md index f121bc8a6e..49180eec52 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -450,3 +450,51 @@ The synapse config will look like this: config: email_template: "{{ user.email }}" ``` + +## Django OAuth Toolkit + +[django-oauth-toolkit](https://github.com/jazzband/django-oauth-toolkit) is a +Django application providing out of the box all the endpoints, data and logic +needed to add OAuth2 capabilities to your Django projects. It supports +[OpenID Connect too](https://django-oauth-toolkit.readthedocs.io/en/latest/oidc.html). + +Configuration on Django's side: + +1. Add an application: https://example.com/admin/oauth2_provider/application/add/ and choose parameters like this: +* `Redirect uris`: https://synapse.example.com/_synapse/client/oidc/callback +* `Client type`: `Confidential` +* `Authorization grant type`: `Authorization code` +* `Algorithm`: `HMAC with SHA-2 256` +2. You can [customize the claims](https://django-oauth-toolkit.readthedocs.io/en/latest/oidc.html#customizing-the-oidc-responses) Django gives to synapse (optional): +
+ Code sample + + ```python + class CustomOAuth2Validator(OAuth2Validator): + + def get_additional_claims(self, request): + return { + "sub": request.user.email, + "email": request.user.email, + "first_name": request.user.first_name, + "last_name": request.user.last_name, + } + ``` +
+Your synapse config is then: + +```yaml +oidc_providers: + - idp_id: django_example + idp_name: "Django Example" + issuer: "https://example.com/o/" + client_id: "your-client-id" # CHANGE ME + client_secret: "your-client-secret" # CHANGE ME + scopes: ["openid"] + user_profile_method: "userinfo_endpoint" # needed because oauth-toolkit does not include user information in the authorization response + user_mapping_provider: + config: + localpart_template: "{{ user.email.split('@')[0] }}" + display_name_template: "{{ user.first_name }} {{ user.last_name }}" + email_template: "{{ user.email }}" +``` -- cgit 1.5.1 From 15db8b7c7f13f33ca49104183e0642892c3b83f1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 24 Aug 2021 10:17:51 +0100 Subject: Correctly initialise the `synapse_user_logins` metric. (#10677) Fix a bug where the prometheus metrics for SSO logins wouldn't be initialised until the first user logged in with a given auth provider. --- changelog.d/10677.bugfix | 1 + synapse/handlers/register.py | 18 ++++++++++++++++++ synapse/handlers/sso.py | 2 ++ synapse/rest/client/login.py | 29 +++++++++++++++++++++++------ 4 files changed, 44 insertions(+), 6 deletions(-) create mode 100644 changelog.d/10677.bugfix diff --git a/changelog.d/10677.bugfix b/changelog.d/10677.bugfix new file mode 100644 index 0000000000..9964afaaee --- /dev/null +++ b/changelog.d/10677.bugfix @@ -0,0 +1 @@ +Fix a bug which caused the `synapse_user_logins_total` Prometheus metric not to be correctly initialised on restart. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 8cf614136e..0ed59d757b 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -56,6 +56,22 @@ login_counter = Counter( ) +def init_counters_for_auth_provider(auth_provider_id: str) -> None: + """Ensure the prometheus counters for the given auth provider are initialised + + This fixes a problem where the counters are not reported for a given auth provider + until the user first logs in/registers. + """ + for is_guest in (True, False): + login_counter.labels(guest=is_guest, auth_provider=auth_provider_id) + for shadow_banned in (True, False): + registration_counter.labels( + guest=is_guest, + shadow_banned=shadow_banned, + auth_provider=auth_provider_id, + ) + + class LoginDict(TypedDict): device_id: str access_token: str @@ -96,6 +112,8 @@ class RegistrationHandler(BaseHandler): self.session_lifetime = hs.config.session_lifetime self.access_token_lifetime = hs.config.access_token_lifetime + init_counters_for_auth_provider("") + async def check_username( self, localpart: str, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 1b855a685c..0e6ebb574e 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -37,6 +37,7 @@ from twisted.web.server import Request from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.config.sso import SsoAttributeRequirement +from synapse.handlers.register import init_counters_for_auth_provider from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html, respond_with_redirect @@ -213,6 +214,7 @@ class SsoHandler: p_id = p.idp_id assert p_id not in self._identity_providers self._identity_providers[p_id] = p + init_counters_for_auth_provider(p_id) def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]: """Get the configured identity providers""" diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 0c8d8967b7..11d07776b2 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -104,6 +104,12 @@ class LoginRestServlet(RestServlet): burst_count=self.hs.config.rc_login_account.burst_count, ) + # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance. + # The reason for this is to ensure that the auth_provider_ids are registered + # with SsoHandler, which in turn ensures that the login/registration prometheus + # counters are initialised for the auth_provider_ids. + _load_sso_handlers(hs) + def on_GET(self, request: SynapseRequest): flows = [] if self.jwt_enabled: @@ -499,12 +505,7 @@ class SsoRedirectServlet(RestServlet): def __init__(self, hs: "HomeServer"): # make sure that the relevant handlers are instantiated, so that they # register themselves with the main SSOHandler. - if hs.config.cas_enabled: - hs.get_cas_handler() - if hs.config.saml2_enabled: - hs.get_saml_handler() - if hs.config.oidc_enabled: - hs.get_oidc_handler() + _load_sso_handlers(hs) self._sso_handler = hs.get_sso_handler() self._msc2858_enabled = hs.config.experimental.msc2858_enabled self._public_baseurl = hs.config.public_baseurl @@ -598,3 +599,19 @@ def register_servlets(hs, http_server): SsoRedirectServlet(hs).register(http_server) if hs.config.cas_enabled: CasTicketServlet(hs).register(http_server) + + +def _load_sso_handlers(hs: "HomeServer"): + """Ensure that the SSO handlers are loaded, if they are enabled by configuration. + + This is mostly useful to ensure that the CAS/SAML/OIDC handlers register themselves + with the main SsoHandler. + + It's safe to call this multiple times. + """ + if hs.config.cas.cas_enabled: + hs.get_cas_handler() + if hs.config.saml2.saml2_enabled: + hs.get_saml_handler() + if hs.config.oidc.oidc_enabled: + hs.get_oidc_handler() -- cgit 1.5.1 From d12ba52f178982ecb47207471bee14472f9597b6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 24 Aug 2021 08:14:03 -0400 Subject: Persist room hierarchy pagination sessions to the database. (#10613) --- changelog.d/10613.feature | 1 + mypy.ini | 1 + synapse/app/generic_worker.py | 2 + synapse/handlers/room_summary.py | 76 +++++------ synapse/storage/databases/main/__init__.py | 2 + synapse/storage/databases/main/session.py | 145 +++++++++++++++++++++ .../schema/main/delta/62/02session_store.sql | 23 ++++ 7 files changed, 212 insertions(+), 38 deletions(-) create mode 100644 changelog.d/10613.feature create mode 100644 synapse/storage/databases/main/session.py create mode 100644 synapse/storage/schema/main/delta/62/02session_store.sql diff --git a/changelog.d/10613.feature b/changelog.d/10613.feature new file mode 100644 index 0000000000..ffc4e4289c --- /dev/null +++ b/changelog.d/10613.feature @@ -0,0 +1 @@ +Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/mypy.ini b/mypy.ini index b17872211e..745e6b78eb 100644 --- a/mypy.ini +++ b/mypy.ini @@ -57,6 +57,7 @@ files = synapse/storage/databases/main/keys.py, synapse/storage/databases/main/pusher.py, synapse/storage/databases/main/registration.py, + synapse/storage/databases/main/session.py, synapse/storage/databases/main/stream.py, synapse/storage/databases/main/ui_auth.py, synapse/storage/database.py, diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index fd2626dbe1..9b71dd75e6 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -118,6 +118,7 @@ from synapse.storage.databases.main.monthly_active_users import ( from synapse.storage.databases.main.presence import PresenceStore from synapse.storage.databases.main.room import RoomWorkerStore from synapse.storage.databases.main.search import SearchStore +from synapse.storage.databases.main.session import SessionStore from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.transactions import TransactionWorkerStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore @@ -253,6 +254,7 @@ class GenericWorkerSlavedStore( SearchStore, TransactionWorkerStore, LockStore, + SessionStore, BaseSlavedStore, ): pass diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index ac6cfc0da9..906985c754 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -28,12 +28,11 @@ from synapse.api.constants import ( Membership, RoomTypes, ) -from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError from synapse.events import EventBase from synapse.events.utils import format_event_for_client_v2 from synapse.types import JsonDict from synapse.util.caches.response_cache import ResponseCache -from synapse.util.stringutils import random_string if TYPE_CHECKING: from synapse.server import HomeServer @@ -76,6 +75,9 @@ class _PaginationSession: class RoomSummaryHandler: + # A unique key used for pagination sessions for the room hierarchy endpoint. + _PAGINATION_SESSION_TYPE = "room_hierarchy_pagination" + # The time a pagination session remains valid for. _PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000 @@ -87,12 +89,6 @@ class RoomSummaryHandler: self._server_name = hs.hostname self._federation_client = hs.get_federation_client() - # A map of query information to the current pagination state. - # - # TODO Allow for multiple workers to share this data. - # TODO Expire pagination tokens. - self._pagination_sessions: Dict[_PaginationKey, _PaginationSession] = {} - # If a user tries to fetch the same page multiple times in quick succession, # only process the first attempt and return its result to subsequent requests. self._pagination_response_cache: ResponseCache[ @@ -102,21 +98,6 @@ class RoomSummaryHandler: "get_room_hierarchy", ) - def _expire_pagination_sessions(self): - """Expire pagination session which are old.""" - expire_before = ( - self._clock.time_msec() - self._PAGINATION_SESSION_VALIDITY_PERIOD_MS - ) - to_expire = [] - - for key, value in self._pagination_sessions.items(): - if value.creation_time_ms < expire_before: - to_expire.append(key) - - for key in to_expire: - logger.debug("Expiring pagination session id %s", key) - del self._pagination_sessions[key] - async def get_space_summary( self, requester: str, @@ -327,18 +308,29 @@ class RoomSummaryHandler: # If this is continuing a previous session, pull the persisted data. if from_token: - self._expire_pagination_sessions() + try: + pagination_session = await self._store.get_session( + session_type=self._PAGINATION_SESSION_TYPE, + session_id=from_token, + ) + except StoreError: + raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM) - pagination_key = _PaginationKey( - requested_room_id, suggested_only, max_depth, from_token - ) - if pagination_key not in self._pagination_sessions: + # If the requester, room ID, suggested-only, or max depth were modified + # the session is invalid. + if ( + requester != pagination_session["requester"] + or requested_room_id != pagination_session["room_id"] + or suggested_only != pagination_session["suggested_only"] + or max_depth != pagination_session["max_depth"] + ): raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM) # Load the previous state. - pagination_session = self._pagination_sessions[pagination_key] - room_queue = pagination_session.room_queue - processed_rooms = pagination_session.processed_rooms + room_queue = [ + _RoomQueueEntry(*fields) for fields in pagination_session["room_queue"] + ] + processed_rooms = set(pagination_session["processed_rooms"]) else: # The queue of rooms to process, the next room is last on the stack. room_queue = [_RoomQueueEntry(requested_room_id, ())] @@ -456,13 +448,21 @@ class RoomSummaryHandler: # If there's additional data, generate a pagination token (and persist state). if room_queue: - next_batch = random_string(24) - result["next_batch"] = next_batch - pagination_key = _PaginationKey( - requested_room_id, suggested_only, max_depth, next_batch - ) - self._pagination_sessions[pagination_key] = _PaginationSession( - self._clock.time_msec(), room_queue, processed_rooms + result["next_batch"] = await self._store.create_session( + session_type=self._PAGINATION_SESSION_TYPE, + value={ + # Information which must be identical across pagination. + "requester": requester, + "room_id": requested_room_id, + "suggested_only": suggested_only, + "max_depth": max_depth, + # The stored state. + "room_queue": [ + attr.astuple(room_entry) for room_entry in room_queue + ], + "processed_rooms": list(processed_rooms), + }, + expiry_ms=self._PAGINATION_SESSION_VALIDITY_PERIOD_MS, ) return result diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 01b918e12e..00a644e8f7 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -63,6 +63,7 @@ from .relations import RelationsStore from .room import RoomStore from .roommember import RoomMemberStore from .search import SearchStore +from .session import SessionStore from .signatures import SignatureStore from .state import StateStore from .stats import StatsStore @@ -121,6 +122,7 @@ class DataStore( ServerMetricsStore, EventForwardExtremitiesStore, LockStore, + SessionStore, ): def __init__(self, database: DatabasePool, db_conn, hs): self.hs = hs diff --git a/synapse/storage/databases/main/session.py b/synapse/storage/databases/main/session.py new file mode 100644 index 0000000000..172f27d109 --- /dev/null +++ b/synapse/storage/databases/main/session.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +import synapse.util.stringutils as stringutils +from synapse.api.errors import StoreError +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.types import JsonDict +from synapse.util import json_encoder + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class SessionStore(SQLBaseStore): + """ + A store for generic session data. + + Each type of session should provide a unique type (to separate sessions). + + Sessions are automatically removed when they expire. + """ + + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + # Create a background job for culling expired sessions. + if hs.config.run_background_tasks: + self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000) + + async def create_session( + self, session_type: str, value: JsonDict, expiry_ms: int + ) -> str: + """ + Creates a new pagination session for the room hierarchy endpoint. + + Args: + session_type: The type for this session. + value: The value to store. + expiry_ms: How long before an item is evicted from the cache + in milliseconds. Default is 0, indicating items never get + evicted based on time. + + Returns: + The newly created session ID. + + Raises: + StoreError if a unique session ID cannot be generated. + """ + # autogen a session ID and try to create it. We may clash, so just + # try a few times till one goes through, giving up eventually. + attempts = 0 + while attempts < 5: + session_id = stringutils.random_string(24) + + try: + await self.db_pool.simple_insert( + table="sessions", + values={ + "session_id": session_id, + "session_type": session_type, + "value": json_encoder.encode(value), + "expiry_time_ms": self.hs.get_clock().time_msec() + expiry_ms, + }, + desc="create_session", + ) + + return session_id + except self.db_pool.engine.module.IntegrityError: + attempts += 1 + raise StoreError(500, "Couldn't generate a session ID.") + + async def get_session(self, session_type: str, session_id: str) -> JsonDict: + """ + Retrieve data stored with create_session + + Args: + session_type: The type for this session. + session_id: The session ID returned from create_session. + + Raises: + StoreError if the session cannot be found. + """ + + def _get_session( + txn: LoggingTransaction, session_type: str, session_id: str, ts: int + ) -> JsonDict: + # This includes the expiry time since items are only periodically + # deleted, not upon expiry. + select_sql = """ + SELECT value FROM sessions WHERE + session_type = ? AND session_id = ? AND expiry_time_ms > ? + """ + txn.execute(select_sql, [session_type, session_id, ts]) + row = txn.fetchone() + + if not row: + raise StoreError(404, "No session") + + return db_to_json(row[0]) + + return await self.db_pool.runInteraction( + "get_session", + _get_session, + session_type, + session_id, + self._clock.time_msec(), + ) + + @wrap_as_background_process("delete_expired_sessions") + async def _delete_expired_sessions(self) -> None: + """Remove sessions with expiry dates that have passed.""" + + def _delete_expired_sessions_txn(txn: LoggingTransaction, ts: int) -> None: + sql = "DELETE FROM sessions WHERE expiry_time_ms <= ?" + txn.execute(sql, (ts,)) + + await self.db_pool.runInteraction( + "delete_expired_sessions", + _delete_expired_sessions_txn, + self._clock.time_msec(), + ) diff --git a/synapse/storage/schema/main/delta/62/02session_store.sql b/synapse/storage/schema/main/delta/62/02session_store.sql new file mode 100644 index 0000000000..535fb34c10 --- /dev/null +++ b/synapse/storage/schema/main/delta/62/02session_store.sql @@ -0,0 +1,23 @@ +/* + * Copyright 2021 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS sessions( + session_type TEXT NOT NULL, -- The unique key for this type of session. + session_id TEXT NOT NULL, -- The session ID passed to the client. + value TEXT NOT NULL, -- A JSON dictionary to persist. + expiry_time_ms BIGINT NOT NULL, -- The time this session will expire (epoch time in milliseconds). + UNIQUE (session_type, session_id) +); -- cgit 1.5.1 From 6f77a3d433c683223024075e805f87bec3327036 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 Aug 2021 15:31:55 +0100 Subject: 1.41.0 --- CHANGES.md | 9 +++++++++ changelog.d/10571.feature | 1 - debian/changelog | 6 ++++++ synapse/__init__.py | 2 +- 4 files changed, 16 insertions(+), 2 deletions(-) delete mode 100644 changelog.d/10571.feature diff --git a/CHANGES.md b/CHANGES.md index cad9423ebd..35456cded6 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,12 @@ +Synapse 1.41.0 (2021-08-24) +=========================== + +Features +-------- + +- Enable room capabilities ([MSC3244](https://github.com/matrix-org/matrix-doc/pull/3244)) by default and set room version 8 as the preferred room version for restricted rooms. ([\#10571](https://github.com/matrix-org/synapse/issues/10571)) + + Synapse 1.41.0rc1 (2021-08-18) ============================== diff --git a/changelog.d/10571.feature b/changelog.d/10571.feature deleted file mode 100644 index 0da318cd5b..0000000000 --- a/changelog.d/10571.feature +++ /dev/null @@ -1 +0,0 @@ -Enable room capabilities ([MSC3244](https://github.com/matrix-org/matrix-doc/pull/3244)) by default and set room version 8 as the preferred room version for restricted rooms. diff --git a/debian/changelog b/debian/changelog index 68f309b0b2..4da4bc018c 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.41.0) stable; urgency=medium + + * New synapse release 1.41.0. + + -- Synapse Packaging team Tue, 24 Aug 2021 15:31:45 +0100 + matrix-synapse-py3 (1.41.0~rc1) stable; urgency=medium * New synapse release 1.41.0~rc1. diff --git a/synapse/__init__.py b/synapse/__init__.py index 6ada20a77f..ef3770262e 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.41.0rc1" +__version__ = "1.41.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.5.1 From f03cafb50c49a1569f1f99485f9cc42abfdc7b21 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 24 Aug 2021 16:06:33 +0100 Subject: Update changelog --- CHANGES.md | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 35456cded6..f8da8771aa 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,10 +1,15 @@ Synapse 1.41.0 (2021-08-24) =========================== +This release adds support for Debian 12 (Bookworm), but **removes support for Ubuntu 20.10 (Groovy Gorilla)**, which reached End of Life last month. + +Note that when using workers the `/_synapse/admin/v1/users/{userId}/media` must now be handled by media workers. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html) for more information. + + Features -------- -- Enable room capabilities ([MSC3244](https://github.com/matrix-org/matrix-doc/pull/3244)) by default and set room version 8 as the preferred room version for restricted rooms. ([\#10571](https://github.com/matrix-org/synapse/issues/10571)) +- Enable room capabilities ([MSC3244](https://github.com/matrix-org/matrix-doc/pull/3244)) by default and set room version 8 as the preferred room version when creating restricted rooms. ([\#10571](https://github.com/matrix-org/synapse/issues/10571)) Synapse 1.41.0rc1 (2021-08-18) @@ -16,7 +21,7 @@ Features - Add `get_userinfo_by_id` method to ModuleApi. ([\#9581](https://github.com/matrix-org/synapse/issues/9581)) - Initial local support for [MSC3266](https://github.com/matrix-org/synapse/pull/10394), Room Summary over the unstable `/rooms/{roomIdOrAlias}/summary` API. ([\#10394](https://github.com/matrix-org/synapse/issues/10394)) - Experimental support for [MSC3288](https://github.com/matrix-org/matrix-doc/pull/3288), sending `room_type` to the identity server for 3pid invites over the `/store-invite` API. ([\#10435](https://github.com/matrix-org/synapse/issues/10435)) -- Add support for sending federation requests through a proxy. Contributed by @Bubu and @dklimpel. ([\#10475](https://github.com/matrix-org/synapse/issues/10475)) +- Add support for sending federation requests through a proxy. Contributed by @Bubu and @dklimpel. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html) for more information. ([\#10596](https://github.com/matrix-org/synapse/issues/10596)). ([\#10475](https://github.com/matrix-org/synapse/issues/10475)) - Add support for "marker" events which makes historical events discoverable for servers that already have all of the scrollback history (part of [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716)). ([\#10498](https://github.com/matrix-org/synapse/issues/10498)) - Add a configuration setting for the time a `/sync` response is cached for. ([\#10513](https://github.com/matrix-org/synapse/issues/10513)) - The default logging handler for new installations is now `PeriodicallyFlushingMemoryHandler`, a buffered logging handler which periodically flushes itself. ([\#10518](https://github.com/matrix-org/synapse/issues/10518)) @@ -38,7 +43,7 @@ Bugfixes - Add some clarification to the sample config file. Contributed by @Kentokamoto. ([\#10129](https://github.com/matrix-org/synapse/issues/10129)) - Fix a long-standing bug where protocols which are not implemented by any appservices were incorrectly returned via `GET /_matrix/client/r0/thirdparty/protocols`. ([\#10532](https://github.com/matrix-org/synapse/issues/10532)) - Fix exceptions in logs when failing to get remote room list. ([\#10541](https://github.com/matrix-org/synapse/issues/10541)) -- Fix longstanding bug which caused the user "status" to be reset when the user went offline. Contributed by @dklimpel. ([\#10550](https://github.com/matrix-org/synapse/issues/10550)) +- Fix longstanding bug which caused the user's presence "status message" to be reset when the user went offline. Contributed by @dklimpel. ([\#10550](https://github.com/matrix-org/synapse/issues/10550)) - Allow public rooms to be previewed in the spaces summary APIs from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). ([\#10580](https://github.com/matrix-org/synapse/issues/10580)) - Fix a bug introduced in v1.37.1 where an error could occur in the asynchronous processing of PDUs when the queue was empty. ([\#10592](https://github.com/matrix-org/synapse/issues/10592)) - Fix errors on /sync when read receipt data is a string. Only affects homeservers with the experimental flag for [MSC2285](https://github.com/matrix-org/matrix-doc/pull/2285) enabled. Contributed by @SimonBrandner. ([\#10606](https://github.com/matrix-org/synapse/issues/10606)) @@ -49,7 +54,7 @@ Bugfixes Improved Documentation ---------------------- -- Add documentation for configuration a forward proxy. ([\#10443](https://github.com/matrix-org/synapse/issues/10443)) +- Add documentation for configuring a forward proxy. ([\#10443](https://github.com/matrix-org/synapse/issues/10443)) - Updated the reverse proxy documentation to highlight the homserver configuration that is needed to make Synapse aware that is is intentionally reverse proxied. ([\#10551](https://github.com/matrix-org/synapse/issues/10551)) - Update CONTRIBUTING.md to fix index links and the instructions for SyTest in docker. ([\#10599](https://github.com/matrix-org/synapse/issues/10599)) -- cgit 1.5.1 From 7367473f965fed1160cb8633de341c5833e5b662 Mon Sep 17 00:00:00 2001 From: Sean Date: Wed, 25 Aug 2021 10:51:08 +0100 Subject: Fix error when selecting between thumbnails with the same quality (#10684) Fixes #10318 --- changelog.d/10684.bugfix | 1 + synapse/rest/media/v1/thumbnail_resource.py | 26 ++++++++++++------- tests/rest/media/v1/test_media_storage.py | 39 ++++++++++++++++++++++++++++- 3 files changed, 56 insertions(+), 10 deletions(-) create mode 100644 changelog.d/10684.bugfix diff --git a/changelog.d/10684.bugfix b/changelog.d/10684.bugfix new file mode 100644 index 0000000000..311b17601a --- /dev/null +++ b/changelog.d/10684.bugfix @@ -0,0 +1 @@ +Fix long-standing issue which caused an error when a thumbnail is requested and there are multiple thumbnails with the same quality rating. diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index a029d426f0..12bd745cb2 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -15,7 +15,7 @@ import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from twisted.web.server import Request @@ -414,9 +414,9 @@ class ThumbnailResource(DirectServeJsonResource): if desired_method == "crop": # Thumbnails that match equal or larger sizes of desired width/height. - crop_info_list = [] + crop_info_list: List[Tuple[int, int, int, bool, int, Dict[str, Any]]] = [] # Other thumbnails. - crop_info_list2 = [] + crop_info_list2: List[Tuple[int, int, int, bool, int, Dict[str, Any]]] = [] for info in thumbnail_infos: # Skip thumbnails generated with different methods. if info["thumbnail_method"] != "crop": @@ -451,15 +451,19 @@ class ThumbnailResource(DirectServeJsonResource): info, ) ) + # Pick the most appropriate thumbnail. Some values of `desired_width` and + # `desired_height` may result in a tie, in which case we avoid comparing on + # the thumbnail info dictionary and pick the thumbnail that appears earlier + # in the list of candidates. if crop_info_list: - thumbnail_info = min(crop_info_list)[-1] + thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1] elif crop_info_list2: - thumbnail_info = min(crop_info_list2)[-1] + thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1] elif desired_method == "scale": # Thumbnails that match equal or larger sizes of desired width/height. - info_list = [] + info_list: List[Tuple[int, bool, int, Dict[str, Any]]] = [] # Other thumbnails. - info_list2 = [] + info_list2: List[Tuple[int, bool, int, Dict[str, Any]]] = [] for info in thumbnail_infos: # Skip thumbnails generated with different methods. @@ -477,10 +481,14 @@ class ThumbnailResource(DirectServeJsonResource): info_list2.append( (size_quality, type_quality, length_quality, info) ) + # Pick the most appropriate thumbnail. Some values of `desired_width` and + # `desired_height` may result in a tie, in which case we avoid comparing on + # the thumbnail info dictionary and pick the thumbnail that appears earlier + # in the list of candidates. if info_list: - thumbnail_info = min(info_list)[-1] + thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1] elif info_list2: - thumbnail_info = min(info_list2)[-1] + thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1] if thumbnail_info: return FileInfo( diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 6085444b9d..2f7eebfe69 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -21,7 +21,7 @@ from unittest.mock import Mock from urllib import parse import attr -from parameterized import parameterized_class +from parameterized import parameterized, parameterized_class from PIL import Image as Image from twisted.internet import defer @@ -473,6 +473,43 @@ class MediaRepoTests(unittest.HomeserverTestCase): }, ) + @parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)]) + def test_same_quality(self, method, desired_size): + """Test that choosing between thumbnails with the same quality rating succeeds. + + We are not particular about which thumbnail is chosen.""" + self.assertIsNotNone( + self.thumbnail_resource._select_thumbnail( + desired_width=desired_size, + desired_height=desired_size, + desired_method=method, + desired_type=self.test_image.content_type, + # Provide two identical thumbnails which are guaranteed to have the same + # quality rating. + thumbnail_infos=[ + { + "thumbnail_width": 32, + "thumbnail_height": 32, + "thumbnail_method": method, + "thumbnail_type": self.test_image.content_type, + "thumbnail_length": 256, + "filesystem_id": f"thumbnail1{self.test_image.extension}", + }, + { + "thumbnail_width": 32, + "thumbnail_height": 32, + "thumbnail_method": method, + "thumbnail_type": self.test_image.content_type, + "thumbnail_length": 256, + "filesystem_id": f"thumbnail2{self.test_image.extension}", + }, + ], + file_id=f"image{self.test_image.extension}", + url_cache=None, + server_name=None, + ) + ) + def test_x_robots_tag_header(self): """ Tests that the `X-Robots-Tag` header is present, which informs web crawlers -- cgit 1.5.1 From 882539e423d3eaad703cdee80582f12a27b34d58 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 25 Aug 2021 10:18:23 -0400 Subject: Ensure the base Docker image is rebuilt when running complement with workers. (#10686) We now always rebuild the matrixdotorg/synapse image, then build the matrixdotorg/synapse-workers image on top of it. --- changelog.d/10686.misc | 1 + scripts-dev/complement.sh | 14 +++++++------- 2 files changed, 8 insertions(+), 7 deletions(-) create mode 100644 changelog.d/10686.misc diff --git a/changelog.d/10686.misc b/changelog.d/10686.misc new file mode 100644 index 0000000000..b76908d74e --- /dev/null +++ b/changelog.d/10686.misc @@ -0,0 +1 @@ +Update `complement.sh` to rebuild the base Docker image when run with workers. diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 5d0ef8dd3a..89af7a4fde 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -35,25 +35,25 @@ if [[ -z "$COMPLEMENT_DIR" ]]; then echo "Checkout available at 'complement-master'" fi +# Build the base Synapse image from the local checkout +docker build -t matrixdotorg/synapse -f "docker/Dockerfile" . + # If we're using workers, modify the docker files slightly. if [[ -n "$WORKERS" ]]; then - BASE_IMAGE=matrixdotorg/synapse-workers - BASE_DOCKERFILE=docker/Dockerfile-workers + # Build the workers docker image (from the base Synapse image). + docker build -t matrixdotorg/synapse-workers -f "docker/Dockerfile-workers" . + export COMPLEMENT_BASE_IMAGE=complement-synapse-workers COMPLEMENT_DOCKERFILE=SynapseWorkers.Dockerfile # And provide some more configuration to complement. export COMPLEMENT_CA=true export COMPLEMENT_VERSION_CHECK_ITERATIONS=500 else - BASE_IMAGE=matrixdotorg/synapse - BASE_DOCKERFILE=docker/Dockerfile export COMPLEMENT_BASE_IMAGE=complement-synapse COMPLEMENT_DOCKERFILE=Synapse.Dockerfile fi -# Build the base Synapse image from the local checkout -docker build -t $BASE_IMAGE -f "$BASE_DOCKERFILE" . -# Build the Synapse monolith image from Complement, based on the above image we just built +# Build the Complement image from the Synapse image we just built. docker build -t $COMPLEMENT_BASE_IMAGE -f "$COMPLEMENT_DIR/dockerfiles/$COMPLEMENT_DOCKERFILE" "$COMPLEMENT_DIR/dockerfiles" cd "$COMPLEMENT_DIR" -- cgit 1.5.1 From b45cc1530b1438b8bfd9c09f179c7338e85ac083 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 25 Aug 2021 17:00:44 +0100 Subject: Make a note to leave a summary when one is bumping the schema version (#10621) I found this easy to miss (and evidently, it looks like it was missed for schema version 62). --- changelog.d/10621.misc | 1 + synapse/storage/schema/__init__.py | 2 ++ 2 files changed, 3 insertions(+) create mode 100644 changelog.d/10621.misc diff --git a/changelog.d/10621.misc b/changelog.d/10621.misc new file mode 100644 index 0000000000..b8de2e1911 --- /dev/null +++ b/changelog.d/10621.misc @@ -0,0 +1 @@ +Add a comment asking developers to leave a reason when bumping the database schema version. \ No newline at end of file diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index a5bc0ee8a5..af9cc69949 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +# When updating these values, please leave a short summary of the changes below. + SCHEMA_VERSION = 63 """Represents the expectations made by the codebase about the database schema -- cgit 1.5.1 From 5548fe097881b543cba37c7cda27ff7efe55025d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 Aug 2021 07:16:53 -0400 Subject: Cache the result of fetching the room hierarchy over federation. (#10647) --- changelog.d/10647.misc | 1 + synapse/federation/federation_client.py | 106 ++++++++++++++++++++------------ 2 files changed, 67 insertions(+), 40 deletions(-) create mode 100644 changelog.d/10647.misc diff --git a/changelog.d/10647.misc b/changelog.d/10647.misc new file mode 100644 index 0000000000..4407a9030d --- /dev/null +++ b/changelog.d/10647.misc @@ -0,0 +1 @@ +Improve the performance of the `/hierarchy` API (from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946)) by caching responses received over federation. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 44d9e8a5c7..1416abd0fb 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -111,6 +111,23 @@ class FederationClient(FederationBase): reset_expiry_on_get=False, ) + # A cache for fetching the room hierarchy over federation. + # + # Some stale data over federation is OK, but must be refreshed + # periodically since the local server is in the room. + # + # It is a map of (room ID, suggested-only) -> the response of + # get_room_hierarchy. + self._get_room_hierarchy_cache: ExpiringCache[ + Tuple[str, bool], Tuple[JsonDict, Sequence[JsonDict], Sequence[str]] + ] = ExpiringCache( + cache_name="get_room_hierarchy_cache", + clock=self._clock, + max_len=1000, + expiry_ms=5 * 60 * 1000, + reset_expiry_on_get=False, + ) + def _clear_tried_cache(self): """Clear pdu_destination_tried cache""" now = self._clock.time_msec() @@ -1324,6 +1341,10 @@ class FederationClient(FederationBase): remote servers """ + cached_result = self._get_room_hierarchy_cache.get((room_id, suggested_only)) + if cached_result: + return cached_result + async def send_request( destination: str, ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]: @@ -1370,58 +1391,63 @@ class FederationClient(FederationBase): return room, children, inaccessible_children try: - return await self._try_destination_list( + result = await self._try_destination_list( "fetch room hierarchy", destinations, send_request, failover_on_unknown_endpoint=True, ) except SynapseError as e: + # If an unexpected error occurred, re-raise it. + if e.code != 502: + raise + # Fallback to the old federation API and translate the results if # no servers implement the new API. # # The algorithm below is a bit inefficient as it only attempts to - # get information for the requested room, but the legacy API may + # parse information for the requested room, but the legacy API may # return additional layers. - if e.code == 502: - legacy_result = await self.get_space_summary( - destinations, - room_id, - suggested_only, - max_rooms_per_space=None, - exclude_rooms=[], - ) + legacy_result = await self.get_space_summary( + destinations, + room_id, + suggested_only, + max_rooms_per_space=None, + exclude_rooms=[], + ) - # Find the requested room in the response (and remove it). - for _i, room in enumerate(legacy_result.rooms): - if room.get("room_id") == room_id: - break - else: - # The requested room was not returned, nothing we can do. - raise - requested_room = legacy_result.rooms.pop(_i) - - # Find any children events of the requested room. - children_events = [] - children_room_ids = set() - for event in legacy_result.events: - if event.room_id == room_id: - children_events.append(event.data) - children_room_ids.add(event.state_key) - # And add them under the requested room. - requested_room["children_state"] = children_events - - # Find the children rooms. - children = [] - for room in legacy_result.rooms: - if room.get("room_id") in children_room_ids: - children.append(room) - - # It isn't clear from the response whether some of the rooms are - # not accessible. - return requested_room, children, () - - raise + # Find the requested room in the response (and remove it). + for _i, room in enumerate(legacy_result.rooms): + if room.get("room_id") == room_id: + break + else: + # The requested room was not returned, nothing we can do. + raise + requested_room = legacy_result.rooms.pop(_i) + + # Find any children events of the requested room. + children_events = [] + children_room_ids = set() + for event in legacy_result.events: + if event.room_id == room_id: + children_events.append(event.data) + children_room_ids.add(event.state_key) + # And add them under the requested room. + requested_room["children_state"] = children_events + + # Find the children rooms. + children = [] + for room in legacy_result.rooms: + if room.get("room_id") in children_room_ids: + children.append(room) + + # It isn't clear from the response whether some of the rooms are + # not accessible. + result = (requested_room, children, ()) + + # Cache the result to avoid fetching data over federation every time. + self._get_room_hierarchy_cache[(room_id, suggested_only)] = result + return result @attr.s(frozen=True, slots=True, auto_attribs=True) -- cgit 1.5.1 From 1aa0dad02187c3b972187f5952cfbce336b0ca5c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 Aug 2021 07:53:52 -0400 Subject: Additional type hints for REST servlets (part 2). (#10674) Applies the changes from #10665 to additional modules. --- changelog.d/10674.misc | 1 + synapse/handlers/presence.py | 5 +++ synapse/rest/client/auth.py | 11 ++++--- synapse/rest/client/devices.py | 48 +++++++++++++++------------- synapse/rest/client/events.py | 38 ++++++++++++++--------- synapse/rest/client/filter.py | 26 +++++++++++----- synapse/rest/client/groups.py | 3 +- synapse/rest/client/initial_sync.py | 16 +++++++--- synapse/rest/client/keys.py | 57 +++++++++++++--------------------- synapse/rest/client/knock.py | 3 +- synapse/rest/client/login.py | 21 ++++++------- synapse/rest/client/logout.py | 17 +++++++--- synapse/rest/client/notifications.py | 13 ++++++-- synapse/rest/client/openid.py | 16 +++++++--- synapse/rest/client/password_policy.py | 18 ++++++----- synapse/rest/client/presence.py | 24 +++++++++----- synapse/rest/client/profile.py | 37 ++++++++++++++++------ 17 files changed, 216 insertions(+), 138 deletions(-) create mode 100644 changelog.d/10674.misc diff --git a/changelog.d/10674.misc b/changelog.d/10674.misc new file mode 100644 index 0000000000..39a37b90b1 --- /dev/null +++ b/changelog.d/10674.misc @@ -0,0 +1 @@ +Add missing type hints to REST servlets. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 7ca14e1d84..4418d63df7 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -353,6 +353,11 @@ class BasePresenceHandler(abc.ABC): # otherwise would not do). await self.set_state(UserID.from_string(user_id), state, force_notify=True) + async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool: + raise NotImplementedError( + "Attempting to check presence on a non-presence worker." + ) + class _NullContextManager(ContextManager[None]): """A context manager which does nothing.""" diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 91800c0278..df8cc4ac7a 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -15,11 +15,14 @@ import logging from typing import TYPE_CHECKING +from twisted.web.server import Request + from synapse.api.constants import LoginType from synapse.api.errors import LoginError, SynapseError from synapse.api.urls import CLIENT_API_PREFIX -from synapse.http.server import respond_with_html +from synapse.http.server import HttpServer, respond_with_html from synapse.http.servlet import RestServlet, parse_string +from synapse.http.site import SynapseRequest from ._base import client_patterns @@ -49,7 +52,7 @@ class AuthRestServlet(RestServlet): self.registration_token_template = hs.config.registration_token_template self.success_template = hs.config.fallback_success_template - async def on_GET(self, request, stagetype): + async def on_GET(self, request: SynapseRequest, stagetype: str) -> None: session = parse_string(request, "session") if not session: raise SynapseError(400, "No session supplied") @@ -88,7 +91,7 @@ class AuthRestServlet(RestServlet): respond_with_html(request, 200, html) return None - async def on_POST(self, request, stagetype): + async def on_POST(self, request: Request, stagetype: str) -> None: session = parse_string(request, "session") if not session: @@ -172,5 +175,5 @@ class AuthRestServlet(RestServlet): return None -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: AuthRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 8b9674db06..25bc3c8f47 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -14,34 +14,36 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api import errors +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, ) from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns, interactive_auth_handler +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class DevicesRestServlet(RestServlet): PATTERNS = client_patterns("/devices$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) devices = await self.device_handler.get_devices_by_user( requester.user.to_string() @@ -57,7 +59,7 @@ class DeleteDevicesRestServlet(RestServlet): PATTERNS = client_patterns("/delete_devices") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -65,7 +67,7 @@ class DeleteDevicesRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) try: @@ -100,18 +102,16 @@ class DeleteDevicesRestServlet(RestServlet): class DeviceRestServlet(RestServlet): PATTERNS = client_patterns("/devices/(?P[^/]*)$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.auth_handler = hs.get_auth_handler() - async def on_GET(self, request, device_id): + async def on_GET( + self, request: SynapseRequest, device_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) device = await self.device_handler.get_device( requester.user.to_string(), device_id @@ -119,7 +119,9 @@ class DeviceRestServlet(RestServlet): return 200, device @interactive_auth_handler - async def on_DELETE(self, request, device_id): + async def on_DELETE( + self, request: SynapseRequest, device_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) try: @@ -146,7 +148,9 @@ class DeviceRestServlet(RestServlet): await self.device_handler.delete_device(requester.user.to_string(), device_id) return 200, {} - async def on_PUT(self, request, device_id): + async def on_PUT( + self, request: SynapseRequest, device_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) body = parse_json_object_from_request(request) @@ -193,13 +197,13 @@ class DehydratedDeviceServlet(RestServlet): PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=()) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - async def on_GET(self, request: SynapseRequest): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) dehydrated_device = await self.device_handler.get_dehydrated_device( requester.user.to_string() @@ -211,7 +215,7 @@ class DehydratedDeviceServlet(RestServlet): else: raise errors.NotFoundError("No dehydrated device available") - async def on_PUT(self, request: SynapseRequest): + async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]: submission = parse_json_object_from_request(request) requester = await self.auth.get_user_by_req(request) @@ -259,13 +263,13 @@ class ClaimDehydratedDeviceServlet(RestServlet): "/org.matrix.msc2697.v2/dehydrated_device/claim", releases=() ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - async def on_POST(self, request: SynapseRequest): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) submission = parse_json_object_from_request(request) @@ -292,7 +296,7 @@ class ClaimDehydratedDeviceServlet(RestServlet): return (200, result) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: DeleteDevicesRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) DeviceRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 52bb579cfd..13b72a045a 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -14,11 +14,18 @@ """This module contains REST servlets to do with event streaming, /events.""" import logging +from typing import TYPE_CHECKING, Dict, List, Tuple, Union from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet +from synapse.http.server import HttpServer +from synapse.http.servlet import RestServlet, parse_string +from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns from synapse.streams.config import PaginationConfig +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -28,31 +35,30 @@ class EventStreamRestServlet(RestServlet): DEFAULT_LONGPOLL_TIME_MS = 30000 - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.event_stream_handler = hs.get_event_stream_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) is_guest = requester.is_guest - room_id = None + args: Dict[bytes, List[bytes]] = request.args # type: ignore if is_guest: - if b"room_id" not in request.args: + if b"room_id" not in args: raise SynapseError(400, "Guest users must specify room_id param") - if b"room_id" in request.args: - room_id = request.args[b"room_id"][0].decode("ascii") + room_id = parse_string(request, "room_id") pagin_config = await PaginationConfig.from_request(self.store, request) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS - if b"timeout" in request.args: + if b"timeout" in args: try: - timeout = int(request.args[b"timeout"][0]) + timeout = int(args[b"timeout"][0]) except ValueError: raise SynapseError(400, "timeout must be in milliseconds.") - as_client_event = b"raw" not in request.args + as_client_event = b"raw" not in args chunk = await self.event_stream_handler.get_stream( requester.user.to_string(), @@ -70,25 +76,27 @@ class EventStreamRestServlet(RestServlet): class EventRestServlet(RestServlet): PATTERNS = client_patterns("/events/(?P[^/]*)$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() self.auth = hs.get_auth() self._event_serializer = hs.get_event_client_serializer() - async def on_GET(self, request, event_id): + async def on_GET( + self, request: SynapseRequest, event_id: str + ) -> Tuple[int, Union[str, JsonDict]]: requester = await self.auth.get_user_by_req(request) event = await self.event_handler.get_event(requester.user, None, event_id) time_now = self.clock.time_msec() if event: - event = await self._event_serializer.serialize_event(event, time_now) - return 200, event + result = await self._event_serializer.serialize_event(event, time_now) + return 200, result else: return 404, "Event not found." -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: EventStreamRestServlet(hs).register(http_server) EventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py index 411667a9c8..6ed60c7418 100644 --- a/synapse/rest/client/filter.py +++ b/synapse/rest/client/filter.py @@ -13,26 +13,34 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.types import UserID +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict, UserID from ._base import client_patterns, set_timeline_upper_limit +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class GetFilterRestServlet(RestServlet): PATTERNS = client_patterns("/user/(?P[^/]*)/filter/(?P[^/]*)") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.filtering = hs.get_filtering() - async def on_GET(self, request, user_id, filter_id): + async def on_GET( + self, request: SynapseRequest, user_id: str, filter_id: str + ) -> Tuple[int, JsonDict]: target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) @@ -43,13 +51,13 @@ class GetFilterRestServlet(RestServlet): raise AuthError(403, "Can only get filters for local users") try: - filter_id = int(filter_id) + filter_id_int = int(filter_id) except Exception: raise SynapseError(400, "Invalid filter_id") try: filter_collection = await self.filtering.get_user_filter( - user_localpart=target_user.localpart, filter_id=filter_id + user_localpart=target_user.localpart, filter_id=filter_id_int ) except StoreError as e: if e.code != 404: @@ -62,13 +70,15 @@ class GetFilterRestServlet(RestServlet): class CreateFilterRestServlet(RestServlet): PATTERNS = client_patterns("/user/(?P[^/]*)/filter") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.filtering = hs.get_filtering() - async def on_POST(self, request, user_id): + async def on_POST( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) @@ -89,6 +99,6 @@ class CreateFilterRestServlet(RestServlet): return 200, {"filter_id": str(filter_id)} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: GetFilterRestServlet(hs).register(http_server) CreateFilterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py index 6285680c00..c3667ff8aa 100644 --- a/synapse/rest/client/groups.py +++ b/synapse/rest/client/groups.py @@ -26,6 +26,7 @@ from synapse.api.constants import ( ) from synapse.api.errors import Codes, SynapseError from synapse.handlers.groups_local import GroupsLocalHandler +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -930,7 +931,7 @@ class GroupsForUserServlet(RestServlet): return 200, result -def register_servlets(hs: "HomeServer", http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: GroupServlet(hs).register(http_server) GroupSummaryServlet(hs).register(http_server) GroupInvitedUsersServlet(hs).register(http_server) diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py index 12ba0e91db..49b1037b28 100644 --- a/synapse/rest/client/initial_sync.py +++ b/synapse/rest/client/initial_sync.py @@ -12,25 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING, Dict, List, Tuple +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean +from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns from synapse.streams.config import PaginationConfig +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer # TODO: Needs unit testing class InitialSyncRestServlet(RestServlet): PATTERNS = client_patterns("/initialSync$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - as_client_event = b"raw" not in request.args + args: Dict[bytes, List[bytes]] = request.args # type: ignore + as_client_event = b"raw" not in args pagination_config = await PaginationConfig.from_request(self.store, request) include_archived = parse_boolean(request, "archived", default=False) content = await self.initial_sync_handler.snapshot_all_rooms( @@ -43,5 +51,5 @@ class InitialSyncRestServlet(RestServlet): return 200, content -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: InitialSyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 012491f597..7281b2ee29 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -15,20 +15,25 @@ # limitations under the License. import logging -from typing import Any +from typing import TYPE_CHECKING, Any, Optional, Tuple from synapse.api.errors import InvalidAPICallError, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_integer, parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.logging.opentracing import log_kv, set_tag, trace -from synapse.types import StreamToken +from synapse.types import JsonDict, StreamToken from ._base import client_patterns, interactive_auth_handler +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -60,18 +65,16 @@ class KeyUploadServlet(RestServlet): PATTERNS = client_patterns("/keys/upload(/(?P[^/]+))?$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() self.device_handler = hs.get_device_handler() @trace(opname="upload_keys") - async def on_POST(self, request, device_id): + async def on_POST( + self, request: SynapseRequest, device_id: Optional[str] + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -149,16 +152,12 @@ class KeyQueryServlet(RestServlet): PATTERNS = client_patterns("/keys/query$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() device_id = requester.device_id @@ -195,17 +194,13 @@ class KeyChangesServlet(RestServlet): PATTERNS = client_patterns("/keys/changes$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() self.store = hs.get_datastore() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) from_token_string = parse_string(request, "from", required=True) @@ -245,12 +240,12 @@ class OneTimeKeyServlet(RestServlet): PATTERNS = client_patterns("/keys/claim$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) @@ -269,11 +264,7 @@ class SigningKeyUploadServlet(RestServlet): PATTERNS = client_patterns("/keys/device_signing/upload$", releases=()) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -281,7 +272,7 @@ class SigningKeyUploadServlet(RestServlet): self.auth_handler = hs.get_auth_handler() @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -329,16 +320,12 @@ class SignaturesUploadServlet(RestServlet): PATTERNS = client_patterns("/keys/signatures/upload$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -349,7 +336,7 @@ class SignaturesUploadServlet(RestServlet): return 200, result -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: KeyUploadServlet(hs).register(http_server) KeyQueryServlet(hs).register(http_server) KeyChangesServlet(hs).register(http_server) diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py index 7d1bc40658..68fb08d0ba 100644 --- a/synapse/rest/client/knock.py +++ b/synapse/rest/client/knock.py @@ -19,6 +19,7 @@ from twisted.web.server import Request from synapse.api.constants import Membership from synapse.api.errors import SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, @@ -103,5 +104,5 @@ class KnockRoomAliasServlet(RestServlet): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: KnockRoomAliasServlet(hs).register(http_server) diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 11d07776b2..4be502a77b 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -1,4 +1,4 @@ -# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ import logging import re -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple from typing_extensions import TypedDict @@ -110,7 +110,7 @@ class LoginRestServlet(RestServlet): # counters are initialised for the auth_provider_ids. _load_sso_handlers(hs) - def on_GET(self, request: SynapseRequest): + def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: flows = [] if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) @@ -157,7 +157,7 @@ class LoginRestServlet(RestServlet): return 200, {"flows": flows} - async def on_POST(self, request: SynapseRequest): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]: login_submission = parse_json_object_from_request(request) if self._msc2918_enabled: @@ -217,7 +217,7 @@ class LoginRestServlet(RestServlet): login_submission: JsonDict, appservice: ApplicationService, should_issue_refresh_token: bool = False, - ): + ) -> LoginResponse: identifier = login_submission.get("identifier") logger.info("Got appservice login request with identifier: %r", identifier) @@ -467,10 +467,7 @@ class RefreshTokenServlet(RestServlet): self._clock = hs.get_clock() self.access_token_lifetime = hs.config.access_token_lifetime - async def on_POST( - self, - request: SynapseRequest, - ): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: refresh_submission = parse_json_object_from_request(request) assert_params_in_dict(refresh_submission, ["refresh_token"]) @@ -570,7 +567,7 @@ class SsoRedirectServlet(RestServlet): class CasTicketServlet(RestServlet): PATTERNS = client_patterns("/login/cas/ticket", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self._cas_handler = hs.get_cas_handler() @@ -592,7 +589,7 @@ class CasTicketServlet(RestServlet): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: LoginRestServlet(hs).register(http_server) if hs.config.access_token_lifetime is not None: RefreshTokenServlet(hs).register(http_server) @@ -601,7 +598,7 @@ def register_servlets(hs, http_server): CasTicketServlet(hs).register(http_server) -def _load_sso_handlers(hs: "HomeServer"): +def _load_sso_handlers(hs: "HomeServer") -> None: """Ensure that the SSO handlers are loaded, if they are enabled by configuration. This is mostly useful to ensure that the CAS/SAML/OIDC handlers register themselves diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py index 6055cac2bd..193a6951b9 100644 --- a/synapse/rest/client/logout.py +++ b/synapse/rest/client/logout.py @@ -13,9 +13,16 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -23,13 +30,13 @@ logger = logging.getLogger(__name__) class LogoutRestServlet(RestServlet): PATTERNS = client_patterns("/logout$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) if requester.device_id is None: @@ -48,13 +55,13 @@ class LogoutRestServlet(RestServlet): class LogoutAllRestServlet(RestServlet): PATTERNS = client_patterns("/logout/all$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() @@ -67,6 +74,6 @@ class LogoutAllRestServlet(RestServlet): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: LogoutRestServlet(hs).register(http_server) LogoutAllRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index 0ede643c2d..d1d8a984c6 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -13,26 +13,33 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.events.utils import format_event_for_client_v2_without_room_id +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class NotificationsServlet(RestServlet): PATTERNS = client_patterns("/notifications$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.store = hs.get_datastore() self.auth = hs.get_auth() self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -87,5 +94,5 @@ class NotificationsServlet(RestServlet): return 200, {"notifications": returned_push_actions, "next_token": next_token} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: NotificationsServlet(hs).register(http_server) diff --git a/synapse/rest/client/openid.py b/synapse/rest/client/openid.py index e8d2673819..4dda6dce4b 100644 --- a/synapse/rest/client/openid.py +++ b/synapse/rest/client/openid.py @@ -12,15 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. - import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import AuthError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from synapse.util.stringutils import random_string from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -58,14 +64,16 @@ class IdTokenServlet(RestServlet): EXPIRES_MS = 3600 * 1000 - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() self.server_name = hs.config.server_name - async def on_POST(self, request, user_id): + async def on_POST( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot request tokens for other users.") @@ -90,5 +98,5 @@ class IdTokenServlet(RestServlet): ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: IdTokenServlet(hs).register(http_server) diff --git a/synapse/rest/client/password_policy.py b/synapse/rest/client/password_policy.py index a83927aee6..6d64efb165 100644 --- a/synapse/rest/client/password_policy.py +++ b/synapse/rest/client/password_policy.py @@ -13,28 +13,32 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple +from twisted.web.server import Request + +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class PasswordPolicyServlet(RestServlet): PATTERNS = client_patterns("/password_policy$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.policy = hs.config.password_policy self.enabled = hs.config.password_policy_enabled - def on_GET(self, request): + def on_GET(self, request: Request) -> Tuple[int, JsonDict]: if not self.enabled or not self.policy: return (200, {}) @@ -53,5 +57,5 @@ class PasswordPolicyServlet(RestServlet): return (200, policy) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: PasswordPolicyServlet(hs).register(http_server) diff --git a/synapse/rest/client/presence.py b/synapse/rest/client/presence.py index 6c27e5faf9..94dd4fe2f4 100644 --- a/synapse/rest/client/presence.py +++ b/synapse/rest/client/presence.py @@ -15,12 +15,18 @@ """ This module contains REST servlets to do with presence: /presence/ """ import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import AuthError, SynapseError from synapse.handlers.presence import format_user_presence_state +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.types import UserID +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -28,7 +34,7 @@ logger = logging.getLogger(__name__) class PresenceStatusRestServlet(RestServlet): PATTERNS = client_patterns("/presence/(?P[^/]*)/status", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.presence_handler = hs.get_presence_handler() @@ -37,7 +43,9 @@ class PresenceStatusRestServlet(RestServlet): self._use_presence = hs.config.server.use_presence - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) @@ -53,13 +61,15 @@ class PresenceStatusRestServlet(RestServlet): raise AuthError(403, "You are not allowed to see their presence.") state = await self.presence_handler.get_state(target_user=user) - state = format_user_presence_state( + result = format_user_presence_state( state, self.clock.time_msec(), include_user_id=False ) - return 200, state + return 200, result - async def on_PUT(self, request, user_id): + async def on_PUT( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) @@ -91,5 +101,5 @@ class PresenceStatusRestServlet(RestServlet): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: PresenceStatusRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py index 5463ed2c4f..d0f20de569 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py @@ -14,22 +14,31 @@ """ This module contains REST servlets to do with profile: /profile/ """ +from typing import TYPE_CHECKING, Tuple + from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.types import UserID +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer class ProfileDisplaynameRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P[^/]*)/displayname", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester_user = None if self.hs.config.require_auth_for_profile_requests: @@ -48,7 +57,9 @@ class ProfileDisplaynameRestServlet(RestServlet): return 200, ret - async def on_PUT(self, request, user_id): + async def on_PUT( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) user = UserID.from_string(user_id) is_admin = await self.auth.is_server_admin(requester.user) @@ -72,13 +83,15 @@ class ProfileDisplaynameRestServlet(RestServlet): class ProfileAvatarURLRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P[^/]*)/avatar_url", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester_user = None if self.hs.config.require_auth_for_profile_requests: @@ -97,7 +110,9 @@ class ProfileAvatarURLRestServlet(RestServlet): return 200, ret - async def on_PUT(self, request, user_id): + async def on_PUT( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) is_admin = await self.auth.is_server_admin(requester.user) @@ -120,13 +135,15 @@ class ProfileAvatarURLRestServlet(RestServlet): class ProfileRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P[^/]*)", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - async def on_GET(self, request, user_id): + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: requester_user = None if self.hs.config.require_auth_for_profile_requests: @@ -149,7 +166,7 @@ class ProfileRestServlet(RestServlet): return 200, ret -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ProfileDisplaynameRestServlet(hs).register(http_server) ProfileAvatarURLRestServlet(hs).register(http_server) ProfileRestServlet(hs).register(http_server) -- cgit 1.5.1 From ad17fbd20eb2dd9fb10a3d02ab1b69e9a0d5b50c Mon Sep 17 00:00:00 2001 From: Azrenbeth <77782548+Azrenbeth@users.noreply.github.com> Date: Thu, 26 Aug 2021 13:53:57 +0100 Subject: Remove pushers when deleting 3pid from account (#10581) When a user deletes an email from their account it will now also remove all pushers for that email and that user (even if these pushers were created by a different client) --- CHANGES.md | 2 + changelog.d/10581.bugfix | 1 + docs/upgrade.md | 5 ++ synapse/handlers/auth.py | 5 +- synapse/storage/databases/main/pusher.py | 72 ++++++++++++++++++++++ .../delta/63/02delete_unlinked_email_pushers.sql | 20 ++++++ tests/push/test_email.py | 39 ++++++++++++ 7 files changed, 143 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10581.bugfix create mode 100644 synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql diff --git a/CHANGES.md b/CHANGES.md index f8da8771aa..24f3d53a6d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,5 @@ +Users will stop receiving message updates via email for addresses that were previously linked to their account + Synapse 1.41.0 (2021-08-24) =========================== diff --git a/changelog.d/10581.bugfix b/changelog.d/10581.bugfix new file mode 100644 index 0000000000..15c7da4497 --- /dev/null +++ b/changelog.d/10581.bugfix @@ -0,0 +1 @@ +Remove pushers when deleting a 3pid from an account. Pushers for old unlinked emails will also be deleted. \ No newline at end of file diff --git a/docs/upgrade.md b/docs/upgrade.md index 6d4b8cb48e..dcf0a7db5b 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -107,6 +107,11 @@ This may affect you if you make use of custom HTML templates for the The template is now provided an `error` variable if the authentication process failed. See the default templates linked above for an example. +# Upgrading to v1.42.0 + +## Removal of out-of-date email pushers +Users will stop receiving message updates via email for addresses that were +once, but not still, linked to their account. # Upgrading to v1.41.0 diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 98d3d2d97f..34725324a6 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1464,6 +1464,10 @@ class AuthHandler(BaseHandler): ) await self.store.user_delete_threepid(user_id, medium, address) + if medium == "email": + await self.store.delete_pusher_by_app_id_pushkey_user_id( + app_id="m.email", pushkey=address, user_id=user_id + ) return result async def hash(self, password: str) -> str: @@ -1732,7 +1736,6 @@ class AuthHandler(BaseHandler): @attr.s(slots=True) class MacaroonGenerator: - hs = attr.ib() def generate_guest_access_token(self, user_id: str) -> str: diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index b48fe086d4..e47caa2125 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -48,6 +48,11 @@ class PusherWorkerStore(SQLBaseStore): self._remove_stale_pushers, ) + self.db_pool.updates.register_background_update_handler( + "remove_deleted_email_pushers", + self._remove_deleted_email_pushers, + ) + def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]: """JSON-decode the data in the rows returned from the `pushers` table @@ -388,6 +393,73 @@ class PusherWorkerStore(SQLBaseStore): return number_deleted + async def _remove_deleted_email_pushers( + self, progress: dict, batch_size: int + ) -> int: + """A background update that deletes all pushers for deleted email addresses. + + In previous versions of synapse, when users deleted their email address, it didn't + also delete all the pushers for that email address. This background update removes + those to prevent unwanted emails. This should only need to be run once (when users + upgrade to v1.42.0 + + Args: + progress: dict used to store progress of this background update + batch_size: the maximum number of rows to retrieve in a single select query + + Returns: + The number of deleted rows + """ + + last_pusher = progress.get("last_pusher", 0) + + def _delete_pushers(txn) -> int: + + sql = """ + SELECT p.id, p.user_name, p.app_id, p.pushkey + FROM pushers AS p + LEFT JOIN user_threepids AS t + ON t.user_id = p.user_name + AND t.medium = 'email' + AND t.address = p.pushkey + WHERE t.user_id is NULL + AND p.app_id = 'm.email' + AND p.id > ? + ORDER BY p.id ASC + LIMIT ? + """ + + txn.execute(sql, (last_pusher, batch_size)) + + last = None + num_deleted = 0 + for row in txn: + last = row[0] + num_deleted += 1 + self.db_pool.simple_delete_txn( + txn, + "pushers", + {"user_name": row[1], "app_id": row[2], "pushkey": row[3]}, + ) + + if last is not None: + self.db_pool.updates._background_update_progress_txn( + txn, "remove_deleted_email_pushers", {"last_pusher": last} + ) + + return num_deleted + + number_deleted = await self.db_pool.runInteraction( + "_remove_deleted_email_pushers", _delete_pushers + ) + + if number_deleted < batch_size: + await self.db_pool.updates._end_background_update( + "remove_deleted_email_pushers" + ) + + return number_deleted + class PusherStore(PusherWorkerStore): def get_pushers_stream_token(self) -> int: diff --git a/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql b/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql new file mode 100644 index 0000000000..611c4b95cf --- /dev/null +++ b/synapse/storage/schema/main/delta/63/02delete_unlinked_email_pushers.sql @@ -0,0 +1,20 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- We may not have deleted all pushers for emails that are no longer linked +-- to an account, so we set up a background job to delete them. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6302, 'remove_deleted_email_pushers', '{}'); diff --git a/tests/push/test_email.py b/tests/push/test_email.py index e0a3342088..eea07485a0 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -125,6 +125,8 @@ class EmailPusherTests(HomeserverTestCase): ) ) + self.auth_handler = hs.get_auth_handler() + def test_need_validated_email(self): """Test that we can only add an email pusher if the user has validated their email. @@ -305,6 +307,43 @@ class EmailPusherTests(HomeserverTestCase): # We should get emailed about that message self._check_for_mail() + def test_no_email_sent_after_removed(self): + # Create a simple room with two users + room = self.helper.create_room_as(self.user_id, tok=self.access_token) + self.helper.invite( + room=room, + src=self.user_id, + tok=self.access_token, + targ=self.others[0].id, + ) + self.helper.join( + room=room, + user=self.others[0].id, + tok=self.others[0].token, + ) + + # The other user sends a single message. + self.helper.send(room, body="Hi!", tok=self.others[0].token) + + # We should get emailed about that message + self._check_for_mail() + + # disassociate the user's email address + self.get_success( + self.auth_handler.delete_threepid( + user_id=self.user_id, + medium="email", + address="a@example.com", + ) + ) + + # check that the pusher for that email address has been deleted + pushers = self.get_success( + self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + ) + pushers = list(pushers) + self.assertEqual(len(pushers), 0) + def _check_for_mail(self): """Check that the user receives an email notification""" -- cgit 1.5.1 From 40f619eaa54d2391deccec473fc0f655c379e766 Mon Sep 17 00:00:00 2001 From: Aaron Raimist Date: Thu, 26 Aug 2021 11:07:58 -0500 Subject: Validate new m.room.power_levels events (#10232) Signed-off-by: Aaron Raimist --- changelog.d/10232.bugfix | 1 + synapse/events/utils.py | 5 ++- synapse/events/validator.py | 77 ++++++++++++++++++++++++++++++++- synapse/python_dependencies.py | 3 +- tests/rest/client/test_power_levels.py | 78 ++++++++++++++++++++++++++++++++++ 5 files changed, 160 insertions(+), 4 deletions(-) create mode 100644 changelog.d/10232.bugfix diff --git a/changelog.d/10232.bugfix b/changelog.d/10232.bugfix new file mode 100644 index 0000000000..7be72271e0 --- /dev/null +++ b/changelog.d/10232.bugfix @@ -0,0 +1 @@ +Validate new `m.room.power_levels` events. Contributed by @aaronraimist. \ No newline at end of file diff --git a/synapse/events/utils.py b/synapse/events/utils.py index b6da2f60af..738a151cef 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -32,6 +32,9 @@ from . import EventBase # the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar" SPLIT_FIELD_REGEX = re.compile(r"(? EventBase: """Returns a pruned version of the given event, which removes all keys we @@ -505,7 +508,7 @@ def validate_canonicaljson(value: Any): * NaN, Infinity, -Infinity """ if isinstance(value, int): - if value <= -(2 ** 53) or 2 ** 53 <= value: + if value < CANONICALJSON_MIN_INT or CANONICALJSON_MAX_INT < value: raise SynapseError(400, "JSON integer out of range", Codes.BAD_JSON) elif isinstance(value, float): diff --git a/synapse/events/validator.py b/synapse/events/validator.py index fa6987d7cb..33954b4f62 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -11,16 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import collections.abc from typing import Union +import jsonschema + from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import EventFormatVersions from synapse.config.homeserver import HomeServerConfig from synapse.events import EventBase from synapse.events.builder import EventBuilder -from synapse.events.utils import validate_canonicaljson +from synapse.events.utils import ( + CANONICALJSON_MAX_INT, + CANONICALJSON_MIN_INT, + validate_canonicaljson, +) from synapse.federation.federation_server import server_matches_acl_event from synapse.types import EventID, RoomID, UserID @@ -87,6 +93,29 @@ class EventValidator: 400, "Can't create an ACL event that denies the local server" ) + if event.type == EventTypes.PowerLevels: + try: + jsonschema.validate( + instance=event.content, + schema=POWER_LEVELS_SCHEMA, + cls=plValidator, + ) + except jsonschema.ValidationError as e: + if e.path: + # example: "users_default": '0' is not of type 'integer' + message = '"' + e.path[-1] + '": ' + e.message # noqa: B306 + # jsonschema.ValidationError.message is a valid attribute + else: + # example: '0' is not of type 'integer' + message = e.message # noqa: B306 + # jsonschema.ValidationError.message is a valid attribute + + raise SynapseError( + code=400, + msg=message, + errcode=Codes.BAD_JSON, + ) + def _validate_retention(self, event: EventBase): """Checks that an event that defines the retention policy for a room respects the format enforced by the spec. @@ -185,3 +214,47 @@ class EventValidator: def _ensure_state_event(self, event): if not event.is_state(): raise SynapseError(400, "'%s' must be state events" % (event.type,)) + + +POWER_LEVELS_SCHEMA = { + "type": "object", + "properties": { + "ban": {"$ref": "#/definitions/int"}, + "events": {"$ref": "#/definitions/objectOfInts"}, + "events_default": {"$ref": "#/definitions/int"}, + "invite": {"$ref": "#/definitions/int"}, + "kick": {"$ref": "#/definitions/int"}, + "notifications": {"$ref": "#/definitions/objectOfInts"}, + "redact": {"$ref": "#/definitions/int"}, + "state_default": {"$ref": "#/definitions/int"}, + "users": {"$ref": "#/definitions/objectOfInts"}, + "users_default": {"$ref": "#/definitions/int"}, + }, + "definitions": { + "int": { + "type": "integer", + "minimum": CANONICALJSON_MIN_INT, + "maximum": CANONICALJSON_MAX_INT, + }, + "objectOfInts": { + "type": "object", + "additionalProperties": {"$ref": "#/definitions/int"}, + }, + }, +} + + +def _create_power_level_validator(): + validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA) + + # by default jsonschema does not consider a frozendict to be an object so + # we need to use a custom type checker + # https://python-jsonschema.readthedocs.io/en/stable/validate/?highlight=object#validating-with-additional-types + type_checker = validator.TYPE_CHECKER.redefine( + "object", lambda checker, thing: isinstance(thing, collections.abc.Mapping) + ) + + return jsonschema.validators.extend(validator, type_checker=type_checker) + + +plValidator = _create_power_level_validator() diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index cdcbdd772b..154e5b7028 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -48,7 +48,8 @@ logger = logging.getLogger(__name__) # [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers. REQUIREMENTS = [ - "jsonschema>=2.5.1", + # we use the TYPE_CHECKER.redefine method added in jsonschema 3.0.0 + "jsonschema>=3.0.0", "frozendict>=1", "unpaddedbase64>=1.1.0", "canonicaljson>=1.4.0", diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py index 91d0762cb0..c0de4c93a8 100644 --- a/tests/rest/client/test_power_levels.py +++ b/tests/rest/client/test_power_levels.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.errors import Codes +from synapse.events.utils import CANONICALJSON_MAX_INT, CANONICALJSON_MIN_INT from synapse.rest import admin from synapse.rest.client import login, room, sync @@ -203,3 +205,79 @@ class PowerLevelsTestCase(HomeserverTestCase): tok=self.admin_access_token, expect_code=200, # expect success ) + + def test_cannot_set_string_power_levels(self): + room_power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.admin_access_token, + ) + + # Update existing power levels with user at PL "0" + room_power_levels["users"].update({self.user_user_id: "0"}) + + body = self.helper.send_state( + self.room_id, + "m.room.power_levels", + room_power_levels, + tok=self.admin_access_token, + expect_code=400, # expect failure + ) + + self.assertEqual( + body["errcode"], + Codes.BAD_JSON, + body, + ) + + def test_cannot_set_unsafe_large_power_levels(self): + room_power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.admin_access_token, + ) + + # Update existing power levels with user at PL above the max safe integer + room_power_levels["users"].update( + {self.user_user_id: CANONICALJSON_MAX_INT + 1} + ) + + body = self.helper.send_state( + self.room_id, + "m.room.power_levels", + room_power_levels, + tok=self.admin_access_token, + expect_code=400, # expect failure + ) + + self.assertEqual( + body["errcode"], + Codes.BAD_JSON, + body, + ) + + def test_cannot_set_unsafe_small_power_levels(self): + room_power_levels = self.helper.get_state( + self.room_id, + "m.room.power_levels", + tok=self.admin_access_token, + ) + + # Update existing power levels with user at PL below the minimum safe integer + room_power_levels["users"].update( + {self.user_user_id: CANONICALJSON_MIN_INT - 1} + ) + + body = self.helper.send_state( + self.room_id, + "m.room.power_levels", + room_power_levels, + tok=self.admin_access_token, + expect_code=400, # expect failure + ) + + self.assertEqual( + body["errcode"], + Codes.BAD_JSON, + body, + ) -- cgit 1.5.1 From 96715d763362a7027c39d571cfde3aa8b7b82fcf Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 26 Aug 2021 18:34:57 +0100 Subject: Make `backfill` and `get_missing_events` use the same codepath (#10645) Given that backfill and get_missing_events are basically the same thing, it's somewhat crazy that we have entirely separate code paths for them. This makes backfill use the existing get_missing_events code, and then clears up all the unused code. --- changelog.d/10645.misc | 1 + synapse/handlers/federation.py | 273 ++++--------------------- synapse/storage/databases/main/purge_events.py | 1 + 3 files changed, 42 insertions(+), 233 deletions(-) create mode 100644 changelog.d/10645.misc diff --git a/changelog.d/10645.misc b/changelog.d/10645.misc new file mode 100644 index 0000000000..ac19263cd8 --- /dev/null +++ b/changelog.d/10645.misc @@ -0,0 +1 @@ +Make `backfill` and `get_missing_events` use the same codepath. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 246df43501..6fa2fc8f52 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -65,6 +65,7 @@ from synapse.event_auth import auth_types_for_event from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator +from synapse.federation.federation_client import InvalidResponseError from synapse.handlers._base import BaseHandler from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( @@ -116,10 +117,6 @@ class _NewEventInfo: Attributes: event: the received event - state: the state at that event, according to /state_ids from a remote - homeserver. Only populated for backfilled events which are going to be a - new backwards extremity. - claimed_auth_event_map: a map of (type, state_key) => event for the event's claimed auth_events. @@ -134,7 +131,6 @@ class _NewEventInfo: """ event: EventBase - state: Optional[Sequence[EventBase]] claimed_auth_event_map: StateMap[EventBase] @@ -443,113 +439,7 @@ class FederationHandler(BaseHandler): return logger.info("Got %d prev_events", len(missing_events)) - await self._process_pulled_events(origin, missing_events) - - async def _get_state_for_room( - self, - destination: str, - room_id: str, - event_id: str, - ) -> List[EventBase]: - """Requests all of the room state at a given event from a remote - homeserver. - - Will also fetch any missing events reported in the `auth_chain_ids` - section of `/state_ids`. - - Args: - destination: The remote homeserver to query for the state. - room_id: The id of the room we're interested in. - event_id: The id of the event we want the state at. - - Returns: - A list of events in the state, not including the event itself. - """ - ( - state_event_ids, - auth_event_ids, - ) = await self.federation_client.get_room_state_ids( - destination, room_id, event_id=event_id - ) - - # Fetch the state events from the DB, and check we have the auth events. - event_map = await self.store.get_events(state_event_ids, allow_rejected=True) - auth_events_in_store = await self.store.have_seen_events( - room_id, auth_event_ids - ) - - # Check for missing events. We handle state and auth event seperately, - # as we want to pull the state from the DB, but we don't for the auth - # events. (Note: we likely won't use the majority of the auth chain, and - # it can be *huge* for large rooms, so it's worth ensuring that we don't - # unnecessarily pull it from the DB). - missing_state_events = set(state_event_ids) - set(event_map) - missing_auth_events = set(auth_event_ids) - set(auth_events_in_store) - if missing_state_events or missing_auth_events: - await self._get_events_and_persist( - destination=destination, - room_id=room_id, - events=missing_state_events | missing_auth_events, - ) - - if missing_state_events: - new_events = await self.store.get_events( - missing_state_events, allow_rejected=True - ) - event_map.update(new_events) - - missing_state_events.difference_update(new_events) - - if missing_state_events: - logger.warning( - "Failed to fetch missing state events for %s %s", - event_id, - missing_state_events, - ) - - if missing_auth_events: - auth_events_in_store = await self.store.have_seen_events( - room_id, missing_auth_events - ) - missing_auth_events.difference_update(auth_events_in_store) - - if missing_auth_events: - logger.warning( - "Failed to fetch missing auth events for %s %s", - event_id, - missing_auth_events, - ) - - remote_state = list(event_map.values()) - - # check for events which were in the wrong room. - # - # this can happen if a remote server claims that the state or - # auth_events at an event in room A are actually events in room B - - bad_events = [ - (event.event_id, event.room_id) - for event in remote_state - if event.room_id != room_id - ] - - for bad_event_id, bad_room_id in bad_events: - # This is a bogus situation, but since we may only discover it a long time - # after it happened, we try our best to carry on, by just omitting the - # bad events from the returned auth/state set. - logger.warning( - "Remote server %s claims event %s in room %s is an auth/state " - "event in room %s", - destination, - bad_event_id, - bad_room_id, - room_id, - ) - - if bad_events: - remote_state = [e for e in remote_state if e.room_id == room_id] - - return remote_state + await self._process_pulled_events(origin, missing_events, backfilled=False) async def _get_state_after_missing_prev_event( self, @@ -567,10 +457,6 @@ class FederationHandler(BaseHandler): Returns: A list of events in the state, including the event itself """ - # TODO: This function is basically the same as _get_state_for_room. Can - # we make backfill() use it, rather than having two code paths? I think the - # only difference is that backfill() persists the prev events separately. - ( state_event_ids, auth_event_ids, @@ -681,6 +567,7 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, state: Optional[Iterable[EventBase]], + backfilled: bool = False, ) -> None: """Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. @@ -693,6 +580,9 @@ class FederationHandler(BaseHandler): state: Normally None, but if we are handling a gap in the graph (ie, we are missing one or more prev_events), the resolved state at the event + + backfilled: True if this is part of a historical batch of events (inhibits + notification to clients, and validation of device keys.) """ logger.debug("Processing event: %s", event) @@ -700,10 +590,15 @@ class FederationHandler(BaseHandler): context = await self.state_handler.compute_event_context( event, old_state=state ) - await self._auth_and_persist_event(origin, event, context, state=state) + await self._auth_and_persist_event( + origin, event, context, state=state, backfilled=backfilled + ) except AuthError as e: raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) + if backfilled: + return + # For encrypted messages we check that we know about the sending device, # if we don't then we mark the device cache for that user as stale. if event.type == EventTypes.Encrypted: @@ -868,7 +763,7 @@ class FederationHandler(BaseHandler): @log_function async def backfill( self, dest: str, room_id: str, limit: int, extremities: List[str] - ) -> List[EventBase]: + ) -> None: """Trigger a backfill request to `dest` for the given `room_id` This will attempt to get more events from the remote. If the other side @@ -878,6 +773,9 @@ class FederationHandler(BaseHandler): sanity-checking on them. If any of the backfilled events are invalid, this method throws a SynapseError. + We might also raise an InvalidResponseError if the response from the remote + server is just bogus. + TODO: make this more useful to distinguish failures of the remote server from invalid events (there is probably no point in trying to re-fetch invalid events from every other HS in the room.) @@ -890,111 +788,18 @@ class FederationHandler(BaseHandler): ) if not events: - return [] - - # ideally we'd sanity check the events here for excess prev_events etc, - # but it's hard to reject events at this point without completely - # breaking backfill in the same way that it is currently broken by - # events whose signature we cannot verify (#3121). - # - # So for now we accept the events anyway. #3124 tracks this. - # - # for ev in events: - # self._sanity_check_event(ev) - - # Don't bother processing events we already have. - seen_events = await self.store.have_events_in_timeline( - {e.event_id for e in events} - ) - - events = [e for e in events if e.event_id not in seen_events] - - if not events: - return [] - - event_map = {e.event_id: e for e in events} - - event_ids = {e.event_id for e in events} - - # build a list of events whose prev_events weren't in the batch. - # (XXX: this will include events whose prev_events we already have; that doesn't - # sound right?) - edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids] - - logger.info("backfill: Got %d events with %d edges", len(events), len(edges)) - - # For each edge get the current state. - - state_events = {} - events_to_state = {} - for e_id in edges: - state = await self._get_state_for_room( - destination=dest, - room_id=room_id, - event_id=e_id, - ) - state_events.update({s.event_id: s for s in state}) - events_to_state[e_id] = state + return - required_auth = { - a_id - for event in events + list(state_events.values()) - for a_id in event.auth_event_ids() - } - auth_events = await self.store.get_events(required_auth, allow_rejected=True) - auth_events.update( - {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} - ) - - ev_infos = [] - - # Step 1: persist the events in the chunk we fetched state for (i.e. - # the backwards extremities), with custom auth events and state - for e_id in events_to_state: - # For paranoia we ensure that these events are marked as - # non-outliers - ev = event_map[e_id] - assert not ev.internal_metadata.is_outlier() - - ev_infos.append( - _NewEventInfo( - event=ev, - state=events_to_state[e_id], - claimed_auth_event_map={ - ( - auth_events[a_id].type, - auth_events[a_id].state_key, - ): auth_events[a_id] - for a_id in ev.auth_event_ids() - if a_id in auth_events - }, + # if there are any events in the wrong room, the remote server is buggy and + # should not be trusted. + for ev in events: + if ev.room_id != room_id: + raise InvalidResponseError( + f"Remote server {dest} returned event {ev.event_id} which is in " + f"room {ev.room_id}, when we were backfilling in {room_id}" ) - ) - - if ev_infos: - await self._auth_and_persist_events( - dest, room_id, ev_infos, backfilled=True - ) - - # Step 2: Persist the rest of the events in the chunk one by one - events.sort(key=lambda e: e.depth) - - for event in events: - if event in events_to_state: - continue - - # For paranoia we ensure that these events are marked as - # non-outliers - assert not event.internal_metadata.is_outlier() - - context = await self.state_handler.compute_event_context(event) - - # We store these one at a time since each event depends on the - # previous to work out the state. - # TODO: We can probably do something more clever here. - await self._auth_and_persist_event(dest, event, context, backfilled=True) - return events + await self._process_pulled_events(dest, events, backfilled=True) async def maybe_backfill( self, room_id: str, current_depth: int, limit: int @@ -1197,7 +1002,7 @@ class FederationHandler(BaseHandler): # appropriate stuff. # TODO: We can probably do something more intelligent here. return True - except SynapseError as e: + except (SynapseError, InvalidResponseError) as e: logger.info("Failed to backfill from %s because %s", dom, e) continue except HttpResponseException as e: @@ -1351,7 +1156,7 @@ class FederationHandler(BaseHandler): else: logger.info("Missing auth event %s", auth_event_id) - event_infos.append(_NewEventInfo(event, None, auth)) + event_infos.append(_NewEventInfo(event, auth)) if event_infos: await self._auth_and_persist_events( @@ -1361,7 +1166,7 @@ class FederationHandler(BaseHandler): ) async def _process_pulled_events( - self, origin: str, events: Iterable[EventBase] + self, origin: str, events: Iterable[EventBase], backfilled: bool ) -> None: """Process a batch of events we have pulled from a remote server @@ -1373,6 +1178,8 @@ class FederationHandler(BaseHandler): Params: origin: The server we received these events from events: The received events. + backfilled: True if this is part of a historical batch of events (inhibits + notification to clients, and validation of device keys.) """ # We want to sort these by depth so we process them and @@ -1381,9 +1188,11 @@ class FederationHandler(BaseHandler): for ev in sorted_events: with nested_logging_context(ev.event_id): - await self._process_pulled_event(origin, ev) + await self._process_pulled_event(origin, ev, backfilled=backfilled) - async def _process_pulled_event(self, origin: str, event: EventBase) -> None: + async def _process_pulled_event( + self, origin: str, event: EventBase, backfilled: bool + ) -> None: """Process a single event that we have pulled from a remote server Pulls in any events required to auth the event, persists the received event, @@ -1400,6 +1209,8 @@ class FederationHandler(BaseHandler): Params: origin: The server we received this event from events: The received event + backfilled: True if this is part of a historical batch of events (inhibits + notification to clients, and validation of device keys.) """ logger.info("Processing pulled event %s", event) @@ -1428,7 +1239,9 @@ class FederationHandler(BaseHandler): try: state = await self._resolve_state_at_missing_prevs(origin, event) - await self._process_received_pdu(origin, event, state=state) + await self._process_received_pdu( + origin, event, state=state, backfilled=backfilled + ) except FederationError as e: if e.code == 403: logger.warning("Pulled event %s failed history check.", event_id) @@ -2451,7 +2264,6 @@ class FederationHandler(BaseHandler): origin: str, room_id: str, event_infos: Collection[_NewEventInfo], - backfilled: bool = False, ) -> None: """Creates the appropriate contexts and persists events. The events should not depend on one another, e.g. this should be used to persist @@ -2467,16 +2279,12 @@ class FederationHandler(BaseHandler): async def prep(ev_info: _NewEventInfo): event = ev_info.event with nested_logging_context(suffix=event.event_id): - res = await self.state_handler.compute_event_context( - event, old_state=ev_info.state - ) + res = await self.state_handler.compute_event_context(event) res = await self._check_event_auth( origin, event, res, - state=ev_info.state, claimed_auth_event_map=ev_info.claimed_auth_event_map, - backfilled=backfilled, ) return res @@ -2493,7 +2301,6 @@ class FederationHandler(BaseHandler): (ev_info.event, context) for ev_info, context in zip(event_infos, contexts) ], - backfilled=backfilled, ) async def _persist_auth_tree( diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 664c65dac5..bccff5e5b9 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -295,6 +295,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): self._invalidate_cache_and_stream( txn, self.have_seen_event, (room_id, event_id) ) + self._invalidate_get_event_cache(event_id) logger.info("[purge] done") -- cgit 1.5.1 From 1800aabfc226938036479d2ab1a750aa34cf3974 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 26 Aug 2021 21:41:44 +0100 Subject: Split `FederationHandler` in half (#10692) The idea here is to take anything to do with incoming events and move it out to a separate handler, as a way of making FederationHandler smaller. --- changelog.d/10692.misc | 1 + synapse/federation/federation_server.py | 7 +- synapse/handlers/federation.py | 1789 +------------------- synapse/handlers/federation_event.py | 1825 +++++++++++++++++++++ synapse/replication/http/federation.py | 4 +- synapse/server.py | 5 + tests/federation/transport/test_knocking.py | 2 +- tests/handlers/test_federation.py | 16 +- tests/handlers/test_presence.py | 4 +- tests/replication/test_federation_sender_shard.py | 2 +- tests/test_federation.py | 10 +- 11 files changed, 1884 insertions(+), 1781 deletions(-) create mode 100644 changelog.d/10692.misc create mode 100644 synapse/handlers/federation_event.py diff --git a/changelog.d/10692.misc b/changelog.d/10692.misc new file mode 100644 index 0000000000..a1b0def76b --- /dev/null +++ b/changelog.d/10692.misc @@ -0,0 +1 @@ +Split the event-processing methods in `FederationHandler` into a separate `FederationEventHandler`. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e1b58d40c5..214ee948fa 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -110,6 +110,7 @@ class FederationServer(FederationBase): super().__init__(hs) self.handler = hs.get_federation_handler() + self._federation_event_handler = hs.get_federation_event_handler() self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() @@ -787,7 +788,9 @@ class FederationServer(FederationBase): event = await self._check_sigs_and_hash(room_version, event) - return await self.handler.on_send_membership_event(origin, event) + return await self._federation_event_handler.on_send_membership_event( + origin, event + ) async def on_event_auth( self, origin: str, room_id: str, event_id: str @@ -1005,7 +1008,7 @@ class FederationServer(FederationBase): async with lock: logger.info("handling received PDU: %s", event) try: - await self.handler.on_receive_pdu(origin, event) + await self._federation_event_handler.on_receive_pdu(origin, event) except FederationError as e: # XXX: Ideally we'd inform the remote we failed to process # the event, but we can't return an error in the transaction diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 6fa2fc8f52..daf1d3bfb3 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -17,23 +17,9 @@ import itertools import logging -from collections.abc import Container from http import HTTPStatus -from typing import ( - TYPE_CHECKING, - Collection, - Dict, - Iterable, - List, - Optional, - Sequence, - Set, - Tuple, - Union, -) +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union -import attr -from prometheus_client import Counter from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 @@ -41,19 +27,12 @@ from unpaddedbase64 import decode_base64 from twisted.internet import defer from synapse import event_auth -from synapse.api.constants import ( - EventContentFields, - EventTypes, - Membership, - RejectedReason, - RoomEncryptionAlgorithms, -) +from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.errors import ( AuthError, CodeMessageException, Codes, FederationDeniedError, - FederationError, HttpResponseException, NotFoundError, RequestSendFailed, @@ -61,7 +40,6 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions from synapse.crypto.event_signing import compute_event_signature -from synapse.event_auth import auth_types_for_event from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator @@ -75,28 +53,14 @@ from synapse.logging.context import ( run_in_background, ) from synapse.logging.utils import log_function -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, - ReplicationFederationSendEventsRestServlet, ReplicationStoreRoomOnOutlierMembershipRestServlet, ) -from synapse.state import StateResolutionStore from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.types import ( - JsonDict, - MutableStateMap, - PersistedEventPosition, - RoomStreamToken, - StateMap, - UserID, - get_domain_from_id, -) -from synapse.util.async_helpers import Linearizer, concurrently_execute -from synapse.util.iterutils import batch_iter +from synapse.types import JsonDict, StateMap, get_domain_from_id +from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination -from synapse.util.stringutils import shortstr from synapse.visibility import filter_events_for_server if TYPE_CHECKING: @@ -104,45 +68,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -soft_failed_event_counter = Counter( - "synapse_federation_soft_failed_events_total", - "Events received over federation that we marked as soft_failed", -) - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _NewEventInfo: - """Holds information about a received event, ready for passing to _auth_and_persist_events - - Attributes: - event: the received event - - claimed_auth_event_map: a map of (type, state_key) => event for the event's - claimed auth_events. - - This can include events which have not yet been persisted, in the case that - we are backfilling a batch of events. - - Note: May be incomplete: if we were unable to find all of the claimed auth - events. Also, treat the contents with caution: the events might also have - been rejected, might not yet have been authorized themselves, or they might - be in the wrong room. - - """ - - event: EventBase - claimed_auth_event_map: StateMap[EventBase] - class FederationHandler(BaseHandler): - """Handles events that originated from federation. - Responsible for: - a) handling received Pdus before handing them on as Events to the rest - of the homeserver (including auth and state conflict resolutions) - b) converting events that were produced by local clients that may need - to be sent to remote homeservers. - c) doing the necessary dances to invite remote users and join remote - rooms. + """Handles general incoming federation requests + + Incoming events are *not* handled here, for which see FederationEventHandler. """ def __init__(self, hs: "HomeServer"): @@ -155,652 +85,35 @@ class FederationHandler(BaseHandler): self.state_store = self.storage.state self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() - self._state_resolution_handler = hs.get_state_resolution_handler() self.server_name = hs.hostname self.keyring = hs.get_keyring() - self.action_generator = hs.get_action_generator() self.is_mine_id = hs.is_mine_id self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() self._event_auth_handler = hs.get_event_auth_handler() - self._message_handler = hs.get_message_handler() self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config self.http_client = hs.get_proxied_blacklisted_http_client() - self._instance_name = hs.get_instance_name() self._replication = hs.get_replication_data_handler() + self._federation_event_handler = hs.get_federation_event_handler() - self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs) self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client( hs ) if hs.config.worker_app: - self._user_device_resync = ( - ReplicationUserDevicesResyncRestServlet.make_client(hs) - ) self._maybe_store_room_on_outlier_membership = ( ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs) ) else: - self._device_list_updater = hs.get_device_handler().device_list_updater self._maybe_store_room_on_outlier_membership = ( self.store.maybe_store_room_on_outlier_membership ) - # When joining a room we need to queue any events for that room up. - # For each room, a list of (pdu, origin) tuples. - self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {} - self._room_pdu_linearizer = Linearizer("fed_room_pdu") - self._room_backfill = Linearizer("room_backfill") self.third_party_event_rules = hs.get_third_party_event_rules() - self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages - - async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None: - """Process a PDU received via a federation /send/ transaction - - Args: - origin: server which initiated the /send/ transaction. Will - be used to fetch missing events or state. - pdu: received PDU - """ - - room_id = pdu.room_id - event_id = pdu.event_id - - # We reprocess pdus when we have seen them only as outliers - existing = await self.store.get_event( - event_id, allow_none=True, allow_rejected=True - ) - - # FIXME: Currently we fetch an event again when we already have it - # if it has been marked as an outlier. - if existing: - if not existing.internal_metadata.is_outlier(): - logger.info( - "Ignoring received event %s which we have already seen", event_id - ) - return - if pdu.internal_metadata.is_outlier(): - logger.info( - "Ignoring received outlier %s which we already have as an outlier", - event_id, - ) - return - logger.info("De-outliering event %s", event_id) - - # do some initial sanity-checking of the event. In particular, make - # sure it doesn't have hundreds of prev_events or auth_events, which - # could cause a huge state resolution or cascade of event fetches. - try: - self._sanity_check_event(pdu) - except SynapseError as err: - logger.warning("Received event failed sanity checks") - raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id) - - # If we are currently in the process of joining this room, then we - # queue up events for later processing. - if room_id in self.room_queues: - logger.info( - "Queuing PDU from %s for now: join in progress", - origin, - ) - self.room_queues[room_id].append((pdu, origin)) - return - - # If we're not in the room just ditch the event entirely. This is - # probably an old server that has come back and thinks we're still in - # the room (or we've been rejoined to the room by a state reset). - # - # Note that if we were never in the room then we would have already - # dropped the event, since we wouldn't know the room version. - is_in_room = await self._event_auth_handler.check_host_in_room( - room_id, self.server_name - ) - if not is_in_room: - logger.info( - "Ignoring PDU from %s as we're not in the room", - origin, - ) - return None - - # Check that the event passes auth based on the state at the event. This is - # done for events that are to be added to the timeline (non-outliers). - # - # Get missing pdus if necessary: - # - Fetching any missing prev events to fill in gaps in the graph - # - Fetching state if we have a hole in the graph - if not pdu.internal_metadata.is_outlier(): - prevs = set(pdu.prev_event_ids()) - seen = await self.store.have_events_in_timeline(prevs) - missing_prevs = prevs - seen - - if missing_prevs: - # We only backfill backwards to the min depth. - min_depth = await self.get_min_depth_for_context(pdu.room_id) - logger.debug("min_depth: %d", min_depth) - - if min_depth is not None and pdu.depth > min_depth: - # If we're missing stuff, ensure we only fetch stuff one - # at a time. - logger.info( - "Acquiring room lock to fetch %d missing prev_events: %s", - len(missing_prevs), - shortstr(missing_prevs), - ) - with (await self._room_pdu_linearizer.queue(pdu.room_id)): - logger.info( - "Acquired room lock to fetch %d missing prev_events", - len(missing_prevs), - ) - - try: - await self._get_missing_events_for_pdu( - origin, pdu, prevs, min_depth - ) - except Exception as e: - raise Exception( - "Error fetching missing prev_events for %s: %s" - % (event_id, e) - ) from e - - # Update the set of things we've seen after trying to - # fetch the missing stuff - seen = await self.store.have_events_in_timeline(prevs) - missing_prevs = prevs - seen - - if not missing_prevs: - logger.info("Found all missing prev_events") - - if missing_prevs: - # since this event was pushed to us, it is possible for it to - # become the only forward-extremity in the room, and we would then - # trust its state to be the state for the whole room. This is very - # bad. Further, if the event was pushed to us, there is no excuse - # for us not to have all the prev_events. (XXX: apart from - # min_depth?) - # - # We therefore reject any such events. - logger.warning( - "Rejecting: failed to fetch %d prev events: %s", - len(missing_prevs), - shortstr(missing_prevs), - ) - raise FederationError( - "ERROR", - 403, - ( - "Your server isn't divulging details about prev_events " - "referenced in this event." - ), - affected=pdu.event_id, - ) - - await self._process_received_pdu(origin, pdu, state=None) - - async def _get_missing_events_for_pdu( - self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int - ) -> None: - """ - Args: - origin: Origin of the pdu. Will be called to get the missing events - pdu: received pdu - prevs: List of event ids which we are missing - min_depth: Minimum depth of events to return. - """ - - room_id = pdu.room_id - event_id = pdu.event_id - - seen = await self.store.have_events_in_timeline(prevs) - - if not prevs - seen: - return - - latest_list = await self.store.get_latest_event_ids_in_room(room_id) - - # We add the prev events that we have seen to the latest - # list to ensure the remote server doesn't give them to us - latest = set(latest_list) - latest |= seen - - logger.info( - "Requesting missing events between %s and %s", - shortstr(latest), - event_id, - ) - - # XXX: we set timeout to 10s to help workaround - # https://github.com/matrix-org/synapse/issues/1733. - # The reason is to avoid holding the linearizer lock - # whilst processing inbound /send transactions, causing - # FDs to stack up and block other inbound transactions - # which empirically can currently take up to 30 minutes. - # - # N.B. this explicitly disables retry attempts. - # - # N.B. this also increases our chances of falling back to - # fetching fresh state for the room if the missing event - # can't be found, which slightly reduces our security. - # it may also increase our DAG extremity count for the room, - # causing additional state resolution? See #1760. - # However, fetching state doesn't hold the linearizer lock - # apparently. - # - # see https://github.com/matrix-org/synapse/pull/1744 - # - # ---- - # - # Update richvdh 2018/09/18: There are a number of problems with timing this - # request out aggressively on the client side: - # - # - it plays badly with the server-side rate-limiter, which starts tarpitting you - # if you send too many requests at once, so you end up with the server carefully - # working through the backlog of your requests, which you have already timed - # out. - # - # - for this request in particular, we now (as of - # https://github.com/matrix-org/synapse/pull/3456) reject any PDUs where the - # server can't produce a plausible-looking set of prev_events - so we becone - # much more likely to reject the event. - # - # - contrary to what it says above, we do *not* fall back to fetching fresh state - # for the room if get_missing_events times out. Rather, we give up processing - # the PDU whose prevs we are missing, which then makes it much more likely that - # we'll end up back here for the *next* PDU in the list, which exacerbates the - # problem. - # - # - the aggressive 10s timeout was introduced to deal with incoming federation - # requests taking 8 hours to process. It's not entirely clear why that was going - # on; certainly there were other issues causing traffic storms which are now - # resolved, and I think in any case we may be more sensible about our locking - # now. We're *certainly* more sensible about our logging. - # - # All that said: Let's try increasing the timeout to 60s and see what happens. - - try: - missing_events = await self.federation_client.get_missing_events( - origin, - room_id, - earliest_events_ids=list(latest), - latest_events=[pdu], - limit=10, - min_depth=min_depth, - timeout=60000, - ) - except (RequestSendFailed, HttpResponseException, NotRetryingDestination) as e: - # We failed to get the missing events, but since we need to handle - # the case of `get_missing_events` not returning the necessary - # events anyway, it is safe to simply log the error and continue. - logger.warning("Failed to get prev_events: %s", e) - return - - logger.info("Got %d prev_events", len(missing_events)) - await self._process_pulled_events(origin, missing_events, backfilled=False) - - async def _get_state_after_missing_prev_event( - self, - destination: str, - room_id: str, - event_id: str, - ) -> List[EventBase]: - """Requests all of the room state at a given event from a remote homeserver. - - Args: - destination: The remote homeserver to query for the state. - room_id: The id of the room we're interested in. - event_id: The id of the event we want the state at. - - Returns: - A list of events in the state, including the event itself - """ - ( - state_event_ids, - auth_event_ids, - ) = await self.federation_client.get_room_state_ids( - destination, room_id, event_id=event_id - ) - - logger.debug( - "state_ids returned %i state events, %i auth events", - len(state_event_ids), - len(auth_event_ids), - ) - - # start by just trying to fetch the events from the store - desired_events = set(state_event_ids) - desired_events.add(event_id) - logger.debug("Fetching %i events from cache/store", len(desired_events)) - fetched_events = await self.store.get_events( - desired_events, allow_rejected=True - ) - - missing_desired_events = desired_events - fetched_events.keys() - logger.debug( - "We are missing %i events (got %i)", - len(missing_desired_events), - len(fetched_events), - ) - - # We probably won't need most of the auth events, so let's just check which - # we have for now, rather than thrashing the event cache with them all - # unnecessarily. - - # TODO: we probably won't actually need all of the auth events, since we - # already have a bunch of the state events. It would be nice if the - # federation api gave us a way of finding out which we actually need. - - missing_auth_events = set(auth_event_ids) - fetched_events.keys() - missing_auth_events.difference_update( - await self.store.have_seen_events(room_id, missing_auth_events) - ) - logger.debug("We are also missing %i auth events", len(missing_auth_events)) - - missing_events = missing_desired_events | missing_auth_events - logger.debug("Fetching %i events from remote", len(missing_events)) - await self._get_events_and_persist( - destination=destination, room_id=room_id, events=missing_events - ) - - # we need to make sure we re-load from the database to get the rejected - # state correct. - fetched_events.update( - await self.store.get_events(missing_desired_events, allow_rejected=True) - ) - - # check for events which were in the wrong room. - # - # this can happen if a remote server claims that the state or - # auth_events at an event in room A are actually events in room B - - bad_events = [ - (event_id, event.room_id) - for event_id, event in fetched_events.items() - if event.room_id != room_id - ] - - for bad_event_id, bad_room_id in bad_events: - # This is a bogus situation, but since we may only discover it a long time - # after it happened, we try our best to carry on, by just omitting the - # bad events from the returned state set. - logger.warning( - "Remote server %s claims event %s in room %s is an auth/state " - "event in room %s", - destination, - bad_event_id, - bad_room_id, - room_id, - ) - - del fetched_events[bad_event_id] - - # if we couldn't get the prev event in question, that's a problem. - remote_event = fetched_events.get(event_id) - if not remote_event: - raise Exception("Unable to get missing prev_event %s" % (event_id,)) - - # missing state at that event is a warning, not a blocker - # XXX: this doesn't sound right? it means that we'll end up with incomplete - # state. - failed_to_fetch = desired_events - fetched_events.keys() - if failed_to_fetch: - logger.warning( - "Failed to fetch missing state events for %s %s", - event_id, - failed_to_fetch, - ) - - remote_state = [ - fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events - ] - - if remote_event.is_state() and remote_event.rejected_reason is None: - remote_state.append(remote_event) - - return remote_state - - async def _process_received_pdu( - self, - origin: str, - event: EventBase, - state: Optional[Iterable[EventBase]], - backfilled: bool = False, - ) -> None: - """Called when we have a new pdu. We need to do auth checks and put it - through the StateHandler. - - Args: - origin: server sending the event - - event: event to be persisted - - state: Normally None, but if we are handling a gap in the graph - (ie, we are missing one or more prev_events), the resolved state at the - event - - backfilled: True if this is part of a historical batch of events (inhibits - notification to clients, and validation of device keys.) - """ - logger.debug("Processing event: %s", event) - - try: - context = await self.state_handler.compute_event_context( - event, old_state=state - ) - await self._auth_and_persist_event( - origin, event, context, state=state, backfilled=backfilled - ) - except AuthError as e: - raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) - - if backfilled: - return - - # For encrypted messages we check that we know about the sending device, - # if we don't then we mark the device cache for that user as stale. - if event.type == EventTypes.Encrypted: - device_id = event.content.get("device_id") - sender_key = event.content.get("sender_key") - - cached_devices = await self.store.get_cached_devices_for_user(event.sender) - - resync = False # Whether we should resync device lists. - - device = None - if device_id is not None: - device = cached_devices.get(device_id) - if device is None: - logger.info( - "Received event from remote device not in our cache: %s %s", - event.sender, - device_id, - ) - resync = True - - # We also check if the `sender_key` matches what we expect. - if sender_key is not None: - # Figure out what sender key we're expecting. If we know the - # device and recognize the algorithm then we can work out the - # exact key to expect. Otherwise check it matches any key we - # have for that device. - - current_keys: Container[str] = [] - - if device: - keys = device.get("keys", {}).get("keys", {}) - - if ( - event.content.get("algorithm") - == RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2 - ): - # For this algorithm we expect a curve25519 key. - key_name = "curve25519:%s" % (device_id,) - current_keys = [keys.get(key_name)] - else: - # We don't know understand the algorithm, so we just - # check it matches a key for the device. - current_keys = keys.values() - elif device_id: - # We don't have any keys for the device ID. - pass - else: - # The event didn't include a device ID, so we just look for - # keys across all devices. - current_keys = [ - key - for device in cached_devices.values() - for key in device.get("keys", {}).get("keys", {}).values() - ] - - # We now check that the sender key matches (one of) the expected - # keys. - if sender_key not in current_keys: - logger.info( - "Received event from remote device with unexpected sender key: %s %s: %s", - event.sender, - device_id or "", - sender_key, - ) - resync = True - - if resync: - run_as_background_process( - "resync_device_due_to_pdu", self._resync_device, event.sender - ) - - await self._handle_marker_event(origin, event) - - async def _handle_marker_event(self, origin: str, marker_event: EventBase): - """Handles backfilling the insertion event when we receive a marker - event that points to one. - - Args: - origin: Origin of the event. Will be called to get the insertion event - marker_event: The event to process - """ - - if marker_event.type != EventTypes.MSC2716_MARKER: - # Not a marker event - return - - if marker_event.rejected_reason is not None: - # Rejected event - return - - # Skip processing a marker event if the room version doesn't - # support it. - room_version = await self.store.get_room_version(marker_event.room_id) - if not room_version.msc2716_historical: - return - - logger.debug("_handle_marker_event: received %s", marker_event) - - insertion_event_id = marker_event.content.get( - EventContentFields.MSC2716_MARKER_INSERTION - ) - - if insertion_event_id is None: - # Nothing to retrieve then (invalid marker) - return - - logger.debug( - "_handle_marker_event: backfilling insertion event %s", insertion_event_id - ) - - await self._get_events_and_persist( - origin, - marker_event.room_id, - [insertion_event_id], - ) - - insertion_event = await self.store.get_event( - insertion_event_id, allow_none=True - ) - if insertion_event is None: - logger.warning( - "_handle_marker_event: server %s didn't return insertion event %s for marker %s", - origin, - insertion_event_id, - marker_event.event_id, - ) - return - - logger.debug( - "_handle_marker_event: succesfully backfilled insertion event %s from marker event %s", - insertion_event, - marker_event, - ) - - await self.store.insert_insertion_extremity( - insertion_event_id, marker_event.room_id - ) - - logger.debug( - "_handle_marker_event: insertion extremity added for %s from marker event %s", - insertion_event, - marker_event, - ) - - async def _resync_device(self, sender: str) -> None: - """We have detected that the device list for the given user may be out - of sync, so we try and resync them. - """ - - try: - await self.store.mark_remote_user_device_cache_as_stale(sender) - - # Immediately attempt a resync in the background - if self.config.worker_app: - await self._user_device_resync(user_id=sender) - else: - await self._device_list_updater.user_device_resync(sender) - except Exception: - logger.exception("Failed to resync device for %s", sender) - - @log_function - async def backfill( - self, dest: str, room_id: str, limit: int, extremities: List[str] - ) -> None: - """Trigger a backfill request to `dest` for the given `room_id` - - This will attempt to get more events from the remote. If the other side - has no new events to offer, this will return an empty list. - - As the events are received, we check their signatures, and also do some - sanity-checking on them. If any of the backfilled events are invalid, - this method throws a SynapseError. - - We might also raise an InvalidResponseError if the response from the remote - server is just bogus. - - TODO: make this more useful to distinguish failures of the remote - server from invalid events (there is probably no point in trying to - re-fetch invalid events from every other HS in the room.) - """ - if dest == self.server_name: - raise SynapseError(400, "Can't backfill from self.") - - events = await self.federation_client.backfill( - dest, room_id, limit=limit, extremities=extremities - ) - - if not events: - return - - # if there are any events in the wrong room, the remote server is buggy and - # should not be trusted. - for ev in events: - if ev.room_id != room_id: - raise InvalidResponseError( - f"Remote server {dest} returned event {ev.event_id} which is in " - f"room {ev.room_id}, when we were backfilling in {room_id}" - ) - - await self._process_pulled_events(dest, events, backfilled=True) - async def maybe_backfill( self, room_id: str, current_depth: int, limit: int ) -> bool: @@ -995,7 +308,7 @@ class FederationHandler(BaseHandler): # TODO: Should we try multiple of these at a time? for dom in domains: try: - await self.backfill( + await self._federation_event_handler.backfill( dom, room_id, limit=100, extremities=extremities ) # If this succeeded then we probably already have the @@ -1084,316 +397,7 @@ class FederationHandler(BaseHandler): tried_domains.update(dom for dom, _ in likely_extremeties_domains) - return False - - async def _get_events_and_persist( - self, destination: str, room_id: str, events: Iterable[str] - ) -> None: - """Fetch the given events from a server, and persist them as outliers. - - This function *does not* recursively get missing auth events of the - newly fetched events. Callers must include in the `events` argument - any missing events from the auth chain. - - Logs a warning if we can't find the given event. - """ - - room_version = await self.store.get_room_version(room_id) - - event_map: Dict[str, EventBase] = {} - - async def get_event(event_id: str): - with nested_logging_context(event_id): - try: - event = await self.federation_client.get_pdu( - [destination], - event_id, - room_version, - outlier=True, - ) - if event is None: - logger.warning( - "Server %s didn't return event %s", - destination, - event_id, - ) - return - - event_map[event.event_id] = event - - except Exception as e: - logger.warning( - "Error fetching missing state/auth event %s: %s %s", - event_id, - type(e), - e, - ) - - await concurrently_execute(get_event, events, 5) - - # Make a map of auth events for each event. We do this after fetching - # all the events as some of the events' auth events will be in the list - # of requested events. - - auth_events = [ - aid - for event in event_map.values() - for aid in event.auth_event_ids() - if aid not in event_map - ] - persisted_events = await self.store.get_events( - auth_events, - allow_rejected=True, - ) - - event_infos = [] - for event in event_map.values(): - auth = {} - for auth_event_id in event.auth_event_ids(): - ae = persisted_events.get(auth_event_id) or event_map.get(auth_event_id) - if ae: - auth[(ae.type, ae.state_key)] = ae - else: - logger.info("Missing auth event %s", auth_event_id) - - event_infos.append(_NewEventInfo(event, auth)) - - if event_infos: - await self._auth_and_persist_events( - destination, - room_id, - event_infos, - ) - - async def _process_pulled_events( - self, origin: str, events: Iterable[EventBase], backfilled: bool - ) -> None: - """Process a batch of events we have pulled from a remote server - - Pulls in any events required to auth the events, persists the received events, - and notifies clients, if appropriate. - - Assumes the events have already had their signatures and hashes checked. - - Params: - origin: The server we received these events from - events: The received events. - backfilled: True if this is part of a historical batch of events (inhibits - notification to clients, and validation of device keys.) - """ - - # We want to sort these by depth so we process them and - # tell clients about them in order. - sorted_events = sorted(events, key=lambda x: x.depth) - - for ev in sorted_events: - with nested_logging_context(ev.event_id): - await self._process_pulled_event(origin, ev, backfilled=backfilled) - - async def _process_pulled_event( - self, origin: str, event: EventBase, backfilled: bool - ) -> None: - """Process a single event that we have pulled from a remote server - - Pulls in any events required to auth the event, persists the received event, - and notifies clients, if appropriate. - - Assumes the event has already had its signatures and hashes checked. - - This is somewhat equivalent to on_receive_pdu, but applies somewhat different - logic in the case that we are missing prev_events (in particular, it just - requests the state at that point, rather than triggering a get_missing_events) - - so is appropriate when we have pulled the event from a remote server, rather - than having it pushed to us. - - Params: - origin: The server we received this event from - events: The received event - backfilled: True if this is part of a historical batch of events (inhibits - notification to clients, and validation of device keys.) - """ - logger.info("Processing pulled event %s", event) - - # these should not be outliers. - assert not event.internal_metadata.is_outlier() - - event_id = event.event_id - - existing = await self.store.get_event( - event_id, allow_none=True, allow_rejected=True - ) - if existing: - if not existing.internal_metadata.is_outlier(): - logger.info( - "Ignoring received event %s which we have already seen", - event_id, - ) - return - logger.info("De-outliering event %s", event_id) - - try: - self._sanity_check_event(event) - except SynapseError as err: - logger.warning("Event %s failed sanity check: %s", event_id, err) - return - - try: - state = await self._resolve_state_at_missing_prevs(origin, event) - await self._process_received_pdu( - origin, event, state=state, backfilled=backfilled - ) - except FederationError as e: - if e.code == 403: - logger.warning("Pulled event %s failed history check.", event_id) - else: - raise - - async def _resolve_state_at_missing_prevs( - self, dest: str, event: EventBase - ) -> Optional[Iterable[EventBase]]: - """Calculate the state at an event with missing prev_events. - - This is used when we have pulled a batch of events from a remote server, and - still don't have all the prev_events. - - If we already have all the prev_events for `event`, this method does nothing. - - Otherwise, the missing prevs become new backwards extremities, and we fall back - to asking the remote server for the state after each missing `prev_event`, - and resolving across them. - - That's ok provided we then resolve the state against other bits of the DAG - before using it - in other words, that the received event `event` is not going - to become the only forwards_extremity in the room (which will ensure that you - can't just take over a room by sending an event, withholding its prev_events, - and declaring yourself to be an admin in the subsequent state request). - - In other words: we should only call this method if `event` has been *pulled* - as part of a batch of missing prev events, or similar. - - Params: - dest: the remote server to ask for state at the missing prevs. Typically, - this will be the server we got `event` from. - event: an event to check for missing prevs. - - Returns: - if we already had all the prev events, `None`. Otherwise, returns a list of - the events in the state at `event`. - """ - room_id = event.room_id - event_id = event.event_id - - prevs = set(event.prev_event_ids()) - seen = await self.store.have_events_in_timeline(prevs) - missing_prevs = prevs - seen - - if not missing_prevs: - return None - - logger.info( - "Event %s is missing prev_events %s: calculating state for a " - "backwards extremity", - event_id, - shortstr(missing_prevs), - ) - # Calculate the state after each of the previous events, and - # resolve them to find the correct state at the current event. - event_map = {event_id: event} - try: - # Get the state of the events we know about - ours = await self.state_store.get_state_groups_ids(room_id, seen) - - # state_maps is a list of mappings from (type, state_key) to event_id - state_maps: List[StateMap[str]] = list(ours.values()) - - # we don't need this any more, let's delete it. - del ours - - # Ask the remote server for the states we don't - # know about - for p in missing_prevs: - logger.info("Requesting state after missing prev_event %s", p) - - with nested_logging_context(p): - # note that if any of the missing prevs share missing state or - # auth events, the requests to fetch those events are deduped - # by the get_pdu_cache in federation_client. - remote_state = await self._get_state_after_missing_prev_event( - dest, room_id, p - ) - - remote_state_map = { - (x.type, x.state_key): x.event_id for x in remote_state - } - state_maps.append(remote_state_map) - - for x in remote_state: - event_map[x.event_id] = x - - room_version = await self.store.get_room_version_id(room_id) - state_map = await self._state_resolution_handler.resolve_events_with_store( - room_id, - room_version, - state_maps, - event_map, - state_res_store=StateResolutionStore(self.store), - ) - - # We need to give _process_received_pdu the actual state events - # rather than event ids, so generate that now. - - # First though we need to fetch all the events that are in - # state_map, so we can build up the state below. - evs = await self.store.get_events( - list(state_map.values()), - get_prev_content=False, - redact_behaviour=EventRedactBehaviour.AS_IS, - ) - event_map.update(evs) - - state = [event_map[e] for e in state_map.values()] - except Exception: - logger.warning( - "Error attempting to resolve state at missing prev_events", - exc_info=True, - ) - raise FederationError( - "ERROR", - 403, - "We can't get valid state history.", - affected=event_id, - ) - return state - - def _sanity_check_event(self, ev: EventBase) -> None: - """ - Do some early sanity checks of a received event - - In particular, checks it doesn't have an excessive number of - prev_events or auth_events, which could cause a huge state resolution - or cascade of event fetches. - - Args: - ev: event to be checked - - Raises: - SynapseError if the event does not pass muster - """ - if len(ev.prev_event_ids()) > 20: - logger.warning( - "Rejecting event %s which has %i prev_events", - ev.event_id, - len(ev.prev_event_ids()), - ) - raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many prev_events") - - if len(ev.auth_event_ids()) > 10: - logger.warning( - "Rejecting event %s which has %i auth_events", - ev.event_id, - len(ev.auth_event_ids()), - ) - raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") + return False async def send_invite(self, target_host: str, event: EventBase) -> EventBase: """Sends the invite to the remote server for signing. @@ -1460,9 +464,9 @@ class FederationHandler(BaseHandler): # This shouldn't happen, because the RoomMemberHandler has a # linearizer lock which only allows one operation per user per room # at a time - so this is just paranoia. - assert room_id not in self.room_queues + assert room_id not in self._federation_event_handler.room_queues - self.room_queues[room_id] = [] + self._federation_event_handler.room_queues[room_id] = [] await self._clean_room_for_join(room_id) @@ -1536,8 +540,8 @@ class FederationHandler(BaseHandler): logger.debug("Finished joining %s to %s", joinee, room_id) return event.event_id, max_stream_id finally: - room_queue = self.room_queues[room_id] - del self.room_queues[room_id] + room_queue = self._federation_event_handler.room_queues[room_id] + del self._federation_event_handler.room_queues[room_id] # we don't need to wait for the queued events to be processed - # it's just a best-effort thing at this point. We do want to do @@ -1613,7 +617,7 @@ class FederationHandler(BaseHandler): event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"] context = await self.state_handler.compute_event_context(event) - stream_id = await self.persist_events_and_notify( + stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) return event.event_id, stream_id @@ -1633,7 +637,7 @@ class FederationHandler(BaseHandler): p, ) with nested_logging_context(p.event_id): - await self.on_receive_pdu(origin, p) + await self._federation_event_handler.on_receive_pdu(origin, p) except Exception as e: logger.warning( "Error handling queued PDU %s from %s: %s", p.event_id, origin, e @@ -1726,7 +730,7 @@ class FederationHandler(BaseHandler): raise # Ensure the user can even join the room. - await self._check_join_restrictions(context, event) + await self._federation_event_handler.check_join_restrictions(context, event) # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` @@ -1803,7 +807,9 @@ class FederationHandler(BaseHandler): ) context = await self.state_handler.compute_event_context(event) - await self.persist_events_and_notify(event.room_id, [(event, context)]) + await self._federation_event_handler.persist_events_and_notify( + event.room_id, [(event, context)] + ) return event @@ -1830,7 +836,7 @@ class FederationHandler(BaseHandler): await self.federation_client.send_leave(host_list, event) context = await self.state_handler.compute_event_context(event) - stream_id = await self.persist_events_and_notify( + stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -1973,116 +979,6 @@ class FederationHandler(BaseHandler): return event - @log_function - async def on_send_membership_event( - self, origin: str, event: EventBase - ) -> Tuple[EventBase, EventContext]: - """ - We have received a join/leave/knock event for a room via send_join/leave/knock. - - Verify that event and send it into the room on the remote homeserver's behalf. - - This is quite similar to on_receive_pdu, with the following principal - differences: - * only membership events are permitted (and only events with - sender==state_key -- ie, no kicks or bans) - * *We* send out the event on behalf of the remote server. - * We enforce the membership restrictions of restricted rooms. - * Rejected events result in an exception rather than being stored. - - There are also other differences, however it is not clear if these are by - design or omission. In particular, we do not attempt to backfill any missing - prev_events. - - Args: - origin: The homeserver of the remote (joining/invited/knocking) user. - event: The member event that has been signed by the remote homeserver. - - Returns: - The event and context of the event after inserting it into the room graph. - - Raises: - SynapseError if the event is not accepted into the room - """ - logger.debug( - "on_send_membership_event: Got event: %s, signatures: %s", - event.event_id, - event.signatures, - ) - - if get_domain_from_id(event.sender) != origin: - logger.info( - "Got send_membership request for user %r from different origin %s", - event.sender, - origin, - ) - raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) - - if event.sender != event.state_key: - raise SynapseError(400, "state_key and sender must match", Codes.BAD_JSON) - - assert not event.internal_metadata.outlier - - # Send this event on behalf of the other server. - # - # The remote server isn't a full participant in the room at this point, so - # may not have an up-to-date list of the other homeservers participating in - # the room, so we send it on their behalf. - event.internal_metadata.send_on_behalf_of = origin - - context = await self.state_handler.compute_event_context(event) - context = await self._check_event_auth(origin, event, context) - if context.rejected: - raise SynapseError( - 403, f"{event.membership} event was rejected", Codes.FORBIDDEN - ) - - # for joins, we need to check the restrictions of restricted rooms - if event.membership == Membership.JOIN: - await self._check_join_restrictions(context, event) - - # for knock events, we run the third-party event rules. It's not entirely clear - # why we don't do this for other sorts of membership events. - if event.membership == Membership.KNOCK: - event_allowed, _ = await self.third_party_event_rules.check_event_allowed( - event, context - ) - if not event_allowed: - logger.info("Sending of knock %s forbidden by third-party rules", event) - raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN - ) - - # all looks good, we can persist the event. - await self._run_push_actions_and_persist_event(event, context) - return event, context - - async def _check_join_restrictions( - self, context: EventContext, event: EventBase - ) -> None: - """Check that restrictions in restricted join rules are matched - - Called when we receive a join event via send_join. - - Raises an auth error if the restrictions are not matched. - """ - prev_state_ids = await context.get_prev_state_ids() - - # Check if the user is already in the room or invited to the room. - user_id = event.state_key - prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) - prev_member_event = None - if prev_member_event_id: - prev_member_event = await self.store.get_event(prev_member_event_id) - - # Check if the member should be allowed access via membership in a space. - await self._event_auth_handler.check_restricted_join_rules( - prev_state_ids, - event.room_version, - user_id, - prev_member_event, - ) - async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]: """Returns the state at the event. i.e. not including said event.""" @@ -2183,126 +1079,6 @@ class FederationHandler(BaseHandler): else: return None - async def get_min_depth_for_context(self, context: str) -> int: - return await self.store.get_min_depth(context) - - async def _auth_and_persist_event( - self, - origin: str, - event: EventBase, - context: EventContext, - state: Optional[Iterable[EventBase]] = None, - claimed_auth_event_map: Optional[StateMap[EventBase]] = None, - backfilled: bool = False, - ) -> None: - """ - Process an event by performing auth checks and then persisting to the database. - - Args: - origin: The host the event originates from. - event: The event itself. - context: - The event context. - - state: - The state events used to check the event for soft-fail. If this is - not provided the current state events will be used. - - claimed_auth_event_map: - A map of (type, state_key) => event for the event's claimed auth_events. - Possibly incomplete, and possibly including events that are not yet - persisted, or authed, or in the right room. - - Only populated where we may not already have persisted these events - - for example, when populating outliers. - - backfilled: True if the event was backfilled. - """ - context = await self._check_event_auth( - origin, - event, - context, - state=state, - claimed_auth_event_map=claimed_auth_event_map, - backfilled=backfilled, - ) - - await self._run_push_actions_and_persist_event(event, context, backfilled) - - async def _run_push_actions_and_persist_event( - self, event: EventBase, context: EventContext, backfilled: bool = False - ): - """Run the push actions for a received event, and persist it. - - Args: - event: The event itself. - context: The event context. - backfilled: True if the event was backfilled. - """ - try: - if ( - not event.internal_metadata.is_outlier() - and not backfilled - and not context.rejected - and (await self.store.get_min_depth(event.room_id)) <= event.depth - ): - await self.action_generator.handle_push_actions_for_event( - event, context - ) - - await self.persist_events_and_notify( - event.room_id, [(event, context)], backfilled=backfilled - ) - except Exception: - run_in_background( - self.store.remove_push_actions_from_staging, event.event_id - ) - raise - - async def _auth_and_persist_events( - self, - origin: str, - room_id: str, - event_infos: Collection[_NewEventInfo], - ) -> None: - """Creates the appropriate contexts and persists events. The events - should not depend on one another, e.g. this should be used to persist - a bunch of outliers, but not a chunk of individual events that depend - on each other for state calculations. - - Notifies about the events where appropriate. - """ - - if not event_infos: - return - - async def prep(ev_info: _NewEventInfo): - event = ev_info.event - with nested_logging_context(suffix=event.event_id): - res = await self.state_handler.compute_event_context(event) - res = await self._check_event_auth( - origin, - event, - res, - claimed_auth_event_map=ev_info.claimed_auth_event_map, - ) - return res - - contexts = await make_deferred_yieldable( - defer.gatherResults( - [run_in_background(prep, ev_info) for ev_info in event_infos], - consumeErrors=True, - ) - ) - - await self.persist_events_and_notify( - room_id, - [ - (ev_info.event, context) - for ev_info, context in zip(event_infos, contexts) - ], - ) - async def _persist_auth_tree( self, origin: str, @@ -2400,7 +1176,7 @@ class FederationHandler(BaseHandler): events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR if auth_events or state: - await self.persist_events_and_notify( + await self._federation_event_handler.persist_events_and_notify( room_id, [ (e, events_to_context[e.event_id]) @@ -2412,108 +1188,10 @@ class FederationHandler(BaseHandler): event, old_state=state ) - return await self.persist_events_and_notify( + return await self._federation_event_handler.persist_events_and_notify( room_id, [(event, new_event_context)] ) - async def _check_for_soft_fail( - self, - event: EventBase, - state: Optional[Iterable[EventBase]], - backfilled: bool, - origin: str, - ) -> None: - """Checks if we should soft fail the event; if so, marks the event as - such. - - Args: - event - state: The state at the event if we don't have all the event's prev events - backfilled: Whether the event is from backfill - origin: The host the event originates from. - """ - # For new (non-backfilled and non-outlier) events we check if the event - # passes auth based on the current state. If it doesn't then we - # "soft-fail" the event. - if backfilled or event.internal_metadata.is_outlier(): - return - - extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id) - extrem_ids = set(extrem_ids_list) - prev_event_ids = set(event.prev_event_ids()) - - if extrem_ids == prev_event_ids: - # If they're the same then the current state is the same as the - # state at the event, so no point rechecking auth for soft fail. - return - - room_version = await self.store.get_room_version_id(event.room_id) - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - - # Calculate the "current state". - if state is not None: - # If we're explicitly given the state then we won't have all the - # prev events, and so we have a gap in the graph. In this case - # we want to be a little careful as we might have been down for - # a while and have an incorrect view of the current state, - # however we still want to do checks as gaps are easy to - # maliciously manufacture. - # - # So we use a "current state" that is actually a state - # resolution across the current forward extremities and the - # given state at the event. This should correctly handle cases - # like bans, especially with state res v2. - - state_sets_d = await self.state_store.get_state_groups( - event.room_id, extrem_ids - ) - state_sets: List[Iterable[EventBase]] = list(state_sets_d.values()) - state_sets.append(state) - current_states = await self.state_handler.resolve_events( - room_version, state_sets, event - ) - current_state_ids: StateMap[str] = { - k: e.event_id for k, e in current_states.items() - } - else: - current_state_ids = await self.state_handler.get_current_state_ids( - event.room_id, latest_event_ids=extrem_ids - ) - - logger.debug( - "Doing soft-fail check for %s: state %s", - event.event_id, - current_state_ids, - ) - - # Now check if event pass auth against said current state - auth_types = auth_types_for_event(room_version_obj, event) - current_state_ids_list = [ - e for k, e in current_state_ids.items() if k in auth_types - ] - - auth_events_map = await self.store.get_events(current_state_ids_list) - current_auth_events = { - (e.type, e.state_key): e for e in auth_events_map.values() - } - - try: - event_auth.check(room_version_obj, event, auth_events=current_auth_events) - except AuthError as e: - logger.warning( - "Soft-failing %r (from %s) because %s", - event, - e, - origin, - extra={ - "room_id": event.room_id, - "mxid": event.sender, - "hs": origin, - }, - ) - soft_failed_event_counter.inc() - event.internal_metadata.soft_failed = True - async def on_get_missing_events( self, origin: str, @@ -2542,334 +1220,6 @@ class FederationHandler(BaseHandler): return missing_events - async def _check_event_auth( - self, - origin: str, - event: EventBase, - context: EventContext, - state: Optional[Iterable[EventBase]] = None, - claimed_auth_event_map: Optional[StateMap[EventBase]] = None, - backfilled: bool = False, - ) -> EventContext: - """ - Checks whether an event should be rejected (for failing auth checks). - - Args: - origin: The host the event originates from. - event: The event itself. - context: - The event context. - - state: - The state events used to check the event for soft-fail. If this is - not provided the current state events will be used. - - claimed_auth_event_map: - A map of (type, state_key) => event for the event's claimed auth_events. - Possibly incomplete, and possibly including events that are not yet - persisted, or authed, or in the right room. - - Only populated where we may not already have persisted these events - - for example, when populating outliers, or the state for a backwards - extremity. - - backfilled: True if the event was backfilled. - - Returns: - The updated context object. - """ - room_version = await self.store.get_room_version_id(event.room_id) - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - - if claimed_auth_event_map: - # if we have a copy of the auth events from the event, use that as the - # basis for auth. - auth_events = claimed_auth_event_map - else: - # otherwise, we calculate what the auth events *should* be, and use that - prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = self._event_auth_handler.compute_auth_events( - event, prev_state_ids, for_verification=True - ) - auth_events_x = await self.store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} - - try: - ( - context, - auth_events_for_auth, - ) = await self._update_auth_events_and_context_for_auth( - origin, event, context, auth_events - ) - except Exception: - # We don't really mind if the above fails, so lets not fail - # processing if it does. However, it really shouldn't fail so - # let's still log as an exception since we'll still want to fix - # any bugs. - logger.exception( - "Failed to double check auth events for %s with remote. " - "Ignoring failure and continuing processing of event.", - event.event_id, - ) - auth_events_for_auth = auth_events - - try: - event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth) - except AuthError as e: - logger.warning("Failed auth resolution for %r because %s", event, e) - context.rejected = RejectedReason.AUTH_ERROR - - if not context.rejected: - await self._check_for_soft_fail(event, state, backfilled, origin=origin) - - if event.type == EventTypes.GuestAccess and not context.rejected: - await self.maybe_kick_guest_users(event) - - # If we are going to send this event over federation we precaclculate - # the joined hosts. - if event.internal_metadata.get_send_on_behalf_of(): - await self.event_creation_handler.cache_joined_hosts_for_event( - event, context - ) - - return context - - async def _update_auth_events_and_context_for_auth( - self, - origin: str, - event: EventBase, - context: EventContext, - input_auth_events: StateMap[EventBase], - ) -> Tuple[EventContext, StateMap[EventBase]]: - """Helper for _check_event_auth. See there for docs. - - Checks whether a given event has the expected auth events. If it - doesn't then we talk to the remote server to compare state to see if - we can come to a consensus (e.g. if one server missed some valid - state). - - This attempts to resolve any potential divergence of state between - servers, but is not essential and so failures should not block further - processing of the event. - - Args: - origin: - event: - context: - - input_auth_events: - Map from (event_type, state_key) to event - - Normally, our calculated auth_events based on the state of the room - at the event's position in the DAG, though occasionally (eg if the - event is an outlier), may be the auth events claimed by the remote - server. - - Returns: - updated context, updated auth event map - """ - # take a copy of input_auth_events before we modify it. - auth_events: MutableStateMap[EventBase] = dict(input_auth_events) - - event_auth_events = set(event.auth_event_ids()) - - # missing_auth is the set of the event's auth_events which we don't yet have - # in auth_events. - missing_auth = event_auth_events.difference( - e.event_id for e in auth_events.values() - ) - - # if we have missing events, we need to fetch those events from somewhere. - # - # we start by checking if they are in the store, and then try calling /event_auth/. - if missing_auth: - have_events = await self.store.have_seen_events(event.room_id, missing_auth) - logger.debug("Events %s are in the store", have_events) - missing_auth.difference_update(have_events) - - if missing_auth: - # If we don't have all the auth events, we need to get them. - logger.info("auth_events contains unknown events: %s", missing_auth) - try: - try: - remote_auth_chain = await self.federation_client.get_event_auth( - origin, event.room_id, event.event_id - ) - except RequestSendFailed as e1: - # The other side isn't around or doesn't implement the - # endpoint, so lets just bail out. - logger.info("Failed to get event auth from remote: %s", e1) - return context, auth_events - - seen_remotes = await self.store.have_seen_events( - event.room_id, [e.event_id for e in remote_auth_chain] - ) - - for e in remote_auth_chain: - if e.event_id in seen_remotes: - continue - - if e.event_id == event.event_id: - continue - - try: - auth_ids = e.auth_event_ids() - auth = { - (e.type, e.state_key): e - for e in remote_auth_chain - if e.event_id in auth_ids or e.type == EventTypes.Create - } - e.internal_metadata.outlier = True - - logger.debug( - "_check_event_auth %s missing_auth: %s", - event.event_id, - e.event_id, - ) - missing_auth_event_context = ( - await self.state_handler.compute_event_context(e) - ) - await self._auth_and_persist_event( - origin, - e, - missing_auth_event_context, - claimed_auth_event_map=auth, - ) - - if e.event_id in event_auth_events: - auth_events[(e.type, e.state_key)] = e - except AuthError: - pass - - except Exception: - logger.exception("Failed to get auth chain") - - if event.internal_metadata.is_outlier(): - # XXX: given that, for an outlier, we'll be working with the - # event's *claimed* auth events rather than those we calculated: - # (a) is there any point in this test, since different_auth below will - # obviously be empty - # (b) alternatively, why don't we do it earlier? - logger.info("Skipping auth_event fetch for outlier") - return context, auth_events - - different_auth = event_auth_events.difference( - e.event_id for e in auth_events.values() - ) - - if not different_auth: - return context, auth_events - - logger.info( - "auth_events refers to events which are not in our calculated auth " - "chain: %s", - different_auth, - ) - - # XXX: currently this checks for redactions but I'm not convinced that is - # necessary? - different_events = await self.store.get_events_as_list(different_auth) - - for d in different_events: - if d.room_id != event.room_id: - logger.warning( - "Event %s refers to auth_event %s which is in a different room", - event.event_id, - d.event_id, - ) - - # don't attempt to resolve the claimed auth events against our own - # in this case: just use our own auth events. - # - # XXX: should we reject the event in this case? It feels like we should, - # but then shouldn't we also do so if we've failed to fetch any of the - # auth events? - return context, auth_events - - # now we state-resolve between our own idea of the auth events, and the remote's - # idea of them. - - local_state = auth_events.values() - remote_auth_events = dict(auth_events) - remote_auth_events.update({(d.type, d.state_key): d for d in different_events}) - remote_state = remote_auth_events.values() - - room_version = await self.store.get_room_version_id(event.room_id) - new_state = await self.state_handler.resolve_events( - room_version, (local_state, remote_state), event - ) - - logger.info( - "After state res: updating auth_events with new state %s", - { - (d.type, d.state_key): d.event_id - for d in new_state.values() - if auth_events.get((d.type, d.state_key)) != d - }, - ) - - auth_events.update(new_state) - - context = await self._update_context_for_auth_events( - event, context, auth_events - ) - - return context, auth_events - - async def _update_context_for_auth_events( - self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] - ) -> EventContext: - """Update the state_ids in an event context after auth event resolution, - storing the changes as a new state group. - - Args: - event: The event we're handling the context for - - context: initial event context - - auth_events: Events to update in the event context. - - Returns: - new event context - """ - # exclude the state key of the new event from the current_state in the context. - if event.is_state(): - event_key: Optional[Tuple[str, str]] = (event.type, event.state_key) - else: - event_key = None - state_updates = { - k: a.event_id for k, a in auth_events.items() if k != event_key - } - - current_state_ids = await context.get_current_state_ids() - current_state_ids = dict(current_state_ids) # type: ignore - - current_state_ids.update(state_updates) - - prev_state_ids = await context.get_prev_state_ids() - prev_state_ids = dict(prev_state_ids) - - prev_state_ids.update({k: a.event_id for k, a in auth_events.items()}) - - # create a new state group as a delta from the existing one. - prev_group = context.state_group - state_group = await self.state_store.store_state_group( - event.event_id, - event.room_id, - prev_group=prev_group, - delta_ids=state_updates, - current_state_ids=current_state_ids, - ) - - return EventContext.with_state( - state_group=state_group, - state_group_before_event=context.state_group_before_event, - current_state_ids=current_state_ids, - prev_state_ids=prev_state_ids, - prev_group=prev_group, - delta_ids=state_updates, - ) - async def construct_auth_difference( self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase] ) -> Dict: @@ -3256,99 +1606,6 @@ class FederationHandler(BaseHandler): if "valid" not in response or not response["valid"]: raise AuthError(403, "Third party certificate was invalid") - async def persist_events_and_notify( - self, - room_id: str, - event_and_contexts: Sequence[Tuple[EventBase, EventContext]], - backfilled: bool = False, - ) -> int: - """Persists events and tells the notifier/pushers about them, if - necessary. - - Args: - room_id: The room ID of events being persisted. - event_and_contexts: Sequence of events with their associated - context that should be persisted. All events must belong to - the same room. - backfilled: Whether these events are a result of - backfilling or not - - Returns: - The stream ID after which all events have been persisted. - """ - if not event_and_contexts: - return self.store.get_current_events_token() - - instance = self.config.worker.events_shard_config.get_instance(room_id) - if instance != self._instance_name: - # Limit the number of events sent over replication. We choose 200 - # here as that is what we default to in `max_request_body_size(..)` - for batch in batch_iter(event_and_contexts, 200): - result = await self._send_events( - instance_name=instance, - store=self.store, - room_id=room_id, - event_and_contexts=batch, - backfilled=backfilled, - ) - return result["max_stream_id"] - else: - assert self.storage.persistence - - # Note that this returns the events that were persisted, which may not be - # the same as were passed in if some were deduplicated due to transaction IDs. - events, max_stream_token = await self.storage.persistence.persist_events( - event_and_contexts, backfilled=backfilled - ) - - if self._ephemeral_messages_enabled: - for event in events: - # If there's an expiry timestamp on the event, schedule its expiry. - self._message_handler.maybe_schedule_expiry(event) - - if not backfilled: # Never notify for backfilled events - for event in events: - await self._notify_persisted_event(event, max_stream_token) - - return max_stream_token.stream - - async def _notify_persisted_event( - self, event: EventBase, max_stream_token: RoomStreamToken - ) -> None: - """Checks to see if notifier/pushers should be notified about the - event or not. - - Args: - event: - max_stream_id: The max_stream_id returned by persist_events - """ - - extra_users = [] - if event.type == EventTypes.Member: - target_user_id = event.state_key - - # We notify for memberships if its an invite for one of our - # users - if event.internal_metadata.is_outlier(): - if event.membership != Membership.INVITE: - if not self.is_mine_id(target_user_id): - return - - target_user = UserID.from_string(target_user_id) - extra_users.append(target_user) - elif event.internal_metadata.is_outlier(): - return - - # the event has been persisted so it should have a stream ordering. - assert event.internal_metadata.stream_ordering - - event_pos = PersistedEventPosition( - self._instance_name, event.internal_metadata.stream_ordering - ) - self.notifier.on_new_room_event( - event, event_pos, max_stream_token, extra_users=extra_users - ) - async def _clean_room_for_join(self, room_id: str) -> None: """Called to clean up any data in DB for a given room, ready for the server to join the room. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py new file mode 100644 index 0000000000..9f055f00cf --- /dev/null +++ b/synapse/handlers/federation_event.py @@ -0,0 +1,1825 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from http import HTTPStatus +from typing import ( + TYPE_CHECKING, + Collection, + Container, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, +) + +import attr +from prometheus_client import Counter + +from twisted.internet import defer + +from synapse import event_auth +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RejectedReason, + RoomEncryptionAlgorithms, +) +from synapse.api.errors import ( + AuthError, + Codes, + FederationError, + HttpResponseException, + RequestSendFailed, + SynapseError, +) +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.event_auth import auth_types_for_event +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.federation.federation_client import InvalidResponseError +from synapse.handlers._base import BaseHandler +from synapse.logging.context import ( + make_deferred_yieldable, + nested_logging_context, + run_in_background, +) +from synapse.logging.utils import log_function +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet +from synapse.replication.http.federation import ( + ReplicationFederationSendEventsRestServlet, +) +from synapse.state import StateResolutionStore +from synapse.storage.databases.main.events_worker import EventRedactBehaviour +from synapse.types import ( + MutableStateMap, + PersistedEventPosition, + RoomStreamToken, + StateMap, + UserID, + get_domain_from_id, +) +from synapse.util.async_helpers import Linearizer, concurrently_execute +from synapse.util.iterutils import batch_iter +from synapse.util.retryutils import NotRetryingDestination +from synapse.util.stringutils import shortstr + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +logger = logging.getLogger(__name__) + +soft_failed_event_counter = Counter( + "synapse_federation_soft_failed_events_total", + "Events received over federation that we marked as soft_failed", +) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _NewEventInfo: + """Holds information about a received event, ready for passing to _auth_and_persist_events + + Attributes: + event: the received event + + claimed_auth_event_map: a map of (type, state_key) => event for the event's + claimed auth_events. + + This can include events which have not yet been persisted, in the case that + we are backfilling a batch of events. + + Note: May be incomplete: if we were unable to find all of the claimed auth + events. Also, treat the contents with caution: the events might also have + been rejected, might not yet have been authorized themselves, or they might + be in the wrong room. + + """ + + event: EventBase + claimed_auth_event_map: StateMap[EventBase] + + +class FederationEventHandler(BaseHandler): + """Handles events that originated from federation. + + Responsible for handing incoming events and passing them on to the rest + of the homeserver (including auth and state conflict resolutions) + """ + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.store = hs.get_datastore() + self.storage = hs.get_storage() + self.state_store = self.storage.state + + self.state_handler = hs.get_state_handler() + self.event_creation_handler = hs.get_event_creation_handler() + self._event_auth_handler = hs.get_event_auth_handler() + self._message_handler = hs.get_message_handler() + self.action_generator = hs.get_action_generator() + self._state_resolution_handler = hs.get_state_resolution_handler() + + self.federation_client = hs.get_federation_client() + self.third_party_event_rules = hs.get_third_party_event_rules() + + self.is_mine_id = hs.is_mine_id + self._instance_name = hs.get_instance_name() + + self.config = hs.config + self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages + + self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs) + if hs.config.worker_app: + self._user_device_resync = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) + ) + else: + self._device_list_updater = hs.get_device_handler().device_list_updater + + # When joining a room we need to queue any events for that room up. + # For each room, a list of (pdu, origin) tuples. + # TODO: replace this with something more elegant, probably based around the + # federation event staging area. + self.room_queues: Dict[str, List[Tuple[EventBase, str]]] = {} + + self._room_pdu_linearizer = Linearizer("fed_room_pdu") + + async def on_receive_pdu(self, origin: str, pdu: EventBase) -> None: + """Process a PDU received via a federation /send/ transaction + + Args: + origin: server which initiated the /send/ transaction. Will + be used to fetch missing events or state. + pdu: received PDU + """ + + room_id = pdu.room_id + event_id = pdu.event_id + + # We reprocess pdus when we have seen them only as outliers + existing = await self.store.get_event( + event_id, allow_none=True, allow_rejected=True + ) + + # FIXME: Currently we fetch an event again when we already have it + # if it has been marked as an outlier. + if existing: + if not existing.internal_metadata.is_outlier(): + logger.info( + "Ignoring received event %s which we have already seen", event_id + ) + return + if pdu.internal_metadata.is_outlier(): + logger.info( + "Ignoring received outlier %s which we already have as an outlier", + event_id, + ) + return + logger.info("De-outliering event %s", event_id) + + # do some initial sanity-checking of the event. In particular, make + # sure it doesn't have hundreds of prev_events or auth_events, which + # could cause a huge state resolution or cascade of event fetches. + try: + self._sanity_check_event(pdu) + except SynapseError as err: + logger.warning("Received event failed sanity checks") + raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id) + + # If we are currently in the process of joining this room, then we + # queue up events for later processing. + if room_id in self.room_queues: + logger.info( + "Queuing PDU from %s for now: join in progress", + origin, + ) + self.room_queues[room_id].append((pdu, origin)) + return + + # If we're not in the room just ditch the event entirely. This is + # probably an old server that has come back and thinks we're still in + # the room (or we've been rejoined to the room by a state reset). + # + # Note that if we were never in the room then we would have already + # dropped the event, since we wouldn't know the room version. + is_in_room = await self._event_auth_handler.check_host_in_room( + room_id, self.server_name + ) + if not is_in_room: + logger.info( + "Ignoring PDU from %s as we're not in the room", + origin, + ) + return None + + # Check that the event passes auth based on the state at the event. This is + # done for events that are to be added to the timeline (non-outliers). + # + # Get missing pdus if necessary: + # - Fetching any missing prev events to fill in gaps in the graph + # - Fetching state if we have a hole in the graph + if not pdu.internal_metadata.is_outlier(): + prevs = set(pdu.prev_event_ids()) + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen + + if missing_prevs: + # We only backfill backwards to the min depth. + min_depth = await self.get_min_depth_for_context(pdu.room_id) + logger.debug("min_depth: %d", min_depth) + + if min_depth is not None and pdu.depth > min_depth: + # If we're missing stuff, ensure we only fetch stuff one + # at a time. + logger.info( + "Acquiring room lock to fetch %d missing prev_events: %s", + len(missing_prevs), + shortstr(missing_prevs), + ) + with (await self._room_pdu_linearizer.queue(pdu.room_id)): + logger.info( + "Acquired room lock to fetch %d missing prev_events", + len(missing_prevs), + ) + + try: + await self._get_missing_events_for_pdu( + origin, pdu, prevs, min_depth + ) + except Exception as e: + raise Exception( + "Error fetching missing prev_events for %s: %s" + % (event_id, e) + ) from e + + # Update the set of things we've seen after trying to + # fetch the missing stuff + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen + + if not missing_prevs: + logger.info("Found all missing prev_events") + + if missing_prevs: + # since this event was pushed to us, it is possible for it to + # become the only forward-extremity in the room, and we would then + # trust its state to be the state for the whole room. This is very + # bad. Further, if the event was pushed to us, there is no excuse + # for us not to have all the prev_events. (XXX: apart from + # min_depth?) + # + # We therefore reject any such events. + logger.warning( + "Rejecting: failed to fetch %d prev events: %s", + len(missing_prevs), + shortstr(missing_prevs), + ) + raise FederationError( + "ERROR", + 403, + ( + "Your server isn't divulging details about prev_events " + "referenced in this event." + ), + affected=pdu.event_id, + ) + + await self._process_received_pdu(origin, pdu, state=None) + + @log_function + async def on_send_membership_event( + self, origin: str, event: EventBase + ) -> Tuple[EventBase, EventContext]: + """ + We have received a join/leave/knock event for a room via send_join/leave/knock. + + Verify that event and send it into the room on the remote homeserver's behalf. + + This is quite similar to on_receive_pdu, with the following principal + differences: + * only membership events are permitted (and only events with + sender==state_key -- ie, no kicks or bans) + * *We* send out the event on behalf of the remote server. + * We enforce the membership restrictions of restricted rooms. + * Rejected events result in an exception rather than being stored. + + There are also other differences, however it is not clear if these are by + design or omission. In particular, we do not attempt to backfill any missing + prev_events. + + Args: + origin: The homeserver of the remote (joining/invited/knocking) user. + event: The member event that has been signed by the remote homeserver. + + Returns: + The event and context of the event after inserting it into the room graph. + + Raises: + SynapseError if the event is not accepted into the room + """ + logger.debug( + "on_send_membership_event: Got event: %s, signatures: %s", + event.event_id, + event.signatures, + ) + + if get_domain_from_id(event.sender) != origin: + logger.info( + "Got send_membership request for user %r from different origin %s", + event.sender, + origin, + ) + raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) + + if event.sender != event.state_key: + raise SynapseError(400, "state_key and sender must match", Codes.BAD_JSON) + + assert not event.internal_metadata.outlier + + # Send this event on behalf of the other server. + # + # The remote server isn't a full participant in the room at this point, so + # may not have an up-to-date list of the other homeservers participating in + # the room, so we send it on their behalf. + event.internal_metadata.send_on_behalf_of = origin + + context = await self.state_handler.compute_event_context(event) + context = await self._check_event_auth(origin, event, context) + if context.rejected: + raise SynapseError( + 403, f"{event.membership} event was rejected", Codes.FORBIDDEN + ) + + # for joins, we need to check the restrictions of restricted rooms + if event.membership == Membership.JOIN: + await self.check_join_restrictions(context, event) + + # for knock events, we run the third-party event rules. It's not entirely clear + # why we don't do this for other sorts of membership events. + if event.membership == Membership.KNOCK: + event_allowed, _ = await self.third_party_event_rules.check_event_allowed( + event, context + ) + if not event_allowed: + logger.info("Sending of knock %s forbidden by third-party rules", event) + raise SynapseError( + 403, "This event is not allowed in this context", Codes.FORBIDDEN + ) + + # all looks good, we can persist the event. + await self._run_push_actions_and_persist_event(event, context) + return event, context + + async def check_join_restrictions( + self, context: EventContext, event: EventBase + ) -> None: + """Check that restrictions in restricted join rules are matched + + Called when we receive a join event via send_join. + + Raises an auth error if the restrictions are not matched. + """ + prev_state_ids = await context.get_prev_state_ids() + + # Check if the user is already in the room or invited to the room. + user_id = event.state_key + prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) + prev_member_event = None + if prev_member_event_id: + prev_member_event = await self.store.get_event(prev_member_event_id) + + # Check if the member should be allowed access via membership in a space. + await self._event_auth_handler.check_restricted_join_rules( + prev_state_ids, + event.room_version, + user_id, + prev_member_event, + ) + + @log_function + async def backfill( + self, dest: str, room_id: str, limit: int, extremities: List[str] + ) -> None: + """Trigger a backfill request to `dest` for the given `room_id` + + This will attempt to get more events from the remote. If the other side + has no new events to offer, this will return an empty list. + + As the events are received, we check their signatures, and also do some + sanity-checking on them. If any of the backfilled events are invalid, + this method throws a SynapseError. + + We might also raise an InvalidResponseError if the response from the remote + server is just bogus. + + TODO: make this more useful to distinguish failures of the remote + server from invalid events (there is probably no point in trying to + re-fetch invalid events from every other HS in the room.) + """ + if dest == self.server_name: + raise SynapseError(400, "Can't backfill from self.") + + events = await self.federation_client.backfill( + dest, room_id, limit=limit, extremities=extremities + ) + + if not events: + return + + # if there are any events in the wrong room, the remote server is buggy and + # should not be trusted. + for ev in events: + if ev.room_id != room_id: + raise InvalidResponseError( + f"Remote server {dest} returned event {ev.event_id} which is in " + f"room {ev.room_id}, when we were backfilling in {room_id}" + ) + + await self._process_pulled_events(dest, events, backfilled=True) + + async def _get_missing_events_for_pdu( + self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int + ) -> None: + """ + Args: + origin: Origin of the pdu. Will be called to get the missing events + pdu: received pdu + prevs: List of event ids which we are missing + min_depth: Minimum depth of events to return. + """ + + room_id = pdu.room_id + event_id = pdu.event_id + + seen = await self.store.have_events_in_timeline(prevs) + + if not prevs - seen: + return + + latest_list = await self.store.get_latest_event_ids_in_room(room_id) + + # We add the prev events that we have seen to the latest + # list to ensure the remote server doesn't give them to us + latest = set(latest_list) + latest |= seen + + logger.info( + "Requesting missing events between %s and %s", + shortstr(latest), + event_id, + ) + + # XXX: we set timeout to 10s to help workaround + # https://github.com/matrix-org/synapse/issues/1733. + # The reason is to avoid holding the linearizer lock + # whilst processing inbound /send transactions, causing + # FDs to stack up and block other inbound transactions + # which empirically can currently take up to 30 minutes. + # + # N.B. this explicitly disables retry attempts. + # + # N.B. this also increases our chances of falling back to + # fetching fresh state for the room if the missing event + # can't be found, which slightly reduces our security. + # it may also increase our DAG extremity count for the room, + # causing additional state resolution? See #1760. + # However, fetching state doesn't hold the linearizer lock + # apparently. + # + # see https://github.com/matrix-org/synapse/pull/1744 + # + # ---- + # + # Update richvdh 2018/09/18: There are a number of problems with timing this + # request out aggressively on the client side: + # + # - it plays badly with the server-side rate-limiter, which starts tarpitting you + # if you send too many requests at once, so you end up with the server carefully + # working through the backlog of your requests, which you have already timed + # out. + # + # - for this request in particular, we now (as of + # https://github.com/matrix-org/synapse/pull/3456) reject any PDUs where the + # server can't produce a plausible-looking set of prev_events - so we becone + # much more likely to reject the event. + # + # - contrary to what it says above, we do *not* fall back to fetching fresh state + # for the room if get_missing_events times out. Rather, we give up processing + # the PDU whose prevs we are missing, which then makes it much more likely that + # we'll end up back here for the *next* PDU in the list, which exacerbates the + # problem. + # + # - the aggressive 10s timeout was introduced to deal with incoming federation + # requests taking 8 hours to process. It's not entirely clear why that was going + # on; certainly there were other issues causing traffic storms which are now + # resolved, and I think in any case we may be more sensible about our locking + # now. We're *certainly* more sensible about our logging. + # + # All that said: Let's try increasing the timeout to 60s and see what happens. + + try: + missing_events = await self.federation_client.get_missing_events( + origin, + room_id, + earliest_events_ids=list(latest), + latest_events=[pdu], + limit=10, + min_depth=min_depth, + timeout=60000, + ) + except (RequestSendFailed, HttpResponseException, NotRetryingDestination) as e: + # We failed to get the missing events, but since we need to handle + # the case of `get_missing_events` not returning the necessary + # events anyway, it is safe to simply log the error and continue. + logger.warning("Failed to get prev_events: %s", e) + return + + logger.info("Got %d prev_events", len(missing_events)) + await self._process_pulled_events(origin, missing_events, backfilled=False) + + async def _process_pulled_events( + self, origin: str, events: Iterable[EventBase], backfilled: bool + ) -> None: + """Process a batch of events we have pulled from a remote server + + Pulls in any events required to auth the events, persists the received events, + and notifies clients, if appropriate. + + Assumes the events have already had their signatures and hashes checked. + + Params: + origin: The server we received these events from + events: The received events. + backfilled: True if this is part of a historical batch of events (inhibits + notification to clients, and validation of device keys.) + """ + + # We want to sort these by depth so we process them and + # tell clients about them in order. + sorted_events = sorted(events, key=lambda x: x.depth) + + for ev in sorted_events: + with nested_logging_context(ev.event_id): + await self._process_pulled_event(origin, ev, backfilled=backfilled) + + async def _process_pulled_event( + self, origin: str, event: EventBase, backfilled: bool + ) -> None: + """Process a single event that we have pulled from a remote server + + Pulls in any events required to auth the event, persists the received event, + and notifies clients, if appropriate. + + Assumes the event has already had its signatures and hashes checked. + + This is somewhat equivalent to on_receive_pdu, but applies somewhat different + logic in the case that we are missing prev_events (in particular, it just + requests the state at that point, rather than triggering a get_missing_events) - + so is appropriate when we have pulled the event from a remote server, rather + than having it pushed to us. + + Params: + origin: The server we received this event from + events: The received event + backfilled: True if this is part of a historical batch of events (inhibits + notification to clients, and validation of device keys.) + """ + logger.info("Processing pulled event %s", event) + + # these should not be outliers. + assert not event.internal_metadata.is_outlier() + + event_id = event.event_id + + existing = await self.store.get_event( + event_id, allow_none=True, allow_rejected=True + ) + if existing: + if not existing.internal_metadata.is_outlier(): + logger.info( + "Ignoring received event %s which we have already seen", + event_id, + ) + return + logger.info("De-outliering event %s", event_id) + + try: + self._sanity_check_event(event) + except SynapseError as err: + logger.warning("Event %s failed sanity check: %s", event_id, err) + return + + try: + state = await self._resolve_state_at_missing_prevs(origin, event) + await self._process_received_pdu( + origin, event, state=state, backfilled=backfilled + ) + except FederationError as e: + if e.code == 403: + logger.warning("Pulled event %s failed history check.", event_id) + else: + raise + + async def _resolve_state_at_missing_prevs( + self, dest: str, event: EventBase + ) -> Optional[Iterable[EventBase]]: + """Calculate the state at an event with missing prev_events. + + This is used when we have pulled a batch of events from a remote server, and + still don't have all the prev_events. + + If we already have all the prev_events for `event`, this method does nothing. + + Otherwise, the missing prevs become new backwards extremities, and we fall back + to asking the remote server for the state after each missing `prev_event`, + and resolving across them. + + That's ok provided we then resolve the state against other bits of the DAG + before using it - in other words, that the received event `event` is not going + to become the only forwards_extremity in the room (which will ensure that you + can't just take over a room by sending an event, withholding its prev_events, + and declaring yourself to be an admin in the subsequent state request). + + In other words: we should only call this method if `event` has been *pulled* + as part of a batch of missing prev events, or similar. + + Params: + dest: the remote server to ask for state at the missing prevs. Typically, + this will be the server we got `event` from. + event: an event to check for missing prevs. + + Returns: + if we already had all the prev events, `None`. Otherwise, returns a list of + the events in the state at `event`. + """ + room_id = event.room_id + event_id = event.event_id + + prevs = set(event.prev_event_ids()) + seen = await self.store.have_events_in_timeline(prevs) + missing_prevs = prevs - seen + + if not missing_prevs: + return None + + logger.info( + "Event %s is missing prev_events %s: calculating state for a " + "backwards extremity", + event_id, + shortstr(missing_prevs), + ) + # Calculate the state after each of the previous events, and + # resolve them to find the correct state at the current event. + event_map = {event_id: event} + try: + # Get the state of the events we know about + ours = await self.state_store.get_state_groups_ids(room_id, seen) + + # state_maps is a list of mappings from (type, state_key) to event_id + state_maps: List[StateMap[str]] = list(ours.values()) + + # we don't need this any more, let's delete it. + del ours + + # Ask the remote server for the states we don't + # know about + for p in missing_prevs: + logger.info("Requesting state after missing prev_event %s", p) + + with nested_logging_context(p): + # note that if any of the missing prevs share missing state or + # auth events, the requests to fetch those events are deduped + # by the get_pdu_cache in federation_client. + remote_state = await self._get_state_after_missing_prev_event( + dest, room_id, p + ) + + remote_state_map = { + (x.type, x.state_key): x.event_id for x in remote_state + } + state_maps.append(remote_state_map) + + for x in remote_state: + event_map[x.event_id] = x + + room_version = await self.store.get_room_version_id(room_id) + state_map = await self._state_resolution_handler.resolve_events_with_store( + room_id, + room_version, + state_maps, + event_map, + state_res_store=StateResolutionStore(self.store), + ) + + # We need to give _process_received_pdu the actual state events + # rather than event ids, so generate that now. + + # First though we need to fetch all the events that are in + # state_map, so we can build up the state below. + evs = await self.store.get_events( + list(state_map.values()), + get_prev_content=False, + redact_behaviour=EventRedactBehaviour.AS_IS, + ) + event_map.update(evs) + + state = [event_map[e] for e in state_map.values()] + except Exception: + logger.warning( + "Error attempting to resolve state at missing prev_events", + exc_info=True, + ) + raise FederationError( + "ERROR", + 403, + "We can't get valid state history.", + affected=event_id, + ) + return state + + async def _get_state_after_missing_prev_event( + self, + destination: str, + room_id: str, + event_id: str, + ) -> List[EventBase]: + """Requests all of the room state at a given event from a remote homeserver. + + Args: + destination: The remote homeserver to query for the state. + room_id: The id of the room we're interested in. + event_id: The id of the event we want the state at. + + Returns: + A list of events in the state, including the event itself + """ + ( + state_event_ids, + auth_event_ids, + ) = await self.federation_client.get_room_state_ids( + destination, room_id, event_id=event_id + ) + + logger.debug( + "state_ids returned %i state events, %i auth events", + len(state_event_ids), + len(auth_event_ids), + ) + + # start by just trying to fetch the events from the store + desired_events = set(state_event_ids) + desired_events.add(event_id) + logger.debug("Fetching %i events from cache/store", len(desired_events)) + fetched_events = await self.store.get_events( + desired_events, allow_rejected=True + ) + + missing_desired_events = desired_events - fetched_events.keys() + logger.debug( + "We are missing %i events (got %i)", + len(missing_desired_events), + len(fetched_events), + ) + + # We probably won't need most of the auth events, so let's just check which + # we have for now, rather than thrashing the event cache with them all + # unnecessarily. + + # TODO: we probably won't actually need all of the auth events, since we + # already have a bunch of the state events. It would be nice if the + # federation api gave us a way of finding out which we actually need. + + missing_auth_events = set(auth_event_ids) - fetched_events.keys() + missing_auth_events.difference_update( + await self.store.have_seen_events(room_id, missing_auth_events) + ) + logger.debug("We are also missing %i auth events", len(missing_auth_events)) + + missing_events = missing_desired_events | missing_auth_events + logger.debug("Fetching %i events from remote", len(missing_events)) + await self._get_events_and_persist( + destination=destination, room_id=room_id, events=missing_events + ) + + # we need to make sure we re-load from the database to get the rejected + # state correct. + fetched_events.update( + await self.store.get_events(missing_desired_events, allow_rejected=True) + ) + + # check for events which were in the wrong room. + # + # this can happen if a remote server claims that the state or + # auth_events at an event in room A are actually events in room B + + bad_events = [ + (event_id, event.room_id) + for event_id, event in fetched_events.items() + if event.room_id != room_id + ] + + for bad_event_id, bad_room_id in bad_events: + # This is a bogus situation, but since we may only discover it a long time + # after it happened, we try our best to carry on, by just omitting the + # bad events from the returned state set. + logger.warning( + "Remote server %s claims event %s in room %s is an auth/state " + "event in room %s", + destination, + bad_event_id, + bad_room_id, + room_id, + ) + + del fetched_events[bad_event_id] + + # if we couldn't get the prev event in question, that's a problem. + remote_event = fetched_events.get(event_id) + if not remote_event: + raise Exception("Unable to get missing prev_event %s" % (event_id,)) + + # missing state at that event is a warning, not a blocker + # XXX: this doesn't sound right? it means that we'll end up with incomplete + # state. + failed_to_fetch = desired_events - fetched_events.keys() + if failed_to_fetch: + logger.warning( + "Failed to fetch missing state events for %s %s", + event_id, + failed_to_fetch, + ) + + remote_state = [ + fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events + ] + + if remote_event.is_state() and remote_event.rejected_reason is None: + remote_state.append(remote_event) + + return remote_state + + async def _process_received_pdu( + self, + origin: str, + event: EventBase, + state: Optional[Iterable[EventBase]], + backfilled: bool = False, + ) -> None: + """Called when we have a new pdu. We need to do auth checks and put it + through the StateHandler. + + Args: + origin: server sending the event + + event: event to be persisted + + state: Normally None, but if we are handling a gap in the graph + (ie, we are missing one or more prev_events), the resolved state at the + event + + backfilled: True if this is part of a historical batch of events (inhibits + notification to clients, and validation of device keys.) + """ + logger.debug("Processing event: %s", event) + + try: + context = await self.state_handler.compute_event_context( + event, old_state=state + ) + await self._auth_and_persist_event( + origin, event, context, state=state, backfilled=backfilled + ) + except AuthError as e: + raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) + + if backfilled: + return + + # For encrypted messages we check that we know about the sending device, + # if we don't then we mark the device cache for that user as stale. + if event.type == EventTypes.Encrypted: + device_id = event.content.get("device_id") + sender_key = event.content.get("sender_key") + + cached_devices = await self.store.get_cached_devices_for_user(event.sender) + + resync = False # Whether we should resync device lists. + + device = None + if device_id is not None: + device = cached_devices.get(device_id) + if device is None: + logger.info( + "Received event from remote device not in our cache: %s %s", + event.sender, + device_id, + ) + resync = True + + # We also check if the `sender_key` matches what we expect. + if sender_key is not None: + # Figure out what sender key we're expecting. If we know the + # device and recognize the algorithm then we can work out the + # exact key to expect. Otherwise check it matches any key we + # have for that device. + + current_keys: Container[str] = [] + + if device: + keys = device.get("keys", {}).get("keys", {}) + + if ( + event.content.get("algorithm") + == RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2 + ): + # For this algorithm we expect a curve25519 key. + key_name = "curve25519:%s" % (device_id,) + current_keys = [keys.get(key_name)] + else: + # We don't know understand the algorithm, so we just + # check it matches a key for the device. + current_keys = keys.values() + elif device_id: + # We don't have any keys for the device ID. + pass + else: + # The event didn't include a device ID, so we just look for + # keys across all devices. + current_keys = [ + key + for device in cached_devices.values() + for key in device.get("keys", {}).get("keys", {}).values() + ] + + # We now check that the sender key matches (one of) the expected + # keys. + if sender_key not in current_keys: + logger.info( + "Received event from remote device with unexpected sender key: %s %s: %s", + event.sender, + device_id or "", + sender_key, + ) + resync = True + + if resync: + run_as_background_process( + "resync_device_due_to_pdu", + self._resync_device, + event.sender, + ) + + await self._handle_marker_event(origin, event) + + async def _resync_device(self, sender: str) -> None: + """We have detected that the device list for the given user may be out + of sync, so we try and resync them. + """ + + try: + await self.store.mark_remote_user_device_cache_as_stale(sender) + + # Immediately attempt a resync in the background + if self.config.worker_app: + await self._user_device_resync(user_id=sender) + else: + await self._device_list_updater.user_device_resync(sender) + except Exception: + logger.exception("Failed to resync device for %s", sender) + + async def _handle_marker_event(self, origin: str, marker_event: EventBase): + """Handles backfilling the insertion event when we receive a marker + event that points to one. + + Args: + origin: Origin of the event. Will be called to get the insertion event + marker_event: The event to process + """ + + if marker_event.type != EventTypes.MSC2716_MARKER: + # Not a marker event + return + + if marker_event.rejected_reason is not None: + # Rejected event + return + + # Skip processing a marker event if the room version doesn't + # support it. + room_version = await self.store.get_room_version(marker_event.room_id) + if not room_version.msc2716_historical: + return + + logger.debug("_handle_marker_event: received %s", marker_event) + + insertion_event_id = marker_event.content.get( + EventContentFields.MSC2716_MARKER_INSERTION + ) + + if insertion_event_id is None: + # Nothing to retrieve then (invalid marker) + return + + logger.debug( + "_handle_marker_event: backfilling insertion event %s", insertion_event_id + ) + + await self._get_events_and_persist( + origin, + marker_event.room_id, + [insertion_event_id], + ) + + insertion_event = await self.store.get_event( + insertion_event_id, allow_none=True + ) + if insertion_event is None: + logger.warning( + "_handle_marker_event: server %s didn't return insertion event %s for marker %s", + origin, + insertion_event_id, + marker_event.event_id, + ) + return + + logger.debug( + "_handle_marker_event: succesfully backfilled insertion event %s from marker event %s", + insertion_event, + marker_event, + ) + + await self.store.insert_insertion_extremity( + insertion_event_id, marker_event.room_id + ) + + logger.debug( + "_handle_marker_event: insertion extremity added for %s from marker event %s", + insertion_event, + marker_event, + ) + + async def _get_events_and_persist( + self, destination: str, room_id: str, events: Iterable[str] + ) -> None: + """Fetch the given events from a server, and persist them as outliers. + + This function *does not* recursively get missing auth events of the + newly fetched events. Callers must include in the `events` argument + any missing events from the auth chain. + + Logs a warning if we can't find the given event. + """ + + room_version = await self.store.get_room_version(room_id) + + event_map: Dict[str, EventBase] = {} + + async def get_event(event_id: str): + with nested_logging_context(event_id): + try: + event = await self.federation_client.get_pdu( + [destination], + event_id, + room_version, + outlier=True, + ) + if event is None: + logger.warning( + "Server %s didn't return event %s", + destination, + event_id, + ) + return + + event_map[event.event_id] = event + + except Exception as e: + logger.warning( + "Error fetching missing state/auth event %s: %s %s", + event_id, + type(e), + e, + ) + + await concurrently_execute(get_event, events, 5) + + # Make a map of auth events for each event. We do this after fetching + # all the events as some of the events' auth events will be in the list + # of requested events. + + auth_events = [ + aid + for event in event_map.values() + for aid in event.auth_event_ids() + if aid not in event_map + ] + persisted_events = await self.store.get_events( + auth_events, + allow_rejected=True, + ) + + event_infos = [] + for event in event_map.values(): + auth = {} + for auth_event_id in event.auth_event_ids(): + ae = persisted_events.get(auth_event_id) or event_map.get(auth_event_id) + if ae: + auth[(ae.type, ae.state_key)] = ae + else: + logger.info("Missing auth event %s", auth_event_id) + + event_infos.append(_NewEventInfo(event, auth)) + + if event_infos: + await self._auth_and_persist_events( + destination, + room_id, + event_infos, + ) + + async def _auth_and_persist_events( + self, + origin: str, + room_id: str, + event_infos: Collection[_NewEventInfo], + ) -> None: + """Creates the appropriate contexts and persists events. The events + should not depend on one another, e.g. this should be used to persist + a bunch of outliers, but not a chunk of individual events that depend + on each other for state calculations. + + Notifies about the events where appropriate. + """ + + if not event_infos: + return + + async def prep(ev_info: _NewEventInfo): + event = ev_info.event + with nested_logging_context(suffix=event.event_id): + res = await self.state_handler.compute_event_context(event) + res = await self._check_event_auth( + origin, + event, + res, + claimed_auth_event_map=ev_info.claimed_auth_event_map, + ) + return res + + contexts = await make_deferred_yieldable( + defer.gatherResults( + [run_in_background(prep, ev_info) for ev_info in event_infos], + consumeErrors=True, + ) + ) + + await self.persist_events_and_notify( + room_id, + [ + (ev_info.event, context) + for ev_info, context in zip(event_infos, contexts) + ], + ) + + async def _auth_and_persist_event( + self, + origin: str, + event: EventBase, + context: EventContext, + state: Optional[Iterable[EventBase]] = None, + claimed_auth_event_map: Optional[StateMap[EventBase]] = None, + backfilled: bool = False, + ) -> None: + """ + Process an event by performing auth checks and then persisting to the database. + + Args: + origin: The host the event originates from. + event: The event itself. + context: + The event context. + + state: + The state events used to check the event for soft-fail. If this is + not provided the current state events will be used. + + claimed_auth_event_map: + A map of (type, state_key) => event for the event's claimed auth_events. + Possibly incomplete, and possibly including events that are not yet + persisted, or authed, or in the right room. + + Only populated where we may not already have persisted these events - + for example, when populating outliers. + + backfilled: True if the event was backfilled. + """ + context = await self._check_event_auth( + origin, + event, + context, + state=state, + claimed_auth_event_map=claimed_auth_event_map, + backfilled=backfilled, + ) + + await self._run_push_actions_and_persist_event(event, context, backfilled) + + async def _check_event_auth( + self, + origin: str, + event: EventBase, + context: EventContext, + state: Optional[Iterable[EventBase]] = None, + claimed_auth_event_map: Optional[StateMap[EventBase]] = None, + backfilled: bool = False, + ) -> EventContext: + """ + Checks whether an event should be rejected (for failing auth checks). + + Args: + origin: The host the event originates from. + event: The event itself. + context: + The event context. + + state: + The state events used to check the event for soft-fail. If this is + not provided the current state events will be used. + + claimed_auth_event_map: + A map of (type, state_key) => event for the event's claimed auth_events. + Possibly incomplete, and possibly including events that are not yet + persisted, or authed, or in the right room. + + Only populated where we may not already have persisted these events - + for example, when populating outliers, or the state for a backwards + extremity. + + backfilled: True if the event was backfilled. + + Returns: + The updated context object. + """ + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + if claimed_auth_event_map: + # if we have a copy of the auth events from the event, use that as the + # basis for auth. + auth_events = claimed_auth_event_map + else: + # otherwise, we calculate what the auth events *should* be, and use that + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = self._event_auth_handler.compute_auth_events( + event, prev_state_ids, for_verification=True + ) + auth_events_x = await self.store.get_events(auth_events_ids) + auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} + + try: + ( + context, + auth_events_for_auth, + ) = await self._update_auth_events_and_context_for_auth( + origin, event, context, auth_events + ) + except Exception: + # We don't really mind if the above fails, so lets not fail + # processing if it does. However, it really shouldn't fail so + # let's still log as an exception since we'll still want to fix + # any bugs. + logger.exception( + "Failed to double check auth events for %s with remote. " + "Ignoring failure and continuing processing of event.", + event.event_id, + ) + auth_events_for_auth = auth_events + + try: + event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth) + except AuthError as e: + logger.warning("Failed auth resolution for %r because %s", event, e) + context.rejected = RejectedReason.AUTH_ERROR + + if not context.rejected: + await self._check_for_soft_fail(event, state, backfilled, origin=origin) + + if event.type == EventTypes.GuestAccess and not context.rejected: + await self.maybe_kick_guest_users(event) + + # If we are going to send this event over federation we precaclculate + # the joined hosts. + if event.internal_metadata.get_send_on_behalf_of(): + await self.event_creation_handler.cache_joined_hosts_for_event( + event, context + ) + + return context + + async def _check_for_soft_fail( + self, + event: EventBase, + state: Optional[Iterable[EventBase]], + backfilled: bool, + origin: str, + ) -> None: + """Checks if we should soft fail the event; if so, marks the event as + such. + + Args: + event + state: The state at the event if we don't have all the event's prev events + backfilled: Whether the event is from backfill + origin: The host the event originates from. + """ + # For new (non-backfilled and non-outlier) events we check if the event + # passes auth based on the current state. If it doesn't then we + # "soft-fail" the event. + if backfilled or event.internal_metadata.is_outlier(): + return + + extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id) + extrem_ids = set(extrem_ids_list) + prev_event_ids = set(event.prev_event_ids()) + + if extrem_ids == prev_event_ids: + # If they're the same then the current state is the same as the + # state at the event, so no point rechecking auth for soft fail. + return + + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + # Calculate the "current state". + if state is not None: + # If we're explicitly given the state then we won't have all the + # prev events, and so we have a gap in the graph. In this case + # we want to be a little careful as we might have been down for + # a while and have an incorrect view of the current state, + # however we still want to do checks as gaps are easy to + # maliciously manufacture. + # + # So we use a "current state" that is actually a state + # resolution across the current forward extremities and the + # given state at the event. This should correctly handle cases + # like bans, especially with state res v2. + + state_sets_d = await self.state_store.get_state_groups( + event.room_id, extrem_ids + ) + state_sets: List[Iterable[EventBase]] = list(state_sets_d.values()) + state_sets.append(state) + current_states = await self.state_handler.resolve_events( + room_version, state_sets, event + ) + current_state_ids: StateMap[str] = { + k: e.event_id for k, e in current_states.items() + } + else: + current_state_ids = await self.state_handler.get_current_state_ids( + event.room_id, latest_event_ids=extrem_ids + ) + + logger.debug( + "Doing soft-fail check for %s: state %s", + event.event_id, + current_state_ids, + ) + + # Now check if event pass auth against said current state + auth_types = auth_types_for_event(room_version_obj, event) + current_state_ids_list = [ + e for k, e in current_state_ids.items() if k in auth_types + ] + + auth_events_map = await self.store.get_events(current_state_ids_list) + current_auth_events = { + (e.type, e.state_key): e for e in auth_events_map.values() + } + + try: + event_auth.check(room_version_obj, event, auth_events=current_auth_events) + except AuthError as e: + logger.warning( + "Soft-failing %r (from %s) because %s", + event, + e, + origin, + extra={ + "room_id": event.room_id, + "mxid": event.sender, + "hs": origin, + }, + ) + soft_failed_event_counter.inc() + event.internal_metadata.soft_failed = True + + async def _update_auth_events_and_context_for_auth( + self, + origin: str, + event: EventBase, + context: EventContext, + input_auth_events: StateMap[EventBase], + ) -> Tuple[EventContext, StateMap[EventBase]]: + """Helper for _check_event_auth. See there for docs. + + Checks whether a given event has the expected auth events. If it + doesn't then we talk to the remote server to compare state to see if + we can come to a consensus (e.g. if one server missed some valid + state). + + This attempts to resolve any potential divergence of state between + servers, but is not essential and so failures should not block further + processing of the event. + + Args: + origin: + event: + context: + + input_auth_events: + Map from (event_type, state_key) to event + + Normally, our calculated auth_events based on the state of the room + at the event's position in the DAG, though occasionally (eg if the + event is an outlier), may be the auth events claimed by the remote + server. + + Returns: + updated context, updated auth event map + """ + # take a copy of input_auth_events before we modify it. + auth_events: MutableStateMap[EventBase] = dict(input_auth_events) + + event_auth_events = set(event.auth_event_ids()) + + # missing_auth is the set of the event's auth_events which we don't yet have + # in auth_events. + missing_auth = event_auth_events.difference( + e.event_id for e in auth_events.values() + ) + + # if we have missing events, we need to fetch those events from somewhere. + # + # we start by checking if they are in the store, and then try calling /event_auth/. + if missing_auth: + have_events = await self.store.have_seen_events(event.room_id, missing_auth) + logger.debug("Events %s are in the store", have_events) + missing_auth.difference_update(have_events) + + if missing_auth: + # If we don't have all the auth events, we need to get them. + logger.info("auth_events contains unknown events: %s", missing_auth) + try: + try: + remote_auth_chain = await self.federation_client.get_event_auth( + origin, event.room_id, event.event_id + ) + except RequestSendFailed as e1: + # The other side isn't around or doesn't implement the + # endpoint, so lets just bail out. + logger.info("Failed to get event auth from remote: %s", e1) + return context, auth_events + + seen_remotes = await self.store.have_seen_events( + event.room_id, [e.event_id for e in remote_auth_chain] + ) + + for e in remote_auth_chain: + if e.event_id in seen_remotes: + continue + + if e.event_id == event.event_id: + continue + + try: + auth_ids = e.auth_event_ids() + auth = { + (e.type, e.state_key): e + for e in remote_auth_chain + if e.event_id in auth_ids or e.type == EventTypes.Create + } + e.internal_metadata.outlier = True + + logger.debug( + "_check_event_auth %s missing_auth: %s", + event.event_id, + e.event_id, + ) + missing_auth_event_context = ( + await self.state_handler.compute_event_context(e) + ) + await self._auth_and_persist_event( + origin, + e, + missing_auth_event_context, + claimed_auth_event_map=auth, + ) + + if e.event_id in event_auth_events: + auth_events[(e.type, e.state_key)] = e + except AuthError: + pass + + except Exception: + logger.exception("Failed to get auth chain") + + if event.internal_metadata.is_outlier(): + # XXX: given that, for an outlier, we'll be working with the + # event's *claimed* auth events rather than those we calculated: + # (a) is there any point in this test, since different_auth below will + # obviously be empty + # (b) alternatively, why don't we do it earlier? + logger.info("Skipping auth_event fetch for outlier") + return context, auth_events + + different_auth = event_auth_events.difference( + e.event_id for e in auth_events.values() + ) + + if not different_auth: + return context, auth_events + + logger.info( + "auth_events refers to events which are not in our calculated auth " + "chain: %s", + different_auth, + ) + + # XXX: currently this checks for redactions but I'm not convinced that is + # necessary? + different_events = await self.store.get_events_as_list(different_auth) + + for d in different_events: + if d.room_id != event.room_id: + logger.warning( + "Event %s refers to auth_event %s which is in a different room", + event.event_id, + d.event_id, + ) + + # don't attempt to resolve the claimed auth events against our own + # in this case: just use our own auth events. + # + # XXX: should we reject the event in this case? It feels like we should, + # but then shouldn't we also do so if we've failed to fetch any of the + # auth events? + return context, auth_events + + # now we state-resolve between our own idea of the auth events, and the remote's + # idea of them. + + local_state = auth_events.values() + remote_auth_events = dict(auth_events) + remote_auth_events.update({(d.type, d.state_key): d for d in different_events}) + remote_state = remote_auth_events.values() + + room_version = await self.store.get_room_version_id(event.room_id) + new_state = await self.state_handler.resolve_events( + room_version, (local_state, remote_state), event + ) + + logger.info( + "After state res: updating auth_events with new state %s", + { + (d.type, d.state_key): d.event_id + for d in new_state.values() + if auth_events.get((d.type, d.state_key)) != d + }, + ) + + auth_events.update(new_state) + + context = await self._update_context_for_auth_events( + event, context, auth_events + ) + + return context, auth_events + + async def _update_context_for_auth_events( + self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] + ) -> EventContext: + """Update the state_ids in an event context after auth event resolution, + storing the changes as a new state group. + + Args: + event: The event we're handling the context for + + context: initial event context + + auth_events: Events to update in the event context. + + Returns: + new event context + """ + # exclude the state key of the new event from the current_state in the context. + if event.is_state(): + event_key: Optional[Tuple[str, str]] = (event.type, event.state_key) + else: + event_key = None + state_updates = { + k: a.event_id for k, a in auth_events.items() if k != event_key + } + + current_state_ids = await context.get_current_state_ids() + current_state_ids = dict(current_state_ids) # type: ignore + + current_state_ids.update(state_updates) + + prev_state_ids = await context.get_prev_state_ids() + prev_state_ids = dict(prev_state_ids) + + prev_state_ids.update({k: a.event_id for k, a in auth_events.items()}) + + # create a new state group as a delta from the existing one. + prev_group = context.state_group + state_group = await self.state_store.store_state_group( + event.event_id, + event.room_id, + prev_group=prev_group, + delta_ids=state_updates, + current_state_ids=current_state_ids, + ) + + return EventContext.with_state( + state_group=state_group, + state_group_before_event=context.state_group_before_event, + current_state_ids=current_state_ids, + prev_state_ids=prev_state_ids, + prev_group=prev_group, + delta_ids=state_updates, + ) + + async def _run_push_actions_and_persist_event( + self, event: EventBase, context: EventContext, backfilled: bool = False + ): + """Run the push actions for a received event, and persist it. + + Args: + event: The event itself. + context: The event context. + backfilled: True if the event was backfilled. + """ + try: + if ( + not event.internal_metadata.is_outlier() + and not backfilled + and not context.rejected + and (await self.store.get_min_depth(event.room_id)) <= event.depth + ): + await self.action_generator.handle_push_actions_for_event( + event, context + ) + + await self.persist_events_and_notify( + event.room_id, [(event, context)], backfilled=backfilled + ) + except Exception: + run_in_background( + self.store.remove_push_actions_from_staging, event.event_id + ) + raise + + async def persist_events_and_notify( + self, + room_id: str, + event_and_contexts: Sequence[Tuple[EventBase, EventContext]], + backfilled: bool = False, + ) -> int: + """Persists events and tells the notifier/pushers about them, if + necessary. + + Args: + room_id: The room ID of events being persisted. + event_and_contexts: Sequence of events with their associated + context that should be persisted. All events must belong to + the same room. + backfilled: Whether these events are a result of + backfilling or not + + Returns: + The stream ID after which all events have been persisted. + """ + if not event_and_contexts: + return self.store.get_current_events_token() + + instance = self.config.worker.events_shard_config.get_instance(room_id) + if instance != self._instance_name: + # Limit the number of events sent over replication. We choose 200 + # here as that is what we default to in `max_request_body_size(..)` + for batch in batch_iter(event_and_contexts, 200): + result = await self._send_events( + instance_name=instance, + store=self.store, + room_id=room_id, + event_and_contexts=batch, + backfilled=backfilled, + ) + return result["max_stream_id"] + else: + assert self.storage.persistence + + # Note that this returns the events that were persisted, which may not be + # the same as were passed in if some were deduplicated due to transaction IDs. + events, max_stream_token = await self.storage.persistence.persist_events( + event_and_contexts, backfilled=backfilled + ) + + if self._ephemeral_messages_enabled: + for event in events: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) + + if not backfilled: # Never notify for backfilled events + for event in events: + await self._notify_persisted_event(event, max_stream_token) + + return max_stream_token.stream + + async def _notify_persisted_event( + self, event: EventBase, max_stream_token: RoomStreamToken + ) -> None: + """Checks to see if notifier/pushers should be notified about the + event or not. + + Args: + event: + max_stream_token: The max_stream_id returned by persist_events + """ + + extra_users = [] + if event.type == EventTypes.Member: + target_user_id = event.state_key + + # We notify for memberships if its an invite for one of our + # users + if event.internal_metadata.is_outlier(): + if event.membership != Membership.INVITE: + if not self.is_mine_id(target_user_id): + return + + target_user = UserID.from_string(target_user_id) + extra_users.append(target_user) + elif event.internal_metadata.is_outlier(): + return + + # the event has been persisted so it should have a stream ordering. + assert event.internal_metadata.stream_ordering + + event_pos = PersistedEventPosition( + self._instance_name, event.internal_metadata.stream_ordering + ) + self.notifier.on_new_room_event( + event, event_pos, max_stream_token, extra_users=extra_users + ) + + def _sanity_check_event(self, ev: EventBase) -> None: + """ + Do some early sanity checks of a received event + + In particular, checks it doesn't have an excessive number of + prev_events or auth_events, which could cause a huge state resolution + or cascade of event fetches. + + Args: + ev: event to be checked + + Raises: + SynapseError if the event does not pass muster + """ + if len(ev.prev_event_ids()) > 20: + logger.warning( + "Rejecting event %s which has %i prev_events", + ev.event_id, + len(ev.prev_event_ids()), + ) + raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many prev_events") + + if len(ev.auth_event_ids()) > 10: + logger.warning( + "Rejecting event %s which has %i auth_events", + ev.event_id, + len(ev.auth_event_ids()), + ) + raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") + + async def get_min_depth_for_context(self, context: str) -> int: + return await self.store.get_min_depth(context) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 79cadb7b57..a0b3145f4e 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -62,7 +62,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.store = hs.get_datastore() self.storage = hs.get_storage() self.clock = hs.get_clock() - self.federation_handler = hs.get_federation_handler() + self.federation_event_handler = hs.get_federation_event_handler() @staticmethod async def _serialize_payload(store, room_id, event_and_contexts, backfilled): @@ -127,7 +127,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): logger.info("Got %d events from federation", len(event_and_contexts)) - max_stream_id = await self.federation_handler.persist_events_and_notify( + max_stream_id = await self.federation_event_handler.persist_events_and_notify( room_id, event_and_contexts, backfilled ) diff --git a/synapse/server.py b/synapse/server.py index de6517663e..5adeeff61a 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -76,6 +76,7 @@ from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler from synapse.handlers.event_auth import EventAuthHandler from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.federation import FederationHandler +from synapse.handlers.federation_event import FederationEventHandler from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler from synapse.handlers.identity import IdentityHandler from synapse.handlers.initial_sync import InitialSyncHandler @@ -546,6 +547,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_federation_handler(self) -> FederationHandler: return FederationHandler(self) + @cache_in_self + def get_federation_event_handler(self) -> FederationEventHandler: + return FederationEventHandler(self) + @cache_in_self def get_identity_handler(self) -> IdentityHandler: return IdentityHandler(self) diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index 383214ab50..663960ff53 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -208,7 +208,7 @@ class FederationKnockingTestCase( async def _check_event_auth(origin, event, context, *args, **kwargs): return context - homeserver.get_federation_handler()._check_event_auth = _check_event_auth + homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth return super().prepare(reactor, clock, homeserver) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index c72a8972a3..6c67a16de9 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -130,7 +130,9 @@ class FederationTestCase(unittest.HomeserverTestCase): ) with LoggingContext("send_rejected"): - d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) + d = run_in_background( + self.hs.get_federation_event_handler().on_receive_pdu, OTHER_SERVER, ev + ) self.get_success(d) # that should have been rejected @@ -182,7 +184,9 @@ class FederationTestCase(unittest.HomeserverTestCase): ) with LoggingContext("send_rejected"): - d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) + d = run_in_background( + self.hs.get_federation_event_handler().on_receive_pdu, OTHER_SERVER, ev + ) self.get_success(d) # that should have been rejected @@ -311,7 +315,9 @@ class FederationTestCase(unittest.HomeserverTestCase): with LoggingContext("receive_pdu"): # Fake the OTHER_SERVER federating the message event over to our local homeserver d = run_in_background( - self.handler.on_receive_pdu, OTHER_SERVER, message_event + self.hs.get_federation_event_handler().on_receive_pdu, + OTHER_SERVER, + message_event, ) self.get_success(d) @@ -382,7 +388,9 @@ class FederationTestCase(unittest.HomeserverTestCase): join_event.signatures[other_server] = {"x": "y"} with LoggingContext("send_join"): d = run_in_background( - self.handler.on_send_membership_event, other_server, join_event + self.hs.get_federation_event_handler().on_send_membership_event, + other_server, + join_event, ) self.get_success(d) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 0a52bc8b72..671dc7d083 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -885,7 +885,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.federation_sender = hs.get_federation_sender() self.event_builder_factory = hs.get_event_builder_factory() - self.federation_handler = hs.get_federation_handler() + self.federation_event_handler = hs.get_federation_event_handler() self.presence_handler = hs.get_presence_handler() # self.event_builder_for_2 = EventBuilderFactory(hs) @@ -1026,7 +1026,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None) ) - self.get_success(self.federation_handler.on_receive_pdu(hostname, event)) + self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event)) # Check that it was successfully persisted. self.get_success(self.store.get_event(event.event_id)) diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index af5dfca752..92a5b53e11 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -205,7 +205,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): def create_room_with_remote_server(self, user, token, remote_server="other_server"): room = self.helper.create_room_as(user, tok=token) store = self.hs.get_datastore() - federation = self.hs.get_federation_handler() + federation = self.hs.get_federation_event_handler() prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) room_version = self.get_success(store.get_room_version(room)) diff --git a/tests/test_federation.py b/tests/test_federation.py index 348fcb72a7..61c9d7c2ef 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -75,7 +75,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) self.handler = self.homeserver.get_federation_handler() - self.handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed( + federation_event_handler = self.homeserver.get_federation_event_handler() + federation_event_handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed( context ) self.client = self.homeserver.get_federation_client() @@ -85,7 +86,9 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Send the join, it should return None (which is not an error) self.assertEqual( - self.get_success(self.handler.on_receive_pdu("test.serv", join_event)), + self.get_success( + federation_event_handler.on_receive_pdu("test.serv", join_event) + ), None, ) @@ -129,9 +132,10 @@ class MessageAcceptTests(unittest.HomeserverTestCase): } ) + federation_event_handler = self.homeserver.get_federation_event_handler() with LoggingContext("test-context"): failure = self.get_failure( - self.handler.on_receive_pdu("test.serv", lying_event), + federation_event_handler.on_receive_pdu("test.serv", lying_event), FederationError, ) -- cgit 1.5.1 From c4fa4f37cbc734f9cd6354a5f2661efc30d73cac Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 27 Aug 2021 10:15:50 +0100 Subject: Fix perf of fetching the same events many times. (#10703) The code to deduplicate repeated fetches of the same set of events was N^2 (over the number of events requested), which could lead to a process being completely wedged. The main fix is to deduplicate the returned deferreds so we only await on a deferred once rather than many times. Seperately, when handling the returned events from the defrered we only add the events we care about to the event map to be returned (so that we don't pay the price of inserting extraneous events into the dict). --- changelog.d/10703.bugfix | 1 + synapse/storage/databases/main/events_worker.py | 29 ++++++++++++++++++++----- 2 files changed, 24 insertions(+), 6 deletions(-) create mode 100644 changelog.d/10703.bugfix diff --git a/changelog.d/10703.bugfix b/changelog.d/10703.bugfix new file mode 100644 index 0000000000..a5a4ecf8ee --- /dev/null +++ b/changelog.d/10703.bugfix @@ -0,0 +1 @@ +Fix a regression introduced in v1.41.0 which affected the performance of concurrent fetches of large sets of events, in extreme cases causing the process to hang. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 375463e4e9..9501f00f3b 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -520,16 +520,26 @@ class EventsWorkerStore(SQLBaseStore): # We now look up if we're already fetching some of the events in the DB, # if so we wait for those lookups to finish instead of pulling the same # events out of the DB multiple times. - already_fetching: Dict[str, defer.Deferred] = {} + # + # Note: we might get the same `ObservableDeferred` back for multiple + # events we're already fetching, so we deduplicate the deferreds to + # avoid extraneous work (if we don't do this we can end up in a n^2 mode + # when we wait on the same Deferred N times, then try and merge the + # same dict into itself N times). + already_fetching_ids: Set[str] = set() + already_fetching_deferreds: Set[ + ObservableDeferred[Dict[str, _EventCacheEntry]] + ] = set() for event_id in missing_events_ids: deferred = self._current_event_fetches.get(event_id) if deferred is not None: # We're already pulling the event out of the DB. Add the deferred # to the collection of deferreds to wait on. - already_fetching[event_id] = deferred.observe() + already_fetching_ids.add(event_id) + already_fetching_deferreds.add(deferred) - missing_events_ids.difference_update(already_fetching) + missing_events_ids.difference_update(already_fetching_ids) if missing_events_ids: log_ctx = current_context() @@ -569,18 +579,25 @@ class EventsWorkerStore(SQLBaseStore): with PreserveLoggingContext(): fetching_deferred.callback(missing_events) - if already_fetching: + if already_fetching_deferreds: # Wait for the other event requests to finish and add their results # to ours. results = await make_deferred_yieldable( defer.gatherResults( - already_fetching.values(), + (d.observe() for d in already_fetching_deferreds), consumeErrors=True, ) ).addErrback(unwrapFirstError) for result in results: - event_entry_map.update(result) + # We filter out events that we haven't asked for as we might get + # a *lot* of superfluous events back, and there is no point + # going through and inserting them all (which can take time). + event_entry_map.update( + (event_id, entry) + for event_id, entry in result.items() + if event_id in already_fetching_ids + ) if not allow_rejected: event_entry_map = { -- cgit 1.5.1 From e62cdbef1a499f428e48f98167b2b709d16c671d Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 27 Aug 2021 11:16:40 +0200 Subject: Improve ServerNoticeServlet to avoid duplicate requests (#10679) Fixes: #9544 --- changelog.d/10679.bugfix | 1 + synapse/rest/admin/__init__.py | 5 +- synapse/rest/admin/server_notice_servlet.py | 19 +- synapse/server_notices/server_notices_manager.py | 17 +- tests/rest/admin/test_server_notice.py | 450 +++++++++++++++++++++++ 5 files changed, 475 insertions(+), 17 deletions(-) create mode 100644 changelog.d/10679.bugfix create mode 100644 tests/rest/admin/test_server_notice.py diff --git a/changelog.d/10679.bugfix b/changelog.d/10679.bugfix new file mode 100644 index 0000000000..5c4061f6d5 --- /dev/null +++ b/changelog.d/10679.bugfix @@ -0,0 +1 @@ +Improve ServerNoticeServlet to avoid duplicate requests and add unit tests. \ No newline at end of file diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 6e1c8736e1..b2514d9d0d 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -223,7 +223,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomMembersRestServlet(hs).register(http_server) DeleteRoomRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) - SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) UserMembershipRestServlet(hs).register(http_server) @@ -247,6 +246,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: NewRegistrationTokenRestServlet(hs).register(http_server) RegistrationTokenRestServlet(hs).register(http_server) + # Some servlets only get registered for the main process. + if hs.config.worker_app is None: + SendServerNoticeServlet(hs).register(http_server) + def register_servlets_for_client_rest_resource( hs: "HomeServer", http_server: HttpServer diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index b5e4c474ef..42201afc86 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Optional, Tuple from synapse.api.constants import EventTypes -from synapse.api.errors import SynapseError +from synapse.api.errors import NotFoundError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -53,6 +53,8 @@ class SendServerNoticeServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() + self.server_notices_manager = hs.get_server_notices_manager() + self.admin_handler = hs.get_admin_handler() self.txns = HttpTransactionCache(hs) def register(self, json_resource: HttpServer): @@ -79,19 +81,22 @@ class SendServerNoticeServlet(RestServlet): # We grab the server notices manager here as its initialisation has a check for worker processes, # but worker processes still need to initialise SendServerNoticeServlet (as it is part of the # admin api). - if not self.hs.get_server_notices_manager().is_enabled(): + if not self.server_notices_manager.is_enabled(): raise SynapseError(400, "Server notices are not enabled on this server") - user_id = body["user_id"] - UserID.from_string(user_id) - if not self.hs.is_mine_id(user_id): + target_user = UserID.from_string(body["user_id"]) + if not self.hs.is_mine(target_user): raise SynapseError(400, "Server notices can only be sent to local users") - event = await self.hs.get_server_notices_manager().send_notice( - user_id=body["user_id"], + if not await self.admin_handler.get_user(target_user): + raise NotFoundError("User not found") + + event = await self.server_notices_manager.send_notice( + user_id=target_user.to_string(), type=event_type, state_key=state_key, event_content=body["content"], + txn_id=txn_id, ) return 200, {"event_id": event.event_id} diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index f19075b760..d87a538917 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -12,26 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +from typing import TYPE_CHECKING, Optional from synapse.api.constants import EventTypes, Membership, RoomCreationPreset from synapse.events import EventBase from synapse.types import UserID, create_requester from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) SERVER_NOTICE_ROOM_TAG = "m.server_notice" class ServerNoticesManager: - def __init__(self, hs): - """ - - Args: - hs (synapse.server.HomeServer): - """ - + def __init__(self, hs: "HomeServer"): self._store = hs.get_datastore() self._config = hs.config self._account_data_handler = hs.get_account_data_handler() @@ -58,6 +55,7 @@ class ServerNoticesManager: event_content: dict, type: str = EventTypes.Message, state_key: Optional[str] = None, + txn_id: Optional[str] = None, ) -> EventBase: """Send a notice to the given user @@ -68,6 +66,7 @@ class ServerNoticesManager: event_content: content of event to send type: type of event is_state_event: Is the event a state event + txn_id: The transaction ID. """ room_id = await self.get_or_create_notice_room_for_user(user_id) await self.maybe_invite_user_to_room(user_id, room_id) @@ -90,7 +89,7 @@ class ServerNoticesManager: event_dict["state_key"] = state_key event, _ = await self._event_creation_handler.create_and_send_nonmember_event( - requester, event_dict, ratelimit=False + requester, event_dict, ratelimit=False, txn_id=txn_id ) return event diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py new file mode 100644 index 0000000000..fbceba3254 --- /dev/null +++ b/tests/rest/admin/test_server_notice.py @@ -0,0 +1,450 @@ +# Copyright 2021 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client import login, room, sync +from synapse.storage.roommember import RoomsForUser +from synapse.types import JsonDict + +from tests import unittest +from tests.unittest import override_config + + +class ServerNoticeTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.room_shutdown_handler = hs.get_room_shutdown_handler() + self.pagination_handler = hs.get_pagination_handler() + self.server_notices_manager = self.hs.get_server_notices_manager() + + # Create user + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + self.url = "/_synapse/admin/v1/send_server_notice" + + def test_no_auth(self): + """Try to send a server notice without authentication.""" + channel = self.make_request("POST", self.url) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """If the user is not a server admin, an error is returned.""" + channel = self.make_request( + "POST", + self.url, + access_token=self.other_user_token, + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_user_does_not_exist(self): + """Tests that a lookup for a user that does not exist returns a 404""" + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"user_id": "@unknown_person:test", "content": ""}, + ) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": "@unknown_person:unknown_domain", + "content": "", + }, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual( + "Server notices can only be sent to local users", channel.json_body["error"] + ) + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_invalid_parameter(self): + """If parameters are invalid, an error is returned.""" + + # no content, no user + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"]) + + # no content + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"user_id": self.other_user}, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + + # no body + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"user_id": self.other_user, "content": ""}, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual("'body' not in content", channel.json_body["error"]) + + # no msgtype + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={"user_id": self.other_user, "content": {"body": ""}}, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual("'msgtype' not in content", channel.json_body["error"]) + + def test_server_notice_disabled(self): + """Tests that server returns error if server notice is disabled""" + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": "", + }, + ) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual( + "Server notices are not enabled on this server", channel.json_body["error"] + ) + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_send_server_notice(self): + """ + Tests that sending two server notices is successfully, + the server uses the same room and do not send messages twice. + """ + # user has no room memberships + self._check_invite_and_join_status(self.other_user, 0, 0) + + # send first message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg one"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) + room_id = invited_rooms[0].room_id + + # user joins the room and is member now + self.helper.join(room=room_id, user=self.other_user, tok=self.other_user_token) + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get messages + messages = self._sync_and_get_messages(room_id, self.other_user_token) + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"]["body"], "test msg one") + self.assertEqual(messages[0]["sender"], "@notices:test") + + # invalidate cache of server notices room_ids + self.get_success( + self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() + ) + + # send second message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg two"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has no new invites or memberships + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get messages + messages = self._sync_and_get_messages(room_id, self.other_user_token) + + self.assertEqual(len(messages), 2) + self.assertEqual(messages[0]["content"]["body"], "test msg one") + self.assertEqual(messages[0]["sender"], "@notices:test") + self.assertEqual(messages[1]["content"]["body"], "test msg two") + self.assertEqual(messages[1]["sender"], "@notices:test") + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_send_server_notice_leave_room(self): + """ + Tests that sending a server notices is successfully. + The user leaves the room and the second message appears + in a new room. + """ + # user has no room memberships + self._check_invite_and_join_status(self.other_user, 0, 0) + + # send first message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg one"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) + first_room_id = invited_rooms[0].room_id + + # user joins the room and is member now + self.helper.join( + room=first_room_id, user=self.other_user, tok=self.other_user_token + ) + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get messages + messages = self._sync_and_get_messages(first_room_id, self.other_user_token) + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"]["body"], "test msg one") + self.assertEqual(messages[0]["sender"], "@notices:test") + + # user leaves the romm + self.helper.leave( + room=first_room_id, user=self.other_user, tok=self.other_user_token + ) + + # user is not member anymore + self._check_invite_and_join_status(self.other_user, 0, 0) + + # invalidate cache of server notices room_ids + # if server tries to send to a cached room_id the user gets the message + # in old room + self.get_success( + self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() + ) + + # send second message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg two"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) + second_room_id = invited_rooms[0].room_id + + # user joins the room and is member now + self.helper.join( + room=second_room_id, user=self.other_user, tok=self.other_user_token + ) + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get messages + messages = self._sync_and_get_messages(second_room_id, self.other_user_token) + + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"]["body"], "test msg two") + self.assertEqual(messages[0]["sender"], "@notices:test") + # room has the same id + self.assertNotEqual(first_room_id, second_room_id) + + @override_config({"server_notices": {"system_mxid_localpart": "notices"}}) + def test_send_server_notice_delete_room(self): + """ + Tests that the user get server notice in a new room + after the first server notice room was deleted. + """ + # user has no room memberships + self._check_invite_and_join_status(self.other_user, 0, 0) + + # send first message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg one"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) + first_room_id = invited_rooms[0].room_id + + # user joins the room and is member now + self.helper.join( + room=first_room_id, user=self.other_user, tok=self.other_user_token + ) + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get messages + messages = self._sync_and_get_messages(first_room_id, self.other_user_token) + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"]["body"], "test msg one") + self.assertEqual(messages[0]["sender"], "@notices:test") + + # shut down and purge room + self.get_success( + self.room_shutdown_handler.shutdown_room(first_room_id, self.admin_user) + ) + self.get_success(self.pagination_handler.purge_room(first_room_id)) + + # user is not member anymore + self._check_invite_and_join_status(self.other_user, 0, 0) + + # It doesn't really matter what API we use here, we just want to assert + # that the room doesn't exist. + summary = self.get_success(self.store.get_room_summary(first_room_id)) + # The summary should be empty since the room doesn't exist. + self.assertEqual(summary, {}) + + # invalidate cache of server notices room_ids + # if server tries to send to a cached room_id it gives an error + self.get_success( + self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all() + ) + + # send second message + channel = self.make_request( + "POST", + self.url, + access_token=self.admin_user_tok, + content={ + "user_id": self.other_user, + "content": {"msgtype": "m.text", "body": "test msg two"}, + }, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # user has one invite + invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0) + second_room_id = invited_rooms[0].room_id + + # user joins the room and is member now + self.helper.join( + room=second_room_id, user=self.other_user, tok=self.other_user_token + ) + self._check_invite_and_join_status(self.other_user, 0, 1) + + # get message + messages = self._sync_and_get_messages(second_room_id, self.other_user_token) + + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["content"]["body"], "test msg two") + self.assertEqual(messages[0]["sender"], "@notices:test") + # second room has new ID + self.assertNotEqual(first_room_id, second_room_id) + + def _check_invite_and_join_status( + self, user_id: str, expected_invites: int, expected_memberships: int + ) -> RoomsForUser: + """Check invite and room membership status of a user. + + Args + user_id: user to check + expected_invites: number of expected invites of this user + expected_memberships: number of expected room memberships of this user + Returns + room_ids from the rooms that the user is invited + """ + + invited_rooms = self.get_success( + self.store.get_invited_rooms_for_local_user(user_id) + ) + self.assertEqual(expected_invites, len(invited_rooms)) + + room_ids = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertEqual(expected_memberships, len(room_ids)) + + return invited_rooms + + def _sync_and_get_messages(self, room_id: str, token: str) -> List[JsonDict]: + """ + Do a sync and get messages of a room. + + Args + room_id: room that contains the messages + token: access token of user + + Returns + list of messages contained in the room + """ + channel = self.make_request( + "GET", "/_matrix/client/r0/sync", access_token=token + ) + self.assertEqual(channel.code, 200) + + # Get the messages + room = channel.json_body["rooms"]["join"][room_id] + messages = [ + x for x in room["timeline"]["events"] if x["type"] == "m.room.message" + ] + return messages -- cgit 1.5.1 From 029b7ad7b94d167b19d63a5dc777a806b0e073f3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 27 Aug 2021 07:08:02 -0400 Subject: Remove unused `compare_digest` function. (#10706) --- changelog.d/10706.misc | 1 + synapse/rest/client/register.py | 13 ------------- 2 files changed, 1 insertion(+), 13 deletions(-) create mode 100644 changelog.d/10706.misc diff --git a/changelog.d/10706.misc b/changelog.d/10706.misc new file mode 100644 index 0000000000..eed4aa58d6 --- /dev/null +++ b/changelog.d/10706.misc @@ -0,0 +1 @@ +Remove unused `compare_digest` function. diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 2781a0ea96..7b5f49d635 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import hmac import logging import random from typing import List, Union @@ -60,18 +59,6 @@ from synapse.util.threepids import ( from ._base import client_patterns, interactive_auth_handler -# We ought to be using hmac.compare_digest() but on older pythons it doesn't -# exist. It's a _really minor_ security flaw to use plain string comparison -# because the timing attack is so obscured by all the other code here it's -# unlikely to make much difference -if hasattr(hmac, "compare_digest"): - compare_digest = hmac.compare_digest -else: - - def compare_digest(a, b): - return a == b - - logger = logging.getLogger(__name__) -- cgit 1.5.1 From 051ddac53b733e5768488bac7548a0c31bf68982 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 27 Aug 2021 12:54:21 +0100 Subject: Clarifications to reverse_proxy.md (#10708) * Update reverse_proxy.md * Create 10708.doc --- changelog.d/10708.doc | 1 + docs/reverse_proxy.md | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10708.doc diff --git a/changelog.d/10708.doc b/changelog.d/10708.doc new file mode 100644 index 0000000000..99f9d69288 --- /dev/null +++ b/changelog.d/10708.doc @@ -0,0 +1 @@ +Minor clarifications to the documentation for reverse proxies. diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md index 5f8d20129e..bc351d604e 100644 --- a/docs/reverse_proxy.md +++ b/docs/reverse_proxy.md @@ -64,6 +64,9 @@ server { server_name matrix.example.com; location ~* ^(\/_matrix|\/_synapse\/client) { + # note: do not add a path (even a single /) after the port in `proxy_pass`, + # otherwise nginx will canonicalise the URI and cause signature verification + # errors. proxy_pass http://localhost:8008; proxy_set_header X-Forwarded-For $remote_addr; proxy_set_header X-Forwarded-Proto $scheme; @@ -76,10 +79,7 @@ server { } ``` -**NOTE**: Do not add a path after the port in `proxy_pass`, otherwise nginx will -canonicalise/normalise the URI. - -### Caddy 1 +### Caddy v1 ``` matrix.example.com { @@ -99,7 +99,7 @@ example.com:8448 { } ``` -### Caddy 2 +### Caddy v2 ``` matrix.example.com { -- cgit 1.5.1 From 54aa7047ebf0d2605e31bdd4933effc4eb63813b Mon Sep 17 00:00:00 2001 From: Azrenbeth <77782548+Azrenbeth@users.noreply.github.com> Date: Fri, 27 Aug 2021 15:19:17 +0100 Subject: Removed page summaries from the top of installation and contributing doc pages (#10711) - Removed page summaries from CONTRIBUTING and installation pages as this information was already in the table of contents on the right hand side - Fixed some broken links in CONTRIBUTING - Added margin-right tag for when table of contents is being shown (otherwise the text in the page sometimes overlaps with it) --- CONTRIBUTING.md | 49 +++++++++----------------------- changelog.d/10711.doc | 1 + docs/setup/installation.md | 39 ------------------------- docs/website_files/table-of-contents.css | 7 ++++- 4 files changed, 21 insertions(+), 75 deletions(-) create mode 100644 changelog.d/10711.doc diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index cd6c34df85..31d0a47fdf 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,30 +2,6 @@ Welcome to Synapse This document aims to get you started with contributing to this repo! -- [1. Who can contribute to Synapse?](#1-who-can-contribute-to-synapse) -- [2. What do I need?](#2-what-do-i-need) -- [3. Get the source.](#3-get-the-source) -- [4. Install the dependencies](#4-install-the-dependencies) - * [Under Unix (macOS, Linux, BSD, ...)](#under-unix-macos-linux-bsd-) - * [Under Windows](#under-windows) -- [5. Get in touch.](#5-get-in-touch) -- [6. Pick an issue.](#6-pick-an-issue) -- [7. Turn coffee and documentation into code and documentation!](#7-turn-coffee-and-documentation-into-code-and-documentation) -- [8. Test, test, test!](#8-test-test-test) - * [Run the linters.](#run-the-linters) - * [Run the unit tests.](#run-the-unit-tests-twisted-trial) - * [Run the integration tests (SyTest).](#run-the-integration-tests-sytest) - * [Run the integration tests (Complement).](#run-the-integration-tests-complement) -- [9. Submit your patch.](#9-submit-your-patch) - * [Changelog](#changelog) - + [How do I know what to call the changelog file before I create the PR?](#how-do-i-know-what-to-call-the-changelog-file-before-i-create-the-pr) - + [Debian changelog](#debian-changelog) - * [Sign off](#sign-off) -- [10. Turn feedback into better code.](#10-turn-feedback-into-better-code) -- [11. Find a new issue.](#11-find-a-new-issue) -- [Notes for maintainers on merging PRs etc](#notes-for-maintainers-on-merging-prs-etc) -- [Conclusion](#conclusion) - # 1. Who can contribute to Synapse? Everyone is welcome to contribute code to [matrix.org @@ -35,7 +11,7 @@ follow a simple 'inbound=outbound' model for contributions: the act of submitting an 'inbound' contribution means that the contributor agrees to license the code under the same terms as the project's overall 'outbound' license - in our case, this is almost always Apache Software License v2 (see -[LICENSE](LICENSE)). +[LICENSE](https://github.com/matrix-org/synapse/blob/develop/LICENSE)). # 2. What do I need? @@ -98,17 +74,20 @@ to work on. # 7. Turn coffee and documentation into code and documentation! -Synapse's code style is documented [here](docs/code_style.md). Please follow -it, including the conventions for the [sample configuration -file](docs/code_style.md#configuration-file-format). +Synapse's code style is documented +[here](https://matrix-org.github.io/synapse/develop/code_style.html). +Please follow it, including the conventions for the +[sample configuration file](https://matrix-org.github.io/synapse/develop/code_style.html#configuration-file-format). -There is a growing amount of documentation located in the [docs](docs) +There is a growing amount of documentation located in the +[docs](https://github.com/matrix-org/synapse/tree/develop/docs) directory. This documentation is intended primarily for sysadmins running their -own Synapse instance, as well as developers interacting externally with -Synapse. [docs/dev](docs/dev) exists primarily to house documentation for -Synapse developers. [docs/admin_api](docs/admin_api) houses documentation -regarding Synapse's Admin API, which is used mostly by sysadmins and external -service developers. +own Synapse instance, as well as developers interacting externally with Synapse. +[docs/development](https://github.com/matrix-org/synapse/tree/develop/docs/development) +exists primarily to house documentation for Synapse developers. +[docs/admin_api](https://github.com/matrix-org/synapse/tree/develop/docs/admin_api) +houses documentation regarding Synapse's Admin API, which is used mostly by sysadmins +and external service developers. If you add new files added to either of these folders, please use [GitHub-Flavoured Markdown](https://guides.github.com/features/mastering-markdown/). @@ -431,7 +410,7 @@ By now, you know the drill! # Notes for maintainers on merging PRs etc There are some notes for those with commit access to the project on how we -manage git [here](docs/development/git.md). +manage git [here](https://matrix-org.github.io/synapse/develop/development/git.html). # Conclusion diff --git a/changelog.d/10711.doc b/changelog.d/10711.doc new file mode 100644 index 0000000000..c495f98be8 --- /dev/null +++ b/changelog.d/10711.doc @@ -0,0 +1 @@ +Removed table of contents from the top of installation and contributing documentation pages. \ No newline at end of file diff --git a/docs/setup/installation.md b/docs/setup/installation.md index 8540a7b0c1..06f869cd75 100644 --- a/docs/setup/installation.md +++ b/docs/setup/installation.md @@ -1,44 +1,5 @@ # Installation Instructions -There are 3 steps to follow under **Installation Instructions**. - -- [Installation Instructions](#installation-instructions) - - [Choosing your server name](#choosing-your-server-name) - - [Installing Synapse](#installing-synapse) - - [Installing from source](#installing-from-source) - - [Platform-specific prerequisites](#platform-specific-prerequisites) - - [Debian/Ubuntu/Raspbian](#debianubunturaspbian) - - [ArchLinux](#archlinux) - - [CentOS/Fedora](#centosfedora) - - [macOS](#macos) - - [OpenSUSE](#opensuse) - - [OpenBSD](#openbsd) - - [Windows](#windows) - - [Prebuilt packages](#prebuilt-packages) - - [Docker images and Ansible playbooks](#docker-images-and-ansible-playbooks) - - [Debian/Ubuntu](#debianubuntu) - - [Matrix.org packages](#matrixorg-packages) - - [Downstream Debian packages](#downstream-debian-packages) - - [Downstream Ubuntu packages](#downstream-ubuntu-packages) - - [Fedora](#fedora) - - [OpenSUSE](#opensuse-1) - - [SUSE Linux Enterprise Server](#suse-linux-enterprise-server) - - [ArchLinux](#archlinux-1) - - [Void Linux](#void-linux) - - [FreeBSD](#freebsd) - - [OpenBSD](#openbsd-1) - - [NixOS](#nixos) - - [Setting up Synapse](#setting-up-synapse) - - [Using PostgreSQL](#using-postgresql) - - [TLS certificates](#tls-certificates) - - [Client Well-Known URI](#client-well-known-uri) - - [Email](#email) - - [Registering a user](#registering-a-user) - - [Setting up a TURN server](#setting-up-a-turn-server) - - [URL previews](#url-previews) - - [Troubleshooting Installation](#troubleshooting-installation) - - ## Choosing your server name It is important to choose the name for your server before you install Synapse, diff --git a/docs/website_files/table-of-contents.css b/docs/website_files/table-of-contents.css index d16bb3b988..1b6f44b66a 100644 --- a/docs/website_files/table-of-contents.css +++ b/docs/website_files/table-of-contents.css @@ -1,3 +1,7 @@ +:root { + --pagetoc-width: 250px; +} + @media only screen and (max-width:1439px) { .sidetoc { display: none; @@ -8,6 +12,7 @@ main { position: relative; margin-left: 100px !important; + margin-right: var(--pagetoc-width) !important; } .sidetoc { margin-left: auto; @@ -18,7 +23,7 @@ } .pagetoc { position: fixed; - width: 250px; + width: var(--pagetoc-width); overflow: auto; right: 20px; height: calc(100% - var(--menu-bar-height)); -- cgit 1.5.1 From 8f98260552f4f39f003bc1fbf6da159d9138081d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 27 Aug 2021 16:33:41 +0100 Subject: Fix incompatibility with Twisted < 21. (#10713) Turns out that the functionality added in #10546 to skip TLS was incompatible with older Twisted versions, so we need to be a bit more inventive. Also, add a test to (hopefully) not break this in future. Sadly, testing TLS is really hard. --- changelog.d/10713.bugfix | 1 + mypy.ini | 1 + synapse/handlers/send_email.py | 65 ++++++++++++++++------ tests/handlers/test_send_email.py | 112 ++++++++++++++++++++++++++++++++++++++ tests/server.py | 15 ++++- 5 files changed, 173 insertions(+), 21 deletions(-) create mode 100644 changelog.d/10713.bugfix create mode 100644 tests/handlers/test_send_email.py diff --git a/changelog.d/10713.bugfix b/changelog.d/10713.bugfix new file mode 100644 index 0000000000..e8caf3d23a --- /dev/null +++ b/changelog.d/10713.bugfix @@ -0,0 +1 @@ +Fix a regression introduced in Synapse 1.41 which broke email transmission on Systems using older versions of the Twisted library. diff --git a/mypy.ini b/mypy.ini index e1b9405daa..349efe37bb 100644 --- a/mypy.ini +++ b/mypy.ini @@ -87,6 +87,7 @@ files = tests/test_utils, tests/handlers/test_password_providers.py, tests/handlers/test_room_summary.py, + tests/handlers/test_send_email.py, tests/rest/client/v1/test_login.py, tests/rest/client/v2_alpha/test_auth.py, tests/util/test_itertools.py, diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py index dda9659c11..a31fe3e3c7 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py @@ -19,9 +19,12 @@ from email.mime.text import MIMEText from io import BytesIO from typing import TYPE_CHECKING, Optional +from pkg_resources import parse_version + +import twisted from twisted.internet.defer import Deferred -from twisted.internet.interfaces import IReactorTCP -from twisted.mail.smtp import ESMTPSenderFactory +from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorTCP +from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory from synapse.logging.context import make_deferred_yieldable @@ -30,6 +33,19 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +_is_old_twisted = parse_version(twisted.__version__) < parse_version("21") + + +class _NoTLSESMTPSender(ESMTPSender): + """Extend ESMTPSender to disable TLS + + Unfortunately, before Twisted 21.2, ESMTPSender doesn't give an easy way to disable + TLS, so we override its internal method which it uses to generate a context factory. + """ + + def _getContextFactory(self) -> Optional[IOpenSSLContextFactory]: + return None + async def _sendmail( reactor: IReactorTCP, @@ -42,7 +58,7 @@ async def _sendmail( password: Optional[bytes] = None, require_auth: bool = False, require_tls: bool = False, - tls_hostname: Optional[str] = None, + enable_tls: bool = True, ) -> None: """A simple wrapper around ESMTPSenderFactory, to allow substitution in tests @@ -57,24 +73,37 @@ async def _sendmail( password: password to give when authenticating require_auth: if auth is not offered, fail the request require_tls: if TLS is not offered, fail the reqest - tls_hostname: TLS hostname to check for. None to disable TLS. + enable_tls: True to enable TLS. If this is False and require_tls is True, + the request will fail. """ msg = BytesIO(msg_bytes) - d: "Deferred[object]" = Deferred() - factory = ESMTPSenderFactory( - username, - password, - from_addr, - to_addr, - msg, - d, - heloFallback=True, - requireAuthentication=require_auth, - requireTransportSecurity=require_tls, - hostname=tls_hostname, - ) + def build_sender_factory(**kwargs) -> ESMTPSenderFactory: + return ESMTPSenderFactory( + username, + password, + from_addr, + to_addr, + msg, + d, + heloFallback=True, + requireAuthentication=require_auth, + requireTransportSecurity=require_tls, + **kwargs, + ) + + if _is_old_twisted: + # before twisted 21.2, we have to override the ESMTPSender protocol to disable + # TLS + factory = build_sender_factory() + + if not enable_tls: + factory.protocol = _NoTLSESMTPSender + else: + # for twisted 21.2 and later, there is a 'hostname' parameter which we should + # set to enable TLS. + factory = build_sender_factory(hostname=smtphost if enable_tls else None) # the IReactorTCP interface claims host has to be a bytes, which seems to be wrong reactor.connectTCP(smtphost, smtpport, factory, timeout=30, bindAddress=None) # type: ignore[arg-type] @@ -154,5 +183,5 @@ class SendEmailHandler: password=self._smtp_pass, require_auth=self._smtp_user is not None, require_tls=self._require_transport_security, - tls_hostname=self._smtp_host if self._enable_tls else None, + enable_tls=self._enable_tls, ) diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py new file mode 100644 index 0000000000..6f77b1237c --- /dev/null +++ b/tests/handlers/test_send_email.py @@ -0,0 +1,112 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Tuple + +from zope.interface import implementer + +from twisted.internet import defer +from twisted.internet.address import IPv4Address +from twisted.internet.defer import ensureDeferred +from twisted.mail import interfaces, smtp + +from tests.server import FakeTransport +from tests.unittest import HomeserverTestCase + + +@implementer(interfaces.IMessageDelivery) +class _DummyMessageDelivery: + def __init__(self): + # (recipient, message) tuples + self.messages: List[Tuple[smtp.Address, bytes]] = [] + + def receivedHeader(self, helo, origin, recipients): + return None + + def validateFrom(self, helo, origin): + return origin + + def record_message(self, recipient: smtp.Address, message: bytes): + self.messages.append((recipient, message)) + + def validateTo(self, user: smtp.User): + return lambda: _DummyMessage(self, user) + + +@implementer(interfaces.IMessageSMTP) +class _DummyMessage: + """IMessageSMTP implementation which saves the message delivered to it + to the _DummyMessageDelivery object. + """ + + def __init__(self, delivery: _DummyMessageDelivery, user: smtp.User): + self._delivery = delivery + self._user = user + self._buffer: List[bytes] = [] + + def lineReceived(self, line): + self._buffer.append(line) + + def eomReceived(self): + message = b"\n".join(self._buffer) + b"\n" + self._delivery.record_message(self._user.dest, message) + return defer.succeed(b"saved") + + def connectionLost(self): + pass + + +class SendEmailHandlerTestCase(HomeserverTestCase): + def test_send_email(self): + """Happy-path test that we can send email to a non-TLS server.""" + h = self.hs.get_send_email_handler() + d = ensureDeferred( + h.send_email( + "foo@bar.com", "test subject", "Tests", "HTML content", "Text content" + ) + ) + # there should be an attempt to connect to localhost:25 + self.assertEqual(len(self.reactor.tcpClients), 1) + (host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[ + 0 + ] + self.assertEqual(host, "localhost") + self.assertEqual(port, 25) + + # wire it up to an SMTP server + message_delivery = _DummyMessageDelivery() + server_protocol = smtp.ESMTP() + server_protocol.delivery = message_delivery + # make sure that the server uses the test reactor to set timeouts + server_protocol.callLater = self.reactor.callLater # type: ignore[assignment] + + client_protocol = client_factory.buildProtocol(None) + client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor)) + server_protocol.makeConnection( + FakeTransport( + client_protocol, + self.reactor, + peer_address=IPv4Address("TCP", "127.0.0.1", 1234), + ) + ) + + # the message should now get delivered + self.get_success(d, by=0.1) + + # check it arrived + self.assertEqual(len(message_delivery.messages), 1) + user, msg = message_delivery.messages.pop() + self.assertEqual(str(user), "foo@bar.com") + self.assertIn(b"Subject: test subject", msg) diff --git a/tests/server.py b/tests/server.py index 6fddd3b305..b861c7b866 100644 --- a/tests/server.py +++ b/tests/server.py @@ -10,9 +10,10 @@ from zope.interface import implementer from twisted.internet import address, threads, udp from twisted.internet._resolver import SimpleResolverComplexifier -from twisted.internet.defer import Deferred, fail, succeed +from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import ( + IAddress, IHostnameResolver, IProtocol, IPullProducer, @@ -511,6 +512,9 @@ class FakeTransport: will get called back for connectionLost() notifications etc. """ + _peer_address: Optional[IAddress] = attr.ib(default=None) + """The value to be returend by getPeer""" + disconnecting = False disconnected = False connected = True @@ -519,7 +523,7 @@ class FakeTransport: autoflush = attr.ib(default=True) def getPeer(self): - return None + return self._peer_address def getHost(self): return None @@ -572,7 +576,12 @@ class FakeTransport: self.producerStreaming = streaming def _produce(): - d = self.producer.resumeProducing() + if not self.producer: + # we've been unregistered + return + # some implementations of IProducer (for example, FileSender) + # don't return a deferred. + d = maybeDeferred(self.producer.resumeProducing) d.addCallback(lambda x: self._reactor.callLater(0.1, _produce)) if not streaming: -- cgit 1.5.1 From 52c7a51cfc568ed61d800df99dcca25dc3b5fe3e Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Tue, 31 Aug 2021 10:09:58 +0100 Subject: Merge pull request from GHSA-3x4c-pq33-4w3q * Add some tests to characterise the problem Some failing. Current states: RoomsMemberListTestCase test_get_member_list ... [OK] test_get_member_list_mixed_memberships ... [OK] test_get_member_list_no_permission ... [OK] test_get_member_list_no_permission_former_member ... [OK] test_get_member_list_no_permission_former_member_with_at_token ... [FAIL] test_get_member_list_no_room ... [OK] test_get_member_list_no_permission_with_at_token ... [FAIL] * Correct the tests * Check user is/was member before divulging room membership * Pull out only the 1 membership event we want. * Update tests/rest/client/v1/test_rooms.py Co-authored-by: Erik Johnston * Fixup tests (following apply review suggestion) Co-authored-by: Erik Johnston --- synapse/handlers/message.py | 23 +++++++++-- tests/rest/client/v1/test_rooms.py | 84 +++++++++++++++++++++++++++++++++++++- 2 files changed, 103 insertions(+), 4 deletions(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 8a0024ce84..101a29c6d3 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -183,20 +183,37 @@ class MessageHandler: if not last_events: raise NotFoundError("Can't find event for token %s" % (at_token,)) + last_event = last_events[0] + + # check whether the user is in the room at that time to determine + # whether they should be treated as peeking. + state_map = await self.state_store.get_state_for_event( + last_event.event_id, + StateFilter.from_types([(EventTypes.Member, user_id)]), + ) + + joined = False + membership_event = state_map.get((EventTypes.Member, user_id)) + if membership_event: + joined = membership_event.membership == Membership.JOIN + + is_peeking = not joined visible_events = await filter_events_for_client( self.storage, user_id, last_events, filter_send_to_client=False, + is_peeking=is_peeking, ) - event = last_events[0] if visible_events: room_state_events = await self.state_store.get_state_for_events( - [event.event_id], state_filter=state_filter + [last_event.event_id], state_filter=state_filter ) - room_state: Mapping[Any, EventBase] = room_state_events[event.event_id] + room_state: Mapping[Any, EventBase] = room_state_events[ + last_event.event_id + ] else: raise AuthError( 403, diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 0c9cbb9aff..50100a5ae4 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -29,7 +29,7 @@ from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import HttpResponseException from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin -from synapse.rest.client import account, directory, login, profile, room +from synapse.rest.client import account, directory, login, profile, room, sync from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.util.stringutils import random_string @@ -381,6 +381,8 @@ class RoomPermissionsTestCase(RoomBase): class RoomsMemberListTestCase(RoomBase): """Tests /rooms/$room_id/members/list REST events.""" + servlets = RoomBase.servlets + [sync.register_servlets] + user_id = "@sid1:red" def test_get_member_list(self): @@ -397,6 +399,86 @@ class RoomsMemberListTestCase(RoomBase): channel = self.make_request("GET", "/rooms/%s/members" % room_id) self.assertEquals(403, channel.code, msg=channel.result["body"]) + def test_get_member_list_no_permission_with_at_token(self): + """ + Tests that a stranger to the room cannot get the member list + (in the case that they use an at token). + """ + room_id = self.helper.create_room_as("@someone.else:red") + + # first sync to get an at token + channel = self.make_request("GET", "/sync") + self.assertEquals(200, channel.code) + sync_token = channel.json_body["next_batch"] + + # check that permission is denied for @sid1:red to get the + # memberships of @someone.else:red's room. + channel = self.make_request( + "GET", + f"/rooms/{room_id}/members?at={sync_token}", + ) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + def test_get_member_list_no_permission_former_member(self): + """ + Tests that a former member of the room can not get the member list. + """ + # create a room, invite the user and the user joins + room_id = self.helper.create_room_as("@alice:red") + self.helper.invite(room_id, "@alice:red", self.user_id) + self.helper.join(room_id, self.user_id) + + # check that the user can see the member list to start with + channel = self.make_request("GET", "/rooms/%s/members" % room_id) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + # ban the user + self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban") + + # check the user can no longer see the member list + channel = self.make_request("GET", "/rooms/%s/members" % room_id) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + + def test_get_member_list_no_permission_former_member_with_at_token(self): + """ + Tests that a former member of the room can not get the member list + (in the case that they use an at token). + """ + # create a room, invite the user and the user joins + room_id = self.helper.create_room_as("@alice:red") + self.helper.invite(room_id, "@alice:red", self.user_id) + self.helper.join(room_id, self.user_id) + + # sync to get an at token + channel = self.make_request("GET", "/sync") + self.assertEquals(200, channel.code) + sync_token = channel.json_body["next_batch"] + + # check that the user can see the member list to start with + channel = self.make_request( + "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) + ) + self.assertEquals(200, channel.code, msg=channel.result["body"]) + + # ban the user (Note: the user is actually allowed to see this event and + # state so that they know they're banned!) + self.helper.change_membership(room_id, "@alice:red", self.user_id, "ban") + + # invite a third user and let them join + self.helper.invite(room_id, "@alice:red", "@bob:red") + self.helper.join(room_id, "@bob:red") + + # now, with the original user, sync again to get a new at token + channel = self.make_request("GET", "/sync") + self.assertEquals(200, channel.code) + sync_token = channel.json_body["next_batch"] + + # check the user can no longer see the updated member list + channel = self.make_request( + "GET", "/rooms/%s/members?at=%s" % (room_id, sync_token) + ) + self.assertEquals(403, channel.code, msg=channel.result["body"]) + def test_get_member_list_mixed_memberships(self): room_creator = "@some_other_guy:red" room_id = self.helper.create_room_as(room_creator) -- cgit 1.5.1 From cb35df940a828bc40b96daed997b5ad4c7842fd3 Mon Sep 17 00:00:00 2001 From: reivilibre <38398653+reivilibre@users.noreply.github.com> Date: Tue, 31 Aug 2021 11:24:09 +0100 Subject: Merge pull request from GHSA-jj53-8fmw-f2w2 --- synapse/groups/groups_server.py | 18 +++++++++++-- tests/rest/client/v2_alpha/test_groups.py | 43 +++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 tests/rest/client/v2_alpha/test_groups.py diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 3dc55ab861..d6b75ac27f 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -332,6 +332,13 @@ class GroupsServerWorkerHandler: requester_user_id, group_id ) + # Note! room_results["is_public"] is about whether the room is considered + # public from the group's point of view. (i.e. whether non-group members + # should be able to see the room is in the group). + # This is not the same as whether the room itself is public (in the sense + # of being visible in the room directory). + # As such, room_results["is_public"] itself is not sufficient to determine + # whether any given user is permitted to see the room's metadata. room_results = await self.store.get_rooms_in_group( group_id, include_private=is_user_in_group ) @@ -341,8 +348,15 @@ class GroupsServerWorkerHandler: room_id = room_result["room_id"] joined_users = await self.store.get_users_in_room(room_id) + + # check the user is actually allowed to see the room before showing it to them + allow_private = requester_user_id in joined_users + entry = await self.room_list_handler.generate_room_entry( - room_id, len(joined_users), with_alias=False, allow_private=True + room_id, + len(joined_users), + with_alias=False, + allow_private=allow_private, ) if not entry: @@ -354,7 +368,7 @@ class GroupsServerWorkerHandler: chunk.sort(key=lambda e: -e["num_joined_members"]) - return {"chunk": chunk, "total_room_count_estimate": len(room_results)} + return {"chunk": chunk, "total_room_count_estimate": len(chunk)} class GroupsServerHandler(GroupsServerWorkerHandler): diff --git a/tests/rest/client/v2_alpha/test_groups.py b/tests/rest/client/v2_alpha/test_groups.py new file mode 100644 index 0000000000..bfa9336baa --- /dev/null +++ b/tests/rest/client/v2_alpha/test_groups.py @@ -0,0 +1,43 @@ +from synapse.rest.client.v1 import room +from synapse.rest.client.v2_alpha import groups + +from tests import unittest +from tests.unittest import override_config + + +class GroupsTestCase(unittest.HomeserverTestCase): + user_id = "@alice:test" + room_creator_user_id = "@bob:test" + + servlets = [room.register_servlets, groups.register_servlets] + + @override_config({"enable_group_creation": True}) + def test_rooms_limited_by_visibility(self): + group_id = "+spqr:test" + + # Alice creates a group + channel = self.make_request("POST", "/create_group", {"localpart": "spqr"}) + self.assertEquals(channel.code, 200, msg=channel.text_body) + self.assertEquals(channel.json_body, {"group_id": group_id}) + + # Bob creates a private room + room_id = self.helper.create_room_as(self.room_creator_user_id, is_public=False) + self.helper.auth_user_id = self.room_creator_user_id + self.helper.send_state( + room_id, "m.room.name", {"name": "bob's secret room"}, tok=None + ) + self.helper.auth_user_id = self.user_id + + # Alice adds the room to her group. + channel = self.make_request( + "PUT", f"/groups/{group_id}/admin/rooms/{room_id}", {} + ) + self.assertEquals(channel.code, 200, msg=channel.text_body) + self.assertEquals(channel.json_body, {}) + + # Alice now tries to retrieve the room list of the space. + channel = self.make_request("GET", f"/groups/{group_id}/rooms") + self.assertEquals(channel.code, 200, msg=channel.text_body) + self.assertEquals( + channel.json_body, {"chunk": [], "total_room_count_estimate": 0} + ) -- cgit 1.5.1 From 46ff99ef95592cd10f2c86ea4f4434c25707bea0 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 31 Aug 2021 11:29:27 +0100 Subject: Advertise matrix-org.github.io/synapse docs (#10595) Point to the book where possible, and use hyperlinks to github to refer to files not included in the book. Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- CONTRIBUTING.md | 423 +------------------------------- README.rst | 37 ++- changelog.d/10595.doc | 1 + docs/development/contributing_guide.md | 430 ++++++++++++++++++++++++++++++++- 4 files changed, 459 insertions(+), 432 deletions(-) create mode 100644 changelog.d/10595.doc diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 31d0a47fdf..2c85edf712 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,422 +1,3 @@ -Welcome to Synapse +# Welcome to Synapse -This document aims to get you started with contributing to this repo! - -# 1. Who can contribute to Synapse? - -Everyone is welcome to contribute code to [matrix.org -projects](https://github.com/matrix-org), provided that they are willing to -license their contributions under the same license as the project itself. We -follow a simple 'inbound=outbound' model for contributions: the act of -submitting an 'inbound' contribution means that the contributor agrees to -license the code under the same terms as the project's overall 'outbound' -license - in our case, this is almost always Apache Software License v2 (see -[LICENSE](https://github.com/matrix-org/synapse/blob/develop/LICENSE)). - -# 2. What do I need? - -The code of Synapse is written in Python 3. To do pretty much anything, you'll need [a recent version of Python 3](https://wiki.python.org/moin/BeginnersGuide/Download). - -The source code of Synapse is hosted on GitHub. You will also need [a recent version of git](https://github.com/git-guides/install-git). - -For some tests, you will need [a recent version of Docker](https://docs.docker.com/get-docker/). - - -# 3. Get the source. - -The preferred and easiest way to contribute changes is to fork the relevant -project on GitHub, and then [create a pull request]( -https://help.github.com/articles/using-pull-requests/) to ask us to pull your -changes into our repo. - -Please base your changes on the `develop` branch. - -```sh -git clone git@github.com:YOUR_GITHUB_USER_NAME/synapse.git -git checkout develop -``` - -If you need help getting started with git, this is beyond the scope of the document, but you -can find many good git tutorials on the web. - -# 4. Install the dependencies - -## Under Unix (macOS, Linux, BSD, ...) - -Once you have installed Python 3 and added the source, please open a terminal and -setup a *virtualenv*, as follows: - -```sh -cd path/where/you/have/cloned/the/repository -python3 -m venv ./env -source ./env/bin/activate -pip install -e ".[all,lint,mypy,test]" -pip install tox -``` - -This will install the developer dependencies for the project. - -## Under Windows - -TBD - - -# 5. Get in touch. - -Join our developer community on Matrix: #synapse-dev:matrix.org ! - - -# 6. Pick an issue. - -Fix your favorite problem or perhaps find a [Good First Issue](https://github.com/matrix-org/synapse/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+First+Issue%22) -to work on. - - -# 7. Turn coffee and documentation into code and documentation! - -Synapse's code style is documented -[here](https://matrix-org.github.io/synapse/develop/code_style.html). -Please follow it, including the conventions for the -[sample configuration file](https://matrix-org.github.io/synapse/develop/code_style.html#configuration-file-format). - -There is a growing amount of documentation located in the -[docs](https://github.com/matrix-org/synapse/tree/develop/docs) -directory. This documentation is intended primarily for sysadmins running their -own Synapse instance, as well as developers interacting externally with Synapse. -[docs/development](https://github.com/matrix-org/synapse/tree/develop/docs/development) -exists primarily to house documentation for Synapse developers. -[docs/admin_api](https://github.com/matrix-org/synapse/tree/develop/docs/admin_api) -houses documentation regarding Synapse's Admin API, which is used mostly by sysadmins -and external service developers. - -If you add new files added to either of these folders, please use [GitHub-Flavoured -Markdown](https://guides.github.com/features/mastering-markdown/). - -Some documentation also exists in [Synapse's GitHub -Wiki](https://github.com/matrix-org/synapse/wiki), although this is primarily -contributed to by community authors. - - -# 8. Test, test, test! - - -While you're developing and before submitting a patch, you'll -want to test your code. - -## Run the linters. - -The linters look at your code and do two things: - -- ensure that your code follows the coding style adopted by the project; -- catch a number of errors in your code. - -They're pretty fast, don't hesitate! - -```sh -source ./env/bin/activate -./scripts-dev/lint.sh -``` - -Note that this script *will modify your files* to fix styling errors. -Make sure that you have saved all your files. - -If you wish to restrict the linters to only the files changed since the last commit -(much faster!), you can instead run: - -```sh -source ./env/bin/activate -./scripts-dev/lint.sh -d -``` - -Or if you know exactly which files you wish to lint, you can instead run: - -```sh -source ./env/bin/activate -./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder -``` - -## Run the unit tests (Twisted trial). - -The unit tests run parts of Synapse, including your changes, to see if anything -was broken. They are slower than the linters but will typically catch more errors. - -```sh -source ./env/bin/activate -trial tests -``` - -If you wish to only run *some* unit tests, you may specify -another module instead of `tests` - or a test class or a method: - -```sh -source ./env/bin/activate -trial tests.rest.admin.test_room tests.handlers.test_admin.ExfiltrateData.test_invite -``` - -If your tests fail, you may wish to look at the logs (the default log level is `ERROR`): - -```sh -less _trial_temp/test.log -``` - -To increase the log level for the tests, set `SYNAPSE_TEST_LOG_LEVEL`: - -```sh -SYNAPSE_TEST_LOG_LEVEL=DEBUG trial tests -``` - - -## Run the integration tests ([Sytest](https://github.com/matrix-org/sytest)). - -The integration tests are a more comprehensive suite of tests. They -run a full version of Synapse, including your changes, to check if -anything was broken. They are slower than the unit tests but will -typically catch more errors. - -The following command will let you run the integration test with the most common -configuration: - -```sh -$ docker run --rm -it -v /path/where/you/have/cloned/the/repository\:/src:ro -v /path/to/where/you/want/logs\:/logs matrixdotorg/sytest-synapse:buster -``` - -This configuration should generally cover your needs. For more details about other configurations, see [documentation in the SyTest repo](https://github.com/matrix-org/sytest/blob/develop/docker/README.md). - - -## Run the integration tests ([Complement](https://github.com/matrix-org/complement)). - -[Complement](https://github.com/matrix-org/complement) is a suite of black box tests that can be run on any homeserver implementation. It can also be thought of as end-to-end (e2e) tests. - -It's often nice to develop on Synapse and write Complement tests at the same time. -Here is how to run your local Synapse checkout against your local Complement checkout. - -(checkout [`complement`](https://github.com/matrix-org/complement) alongside your `synapse` checkout) -```sh -COMPLEMENT_DIR=../complement ./scripts-dev/complement.sh -``` - -To run a specific test file, you can pass the test name at the end of the command. The name passed comes from the naming structure in your Complement tests. If you're unsure of the name, you can do a full run and copy it from the test output: - -```sh -COMPLEMENT_DIR=../complement ./scripts-dev/complement.sh TestBackfillingHistory -``` - -To run a specific test, you can specify the whole name structure: - -```sh -COMPLEMENT_DIR=../complement ./scripts-dev/complement.sh TestBackfillingHistory/parallel/Backfilled_historical_events_resolve_with_proper_state_in_correct_order -``` - - -### Access database for homeserver after Complement test runs. - -If you're curious what the database looks like after you run some tests, here are some steps to get you going in Synapse: - - 1. In your Complement test comment out `defer deployment.Destroy(t)` and replace with `defer time.Sleep(2 * time.Hour)` to keep the homeserver running after the tests complete - 1. Start the Complement tests - 1. Find the name of the container, `docker ps -f name=complement_` (this will filter for just the Compelement related Docker containers) - 1. Access the container replacing the name with what you found in the previous step: `docker exec -it complement_1_hs_with_application_service.hs1_2 /bin/bash` - 1. Install sqlite (database driver), `apt-get update && apt-get install -y sqlite3` - 1. Then run `sqlite3` and open the database `.open /conf/homeserver.db` (this db path comes from the Synapse homeserver.yaml) - - -# 9. Submit your patch. - -Once you're happy with your patch, it's time to prepare a Pull Request. - -To prepare a Pull Request, please: - -1. verify that [all the tests pass](#test-test-test), including the coding style; -2. [sign off](#sign-off) your contribution; -3. `git push` your commit to your fork of Synapse; -4. on GitHub, [create the Pull Request](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/creating-a-pull-request); -5. add a [changelog entry](#changelog) and push it to your Pull Request; -6. for most contributors, that's all - however, if you are a member of the organization `matrix-org`, on GitHub, please request a review from `matrix.org / Synapse Core`. -7. if you need to update your PR, please avoid rebasing and just add new commits to your branch. - - -## Changelog - -All changes, even minor ones, need a corresponding changelog / newsfragment -entry. These are managed by [Towncrier](https://github.com/hawkowl/towncrier). - -To create a changelog entry, make a new file in the `changelog.d` directory named -in the format of `PRnumber.type`. The type can be one of the following: - -* `feature` -* `bugfix` -* `docker` (for updates to the Docker image) -* `doc` (for updates to the documentation) -* `removal` (also used for deprecations) -* `misc` (for internal-only changes) - -This file will become part of our [changelog]( -https://github.com/matrix-org/synapse/blob/master/CHANGES.md) at the next -release, so the content of the file should be a short description of your -change in the same style as the rest of the changelog. The file can contain Markdown -formatting, and should end with a full stop (.) or an exclamation mark (!) for -consistency. - -Adding credits to the changelog is encouraged, we value your -contributions and would like to have you shouted out in the release notes! - -For example, a fix in PR #1234 would have its changelog entry in -`changelog.d/1234.bugfix`, and contain content like: - -> The security levels of Florbs are now validated when received -> via the `/federation/florb` endpoint. Contributed by Jane Matrix. - -If there are multiple pull requests involved in a single bugfix/feature/etc, -then the content for each `changelog.d` file should be the same. Towncrier will -merge the matching files together into a single changelog entry when we come to -release. - -### How do I know what to call the changelog file before I create the PR? - -Obviously, you don't know if you should call your newsfile -`1234.bugfix` or `5678.bugfix` until you create the PR, which leads to a -chicken-and-egg problem. - -There are two options for solving this: - - 1. Open the PR without a changelog file, see what number you got, and *then* - add the changelog file to your branch (see [Updating your pull - request](#updating-your-pull-request)), or: - - 1. Look at the [list of all - issues/PRs](https://github.com/matrix-org/synapse/issues?q=), add one to the - highest number you see, and quickly open the PR before somebody else claims - your number. - - [This - script](https://github.com/richvdh/scripts/blob/master/next_github_number.sh) - might be helpful if you find yourself doing this a lot. - -Sorry, we know it's a bit fiddly, but it's *really* helpful for us when we come -to put together a release! - -### Debian changelog - -Changes which affect the debian packaging files (in `debian`) are an -exception to the rule that all changes require a `changelog.d` file. - -In this case, you will need to add an entry to the debian changelog for the -next release. For this, run the following command: - -``` -dch -``` - -This will make up a new version number (if there isn't already an unreleased -version in flight), and open an editor where you can add a new changelog entry. -(Our release process will ensure that the version number and maintainer name is -corrected for the release.) - -If your change affects both the debian packaging *and* files outside the debian -directory, you will need both a regular newsfragment *and* an entry in the -debian changelog. (Though typically such changes should be submitted as two -separate pull requests.) - -## Sign off - -In order to have a concrete record that your contribution is intentional -and you agree to license it under the same terms as the project's license, we've adopted the -same lightweight approach that the Linux Kernel -[submitting patches process]( -https://www.kernel.org/doc/html/latest/process/submitting-patches.html#sign-your-work-the-developer-s-certificate-of-origin>), -[Docker](https://github.com/docker/docker/blob/master/CONTRIBUTING.md), and many other -projects use: the DCO (Developer Certificate of Origin: -http://developercertificate.org/). This is a simple declaration that you wrote -the contribution or otherwise have the right to contribute it to Matrix: - -``` -Developer Certificate of Origin -Version 1.1 - -Copyright (C) 2004, 2006 The Linux Foundation and its contributors. -660 York Street, Suite 102, -San Francisco, CA 94110 USA - -Everyone is permitted to copy and distribute verbatim copies of this -license document, but changing it is not allowed. - -Developer's Certificate of Origin 1.1 - -By making a contribution to this project, I certify that: - -(a) The contribution was created in whole or in part by me and I - have the right to submit it under the open source license - indicated in the file; or - -(b) The contribution is based upon previous work that, to the best - of my knowledge, is covered under an appropriate open source - license and I have the right under that license to submit that - work with modifications, whether created in whole or in part - by me, under the same open source license (unless I am - permitted to submit under a different license), as indicated - in the file; or - -(c) The contribution was provided directly to me by some other - person who certified (a), (b) or (c) and I have not modified - it. - -(d) I understand and agree that this project and the contribution - are public and that a record of the contribution (including all - personal information I submit with it, including my sign-off) is - maintained indefinitely and may be redistributed consistent with - this project or the open source license(s) involved. -``` - -If you agree to this for your contribution, then all that's needed is to -include the line in your commit or pull request comment: - -``` -Signed-off-by: Your Name -``` - -We accept contributions under a legally identifiable name, such as -your name on government documentation or common-law names (names -claimed by legitimate usage or repute). Unfortunately, we cannot -accept anonymous contributions at this time. - -Git allows you to add this signoff automatically when using the `-s` -flag to `git commit`, which uses the name and email set in your -`user.name` and `user.email` git configs. - - -# 10. Turn feedback into better code. - -Once the Pull Request is opened, you will see a few things: - -1. our automated CI (Continuous Integration) pipeline will run (again) the linters, the unit tests, the integration tests and more; -2. one or more of the developers will take a look at your Pull Request and offer feedback. - -From this point, you should: - -1. Look at the results of the CI pipeline. - - If there is any error, fix the error. -2. If a developer has requested changes, make these changes and let us know if it is ready for a developer to review again. -3. Create a new commit with the changes. - - Please do NOT overwrite the history. New commits make the reviewer's life easier. - - Push this commits to your Pull Request. -4. Back to 1. - -Once both the CI and the developers are happy, the patch will be merged into Synapse and released shortly! - -# 11. Find a new issue. - -By now, you know the drill! - -# Notes for maintainers on merging PRs etc - -There are some notes for those with commit access to the project on how we -manage git [here](https://matrix-org.github.io/synapse/develop/development/git.html). - -# Conclusion - -That's it! Matrix is a very open and collaborative project as you might expect -given our obsession with open communication. If we're going to successfully -matrix together all the fragmented communication technologies out there we are -reliant on contributions and collaboration from the community to do so. So -please get involved - and we hope you have as much fun hacking on Matrix as we -do! +Please see the [contributors' guide](https://matrix-org.github.io/synapse/latest/development/contributing_guide.html) in our rendered documentation. diff --git a/README.rst b/README.rst index 0ae05616e7..db977c025f 100644 --- a/README.rst +++ b/README.rst @@ -1,6 +1,6 @@ -========================================================= -Synapse |support| |development| |license| |pypi| |python| -========================================================= +========================================================================= +Synapse |support| |development| |documentation| |license| |pypi| |python| +========================================================================= .. contents:: @@ -85,9 +85,14 @@ For support installing or managing Synapse, please join |room|_ (from a matrix.o account if necessary) and ask questions there. We do not use GitHub issues for support requests, only for bug reports and feature requests. +Synapse's documentation is `nicely rendered on GitHub Pages `_, +with its source available in |docs|_. + .. |room| replace:: ``#synapse:matrix.org`` .. _room: https://matrix.to/#/#synapse:matrix.org +.. |docs| replace:: ``docs`` +.. _docs: docs Synapse Installation ==================== @@ -263,7 +268,23 @@ Then update the ``users`` table in the database:: Synapse Development =================== -Join our developer community on Matrix: `#synapse-dev:matrix.org `_ +The best place to get started is our +`guide for contributors `_. +This is part of our larger `documentation `_, which includes +information for synapse developers as well as synapse administrators. + +Developers might be particularly interested in: + +* `Synapse's database schema `_, +* `notes on Synapse's implementation details `_, and +* `how we use git `_. + +Alongside all that, join our developer community on Matrix: +`#synapse-dev:matrix.org `_, featuring real humans! + + +Quick start +----------- Before setting up a development environment for synapse, make sure you have the system dependencies (such as the python header files) installed - see @@ -308,7 +329,7 @@ If you just want to start a single instance of the app and run it directly:: Running the unit tests -====================== +---------------------- After getting up and running, you may wish to run Synapse's unit tests to check that everything is installed correctly:: @@ -327,7 +348,7 @@ to see the logging output, see the `CONTRIBUTING doc `_, a Matrix homeserver integration testing suite, which uses HTTP requests to @@ -445,6 +466,10 @@ This is normally caused by a misconfiguration in your reverse-proxy. See :alt: (discuss development on #synapse-dev:matrix.org) :target: https://matrix.to/#/#synapse-dev:matrix.org +.. |documentation| image:: https://img.shields.io/badge/documentation-%E2%9C%93-success + :alt: (Rendered documentation on GitHub Pages) + :target: https://matrix-org.github.io/synapse/latest/ + .. |license| image:: https://img.shields.io/github/license/matrix-org/synapse :alt: (check license in LICENSE file) :target: LICENSE diff --git a/changelog.d/10595.doc b/changelog.d/10595.doc new file mode 100644 index 0000000000..4823146d6b --- /dev/null +++ b/changelog.d/10595.doc @@ -0,0 +1 @@ +Advertise https://matrix-org.github.io/synapse docs in README and CONTRIBUTING files. diff --git a/docs/development/contributing_guide.md b/docs/development/contributing_guide.md index ddf0887123..97352b0f26 100644 --- a/docs/development/contributing_guide.md +++ b/docs/development/contributing_guide.md @@ -1,7 +1,427 @@ - # Contributing -{{#include ../../CONTRIBUTING.md}} +This document aims to get you started with contributing to Synapse! + +# 1. Who can contribute to Synapse? + +Everyone is welcome to contribute code to [matrix.org +projects](https://github.com/matrix-org), provided that they are willing to +license their contributions under the same license as the project itself. We +follow a simple 'inbound=outbound' model for contributions: the act of +submitting an 'inbound' contribution means that the contributor agrees to +license the code under the same terms as the project's overall 'outbound' +license - in our case, this is almost always Apache Software License v2 (see +[LICENSE](https://github.com/matrix-org/synapse/blob/develop/LICENSE)). + +# 2. What do I need? + +The code of Synapse is written in Python 3. To do pretty much anything, you'll need [a recent version of Python 3](https://wiki.python.org/moin/BeginnersGuide/Download). + +The source code of Synapse is hosted on GitHub. You will also need [a recent version of git](https://github.com/git-guides/install-git). + +For some tests, you will need [a recent version of Docker](https://docs.docker.com/get-docker/). + + +# 3. Get the source. + +The preferred and easiest way to contribute changes is to fork the relevant +project on GitHub, and then [create a pull request]( +https://help.github.com/articles/using-pull-requests/) to ask us to pull your +changes into our repo. + +Please base your changes on the `develop` branch. + +```sh +git clone git@github.com:YOUR_GITHUB_USER_NAME/synapse.git +git checkout develop +``` + +If you need help getting started with git, this is beyond the scope of the document, but you +can find many good git tutorials on the web. + +# 4. Install the dependencies + +## Under Unix (macOS, Linux, BSD, ...) + +Once you have installed Python 3 and added the source, please open a terminal and +setup a *virtualenv*, as follows: + +```sh +cd path/where/you/have/cloned/the/repository +python3 -m venv ./env +source ./env/bin/activate +pip install -e ".[all,lint,mypy,test]" +pip install tox +``` + +This will install the developer dependencies for the project. + +## Under Windows + +TBD + + +# 5. Get in touch. + +Join our developer community on Matrix: #synapse-dev:matrix.org ! + + +# 6. Pick an issue. + +Fix your favorite problem or perhaps find a [Good First Issue](https://github.com/matrix-org/synapse/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+First+Issue%22) +to work on. + + +# 7. Turn coffee into code and documentation! + +There is a growing amount of documentation located in the +[`docs`](https://github.com/matrix-org/synapse/tree/develop/docs) +directory, with a rendered version [available online](https://matrix-org.github.io/synapse). +This documentation is intended primarily for sysadmins running their +own Synapse instance, as well as developers interacting externally with +Synapse. +[`docs/development`](https://github.com/matrix-org/synapse/tree/develop/docs/development) +exists primarily to house documentation for +Synapse developers. +[`docs/admin_api`](https://github.com/matrix-org/synapse/tree/develop/docs/admin_api) houses documentation +regarding Synapse's Admin API, which is used mostly by sysadmins and external +service developers. + +Synapse's code style is documented [here](../code_style.md). Please follow +it, including the conventions for the [sample configuration +file](../code_style.md#configuration-file-format). + +We welcome improvements and additions to our documentation itself! When +writing new pages, please +[build `docs` to a book](https://github.com/matrix-org/synapse/tree/develop/docs#adding-to-the-documentation) +to check that your contributions render correctly. The docs are written in +[GitHub-Flavoured Markdown](https://guides.github.com/features/mastering-markdown/). + +Some documentation also exists in [Synapse's GitHub +Wiki](https://github.com/matrix-org/synapse/wiki), although this is primarily +contributed to by community authors. + + +# 8. Test, test, test! + + +While you're developing and before submitting a patch, you'll +want to test your code. + +## Run the linters. + +The linters look at your code and do two things: + +- ensure that your code follows the coding style adopted by the project; +- catch a number of errors in your code. + +They're pretty fast, don't hesitate! + +```sh +source ./env/bin/activate +./scripts-dev/lint.sh +``` + +Note that this script *will modify your files* to fix styling errors. +Make sure that you have saved all your files. + +If you wish to restrict the linters to only the files changed since the last commit +(much faster!), you can instead run: + +```sh +source ./env/bin/activate +./scripts-dev/lint.sh -d +``` + +Or if you know exactly which files you wish to lint, you can instead run: + +```sh +source ./env/bin/activate +./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder +``` + +## Run the unit tests (Twisted trial). + +The unit tests run parts of Synapse, including your changes, to see if anything +was broken. They are slower than the linters but will typically catch more errors. + +```sh +source ./env/bin/activate +trial tests +``` + +If you wish to only run *some* unit tests, you may specify +another module instead of `tests` - or a test class or a method: + +```sh +source ./env/bin/activate +trial tests.rest.admin.test_room tests.handlers.test_admin.ExfiltrateData.test_invite +``` + +If your tests fail, you may wish to look at the logs (the default log level is `ERROR`): + +```sh +less _trial_temp/test.log +``` + +To increase the log level for the tests, set `SYNAPSE_TEST_LOG_LEVEL`: + +```sh +SYNAPSE_TEST_LOG_LEVEL=DEBUG trial tests +``` + + +## Run the integration tests ([Sytest](https://github.com/matrix-org/sytest)). + +The integration tests are a more comprehensive suite of tests. They +run a full version of Synapse, including your changes, to check if +anything was broken. They are slower than the unit tests but will +typically catch more errors. + +The following command will let you run the integration test with the most common +configuration: + +```sh +$ docker run --rm -it -v /path/where/you/have/cloned/the/repository\:/src:ro -v /path/to/where/you/want/logs\:/logs matrixdotorg/sytest-synapse:buster +``` + +This configuration should generally cover your needs. For more details about other configurations, see [documentation in the SyTest repo](https://github.com/matrix-org/sytest/blob/develop/docker/README.md). + + +## Run the integration tests ([Complement](https://github.com/matrix-org/complement)). + +[Complement](https://github.com/matrix-org/complement) is a suite of black box tests that can be run on any homeserver implementation. It can also be thought of as end-to-end (e2e) tests. + +It's often nice to develop on Synapse and write Complement tests at the same time. +Here is how to run your local Synapse checkout against your local Complement checkout. + +(checkout [`complement`](https://github.com/matrix-org/complement) alongside your `synapse` checkout) +```sh +COMPLEMENT_DIR=../complement ./scripts-dev/complement.sh +``` + +To run a specific test file, you can pass the test name at the end of the command. The name passed comes from the naming structure in your Complement tests. If you're unsure of the name, you can do a full run and copy it from the test output: + +```sh +COMPLEMENT_DIR=../complement ./scripts-dev/complement.sh TestBackfillingHistory +``` + +To run a specific test, you can specify the whole name structure: + +```sh +COMPLEMENT_DIR=../complement ./scripts-dev/complement.sh TestBackfillingHistory/parallel/Backfilled_historical_events_resolve_with_proper_state_in_correct_order +``` + + +### Access database for homeserver after Complement test runs. + +If you're curious what the database looks like after you run some tests, here are some steps to get you going in Synapse: + +1. In your Complement test comment out `defer deployment.Destroy(t)` and replace with `defer time.Sleep(2 * time.Hour)` to keep the homeserver running after the tests complete +1. Start the Complement tests +1. Find the name of the container, `docker ps -f name=complement_` (this will filter for just the Compelement related Docker containers) +1. Access the container replacing the name with what you found in the previous step: `docker exec -it complement_1_hs_with_application_service.hs1_2 /bin/bash` +1. Install sqlite (database driver), `apt-get update && apt-get install -y sqlite3` +1. Then run `sqlite3` and open the database `.open /conf/homeserver.db` (this db path comes from the Synapse homeserver.yaml) + + +# 9. Submit your patch. + +Once you're happy with your patch, it's time to prepare a Pull Request. + +To prepare a Pull Request, please: + +1. verify that [all the tests pass](#test-test-test), including the coding style; +2. [sign off](#sign-off) your contribution; +3. `git push` your commit to your fork of Synapse; +4. on GitHub, [create the Pull Request](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/creating-a-pull-request); +5. add a [changelog entry](#changelog) and push it to your Pull Request; +6. for most contributors, that's all - however, if you are a member of the organization `matrix-org`, on GitHub, please request a review from `matrix.org / Synapse Core`. +7. if you need to update your PR, please avoid rebasing and just add new commits to your branch. + + +## Changelog + +All changes, even minor ones, need a corresponding changelog / newsfragment +entry. These are managed by [Towncrier](https://github.com/hawkowl/towncrier). + +To create a changelog entry, make a new file in the `changelog.d` directory named +in the format of `PRnumber.type`. The type can be one of the following: + +* `feature` +* `bugfix` +* `docker` (for updates to the Docker image) +* `doc` (for updates to the documentation) +* `removal` (also used for deprecations) +* `misc` (for internal-only changes) + +This file will become part of our [changelog]( +https://github.com/matrix-org/synapse/blob/master/CHANGES.md) at the next +release, so the content of the file should be a short description of your +change in the same style as the rest of the changelog. The file can contain Markdown +formatting, and should end with a full stop (.) or an exclamation mark (!) for +consistency. + +Adding credits to the changelog is encouraged, we value your +contributions and would like to have you shouted out in the release notes! + +For example, a fix in PR #1234 would have its changelog entry in +`changelog.d/1234.bugfix`, and contain content like: + +> The security levels of Florbs are now validated when received +> via the `/federation/florb` endpoint. Contributed by Jane Matrix. + +If there are multiple pull requests involved in a single bugfix/feature/etc, +then the content for each `changelog.d` file should be the same. Towncrier will +merge the matching files together into a single changelog entry when we come to +release. + +### How do I know what to call the changelog file before I create the PR? + +Obviously, you don't know if you should call your newsfile +`1234.bugfix` or `5678.bugfix` until you create the PR, which leads to a +chicken-and-egg problem. + +There are two options for solving this: + +1. Open the PR without a changelog file, see what number you got, and *then* + add the changelog file to your branch (see [Updating your pull + request](#updating-your-pull-request)), or: + +1. Look at the [list of all + issues/PRs](https://github.com/matrix-org/synapse/issues?q=), add one to the + highest number you see, and quickly open the PR before somebody else claims + your number. + + [This + script](https://github.com/richvdh/scripts/blob/master/next_github_number.sh) + might be helpful if you find yourself doing this a lot. + +Sorry, we know it's a bit fiddly, but it's *really* helpful for us when we come +to put together a release! + +### Debian changelog + +Changes which affect the debian packaging files (in `debian`) are an +exception to the rule that all changes require a `changelog.d` file. + +In this case, you will need to add an entry to the debian changelog for the +next release. For this, run the following command: + +``` +dch +``` + +This will make up a new version number (if there isn't already an unreleased +version in flight), and open an editor where you can add a new changelog entry. +(Our release process will ensure that the version number and maintainer name is +corrected for the release.) + +If your change affects both the debian packaging *and* files outside the debian +directory, you will need both a regular newsfragment *and* an entry in the +debian changelog. (Though typically such changes should be submitted as two +separate pull requests.) + +## Sign off + +In order to have a concrete record that your contribution is intentional +and you agree to license it under the same terms as the project's license, we've adopted the +same lightweight approach that the Linux Kernel +[submitting patches process]( +https://www.kernel.org/doc/html/latest/process/submitting-patches.html#sign-your-work-the-developer-s-certificate-of-origin>), +[Docker](https://github.com/docker/docker/blob/master/CONTRIBUTING.md), and many other +projects use: the DCO (Developer Certificate of Origin: +http://developercertificate.org/). This is a simple declaration that you wrote +the contribution or otherwise have the right to contribute it to Matrix: + +``` +Developer Certificate of Origin +Version 1.1 + +Copyright (C) 2004, 2006 The Linux Foundation and its contributors. +660 York Street, Suite 102, +San Francisco, CA 94110 USA + +Everyone is permitted to copy and distribute verbatim copies of this +license document, but changing it is not allowed. + +Developer's Certificate of Origin 1.1 + +By making a contribution to this project, I certify that: + +(a) The contribution was created in whole or in part by me and I + have the right to submit it under the open source license + indicated in the file; or + +(b) The contribution is based upon previous work that, to the best + of my knowledge, is covered under an appropriate open source + license and I have the right under that license to submit that + work with modifications, whether created in whole or in part + by me, under the same open source license (unless I am + permitted to submit under a different license), as indicated + in the file; or + +(c) The contribution was provided directly to me by some other + person who certified (a), (b) or (c) and I have not modified + it. + +(d) I understand and agree that this project and the contribution + are public and that a record of the contribution (including all + personal information I submit with it, including my sign-off) is + maintained indefinitely and may be redistributed consistent with + this project or the open source license(s) involved. +``` + +If you agree to this for your contribution, then all that's needed is to +include the line in your commit or pull request comment: + +``` +Signed-off-by: Your Name +``` + +We accept contributions under a legally identifiable name, such as +your name on government documentation or common-law names (names +claimed by legitimate usage or repute). Unfortunately, we cannot +accept anonymous contributions at this time. + +Git allows you to add this signoff automatically when using the `-s` +flag to `git commit`, which uses the name and email set in your +`user.name` and `user.email` git configs. + + +# 10. Turn feedback into better code. + +Once the Pull Request is opened, you will see a few things: + +1. our automated CI (Continuous Integration) pipeline will run (again) the linters, the unit tests, the integration tests and more; +2. one or more of the developers will take a look at your Pull Request and offer feedback. + +From this point, you should: + +1. Look at the results of the CI pipeline. + - If there is any error, fix the error. +2. If a developer has requested changes, make these changes and let us know if it is ready for a developer to review again. +3. Create a new commit with the changes. + - Please do NOT overwrite the history. New commits make the reviewer's life easier. + - Push this commits to your Pull Request. +4. Back to 1. + +Once both the CI and the developers are happy, the patch will be merged into Synapse and released shortly! + +# 11. Find a new issue. + +By now, you know the drill! + +# Notes for maintainers on merging PRs etc + +There are some notes for those with commit access to the project on how we +manage git [here](git.md). + +# Conclusion + +That's it! Matrix is a very open and collaborative project as you might expect +given our obsession with open communication. If we're going to successfully +matrix together all the fragmented communication technologies out there we are +reliant on contributions and collaboration from the community to do so. So +please get involved - and we hope you have as much fun hacking on Matrix as we +do! -- cgit 1.5.1 From 8c26f16c76b475e0ace7b58920d90368b180454c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 31 Aug 2021 12:56:22 +0100 Subject: Fix up unit tests (#10723) These were broken in an incorrect merge of GHSA-jj53-8fmw-f2w2 (cb35df9) --- changelog.d/10723.bugfix | 1 + tests/rest/client/v2_alpha/test_groups.py | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) create mode 100644 changelog.d/10723.bugfix diff --git a/changelog.d/10723.bugfix b/changelog.d/10723.bugfix new file mode 100644 index 0000000000..e6ffdc9512 --- /dev/null +++ b/changelog.d/10723.bugfix @@ -0,0 +1 @@ +Fix unauthorised exposure of room metadata to communities. diff --git a/tests/rest/client/v2_alpha/test_groups.py b/tests/rest/client/v2_alpha/test_groups.py index bfa9336baa..ad0425ae65 100644 --- a/tests/rest/client/v2_alpha/test_groups.py +++ b/tests/rest/client/v2_alpha/test_groups.py @@ -1,5 +1,18 @@ -from synapse.rest.client.v1 import room -from synapse.rest.client.v2_alpha import groups +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest.client import groups, room from tests import unittest from tests.unittest import override_config -- cgit 1.5.1 From a4c8a2f08b735266fbbe2f259e640f00dc5e3a00 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 31 Aug 2021 13:42:51 +0100 Subject: 1.41.1 --- CHANGES.md | 32 ++++++++++++++++++++++++++++++++ debian/changelog | 6 ++++++ synapse/__init__.py | 2 +- 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index f8da8771aa..fab27b874e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,35 @@ +Synapse 1.41.1 (2021-08-31) +=========================== + +Due to the two security issues highlighted below, server administrators are encouraged to update Synapse. We are not aware of these vulnerabilities being exploited in the wild. + +Security advisory +----------------- + +The following issues are fixed in v1.41.1. + +- **[GHSA-3x4c-pq33-4w3q](https://github.com/matrix-org/synapse/security/advisories/GHSA-3x4c-pq33-4w3q) / [CVE-2021-39164](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-39164): Enumerating a private room's list of members and their display names.** + + If an unauthorized user both knows the Room ID of a private room *and* that room's history visibility is set to `shared`, then they may be able to enumerate the room's members, including their display names. + + The unauthorized user must be on the same homeserver as a user who is a member of the target room. + + Fixed by [52c7a51cf](https://github.com/matrix-org/synapse/commit/52c7a51cf). + +- **[GHSA-jj53-8fmw-f2w2](https://github.com/matrix-org/synapse/security/advisories/GHSA-jj53-8fmw-f2w2) / [CVE-2021-39163](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-39163): Disclosing a private room's name, avatar, topic, and number of members.** + + If an unauthorized user knows the Room ID of a private room, then its name, avatar, topic, and number of members may be disclosed through Group / Community features. + + The unauthorized user must be on the same homeserver as a user who is a member of the target room, and their homeserver must allow non-administrators to create groups (`enable_group_creation` in the Synapse configuration; off by default). + + Fixed by [cb35df940a](https://github.com/matrix-org/synapse/commit/cb35df940a), [\#10723](https://github.com/matrix-org/synapse/issues/10723). + +Bugfixes +-------- + +- Fix a regression introduced in Synapse 1.41 which broke email transmission on systems using older versions of the Twisted library. ([\#10713](https://github.com/matrix-org/synapse/issues/10713)) + + Synapse 1.41.0 (2021-08-24) =========================== diff --git a/debian/changelog b/debian/changelog index 4da4bc018c..5f7a795b6e 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.41.1) stable; urgency=high + + * New synapse release 1.41.1. + + -- Synapse Packaging team Tue, 31 Aug 2021 12:59:10 +0100 + matrix-synapse-py3 (1.41.0) stable; urgency=medium * New synapse release 1.41.0. diff --git a/synapse/__init__.py b/synapse/__init__.py index ef3770262e..06d80f79b3 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.41.0" +__version__ = "1.41.1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.5.1 From 78e590d473df6a225157dfa7460341b05e52bc26 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 31 Aug 2021 12:38:43 -0400 Subject: Move the sessions delta to the latest schema version. (#10725) This was erroneously put under schema version 62 instead of 63. --- changelog.d/10725.feature | 1 + .../schema/main/delta/62/02session_store.sql | 23 ---------------------- .../schema/main/delta/63/03session_store.sql | 23 ++++++++++++++++++++++ 3 files changed, 24 insertions(+), 23 deletions(-) create mode 100644 changelog.d/10725.feature delete mode 100644 synapse/storage/schema/main/delta/62/02session_store.sql create mode 100644 synapse/storage/schema/main/delta/63/03session_store.sql diff --git a/changelog.d/10725.feature b/changelog.d/10725.feature new file mode 100644 index 0000000000..ffc4e4289c --- /dev/null +++ b/changelog.d/10725.feature @@ -0,0 +1 @@ +Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/synapse/storage/schema/main/delta/62/02session_store.sql b/synapse/storage/schema/main/delta/62/02session_store.sql deleted file mode 100644 index 535fb34c10..0000000000 --- a/synapse/storage/schema/main/delta/62/02session_store.sql +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright 2021 The Matrix.org Foundation C.I.C. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -CREATE TABLE IF NOT EXISTS sessions( - session_type TEXT NOT NULL, -- The unique key for this type of session. - session_id TEXT NOT NULL, -- The session ID passed to the client. - value TEXT NOT NULL, -- A JSON dictionary to persist. - expiry_time_ms BIGINT NOT NULL, -- The time this session will expire (epoch time in milliseconds). - UNIQUE (session_type, session_id) -); diff --git a/synapse/storage/schema/main/delta/63/03session_store.sql b/synapse/storage/schema/main/delta/63/03session_store.sql new file mode 100644 index 0000000000..535fb34c10 --- /dev/null +++ b/synapse/storage/schema/main/delta/63/03session_store.sql @@ -0,0 +1,23 @@ +/* + * Copyright 2021 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS sessions( + session_type TEXT NOT NULL, -- The unique key for this type of session. + session_id TEXT NOT NULL, -- The session ID passed to the client. + value TEXT NOT NULL, -- A JSON dictionary to persist. + expiry_time_ms BIGINT NOT NULL, -- The time this session will expire (epoch time in milliseconds). + UNIQUE (session_type, session_id) +); -- cgit 1.5.1 From 287918e2d4437d890b2c0078b12859fae268dacc Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 31 Aug 2021 13:22:29 -0400 Subject: Additional type hints for the client REST servlets (part 3). (#10707) --- changelog.d/10707.misc | 1 + synapse/rest/client/account_data.py | 37 ++++-- synapse/rest/client/groups.py | 12 +- synapse/rest/client/receipts.py | 15 ++- synapse/rest/client/register.py | 78 +++++------- synapse/rest/client/relations.py | 80 ++++++++++--- synapse/rest/client/room.py | 233 +++++++++++++++++++++++++----------- 7 files changed, 306 insertions(+), 150 deletions(-) create mode 100644 changelog.d/10707.misc diff --git a/changelog.d/10707.misc b/changelog.d/10707.misc new file mode 100644 index 0000000000..39a37b90b1 --- /dev/null +++ b/changelog.d/10707.misc @@ -0,0 +1 @@ +Add missing type hints to REST servlets. diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py index 7517e9304e..d1badbdf3b 100644 --- a/synapse/rest/client/account_data.py +++ b/synapse/rest/client/account_data.py @@ -13,12 +13,19 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import AuthError, NotFoundError, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -32,13 +39,15 @@ class AccountDataServlet(RestServlet): "/user/(?P[^/]*)/account_data/(?P[^/]*)" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.handler = hs.get_account_data_handler() - async def on_PUT(self, request, user_id, account_data_type): + async def on_PUT( + self, request: SynapseRequest, user_id: str, account_data_type: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") @@ -49,7 +58,9 @@ class AccountDataServlet(RestServlet): return 200, {} - async def on_GET(self, request, user_id, account_data_type): + async def on_GET( + self, request: SynapseRequest, user_id: str, account_data_type: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") @@ -76,13 +87,19 @@ class RoomAccountDataServlet(RestServlet): "/account_data/(?P[^/]*)" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.handler = hs.get_account_data_handler() - async def on_PUT(self, request, user_id, room_id, account_data_type): + async def on_PUT( + self, + request: SynapseRequest, + user_id: str, + room_id: str, + account_data_type: str, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") @@ -102,7 +119,13 @@ class RoomAccountDataServlet(RestServlet): return 200, {} - async def on_GET(self, request, user_id, room_id, account_data_type): + async def on_GET( + self, + request: SynapseRequest, + user_id: str, + room_id: str, + account_data_type: str, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") @@ -117,6 +140,6 @@ class RoomAccountDataServlet(RestServlet): return 200, event -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: AccountDataServlet(hs).register(http_server) RoomAccountDataServlet(hs).register(http_server) diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py index c3667ff8aa..0950f43f2f 100644 --- a/synapse/rest/client/groups.py +++ b/synapse/rest/client/groups.py @@ -156,7 +156,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): group_id: str, category_id: Optional[str], room_id: str, - ): + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -188,7 +188,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): @_validate_group_id async def on_DELETE( self, request: SynapseRequest, group_id: str, category_id: str, room_id: str - ): + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -451,7 +451,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): @_validate_group_id async def on_DELETE( self, request: SynapseRequest, group_id: str, role_id: str, user_id: str - ): + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -674,7 +674,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): @_validate_group_id async def on_PUT( self, request: SynapseRequest, group_id: str, room_id: str, config_key: str - ): + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -706,7 +706,7 @@ class GroupAdminUsersInviteServlet(RestServlet): @_validate_group_id async def on_PUT( - self, request: SynapseRequest, group_id, user_id + self, request: SynapseRequest, group_id: str, user_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -738,7 +738,7 @@ class GroupAdminUsersKickServlet(RestServlet): @_validate_group_id async def on_PUT( - self, request: SynapseRequest, group_id, user_id + self, request: SynapseRequest, group_id: str, user_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index d9ab836cd8..9770413c61 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -13,13 +13,20 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.constants import ReadReceiptEventFields from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -30,14 +37,16 @@ class ReceiptRestServlet(RestServlet): "/(?P[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.receipts_handler = hs.get_receipts_handler() self.presence_handler = hs.get_presence_handler() - async def on_POST(self, request, room_id, receipt_type, event_id): + async def on_POST( + self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if receipt_type != "m.read": @@ -67,5 +76,5 @@ class ReceiptRestServlet(RestServlet): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReceiptRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 7b5f49d635..a28acd4041 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -14,7 +14,9 @@ # limitations under the License. import logging import random -from typing import List, Union +from typing import TYPE_CHECKING, List, Optional, Tuple + +from twisted.web.server import Request import synapse import synapse.api.auth @@ -29,15 +31,13 @@ from synapse.api.errors import ( ) from synapse.api.ratelimiting import Ratelimiter from synapse.config import ConfigError -from synapse.config.captcha import CaptchaConfig -from synapse.config.consent import ConsentConfig from synapse.config.emailconfig import ThreepidBehaviour +from synapse.config.homeserver import HomeServerConfig from synapse.config.ratelimiting import FederationRateLimitConfig -from synapse.config.registration import RegistrationConfig from synapse.config.server import is_threepid_reserved from synapse.handlers.auth import AuthHandler from synapse.handlers.ui_auth import UIAuthSessionDataConstants -from synapse.http.server import finish_request, respond_with_html +from synapse.http.server import HttpServer, finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -45,6 +45,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.metrics import threepid_send_requests from synapse.push.mailer import Mailer from synapse.types import JsonDict @@ -59,17 +60,16 @@ from synapse.util.threepids import ( from ._base import client_patterns, interactive_auth_handler +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class EmailRegisterRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/register/email/requestToken$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() @@ -83,7 +83,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): template_text=self.config.email_registration_template_text, ) - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.hs.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -171,16 +171,12 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): class MsisdnRegisterRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/register/msisdn/requestToken$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) assert_params_in_dict( @@ -255,11 +251,7 @@ class RegistrationSubmitTokenServlet(RestServlet): "/registration/(?P[^/]*)/submit_token$", releases=(), unstable=True ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -272,7 +264,7 @@ class RegistrationSubmitTokenServlet(RestServlet): self.config.email_registration_template_failure_html ) - async def on_GET(self, request, medium): + async def on_GET(self, request: Request, medium: str) -> None: if medium != "email": raise SynapseError( 400, "This medium is currently not supported for registration" @@ -326,11 +318,7 @@ class RegistrationSubmitTokenServlet(RestServlet): class UsernameAvailabilityRestServlet(RestServlet): PATTERNS = client_patterns("/register/available") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.registration_handler = hs.get_registration_handler() @@ -350,7 +338,7 @@ class UsernameAvailabilityRestServlet(RestServlet): ), ) - async def on_GET(self, request): + async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: if not self.hs.config.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN @@ -419,11 +407,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): class RegisterRestServlet(RestServlet): PATTERNS = client_patterns("/register$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs @@ -445,23 +429,21 @@ class RegisterRestServlet(RestServlet): ) @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) client_addr = request.getClientIP() await self.ratelimiter.ratelimit(None, client_addr, update=False) - kind = b"user" - if b"kind" in request.args: - kind = request.args[b"kind"][0] + kind = parse_string(request, "kind", default="user") - if kind == b"guest": + if kind == "guest": ret = await self._do_guest_registration(body, address=client_addr) return ret - elif kind != b"user": + elif kind != "user": raise UnrecognizedRequestError( - "Do not understand membership kind: %s" % (kind.decode("utf8"),) + f"Do not understand membership kind: {kind}", ) if self._msc2918_enabled: @@ -749,7 +731,7 @@ class RegisterRestServlet(RestServlet): async def _do_appservice_registration( self, username, as_token, body, should_issue_refresh_token: bool = False - ): + ) -> JsonDict: user_id = await self.registration_handler.appservice_register( username, as_token ) @@ -766,7 +748,7 @@ class RegisterRestServlet(RestServlet): params: JsonDict, is_appservice_ghost: bool = False, should_issue_refresh_token: bool = False, - ): + ) -> JsonDict: """Complete registration of newly-registered user Allocates device_id if one was not given; also creates access_token. @@ -810,7 +792,9 @@ class RegisterRestServlet(RestServlet): return result - async def _do_guest_registration(self, params, address=None): + async def _do_guest_registration( + self, params: JsonDict, address: Optional[str] = None + ) -> Tuple[int, JsonDict]: if not self.hs.config.allow_guest_access: raise SynapseError(403, "Guest access is disabled") user_id = await self.registration_handler.register_user( @@ -848,9 +832,7 @@ class RegisterRestServlet(RestServlet): def _calculate_registration_flows( - # technically `config` has to provide *all* of these interfaces, not just one - config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig], - auth_handler: AuthHandler, + config: HomeServerConfig, auth_handler: AuthHandler ) -> List[List[str]]: """Get a suitable flows list for registration @@ -929,7 +911,7 @@ def _calculate_registration_flows( return flows -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: EmailRegisterRequestTokenRestServlet(hs).register(http_server) MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) UsernameAvailabilityRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 0821cd285f..0b0711c03c 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -19,25 +19,32 @@ any time to reflect changes in the MSC. """ import logging +from typing import TYPE_CHECKING, Awaitable, Optional, Tuple from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import ShadowBanError, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_integer, parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.rest.client.transactions import HttpTransactionCache from synapse.storage.relations import ( AggregationPaginationToken, PaginationChunk, RelationPaginationToken, ) +from synapse.types import JsonDict from synapse.util.stringutils import random_string from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -59,13 +66,13 @@ class RelationSendServlet(RestServlet): "/(?P[^/]*)/(?P[^/]*)/(?P[^/]*)" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.event_creation_handler = hs.get_event_creation_handler() self.txns = HttpTransactionCache(hs) - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: http_server.register_paths( "POST", client_patterns(self.PATTERN + "$", releases=()), @@ -79,14 +86,35 @@ class RelationSendServlet(RestServlet): self.__class__.__name__, ) - def on_PUT(self, request, *args, **kwargs): + def on_PUT( + self, + request: SynapseRequest, + room_id: str, + parent_id: str, + relation_type: str, + event_type: str, + txn_id: Optional[str] = None, + ) -> Awaitable[Tuple[int, JsonDict]]: return self.txns.fetch_or_execute_request( - request, self.on_PUT_or_POST, request, *args, **kwargs + request, + self.on_PUT_or_POST, + request, + room_id, + parent_id, + relation_type, + event_type, + txn_id, ) async def on_PUT_or_POST( - self, request, room_id, parent_id, relation_type, event_type, txn_id=None - ): + self, + request: SynapseRequest, + room_id: str, + parent_id: str, + relation_type: str, + event_type: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) if event_type == EventTypes.Member: @@ -136,7 +164,7 @@ class RelationPaginationServlet(RestServlet): releases=(), ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -145,8 +173,13 @@ class RelationPaginationServlet(RestServlet): self.event_handler = hs.get_event_handler() async def on_GET( - self, request, room_id, parent_id, relation_type=None, event_type=None - ): + self, + request: SynapseRequest, + room_id: str, + parent_id: str, + relation_type: Optional[str] = None, + event_type: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.check_user_in_room_or_world_readable( @@ -156,6 +189,8 @@ class RelationPaginationServlet(RestServlet): # This gets the original event and checks that a) the event exists and # b) the user is allowed to view it. event = await self.event_handler.get_event(requester.user, room_id, parent_id) + if event is None: + raise SynapseError(404, "Unknown parent event.") limit = parse_integer(request, "limit", default=5) from_token_str = parse_string(request, "from") @@ -233,15 +268,20 @@ class RelationAggregationPaginationServlet(RestServlet): releases=(), ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.event_handler = hs.get_event_handler() async def on_GET( - self, request, room_id, parent_id, relation_type=None, event_type=None - ): + self, + request: SynapseRequest, + room_id: str, + parent_id: str, + relation_type: Optional[str] = None, + event_type: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.check_user_in_room_or_world_readable( @@ -253,6 +293,8 @@ class RelationAggregationPaginationServlet(RestServlet): # This checks that a) the event exists and b) the user is allowed to # view it. event = await self.event_handler.get_event(requester.user, room_id, parent_id) + if event is None: + raise SynapseError(404, "Unknown parent event.") if relation_type not in (RelationTypes.ANNOTATION, None): raise SynapseError(400, "Relation type must be 'annotation'") @@ -315,7 +357,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): releases=(), ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -323,7 +365,15 @@ class RelationAggregationGroupPaginationServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() - async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): + async def on_GET( + self, + request: SynapseRequest, + room_id: str, + parent_id: str, + relation_type: str, + event_type: str, + key: str, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.check_user_in_room_or_world_readable( @@ -374,7 +424,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): return 200, return_value -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationSendServlet(hs).register(http_server) RelationPaginationServlet(hs).register(http_server) RelationAggregationPaginationServlet(hs).register(http_server) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index c5c54564be..9b0c546505 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -16,9 +16,11 @@ """ This module contains REST servlets to do with rooms: /rooms/ """ import logging import re -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple from urllib import parse as urlparse +from twisted.web.server import Request + from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, @@ -30,6 +32,7 @@ from synapse.api.errors import ( ) from synapse.api.filtering import Filter from synapse.events.utils import format_event_for_client_v2 +from synapse.http.server import HttpServer from synapse.http.servlet import ( ResolveRoomIdMixin, RestServlet, @@ -57,7 +60,7 @@ logger = logging.getLogger(__name__) class TransactionRestServlet(RestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.txns = HttpTransactionCache(hs) @@ -65,20 +68,22 @@ class TransactionRestServlet(RestServlet): class RoomCreateRestServlet(TransactionRestServlet): # No PATTERN; we have custom dispatch rules here - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._room_creation_handler = hs.get_room_creation_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) - def on_PUT(self, request, txn_id): + def on_PUT( + self, request: SynapseRequest, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request(request, self.on_POST, request) - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) info, _ = await self._room_creation_handler.create_room( @@ -87,21 +92,21 @@ class RoomCreateRestServlet(TransactionRestServlet): return 200, info - def get_room_config(self, request): + def get_room_config(self, request: Request) -> JsonDict: user_supplied_config = parse_json_object_from_request(request) return user_supplied_config # TODO: Needs unit testing for generic events class RoomStateEventRestServlet(TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: # /room/$roomid/state/$eventtype no_state_key = "/rooms/(?P[^/]*)/state/(?P[^/]*)$" @@ -136,13 +141,19 @@ class RoomStateEventRestServlet(TransactionRestServlet): self.__class__.__name__, ) - def on_GET_no_state_key(self, request, room_id, event_type): + def on_GET_no_state_key( + self, request: SynapseRequest, room_id: str, event_type: str + ) -> Awaitable[Tuple[int, JsonDict]]: return self.on_GET(request, room_id, event_type, "") - def on_PUT_no_state_key(self, request, room_id, event_type): + def on_PUT_no_state_key( + self, request: SynapseRequest, room_id: str, event_type: str + ) -> Awaitable[Tuple[int, JsonDict]]: return self.on_PUT(request, room_id, event_type, "") - async def on_GET(self, request, room_id, event_type, state_key): + async def on_GET( + self, request: SynapseRequest, room_id: str, event_type: str, state_key: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) format = parse_string( request, "format", default="content", allowed_values=["content", "event"] @@ -165,7 +176,17 @@ class RoomStateEventRestServlet(TransactionRestServlet): elif format == "content": return 200, data.get_dict()["content"] - async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): + # Format must be event or content, per the parse_string call above. + raise RuntimeError(f"Unknown format: {format:r}.") + + async def on_PUT( + self, + request: SynapseRequest, + room_id: str, + event_type: str, + state_key: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if txn_id: @@ -211,27 +232,35 @@ class RoomStateEventRestServlet(TransactionRestServlet): # TODO: Needs unit testing for generic events + feedback class RoomSendEventRestServlet(TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/send/$event_type[/$txn_id] PATTERNS = "/rooms/(?P[^/]*)/send/(?P[^/]*)" register_txn_path(self, PATTERNS, http_server, with_get=True) - async def on_POST(self, request, room_id, event_type, txn_id=None): + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + event_type: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) - event_dict = { + event_dict: JsonDict = { "type": event_type, "content": content, "room_id": room_id, "sender": requester.user.to_string(), } + # Twisted will have processed the args by now. + assert request.args is not None if b"ts" in request.args and requester.app_service: event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) @@ -249,10 +278,14 @@ class RoomSendEventRestServlet(TransactionRestServlet): set_tag("event_id", event_id) return 200, {"event_id": event_id} - def on_GET(self, request, room_id, event_type, txn_id): + def on_GET( + self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str + ) -> Tuple[int, str]: return 200, "Not implemented" - def on_PUT(self, request, room_id, event_type, txn_id): + def on_PUT( + self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( @@ -262,12 +295,12 @@ class RoomSendEventRestServlet(TransactionRestServlet): # TODO: Needs unit testing for room ID + alias joins class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: # /join/$room_identifier[/$txn_id] PATTERNS = "/join/(?P[^/]*)" register_txn_path(self, PATTERNS, http_server) @@ -277,7 +310,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): request: SynapseRequest, room_identifier: str, txn_id: Optional[str] = None, - ): + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) try: @@ -308,7 +341,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): return 200, {"room_id": room_id} - def on_PUT(self, request, room_identifier, txn_id): + def on_PUT( + self, request: SynapseRequest, room_identifier: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( @@ -320,12 +355,12 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): class PublicRoomListRestServlet(TransactionRestServlet): PATTERNS = client_patterns("/publicRooms$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs self.auth = hs.get_auth() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: server = parse_string(request, "server") try: @@ -374,7 +409,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): return 200, data - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) server = parse_string(request, "server") @@ -438,13 +473,15 @@ class PublicRoomListRestServlet(TransactionRestServlet): class RoomMemberListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/members$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: # TODO support Pagination stream API (limit/tokens) requester = await self.auth.get_user_by_req(request, allow_guest=True) handler = self.message_handler @@ -490,12 +527,14 @@ class RoomMemberListRestServlet(RestServlet): class JoinedRoomMemberListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/joined_members$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) users_with_profile = await self.message_handler.get_joined_members( @@ -509,17 +548,21 @@ class JoinedRoomMemberListRestServlet(RestServlet): class RoomMessageListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/messages$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = await PaginationConfig.from_request( self.store, request, default_limit=10 ) + # Twisted will have processed the args by now. + assert request.args is not None as_client_event = b"raw" not in request.args filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: @@ -549,12 +592,14 @@ class RoomMessageListRestServlet(RestServlet): class RoomStateRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/state$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, List[JsonDict]]: requester = await self.auth.get_user_by_req(request, allow_guest=True) # Get all the current state for this room events = await self.message_handler.get_state_events( @@ -569,13 +614,15 @@ class RoomStateRestServlet(RestServlet): class RoomInitialSyncRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/initialSync$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = await PaginationConfig.from_request(self.store, request) content = await self.initial_sync_handler.room_initial_sync( @@ -589,14 +636,16 @@ class RoomEventServlet(RestServlet): "/rooms/(?P[^/]*)/event/(?P[^/]*)$", v1=True ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - async def on_GET(self, request, room_id, event_id): + async def on_GET( + self, request: SynapseRequest, room_id: str, event_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) try: event = await self.event_handler.get_event( @@ -610,10 +659,10 @@ class RoomEventServlet(RestServlet): time_now = self.clock.time_msec() if event: - event = await self._event_serializer.serialize_event(event, time_now) - return 200, event + event_dict = await self._event_serializer.serialize_event(event, time_now) + return 200, event_dict - return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) class RoomEventContextServlet(RestServlet): @@ -621,14 +670,16 @@ class RoomEventContextServlet(RestServlet): "/rooms/(?P[^/]*)/context/(?P[^/]*)$", v1=True ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() self.room_context_handler = hs.get_room_context_handler() self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - async def on_GET(self, request, room_id, event_id): + async def on_GET( + self, request: SynapseRequest, room_id: str, event_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) limit = parse_integer(request, "limit", default=10) @@ -669,23 +720,27 @@ class RoomEventContextServlet(RestServlet): class RoomForgetRestServlet(TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: PATTERNS = "/rooms/(?P[^/]*)/forget" register_txn_path(self, PATTERNS, http_server) - async def on_POST(self, request, room_id, txn_id=None): + async def on_POST( + self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=False) await self.room_member_handler.forget(user=requester.user, room_id=room_id) return 200, {} - def on_PUT(self, request, room_id, txn_id): + def on_PUT( + self, request: SynapseRequest, room_id: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( @@ -695,12 +750,12 @@ class RoomForgetRestServlet(TransactionRestServlet): # TODO: Needs unit testing class RoomMembershipRestServlet(TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/[invite|join|leave] PATTERNS = ( "/rooms/(?P[^/]*)/" @@ -708,7 +763,13 @@ class RoomMembershipRestServlet(TransactionRestServlet): ) register_txn_path(self, PATTERNS, http_server) - async def on_POST(self, request, room_id, membership_action, txn_id=None): + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + membership_action: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) if requester.is_guest and membership_action not in { @@ -771,13 +832,15 @@ class RoomMembershipRestServlet(TransactionRestServlet): return 200, return_value - def _has_3pid_invite_keys(self, content): + def _has_3pid_invite_keys(self, content: JsonDict) -> bool: for key in {"id_server", "medium", "address"}: if key not in content: return False return True - def on_PUT(self, request, room_id, membership_action, txn_id): + def on_PUT( + self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( @@ -786,16 +849,22 @@ class RoomMembershipRestServlet(TransactionRestServlet): class RoomRedactEventRestServlet(TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: PATTERNS = "/rooms/(?P[^/]*)/redact/(?P[^/]*)" register_txn_path(self, PATTERNS, http_server) - async def on_POST(self, request, room_id, event_id, txn_id=None): + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + event_id: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) @@ -821,7 +890,9 @@ class RoomRedactEventRestServlet(TransactionRestServlet): set_tag("event_id", event_id) return 200, {"event_id": event_id} - def on_PUT(self, request, room_id, event_id, txn_id): + def on_PUT( + self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( @@ -846,7 +917,9 @@ class RoomTypingRestServlet(RestServlet): hs.config.worker.writers.typing == hs.get_instance_name() ) - async def on_PUT(self, request, room_id, user_id): + async def on_PUT( + self, request: SynapseRequest, room_id: str, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if not self._is_typing_writer: @@ -897,7 +970,9 @@ class RoomAliasListServlet(RestServlet): self.auth = hs.get_auth() self.directory_handler = hs.get_directory_handler() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) alias_list = await self.directory_handler.get_aliases_for_room( @@ -910,12 +985,12 @@ class RoomAliasListServlet(RestServlet): class SearchRestServlet(RestServlet): PATTERNS = client_patterns("/search$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.search_handler = hs.get_search_handler() self.auth = hs.get_auth() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) @@ -929,19 +1004,24 @@ class SearchRestServlet(RestServlet): class JoinedRoomsRestServlet(RestServlet): PATTERNS = client_patterns("/joined_rooms$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) room_ids = await self.store.get_rooms_for_user(requester.user.to_string()) return 200, {"joined_rooms": list(room_ids)} -def register_txn_path(servlet, regex_string, http_server, with_get=False): +def register_txn_path( + servlet: RestServlet, + regex_string: str, + http_server: HttpServer, + with_get: bool = False, +) -> None: """Registers a transaction-based path. This registers two paths: @@ -949,28 +1029,37 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): POST regex_string Args: - regex_string (str): The regex string to register. Must NOT have a - trailing $ as this string will be appended to. - http_server : The http_server to register paths with. + regex_string: The regex string to register. Must NOT have a + trailing $ as this string will be appended to. + http_server: The http_server to register paths with. with_get: True to also register respective GET paths for the PUTs. """ + on_POST = getattr(servlet, "on_POST", None) + on_PUT = getattr(servlet, "on_PUT", None) + if on_POST is None or on_PUT is None: + raise RuntimeError("on_POST and on_PUT must exist when using register_txn_path") http_server.register_paths( "POST", client_patterns(regex_string + "$", v1=True), - servlet.on_POST, + on_POST, servlet.__class__.__name__, ) http_server.register_paths( "PUT", client_patterns(regex_string + "/(?P[^/]*)$", v1=True), - servlet.on_PUT, + on_PUT, servlet.__class__.__name__, ) + on_GET = getattr(servlet, "on_GET", None) if with_get: + if on_GET is None: + raise RuntimeError( + "register_txn_path called with with_get = True, but no on_GET method exists" + ) http_server.register_paths( "GET", client_patterns(regex_string + "/(?P[^/]*)$", v1=True), - servlet.on_GET, + on_GET, servlet.__class__.__name__, ) @@ -1120,7 +1209,9 @@ class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet): ) -def register_servlets(hs: "HomeServer", http_server, is_worker=False): +def register_servlets( + hs: "HomeServer", http_server: HttpServer, is_worker: bool = False +) -> None: RoomStateEventRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server) JoinedRoomMemberListRestServlet(hs).register(http_server) @@ -1148,5 +1239,5 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False): RoomForgetRestServlet(hs).register(http_server) -def register_deprecated_servlets(hs, http_server): +def register_deprecated_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomInitialSyncRestServlet(hs).register(http_server) -- cgit 1.5.1 From e2481dbe9325321d460037a2efe9b9ea2ac78057 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 31 Aug 2021 18:37:07 -0400 Subject: Allow configuration of the oEmbed URLs. (#10714) This adds configuration options (under an `oembed` section) to configure which URLs are matched to use oEmbed for URL previews. --- changelog.d/10714.feature | 1 + docs/sample_config.yaml | 21 +++ synapse/config/homeserver.py | 2 + synapse/config/oembed.py | 180 ++++++++++++++++++++++ synapse/res/providers.json | 17 +++ synapse/rest/media/v1/oembed.py | 135 ++++++++++++++++ synapse/rest/media/v1/preview_url_resource.py | 147 +----------------- tests/rest/media/v1/test_url_preview.py | 212 +++++++++++++------------- 8 files changed, 463 insertions(+), 252 deletions(-) create mode 100644 changelog.d/10714.feature create mode 100644 synapse/config/oembed.py create mode 100644 synapse/res/providers.json create mode 100644 synapse/rest/media/v1/oembed.py diff --git a/changelog.d/10714.feature b/changelog.d/10714.feature new file mode 100644 index 0000000000..7d18f5c133 --- /dev/null +++ b/changelog.d/10714.feature @@ -0,0 +1 @@ +Allow configuration of the oEmbed URLs used for URL previews. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 935841dbfa..e155b978d8 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1075,6 +1075,27 @@ url_preview_accept_language: # - en +# oEmbed allows for easier embedding content from a website. It can be +# used for generating URLs previews of services which support it. +# +oembed: + # A default list of oEmbed providers is included with Synapse. + # + # Uncomment the following to disable using these default oEmbed URLs. + # Defaults to 'false'. + # + #disable_default_providers: true + + # Additional files with oEmbed configuration (each should be in the + # form of providers.json). + # + # By default, this list is empty (so only the default providers.json + # is used). + # + #additional_providers: + # - oembed/my_providers.json + + ## Captcha ## # See docs/CAPTCHA_SETUP.md for full details of configuring this. diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 1f42a51857..442f1b9ac0 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -30,6 +30,7 @@ from .key import KeyConfig from .logger import LoggingConfig from .metrics import MetricsConfig from .modules import ModulesConfig +from .oembed import OembedConfig from .oidc import OIDCConfig from .password_auth_providers import PasswordAuthProviderConfig from .push import PushConfig @@ -65,6 +66,7 @@ class HomeServerConfig(RootConfig): LoggingConfig, RatelimitConfig, ContentRepositoryConfig, + OembedConfig, CaptchaConfig, VoipConfig, RegistrationConfig, diff --git a/synapse/config/oembed.py b/synapse/config/oembed.py new file mode 100644 index 0000000000..09267b5eef --- /dev/null +++ b/synapse/config/oembed.py @@ -0,0 +1,180 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import re +from typing import Any, Dict, Iterable, List, Pattern +from urllib import parse as urlparse + +import attr +import pkg_resources + +from synapse.types import JsonDict + +from ._base import Config, ConfigError +from ._util import validate_config + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class OEmbedEndpointConfig: + # The API endpoint to fetch. + api_endpoint: str + # The patterns to match. + url_patterns: List[Pattern] + + +class OembedConfig(Config): + """oEmbed Configuration""" + + section = "oembed" + + def read_config(self, config, **kwargs): + oembed_config: Dict[str, Any] = config.get("oembed") or {} + + # A list of patterns which will be used. + self.oembed_patterns: List[OEmbedEndpointConfig] = list( + self._parse_and_validate_providers(oembed_config) + ) + + def _parse_and_validate_providers( + self, oembed_config: dict + ) -> Iterable[OEmbedEndpointConfig]: + """Extract and parse the oEmbed providers from the given JSON file. + + Returns a generator which yields the OidcProviderConfig objects + """ + # Whether to use the packaged providers.json file. + if not oembed_config.get("disable_default_providers") or False: + providers = json.load( + pkg_resources.resource_stream("synapse", "res/providers.json") + ) + yield from self._parse_and_validate_provider( + providers, config_path=("oembed",) + ) + + # The JSON files which includes additional provider information. + for i, file in enumerate(oembed_config.get("additional_providers") or []): + # TODO Error checking. + with open(file) as f: + providers = json.load(f) + + yield from self._parse_and_validate_provider( + providers, + config_path=( + "oembed", + "additional_providers", + f"", + ), + ) + + def _parse_and_validate_provider( + self, providers: List[JsonDict], config_path: Iterable[str] + ) -> Iterable[OEmbedEndpointConfig]: + # Ensure it is the proper form. + validate_config( + _OEMBED_PROVIDER_SCHEMA, + providers, + config_path=config_path, + ) + + # Parse it and yield each result. + for provider in providers: + # Each provider might have multiple API endpoints, each which + # might have multiple patterns to match. + for endpoint in provider["endpoints"]: + api_endpoint = endpoint["url"] + patterns = [ + self._glob_to_pattern(glob, config_path) + for glob in endpoint["schemes"] + ] + yield OEmbedEndpointConfig(api_endpoint, patterns) + + def _glob_to_pattern(self, glob: str, config_path: Iterable[str]) -> Pattern: + """ + Convert the glob into a sane regular expression to match against. The + rules followed will be slightly different for the domain portion vs. + the rest. + + 1. The scheme must be one of HTTP / HTTPS (and have no globs). + 2. The domain can have globs, but we limit it to characters that can + reasonably be a domain part. + TODO: This does not attempt to handle Unicode domain names. + TODO: The domain should not allow wildcard TLDs. + 3. Other parts allow a glob to be any one, or more, characters. + """ + results = urlparse.urlparse(glob) + + # Ensure the scheme does not have wildcards (and is a sane scheme). + if results.scheme not in {"http", "https"}: + raise ConfigError(f"Insecure oEmbed scheme: {results.scheme}", config_path) + + pattern = urlparse.urlunparse( + [ + results.scheme, + re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"), + ] + + [re.escape(part).replace("\\*", ".+") for part in results[2:]] + ) + return re.compile(pattern) + + def generate_config_section(self, **kwargs): + return """\ + # oEmbed allows for easier embedding content from a website. It can be + # used for generating URLs previews of services which support it. + # + oembed: + # A default list of oEmbed providers is included with Synapse. + # + # Uncomment the following to disable using these default oEmbed URLs. + # Defaults to 'false'. + # + #disable_default_providers: true + + # Additional files with oEmbed configuration (each should be in the + # form of providers.json). + # + # By default, this list is empty (so only the default providers.json + # is used). + # + #additional_providers: + # - oembed/my_providers.json + """ + + +_OEMBED_PROVIDER_SCHEMA = { + "type": "array", + "items": { + "type": "object", + "properties": { + "provider_name": {"type": "string"}, + "provider_url": {"type": "string"}, + "endpoints": { + "type": "array", + "items": { + "type": "object", + "properties": { + "schemes": { + "type": "array", + "items": {"type": "string"}, + }, + "url": {"type": "string"}, + "formats": {"type": "array", "items": {"type": "string"}}, + "discovery": {"type": "boolean"}, + }, + "required": ["schemes", "url"], + }, + }, + }, + "required": ["provider_name", "provider_url", "endpoints"], + }, +} diff --git a/synapse/res/providers.json b/synapse/res/providers.json new file mode 100644 index 0000000000..f1838f9559 --- /dev/null +++ b/synapse/res/providers.json @@ -0,0 +1,17 @@ +[ + { + "provider_name": "Twitter", + "provider_url": "http://www.twitter.com/", + "endpoints": [ + { + "schemes": [ + "https://twitter.com/*/status/*", + "https://*.twitter.com/*/status/*", + "https://twitter.com/*/moments/*", + "https://*.twitter.com/*/moments/*" + ], + "url": "https://publish.twitter.com/oembed" + } + ] + } +] \ No newline at end of file diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py new file mode 100644 index 0000000000..afe41823e4 --- /dev/null +++ b/synapse/rest/media/v1/oembed.py @@ -0,0 +1,135 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Optional + +import attr + +from synapse.http.client import SimpleHttpClient + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +@attr.s(slots=True, auto_attribs=True) +class OEmbedResult: + # Either HTML content or URL must be provided. + html: Optional[str] + url: Optional[str] + title: Optional[str] + # Number of seconds to cache the content. + cache_age: int + + +class OEmbedError(Exception): + """An error occurred processing the oEmbed object.""" + + +class OEmbedProvider: + """ + A helper for accessing oEmbed content. + + It can be used to check if a URL should be accessed via oEmbed and for + requesting/parsing oEmbed content. + """ + + def __init__(self, hs: "HomeServer", client: SimpleHttpClient): + self._oembed_patterns = {} + for oembed_endpoint in hs.config.oembed.oembed_patterns: + for pattern in oembed_endpoint.url_patterns: + self._oembed_patterns[pattern] = oembed_endpoint.api_endpoint + self._client = client + + def get_oembed_url(self, url: str) -> Optional[str]: + """ + Check whether the URL should be downloaded as oEmbed content instead. + + Args: + url: The URL to check. + + Returns: + A URL to use instead or None if the original URL should be used. + """ + for url_pattern, endpoint in self._oembed_patterns.items(): + if url_pattern.fullmatch(url): + return endpoint + + # No match. + return None + + async def get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult: + """ + Request content from an oEmbed endpoint. + + Args: + endpoint: The oEmbed API endpoint. + url: The URL to pass to the API. + + Returns: + An object representing the metadata returned. + + Raises: + OEmbedError if fetching or parsing of the oEmbed information fails. + """ + try: + logger.debug("Trying to get oEmbed content for url '%s'", url) + result = await self._client.get_json( + endpoint, + # TODO Specify max height / width. + # Note that only the JSON format is supported. + args={"url": url}, + ) + + # Ensure there's a version of 1.0. + if result.get("version") != "1.0": + raise OEmbedError("Invalid version: %s" % (result.get("version"),)) + + oembed_type = result.get("type") + + # Ensure the cache age is None or an int. + cache_age = result.get("cache_age") + if cache_age: + cache_age = int(cache_age) + + oembed_result = OEmbedResult(None, None, result.get("title"), cache_age) + + # HTML content. + if oembed_type == "rich": + oembed_result.html = result.get("html") + return oembed_result + + if oembed_type == "photo": + oembed_result.url = result.get("url") + return oembed_result + + # TODO Handle link and video types. + + if "thumbnail_url" in result: + oembed_result.url = result.get("thumbnail_url") + return oembed_result + + raise OEmbedError("Incompatible oEmbed information.") + + except OEmbedError as e: + # Trap OEmbedErrors first so we can directly re-raise them. + logger.warning("Error parsing oEmbed metadata from %s: %r", url, e) + raise + + except Exception as e: + # Trap any exception and let the code follow as usual. + # FIXME: pass through 404s and other error messages nicely + logger.warning("Error downloading oEmbed metadata from %s: %r", url, e) + raise OEmbedError() from e diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 0f051d4041..317d333b12 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -25,8 +25,6 @@ import traceback from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union from urllib import parse as urlparse -import attr - from twisted.internet.error import DNSLookupError from twisted.web.server import Request @@ -43,6 +41,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers from synapse.rest.media.v1.media_storage import MediaStorage +from synapse.rest.media.v1.oembed import OEmbedError, OEmbedProvider from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache @@ -71,63 +70,6 @@ OG_TAG_VALUE_MAXLEN = 1000 ONE_HOUR = 60 * 60 * 1000 -# A map of globs to API endpoints. -_oembed_globs = { - # Twitter. - "https://publish.twitter.com/oembed": [ - "https://twitter.com/*/status/*", - "https://*.twitter.com/*/status/*", - "https://twitter.com/*/moments/*", - "https://*.twitter.com/*/moments/*", - # Include the HTTP versions too. - "http://twitter.com/*/status/*", - "http://*.twitter.com/*/status/*", - "http://twitter.com/*/moments/*", - "http://*.twitter.com/*/moments/*", - ], -} -# Convert the globs to regular expressions. -_oembed_patterns = {} -for endpoint, globs in _oembed_globs.items(): - for glob in globs: - # Convert the glob into a sane regular expression to match against. The - # rules followed will be slightly different for the domain portion vs. - # the rest. - # - # 1. The scheme must be one of HTTP / HTTPS (and have no globs). - # 2. The domain can have globs, but we limit it to characters that can - # reasonably be a domain part. - # TODO: This does not attempt to handle Unicode domain names. - # 3. Other parts allow a glob to be any one, or more, characters. - results = urlparse.urlparse(glob) - - # Ensure the scheme does not have wildcards (and is a sane scheme). - if results.scheme not in {"http", "https"}: - raise ValueError("Insecure oEmbed glob scheme: %s" % (results.scheme,)) - - pattern = urlparse.urlunparse( - [ - results.scheme, - re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"), - ] - + [re.escape(part).replace("\\*", ".+") for part in results[2:]] - ) - _oembed_patterns[re.compile(pattern)] = endpoint - - -@attr.s(slots=True) -class OEmbedResult: - # Either HTML content or URL must be provided. - html = attr.ib(type=Optional[str]) - url = attr.ib(type=Optional[str]) - title = attr.ib(type=Optional[str]) - # Number of seconds to cache the content. - cache_age = attr.ib(type=int) - - -class OEmbedError(Exception): - """An error occurred processing the oEmbed object.""" - class PreviewUrlResource(DirectServeJsonResource): isLeaf = True @@ -157,6 +99,8 @@ class PreviewUrlResource(DirectServeJsonResource): self.primary_base_path = media_repo.primary_base_path self.media_storage = media_storage + self._oembed = OEmbedProvider(hs, self.client) + # We run the background jobs if we're the instance specified (or no # instance is specified, where we assume there is only one instance # serving media). @@ -367,87 +311,6 @@ class PreviewUrlResource(DirectServeJsonResource): return jsonog.encode("utf8") - def _get_oembed_url(self, url: str) -> Optional[str]: - """ - Check whether the URL should be downloaded as oEmbed content instead. - - Args: - url: The URL to check. - - Returns: - A URL to use instead or None if the original URL should be used. - """ - for url_pattern, endpoint in _oembed_patterns.items(): - if url_pattern.fullmatch(url): - return endpoint - - # No match. - return None - - async def _get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult: - """ - Request content from an oEmbed endpoint. - - Args: - endpoint: The oEmbed API endpoint. - url: The URL to pass to the API. - - Returns: - An object representing the metadata returned. - - Raises: - OEmbedError if fetching or parsing of the oEmbed information fails. - """ - try: - logger.debug("Trying to get oEmbed content for url '%s'", url) - result = await self.client.get_json( - endpoint, - # TODO Specify max height / width. - # Note that only the JSON format is supported. - args={"url": url}, - ) - - # Ensure there's a version of 1.0. - if result.get("version") != "1.0": - raise OEmbedError("Invalid version: %s" % (result.get("version"),)) - - oembed_type = result.get("type") - - # Ensure the cache age is None or an int. - cache_age = result.get("cache_age") - if cache_age: - cache_age = int(cache_age) - - oembed_result = OEmbedResult(None, None, result.get("title"), cache_age) - - # HTML content. - if oembed_type == "rich": - oembed_result.html = result.get("html") - return oembed_result - - if oembed_type == "photo": - oembed_result.url = result.get("url") - return oembed_result - - # TODO Handle link and video types. - - if "thumbnail_url" in result: - oembed_result.url = result.get("thumbnail_url") - return oembed_result - - raise OEmbedError("Incompatible oEmbed information.") - - except OEmbedError as e: - # Trap OEmbedErrors first so we can directly re-raise them. - logger.warning("Error parsing oEmbed metadata from %s: %r", url, e) - raise - - except Exception as e: - # Trap any exception and let the code follow as usual. - # FIXME: pass through 404s and other error messages nicely - logger.warning("Error downloading oEmbed metadata from %s: %r", url, e) - raise OEmbedError() from e - async def _download_url(self, url: str, user: str) -> Dict[str, Any]: # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a @@ -459,11 +322,11 @@ class PreviewUrlResource(DirectServeJsonResource): # If this URL can be accessed via oEmbed, use that instead. url_to_download: Optional[str] = url - oembed_url = self._get_oembed_url(url) + oembed_url = self._oembed.get_oembed_url(url) if oembed_url: # The result might be a new URL to download, or it might be HTML content. try: - oembed_result = await self._get_oembed_content(oembed_url, url) + oembed_result = await self._oembed.get_oembed_content(oembed_url, url) if oembed_result.url: url_to_download = oembed_result.url elif oembed_result.html: diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index d3ef7bb4c6..7fa9027227 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -14,13 +14,14 @@ import json import os import re -from unittest.mock import patch from twisted.internet._resolver import HostResolution from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.error import DNSLookupError from twisted.test.proto_helpers import AccumulatingProtocol +from synapse.config.oembed import OEmbedEndpointConfig + from tests import unittest from tests.server import FakeTransport @@ -81,6 +82,19 @@ class URLPreviewTests(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) + # After the hs is created, modify the parsed oEmbed config (to avoid + # messing with files). + # + # Note that HTTP URLs are used to avoid having to deal with TLS in tests. + hs.config.oembed.oembed_patterns = [ + OEmbedEndpointConfig( + api_endpoint="http://publish.twitter.com/oembed", + url_patterns=[ + re.compile(r"http://twitter\.com/.+/status/.+"), + ], + ) + ] + return hs def prepare(self, reactor, clock, hs): @@ -544,123 +558,101 @@ class URLPreviewTests(unittest.HomeserverTestCase): def test_oembed_photo(self): """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL.""" - # Route the HTTP version to an HTTP endpoint so that the tests work. - with patch.dict( - "synapse.rest.media.v1.preview_url_resource._oembed_patterns", - { - re.compile( - r"http://twitter\.com/.+/status/.+" - ): "http://publish.twitter.com/oembed", - }, - clear=True, - ): - - self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] - self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")] - - result = { - "version": "1.0", - "type": "photo", - "url": "http://cdn.twitter.com/matrixdotorg", - } - oembed_content = json.dumps(result).encode("utf-8") - - end_content = ( - b"" - b"Some Title" - b'' - b"" - ) + self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] + self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")] - channel = self.make_request( - "GET", - "preview_url?url=http://twitter.com/matrixdotorg/status/12345", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b'Content-Type: application/json; charset="utf8"\r\n\r\n' - ) - % (len(oembed_content),) - + oembed_content - ) + result = { + "version": "1.0", + "type": "photo", + "url": "http://cdn.twitter.com/matrixdotorg", + } + oembed_content = json.dumps(result).encode("utf-8") - self.pump() - - client = self.reactor.tcpClients[1][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b'Content-Type: text/html; charset="utf8"\r\n\r\n' - ) - % (len(end_content),) - + end_content + end_content = ( + b"" + b"Some Title" + b'' + b"" + ) + + channel = self.make_request( + "GET", + "preview_url?url=http://twitter.com/matrixdotorg/status/12345", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: application/json; charset="utf8"\r\n\r\n' ) + % (len(oembed_content),) + + oembed_content + ) - self.pump() + self.pump() - self.assertEqual(channel.code, 200) - self.assertEqual( - channel.json_body, {"og:title": "Some Title", "og:description": "hi"} + client = self.reactor.tcpClients[1][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' ) + % (len(end_content),) + + end_content + ) + + self.pump() + + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, {"og:title": "Some Title", "og:description": "hi"} + ) def test_oembed_rich(self): """Test an oEmbed endpoint which returns HTML content via the 'rich' type.""" - # Route the HTTP version to an HTTP endpoint so that the tests work. - with patch.dict( - "synapse.rest.media.v1.preview_url_resource._oembed_patterns", - { - re.compile( - r"http://twitter\.com/.+/status/.+" - ): "http://publish.twitter.com/oembed", - }, - clear=True, - ): - - self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] - - result = { - "version": "1.0", - "type": "rich", - "html": "
Content Preview
", - } - end_content = json.dumps(result).encode("utf-8") - - channel = self.make_request( - "GET", - "preview_url?url=http://twitter.com/matrixdotorg/status/12345", - shorthand=False, - await_result=False, - ) - self.pump() - - client = self.reactor.tcpClients[0][2].buildProtocol(None) - server = AccumulatingProtocol() - server.makeConnection(FakeTransport(client, self.reactor)) - client.makeConnection(FakeTransport(server, self.reactor)) - client.dataReceived( - ( - b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" - b'Content-Type: application/json; charset="utf8"\r\n\r\n' - ) - % (len(end_content),) - + end_content - ) + self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")] + + result = { + "version": "1.0", + "type": "rich", + "html": "
Content Preview
", + } + end_content = json.dumps(result).encode("utf-8") + + channel = self.make_request( + "GET", + "preview_url?url=http://twitter.com/matrixdotorg/status/12345", + shorthand=False, + await_result=False, + ) + self.pump() - self.pump() - self.assertEqual(channel.code, 200) - self.assertEqual( - channel.json_body, - {"og:title": None, "og:description": "Content Preview"}, + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: application/json; charset="utf8"\r\n\r\n' ) + % (len(end_content),) + + end_content + ) + + self.pump() + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, + {"og:title": None, "og:description": "Content Preview"}, + ) -- cgit 1.5.1 From 3693ea61f5f56f4a49cce7e2b3ecd304f014d8cc Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 1 Sep 2021 10:13:01 +0100 Subject: Fix iteration in _remove_deleted_email_pushers background job. (#10734) --- changelog.d/10734.bugfix | 1 + synapse/storage/databases/main/pusher.py | 3 ++- tests/push/test_email.py | 44 ++++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10734.bugfix diff --git a/changelog.d/10734.bugfix b/changelog.d/10734.bugfix new file mode 100644 index 0000000000..15c7da4497 --- /dev/null +++ b/changelog.d/10734.bugfix @@ -0,0 +1 @@ +Remove pushers when deleting a 3pid from an account. Pushers for old unlinked emails will also be deleted. \ No newline at end of file diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index e47caa2125..63ac09c61d 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -430,10 +430,11 @@ class PusherWorkerStore(SQLBaseStore): """ txn.execute(sql, (last_pusher, batch_size)) + rows = txn.fetchall() last = None num_deleted = 0 - for row in txn: + for row in rows: last = row[0] num_deleted += 1 self.db_pool.simple_delete_txn( diff --git a/tests/push/test_email.py b/tests/push/test_email.py index eea07485a0..c4ba13a6b2 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -344,6 +344,50 @@ class EmailPusherTests(HomeserverTestCase): pushers = list(pushers) self.assertEqual(len(pushers), 0) + def test_remove_unlinked_pushers_background_job(self): + """Checks that all existing pushers associated with unlinked email addresses are removed + upon running the remove_deleted_email_pushers background update. + """ + # disassociate the user's email address manually (without deleting the pusher). + # This resembles the old behaviour, which the background update below is intended + # to clean up. + self.get_success( + self.hs.get_datastore().user_delete_threepid( + self.user_id, "email", "a@example.com" + ) + ) + + # Run the "remove_deleted_email_pushers" background job + self.get_success( + self.hs.get_datastore().db_pool.simple_insert( + table="background_updates", + values={ + "update_name": "remove_deleted_email_pushers", + "progress_json": "{}", + "depends_on": None, + }, + ) + ) + + # ... and tell the DataStore that it hasn't finished all updates yet + self.hs.get_datastore().db_pool.updates._all_done = False + + # Now let's actually drive the updates to completion + while not self.get_success( + self.hs.get_datastore().db_pool.updates.has_completed_background_updates() + ): + self.get_success( + self.hs.get_datastore().db_pool.updates.do_next_background_update(100), + by=0.1, + ) + + # Check that all pushers with unlinked addresses were deleted + pushers = self.get_success( + self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) + ) + pushers = list(pushers) + self.assertEqual(len(pushers), 0) + def _check_for_mail(self): """Check that the user receives an email notification""" -- cgit 1.5.1 From 6b2aca473a66ae60803208911f97fbe9789dc1ac Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 1 Sep 2021 11:47:24 +0100 Subject: 1.42.0rc1 --- CHANGES.md | 69 ++++++++++++++++++++++++++++++++++++++++++++++- changelog.d/10142.feature | 1 - changelog.d/10192.doc | 1 - changelog.d/10232.bugfix | 1 - changelog.d/10452.feature | 1 - changelog.d/10524.feature | 1 - changelog.d/10561.bugfix | 1 - changelog.d/10581.bugfix | 1 - changelog.d/10593.bugfix | 1 - changelog.d/10595.doc | 1 - changelog.d/10608.misc | 1 - changelog.d/10613.feature | 1 - changelog.d/10614.misc | 1 - changelog.d/10615.misc | 1 - changelog.d/10621.misc | 1 - changelog.d/10624.misc | 1 - changelog.d/10627.misc | 1 - changelog.d/10629.misc | 1 - changelog.d/10630.misc | 1 - changelog.d/10639.doc | 1 - changelog.d/10640.misc | 1 - changelog.d/10642.misc | 1 - changelog.d/10644.bugfix | 1 - changelog.d/10645.misc | 1 - changelog.d/10647.misc | 1 - changelog.d/10651.misc | 1 - changelog.d/10654.bugfix | 1 - changelog.d/10662.misc | 1 - changelog.d/10664.misc | 1 - changelog.d/10665.misc | 1 - changelog.d/10666.misc | 1 - changelog.d/10667.misc | 1 - changelog.d/10672.misc | 1 - changelog.d/10674.misc | 1 - changelog.d/10677.bugfix | 1 - changelog.d/10679.bugfix | 1 - changelog.d/10684.bugfix | 1 - changelog.d/10686.misc | 1 - changelog.d/10692.misc | 1 - changelog.d/10703.bugfix | 1 - changelog.d/10706.misc | 1 - changelog.d/10708.doc | 1 - changelog.d/10711.doc | 1 - changelog.d/10713.bugfix | 1 - changelog.d/10723.bugfix | 1 - changelog.d/10725.feature | 1 - changelog.d/10734.bugfix | 1 - changelog.d/8830.removal | 1 - debian/changelog | 6 +++++ docs/upgrade.md | 6 ++--- synapse/__init__.py | 2 +- 51 files changed, 78 insertions(+), 52 deletions(-) delete mode 100644 changelog.d/10142.feature delete mode 100644 changelog.d/10192.doc delete mode 100644 changelog.d/10232.bugfix delete mode 100644 changelog.d/10452.feature delete mode 100644 changelog.d/10524.feature delete mode 100644 changelog.d/10561.bugfix delete mode 100644 changelog.d/10581.bugfix delete mode 100644 changelog.d/10593.bugfix delete mode 100644 changelog.d/10595.doc delete mode 100644 changelog.d/10608.misc delete mode 100644 changelog.d/10613.feature delete mode 100644 changelog.d/10614.misc delete mode 100644 changelog.d/10615.misc delete mode 100644 changelog.d/10621.misc delete mode 100644 changelog.d/10624.misc delete mode 100644 changelog.d/10627.misc delete mode 100644 changelog.d/10629.misc delete mode 100644 changelog.d/10630.misc delete mode 100644 changelog.d/10639.doc delete mode 100644 changelog.d/10640.misc delete mode 100644 changelog.d/10642.misc delete mode 100644 changelog.d/10644.bugfix delete mode 100644 changelog.d/10645.misc delete mode 100644 changelog.d/10647.misc delete mode 100644 changelog.d/10651.misc delete mode 100644 changelog.d/10654.bugfix delete mode 100644 changelog.d/10662.misc delete mode 100644 changelog.d/10664.misc delete mode 100644 changelog.d/10665.misc delete mode 100644 changelog.d/10666.misc delete mode 100644 changelog.d/10667.misc delete mode 100644 changelog.d/10672.misc delete mode 100644 changelog.d/10674.misc delete mode 100644 changelog.d/10677.bugfix delete mode 100644 changelog.d/10679.bugfix delete mode 100644 changelog.d/10684.bugfix delete mode 100644 changelog.d/10686.misc delete mode 100644 changelog.d/10692.misc delete mode 100644 changelog.d/10703.bugfix delete mode 100644 changelog.d/10706.misc delete mode 100644 changelog.d/10708.doc delete mode 100644 changelog.d/10711.doc delete mode 100644 changelog.d/10713.bugfix delete mode 100644 changelog.d/10723.bugfix delete mode 100644 changelog.d/10725.feature delete mode 100644 changelog.d/10734.bugfix delete mode 100644 changelog.d/8830.removal diff --git a/CHANGES.md b/CHANGES.md index 7046c336a0..57ab44faa7 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,4 +1,71 @@ -Users will stop receiving message updates via email for addresses that were previously linked to their account +Synapse 1.42.0rc1 (2021-09-01) +============================== + +As of this release, users will stop receiving message updates via email for addresses that were previously linked to their account (but are not linked anymore). + + +Features +-------- + +- Add support for [MSC3231](https://github.com/matrix-org/matrix-doc/pull/3231): Token authenticated registration. Users can be required to submit a token during registration to authenticate themselves. Contributed by Callum Brown. ([\#10142](https://github.com/matrix-org/synapse/issues/10142)) +- Add support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283): Expose enable_set_displayname in capabilities. ([\#10452](https://github.com/matrix-org/synapse/issues/10452)) +- Port the PresenceRouter module interface to the new generic interface. ([\#10524](https://github.com/matrix-org/synapse/issues/10524)) +- Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). ([\#10613](https://github.com/matrix-org/synapse/issues/10613), [\#10725](https://github.com/matrix-org/synapse/issues/10725)) + + +Bugfixes +-------- + +- Validate new `m.room.power_levels` events. Contributed by @aaronraimist. ([\#10232](https://github.com/matrix-org/synapse/issues/10232)) +- Display an error on User-Interactive Authentication fallback pages when authentication fails. Contributed by Callum Brown. ([\#10561](https://github.com/matrix-org/synapse/issues/10561)) +- Remove pushers when deleting a 3pid from an account. Pushers for old unlinked emails will also be deleted. ([\#10581](https://github.com/matrix-org/synapse/issues/10581), [\#10734](https://github.com/matrix-org/synapse/issues/10734)) +- Reject Client-Server `/keys/query` requests which provide `device_ids` incorrectly. ([\#10593](https://github.com/matrix-org/synapse/issues/10593)) +- Rooms with unsupported room versions are no longer returned via `/sync`. ([\#10644](https://github.com/matrix-org/synapse/issues/10644)) +- Enforce the maximum length for per-room display names and avatar URLs. ([\#10654](https://github.com/matrix-org/synapse/issues/10654)) +- Fix a bug which caused the `synapse_user_logins_total` Prometheus metric not to be correctly initialised on restart. ([\#10677](https://github.com/matrix-org/synapse/issues/10677)) +- Improve ServerNoticeServlet to avoid duplicate requests and add unit tests. ([\#10679](https://github.com/matrix-org/synapse/issues/10679)) +- Fix long-standing issue which caused an error when a thumbnail is requested and there are multiple thumbnails with the same quality rating. ([\#10684](https://github.com/matrix-org/synapse/issues/10684)) +- Fix a regression introduced in v1.41.0 which affected the performance of concurrent fetches of large sets of events, in extreme cases causing the process to hang. ([\#10703](https://github.com/matrix-org/synapse/issues/10703)) +- Fix a regression introduced in Synapse 1.41 which broke email transmission on Systems using older versions of the Twisted library. ([\#10713](https://github.com/matrix-org/synapse/issues/10713)) +- Fix unauthorised exposure of room metadata to communities. ([\#10723](https://github.com/matrix-org/synapse/issues/10723)) + + +Improved Documentation +---------------------- + +- Add documentation on how to connect Django with synapse using oidc and django-oauth-toolkit. Contributed by @HugoDelval. ([\#10192](https://github.com/matrix-org/synapse/issues/10192)) +- Advertise https://matrix-org.github.io/synapse docs in README and CONTRIBUTING files. ([\#10595](https://github.com/matrix-org/synapse/issues/10595)) +- Fix some of the titles not rendering in the OIDC documentation. ([\#10639](https://github.com/matrix-org/synapse/issues/10639)) +- Minor clarifications to the documentation for reverse proxies. ([\#10708](https://github.com/matrix-org/synapse/issues/10708)) +- Removed table of contents from the top of installation and contributing documentation pages. ([\#10711](https://github.com/matrix-org/synapse/issues/10711)) + + +Deprecations and Removals +------------------------- + +- Remove deprecated Shutdown Room and Purge Room Admin API. ([\#8830](https://github.com/matrix-org/synapse/issues/8830)) + + +Internal Changes +---------------- + +- Improve type hints for the proxy agent and SRV resolver modules. Contributed by @dklimpel. ([\#10608](https://github.com/matrix-org/synapse/issues/10608)) +- Clean up some of the federation event authentication code for clarity. ([\#10614](https://github.com/matrix-org/synapse/issues/10614), [\#10615](https://github.com/matrix-org/synapse/issues/10615), [\#10624](https://github.com/matrix-org/synapse/issues/10624), [\#10640](https://github.com/matrix-org/synapse/issues/10640)) +- Add a comment asking developers to leave a reason when bumping the database schema version. ([\#10621](https://github.com/matrix-org/synapse/issues/10621)) +- Remove not needed database updates in modify user admin API. ([\#10627](https://github.com/matrix-org/synapse/issues/10627)) +- Convert room member storage tuples to `attrs` classes. ([\#10629](https://github.com/matrix-org/synapse/issues/10629), [\#10642](https://github.com/matrix-org/synapse/issues/10642)) +- Use auto-attribs for the attrs classes used in sync. ([\#10630](https://github.com/matrix-org/synapse/issues/10630)) +- Make `backfill` and `get_missing_events` use the same codepath. ([\#10645](https://github.com/matrix-org/synapse/issues/10645)) +- Improve the performance of the `/hierarchy` API (from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946)) by caching responses received over federation. ([\#10647](https://github.com/matrix-org/synapse/issues/10647)) +- Run a nightly CI build against Twisted trunk. ([\#10651](https://github.com/matrix-org/synapse/issues/10651), [\#10672](https://github.com/matrix-org/synapse/issues/10672)) +- Do not print out stack traces for network errors when fetching data over federation. ([\#10662](https://github.com/matrix-org/synapse/issues/10662)) +- Simplify tests for device admin rest API. ([\#10664](https://github.com/matrix-org/synapse/issues/10664)) +- Add missing type hints to REST servlets. ([\#10665](https://github.com/matrix-org/synapse/issues/10665), [\#10666](https://github.com/matrix-org/synapse/issues/10666), [\#10674](https://github.com/matrix-org/synapse/issues/10674)) +- Flatten the `tests.synapse.rests` package by moving the contents of `v1` and `v2_alpha` into the parent. ([\#10667](https://github.com/matrix-org/synapse/issues/10667)) +- Update `complement.sh` to rebuild the base Docker image when run with workers. ([\#10686](https://github.com/matrix-org/synapse/issues/10686)) +- Split the event-processing methods in `FederationHandler` into a separate `FederationEventHandler`. ([\#10692](https://github.com/matrix-org/synapse/issues/10692)) +- Remove unused `compare_digest` function. ([\#10706](https://github.com/matrix-org/synapse/issues/10706)) + Synapse 1.41.1 (2021-08-31) =========================== diff --git a/changelog.d/10142.feature b/changelog.d/10142.feature deleted file mode 100644 index 5353f6269d..0000000000 --- a/changelog.d/10142.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for [MSC3231 - Token authenticated registration](https://github.com/matrix-org/matrix-doc/pull/3231). Users can be required to submit a token during registration to authenticate themselves. Contributed by Callum Brown. diff --git a/changelog.d/10192.doc b/changelog.d/10192.doc deleted file mode 100644 index 3dd00537e8..0000000000 --- a/changelog.d/10192.doc +++ /dev/null @@ -1 +0,0 @@ -Add documentation on how to connect Django with synapse using oidc and django-oauth-toolkit. Contributed by @HugoDelval. diff --git a/changelog.d/10232.bugfix b/changelog.d/10232.bugfix deleted file mode 100644 index 7be72271e0..0000000000 --- a/changelog.d/10232.bugfix +++ /dev/null @@ -1 +0,0 @@ -Validate new `m.room.power_levels` events. Contributed by @aaronraimist. \ No newline at end of file diff --git a/changelog.d/10452.feature b/changelog.d/10452.feature deleted file mode 100644 index f332b383e3..0000000000 --- a/changelog.d/10452.feature +++ /dev/null @@ -1 +0,0 @@ -Add support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283): Expose enable_set_displayname in capabilities. \ No newline at end of file diff --git a/changelog.d/10524.feature b/changelog.d/10524.feature deleted file mode 100644 index 288c9bd74e..0000000000 --- a/changelog.d/10524.feature +++ /dev/null @@ -1 +0,0 @@ -Port the PresenceRouter module interface to the new generic interface. \ No newline at end of file diff --git a/changelog.d/10561.bugfix b/changelog.d/10561.bugfix deleted file mode 100644 index 2e4f53508c..0000000000 --- a/changelog.d/10561.bugfix +++ /dev/null @@ -1 +0,0 @@ -Display an error on User-Interactive Authentication fallback pages when authentication fails. Contributed by Callum Brown. diff --git a/changelog.d/10581.bugfix b/changelog.d/10581.bugfix deleted file mode 100644 index 15c7da4497..0000000000 --- a/changelog.d/10581.bugfix +++ /dev/null @@ -1 +0,0 @@ -Remove pushers when deleting a 3pid from an account. Pushers for old unlinked emails will also be deleted. \ No newline at end of file diff --git a/changelog.d/10593.bugfix b/changelog.d/10593.bugfix deleted file mode 100644 index af910bfa4d..0000000000 --- a/changelog.d/10593.bugfix +++ /dev/null @@ -1 +0,0 @@ -Reject Client-Server `/keys/query` requests which provide `device_ids` incorrectly. diff --git a/changelog.d/10595.doc b/changelog.d/10595.doc deleted file mode 100644 index 4823146d6b..0000000000 --- a/changelog.d/10595.doc +++ /dev/null @@ -1 +0,0 @@ -Advertise https://matrix-org.github.io/synapse docs in README and CONTRIBUTING files. diff --git a/changelog.d/10608.misc b/changelog.d/10608.misc deleted file mode 100644 index 875bdd2fd0..0000000000 --- a/changelog.d/10608.misc +++ /dev/null @@ -1 +0,0 @@ -Improve type hints for the proxy agent and SRV resolver modules. Contributed by @dklimpel. \ No newline at end of file diff --git a/changelog.d/10613.feature b/changelog.d/10613.feature deleted file mode 100644 index ffc4e4289c..0000000000 --- a/changelog.d/10613.feature +++ /dev/null @@ -1 +0,0 @@ -Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/changelog.d/10614.misc b/changelog.d/10614.misc deleted file mode 100644 index 9a765435db..0000000000 --- a/changelog.d/10614.misc +++ /dev/null @@ -1 +0,0 @@ -Clean up some of the federation event authentication code for clarity. diff --git a/changelog.d/10615.misc b/changelog.d/10615.misc deleted file mode 100644 index 9a765435db..0000000000 --- a/changelog.d/10615.misc +++ /dev/null @@ -1 +0,0 @@ -Clean up some of the federation event authentication code for clarity. diff --git a/changelog.d/10621.misc b/changelog.d/10621.misc deleted file mode 100644 index b8de2e1911..0000000000 --- a/changelog.d/10621.misc +++ /dev/null @@ -1 +0,0 @@ -Add a comment asking developers to leave a reason when bumping the database schema version. \ No newline at end of file diff --git a/changelog.d/10624.misc b/changelog.d/10624.misc deleted file mode 100644 index 9a765435db..0000000000 --- a/changelog.d/10624.misc +++ /dev/null @@ -1 +0,0 @@ -Clean up some of the federation event authentication code for clarity. diff --git a/changelog.d/10627.misc b/changelog.d/10627.misc deleted file mode 100644 index e6d314976e..0000000000 --- a/changelog.d/10627.misc +++ /dev/null @@ -1 +0,0 @@ -Remove not needed database updates in modify user admin API. \ No newline at end of file diff --git a/changelog.d/10629.misc b/changelog.d/10629.misc deleted file mode 100644 index cca1eb6c57..0000000000 --- a/changelog.d/10629.misc +++ /dev/null @@ -1 +0,0 @@ -Convert room member storage tuples to `attrs` classes. diff --git a/changelog.d/10630.misc b/changelog.d/10630.misc deleted file mode 100644 index 7d01e00e48..0000000000 --- a/changelog.d/10630.misc +++ /dev/null @@ -1 +0,0 @@ -Use auto-attribs for the attrs classes used in sync. diff --git a/changelog.d/10639.doc b/changelog.d/10639.doc deleted file mode 100644 index acbac4aad8..0000000000 --- a/changelog.d/10639.doc +++ /dev/null @@ -1 +0,0 @@ -Fix some of the titles not rendering in the OIDC documentation. diff --git a/changelog.d/10640.misc b/changelog.d/10640.misc deleted file mode 100644 index 9a765435db..0000000000 --- a/changelog.d/10640.misc +++ /dev/null @@ -1 +0,0 @@ -Clean up some of the federation event authentication code for clarity. diff --git a/changelog.d/10642.misc b/changelog.d/10642.misc deleted file mode 100644 index cca1eb6c57..0000000000 --- a/changelog.d/10642.misc +++ /dev/null @@ -1 +0,0 @@ -Convert room member storage tuples to `attrs` classes. diff --git a/changelog.d/10644.bugfix b/changelog.d/10644.bugfix deleted file mode 100644 index d88a81fd82..0000000000 --- a/changelog.d/10644.bugfix +++ /dev/null @@ -1 +0,0 @@ -Rooms with unsupported room versions are no longer returned via `/sync`. diff --git a/changelog.d/10645.misc b/changelog.d/10645.misc deleted file mode 100644 index ac19263cd8..0000000000 --- a/changelog.d/10645.misc +++ /dev/null @@ -1 +0,0 @@ -Make `backfill` and `get_missing_events` use the same codepath. diff --git a/changelog.d/10647.misc b/changelog.d/10647.misc deleted file mode 100644 index 4407a9030d..0000000000 --- a/changelog.d/10647.misc +++ /dev/null @@ -1 +0,0 @@ -Improve the performance of the `/hierarchy` API (from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946)) by caching responses received over federation. diff --git a/changelog.d/10651.misc b/changelog.d/10651.misc deleted file mode 100644 index 7104c121e0..0000000000 --- a/changelog.d/10651.misc +++ /dev/null @@ -1 +0,0 @@ -Run a nightly CI build against Twisted trunk. diff --git a/changelog.d/10654.bugfix b/changelog.d/10654.bugfix deleted file mode 100644 index b0bd78453f..0000000000 --- a/changelog.d/10654.bugfix +++ /dev/null @@ -1 +0,0 @@ -Enforce the maximum length for per-room display names and avatar URLs. \ No newline at end of file diff --git a/changelog.d/10662.misc b/changelog.d/10662.misc deleted file mode 100644 index 593f9ceaad..0000000000 --- a/changelog.d/10662.misc +++ /dev/null @@ -1 +0,0 @@ -Do not print out stack traces for network errors when fetching data over federation. diff --git a/changelog.d/10664.misc b/changelog.d/10664.misc deleted file mode 100644 index cebd5e9a96..0000000000 --- a/changelog.d/10664.misc +++ /dev/null @@ -1 +0,0 @@ -Simplify tests for device admin rest API. \ No newline at end of file diff --git a/changelog.d/10665.misc b/changelog.d/10665.misc deleted file mode 100644 index 39a37b90b1..0000000000 --- a/changelog.d/10665.misc +++ /dev/null @@ -1 +0,0 @@ -Add missing type hints to REST servlets. diff --git a/changelog.d/10666.misc b/changelog.d/10666.misc deleted file mode 100644 index 39a37b90b1..0000000000 --- a/changelog.d/10666.misc +++ /dev/null @@ -1 +0,0 @@ -Add missing type hints to REST servlets. diff --git a/changelog.d/10667.misc b/changelog.d/10667.misc deleted file mode 100644 index c92846ae26..0000000000 --- a/changelog.d/10667.misc +++ /dev/null @@ -1 +0,0 @@ -Flatten the `tests.synapse.rests` package by moving the contents of `v1` and `v2_alpha` into the parent. \ No newline at end of file diff --git a/changelog.d/10672.misc b/changelog.d/10672.misc deleted file mode 100644 index 7104c121e0..0000000000 --- a/changelog.d/10672.misc +++ /dev/null @@ -1 +0,0 @@ -Run a nightly CI build against Twisted trunk. diff --git a/changelog.d/10674.misc b/changelog.d/10674.misc deleted file mode 100644 index 39a37b90b1..0000000000 --- a/changelog.d/10674.misc +++ /dev/null @@ -1 +0,0 @@ -Add missing type hints to REST servlets. diff --git a/changelog.d/10677.bugfix b/changelog.d/10677.bugfix deleted file mode 100644 index 9964afaaee..0000000000 --- a/changelog.d/10677.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a bug which caused the `synapse_user_logins_total` Prometheus metric not to be correctly initialised on restart. diff --git a/changelog.d/10679.bugfix b/changelog.d/10679.bugfix deleted file mode 100644 index 5c4061f6d5..0000000000 --- a/changelog.d/10679.bugfix +++ /dev/null @@ -1 +0,0 @@ -Improve ServerNoticeServlet to avoid duplicate requests and add unit tests. \ No newline at end of file diff --git a/changelog.d/10684.bugfix b/changelog.d/10684.bugfix deleted file mode 100644 index 311b17601a..0000000000 --- a/changelog.d/10684.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix long-standing issue which caused an error when a thumbnail is requested and there are multiple thumbnails with the same quality rating. diff --git a/changelog.d/10686.misc b/changelog.d/10686.misc deleted file mode 100644 index b76908d74e..0000000000 --- a/changelog.d/10686.misc +++ /dev/null @@ -1 +0,0 @@ -Update `complement.sh` to rebuild the base Docker image when run with workers. diff --git a/changelog.d/10692.misc b/changelog.d/10692.misc deleted file mode 100644 index a1b0def76b..0000000000 --- a/changelog.d/10692.misc +++ /dev/null @@ -1 +0,0 @@ -Split the event-processing methods in `FederationHandler` into a separate `FederationEventHandler`. diff --git a/changelog.d/10703.bugfix b/changelog.d/10703.bugfix deleted file mode 100644 index a5a4ecf8ee..0000000000 --- a/changelog.d/10703.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a regression introduced in v1.41.0 which affected the performance of concurrent fetches of large sets of events, in extreme cases causing the process to hang. diff --git a/changelog.d/10706.misc b/changelog.d/10706.misc deleted file mode 100644 index eed4aa58d6..0000000000 --- a/changelog.d/10706.misc +++ /dev/null @@ -1 +0,0 @@ -Remove unused `compare_digest` function. diff --git a/changelog.d/10708.doc b/changelog.d/10708.doc deleted file mode 100644 index 99f9d69288..0000000000 --- a/changelog.d/10708.doc +++ /dev/null @@ -1 +0,0 @@ -Minor clarifications to the documentation for reverse proxies. diff --git a/changelog.d/10711.doc b/changelog.d/10711.doc deleted file mode 100644 index c495f98be8..0000000000 --- a/changelog.d/10711.doc +++ /dev/null @@ -1 +0,0 @@ -Removed table of contents from the top of installation and contributing documentation pages. \ No newline at end of file diff --git a/changelog.d/10713.bugfix b/changelog.d/10713.bugfix deleted file mode 100644 index e8caf3d23a..0000000000 --- a/changelog.d/10713.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix a regression introduced in Synapse 1.41 which broke email transmission on Systems using older versions of the Twisted library. diff --git a/changelog.d/10723.bugfix b/changelog.d/10723.bugfix deleted file mode 100644 index e6ffdc9512..0000000000 --- a/changelog.d/10723.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix unauthorised exposure of room metadata to communities. diff --git a/changelog.d/10725.feature b/changelog.d/10725.feature deleted file mode 100644 index ffc4e4289c..0000000000 --- a/changelog.d/10725.feature +++ /dev/null @@ -1 +0,0 @@ -Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/changelog.d/10734.bugfix b/changelog.d/10734.bugfix deleted file mode 100644 index 15c7da4497..0000000000 --- a/changelog.d/10734.bugfix +++ /dev/null @@ -1 +0,0 @@ -Remove pushers when deleting a 3pid from an account. Pushers for old unlinked emails will also be deleted. \ No newline at end of file diff --git a/changelog.d/8830.removal b/changelog.d/8830.removal deleted file mode 100644 index b3a93a9af2..0000000000 --- a/changelog.d/8830.removal +++ /dev/null @@ -1 +0,0 @@ -Remove deprecated Shutdown Room and Purge Room Admin API. \ No newline at end of file diff --git a/debian/changelog b/debian/changelog index 5f7a795b6e..0f7dbdf71e 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.42.0~rc1) stable; urgency=medium + + * New synapse release 1.42.0rc1. + + -- Synapse Packaging team Wed, 01 Sep 2021 11:37:48 +0100 + matrix-synapse-py3 (1.41.1) stable; urgency=high * New synapse release 1.41.1. diff --git a/docs/upgrade.md b/docs/upgrade.md index dcf0a7db5b..453dbbabe7 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -85,7 +85,7 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` -# Upgrading to v1.xx.0 +# Upgrading to v1.42.0 ## Removal of old Room Admin API @@ -107,12 +107,12 @@ This may affect you if you make use of custom HTML templates for the The template is now provided an `error` variable if the authentication process failed. See the default templates linked above for an example. -# Upgrading to v1.42.0 - ## Removal of out-of-date email pushers + Users will stop receiving message updates via email for addresses that were once, but not still, linked to their account. + # Upgrading to v1.41.0 ## Add support for routing outbound HTTP requests via a proxy for federation diff --git a/synapse/__init__.py b/synapse/__init__.py index 06d80f79b3..e5b075c53b 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.41.1" +__version__ = "1.42.0rc1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.5.1 From f8bf83b811cf54b4ce668bcda6043d14c4980f00 Mon Sep 17 00:00:00 2001 From: Sean Date: Wed, 1 Sep 2021 11:55:31 +0100 Subject: Skip the final GC on shutdown to improve restart times (#10712) Use `gc.freeze()` on exit to exclude all existing objects from the final GC. In testing, this sped up shutdown by up to a few seconds. `gc.freeze()` runs in constant time, so there is little chance of performance regression. Signed-off-by: Sean Quah --- changelog.d/10712.feature | 1 + synapse/app/_base.py | 7 +++++++ 2 files changed, 8 insertions(+) create mode 100644 changelog.d/10712.feature diff --git a/changelog.d/10712.feature b/changelog.d/10712.feature new file mode 100644 index 0000000000..d04db6f26f --- /dev/null +++ b/changelog.d/10712.feature @@ -0,0 +1 @@ +Skip final GC at shutdown to improve restart performance. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 39e28aff9f..6fc14930d1 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import atexit import gc import logging import os @@ -403,6 +404,12 @@ async def start(hs: "HomeServer"): gc.collect() gc.freeze() + # Speed up shutdowns by freezing all allocated objects. This moves everything + # into the permanent generation and excludes them from the final GC. + # Unfortunately only works on Python 3.7 + if platform.python_implementation() == "CPython" and sys.version_info >= (3, 7): + atexit.register(gc.freeze) + def setup_sentry(hs): """Enable sentry integration, if enabled in configuration -- cgit 1.5.1 From 70bef88731a1f16a50b126b9f6c8fba009f9130c Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 1 Sep 2021 12:04:08 +0100 Subject: Improve changelog --- CHANGES.md | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 57ab44faa7..0229d26981 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,15 +1,15 @@ Synapse 1.42.0rc1 (2021-09-01) ============================== -As of this release, users will stop receiving message updates via email for addresses that were previously linked to their account (but are not linked anymore). +This release fixes a bug where users who once added and then removed an e-mail address from their account would continue to receive e-mail notifications. Features -------- - Add support for [MSC3231](https://github.com/matrix-org/matrix-doc/pull/3231): Token authenticated registration. Users can be required to submit a token during registration to authenticate themselves. Contributed by Callum Brown. ([\#10142](https://github.com/matrix-org/synapse/issues/10142)) -- Add support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283): Expose enable_set_displayname in capabilities. ([\#10452](https://github.com/matrix-org/synapse/issues/10452)) -- Port the PresenceRouter module interface to the new generic interface. ([\#10524](https://github.com/matrix-org/synapse/issues/10524)) +- Add support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283): Expose `enable_set_displayname` in capabilities. ([\#10452](https://github.com/matrix-org/synapse/issues/10452)) +- Port the `PresenceRouter` module interface to the new generic interface. ([\#10524](https://github.com/matrix-org/synapse/issues/10524)) - Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). ([\#10613](https://github.com/matrix-org/synapse/issues/10613), [\#10725](https://github.com/matrix-org/synapse/issues/10725)) @@ -23,21 +23,20 @@ Bugfixes - Rooms with unsupported room versions are no longer returned via `/sync`. ([\#10644](https://github.com/matrix-org/synapse/issues/10644)) - Enforce the maximum length for per-room display names and avatar URLs. ([\#10654](https://github.com/matrix-org/synapse/issues/10654)) - Fix a bug which caused the `synapse_user_logins_total` Prometheus metric not to be correctly initialised on restart. ([\#10677](https://github.com/matrix-org/synapse/issues/10677)) -- Improve ServerNoticeServlet to avoid duplicate requests and add unit tests. ([\#10679](https://github.com/matrix-org/synapse/issues/10679)) +- Improve `ServerNoticeServlet` to avoid duplicate requests and add unit tests. ([\#10679](https://github.com/matrix-org/synapse/issues/10679)) - Fix long-standing issue which caused an error when a thumbnail is requested and there are multiple thumbnails with the same quality rating. ([\#10684](https://github.com/matrix-org/synapse/issues/10684)) - Fix a regression introduced in v1.41.0 which affected the performance of concurrent fetches of large sets of events, in extreme cases causing the process to hang. ([\#10703](https://github.com/matrix-org/synapse/issues/10703)) - Fix a regression introduced in Synapse 1.41 which broke email transmission on Systems using older versions of the Twisted library. ([\#10713](https://github.com/matrix-org/synapse/issues/10713)) -- Fix unauthorised exposure of room metadata to communities. ([\#10723](https://github.com/matrix-org/synapse/issues/10723)) Improved Documentation ---------------------- -- Add documentation on how to connect Django with synapse using oidc and django-oauth-toolkit. Contributed by @HugoDelval. ([\#10192](https://github.com/matrix-org/synapse/issues/10192)) +- Add documentation on how to connect Django with synapse using OIDC and django-oauth-toolkit. Contributed by @HugoDelval. ([\#10192](https://github.com/matrix-org/synapse/issues/10192)) - Advertise https://matrix-org.github.io/synapse docs in README and CONTRIBUTING files. ([\#10595](https://github.com/matrix-org/synapse/issues/10595)) - Fix some of the titles not rendering in the OIDC documentation. ([\#10639](https://github.com/matrix-org/synapse/issues/10639)) - Minor clarifications to the documentation for reverse proxies. ([\#10708](https://github.com/matrix-org/synapse/issues/10708)) -- Removed table of contents from the top of installation and contributing documentation pages. ([\#10711](https://github.com/matrix-org/synapse/issues/10711)) +- Remove table of contents from the top of installation and contributing documentation pages. ([\#10711](https://github.com/matrix-org/synapse/issues/10711)) Deprecations and Removals -- cgit 1.5.1 From 940d4d3ac1dc3aad458364d55f790c5539f3ce0e Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 1 Sep 2021 12:07:33 +0100 Subject: Improve changelog Expand OIDC to OpenID Connect. --- CHANGES.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 0229d26981..7d5a6e2b43 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -32,9 +32,9 @@ Bugfixes Improved Documentation ---------------------- -- Add documentation on how to connect Django with synapse using OIDC and django-oauth-toolkit. Contributed by @HugoDelval. ([\#10192](https://github.com/matrix-org/synapse/issues/10192)) -- Advertise https://matrix-org.github.io/synapse docs in README and CONTRIBUTING files. ([\#10595](https://github.com/matrix-org/synapse/issues/10595)) -- Fix some of the titles not rendering in the OIDC documentation. ([\#10639](https://github.com/matrix-org/synapse/issues/10639)) +- Add documentation on how to connect Django with Synapse using OpenID Connect and django-oauth-toolkit. Contributed by @HugoDelval. ([\#10192](https://github.com/matrix-org/synapse/issues/10192)) +- Advertise https://matrix-org.github.io/synapse documentation in the `README` and `CONTRIBUTING` files. ([\#10595](https://github.com/matrix-org/synapse/issues/10595)) +- Fix some of the titles not rendering in the OpenID Connect documentation. ([\#10639](https://github.com/matrix-org/synapse/issues/10639)) - Minor clarifications to the documentation for reverse proxies. ([\#10708](https://github.com/matrix-org/synapse/issues/10708)) - Remove table of contents from the top of installation and contributing documentation pages. ([\#10711](https://github.com/matrix-org/synapse/issues/10711)) -- cgit 1.5.1 From d9069388f3aa22d548b4d51c069d42bd644b7ff4 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 1 Sep 2021 13:48:41 +0100 Subject: Correctly include room avatars in email notifications (#10658) Judging by the template, this was intended ages ago, but we never actually passed an avatar URL to the template. So let's provide one. Closes #1546. Co-authored-by: Patrick Cloke --- changelog.d/10658.bugfix | 1 + synapse/push/mailer.py | 24 +++++++++++++++++++++- tests/push/test_email.py | 52 +++++++++++++++++++++++++++++++++++++++++++----- 3 files changed, 71 insertions(+), 6 deletions(-) create mode 100644 changelog.d/10658.bugfix diff --git a/changelog.d/10658.bugfix b/changelog.d/10658.bugfix new file mode 100644 index 0000000000..a59d402933 --- /dev/null +++ b/changelog.d/10658.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where room avatars were not included in email notifications. diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 941fb238b7..b0834720ad 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -258,7 +258,7 @@ class Mailer: # actually sort our so-called rooms_in_order list, most recent room first rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0)) - rooms = [] + rooms: List[Dict[str, Any]] = [] for r in rooms_in_order: roomvars = await self._get_room_vars( @@ -362,6 +362,7 @@ class Mailer: "notifs": [], "invite": is_invite, "link": self._make_room_link(room_id), + "avatar_url": await self._get_room_avatar(room_state_ids), } if not is_invite: @@ -393,6 +394,27 @@ class Mailer: return room_vars + async def _get_room_avatar( + self, + room_state_ids: StateMap[str], + ) -> Optional[str]: + """ + Retrieve the avatar url for this room---if it exists. + + Args: + room_state_ids: The event IDs of the current room state. + + Returns: + room's avatar url if it's present and a string; otherwise None. + """ + event_id = room_state_ids.get((EventTypes.RoomAvatar, "")) + if event_id: + ev = await self.store.get_event(event_id) + url = ev.content.get("url") + if isinstance(url, str): + return url + return None + async def _get_notif_vars( self, notif: Dict[str, Any], diff --git a/tests/push/test_email.py b/tests/push/test_email.py index eea07485a0..2bed7302cf 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import email.message import os +from typing import Dict, List, Sequence, Tuple import attr import pkg_resources @@ -70,9 +71,10 @@ class EmailPusherTests(HomeserverTestCase): hs = self.setup_test_homeserver(config=config) # List[Tuple[Deferred, args, kwargs]] - self.email_attempts = [] + self.email_attempts: List[Tuple[Deferred, Sequence, Dict]] = [] def sendmail(*args, **kwargs): + # This mocks out synapse.reactor.send_email._sendmail. d = Deferred() self.email_attempts.append((d, args, kwargs)) return d @@ -255,6 +257,39 @@ class EmailPusherTests(HomeserverTestCase): # We should get emailed about those messages self._check_for_mail() + def test_room_notifications_include_avatar(self): + # Create a room and set its avatar. + room = self.helper.create_room_as(self.user_id, tok=self.access_token) + self.helper.send_state( + room, "m.room.avatar", {"url": "mxc://DUMMY_MEDIA_ID"}, self.access_token + ) + + # Invite two other uses. + for other in self.others: + self.helper.invite( + room=room, src=self.user_id, tok=self.access_token, targ=other.id + ) + self.helper.join(room=room, user=other.id, tok=other.token) + + # The other users send some messages. + # TODO It seems that two messages are required to trigger an email? + self.helper.send(room, body="Alpha", tok=self.others[0].token) + self.helper.send(room, body="Beta", tok=self.others[1].token) + + # We should get emailed about those messages + args, kwargs = self._check_for_mail() + + # That email should contain the room's avatar + msg: bytes = args[5] + # Multipart: plain text, base 64 encoded; html, base 64 encoded + html = ( + email.message_from_bytes(msg) + .get_payload()[1] + .get_payload(decode=True) + .decode() + ) + self.assertIn("_matrix/media/v1/thumbnail/DUMMY_MEDIA_ID", html) + def test_empty_room(self): """All users leaving a room shouldn't cause the pusher to break.""" # Create a simple room with two users @@ -344,9 +379,14 @@ class EmailPusherTests(HomeserverTestCase): pushers = list(pushers) self.assertEqual(len(pushers), 0) - def _check_for_mail(self): - """Check that the user receives an email notification""" + def _check_for_mail(self) -> Tuple[Sequence, Dict]: + """ + Assert that synapse sent off exactly one email notification. + Returns: + args and kwargs passed to synapse.reactor.send_email._sendmail for + that notification. + """ # Get the stream ordering before it gets sent pushers = self.get_success( self.hs.get_datastore().get_pushers_by({"user_name": self.user_id}) @@ -369,8 +409,9 @@ class EmailPusherTests(HomeserverTestCase): # One email was attempted to be sent self.assertEqual(len(self.email_attempts), 1) + deferred, sendmail_args, sendmail_kwargs = self.email_attempts[0] # Make the email succeed - self.email_attempts[0][0].callback(True) + deferred.callback(True) self.pump() # One email was attempted to be sent @@ -386,3 +427,4 @@ class EmailPusherTests(HomeserverTestCase): # Reset the attempts. self.email_attempts = [] + return sendmail_args, sendmail_kwargs -- cgit 1.5.1 From c6e103c1a6c2059a7a48c59bd3b443bed8891610 Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Wed, 1 Sep 2021 13:49:16 +0100 Subject: Make minor changes to changelog --- CHANGES.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 7d5a6e2b43..986efbba0d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,7 +1,7 @@ Synapse 1.42.0rc1 (2021-09-01) ============================== -This release fixes a bug where users who once added and then removed an e-mail address from their account would continue to receive e-mail notifications. +Server administrators are reminded to read [the upgrade notes](docs/upgrade.md#upgrading-to-v1420). Features @@ -18,7 +18,7 @@ Bugfixes - Validate new `m.room.power_levels` events. Contributed by @aaronraimist. ([\#10232](https://github.com/matrix-org/synapse/issues/10232)) - Display an error on User-Interactive Authentication fallback pages when authentication fails. Contributed by Callum Brown. ([\#10561](https://github.com/matrix-org/synapse/issues/10561)) -- Remove pushers when deleting a 3pid from an account. Pushers for old unlinked emails will also be deleted. ([\#10581](https://github.com/matrix-org/synapse/issues/10581), [\#10734](https://github.com/matrix-org/synapse/issues/10734)) +- Remove pushers when deleting an e-mail address from an account. Pushers for old unlinked emails will also be deleted. ([\#10581](https://github.com/matrix-org/synapse/issues/10581), [\#10734](https://github.com/matrix-org/synapse/issues/10734)) - Reject Client-Server `/keys/query` requests which provide `device_ids` incorrectly. ([\#10593](https://github.com/matrix-org/synapse/issues/10593)) - Rooms with unsupported room versions are no longer returned via `/sync`. ([\#10644](https://github.com/matrix-org/synapse/issues/10644)) - Enforce the maximum length for per-room display names and avatar URLs. ([\#10654](https://github.com/matrix-org/synapse/issues/10654)) -- cgit 1.5.1 From dc75fb7f0552e6a9903a7c173672c96610219ec0 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 1 Sep 2021 10:27:58 -0500 Subject: Populate `rooms.creator` field for easy lookup (#10697) Part of https://github.com/matrix-org/synapse/pull/10566 - Fill in creator whenever we insert into the rooms table - Add background update to backfill any missing creator values --- changelog.d/10697.misc | 1 + synapse/api/constants.py | 3 + synapse/handlers/federation.py | 1 + synapse/storage/databases/main/room.py | 97 ++++++++++++++++++++- .../main/delta/63/02populate-rooms-creator.sql | 17 ++++ tests/storage/databases/main/test_room.py | 98 ++++++++++++++++++++++ 6 files changed, 213 insertions(+), 4 deletions(-) create mode 100644 changelog.d/10697.misc create mode 100644 synapse/storage/schema/main/delta/63/02populate-rooms-creator.sql create mode 100644 tests/storage/databases/main/test_room.py diff --git a/changelog.d/10697.misc b/changelog.d/10697.misc new file mode 100644 index 0000000000..a9ad17faf2 --- /dev/null +++ b/changelog.d/10697.misc @@ -0,0 +1 @@ +Ensure `rooms.creator` field is always populated for easy lookup in [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) usage later. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 829061c870..5e34eb7e13 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -198,6 +198,9 @@ class EventContentFields: # cf https://github.com/matrix-org/matrix-doc/pull/1772 ROOM_TYPE = "type" + # The creator of the room, as used in `m.room.create` events. + ROOM_CREATOR = "creator" + # Used on normal messages to indicate they were historically imported after the fact MSC2716_HISTORICAL = "org.matrix.msc2716.historical" # For "insertion" events to indicate what the next chunk ID should be in diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index daf1d3bfb3..77df9185f6 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -507,6 +507,7 @@ class FederationHandler(BaseHandler): await self.store.upsert_room_on_join( room_id=room_id, room_version=room_version_obj, + auth_events=auth_chain, ) max_stream_id = await self._persist_auth_tree( diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index f98b892598..6e7312266d 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -19,9 +19,10 @@ from abc import abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Tuple -from synapse.api.constants import EventTypes, JoinRules +from synapse.api.constants import EventContentFields, EventTypes, JoinRules from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchStore @@ -1013,6 +1014,7 @@ class _BackgroundUpdates: ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column" POPULATE_ROOM_DEPTH_MIN_DEPTH2 = "populate_room_depth_min_depth2" REPLACE_ROOM_DEPTH_MIN_DEPTH = "replace_room_depth_min_depth" + POPULATE_ROOMS_CREATOR_COLUMN = "populate_rooms_creator_column" _REPLACE_ROOM_DEPTH_SQL_COMMANDS = ( @@ -1054,6 +1056,11 @@ class RoomBackgroundUpdateStore(SQLBaseStore): self._background_replace_room_depth_min_depth, ) + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN, + self._background_populate_rooms_creator_column, + ) + async def _background_insert_retention(self, progress, batch_size): """Retrieves a list of all rooms within a range and inserts an entry for each of them into the room_retention table. @@ -1273,7 +1280,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): keyvalues={"room_id": room_id}, retcol="MAX(stream_ordering)", allow_none=True, - desc="upsert_room_on_join", + desc="has_auth_chain_index_fallback", ) return max_ordering is None @@ -1343,6 +1350,65 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return 0 + async def _background_populate_rooms_creator_column( + self, progress: dict, batch_size: int + ): + """Background update to go and add creator information to `rooms` + table from `current_state_events` table. + """ + + last_room_id = progress.get("room_id", "") + + def _background_populate_rooms_creator_column_txn(txn: LoggingTransaction): + sql = """ + SELECT room_id, json FROM event_json + INNER JOIN rooms AS room USING (room_id) + INNER JOIN current_state_events AS state_event USING (room_id, event_id) + WHERE room_id > ? AND (room.creator IS NULL OR room.creator = '') AND state_event.type = 'm.room.create' AND state_event.state_key = '' + ORDER BY room_id + LIMIT ? + """ + + txn.execute(sql, (last_room_id, batch_size)) + room_id_to_create_event_results = txn.fetchall() + + new_last_room_id = "" + for room_id, event_json in room_id_to_create_event_results: + event_dict = db_to_json(event_json) + + creator = event_dict.get("content").get(EventContentFields.ROOM_CREATOR) + + self.db_pool.simple_update_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"creator": creator}, + ) + new_last_room_id = room_id + + if new_last_room_id == "": + return True + + self.db_pool.updates._background_update_progress_txn( + txn, + _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN, + {"room_id": new_last_room_id}, + ) + + return False + + end = await self.db_pool.runInteraction( + "_background_populate_rooms_creator_column", + _background_populate_rooms_creator_column_txn, + ) + + if end: + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN + ) + + return batch_size + class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def __init__(self, database: DatabasePool, db_conn, hs): @@ -1350,7 +1416,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): self.config = hs.config - async def upsert_room_on_join(self, room_id: str, room_version: RoomVersion): + async def upsert_room_on_join( + self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase] + ): """Ensure that the room is stored in the table Called when we join a room over federation, and overwrites any room version @@ -1361,6 +1429,24 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): # mark the room as having an auth chain cover index. has_auth_chain_index = await self.has_auth_chain_index(room_id) + create_event = None + for e in auth_events: + if (e.type, e.state_key) == (EventTypes.Create, ""): + create_event = e + break + + if create_event is None: + # If the state doesn't have a create event then the room is + # invalid, and it would fail auth checks anyway. + raise StoreError(400, "No create event in state") + + room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) + + if not isinstance(room_creator, str): + # If the create event does not have a creator then the room is + # invalid, and it would fail auth checks anyway. + raise StoreError(400, "No creator defined on the create event") + await self.db_pool.simple_upsert( desc="upsert_room_on_join", table="rooms", @@ -1368,7 +1454,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): values={"room_version": room_version.identifier}, insertion_values={ "is_public": False, - "creator": "", + "creator": room_creator, "has_auth_chain_index": has_auth_chain_index, }, # rooms has a unique constraint on room_id, so no need to lock when doing an @@ -1396,6 +1482,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): insertion_values={ "room_version": room_version.identifier, "is_public": False, + # We don't worry about setting the `creator` here because + # we don't process any messages in a room while a user is + # invited (only after the join). "creator": "", "has_auth_chain_index": has_auth_chain_index, }, diff --git a/synapse/storage/schema/main/delta/63/02populate-rooms-creator.sql b/synapse/storage/schema/main/delta/63/02populate-rooms-creator.sql new file mode 100644 index 0000000000..f7c0b31261 --- /dev/null +++ b/synapse/storage/schema/main/delta/63/02populate-rooms-creator.sql @@ -0,0 +1,17 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) + VALUES (6302, 'populate_rooms_creator_column', '{}'); diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py new file mode 100644 index 0000000000..ffee707153 --- /dev/null +++ b/tests/storage/databases/main/test_room.py @@ -0,0 +1,98 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.rest import admin +from synapse.rest.client import login, room +from synapse.storage.databases.main.room import _BackgroundUpdates + +from tests.unittest import HomeserverTestCase + + +class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase): + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.user_id = self.register_user("foo", "pass") + self.token = self.login("foo", "pass") + + def _generate_room(self) -> str: + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + return room_id + + def test_background_populate_rooms_creator_column(self): + """Test that the background update to populate the rooms creator column + works properly. + """ + + # Insert a room without the creator + room_id = self._generate_room() + self.get_success( + self.store.db_pool.simple_update( + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"creator": None}, + desc="test", + ) + ) + + # Make sure the test is starting out with a room without a creator + room_creator_before = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table="rooms", + keyvalues={"room_id": room_id}, + retcol="creator", + allow_none=True, + ) + ) + self.assertEqual(room_creator_before, None) + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + { + "update_name": _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN, + "progress_json": "{}", + }, + ) + ) + + # ... and tell the DataStore that it hasn't finished all updates yet + self.store.db_pool.updates._all_done = False + + # Now let's actually drive the updates to completion + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # Make sure the background update filled in the room creator + room_creator_after = self.get_success( + self.store.db_pool.simple_select_one_onecol( + table="rooms", + keyvalues={"room_id": room_id}, + retcol="creator", + allow_none=True, + ) + ) + self.assertEqual(room_creator_after, self.user_id) -- cgit 1.5.1 From d1f1b46c2cfe612790aebdd54765953951c94e45 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 1 Sep 2021 11:59:32 -0400 Subject: Additional type hints for client REST servlets (part 4) (#10728) --- changelog.d/10728.misc | 1 + synapse/rest/client/account.py | 82 ++++++++++++++++++------------------- synapse/rest/client/knock.py | 6 ++- synapse/rest/client/register.py | 14 +++---- synapse/rest/client/report_event.py | 15 +++++-- synapse/rest/client/room_batch.py | 31 ++++++++++---- synapse/rest/client/room_keys.py | 53 +++++++++++++----------- synapse/rest/client/sendtodevice.py | 27 +++++++----- synapse/rest/client/sync.py | 16 +++++++- 9 files changed, 145 insertions(+), 100 deletions(-) create mode 100644 changelog.d/10728.misc diff --git a/changelog.d/10728.misc b/changelog.d/10728.misc new file mode 100644 index 0000000000..39a37b90b1 --- /dev/null +++ b/changelog.d/10728.misc @@ -0,0 +1 @@ +Add missing type hints to REST servlets. diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index fb5ad2906e..aefaaa8ae8 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -16,9 +16,11 @@ import logging import random from http import HTTPStatus -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Tuple from urllib.parse import urlparse +from twisted.web.server import Request + from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, @@ -28,15 +30,17 @@ from synapse.api.errors import ( ) from synapse.config.emailconfig import ThreepidBehaviour from synapse.handlers.ui_auth import UIAuthSessionDataConstants -from synapse.http.server import finish_request, respond_with_html +from synapse.http.server import HttpServer, finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.metrics import threepid_send_requests from synapse.push.mailer import Mailer +from synapse.types import JsonDict from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.threepids import check_3pid_allowed, validate_email @@ -68,7 +72,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): template_text=self.config.email_password_reset_template_text, ) - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -159,7 +163,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): class PasswordRestServlet(RestServlet): PATTERNS = client_patterns("/account/password$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -169,7 +173,7 @@ class PasswordRestServlet(RestServlet): self._set_password_handler = hs.get_set_password_handler() @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) # we do basic sanity checks here because the auth layer will store these @@ -190,6 +194,7 @@ class PasswordRestServlet(RestServlet): # # In the second case, we require a password to confirm their identity. + requester = None if self.auth.has_access_token(request): requester = await self.auth.get_user_by_req(request) try: @@ -206,16 +211,15 @@ class PasswordRestServlet(RestServlet): # If a password is available now, hash the provided password and # store it for later. if new_password: - password_hash = await self.auth_handler.hash(new_password) + new_password_hash = await self.auth_handler.hash(new_password) await self.auth_handler.set_session_data( e.session_id, UIAuthSessionDataConstants.PASSWORD_HASH, - password_hash, + new_password_hash, ) raise user_id = requester.user.to_string() else: - requester = None try: result, params, session_id = await self.auth_handler.check_ui_auth( [[LoginType.EMAIL_IDENTITY]], @@ -230,11 +234,11 @@ class PasswordRestServlet(RestServlet): # If a password is available now, hash the provided password and # store it for later. if new_password: - password_hash = await self.auth_handler.hash(new_password) + new_password_hash = await self.auth_handler.hash(new_password) await self.auth_handler.set_session_data( e.session_id, UIAuthSessionDataConstants.PASSWORD_HASH, - password_hash, + new_password_hash, ) raise @@ -264,7 +268,7 @@ class PasswordRestServlet(RestServlet): # If we have a password in this request, prefer it. Otherwise, use the # password hash from an earlier request. if new_password: - password_hash = await self.auth_handler.hash(new_password) + password_hash: Optional[str] = await self.auth_handler.hash(new_password) elif session_id is not None: password_hash = await self.auth_handler.get_session_data( session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None @@ -288,7 +292,7 @@ class PasswordRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet): PATTERNS = client_patterns("/account/deactivate$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -296,7 +300,7 @@ class DeactivateAccountRestServlet(RestServlet): self._deactivate_account_handler = hs.get_deactivate_account_handler() @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) erase = body.get("erase", False) if not isinstance(erase, bool): @@ -338,7 +342,7 @@ class DeactivateAccountRestServlet(RestServlet): class EmailThreepidRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/email/requestToken$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.config = hs.config @@ -353,7 +357,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): template_text=self.config.email_add_threepid_template_text, ) - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -449,7 +453,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): self.store = self.hs.get_datastore() self.identity_handler = hs.get_identity_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) assert_params_in_dict( body, ["client_secret", "country", "phone_number", "send_attempt"] @@ -525,11 +529,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): "/add_threepid/email/submit_token$", releases=(), unstable=True ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.config = hs.config self.clock = hs.get_clock() @@ -539,7 +539,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): self.config.email_add_threepid_template_failure_html ) - async def on_GET(self, request): + async def on_GET(self, request: Request) -> None: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -596,18 +596,14 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): "/add_threepid/msisdn/submit_token$", releases=(), unstable=True ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.config = hs.config self.clock = hs.get_clock() self.store = hs.get_datastore() self.identity_handler = hs.get_identity_handler() - async def on_POST(self, request): + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: if not self.config.account_threepid_delegate_msisdn: raise SynapseError( 400, @@ -632,7 +628,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): class ThreepidRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() @@ -640,14 +636,14 @@ class ThreepidRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.datastore = self.hs.get_datastore() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) threepids = await self.datastore.user_get_threepids(requester.user.to_string()) return 200, {"threepids": threepids} - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if not self.hs.config.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN @@ -688,7 +684,7 @@ class ThreepidRestServlet(RestServlet): class ThreepidAddRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/add$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() @@ -696,7 +692,7 @@ class ThreepidAddRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if not self.hs.config.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN @@ -738,13 +734,13 @@ class ThreepidAddRestServlet(RestServlet): class ThreepidBindRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/bind$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) assert_params_in_dict(body, ["id_server", "sid", "client_secret"]) @@ -767,14 +763,14 @@ class ThreepidBindRestServlet(RestServlet): class ThreepidUnbindRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/unbind$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() self.datastore = self.hs.get_datastore() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """Unbind the given 3pid from a specific identity server, or identity servers that are known to have this 3pid bound """ @@ -798,13 +794,13 @@ class ThreepidUnbindRestServlet(RestServlet): class ThreepidDeleteRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/delete$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if not self.hs.config.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN @@ -835,7 +831,7 @@ class ThreepidDeleteRestServlet(RestServlet): return 200, {"id_server_unbind_result": id_server_unbind_result} -def assert_valid_next_link(hs: "HomeServer", next_link: str): +def assert_valid_next_link(hs: "HomeServer", next_link: str) -> None: """ Raises a SynapseError if a given next_link value is invalid @@ -877,11 +873,11 @@ def assert_valid_next_link(hs: "HomeServer", next_link: str): class WhoamiRestServlet(RestServlet): PATTERNS = client_patterns("/account/whoami$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) response = {"user_id": requester.user.to_string()} @@ -894,7 +890,7 @@ class WhoamiRestServlet(RestServlet): return 200, response -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: EmailPasswordRequestTokenRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py index 68fb08d0ba..0152a0c66a 100644 --- a/synapse/rest/client/knock.py +++ b/synapse/rest/client/knock.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple from twisted.web.server import Request @@ -96,7 +96,9 @@ class KnockRoomAliasServlet(RestServlet): return 200, {"room_id": room_id} - def on_PUT(self, request: Request, room_identifier: str, txn_id: str): + def on_PUT( + self, request: Request, room_identifier: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index a28acd4041..8f3dd2a101 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -375,11 +375,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): unstable=True, ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.store = hs.get_datastore() @@ -390,7 +386,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count, ) - async def on_GET(self, request): + async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: await self.ratelimiter.ratelimit(None, (request.getClientIP(),)) if not self.hs.config.enable_registration: @@ -730,7 +726,11 @@ class RegisterRestServlet(RestServlet): return 200, return_dict async def _do_appservice_registration( - self, username, as_token, body, should_issue_refresh_token: bool = False + self, + username: str, + as_token: str, + body: JsonDict, + should_issue_refresh_token: bool = False, ) -> JsonDict: user_id = await self.registration_handler.appservice_register( username, as_token diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py index 07ea39a8a3..d4a4adb50c 100644 --- a/synapse/rest/client/report_event.py +++ b/synapse/rest/client/report_event.py @@ -14,26 +14,35 @@ import logging from http import HTTPStatus +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class ReportEventRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P[^/]*)/report/(?P[^/]*)$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() - async def on_POST(self, request, room_id, event_id): + async def on_POST( + self, request: SynapseRequest, room_id: str, event_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -64,5 +73,5 @@ class ReportEventRestServlet(RestServlet): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReportEventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index 3172aba605..ed96978448 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -14,10 +14,14 @@ import logging import re +from typing import TYPE_CHECKING, Awaitable, List, Tuple + +from twisted.web.server import Request from synapse.api.constants import EventContentFields, EventTypes from synapse.api.errors import AuthError, Codes, SynapseError from synapse.appservice import ApplicationService +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -25,10 +29,14 @@ from synapse.http.servlet import ( parse_string, parse_strings_from_args, ) +from synapse.http.site import SynapseRequest from synapse.rest.client.transactions import HttpTransactionCache -from synapse.types import Requester, UserID, create_requester +from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.util.stringutils import random_string +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -66,7 +74,7 @@ class RoomBatchSendEventRestServlet(RestServlet): ), ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.store = hs.get_datastore() @@ -76,7 +84,7 @@ class RoomBatchSendEventRestServlet(RestServlet): self.auth = hs.get_auth() self.txns = HttpTransactionCache(hs) - async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int: + async def _inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int: ( most_recent_prev_event_id, most_recent_prev_event_depth, @@ -118,7 +126,7 @@ class RoomBatchSendEventRestServlet(RestServlet): def _create_insertion_event_dict( self, sender: str, room_id: str, origin_server_ts: int - ): + ) -> JsonDict: """Creates an event dict for an "insertion" event with the proper fields and a random chunk ID. @@ -128,7 +136,7 @@ class RoomBatchSendEventRestServlet(RestServlet): origin_server_ts: Timestamp when the event was sent Returns: - Tuple of event ID and stream ordering position + The new event dictionary to insert. """ next_chunk_id = random_string(8) @@ -164,7 +172,9 @@ class RoomBatchSendEventRestServlet(RestServlet): return create_requester(user_id, app_service=app_service) - async def on_POST(self, request, room_id): + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=False) if not requester.app_service: @@ -176,6 +186,7 @@ class RoomBatchSendEventRestServlet(RestServlet): body = parse_json_object_from_request(request) assert_params_in_dict(body, ["state_events_at_start", "events"]) + assert request.args is not None prev_events_from_query = parse_strings_from_args(request.args, "prev_event") chunk_id_from_query = parse_string(request, "chunk_id") @@ -425,16 +436,18 @@ class RoomBatchSendEventRestServlet(RestServlet): ], } - def on_GET(self, request, room_id): + def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]: return 501, "Not implemented" - def on_PUT(self, request, room_id): + def on_PUT( + self, request: SynapseRequest, room_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: return self.txns.fetch_or_execute_request( request, self.on_POST, request, room_id ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: msc2716_enabled = hs.config.experimental.msc2716_enabled if msc2716_enabled: diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py index 263596be86..37e39570f6 100644 --- a/synapse/rest/client/room_keys.py +++ b/synapse/rest/client/room_keys.py @@ -13,16 +13,23 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Optional, Tuple from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -31,16 +38,14 @@ class RoomKeysServlet(RestServlet): "/room_keys/keys(/(?P[^/]+))?(/(?P[^/]+))?$" ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - async def on_PUT(self, request, room_id, session_id): + async def on_PUT( + self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str] + ) -> Tuple[int, JsonDict]: """ Uploads one or more encrypted E2E room keys for backup purposes. room_id: the ID of the room the keys are for (optional) @@ -133,7 +138,9 @@ class RoomKeysServlet(RestServlet): ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) return 200, ret - async def on_GET(self, request, room_id, session_id): + async def on_GET( + self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str] + ) -> Tuple[int, JsonDict]: """ Retrieves one or more encrypted E2E room keys for backup purposes. Symmetric with the PUT version of the API. @@ -215,7 +222,9 @@ class RoomKeysServlet(RestServlet): return 200, room_keys - async def on_DELETE(self, request, room_id, session_id): + async def on_DELETE( + self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str] + ) -> Tuple[int, JsonDict]: """ Deletes one or more encrypted E2E room keys for a user for backup purposes. @@ -242,16 +251,12 @@ class RoomKeysServlet(RestServlet): class RoomKeysNewVersionServlet(RestServlet): PATTERNS = client_patterns("/room_keys/version$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """ Create a new backup version for this user's room_keys with the given info. The version is allocated by the server and returned to the user @@ -295,16 +300,14 @@ class RoomKeysNewVersionServlet(RestServlet): class RoomKeysVersionServlet(RestServlet): PATTERNS = client_patterns("/room_keys/version(/(?P[^/]+))?$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - async def on_GET(self, request, version): + async def on_GET( + self, request: SynapseRequest, version: Optional[str] + ) -> Tuple[int, JsonDict]: """ Retrieve the version information about a given version of the user's room_keys backup. If the version part is missing, returns info about the @@ -332,7 +335,9 @@ class RoomKeysVersionServlet(RestServlet): raise SynapseError(404, "No backup found", Codes.NOT_FOUND) return 200, info - async def on_DELETE(self, request, version): + async def on_DELETE( + self, request: SynapseRequest, version: Optional[str] + ) -> Tuple[int, JsonDict]: """ Delete the information about a given version of the user's room_keys backup. If the version part is missing, deletes the most @@ -351,7 +356,9 @@ class RoomKeysVersionServlet(RestServlet): await self.e2e_room_keys_handler.delete_version(user_id, version) return 200, {} - async def on_PUT(self, request, version): + async def on_PUT( + self, request: SynapseRequest, version: Optional[str] + ) -> Tuple[int, JsonDict]: """ Update the information about a given version of the user's room_keys backup. @@ -385,7 +392,7 @@ class RoomKeysVersionServlet(RestServlet): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomKeysServlet(hs).register(http_server) RoomKeysVersionServlet(hs).register(http_server) RoomKeysNewVersionServlet(hs).register(http_server) diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py index d537d811d8..3322c8ef48 100644 --- a/synapse/rest/client/sendtodevice.py +++ b/synapse/rest/client/sendtodevice.py @@ -13,15 +13,21 @@ # limitations under the License. import logging -from typing import Tuple +from typing import TYPE_CHECKING, Awaitable, Tuple from synapse.http import servlet +from synapse.http.server import HttpServer from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request +from synapse.http.site import SynapseRequest from synapse.logging.opentracing import set_tag, trace from synapse.rest.client.transactions import HttpTransactionCache +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -30,11 +36,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): "/sendToDevice/(?P[^/]*)/(?P[^/]*)$" ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -42,14 +44,18 @@ class SendToDeviceRestServlet(servlet.RestServlet): self.device_message_handler = hs.get_device_message_handler() @trace(opname="sendToDevice") - def on_PUT(self, request, message_type, txn_id): + def on_PUT( + self, request: SynapseRequest, message_type: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("message_type", message_type) set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( request, self._put, request, message_type, txn_id ) - async def _put(self, request, message_type, txn_id): + async def _put( + self, request: SynapseRequest, message_type: str, txn_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) @@ -59,9 +65,8 @@ class SendToDeviceRestServlet(servlet.RestServlet): requester, message_type, content["messages"] ) - response: Tuple[int, dict] = (200, {}) - return response + return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: SendToDeviceRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 65c37be3e9..1259058b9b 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -14,12 +14,24 @@ import itertools import logging from collections import defaultdict -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) from synapse.api.constants import Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.api.presence import UserPresenceState +from synapse.events import EventBase from synapse.events.utils import ( format_event_for_client_v2_without_room_id, format_event_raw, @@ -504,7 +516,7 @@ class SyncRestServlet(RestServlet): The room, encoded in our response format """ - def serialize(events): + def serialize(events: Iterable[EventBase]) -> Awaitable[List[JsonDict]]: return self._event_serializer.serialize_events( events, time_now=time_now, -- cgit 1.5.1 From 6258730ebe54f72b8869e8181d032b67ff9fd6e4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 1 Sep 2021 12:59:52 -0400 Subject: Consider the `origin_server_ts` of the `m.space.child` event when ordering rooms. (#10730) This updates the ordering of the returned events from the spaces summary API to that defined in MSC2946 (which updates MSC1772). Previously a step was skipped causing ordering to be inconsistent with clients. --- changelog.d/10730.bugfix | 1 + synapse/handlers/room_summary.py | 15 ++++++++------- tests/handlers/test_room_summary.py | 18 +++++++++++++----- 3 files changed, 22 insertions(+), 12 deletions(-) create mode 100644 changelog.d/10730.bugfix diff --git a/changelog.d/10730.bugfix b/changelog.d/10730.bugfix new file mode 100644 index 0000000000..f1612d3c08 --- /dev/null +++ b/changelog.d/10730.bugfix @@ -0,0 +1 @@ +Fix a bug where the ordering algorithm was skipping the `origin_server_ts` step in the spaces summary resulting in unstable room orderings. diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 906985c754..d1b6f3253e 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -1139,25 +1139,26 @@ def _is_suggested_child_event(edge_event: EventBase) -> bool: _INVALID_ORDER_CHARS_RE = re.compile(r"[^\x20-\x7E]") -def _child_events_comparison_key(child: EventBase) -> Tuple[bool, Optional[str], str]: +def _child_events_comparison_key( + child: EventBase, +) -> Tuple[bool, Optional[str], int, str]: """ Generate a value for comparing two child events for ordering. - The rules for ordering are supposed to be: + The rules for ordering are: 1. The 'order' key, if it is valid. - 2. The 'origin_server_ts' of the 'm.room.create' event. + 2. The 'origin_server_ts' of the 'm.space.child' event. 3. The 'room_id'. - But we skip step 2 since we may not have any state from the room. - Args: child: The event for generating a comparison key. Returns: The comparison key as a tuple of: False if the ordering is valid. - The ordering field. + The 'order' field or None if it is not given or invalid. + The 'origin_server_ts' field. The room ID. """ order = child.content.get("order") @@ -1168,4 +1169,4 @@ def _child_events_comparison_key(child: EventBase) -> Tuple[bool, Optional[str], order = None # Items without an order come last. - return (order is None, order, child.room_id) + return (order is None, order, child.origin_server_ts, child.room_id) diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index ac800afa7d..449ba89e5a 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -35,10 +35,11 @@ from synapse.types import JsonDict, UserID from tests import unittest -def _create_event(room_id: str, order: Optional[Any] = None): - result = mock.Mock() +def _create_event(room_id: str, order: Optional[Any] = None, origin_server_ts: int = 0): + result = mock.Mock(name=room_id) result.room_id = room_id result.content = {} + result.origin_server_ts = origin_server_ts if order is not None: result.content["order"] = order return result @@ -63,10 +64,17 @@ class TestSpaceSummarySort(unittest.TestCase): self.assertEqual([ev2, ev1], _order(ev1, ev2)) + def test_order_origin_server_ts(self): + """Origin server is a tie-breaker for ordering.""" + ev1 = _create_event("!abc:test", origin_server_ts=10) + ev2 = _create_event("!xyz:test", origin_server_ts=30) + + self.assertEqual([ev1, ev2], _order(ev1, ev2)) + def test_order_room_id(self): - """Room ID is a tie-breaker for ordering.""" - ev1 = _create_event("!abc:test", "abc") - ev2 = _create_event("!xyz:test", "abc") + """Room ID is a final tie-breaker for ordering.""" + ev1 = _create_event("!abc:test") + ev2 = _create_event("!xyz:test") self.assertEqual([ev1, ev2], _order(ev1, ev2)) -- cgit 1.5.1 From c586d6803a2adebcdd486be2d9eac1f62fd7d4ab Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 1 Sep 2021 13:01:08 -0400 Subject: Ignore rooms with unknown room versions in the spaces summary. (#10727) This avoids breaking the entire endpoint if a room with an unsupported room version is encountered. --- changelog.d/10727.misc | 1 + synapse/handlers/room_summary.py | 16 ++++++++++++++-- tests/handlers/test_room_summary.py | 25 +++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) create mode 100644 changelog.d/10727.misc diff --git a/changelog.d/10727.misc b/changelog.d/10727.misc new file mode 100644 index 0000000000..63fe6e5c7d --- /dev/null +++ b/changelog.d/10727.misc @@ -0,0 +1 @@ +Do not include rooms with unknown room versions in the spaces summary results. diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index d1b6f3253e..4bc9c73e6e 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -28,7 +28,14 @@ from synapse.api.constants import ( Membership, RoomTypes, ) -from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + NotFoundError, + StoreError, + SynapseError, + UnsupportedRoomVersionError, +) from synapse.events import EventBase from synapse.events.utils import format_event_for_client_v2 from synapse.types import JsonDict @@ -814,7 +821,12 @@ class RoomSummaryHandler: logger.info("room %s is unknown, omitting from summary", room_id) return False - room_version = await self._store.get_room_version(room_id) + try: + room_version = await self._store.get_room_version(room_id) + except UnsupportedRoomVersionError: + # If a room with an unsupported room version is encountered, ignore + # it to avoid breaking the entire summary response. + return False # Include the room if it has join rules of public or knock. join_rules_event_id = state_ids.get((EventTypes.JoinRules, "")) diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index 449ba89e5a..d3d0bf1ac5 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -581,6 +581,31 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ] self._assert_hierarchy(result, expected) + def test_unknown_room_version(self): + """ + If an room with an unknown room version is encountered it should not cause + the entire summary to skip. + """ + # Poke the database and update the room version to an unknown one. + self.get_success( + self.hs.get_datastores().main.db_pool.simple_update( + "rooms", + keyvalues={"room_id": self.room}, + updatevalues={"room_version": "unknown-room-version"}, + desc="updated-room-version", + ) + ) + + result = self.get_success(self.handler.get_space_summary(self.user, self.space)) + # The result should have only the space, along with a link from space -> room. + expected = [(self.space, [self.room])] + self._assert_rooms(result, expected) + + result = self.get_success( + self.handler.get_room_hierarchy(self.user, self.space) + ) + self._assert_hierarchy(result, expected) + def test_fed_complex(self): """ Return data over federation and ensure that it is handled properly. -- cgit 1.5.1 From 00640ee71affeab0f6ed7e784925522437f076c5 Mon Sep 17 00:00:00 2001 From: cuttingedge1109 <53085803+cuttingedge1109@users.noreply.github.com> Date: Thu, 2 Sep 2021 15:07:53 +0200 Subject: Fix documentation of directory name for remote thumbnails (#10556) --- changelog.d/10556.doc | 1 + docs/media_repository.md | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10556.doc diff --git a/changelog.d/10556.doc b/changelog.d/10556.doc new file mode 100644 index 0000000000..7526ae11db --- /dev/null +++ b/changelog.d/10556.doc @@ -0,0 +1 @@ +Minor fix to the `media_repository` developer documentation. Contributed by @cuttingedge1109. \ No newline at end of file diff --git a/docs/media_repository.md b/docs/media_repository.md index 1bf8f16f55..99ee8f1ef7 100644 --- a/docs/media_repository.md +++ b/docs/media_repository.md @@ -27,4 +27,4 @@ Remote content is cached under `"remote_content"` directory. Each item of remote content is assigned a local `"filesystem_id"` to ensure that the directory structure `"remote_content/server_name/aa/bb/ccccccccdddddddddddd"` is appropriate. Thumbnails for remote content are stored under -`"remote_thumbnails/server_name/..."` +`"remote_thumbnail/server_name/..."` -- cgit 1.5.1 From f58d202e3fc2c17bbe4ee24dd07f09888f73ef23 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 3 Sep 2021 10:59:25 +0100 Subject: Fix bug with reusing 'txn' when persisting event. (#10743) This will only happen when a server has multiple out of band membership events in a single room. --- changelog.d/10743.bugfix | 1 + synapse/storage/databases/main/events.py | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10743.bugfix diff --git a/changelog.d/10743.bugfix b/changelog.d/10743.bugfix new file mode 100644 index 0000000000..d597a19870 --- /dev/null +++ b/changelog.d/10743.bugfix @@ -0,0 +1 @@ +Fix edge case when persisting events into a room where there are multiple events we previously hadn't calculated auth chains for (and hadn't marked as needing to be calculated). diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 40b53274fb..096ae28788 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -575,7 +575,13 @@ class PersistEventsStore: missing_auth_chains.clear() - for auth_id, event_type, state_key, chain_id, sequence_number in txn: + for ( + auth_id, + event_type, + state_key, + chain_id, + sequence_number, + ) in txn.fetchall(): event_to_types[auth_id] = (event_type, state_key) if chain_id is None: -- cgit 1.5.1 From ecbfa4fe4fa7625dec14ef8f9bd06cc4ad141de0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 3 Sep 2021 09:22:22 -0400 Subject: Additional type hints for client REST servlets (part 5) (#10736) Additionally this enforce type hints on all function signatures inside of the synapse.rest.client package. --- changelog.d/10736.misc | 1 + mypy.ini | 3 + synapse/http/servlet.py | 19 +++++ synapse/rest/admin/server_notice_servlet.py | 6 +- synapse/rest/client/_base.py | 11 ++- synapse/rest/client/groups.py | 10 ++- synapse/rest/client/push_rule.py | 112 +++++++++++++++++----------- synapse/rest/client/transactions.py | 52 +++++++++---- 8 files changed, 146 insertions(+), 68 deletions(-) create mode 100644 changelog.d/10736.misc diff --git a/changelog.d/10736.misc b/changelog.d/10736.misc new file mode 100644 index 0000000000..39a37b90b1 --- /dev/null +++ b/changelog.d/10736.misc @@ -0,0 +1 @@ +Add missing type hints to REST servlets. diff --git a/mypy.ini b/mypy.ini index f6de668edd..03dc6b8697 100644 --- a/mypy.ini +++ b/mypy.ini @@ -98,6 +98,9 @@ files = tests/util/test_itertools.py, tests/util/test_stream_change_cache.py +[mypy-synapse.rest.client.*] +disallow_untyped_defs = True + [mypy-pymacaroons.*] ignore_missing_imports = True diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index a12fa30bfd..91ba93372c 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -572,6 +572,25 @@ def parse_string_from_args( return strings[0] +@overload +def parse_json_value_from_request(request: Request) -> JsonDict: + ... + + +@overload +def parse_json_value_from_request( + request: Request, allow_empty_body: Literal[False] +) -> JsonDict: + ... + + +@overload +def parse_json_value_from_request( + request: Request, allow_empty_body: bool = False +) -> Optional[JsonDict]: + ... + + def parse_json_value_from_request( request: Request, allow_empty_body: bool = False ) -> Optional[JsonDict]: diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index 42201afc86..f5a38c2670 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Optional, Tuple from synapse.api.constants import EventTypes from synapse.api.errors import NotFoundError, SynapseError @@ -101,7 +101,9 @@ class SendServerNoticeServlet(RestServlet): return 200, {"event_id": event.event_id} - def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]: + def on_PUT( + self, request: SynapseRequest, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: return self.txns.fetch_or_execute_request( request, self.on_POST, request, txn_id ) diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py index 0443f4571c..a0971ce994 100644 --- a/synapse/rest/client/_base.py +++ b/synapse/rest/client/_base.py @@ -16,7 +16,7 @@ """ import logging import re -from typing import Iterable, Pattern +from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar, cast from synapse.api.errors import InteractiveAuthIncompleteError from synapse.api.urls import CLIENT_API_PREFIX @@ -76,7 +76,10 @@ def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) ) -def interactive_auth_handler(orig): +C = TypeVar("C", bound=Callable[..., Awaitable[Tuple[int, JsonDict]]]) + + +def interactive_auth_handler(orig: C) -> C: """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors Takes a on_POST method which returns an Awaitable (errcode, body) response @@ -91,10 +94,10 @@ def interactive_auth_handler(orig): await self.auth_handler.check_auth """ - async def wrapped(*args, **kwargs): + async def wrapped(*args: Any, **kwargs: Any) -> Tuple[int, JsonDict]: try: return await orig(*args, **kwargs) except InteractiveAuthIncompleteError as e: return 401, e.result - return wrapped + return cast(C, wrapped) diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py index 0950f43f2f..a7e9aa3e9b 100644 --- a/synapse/rest/client/groups.py +++ b/synapse/rest/client/groups.py @@ -15,7 +15,7 @@ import logging from functools import wraps -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple from twisted.web.server import Request @@ -43,14 +43,18 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -def _validate_group_id(f): +def _validate_group_id( + f: Callable[..., Awaitable[Tuple[int, JsonDict]]] +) -> Callable[..., Awaitable[Tuple[int, JsonDict]]]: """Wrapper to validate the form of the group ID. Can be applied to any on_FOO methods that accepts a group ID as a URL parameter. """ @wraps(f) - def wrapper(self, request: Request, group_id: str, *args, **kwargs): + def wrapper( + self: RestServlet, request: Request, group_id: str, *args: Any, **kwargs: Any + ) -> Awaitable[Tuple[int, JsonDict]]: if not GroupID.is_valid(group_id): raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index 702b351d18..fb3211bf3a 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -12,22 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union + +import attr + from synapse.api.errors import ( NotFoundError, StoreError, SynapseError, UnrecognizedRequestError, ) +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_json_value_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.rest.client._base import client_patterns from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RuleSpec: + scope: str + template: str + rule_id: str + attr: Optional[str] class PushRuleRestServlet(RestServlet): @@ -36,7 +54,7 @@ class PushRuleRestServlet(RestServlet): "Unrecognised request: You probably wanted a trailing slash" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -45,7 +63,7 @@ class PushRuleRestServlet(RestServlet): self._users_new_default_push_rules = hs.config.users_new_default_push_rules - async def on_PUT(self, request, path): + async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: if self._is_worker: raise Exception("Cannot handle PUT /push_rules on worker") @@ -57,25 +75,25 @@ class PushRuleRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) - if "/" in spec["rule_id"] or "\\" in spec["rule_id"]: + if "/" in spec.rule_id or "\\" in spec.rule_id: raise SynapseError(400, "rule_id may not contain slashes") content = parse_json_value_from_request(request) user_id = requester.user.to_string() - if "attr" in spec: + if spec.attr: await self.set_rule_attr(user_id, spec, content) self.notify_user(user_id) return 200, {} - if spec["rule_id"].startswith("."): + if spec.rule_id.startswith("."): # Rule ids starting with '.' are reserved for server default rules. raise SynapseError(400, "cannot add new rule_ids that start with '.'") try: (conditions, actions) = _rule_tuple_from_request_object( - spec["template"], spec["rule_id"], content + spec.template, spec.rule_id, content ) except InvalidRuleException as e: raise SynapseError(400, str(e)) @@ -106,7 +124,9 @@ class PushRuleRestServlet(RestServlet): return 200, {} - async def on_DELETE(self, request, path): + async def on_DELETE( + self, request: SynapseRequest, path: str + ) -> Tuple[int, JsonDict]: if self._is_worker: raise Exception("Cannot handle DELETE /push_rules on worker") @@ -127,7 +147,7 @@ class PushRuleRestServlet(RestServlet): else: raise - async def on_GET(self, request, path): + async def on_GET(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -138,40 +158,42 @@ class PushRuleRestServlet(RestServlet): rules = format_push_rules_for_user(requester.user, rules) - path = path.split("/")[1:] + path_parts = path.split("/")[1:] - if path == []: + if path_parts == []: # we're a reference impl: pedantry is our job. raise UnrecognizedRequestError( PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR ) - if path[0] == "": + if path_parts[0] == "": return 200, rules - elif path[0] == "global": - result = _filter_ruleset_with_path(rules["global"], path[1:]) + elif path_parts[0] == "global": + result = _filter_ruleset_with_path(rules["global"], path_parts[1:]) return 200, result else: raise UnrecognizedRequestError() - def notify_user(self, user_id): + def notify_user(self, user_id: str) -> None: stream_id = self.store.get_max_push_rules_stream_id() self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) - async def set_rule_attr(self, user_id, spec, val): - if spec["attr"] not in ("enabled", "actions"): + async def set_rule_attr( + self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict] + ) -> None: + if spec.attr not in ("enabled", "actions"): # for the sake of potential future expansion, shouldn't report # 404 in the case of an unknown request so check it corresponds to # a known attribute first. raise UnrecognizedRequestError() namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec["rule_id"] + rule_id = spec.rule_id is_default_rule = rule_id.startswith(".") if is_default_rule: if namespaced_rule_id not in BASE_RULE_IDS: raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,)) - if spec["attr"] == "enabled": + if spec.attr == "enabled": if isinstance(val, dict) and "enabled" in val: val = val["enabled"] if not isinstance(val, bool): @@ -179,14 +201,18 @@ class PushRuleRestServlet(RestServlet): # This should *actually* take a dict, but many clients pass # bools directly, so let's not break them. raise SynapseError(400, "Value for 'enabled' must be boolean") - return await self.store.set_push_rule_enabled( + await self.store.set_push_rule_enabled( user_id, namespaced_rule_id, val, is_default_rule ) - elif spec["attr"] == "actions": + elif spec.attr == "actions": + if not isinstance(val, dict): + raise SynapseError(400, "Value must be a dict") actions = val.get("actions") + if not isinstance(actions, list): + raise SynapseError(400, "Value for 'actions' must be dict") _check_actions(actions) namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec["rule_id"] + rule_id = spec.rule_id is_default_rule = rule_id.startswith(".") if is_default_rule: if user_id in self._users_new_default_push_rules: @@ -196,22 +222,21 @@ class PushRuleRestServlet(RestServlet): if namespaced_rule_id not in rule_ids: raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) - return await self.store.set_push_rule_actions( + await self.store.set_push_rule_actions( user_id, namespaced_rule_id, actions, is_default_rule ) else: raise UnrecognizedRequestError() -def _rule_spec_from_path(path): +def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec: """Turn a sequence of path components into a rule spec Args: - path (sequence[unicode]): the URL path components. + path: the URL path components. Returns: - dict: rule spec dict, containing scope/template/rule_id entries, - and possibly attr. + rule spec, containing scope/template/rule_id entries, and possibly attr. Raises: UnrecognizedRequestError if the path components cannot be parsed. @@ -237,17 +262,18 @@ def _rule_spec_from_path(path): rule_id = path[0] - spec = {"scope": scope, "template": template, "rule_id": rule_id} - path = path[1:] + attr = None if len(path) > 0 and len(path[0]) > 0: - spec["attr"] = path[0] + attr = path[0] - return spec + return RuleSpec(scope, template, rule_id, attr) -def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): +def _rule_tuple_from_request_object( + rule_template: str, rule_id: str, req_obj: JsonDict +) -> Tuple[List[JsonDict], List[Union[str, JsonDict]]]: if rule_template in ["override", "underride"]: if "conditions" not in req_obj: raise InvalidRuleException("Missing 'conditions'") @@ -277,7 +303,7 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): return conditions, actions -def _check_actions(actions): +def _check_actions(actions: List[Union[str, JsonDict]]) -> None: if not isinstance(actions, list): raise InvalidRuleException("No actions found") @@ -290,7 +316,7 @@ def _check_actions(actions): raise InvalidRuleException("Unrecognised action") -def _filter_ruleset_with_path(ruleset, path): +def _filter_ruleset_with_path(ruleset: JsonDict, path: List[str]) -> JsonDict: if path == []: raise UnrecognizedRequestError( PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR @@ -315,7 +341,7 @@ def _filter_ruleset_with_path(ruleset, path): if r["rule_id"] == rule_id: the_rule = r if the_rule is None: - raise NotFoundError + raise NotFoundError() path = path[1:] if len(path) == 0: @@ -330,25 +356,25 @@ def _filter_ruleset_with_path(ruleset, path): raise UnrecognizedRequestError() -def _priority_class_from_spec(spec): - if spec["template"] not in PRIORITY_CLASS_MAP.keys(): - raise InvalidRuleException("Unknown template: %s" % (spec["template"])) - pc = PRIORITY_CLASS_MAP[spec["template"]] +def _priority_class_from_spec(spec: RuleSpec) -> int: + if spec.template not in PRIORITY_CLASS_MAP.keys(): + raise InvalidRuleException("Unknown template: %s" % (spec.template)) + pc = PRIORITY_CLASS_MAP[spec.template] return pc -def _namespaced_rule_id_from_spec(spec): - return _namespaced_rule_id(spec, spec["rule_id"]) +def _namespaced_rule_id_from_spec(spec: RuleSpec) -> str: + return _namespaced_rule_id(spec, spec.rule_id) -def _namespaced_rule_id(spec, rule_id): - return "global/%s/%s" % (spec["template"], rule_id) +def _namespaced_rule_id(spec: RuleSpec, rule_id: str) -> str: + return "global/%s/%s" % (spec.template, rule_id) class InvalidRuleException(Exception): pass -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: PushRuleRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 94ff3719ce..914fb3acf5 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -15,28 +15,37 @@ """This module contains logic for storing HTTP PUT transactions. This is used to ensure idempotency when performing PUTs using the REST API.""" import logging +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Tuple + +from twisted.python.failure import Failure +from twisted.web.server import Request from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.types import JsonDict from synapse.util.async_helpers import ObservableDeferred +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins class HttpTransactionCache: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = self.hs.get_auth() self.clock = self.hs.get_clock() - self.transactions = { - # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp) - } + # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp) + self.transactions: Dict[ + str, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int] + ] = {} # Try to clean entries every 30 mins. This means entries will exist # for at *LEAST* 30 mins, and at *MOST* 60 mins. self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS) - def _get_transaction_key(self, request): + def _get_transaction_key(self, request: Request) -> str: """A helper function which returns a transaction key that can be used with TransactionCache for idempotent requests. @@ -45,15 +54,21 @@ class HttpTransactionCache: path and the access_token for the requesting user. Args: - request (twisted.web.http.Request): The incoming request. Must - contain an access_token. + request: The incoming request. Must contain an access_token. Returns: - str: A transaction key + A transaction key """ + assert request.path is not None token = self.auth.get_access_token_from_request(request) return request.path.decode("utf8") + "/" + token - def fetch_or_execute_request(self, request, fn, *args, **kwargs): + def fetch_or_execute_request( + self, + request: Request, + fn: Callable[..., Awaitable[Tuple[int, JsonDict]]], + *args: Any, + **kwargs: Any, + ) -> Awaitable[Tuple[int, JsonDict]]: """A helper function for fetch_or_execute which extracts a transaction key from the given request. @@ -64,15 +79,20 @@ class HttpTransactionCache: self._get_transaction_key(request), fn, *args, **kwargs ) - def fetch_or_execute(self, txn_key, fn, *args, **kwargs): + def fetch_or_execute( + self, + txn_key: str, + fn: Callable[..., Awaitable[Tuple[int, JsonDict]]], + *args: Any, + **kwargs: Any, + ) -> Awaitable[Tuple[int, JsonDict]]: """Fetches the response for this transaction, or executes the given function to produce a response for this transaction. Args: - txn_key (str): A key to ensure idempotency should fetch_or_execute be - called again at a later point in time. - fn (function): A function which returns a tuple of - (response_code, response_dict). + txn_key: A key to ensure idempotency should fetch_or_execute be + called again at a later point in time. + fn: A function which returns a tuple of (response_code, response_dict). *args: Arguments to pass to fn. **kwargs: Keyword arguments to pass to fn. Returns: @@ -90,7 +110,7 @@ class HttpTransactionCache: # if the request fails with an exception, remove it # from the transaction map. This is done to ensure that we don't # cache transient errors like rate-limiting errors, etc. - def remove_from_map(err): + def remove_from_map(err: Failure) -> None: self.transactions.pop(txn_key, None) # we deliberately do not propagate the error any further, as we # expect the observers to have reported it. @@ -99,7 +119,7 @@ class HttpTransactionCache: return make_deferred_yieldable(observable.observe()) - def _cleanup(self): + def _cleanup(self) -> None: now = self.clock.time_msec() for key in list(self.transactions): ts = self.transactions[key][1] -- cgit 1.5.1 From 2cb85bdf7586f5094656dce6030922849fbbdb87 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 3 Sep 2021 09:46:18 -0400 Subject: Raise an error if an unknown preset is used to create a room. (#10738) Raises a 400 error instead of a 500 if an unknown preset is passed from a client to create a room. --- changelog.d/10738.misc | 1 + synapse/handlers/room.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10738.misc diff --git a/changelog.d/10738.misc b/changelog.d/10738.misc new file mode 100644 index 0000000000..cef54153dc --- /dev/null +++ b/changelog.d/10738.misc @@ -0,0 +1 @@ +Additional error checking for the `preset` field when creating a room. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b33fe09f77..ed780bb41f 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -909,7 +909,12 @@ class RoomCreationHandler(BaseHandler): ) return last_stream_id - config = self._presets_dict[preset_config] + try: + config = self._presets_dict[preset_config] + except KeyError: + raise SynapseError( + 400, f"'{preset_config}' is not a valid preset", errcode=Codes.BAD_JSON + ) creation_content.update({"creator": creator_id}) await send(etype=EventTypes.Create, content=creation_content) -- cgit 1.5.1 From 0eae330a26be63716c338009d3ae562f075942aa Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 3 Sep 2021 16:35:49 +0100 Subject: Use `execute_values` more in PostgreSQL (#10754) `execute_values` is a faster version of `execute_batch`. --- changelog.d/10754.misc | 1 + synapse/storage/database.py | 61 +++++++++++++++++++++++++++++++-------------- 2 files changed, 43 insertions(+), 19 deletions(-) create mode 100644 changelog.d/10754.misc diff --git a/changelog.d/10754.misc b/changelog.d/10754.misc new file mode 100644 index 0000000000..3b7acff03f --- /dev/null +++ b/changelog.d/10754.misc @@ -0,0 +1 @@ +Minor speed ups when joining large rooms over federation. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 95d2caff62..0084d9f96c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -280,18 +280,18 @@ class LoggingTransaction: else: self.executemany(sql, args) - def execute_values(self, sql: str, *args: Any) -> List[Tuple]: + def execute_values(self, sql: str, *args: Any, fetch: bool = True) -> List[Tuple]: """Corresponds to psycopg2.extras.execute_values. Only available when using postgres. - Always sets fetch=True when caling `execute_values`, so will return the - results. + The `fetch` parameter must be set to False if the query does not return + rows (e.g. INSERTs). """ assert isinstance(self.database_engine, PostgresEngine) from psycopg2.extras import execute_values # type: ignore return self._do_execute( - lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args + lambda *x: execute_values(self.txn, *x, fetch=fetch), sql, *args ) def execute(self, sql: str, *args: Any) -> None: @@ -920,13 +920,23 @@ class DatabasePool: if k != keys[0]: raise RuntimeError("All items must have the same keys") - sql = "INSERT INTO %s (%s) VALUES(%s)" % ( - table, - ", ".join(k for k in keys[0]), - ", ".join("?" for _ in keys[0]), - ) + if isinstance(txn.database_engine, PostgresEngine): + # We use `execute_values` as it can be a lot faster than `execute_batch`, + # but it's only available on postgres. + sql = "INSERT INTO %s (%s) VALUES ?" % ( + table, + ", ".join(k for k in keys[0]), + ) - txn.execute_batch(sql, vals) + txn.execute_values(sql, vals, fetch=False) + else: + sql = "INSERT INTO %s (%s) VALUES(%s)" % ( + table, + ", ".join(k for k in keys[0]), + ", ".join("?" for _ in keys[0]), + ) + + txn.execute_batch(sql, vals) async def simple_upsert( self, @@ -1281,20 +1291,33 @@ class DatabasePool: k + "=EXCLUDED." + k for k in value_names ) - sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % ( - table, - ", ".join(k for k in allnames), - ", ".join("?" for _ in allnames), - ", ".join(key_names), - latter, - ) - args = [] for x, y in zip(key_values, value_values): args.append(tuple(x) + tuple(y)) - return txn.execute_batch(sql, args) + if isinstance(txn.database_engine, PostgresEngine): + # We use `execute_values` as it can be a lot faster than `execute_batch`, + # but it's only available on postgres. + sql = "INSERT INTO %s (%s) VALUES ? ON CONFLICT (%s) DO %s" % ( + table, + ", ".join(k for k in allnames), + ", ".join(key_names), + latter, + ) + + txn.execute_values(sql, args, fetch=False) + + else: + sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % ( + table, + ", ".join(k for k in allnames), + ", ".join("?" for _ in allnames), + ", ".join(key_names), + latter, + ) + + return txn.execute_batch(sql, args) @overload async def simple_select_one( -- cgit 1.5.1 From 924276f4821c71b788378e1ed6a0bd384771844f Mon Sep 17 00:00:00 2001 From: Sean Date: Fri, 3 Sep 2021 17:16:56 +0100 Subject: Add a partial index to `presence_stream` to speed up startups (#10748) Signed-off-by: Sean Quah --- changelog.d/10748.misc | 1 + scripts/synapse_port_db | 2 ++ synapse/storage/databases/main/presence.py | 23 +++++++++++++++++++++- .../63/04add_presence_stream_not_offline_index.sql | 18 +++++++++++++++++ 4 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10748.misc create mode 100644 synapse/storage/schema/main/delta/63/04add_presence_stream_not_offline_index.sql diff --git a/changelog.d/10748.misc b/changelog.d/10748.misc new file mode 100644 index 0000000000..b9e2c46087 --- /dev/null +++ b/changelog.d/10748.misc @@ -0,0 +1 @@ +Add an index to `presence_stream` to hopefully speed up startups a little. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 2bbaf5557d..fa6ac6d93a 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -46,6 +46,7 @@ from synapse.storage.databases.main.events_bg_updates import ( from synapse.storage.databases.main.media_repository import ( MediaRepositoryBackgroundUpdateStore, ) +from synapse.storage.databases.main.presence import PresenceBackgroundUpdateStore from synapse.storage.databases.main.pusher import PusherWorkerStore from synapse.storage.databases.main.registration import ( RegistrationBackgroundUpdateStore, @@ -179,6 +180,7 @@ class Store( EndToEndKeyBackgroundStore, StatsStore, PusherWorkerStore, + PresenceBackgroundUpdateStore, ): def execute(self, f, *args, **kwargs): return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 1388771c40..12cf6995eb 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -29,7 +29,26 @@ if TYPE_CHECKING: from synapse.server import HomeServer -class PresenceStore(SQLBaseStore): +class PresenceBackgroundUpdateStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: Connection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + # Used by `PresenceStore._get_active_presence()` + self.db_pool.updates.register_background_index_update( + "presence_stream_not_offline_index", + index_name="presence_stream_state_not_offline_idx", + table="presence_stream", + columns=["state"], + where_clause="state != 'offline'", + ) + + +class PresenceStore(PresenceBackgroundUpdateStore): def __init__( self, database: DatabasePool, @@ -332,6 +351,8 @@ class PresenceStore(SQLBaseStore): the appropriate time outs. """ + # The `presence_stream_state_not_offline_idx` index should be used for this + # query. sql = ( "SELECT user_id, state, last_active_ts, last_federation_update_ts," " last_user_sync_ts, status_msg, currently_active FROM presence_stream" diff --git a/synapse/storage/schema/main/delta/63/04add_presence_stream_not_offline_index.sql b/synapse/storage/schema/main/delta/63/04add_presence_stream_not_offline_index.sql new file mode 100644 index 0000000000..b90856004b --- /dev/null +++ b/synapse/storage/schema/main/delta/63/04add_presence_stream_not_offline_index.sql @@ -0,0 +1,18 @@ +/* + * Copyright 2021 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6304, 'presence_stream_not_offline_index', '{}'); -- cgit 1.5.1 From ae3c16318bf8082c34fa91dc32b30ab9887b3b24 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 3 Sep 2021 12:51:15 -0400 Subject: Support MSC3375: room version 9. (#10747) --- changelog.d/10747.feature | 1 + synapse/api/room_versions.py | 31 +++++++++++++++++++++++++++++ synapse/events/utils.py | 2 ++ tests/events/test_utils.py | 46 +++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10747.feature diff --git a/changelog.d/10747.feature b/changelog.d/10747.feature new file mode 100644 index 0000000000..c1cca9a17b --- /dev/null +++ b/changelog.d/10747.feature @@ -0,0 +1 @@ +Support room version 9 from [MSC3375](https://github.com/matrix-org/matrix-doc/pull/3375). diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index 8abcdfd4fd..a19be6707a 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -70,6 +70,9 @@ class RoomVersion: msc2176_redaction_rules = attr.ib(type=bool) # MSC3083: Support the 'restricted' join_rule. msc3083_join_rules = attr.ib(type=bool) + # MSC3375: Support for the proper redaction rules for MSC3083. This mustn't + # be enabled if MSC3083 is not. + msc3375_redaction_rules = attr.ib(type=bool) # MSC2403: Allows join_rules to be set to 'knock', changes auth rules to allow sending # m.room.membership event with membership 'knock'. msc2403_knocking = attr.ib(type=bool) @@ -92,6 +95,7 @@ class RoomVersions: limit_notifications_power_levels=False, msc2176_redaction_rules=False, msc3083_join_rules=False, + msc3375_redaction_rules=False, msc2403_knocking=False, msc2716_historical=False, msc2716_redactions=False, @@ -107,6 +111,7 @@ class RoomVersions: limit_notifications_power_levels=False, msc2176_redaction_rules=False, msc3083_join_rules=False, + msc3375_redaction_rules=False, msc2403_knocking=False, msc2716_historical=False, msc2716_redactions=False, @@ -122,6 +127,7 @@ class RoomVersions: limit_notifications_power_levels=False, msc2176_redaction_rules=False, msc3083_join_rules=False, + msc3375_redaction_rules=False, msc2403_knocking=False, msc2716_historical=False, msc2716_redactions=False, @@ -137,6 +143,7 @@ class RoomVersions: limit_notifications_power_levels=False, msc2176_redaction_rules=False, msc3083_join_rules=False, + msc3375_redaction_rules=False, msc2403_knocking=False, msc2716_historical=False, msc2716_redactions=False, @@ -152,6 +159,7 @@ class RoomVersions: limit_notifications_power_levels=False, msc2176_redaction_rules=False, msc3083_join_rules=False, + msc3375_redaction_rules=False, msc2403_knocking=False, msc2716_historical=False, msc2716_redactions=False, @@ -167,6 +175,7 @@ class RoomVersions: limit_notifications_power_levels=True, msc2176_redaction_rules=False, msc3083_join_rules=False, + msc3375_redaction_rules=False, msc2403_knocking=False, msc2716_historical=False, msc2716_redactions=False, @@ -182,6 +191,7 @@ class RoomVersions: limit_notifications_power_levels=True, msc2176_redaction_rules=True, msc3083_join_rules=False, + msc3375_redaction_rules=False, msc2403_knocking=False, msc2716_historical=False, msc2716_redactions=False, @@ -197,6 +207,7 @@ class RoomVersions: limit_notifications_power_levels=True, msc2176_redaction_rules=False, msc3083_join_rules=False, + msc3375_redaction_rules=False, msc2403_knocking=True, msc2716_historical=False, msc2716_redactions=False, @@ -212,6 +223,23 @@ class RoomVersions: limit_notifications_power_levels=True, msc2176_redaction_rules=False, msc3083_join_rules=True, + msc3375_redaction_rules=False, + msc2403_knocking=True, + msc2716_historical=False, + msc2716_redactions=False, + ) + V9 = RoomVersion( + "9", + RoomDisposition.STABLE, + EventFormatVersions.V3, + StateResolutionVersions.V2, + enforce_key_validity=True, + special_case_aliases_auth=False, + strict_canonicaljson=True, + limit_notifications_power_levels=True, + msc2176_redaction_rules=False, + msc3083_join_rules=True, + msc3375_redaction_rules=True, msc2403_knocking=True, msc2716_historical=False, msc2716_redactions=False, @@ -227,6 +255,7 @@ class RoomVersions: limit_notifications_power_levels=True, msc2176_redaction_rules=False, msc3083_join_rules=False, + msc3375_redaction_rules=False, msc2403_knocking=True, msc2716_historical=True, msc2716_redactions=False, @@ -242,6 +271,7 @@ class RoomVersions: limit_notifications_power_levels=True, msc2176_redaction_rules=False, msc3083_join_rules=False, + msc3375_redaction_rules=False, msc2403_knocking=True, msc2716_historical=True, msc2716_redactions=True, @@ -261,6 +291,7 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = { RoomVersions.V7, RoomVersions.MSC2716, RoomVersions.V8, + RoomVersions.V9, ) } diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 738a151cef..fb22337e27 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -104,6 +104,8 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: if event_type == EventTypes.Member: add_fields("membership") + if room_version.msc3375_redaction_rules: + add_fields("join_authorised_via_users_server") elif event_type == EventTypes.Create: # MSC2176 rules state that create events cannot be redacted. if room_version.msc2176_redaction_rules: diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 7a826c086e..5446fda5e7 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -322,7 +322,7 @@ class PruneEventTestCase(unittest.TestCase): }, ) - # After MSC3083, alias events have no special behavior. + # After MSC3083, the allow key is protected from redaction. self.run_test( { "type": "m.room.join_rules", @@ -344,6 +344,50 @@ class PruneEventTestCase(unittest.TestCase): room_version=RoomVersions.V8, ) + def test_member(self): + """Member events have changed behavior starting with MSC3375.""" + self.run_test( + { + "type": "m.room.member", + "event_id": "$test:domain", + "content": { + "membership": "join", + "join_authorised_via_users_server": "@user:domain", + "other_key": "stripped", + }, + }, + { + "type": "m.room.member", + "event_id": "$test:domain", + "content": {"membership": "join"}, + "signatures": {}, + "unsigned": {}, + }, + ) + + # After MSC3375, the join_authorised_via_users_server key is protected + # from redaction. + self.run_test( + { + "type": "m.room.member", + "content": { + "membership": "join", + "join_authorised_via_users_server": "@user:domain", + "other_key": "stripped", + }, + }, + { + "type": "m.room.member", + "content": { + "membership": "join", + "join_authorised_via_users_server": "@user:domain", + }, + "signatures": {}, + "unsigned": {}, + }, + room_version=RoomVersions.V9, + ) + class SerializeEventTestCase(unittest.TestCase): def serialize(self, ev, fields): -- cgit 1.5.1 From 92b6ac31b22fd2fafc5a68602357212fd87b0607 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 3 Sep 2021 18:23:46 +0100 Subject: Speed up MultiWriterIdGenerator when lots of IDs are in flight. (#10755) --- changelog.d/10755.misc | 1 + stubs/sortedcontainers/__init__.pyi | 2 + stubs/sortedcontainers/sortedset.pyi | 118 ++++++++++++++++++++++++++++++++++ synapse/storage/util/id_generators.py | 5 +- 4 files changed, 124 insertions(+), 2 deletions(-) create mode 100644 changelog.d/10755.misc create mode 100644 stubs/sortedcontainers/sortedset.pyi diff --git a/changelog.d/10755.misc b/changelog.d/10755.misc new file mode 100644 index 0000000000..3b7acff03f --- /dev/null +++ b/changelog.d/10755.misc @@ -0,0 +1 @@ +Minor speed ups when joining large rooms over federation. diff --git a/stubs/sortedcontainers/__init__.pyi b/stubs/sortedcontainers/__init__.pyi index fa307483fe..0602a4fa90 100644 --- a/stubs/sortedcontainers/__init__.pyi +++ b/stubs/sortedcontainers/__init__.pyi @@ -1,5 +1,6 @@ from .sorteddict import SortedDict, SortedItemsView, SortedKeysView, SortedValuesView from .sortedlist import SortedKeyList, SortedList, SortedListWithKey +from .sortedset import SortedSet __all__ = [ "SortedDict", @@ -9,4 +10,5 @@ __all__ = [ "SortedKeyList", "SortedList", "SortedListWithKey", + "SortedSet", ] diff --git a/stubs/sortedcontainers/sortedset.pyi b/stubs/sortedcontainers/sortedset.pyi new file mode 100644 index 0000000000..f9c2908386 --- /dev/null +++ b/stubs/sortedcontainers/sortedset.pyi @@ -0,0 +1,118 @@ +# stub for SortedSet. This is a lightly edited copy of +# https://github.com/grantjenks/python-sortedcontainers/blob/d0a225d7fd0fb4c54532b8798af3cbeebf97e2d5/sortedcontainers/sortedset.pyi +# (from https://github.com/grantjenks/python-sortedcontainers/pull/107) + +from typing import ( + AbstractSet, + Any, + Callable, + Generic, + Hashable, + Iterable, + Iterator, + List, + MutableSet, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +# --- Global + +_T = TypeVar("_T", bound=Hashable) +_S = TypeVar("_S", bound=Hashable) +_SS = TypeVar("_SS", bound=SortedSet) +_Key = Callable[[_T], Any] + +class SortedSet(MutableSet[_T], Sequence[_T]): + def __init__( + self, + iterable: Optional[Iterable[_T]] = ..., + key: Optional[_Key[_T]] = ..., + ) -> None: ... + @classmethod + def _fromset( + cls, values: Set[_T], key: Optional[_Key[_T]] = ... + ) -> SortedSet[_T]: ... + @property + def key(self) -> Optional[_Key[_T]]: ... + def __contains__(self, value: Any) -> bool: ... + @overload + def __getitem__(self, index: int) -> _T: ... + @overload + def __getitem__(self, index: slice) -> List[_T]: ... + def __delitem__(self, index: Union[int, slice]) -> None: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __lt__(self, other: Iterable[_T]) -> bool: ... + def __gt__(self, other: Iterable[_T]) -> bool: ... + def __le__(self, other: Iterable[_T]) -> bool: ... + def __ge__(self, other: Iterable[_T]) -> bool: ... + def __len__(self) -> int: ... + def __iter__(self) -> Iterator[_T]: ... + def __reversed__(self) -> Iterator[_T]: ... + def add(self, value: _T) -> None: ... + def _add(self, value: _T) -> None: ... + def clear(self) -> None: ... + def copy(self: _SS) -> _SS: ... + def __copy__(self: _SS) -> _SS: ... + def count(self, value: _T) -> int: ... + def discard(self, value: _T) -> None: ... + def _discard(self, value: _T) -> None: ... + def pop(self, index: int = ...) -> _T: ... + def remove(self, value: _T) -> None: ... + def difference(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... + def __sub__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... + def difference_update( + self, *iterables: Iterable[_S] + ) -> SortedSet[Union[_T, _S]]: ... + def __isub__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... + def intersection(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... + def __and__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... + def __rand__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... + def intersection_update( + self, *iterables: Iterable[_S] + ) -> SortedSet[Union[_T, _S]]: ... + def __iand__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... + def symmetric_difference(self, other: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... + def __xor__(self, other: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... + def __rxor__(self, other: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... + def symmetric_difference_update( + self, other: Iterable[_S] + ) -> SortedSet[Union[_T, _S]]: ... + def __ixor__(self, other: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... + def union(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... + def __or__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... + def __ror__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... + def update(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... + def __ior__(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... + def _update(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ... + def __reduce__( + self, + ) -> Tuple[Type[SortedSet[_T]], Set[_T], Callable[[_T], Any]]: ... + def __repr__(self) -> str: ... + def _check(self) -> None: ... + def bisect_left(self, value: _T) -> int: ... + def bisect_right(self, value: _T) -> int: ... + def islice( + self, + start: Optional[int] = ..., + stop: Optional[int] = ..., + reverse=bool, + ) -> Iterator[_T]: ... + def irange( + self, + minimum: Optional[_T] = ..., + maximum: Optional[_T] = ..., + inclusive: Tuple[bool, bool] = ..., + reverse: bool = ..., + ) -> Iterator[_T]: ... + def index( + self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ... + ) -> int: ... + def _reset(self, load: int) -> None: ... diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index c768fdea56..6f7cbe40f4 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -19,6 +19,7 @@ from contextlib import contextmanager from typing import Dict, Iterable, List, Optional, Set, Tuple, Union import attr +from sortedcontainers import SortedSet from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction @@ -240,7 +241,7 @@ class MultiWriterIdGenerator: # Set of local IDs that we're still processing. The current position # should be less than the minimum of this set (if not empty). - self._unfinished_ids: Set[int] = set() + self._unfinished_ids: SortedSet[int] = SortedSet() # Set of local IDs that we've processed that are larger than the current # position, due to there being smaller unpersisted IDs. @@ -473,7 +474,7 @@ class MultiWriterIdGenerator: finished = set() - min_unfinshed = min(self._unfinished_ids) + min_unfinshed = self._unfinished_ids[0] for s in self._finished_ids: if s < min_unfinshed: if new_cur is None or new_cur < s: -- cgit 1.5.1 From 1ca70fd312a860a6b486406de3b38ef60ac4abe6 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Sat, 4 Sep 2021 00:58:49 -0500 Subject: Allow room creator to send MSC2716 related events in existing room versions (#10566) * Allow room creator to send MSC2716 related events in existing room versions Discussed at https://github.com/matrix-org/matrix-doc/pull/2716/#discussion_r682474869 Restoring `get_create_event_for_room_txn` from, https://github.com/matrix-org/synapse/pull/10245/commits/44bb3f0cf5cb365ef9281554daceeecfb17cc94d * Add changelog * Stop people from trying to redact MSC2716 events in unsupported room versions * Populate rooms.creator column for easy lookup > From some [out of band discussion](https://matrix.to/#/!UytJQHLQYfvYWsGrGY:jki.re/$p2fKESoFst038x6pOOmsY0C49S2gLKMr0jhNMz_JJz0?via=jki.re&via=matrix.org), my plan is to use `rooms.creator`. But currently, we don't fill in `creator` for remote rooms when a user is invited to a room for example. So we need to add some code to fill in `creator` wherever we add to the `rooms` table. And also add a background update to fill in the rows missing `creator` (we can use the same logic that `get_create_event_for_room_txn` is doing by looking in the state events to get the `creator`). > > https://github.com/matrix-org/synapse/pull/10566#issuecomment-901616642 * Remove and switch away from get_create_event_for_room_txn * Fix no create event being found because no state events persisted yet * Fix and add tests for rooms creator bg update * Populate rooms.creator field for easy lookup Part of https://github.com/matrix-org/synapse/pull/10566 - Fill in creator whenever we insert into the rooms table - Add background update to backfill any missing creator values * Add changelog * Fix usage * Remove extra delta already included in #10697 * Don't worry about setting creator for invite * Only iterate over rows missing the creator See https://github.com/matrix-org/synapse/pull/10697#discussion_r695940898 * Use constant to fetch room creator field See https://github.com/matrix-org/synapse/pull/10697#discussion_r696803029 * More protection from other random types See https://github.com/matrix-org/synapse/pull/10697#discussion_r696806853 * Move new background update to end of list See https://github.com/matrix-org/synapse/pull/10697#discussion_r696814181 * Fix query casing * Fix ambiguity iterating over cursor instead of list Fix `psycopg2.ProgrammingError: no results to fetch` error when tests run with Postgres. ``` SYNAPSE_POSTGRES=1 SYNAPSE_TEST_LOG_LEVEL=INFO python -m twisted.trial tests.storage.databases.main.test_room ``` --- We use `txn.fetchall` because it will return the results as a list or an empty list when there are no results. Docs: > `cursor` objects are iterable, so, instead of calling explicitly fetchone() in a loop, the object itself can be used: > > https://www.psycopg.org/docs/cursor.html#cursor-iterable And I'm guessing iterating over a raw cursor does something weird when there are no results. --- Test CI failure: https://github.com/matrix-org/synapse/pull/10697/checks?check_run_id=3468916530 ``` tests.test_visibility.FilterEventsForServerTestCase.test_large_room =============================================================================== [FAIL] Traceback (most recent call last): File "/home/runner/work/synapse/synapse/tests/storage/databases/main/test_room.py", line 85, in test_background_populate_rooms_creator_column self.get_success( File "/home/runner/work/synapse/synapse/tests/unittest.py", line 500, in get_success return self.successResultOf(d) File "/home/runner/work/synapse/synapse/.tox/py/lib/python3.9/site-packages/twisted/trial/_synctest.py", line 700, in successResultOf self.fail( twisted.trial.unittest.FailTest: Success result expected on , found failure result instead: Traceback (most recent call last): File "/home/runner/work/synapse/synapse/.tox/py/lib/python3.9/site-packages/twisted/internet/defer.py", line 701, in errback self._startRunCallbacks(fail) File "/home/runner/work/synapse/synapse/.tox/py/lib/python3.9/site-packages/twisted/internet/defer.py", line 764, in _startRunCallbacks self._runCallbacks() File "/home/runner/work/synapse/synapse/.tox/py/lib/python3.9/site-packages/twisted/internet/defer.py", line 858, in _runCallbacks current.result = callback( # type: ignore[misc] File "/home/runner/work/synapse/synapse/.tox/py/lib/python3.9/site-packages/twisted/internet/defer.py", line 1751, in gotResult current_context.run(_inlineCallbacks, r, gen, status) --- --- File "/home/runner/work/synapse/synapse/.tox/py/lib/python3.9/site-packages/twisted/internet/defer.py", line 1657, in _inlineCallbacks result = current_context.run( File "/home/runner/work/synapse/synapse/.tox/py/lib/python3.9/site-packages/twisted/python/failure.py", line 500, in throwExceptionIntoGenerator return g.throw(self.type, self.value, self.tb) File "/home/runner/work/synapse/synapse/synapse/storage/background_updates.py", line 224, in do_next_background_update await self._do_background_update(desired_duration_ms) File "/home/runner/work/synapse/synapse/synapse/storage/background_updates.py", line 261, in _do_background_update items_updated = await update_handler(progress, batch_size) File "/home/runner/work/synapse/synapse/synapse/storage/databases/main/room.py", line 1399, in _background_populate_rooms_creator_column end = await self.db_pool.runInteraction( File "/home/runner/work/synapse/synapse/synapse/storage/database.py", line 686, in runInteraction result = await self.runWithConnection( File "/home/runner/work/synapse/synapse/synapse/storage/database.py", line 791, in runWithConnection return await make_deferred_yieldable( File "/home/runner/work/synapse/synapse/.tox/py/lib/python3.9/site-packages/twisted/internet/defer.py", line 858, in _runCallbacks current.result = callback( # type: ignore[misc] File "/home/runner/work/synapse/synapse/tests/server.py", line 425, in d.addCallback(lambda x: function(*args, **kwargs)) File "/home/runner/work/synapse/synapse/.tox/py/lib/python3.9/site-packages/twisted/enterprise/adbapi.py", line 293, in _runWithConnection compat.reraise(excValue, excTraceback) File "/home/runner/work/synapse/synapse/.tox/py/lib/python3.9/site-packages/twisted/python/deprecate.py", line 298, in deprecatedFunction return function(*args, **kwargs) File "/home/runner/work/synapse/synapse/.tox/py/lib/python3.9/site-packages/twisted/python/compat.py", line 404, in reraise raise exception.with_traceback(traceback) File "/home/runner/work/synapse/synapse/.tox/py/lib/python3.9/site-packages/twisted/enterprise/adbapi.py", line 284, in _runWithConnection result = func(conn, *args, **kw) File "/home/runner/work/synapse/synapse/synapse/storage/database.py", line 786, in inner_func return func(db_conn, *args, **kwargs) File "/home/runner/work/synapse/synapse/synapse/storage/database.py", line 554, in new_transaction r = func(cursor, *args, **kwargs) File "/home/runner/work/synapse/synapse/synapse/storage/databases/main/room.py", line 1375, in _background_populate_rooms_creator_column_txn for room_id, event_json in txn: psycopg2.ProgrammingError: no results to fetch ``` * Move code not under the MSC2716 room version underneath an experimental config option See https://github.com/matrix-org/synapse/pull/10566#issuecomment-906437909 * Add ordering to rooms creator background update See https://github.com/matrix-org/synapse/pull/10697#discussion_r696815277 * Add comment to better document constant See https://github.com/matrix-org/synapse/pull/10697#discussion_r699674458 * Use constant field --- changelog.d/10566.feature | 1 + synapse/handlers/federation_event.py | 10 ++++++++-- synapse/handlers/message.py | 28 +++++++++++++++++++++++++--- synapse/storage/databases/main/events.py | 32 +++++++++++++++++++++++++++----- 4 files changed, 61 insertions(+), 10 deletions(-) create mode 100644 changelog.d/10566.feature diff --git a/changelog.d/10566.feature b/changelog.d/10566.feature new file mode 100644 index 0000000000..04575d76a9 --- /dev/null +++ b/changelog.d/10566.feature @@ -0,0 +1 @@ +Allow room creators to send historical events specified by [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) in existing room versions. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 9f055f00cf..b622e3ae2d 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1023,9 +1023,15 @@ class FederationEventHandler(BaseHandler): return # Skip processing a marker event if the room version doesn't - # support it. + # support it or the event is not from the room creator. room_version = await self.store.get_room_version(marker_event.room_id) - if not room_version.msc2716_historical: + create_event = await self.store.get_create_event_for_room(marker_event.room_id) + room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) + if ( + not room_version.msc2716_historical + or not self.hs.config.experimental.msc2716_enabled + or marker_event.sender != room_creator + ): return logger.debug("_handle_marker_event: received %s", marker_event) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 101a29c6d3..9d2c897341 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1393,6 +1393,9 @@ class EventCreationHandler: allow_none=True, ) + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + # we can make some additional checks now if we have the original event. if original_event: if original_event.type == EventTypes.Create: @@ -1404,6 +1407,28 @@ class EventCreationHandler: if original_event.type == EventTypes.ServerACL: raise AuthError(403, "Redacting server ACL events is not permitted") + # Add a little safety stop-gap to prevent people from trying to + # redact MSC2716 related events when they're in a room version + # which does not support it yet. We allow people to use MSC2716 + # events in existing room versions but only from the room + # creator since it does not require any changes to the auth + # rules and in effect, the redaction algorithm . In the + # supported room version, we add the `historical` power level to + # auth the MSC2716 related events and adjust the redaction + # algorthim to keep the `historical` field around (redacting an + # event should only strip fields which don't affect the + # structural protocol level). + is_msc2716_event = ( + original_event.type == EventTypes.MSC2716_INSERTION + or original_event.type == EventTypes.MSC2716_CHUNK + or original_event.type == EventTypes.MSC2716_MARKER + ) + if not room_version_obj.msc2716_historical and is_msc2716_event: + raise AuthError( + 403, + "Redacting MSC2716 events is not supported in this room version", + ) + prev_state_ids = await context.get_prev_state_ids() auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=True @@ -1411,9 +1436,6 @@ class EventCreationHandler: auth_events_map = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()} - room_version = await self.store.get_room_version_id(event.room_id) - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - if event_auth.check_redaction( room_version_obj, event, auth_events=auth_events ): diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 096ae28788..c0ff4c0633 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1776,10 +1776,21 @@ class PersistEventsStore: # Not a insertion event return - # Skip processing a insertion event if the room version doesn't - # support it. + # Skip processing an insertion event if the room version doesn't + # support it or the event is not from the room creator. room_version = self.store.get_room_version_txn(txn, event.room_id) - if not room_version.msc2716_historical: + room_creator = self.db_pool.simple_select_one_onecol_txn( + txn, + table="rooms", + keyvalues={"room_id": event.room_id}, + retcol="creator", + allow_none=True, + ) + if ( + not room_version.msc2716_historical + or not self.hs.config.experimental.msc2716_enabled + or event.sender != room_creator + ): return next_chunk_id = event.content.get(EventContentFields.MSC2716_NEXT_CHUNK_ID) @@ -1828,9 +1839,20 @@ class PersistEventsStore: return # Skip processing a chunk event if the room version doesn't - # support it. + # support it or the event is not from the room creator. room_version = self.store.get_room_version_txn(txn, event.room_id) - if not room_version.msc2716_historical: + room_creator = self.db_pool.simple_select_one_onecol_txn( + txn, + table="rooms", + keyvalues={"room_id": event.room_id}, + retcol="creator", + allow_none=True, + ) + if ( + not room_version.msc2716_historical + or not self.hs.config.experimental.msc2716_enabled + or event.sender != room_creator + ): return chunk_id = event.content.get(EventContentFields.MSC2716_CHUNK_ID) -- cgit 1.5.1 From 2ca0d64854b0d369785f1cb76915a480df445792 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 6 Sep 2021 10:14:07 +0100 Subject: Speed up persisting redacted events (#10756) --- changelog.d/10756.misc | 1 + synapse/storage/databases/main/events.py | 22 +++++++++++----------- 2 files changed, 12 insertions(+), 11 deletions(-) create mode 100644 changelog.d/10756.misc diff --git a/changelog.d/10756.misc b/changelog.d/10756.misc new file mode 100644 index 0000000000..3b7acff03f --- /dev/null +++ b/changelog.d/10756.misc @@ -0,0 +1 @@ +Minor speed ups when joining large rooms over federation. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index c0ff4c0633..f07e288056 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1385,18 +1385,18 @@ class PersistEventsStore: # If we're persisting an unredacted event we go and ensure # that we mark any redactions that reference this event as # requiring censoring. - sql = "UPDATE redactions SET have_censored = ? WHERE redacts = ?" - txn.execute_batch( - sql, - ( - ( - False, - event.event_id, - ) - for event, _ in events_and_contexts - if not event.internal_metadata.is_redacted() - ), + unredacted_events = [ + event.event_id + for event, _ in events_and_contexts + if not event.internal_metadata.is_redacted() + ] + sql = "UPDATE redactions SET have_censored = ? WHERE " + clause, args = make_in_list_sql_clause( + self.database_engine, + "redacts", + unredacted_events, ) + txn.execute(sql + clause, [False] + args) state_events_and_contexts = [ ec for ec in events_and_contexts if ec[0].is_state() -- cgit 1.5.1 From 5e9b382505b9587983cb1db2606ccc559de7f121 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 6 Sep 2021 11:37:54 +0100 Subject: Pull out encrypted_by_default tests from user_directory tests (#10752) --- changelog.d/10752.misc | 1 + mypy.ini | 1 + tests/handlers/test_room.py | 108 ++++++++++++++++++++++++++++++++++ tests/handlers/test_user_directory.py | 96 +----------------------------- 4 files changed, 111 insertions(+), 95 deletions(-) create mode 100644 changelog.d/10752.misc create mode 100644 tests/handlers/test_room.py diff --git a/changelog.d/10752.misc b/changelog.d/10752.misc new file mode 100644 index 0000000000..5f9aa23018 --- /dev/null +++ b/changelog.d/10752.misc @@ -0,0 +1 @@ +Move tests relating to rooms having encryption out of the user_directory tests. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index 03dc6b8697..4096f72241 100644 --- a/mypy.ini +++ b/mypy.ini @@ -90,6 +90,7 @@ files = tests/test_event_auth.py, tests/test_utils, tests/handlers/test_password_providers.py, + tests/handlers/test_room.py, tests/handlers/test_room_summary.py, tests/handlers/test_send_email.py, tests/handlers/test_sync.py, diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py new file mode 100644 index 0000000000..fcde5dab72 --- /dev/null +++ b/tests/handlers/test_room.py @@ -0,0 +1,108 @@ +import synapse +from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms +from synapse.rest.client import login, room + +from tests import unittest +from tests.unittest import override_config + + +class EncryptedByDefaultTestCase(unittest.HomeserverTestCase): + servlets = [ + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + ] + + @override_config({"encryption_enabled_by_default_for_room_type": "all"}) + def test_encrypted_by_default_config_option_all(self): + """Tests that invite-only and non-invite-only rooms have encryption enabled by + default when the config option encryption_enabled_by_default_for_room_type is "all". + """ + # Create a user + user = self.register_user("user", "pass") + user_token = self.login(user, "pass") + + # Create an invite-only room as that user + room_id = self.helper.create_room_as(user, is_public=False, tok=user_token) + + # Check that the room has an encryption state event + event_content = self.helper.get_state( + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, + ) + self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) + + # Create a non invite-only room as that user + room_id = self.helper.create_room_as(user, is_public=True, tok=user_token) + + # Check that the room has an encryption state event + event_content = self.helper.get_state( + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, + ) + self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) + + @override_config({"encryption_enabled_by_default_for_room_type": "invite"}) + def test_encrypted_by_default_config_option_invite(self): + """Tests that only new, invite-only rooms have encryption enabled by default when + the config option encryption_enabled_by_default_for_room_type is "invite". + """ + # Create a user + user = self.register_user("user", "pass") + user_token = self.login(user, "pass") + + # Create an invite-only room as that user + room_id = self.helper.create_room_as(user, is_public=False, tok=user_token) + + # Check that the room has an encryption state event + event_content = self.helper.get_state( + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, + ) + self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) + + # Create a non invite-only room as that user + room_id = self.helper.create_room_as(user, is_public=True, tok=user_token) + + # Check that the room does not have an encryption state event + self.helper.get_state( + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, + expect_code=404, + ) + + @override_config({"encryption_enabled_by_default_for_room_type": "off"}) + def test_encrypted_by_default_config_option_off(self): + """Tests that neither new invite-only nor non-invite-only rooms have encryption + enabled by default when the config option + encryption_enabled_by_default_for_room_type is "off". + """ + # Create a user + user = self.register_user("user", "pass") + user_token = self.login(user, "pass") + + # Create an invite-only room as that user + room_id = self.helper.create_room_as(user, is_public=False, tok=user_token) + + # Check that the room does not have an encryption state event + self.helper.get_state( + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, + expect_code=404, + ) + + # Create a non invite-only room as that user + room_id = self.helper.create_room_as(user, is_public=True, tok=user_token) + + # Check that the room does not have an encryption state event + self.helper.get_state( + room_id=room_id, + event_type=EventTypes.RoomEncryption, + tok=user_token, + expect_code=404, + ) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index e44bf2b3b1..a91d31ce61 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -16,7 +16,7 @@ from unittest.mock import Mock from twisted.internet import defer import synapse.rest.admin -from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes +from synapse.api.constants import UserTypes from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.rest.client import login, room, user_directory from synapse.storage.roommember import ProfileInfo @@ -187,100 +187,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user3", 10)) self.assertEqual(len(s["results"]), 0) - @override_config({"encryption_enabled_by_default_for_room_type": "all"}) - def test_encrypted_by_default_config_option_all(self): - """Tests that invite-only and non-invite-only rooms have encryption enabled by - default when the config option encryption_enabled_by_default_for_room_type is "all". - """ - # Create a user - user = self.register_user("user", "pass") - user_token = self.login(user, "pass") - - # Create an invite-only room as that user - room_id = self.helper.create_room_as(user, is_public=False, tok=user_token) - - # Check that the room has an encryption state event - event_content = self.helper.get_state( - room_id=room_id, - event_type=EventTypes.RoomEncryption, - tok=user_token, - ) - self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) - - # Create a non invite-only room as that user - room_id = self.helper.create_room_as(user, is_public=True, tok=user_token) - - # Check that the room has an encryption state event - event_content = self.helper.get_state( - room_id=room_id, - event_type=EventTypes.RoomEncryption, - tok=user_token, - ) - self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) - - @override_config({"encryption_enabled_by_default_for_room_type": "invite"}) - def test_encrypted_by_default_config_option_invite(self): - """Tests that only new, invite-only rooms have encryption enabled by default when - the config option encryption_enabled_by_default_for_room_type is "invite". - """ - # Create a user - user = self.register_user("user", "pass") - user_token = self.login(user, "pass") - - # Create an invite-only room as that user - room_id = self.helper.create_room_as(user, is_public=False, tok=user_token) - - # Check that the room has an encryption state event - event_content = self.helper.get_state( - room_id=room_id, - event_type=EventTypes.RoomEncryption, - tok=user_token, - ) - self.assertEqual(event_content, {"algorithm": RoomEncryptionAlgorithms.DEFAULT}) - - # Create a non invite-only room as that user - room_id = self.helper.create_room_as(user, is_public=True, tok=user_token) - - # Check that the room does not have an encryption state event - self.helper.get_state( - room_id=room_id, - event_type=EventTypes.RoomEncryption, - tok=user_token, - expect_code=404, - ) - - @override_config({"encryption_enabled_by_default_for_room_type": "off"}) - def test_encrypted_by_default_config_option_off(self): - """Tests that neither new invite-only nor non-invite-only rooms have encryption - enabled by default when the config option - encryption_enabled_by_default_for_room_type is "off". - """ - # Create a user - user = self.register_user("user", "pass") - user_token = self.login(user, "pass") - - # Create an invite-only room as that user - room_id = self.helper.create_room_as(user, is_public=False, tok=user_token) - - # Check that the room does not have an encryption state event - self.helper.get_state( - room_id=room_id, - event_type=EventTypes.RoomEncryption, - tok=user_token, - expect_code=404, - ) - - # Create a non invite-only room as that user - room_id = self.helper.create_room_as(user, is_public=True, tok=user_token) - - # Check that the room does not have an encryption state event - self.helper.get_state( - room_id=room_id, - event_type=EventTypes.RoomEncryption, - tok=user_token, - expect_code=404, - ) - def test_spam_checker(self): """ A user which fails the spam checks will not appear in search results. -- cgit 1.5.1 From 56e2a306348e10b2abfa4a615ebf9c6d43177f6e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 6 Sep 2021 12:17:16 +0100 Subject: Move `maybe_kick_guest_users` out of `BaseHandler` (#10744) This is part of my ongoing war against BaseHandler. I've moved kick_guest_users into RoomMemberHandler (since it calls out to that handler anyway), and split maybe_kick_guest_users into the two places it is called. --- changelog.d/10744.misc | 1 + synapse/api/constants.py | 9 +++++ synapse/handlers/_base.py | 68 ------------------------------------ synapse/handlers/federation_event.py | 17 +++++++-- synapse/handlers/message.py | 27 ++++++++++++-- synapse/handlers/room.py | 5 ++- synapse/handlers/room_list.py | 12 +++++-- synapse/handlers/room_member.py | 65 ++++++++++++++++++++++++++++++---- synapse/handlers/stats.py | 6 ++-- 9 files changed, 125 insertions(+), 85 deletions(-) create mode 100644 changelog.d/10744.misc diff --git a/changelog.d/10744.misc b/changelog.d/10744.misc new file mode 100644 index 0000000000..f0f789ce3c --- /dev/null +++ b/changelog.d/10744.misc @@ -0,0 +1 @@ +Move `kick_guest_users` into `RoomMemberHandler`. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 5e34eb7e13..5f0f34119b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -201,6 +201,9 @@ class EventContentFields: # The creator of the room, as used in `m.room.create` events. ROOM_CREATOR = "creator" + # Used in m.room.guest_access events. + GUEST_ACCESS = "guest_access" + # Used on normal messages to indicate they were historically imported after the fact MSC2716_HISTORICAL = "org.matrix.msc2716.historical" # For "insertion" events to indicate what the next chunk ID should be in @@ -235,5 +238,11 @@ class HistoryVisibility: WORLD_READABLE = "world_readable" +class GuestAccess: + CAN_JOIN = "can_join" + # anything that is not "can_join" is considered "forbidden", but for completeness: + FORBIDDEN = "forbidden" + + class ReadReceiptEventFields: MSC2285_HIDDEN = "org.matrix.msc2285.hidden" diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 6a05a65305..955cfa2207 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -15,10 +15,7 @@ import logging from typing import TYPE_CHECKING, Optional -import synapse.types -from synapse.api.constants import EventTypes, Membership from synapse.api.ratelimiting import Ratelimiter -from synapse.types import UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -115,68 +112,3 @@ class BaseHandler: burst_count=burst_count, update=update, ) - - async def maybe_kick_guest_users(self, event, context=None): - # Technically this function invalidates current_state by changing it. - # Hopefully this isn't that important to the caller. - if event.type == EventTypes.GuestAccess: - guest_access = event.content.get("guest_access", "forbidden") - if guest_access != "can_join": - if context: - current_state_ids = await context.get_current_state_ids() - current_state_dict = await self.store.get_events( - list(current_state_ids.values()) - ) - current_state = list(current_state_dict.values()) - else: - current_state_map = await self.state_handler.get_current_state( - event.room_id - ) - current_state = list(current_state_map.values()) - - logger.info("maybe_kick_guest_users %r", current_state) - await self.kick_guest_users(current_state) - - async def kick_guest_users(self, current_state): - for member_event in current_state: - try: - if member_event.type != EventTypes.Member: - continue - - target_user = UserID.from_string(member_event.state_key) - if not self.hs.is_mine(target_user): - continue - - if member_event.content["membership"] not in { - Membership.JOIN, - Membership.INVITE, - }: - continue - - if ( - "kind" not in member_event.content - or member_event.content["kind"] != "guest" - ): - continue - - # We make the user choose to leave, rather than have the - # event-sender kick them. This is partially because we don't - # need to worry about power levels, and partially because guest - # users are a concept which doesn't hugely work over federation, - # and having homeservers have their own users leave keeps more - # of that decision-making and control local to the guest-having - # homeserver. - requester = synapse.types.create_requester( - target_user, is_guest=True, authenticated_entity=self.server_name - ) - handler = self.hs.get_room_member_handler() - await handler.update_membership( - requester, - target_user, - member_event.room_id, - "leave", - ratelimit=False, - require_consent=False, - ) - except Exception as e: - logger.exception("Error kicking guest user: %s" % (e,)) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index b622e3ae2d..3414747f49 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -36,6 +36,7 @@ from synapse import event_auth from synapse.api.constants import ( EventContentFields, EventTypes, + GuestAccess, Membership, RejectedReason, RoomEncryptionAlgorithms, @@ -1327,9 +1328,7 @@ class FederationEventHandler(BaseHandler): if not context.rejected: await self._check_for_soft_fail(event, state, backfilled, origin=origin) - - if event.type == EventTypes.GuestAccess and not context.rejected: - await self.maybe_kick_guest_users(event) + await self._maybe_kick_guest_users(event) # If we are going to send this event over federation we precaclculate # the joined hosts. @@ -1340,6 +1339,18 @@ class FederationEventHandler(BaseHandler): return context + async def _maybe_kick_guest_users(self, event: EventBase) -> None: + if event.type != EventTypes.GuestAccess: + return + + guest_access = event.content.get(EventContentFields.GUEST_ACCESS) + if guest_access == GuestAccess.CAN_JOIN: + return + + current_state_map = await self.state_handler.get_current_state(event.room_id) + current_state = list(current_state_map.values()) + await self.hs.get_room_member_handler().kick_guest_users(current_state) + async def _check_for_soft_fail( self, event: EventBase, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9d2c897341..bf0fef1510 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -27,6 +27,7 @@ from synapse import event_auth from synapse.api.constants import ( EventContentFields, EventTypes, + GuestAccess, Membership, RelationTypes, UserTypes, @@ -426,7 +427,7 @@ class EventCreationHandler: self.send_event = ReplicationSendEventRestServlet.make_client(hs) - # This is only used to get at ratelimit function, and maybe_kick_guest_users + # This is only used to get at ratelimit function self.base_handler = BaseHandler(hs) # We arbitrarily limit concurrent event creation for a room to 5. @@ -1306,7 +1307,7 @@ class EventCreationHandler: requester, is_admin_redaction=is_admin_redaction ) - await self.base_handler.maybe_kick_guest_users(event, context) + await self._maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: # Validate a newly added alias or newly added alt_aliases. @@ -1493,6 +1494,28 @@ class EventCreationHandler: return event + async def _maybe_kick_guest_users( + self, event: EventBase, context: EventContext + ) -> None: + if event.type != EventTypes.GuestAccess: + return + + guest_access = event.content.get(EventContentFields.GUEST_ACCESS) + if guest_access == GuestAccess.CAN_JOIN: + return + + current_state_ids = await context.get_current_state_ids() + + # since this is a client-generated event, it cannot be an outlier and we must + # therefore have the state ids. + assert current_state_ids is not None + current_state_dict = await self.store.get_events( + list(current_state_ids.values()) + ) + current_state = list(current_state_dict.values()) + logger.info("maybe_kick_guest_users %r", current_state) + await self.hs.get_room_member_handler().kick_guest_users(current_state) + async def _bump_active_time(self, user: UserID) -> None: try: presence = self.hs.get_presence_handler() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index ed780bb41f..0235fd09b4 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -25,7 +25,9 @@ from collections import OrderedDict from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple from synapse.api.constants import ( + EventContentFields, EventTypes, + GuestAccess, HistoryVisibility, JoinRules, Membership, @@ -993,7 +995,8 @@ class RoomCreationHandler(BaseHandler): if config["guest_can_join"]: if (EventTypes.GuestAccess, "") not in initial_state: last_sent_stream_id = await send( - etype=EventTypes.GuestAccess, content={"guest_access": "can_join"} + etype=EventTypes.GuestAccess, + content={EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN}, ) for (etype, state_key), content in initial_state.items(): diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 6d433fad41..92bb75c848 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -19,7 +19,13 @@ from typing import TYPE_CHECKING, Optional, Tuple import msgpack from unpaddedbase64 import decode_base64, encode_base64 -from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules +from synapse.api.constants import ( + EventContentFields, + EventTypes, + GuestAccess, + HistoryVisibility, + JoinRules, +) from synapse.api.errors import ( Codes, HttpResponseException, @@ -336,8 +342,8 @@ class RoomListHandler(BaseHandler): guest_event = current_state.get((EventTypes.GuestAccess, "")) guest = None if guest_event: - guest = guest_event.content.get("guest_access", None) - result["guest_can_join"] = guest == "can_join" + guest = guest_event.content.get(EventContentFields.GUEST_ACCESS) + result["guest_can_join"] = guest == GuestAccess.CAN_JOIN avatar_event = current_state.get(("m.room.avatar", "")) if avatar_event: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 401b84aad1..4390201641 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -23,6 +23,7 @@ from synapse.api.constants import ( AccountDataTypes, EventContentFields, EventTypes, + GuestAccess, Membership, ) from synapse.api.errors import ( @@ -44,6 +45,7 @@ from synapse.types import ( RoomID, StateMap, UserID, + create_requester, get_domain_from_id, ) from synapse.util.async_helpers import Linearizer @@ -70,6 +72,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.config = hs.config + self._server_name = hs.hostname self.federation_handler = hs.get_federation_handler() self.directory_handler = hs.get_directory_handler() @@ -115,9 +118,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, ) - # This is only used to get at ratelimit function, and - # maybe_kick_guest_users. It's fine there are multiple of these as - # it doesn't store state. + # This is only used to get at the ratelimit function. It's fine there are + # multiple of these as it doesn't store state. self.base_handler = BaseHandler(hs) @abc.abstractmethod @@ -1095,10 +1097,62 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): return bool( guest_access and guest_access.content - and "guest_access" in guest_access.content - and guest_access.content["guest_access"] == "can_join" + and guest_access.content.get(EventContentFields.GUEST_ACCESS) + == GuestAccess.CAN_JOIN ) + async def kick_guest_users(self, current_state: Iterable[EventBase]) -> None: + """Kick any local guest users from the room. + + This is called when the room state changes from guests allowed to not-allowed. + + Params: + current_state: the current state of the room. We will iterate this to look + for guest users to kick. + """ + for member_event in current_state: + try: + if member_event.type != EventTypes.Member: + continue + + if not self.hs.is_mine_id(member_event.state_key): + continue + + if member_event.content["membership"] not in { + Membership.JOIN, + Membership.INVITE, + }: + continue + + if ( + "kind" not in member_event.content + or member_event.content["kind"] != "guest" + ): + continue + + # We make the user choose to leave, rather than have the + # event-sender kick them. This is partially because we don't + # need to worry about power levels, and partially because guest + # users are a concept which doesn't hugely work over federation, + # and having homeservers have their own users leave keeps more + # of that decision-making and control local to the guest-having + # homeserver. + target_user = UserID.from_string(member_event.state_key) + requester = create_requester( + target_user, is_guest=True, authenticated_entity=self._server_name + ) + handler = self.hs.get_room_member_handler() + await handler.update_membership( + requester, + target_user, + member_event.room_id, + "leave", + ratelimit=False, + require_consent=False, + ) + except Exception as e: + logger.exception("Error kicking guest user: %s" % (e,)) + async def lookup_room_alias( self, room_alias: RoomAlias ) -> Tuple[RoomID, List[str]]: @@ -1352,7 +1406,6 @@ class RoomMemberMasterHandler(RoomMemberHandler): self.distributor = hs.get_distributor() self.distributor.declare("user_left_room") - self._server_name = hs.hostname async def _is_remote_room_too_complex( self, room_id: str, remote_room_hosts: List[str] diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 3fd89af2a4..3a4c41c9ff 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple from typing_extensions import Counter as CounterType -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import JsonDict @@ -273,7 +273,9 @@ class StatsHandler: elif typ == EventTypes.CanonicalAlias: room_state["canonical_alias"] = event_content.get("alias") elif typ == EventTypes.GuestAccess: - room_state["guest_access"] = event_content.get("guest_access") + room_state["guest_access"] = event_content.get( + EventContentFields.GUEST_ACCESS + ) for room_id, state in room_to_state_updates.items(): logger.debug("Updating room_stats_state for %s: %s", room_id, state) -- cgit 1.5.1 From e1641b46d19c9745f512e623544b2bddfc89551d Mon Sep 17 00:00:00 2001 From: David Teller Date: Mon, 6 Sep 2021 15:24:31 +0200 Subject: Doc: Clarifying undoing room shutdowns (#10480) --- changelog.d/10735.doc | 1 + docs/admin_api/rooms.md | 42 +++++++++++++++++++++++++++--------------- 2 files changed, 28 insertions(+), 15 deletions(-) create mode 100644 changelog.d/10735.doc diff --git a/changelog.d/10735.doc b/changelog.d/10735.doc new file mode 100644 index 0000000000..5d6207afb9 --- /dev/null +++ b/changelog.d/10735.doc @@ -0,0 +1 @@ +Clarify admin API documentation on undoing room deletions. diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 48777dd231..8e524e6509 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -481,32 +481,44 @@ The following fields are returned in the JSON response body: * `new_room_id` - A string representing the room ID of the new room. -## Undoing room shutdowns +## Undoing room deletions -*Note*: This guide may be outdated by the time you read it. By nature of room shutdowns being performed at the database level, +*Note*: This guide may be outdated by the time you read it. By nature of room deletions being performed at the database level, the structure can and does change without notice. -First, it's important to understand that a room shutdown is very destructive. Undoing a shutdown is not as simple as pretending it +First, it's important to understand that a room deletion is very destructive. Undoing a deletion is not as simple as pretending it never happened - work has to be done to move forward instead of resetting the past. In fact, in some cases it might not be possible to recover at all: * If the room was invite-only, your users will need to be re-invited. * If the room no longer has any members at all, it'll be impossible to rejoin. -* The first user to rejoin will have to do so via an alias on a different server. +* The first user to rejoin will have to do so via an alias on a different + server (or receive an invite from a user on a different server). With all that being said, if you still want to try and recover the room: -1. For safety reasons, shut down Synapse. -2. In the database, run `DELETE FROM blocked_rooms WHERE room_id = '!example:example.org';` - * For caution: it's recommended to run this in a transaction: `BEGIN; DELETE ...;`, verify you got 1 result, then `COMMIT;`. - * The room ID is the same one supplied to the shutdown room API, not the Content Violation room. -3. Restart Synapse. +1. If the room was `block`ed, you must unblock it on your server. This can be + accomplished as follows: -You will have to manually handle, if you so choose, the following: + 1. For safety reasons, shut down Synapse. + 2. In the database, run `DELETE FROM blocked_rooms WHERE room_id = '!example:example.org';` + * For caution: it's recommended to run this in a transaction: `BEGIN; DELETE ...;`, verify you got 1 result, then `COMMIT;`. + * The room ID is the same one supplied to the delete room API, not the Content Violation room. + 3. Restart Synapse. -* Aliases that would have been redirected to the Content Violation room. -* Users that would have been booted from the room (and will have been force-joined to the Content Violation room). -* Removal of the Content Violation room if desired. + This step is unnecessary if `block` was not set. + +2. Any room aliases on your server that pointed to the deleted room may have + been deleted, or redirected to the Content Violation room. These will need + to be restored manually. + +3. Users on your server that were in the deleted room will have been kicked + from the room. Consider whether you want to update their membership + (possibly via the [Edit Room Membership API](room_membership.md)) or let + them handle rejoining themselves. + +4. If `new_room_user_id` was given, a 'Content Violation' will have been + created. Consider whether you want to delete that roomm. ## Deprecated endpoint @@ -536,7 +548,7 @@ POST /_synapse/admin/v1/rooms//make_room_admin # Forward Extremities Admin API Enables querying and deleting forward extremities from rooms. When a lot of forward -extremities accumulate in a room, performance can become degraded. For details, see +extremities accumulate in a room, performance can become degraded. For details, see [#1760](https://github.com/matrix-org/synapse/issues/1760). ## Check for forward extremities @@ -565,7 +577,7 @@ A response as follows will be returned: ## Deleting forward extremities -**WARNING**: Please ensure you know what you're doing and have read +**WARNING**: Please ensure you know what you're doing and have read the related issue [#1760](https://github.com/matrix-org/synapse/issues/1760). Under no situations should this API be executed as an automated maintenance task! -- cgit 1.5.1 From 7bb3673f37d02d8bcda586daaac3334aebe81195 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 6 Sep 2021 14:35:56 +0100 Subject: Ease searching for M_TOO_LARGE-related error codes (#10750) --- changelog.d/10750.misc | 1 + synapse/event_auth.py | 15 ++++++--------- 2 files changed, 7 insertions(+), 9 deletions(-) create mode 100644 changelog.d/10750.misc diff --git a/changelog.d/10750.misc b/changelog.d/10750.misc new file mode 100644 index 0000000000..ded5cf626c --- /dev/null +++ b/changelog.d/10750.misc @@ -0,0 +1 @@ +Refactor event size checking code to simplify searching the codebase for the origins of certain error strings that are occasionally emitted. \ No newline at end of file diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c3a0c10499..b63a1afe93 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -216,21 +216,18 @@ def check( def _check_size_limits(event: EventBase) -> None: - def too_big(field): - raise EventSizeError("%s too large" % (field,)) - if len(event.user_id) > 255: - too_big("user_id") + raise EventSizeError("'user_id' too large") if len(event.room_id) > 255: - too_big("room_id") + raise EventSizeError("'room_id' too large") if event.is_state() and len(event.state_key) > 255: - too_big("state_key") + raise EventSizeError("'state_key' too large") if len(event.type) > 255: - too_big("type") + raise EventSizeError("'type' too large") if len(event.event_id) > 255: - too_big("event_id") + raise EventSizeError("'event_id' too large") if len(encode_canonical_json(event.get_pdu_json())) > MAX_PDU_SIZE: - too_big("event") + raise EventSizeError("event too large") def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool: -- cgit 1.5.1 From 40a1fddd1b7843c25248fca409debab6d8393c59 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 6 Sep 2021 14:37:15 +0100 Subject: Allow `room_alias_name` parameter to be handled by /createRoom calls on workers (#10757) --- changelog.d/10757.bugfix | 1 + synapse/storage/databases/main/directory.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/10757.bugfix diff --git a/changelog.d/10757.bugfix b/changelog.d/10757.bugfix new file mode 100644 index 0000000000..bce36ef242 --- /dev/null +++ b/changelog.d/10757.bugfix @@ -0,0 +1 @@ +Fix a bug which prevented calls to `/createRoom` that included the `room_alias_name` parameter from being handled by worker processes. \ No newline at end of file diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 86075bc55b..6daf8b8ffb 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -75,8 +75,6 @@ class DirectoryWorkerStore(SQLBaseStore): desc="get_aliases_for_room", ) - -class DirectoryStore(DirectoryWorkerStore): async def create_room_alias_association( self, room_alias: RoomAlias, @@ -126,6 +124,8 @@ class DirectoryStore(DirectoryWorkerStore): 409, "Room alias %s already exists" % room_alias.to_string() ) + +class DirectoryStore(DirectoryWorkerStore): async def delete_room_alias(self, room_alias: RoomAlias) -> str: room_id = await self.db_pool.runInteraction( "delete_room_alias", self._delete_room_alias_txn, room_alias -- cgit 1.5.1 From b298de780a488274846cb76f986f2f66a92bb64a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 6 Sep 2021 14:49:33 +0100 Subject: Stop using BaseHandler in `FederationEventHandler` (#10745) It's now only used in a couple of places, so we can drop it altogether. --- changelog.d/10744.misc | 2 +- changelog.d/10745.misc | 1 + synapse/handlers/federation_event.py | 19 ++++++++++--------- 3 files changed, 12 insertions(+), 10 deletions(-) create mode 100644 changelog.d/10745.misc diff --git a/changelog.d/10744.misc b/changelog.d/10744.misc index f0f789ce3c..9a765435db 100644 --- a/changelog.d/10744.misc +++ b/changelog.d/10744.misc @@ -1 +1 @@ -Move `kick_guest_users` into `RoomMemberHandler`. +Clean up some of the federation event authentication code for clarity. diff --git a/changelog.d/10745.misc b/changelog.d/10745.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10745.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 3414747f49..afeb2892dc 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -54,7 +54,6 @@ from synapse.event_auth import auth_types_for_event from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.federation.federation_client import InvalidResponseError -from synapse.handlers._base import BaseHandler from synapse.logging.context import ( make_deferred_yieldable, nested_logging_context, @@ -117,7 +116,7 @@ class _NewEventInfo: claimed_auth_event_map: StateMap[EventBase] -class FederationEventHandler(BaseHandler): +class FederationEventHandler: """Handles events that originated from federation. Responsible for handing incoming events and passing them on to the rest @@ -125,8 +124,6 @@ class FederationEventHandler(BaseHandler): """ def __init__(self, hs: "HomeServer"): - super().__init__(hs) - self.store = hs.get_datastore() self.storage = hs.get_storage() self.state_store = self.storage.state @@ -137,11 +134,15 @@ class FederationEventHandler(BaseHandler): self._message_handler = hs.get_message_handler() self.action_generator = hs.get_action_generator() self._state_resolution_handler = hs.get_state_resolution_handler() + # avoid a circular dependency by deferring execution here + self._get_room_member_handler = hs.get_room_member_handler self.federation_client = hs.get_federation_client() self.third_party_event_rules = hs.get_third_party_event_rules() + self._notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id + self._server_name = hs.hostname self._instance_name = hs.get_instance_name() self.config = hs.config @@ -222,7 +223,7 @@ class FederationEventHandler(BaseHandler): # Note that if we were never in the room then we would have already # dropped the event, since we wouldn't know the room version. is_in_room = await self._event_auth_handler.check_host_in_room( - room_id, self.server_name + room_id, self._server_name ) if not is_in_room: logger.info( @@ -435,7 +436,7 @@ class FederationEventHandler(BaseHandler): server from invalid events (there is probably no point in trying to re-fetch invalid events from every other HS in the room.) """ - if dest == self.server_name: + if dest == self._server_name: raise SynapseError(400, "Can't backfill from self.") events = await self.federation_client.backfill( @@ -1030,7 +1031,7 @@ class FederationEventHandler(BaseHandler): room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) if ( not room_version.msc2716_historical - or not self.hs.config.experimental.msc2716_enabled + or not self.config.experimental.msc2716_enabled or marker_event.sender != room_creator ): return @@ -1349,7 +1350,7 @@ class FederationEventHandler(BaseHandler): current_state_map = await self.state_handler.get_current_state(event.room_id) current_state = list(current_state_map.values()) - await self.hs.get_room_member_handler().kick_guest_users(current_state) + await self._get_room_member_handler().kick_guest_users(current_state) async def _check_for_soft_fail( self, @@ -1804,7 +1805,7 @@ class FederationEventHandler(BaseHandler): event_pos = PersistedEventPosition( self._instance_name, event.internal_metadata.stream_ordering ) - self.notifier.on_new_room_event( + self._notifier.on_new_room_event( event, event_pos, max_stream_token, extra_users=extra_users ) -- cgit 1.5.1 From 8c9e723fe0414c7d93aa93180cdedc163b2058c5 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 6 Sep 2021 16:23:50 +0200 Subject: Add a warning when using deprecated template_dir settings (#10768) The deprecation itself happened in #10596 which shipped with Synapse v1.41.0. However, it doesn't seem fair to suddenly drop support for these settings in ~4-6w without being more vocal about said deprecation. --- changelog.d/10768.misc | 1 + synapse/config/account_validity.py | 14 ++++++++++++++ synapse/config/emailconfig.py | 14 ++++++++++++++ synapse/config/sso.py | 13 +++++++++++++ 4 files changed, 42 insertions(+) create mode 100644 changelog.d/10768.misc diff --git a/changelog.d/10768.misc b/changelog.d/10768.misc new file mode 100644 index 0000000000..afd64ca1b0 --- /dev/null +++ b/changelog.d/10768.misc @@ -0,0 +1 @@ +Print a warning when using one of the deprecated `template_dir` settings. diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py index 52e63ab1f6..ffaffc4931 100644 --- a/synapse/config/account_validity.py +++ b/synapse/config/account_validity.py @@ -11,8 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging + from synapse.config._base import Config, ConfigError +logger = logging.getLogger(__name__) + +LEGACY_TEMPLATE_DIR_WARNING = """ +This server's configuration file is using the deprecated 'template_dir' setting in the +'account_validity' section. Support for this setting has been deprecated and will be +removed in a future version of Synapse. Server admins should instead use the new +'custom_templates_directory' setting documented here: +https://matrix-org.github.io/synapse/latest/templates.html +---------------------------------------------------------------------------------------""" + class AccountValidityConfig(Config): section = "account_validity" @@ -69,6 +81,8 @@ class AccountValidityConfig(Config): # Load account validity templates. account_validity_template_dir = account_validity_config.get("template_dir") + if account_validity_template_dir is not None: + logger.warning(LEGACY_TEMPLATE_DIR_WARNING) account_renewed_template_filename = account_validity_config.get( "account_renewed_html_path", "account_renewed.html" diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 4477419196..936abe6178 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -16,6 +16,7 @@ # This file can't be called email.py because if it is, we cannot: import email.utils +import logging import os from enum import Enum from typing import Optional @@ -24,6 +25,8 @@ import attr from ._base import Config, ConfigError +logger = logging.getLogger(__name__) + MISSING_PASSWORD_RESET_CONFIG_ERROR = """\ Password reset emails are enabled on this homeserver due to a partial 'email' block. However, the following required keys are missing: @@ -44,6 +47,14 @@ DEFAULT_SUBJECTS = { "email_validation": "[%(server_name)s] Validate your email", } +LEGACY_TEMPLATE_DIR_WARNING = """ +This server's configuration file is using the deprecated 'template_dir' setting in the +'email' section. Support for this setting has been deprecated and will be removed in a +future version of Synapse. Server admins should instead use the new +'custom_templates_directory' setting documented here: +https://matrix-org.github.io/synapse/latest/templates.html +---------------------------------------------------------------------------------------""" + @attr.s(slots=True, frozen=True) class EmailSubjectConfig: @@ -105,6 +116,9 @@ class EmailConfig(Config): # A user-configurable template directory template_dir = email_config.get("template_dir") + if template_dir is not None: + logger.warning(LEGACY_TEMPLATE_DIR_WARNING) + if isinstance(template_dir, str): # We need an absolute path, because we change directory after starting (and # we don't yet know what auxiliary templates like mail.css we will need). diff --git a/synapse/config/sso.py b/synapse/config/sso.py index fe1177ab81..524a7ff3aa 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -11,12 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import Any, Dict, Optional import attr from ._base import Config +logger = logging.getLogger(__name__) + +LEGACY_TEMPLATE_DIR_WARNING = """ +This server's configuration file is using the deprecated 'template_dir' setting in the +'sso' section. Support for this setting has been deprecated and will be removed in a +future version of Synapse. Server admins should instead use the new +'custom_templates_directory' setting documented here: +https://matrix-org.github.io/synapse/latest/templates.html +---------------------------------------------------------------------------------------""" + @attr.s(frozen=True) class SsoAttributeRequirement: @@ -43,6 +54,8 @@ class SSOConfig(Config): # The sso-specific template_dir self.sso_template_dir = sso_config.get("template_dir") + if self.sso_template_dir is not None: + logger.warning(LEGACY_TEMPLATE_DIR_WARNING) # Read templates from disk custom_template_directories = ( -- cgit 1.5.1 From e9958d908da9b0d29d7e93538c12692519ebcfd6 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 6 Sep 2021 15:25:23 +0100 Subject: 1.42.0rc2 --- CHANGES.md | 15 +++++++++++++++ changelog.d/10747.feature | 1 - changelog.d/10768.misc | 1 - debian/changelog | 6 ++++++ synapse/__init__.py | 2 +- 5 files changed, 22 insertions(+), 3 deletions(-) delete mode 100644 changelog.d/10747.feature delete mode 100644 changelog.d/10768.misc diff --git a/CHANGES.md b/CHANGES.md index 986efbba0d..3ef521f225 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,18 @@ +Synapse 1.42.0rc2 (2021-09-06) +============================== + +Features +-------- + +- Support room version 9 from [MSC3375](https://github.com/matrix-org/matrix-doc/pull/3375). ([\#10747](https://github.com/matrix-org/synapse/issues/10747)) + + +Internal Changes +---------------- + +- Print a warning when using one of the deprecated `template_dir` settings. ([\#10768](https://github.com/matrix-org/synapse/issues/10768)) + + Synapse 1.42.0rc1 (2021-09-01) ============================== diff --git a/changelog.d/10747.feature b/changelog.d/10747.feature deleted file mode 100644 index c1cca9a17b..0000000000 --- a/changelog.d/10747.feature +++ /dev/null @@ -1 +0,0 @@ -Support room version 9 from [MSC3375](https://github.com/matrix-org/matrix-doc/pull/3375). diff --git a/changelog.d/10768.misc b/changelog.d/10768.misc deleted file mode 100644 index afd64ca1b0..0000000000 --- a/changelog.d/10768.misc +++ /dev/null @@ -1 +0,0 @@ -Print a warning when using one of the deprecated `template_dir` settings. diff --git a/debian/changelog b/debian/changelog index 0f7dbdf71e..e865e0d2f6 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.42.0~rc2) stable; urgency=medium + + * New synapse release 1.42.0~rc2. + + -- Synapse Packaging team Mon, 06 Sep 2021 15:25:13 +0100 + matrix-synapse-py3 (1.42.0~rc1) stable; urgency=medium * New synapse release 1.42.0rc1. diff --git a/synapse/__init__.py b/synapse/__init__.py index e5b075c53b..e4302d81a8 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.42.0rc1" +__version__ = "1.42.0rc2" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.5.1 From 20d773906ca7aa1e834118809e0a090651ae96a9 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 6 Sep 2021 15:26:12 +0100 Subject: Move the upgrade notes reminder up to rc2 --- CHANGES.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 3ef521f225..64c30eed10 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,6 +1,8 @@ Synapse 1.42.0rc2 (2021-09-06) ============================== +Server administrators are reminded to read [the upgrade notes](docs/upgrade.md#upgrading-to-v1420). + Features -------- @@ -16,9 +18,6 @@ Internal Changes Synapse 1.42.0rc1 (2021-09-01) ============================== -Server administrators are reminded to read [the upgrade notes](docs/upgrade.md#upgrading-to-v1420). - - Features -------- -- cgit 1.5.1 From ca3cb1e039451b108015fa39d2b9d20022314182 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 6 Sep 2021 15:57:57 +0100 Subject: Expand on why users should read upgrade notes --- CHANGES.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 64c30eed10..67d649a4dd 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,7 +1,10 @@ Synapse 1.42.0rc2 (2021-09-06) ============================== -Server administrators are reminded to read [the upgrade notes](docs/upgrade.md#upgrading-to-v1420). +This version of Synapse removes deprecated room-management admin APIs and out-of-date +email pushers, and improves error handling for fallback templates for user-interactive +authentication. For more information on these points, server administrators are +encouraged to read [the upgrade notes](docs/upgrade.md#upgrading-to-v1420). Features -------- -- cgit 1.5.1 From ff039df70d9d176f5953c0fd3770fe2e918b1ecd Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 6 Sep 2021 16:05:05 +0100 Subject: Improve changelog wording --- CHANGES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 67d649a4dd..f9659d596d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,7 +1,7 @@ Synapse 1.42.0rc2 (2021-09-06) ============================== -This version of Synapse removes deprecated room-management admin APIs and out-of-date +This version of Synapse removes deprecated room-management admin APIs, removes out-of-date email pushers, and improves error handling for fallback templates for user-interactive authentication. For more information on these points, server administrators are encouraged to read [the upgrade notes](docs/upgrade.md#upgrading-to-v1420). -- cgit 1.5.1 From 6e895366ea7f194cd48fae08a9909ee01a9fadae Mon Sep 17 00:00:00 2001 From: Azrenbeth <77782548+Azrenbeth@users.noreply.github.com> Date: Mon, 6 Sep 2021 16:08:03 +0100 Subject: Add config option to use non-default manhole password and keys (#10643) --- changelog.d/10643.feature | 1 + docs/manhole.md | 29 +++++++++++++-- docs/sample_config.yaml | 18 +++++++++ synapse/app/_base.py | 10 ++++- synapse/app/generic_worker.py | 5 ++- synapse/app/homeserver.py | 5 ++- synapse/config/server.py | 87 ++++++++++++++++++++++++++++++++++++++++++- synapse/util/manhole.py | 15 ++++++-- tests/config/test_server.py | 8 ++-- 9 files changed, 161 insertions(+), 17 deletions(-) create mode 100644 changelog.d/10643.feature diff --git a/changelog.d/10643.feature b/changelog.d/10643.feature new file mode 100644 index 0000000000..bd63a3d258 --- /dev/null +++ b/changelog.d/10643.feature @@ -0,0 +1 @@ +Add config option to use non-default manhole password and keys. \ No newline at end of file diff --git a/docs/manhole.md b/docs/manhole.md index db92df88dc..715ed840f2 100644 --- a/docs/manhole.md +++ b/docs/manhole.md @@ -11,7 +11,7 @@ Note that this will give administrative access to synapse to **all users** with shell access to the server. It should therefore **not** be enabled in environments where untrusted users have shell access. -*** +## Configuring the manhole To enable it, first uncomment the `manhole` listener configuration in `homeserver.yaml`. The configuration is slightly different if you're using docker. @@ -52,16 +52,37 @@ listeners: type: manhole ``` -#### Accessing synapse manhole +### Security settings + +The following config options are available: + +- `username` - The username for the manhole (defaults to `matrix`) +- `password` - The password for the manhole (defaults to `rabbithole`) +- `ssh_priv_key` - The path to a private SSH key (defaults to a hardcoded value) +- `ssh_pub_key` - The path to a public SSH key (defaults to a hardcoded value) + +For example: + +```yaml +manhole_settings: + username: manhole + password: mypassword + ssh_priv_key: "/home/synapse/manhole_keys/id_rsa" + ssh_pub_key: "/home/synapse/manhole_keys/id_rsa.pub" +``` + + +## Accessing synapse manhole Then restart synapse, and point an ssh client at port 9000 on localhost, using -the username `matrix`: +the username and password configured in `homeserver.yaml` - with the default +configuration, this would be: ```bash ssh -p9000 matrix@localhost ``` -The password is `rabbithole`. +Then enter the password when prompted (the default is `rabbithole`). This gives a Python REPL in which `hs` gives access to the `synapse.server.HomeServer` object - which in turn gives access to many other diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index e155b978d8..e15a832220 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -335,6 +335,24 @@ listeners: # bind_addresses: ['::1', '127.0.0.1'] # type: manhole +# Connection settings for the manhole +# +manhole_settings: + # The username for the manhole. This defaults to 'matrix'. + # + #username: manhole + + # The password for the manhole. This defaults to 'rabbithole'. + # + #password: mypassword + + # The private and public SSH key pair used to encrypt the manhole traffic. + # If these are left unset, then hardcoded and non-secret keys are used, + # which could allow traffic to be intercepted if sent over a public network. + # + #ssh_priv_key_path: CONFDIR/id_rsa + #ssh_pub_key_path: CONFDIR/id_rsa.pub + # Forward extremities can build up in a room due to networking delays between # homeservers. Once this happens in a large room, calculation of the state of # that room can become quite expensive. To mitigate this, once the number of diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 6fc14930d1..89bda00090 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -37,6 +37,7 @@ from synapse.api.constants import MAX_PDU_SIZE from synapse.app import check_bind_error from synapse.app.phone_stats_home import start_phone_stats_home from synapse.config.homeserver import HomeServerConfig +from synapse.config.server import ManholeConfig from synapse.crypto import context_factory from synapse.events.presence_router import load_legacy_presence_router from synapse.events.spamcheck import load_legacy_spam_checkers @@ -230,7 +231,12 @@ def listen_metrics(bind_addresses, port): start_http_server(port, addr=host, registry=RegistryProxy) -def listen_manhole(bind_addresses: Iterable[str], port: int, manhole_globals: dict): +def listen_manhole( + bind_addresses: Iterable[str], + port: int, + manhole_settings: ManholeConfig, + manhole_globals: dict, +): # twisted.conch.manhole 21.1.0 uses "int_from_bytes", which produces a confusing # warning. It's fixed by https://github.com/twisted/twisted/pull/1522), so # suppress the warning for now. @@ -245,7 +251,7 @@ def listen_manhole(bind_addresses: Iterable[str], port: int, manhole_globals: di listen_tcp( bind_addresses, port, - manhole(username="matrix", password="rabbithole", globals=manhole_globals), + manhole(settings=manhole_settings, globals=manhole_globals), ) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 9b71dd75e6..2eb8d5a79c 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -395,7 +395,10 @@ class GenericWorkerServer(HomeServer): self._listen_http(listener) elif listener.type == "manhole": _base.listen_manhole( - listener.bind_addresses, listener.port, manhole_globals={"hs": self} + listener.bind_addresses, + listener.port, + manhole_settings=self.config.server.manhole_settings, + manhole_globals={"hs": self}, ) elif listener.type == "metrics": if not self.config.enable_metrics: diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 7dae163c1a..708db86f5d 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -291,7 +291,10 @@ class SynapseHomeServer(HomeServer): ) elif listener.type == "manhole": _base.listen_manhole( - listener.bind_addresses, listener.port, manhole_globals={"hs": self} + listener.bind_addresses, + listener.port, + manhole_settings=self.config.server.manhole_settings, + manhole_globals={"hs": self}, ) elif listener.type == "replication": services = listen_tcp( diff --git a/synapse/config/server.py b/synapse/config/server.py index d2c900f50c..7b9109a592 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -25,11 +25,14 @@ import attr import yaml from netaddr import AddrFormatError, IPNetwork, IPSet +from twisted.conch.ssh.keys import Key + from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.util.module_loader import load_module from synapse.util.stringutils import parse_and_validate_server_name from ._base import Config, ConfigError +from ._util import validate_config logger = logging.Logger(__name__) @@ -216,6 +219,16 @@ class ListenerConfig: http_options = attr.ib(type=Optional[HttpListenerConfig], default=None) +@attr.s(frozen=True) +class ManholeConfig: + """Object describing the configuration of the manhole""" + + username = attr.ib(type=str, validator=attr.validators.instance_of(str)) + password = attr.ib(type=str, validator=attr.validators.instance_of(str)) + priv_key = attr.ib(type=Optional[Key]) + pub_key = attr.ib(type=Optional[Key]) + + class ServerConfig(Config): section = "server" @@ -649,6 +662,41 @@ class ServerConfig(Config): ) ) + manhole_settings = config.get("manhole_settings") or {} + validate_config( + _MANHOLE_SETTINGS_SCHEMA, manhole_settings, ("manhole_settings",) + ) + + manhole_username = manhole_settings.get("username", "matrix") + manhole_password = manhole_settings.get("password", "rabbithole") + manhole_priv_key_path = manhole_settings.get("ssh_priv_key_path") + manhole_pub_key_path = manhole_settings.get("ssh_pub_key_path") + + manhole_priv_key = None + if manhole_priv_key_path is not None: + try: + manhole_priv_key = Key.fromFile(manhole_priv_key_path) + except Exception as e: + raise ConfigError( + f"Failed to read manhole private key file {manhole_priv_key_path}" + ) from e + + manhole_pub_key = None + if manhole_pub_key_path is not None: + try: + manhole_pub_key = Key.fromFile(manhole_pub_key_path) + except Exception as e: + raise ConfigError( + f"Failed to read manhole public key file {manhole_pub_key_path}" + ) from e + + self.manhole_settings = ManholeConfig( + username=manhole_username, + password=manhole_password, + priv_key=manhole_priv_key, + pub_key=manhole_pub_key, + ) + metrics_port = config.get("metrics_port") if metrics_port: logger.warning(METRICS_PORT_WARNING) @@ -715,7 +763,7 @@ class ServerConfig(Config): if not isinstance(templates_config, dict): raise ConfigError("The 'templates' section must be a dictionary") - self.custom_template_directory = templates_config.get( + self.custom_template_directory: Optional[str] = templates_config.get( "custom_template_directory" ) if self.custom_template_directory is not None and not isinstance( @@ -727,7 +775,13 @@ class ServerConfig(Config): return any(listener.tls for listener in self.listeners) def generate_config_section( - self, server_name, data_dir_path, open_private_ports, listeners, **kwargs + self, + server_name, + data_dir_path, + open_private_ports, + listeners, + config_dir_path, + **kwargs, ): ip_range_blacklist = "\n".join( " # - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST @@ -1068,6 +1122,24 @@ class ServerConfig(Config): # bind_addresses: ['::1', '127.0.0.1'] # type: manhole + # Connection settings for the manhole + # + manhole_settings: + # The username for the manhole. This defaults to 'matrix'. + # + #username: manhole + + # The password for the manhole. This defaults to 'rabbithole'. + # + #password: mypassword + + # The private and public SSH key pair used to encrypt the manhole traffic. + # If these are left unset, then hardcoded and non-secret keys are used, + # which could allow traffic to be intercepted if sent over a public network. + # + #ssh_priv_key_path: %(config_dir_path)s/id_rsa + #ssh_pub_key_path: %(config_dir_path)s/id_rsa.pub + # Forward extremities can build up in a room due to networking delays between # homeservers. Once this happens in a large room, calculation of the state of # that room can become quite expensive. To mitigate this, once the number of @@ -1436,3 +1508,14 @@ def _warn_if_webclient_configured(listeners: Iterable[ListenerConfig]) -> None: if name == "webclient": logger.warning(NO_MORE_WEB_CLIENT_WARNING) return + + +_MANHOLE_SETTINGS_SCHEMA = { + "type": "object", + "properties": { + "username": {"type": "string"}, + "password": {"type": "string"}, + "ssh_priv_key_path": {"type": "string"}, + "ssh_pub_key_path": {"type": "string"}, + }, +} diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index 522daa323d..cfb5b94ca9 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -61,7 +61,7 @@ EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs= -----END RSA PRIVATE KEY-----""" -def manhole(username, password, globals): +def manhole(settings, globals): """Starts a ssh listener with password authentication using the given username and password. Clients connecting to the ssh listener will find themselves in a colored python shell with @@ -75,6 +75,15 @@ def manhole(username, password, globals): Returns: twisted.internet.protocol.Factory: A factory to pass to ``listenTCP`` """ + username = settings.username + password = settings.password + priv_key = settings.priv_key + if priv_key is None: + priv_key = Key.fromString(PRIVATE_KEY) + pub_key = settings.pub_key + if pub_key is None: + pub_key = Key.fromString(PUBLIC_KEY) + if not isinstance(password, bytes): password = password.encode("ascii") @@ -86,8 +95,8 @@ def manhole(username, password, globals): ) factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) - factory.publicKeys[b"ssh-rsa"] = Key.fromString(PUBLIC_KEY) - factory.privateKeys[b"ssh-rsa"] = Key.fromString(PRIVATE_KEY) + factory.privateKeys[b"ssh-rsa"] = priv_key + factory.publicKeys[b"ssh-rsa"] = pub_key return factory diff --git a/tests/config/test_server.py b/tests/config/test_server.py index 6f2b9e997d..b6f21294ba 100644 --- a/tests/config/test_server.py +++ b/tests/config/test_server.py @@ -35,7 +35,7 @@ class ServerConfigTestCase(unittest.TestCase): def test_unsecure_listener_no_listeners_open_private_ports_false(self): conf = yaml.safe_load( ServerConfig().generate_config_section( - "che.org", "/data_dir_path", False, None + "che.org", "/data_dir_path", False, None, config_dir_path="CONFDIR" ) ) @@ -55,7 +55,7 @@ class ServerConfigTestCase(unittest.TestCase): def test_unsecure_listener_no_listeners_open_private_ports_true(self): conf = yaml.safe_load( ServerConfig().generate_config_section( - "che.org", "/data_dir_path", True, None + "che.org", "/data_dir_path", True, None, config_dir_path="CONFDIR" ) ) @@ -89,7 +89,7 @@ class ServerConfigTestCase(unittest.TestCase): conf = yaml.safe_load( ServerConfig().generate_config_section( - "this.one.listens", "/data_dir_path", True, listeners + "this.one.listens", "/data_dir_path", True, listeners, "CONFDIR" ) ) @@ -123,7 +123,7 @@ class ServerConfigTestCase(unittest.TestCase): conf = yaml.safe_load( ServerConfig().generate_config_section( - "this.one.listens", "/data_dir_path", True, listeners + "this.one.listens", "/data_dir_path", True, listeners, "CONFDIR" ) ) -- cgit 1.5.1 From f1c6b76418c2a129ec2f0183936221b134b1b07e Mon Sep 17 00:00:00 2001 From: Azrenbeth <77782548+Azrenbeth@users.noreply.github.com> Date: Mon, 6 Sep 2021 16:08:25 +0100 Subject: Add logging to help debug #9424 (#10704) --- changelog.d/10704.bugfix | 1 + synapse/handlers/sync.py | 67 ++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 55 insertions(+), 13 deletions(-) create mode 100644 changelog.d/10704.bugfix diff --git a/changelog.d/10704.bugfix b/changelog.d/10704.bugfix new file mode 100644 index 0000000000..4284cddc2b --- /dev/null +++ b/changelog.d/10704.bugfix @@ -0,0 +1 @@ +Added opentrace logging to help debug #9424. \ No newline at end of file diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 86c3c7f0df..e017b28cd2 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -505,10 +505,13 @@ class SyncHandler: else: limited = False + log_kv({"limited": limited}) + if potential_recents: recents = sync_config.filter_collection.filter_room_timeline( potential_recents ) + log_kv({"recents_after_sync_filtering": len(recents)}) # We check if there are any state events, if there are then we pass # all current state events to the filter_events function. This is to @@ -526,6 +529,7 @@ class SyncHandler: recents, always_include_ids=current_state_ids, ) + log_kv({"recents_after_visibility_filtering": len(recents)}) else: recents = [] @@ -566,10 +570,15 @@ class SyncHandler: events, end_key = await self.store.get_recent_events_for_room( room_id, limit=load_limit + 1, end_token=end_key ) + + log_kv({"loaded_recents": len(events)}) + loaded_recents = sync_config.filter_collection.filter_room_timeline( events ) + log_kv({"loaded_recents_after_sync_filtering": len(loaded_recents)}) + # We check if there are any state events, if there are then we pass # all current state events to the filter_events function. This is to # ensure that we always include current state in the timeline @@ -586,6 +595,9 @@ class SyncHandler: loaded_recents, always_include_ids=current_state_ids, ) + + log_kv({"loaded_recents_after_client_filtering": len(loaded_recents)}) + loaded_recents.extend(recents) recents = loaded_recents @@ -1116,6 +1128,8 @@ class SyncHandler: logger.debug("Fetching group data") await self._generate_sync_entry_for_groups(sync_result_builder) + num_events = 0 + # debug for https://github.com/matrix-org/synapse/issues/4422 for joined_room in sync_result_builder.joined: room_id = joined_room.room_id @@ -1123,6 +1137,14 @@ class SyncHandler: issue4422_logger.debug( "Sync result for newly joined room %s: %r", room_id, joined_room ) + num_events += len(joined_room.timeline.events) + + log_kv( + { + "joined_rooms_in_result": len(sync_result_builder.joined), + "events_in_result": num_events, + } + ) logger.debug("Sync response calculation complete") return SyncResult( @@ -1467,6 +1489,7 @@ class SyncHandler: if not sync_result_builder.full_state: if since_token and not ephemeral_by_room and not account_data_by_room: have_changed = await self._have_rooms_changed(sync_result_builder) + log_kv({"rooms_have_changed": have_changed}) if not have_changed: tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key @@ -1501,25 +1524,30 @@ class SyncHandler: tags_by_room = await self.store.get_tags_for_user(user_id) + log_kv({"rooms_changed": len(room_changes.room_entries)}) + room_entries = room_changes.room_entries invited = room_changes.invited knocked = room_changes.knocked newly_joined_rooms = room_changes.newly_joined_rooms newly_left_rooms = room_changes.newly_left_rooms - async def handle_room_entries(room_entry): - logger.debug("Generating room entry for %s", room_entry.room_id) - res = await self._generate_room_entry( - sync_result_builder, - ignored_users, - room_entry, - ephemeral=ephemeral_by_room.get(room_entry.room_id, []), - tags=tags_by_room.get(room_entry.room_id), - account_data=account_data_by_room.get(room_entry.room_id, {}), - always_include=sync_result_builder.full_state, - ) - logger.debug("Generated room entry for %s", room_entry.room_id) - return res + async def handle_room_entries(room_entry: "RoomSyncResultBuilder"): + with start_active_span("generate_room_entry"): + set_tag("room_id", room_entry.room_id) + log_kv({"events": len(room_entry.events or [])}) + logger.debug("Generating room entry for %s", room_entry.room_id) + res = await self._generate_room_entry( + sync_result_builder, + ignored_users, + room_entry, + ephemeral=ephemeral_by_room.get(room_entry.room_id, []), + tags=tags_by_room.get(room_entry.room_id), + account_data=account_data_by_room.get(room_entry.room_id, {}), + always_include=sync_result_builder.full_state, + ) + logger.debug("Generated room entry for %s", room_entry.room_id) + return res await concurrently_execute(handle_room_entries, room_entries, 10) @@ -1932,6 +1960,12 @@ class SyncHandler: room_id = room_builder.room_id since_token = room_builder.since_token upto_token = room_builder.upto_token + log_kv( + { + "since_token": since_token, + "upto_token": upto_token, + } + ) batch = await self._load_filtered_recents( room_id, @@ -1941,6 +1975,13 @@ class SyncHandler: potential_recents=events, newly_joined_room=newly_joined, ) + log_kv( + { + "batch_events": len(batch.events), + "prev_batch": batch.prev_batch, + "batch_limited": batch.limited, + } + ) # Note: `batch` can be both empty and limited here in the case where # `_load_filtered_recents` can't find any events the user should see -- cgit 1.5.1 From f30c9745abe2c5996918bbd99bae5ce3088d0127 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 7 Sep 2021 11:15:51 +0100 Subject: Underscore-prefix private fields in `FederationEventHandler` (#10746) --- changelog.d/10746.misc | 1 + synapse/handlers/federation_event.py | 144 ++++++++++++++++++----------------- 2 files changed, 74 insertions(+), 71 deletions(-) create mode 100644 changelog.d/10746.misc diff --git a/changelog.d/10746.misc b/changelog.d/10746.misc new file mode 100644 index 0000000000..9a765435db --- /dev/null +++ b/changelog.d/10746.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index afeb2892dc..69f8287b2b 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -124,28 +124,28 @@ class FederationEventHandler: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() - self.storage = hs.get_storage() - self.state_store = self.storage.state + self._store = hs.get_datastore() + self._storage = hs.get_storage() + self._state_store = self._storage.state - self.state_handler = hs.get_state_handler() - self.event_creation_handler = hs.get_event_creation_handler() + self._state_handler = hs.get_state_handler() + self._event_creation_handler = hs.get_event_creation_handler() self._event_auth_handler = hs.get_event_auth_handler() self._message_handler = hs.get_message_handler() - self.action_generator = hs.get_action_generator() + self._action_generator = hs.get_action_generator() self._state_resolution_handler = hs.get_state_resolution_handler() # avoid a circular dependency by deferring execution here self._get_room_member_handler = hs.get_room_member_handler - self.federation_client = hs.get_federation_client() - self.third_party_event_rules = hs.get_third_party_event_rules() + self._federation_client = hs.get_federation_client() + self._third_party_event_rules = hs.get_third_party_event_rules() self._notifier = hs.get_notifier() - self.is_mine_id = hs.is_mine_id + self._is_mine_id = hs.is_mine_id self._server_name = hs.hostname self._instance_name = hs.get_instance_name() - self.config = hs.config + self._config = hs.config self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs) @@ -177,7 +177,7 @@ class FederationEventHandler: event_id = pdu.event_id # We reprocess pdus when we have seen them only as outliers - existing = await self.store.get_event( + existing = await self._store.get_event( event_id, allow_none=True, allow_rejected=True ) @@ -240,7 +240,7 @@ class FederationEventHandler: # - Fetching state if we have a hole in the graph if not pdu.internal_metadata.is_outlier(): prevs = set(pdu.prev_event_ids()) - seen = await self.store.have_events_in_timeline(prevs) + seen = await self._store.have_events_in_timeline(prevs) missing_prevs = prevs - seen if missing_prevs: @@ -274,7 +274,7 @@ class FederationEventHandler: # Update the set of things we've seen after trying to # fetch the missing stuff - seen = await self.store.have_events_in_timeline(prevs) + seen = await self._store.have_events_in_timeline(prevs) missing_prevs = prevs - seen if not missing_prevs: @@ -363,7 +363,7 @@ class FederationEventHandler: # the room, so we send it on their behalf. event.internal_metadata.send_on_behalf_of = origin - context = await self.state_handler.compute_event_context(event) + context = await self._state_handler.compute_event_context(event) context = await self._check_event_auth(origin, event, context) if context.rejected: raise SynapseError( @@ -377,7 +377,7 @@ class FederationEventHandler: # for knock events, we run the third-party event rules. It's not entirely clear # why we don't do this for other sorts of membership events. if event.membership == Membership.KNOCK: - event_allowed, _ = await self.third_party_event_rules.check_event_allowed( + event_allowed, _ = await self._third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -406,7 +406,7 @@ class FederationEventHandler: prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) prev_member_event = None if prev_member_event_id: - prev_member_event = await self.store.get_event(prev_member_event_id) + prev_member_event = await self._store.get_event(prev_member_event_id) # Check if the member should be allowed access via membership in a space. await self._event_auth_handler.check_restricted_join_rules( @@ -439,7 +439,7 @@ class FederationEventHandler: if dest == self._server_name: raise SynapseError(400, "Can't backfill from self.") - events = await self.federation_client.backfill( + events = await self._federation_client.backfill( dest, room_id, limit=limit, extremities=extremities ) @@ -471,12 +471,12 @@ class FederationEventHandler: room_id = pdu.room_id event_id = pdu.event_id - seen = await self.store.have_events_in_timeline(prevs) + seen = await self._store.have_events_in_timeline(prevs) if not prevs - seen: return - latest_list = await self.store.get_latest_event_ids_in_room(room_id) + latest_list = await self._store.get_latest_event_ids_in_room(room_id) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us @@ -538,7 +538,7 @@ class FederationEventHandler: # All that said: Let's try increasing the timeout to 60s and see what happens. try: - missing_events = await self.federation_client.get_missing_events( + missing_events = await self._federation_client.get_missing_events( origin, room_id, earliest_events_ids=list(latest), @@ -611,7 +611,7 @@ class FederationEventHandler: event_id = event.event_id - existing = await self.store.get_event( + existing = await self._store.get_event( event_id, allow_none=True, allow_rejected=True ) if existing: @@ -676,7 +676,7 @@ class FederationEventHandler: event_id = event.event_id prevs = set(event.prev_event_ids()) - seen = await self.store.have_events_in_timeline(prevs) + seen = await self._store.have_events_in_timeline(prevs) missing_prevs = prevs - seen if not missing_prevs: @@ -693,7 +693,7 @@ class FederationEventHandler: event_map = {event_id: event} try: # Get the state of the events we know about - ours = await self.state_store.get_state_groups_ids(room_id, seen) + ours = await self._state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id state_maps: List[StateMap[str]] = list(ours.values()) @@ -722,13 +722,13 @@ class FederationEventHandler: for x in remote_state: event_map[x.event_id] = x - room_version = await self.store.get_room_version_id(room_id) + room_version = await self._store.get_room_version_id(room_id) state_map = await self._state_resolution_handler.resolve_events_with_store( room_id, room_version, state_maps, event_map, - state_res_store=StateResolutionStore(self.store), + state_res_store=StateResolutionStore(self._store), ) # We need to give _process_received_pdu the actual state events @@ -736,7 +736,7 @@ class FederationEventHandler: # First though we need to fetch all the events that are in # state_map, so we can build up the state below. - evs = await self.store.get_events( + evs = await self._store.get_events( list(state_map.values()), get_prev_content=False, redact_behaviour=EventRedactBehaviour.AS_IS, @@ -776,7 +776,7 @@ class FederationEventHandler: ( state_event_ids, auth_event_ids, - ) = await self.federation_client.get_room_state_ids( + ) = await self._federation_client.get_room_state_ids( destination, room_id, event_id=event_id ) @@ -790,7 +790,7 @@ class FederationEventHandler: desired_events = set(state_event_ids) desired_events.add(event_id) logger.debug("Fetching %i events from cache/store", len(desired_events)) - fetched_events = await self.store.get_events( + fetched_events = await self._store.get_events( desired_events, allow_rejected=True ) @@ -811,7 +811,7 @@ class FederationEventHandler: missing_auth_events = set(auth_event_ids) - fetched_events.keys() missing_auth_events.difference_update( - await self.store.have_seen_events(room_id, missing_auth_events) + await self._store.have_seen_events(room_id, missing_auth_events) ) logger.debug("We are also missing %i auth events", len(missing_auth_events)) @@ -824,7 +824,7 @@ class FederationEventHandler: # we need to make sure we re-load from the database to get the rejected # state correct. fetched_events.update( - await self.store.get_events(missing_desired_events, allow_rejected=True) + await self._store.get_events(missing_desired_events, allow_rejected=True) ) # check for events which were in the wrong room. @@ -903,7 +903,7 @@ class FederationEventHandler: logger.debug("Processing event: %s", event) try: - context = await self.state_handler.compute_event_context( + context = await self._state_handler.compute_event_context( event, old_state=state ) await self._auth_and_persist_event( @@ -921,7 +921,7 @@ class FederationEventHandler: device_id = event.content.get("device_id") sender_key = event.content.get("sender_key") - cached_devices = await self.store.get_cached_devices_for_user(event.sender) + cached_devices = await self._store.get_cached_devices_for_user(event.sender) resync = False # Whether we should resync device lists. @@ -997,10 +997,10 @@ class FederationEventHandler: """ try: - await self.store.mark_remote_user_device_cache_as_stale(sender) + await self._store.mark_remote_user_device_cache_as_stale(sender) # Immediately attempt a resync in the background - if self.config.worker_app: + if self._config.worker_app: await self._user_device_resync(user_id=sender) else: await self._device_list_updater.user_device_resync(sender) @@ -1026,12 +1026,12 @@ class FederationEventHandler: # Skip processing a marker event if the room version doesn't # support it or the event is not from the room creator. - room_version = await self.store.get_room_version(marker_event.room_id) - create_event = await self.store.get_create_event_for_room(marker_event.room_id) + room_version = await self._store.get_room_version(marker_event.room_id) + create_event = await self._store.get_create_event_for_room(marker_event.room_id) room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) if ( not room_version.msc2716_historical - or not self.config.experimental.msc2716_enabled + or not self._config.experimental.msc2716_enabled or marker_event.sender != room_creator ): return @@ -1056,7 +1056,7 @@ class FederationEventHandler: [insertion_event_id], ) - insertion_event = await self.store.get_event( + insertion_event = await self._store.get_event( insertion_event_id, allow_none=True ) if insertion_event is None: @@ -1074,7 +1074,7 @@ class FederationEventHandler: marker_event, ) - await self.store.insert_insertion_extremity( + await self._store.insert_insertion_extremity( insertion_event_id, marker_event.room_id ) @@ -1096,14 +1096,14 @@ class FederationEventHandler: Logs a warning if we can't find the given event. """ - room_version = await self.store.get_room_version(room_id) + room_version = await self._store.get_room_version(room_id) event_map: Dict[str, EventBase] = {} async def get_event(event_id: str): with nested_logging_context(event_id): try: - event = await self.federation_client.get_pdu( + event = await self._federation_client.get_pdu( [destination], event_id, room_version, @@ -1139,7 +1139,7 @@ class FederationEventHandler: for aid in event.auth_event_ids() if aid not in event_map ] - persisted_events = await self.store.get_events( + persisted_events = await self._store.get_events( auth_events, allow_rejected=True, ) @@ -1183,7 +1183,7 @@ class FederationEventHandler: async def prep(ev_info: _NewEventInfo): event = ev_info.event with nested_logging_context(suffix=event.event_id): - res = await self.state_handler.compute_event_context(event) + res = await self._state_handler.compute_event_context(event) res = await self._check_event_auth( origin, event, @@ -1286,7 +1286,7 @@ class FederationEventHandler: Returns: The updated context object. """ - room_version = await self.store.get_room_version_id(event.room_id) + room_version = await self._store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] if claimed_auth_event_map: @@ -1299,7 +1299,7 @@ class FederationEventHandler: auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events_x = await self.store.get_events(auth_events_ids) + auth_events_x = await self._store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} try: @@ -1334,7 +1334,7 @@ class FederationEventHandler: # If we are going to send this event over federation we precaclculate # the joined hosts. if event.internal_metadata.get_send_on_behalf_of(): - await self.event_creation_handler.cache_joined_hosts_for_event( + await self._event_creation_handler.cache_joined_hosts_for_event( event, context ) @@ -1348,7 +1348,7 @@ class FederationEventHandler: if guest_access == GuestAccess.CAN_JOIN: return - current_state_map = await self.state_handler.get_current_state(event.room_id) + current_state_map = await self._state_handler.get_current_state(event.room_id) current_state = list(current_state_map.values()) await self._get_room_member_handler().kick_guest_users(current_state) @@ -1374,7 +1374,7 @@ class FederationEventHandler: if backfilled or event.internal_metadata.is_outlier(): return - extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id) + extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id) extrem_ids = set(extrem_ids_list) prev_event_ids = set(event.prev_event_ids()) @@ -1383,7 +1383,7 @@ class FederationEventHandler: # state at the event, so no point rechecking auth for soft fail. return - room_version = await self.store.get_room_version_id(event.room_id) + room_version = await self._store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] # Calculate the "current state". @@ -1400,19 +1400,19 @@ class FederationEventHandler: # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets_d = await self.state_store.get_state_groups( + state_sets_d = await self._state_store.get_state_groups( event.room_id, extrem_ids ) state_sets: List[Iterable[EventBase]] = list(state_sets_d.values()) state_sets.append(state) - current_states = await self.state_handler.resolve_events( + current_states = await self._state_handler.resolve_events( room_version, state_sets, event ) current_state_ids: StateMap[str] = { k: e.event_id for k, e in current_states.items() } else: - current_state_ids = await self.state_handler.get_current_state_ids( + current_state_ids = await self._state_handler.get_current_state_ids( event.room_id, latest_event_ids=extrem_ids ) @@ -1428,7 +1428,7 @@ class FederationEventHandler: e for k, e in current_state_ids.items() if k in auth_types ] - auth_events_map = await self.store.get_events(current_state_ids_list) + auth_events_map = await self._store.get_events(current_state_ids_list) current_auth_events = { (e.type, e.state_key): e for e in auth_events_map.values() } @@ -1499,7 +1499,9 @@ class FederationEventHandler: # # we start by checking if they are in the store, and then try calling /event_auth/. if missing_auth: - have_events = await self.store.have_seen_events(event.room_id, missing_auth) + have_events = await self._store.have_seen_events( + event.room_id, missing_auth + ) logger.debug("Events %s are in the store", have_events) missing_auth.difference_update(have_events) @@ -1508,7 +1510,7 @@ class FederationEventHandler: logger.info("auth_events contains unknown events: %s", missing_auth) try: try: - remote_auth_chain = await self.federation_client.get_event_auth( + remote_auth_chain = await self._federation_client.get_event_auth( origin, event.room_id, event.event_id ) except RequestSendFailed as e1: @@ -1517,7 +1519,7 @@ class FederationEventHandler: logger.info("Failed to get event auth from remote: %s", e1) return context, auth_events - seen_remotes = await self.store.have_seen_events( + seen_remotes = await self._store.have_seen_events( event.room_id, [e.event_id for e in remote_auth_chain] ) @@ -1543,7 +1545,7 @@ class FederationEventHandler: e.event_id, ) missing_auth_event_context = ( - await self.state_handler.compute_event_context(e) + await self._state_handler.compute_event_context(e) ) await self._auth_and_persist_event( origin, @@ -1584,7 +1586,7 @@ class FederationEventHandler: # XXX: currently this checks for redactions but I'm not convinced that is # necessary? - different_events = await self.store.get_events_as_list(different_auth) + different_events = await self._store.get_events_as_list(different_auth) for d in different_events: if d.room_id != event.room_id: @@ -1610,8 +1612,8 @@ class FederationEventHandler: remote_auth_events.update({(d.type, d.state_key): d for d in different_events}) remote_state = remote_auth_events.values() - room_version = await self.store.get_room_version_id(event.room_id) - new_state = await self.state_handler.resolve_events( + room_version = await self._store.get_room_version_id(event.room_id) + new_state = await self._state_handler.resolve_events( room_version, (local_state, remote_state), event ) @@ -1669,7 +1671,7 @@ class FederationEventHandler: # create a new state group as a delta from the existing one. prev_group = context.state_group - state_group = await self.state_store.store_state_group( + state_group = await self._state_store.store_state_group( event.event_id, event.room_id, prev_group=prev_group, @@ -1701,9 +1703,9 @@ class FederationEventHandler: not event.internal_metadata.is_outlier() and not backfilled and not context.rejected - and (await self.store.get_min_depth(event.room_id)) <= event.depth + and (await self._store.get_min_depth(event.room_id)) <= event.depth ): - await self.action_generator.handle_push_actions_for_event( + await self._action_generator.handle_push_actions_for_event( event, context ) @@ -1712,7 +1714,7 @@ class FederationEventHandler: ) except Exception: run_in_background( - self.store.remove_push_actions_from_staging, event.event_id + self._store.remove_push_actions_from_staging, event.event_id ) raise @@ -1737,27 +1739,27 @@ class FederationEventHandler: The stream ID after which all events have been persisted. """ if not event_and_contexts: - return self.store.get_current_events_token() + return self._store.get_current_events_token() - instance = self.config.worker.events_shard_config.get_instance(room_id) + instance = self._config.worker.events_shard_config.get_instance(room_id) if instance != self._instance_name: # Limit the number of events sent over replication. We choose 200 # here as that is what we default to in `max_request_body_size(..)` for batch in batch_iter(event_and_contexts, 200): result = await self._send_events( instance_name=instance, - store=self.store, + store=self._store, room_id=room_id, event_and_contexts=batch, backfilled=backfilled, ) return result["max_stream_id"] else: - assert self.storage.persistence + assert self._storage.persistence # Note that this returns the events that were persisted, which may not be # the same as were passed in if some were deduplicated due to transaction IDs. - events, max_stream_token = await self.storage.persistence.persist_events( + events, max_stream_token = await self._storage.persistence.persist_events( event_and_contexts, backfilled=backfilled ) @@ -1791,7 +1793,7 @@ class FederationEventHandler: # users if event.internal_metadata.is_outlier(): if event.membership != Membership.INVITE: - if not self.is_mine_id(target_user_id): + if not self._is_mine_id(target_user_id): return target_user = UserID.from_string(target_user_id) @@ -1840,4 +1842,4 @@ class FederationEventHandler: raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") async def get_min_depth_for_context(self, context: str) -> int: - return await self.store.get_min_depth(context) + return await self._store.get_min_depth(context) -- cgit 1.5.1 From a23f3abb9bc3d67888a2e1090e5df849bc442549 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 7 Sep 2021 08:43:54 -0400 Subject: Return stripped m.space.child events via the space summary. (#10760) The full event content cannot be trusted from this API (as no auth chain, etc.) is processed over federation. Returning the full event content was a bug as MSC2946 specifies that only the stripped state should be returned. This also avoids calculating aggregations / annotations which go unused. --- changelog.d/10760.bugfix | 1 + synapse/handlers/room_summary.py | 26 ++++++++++++-------------- 2 files changed, 13 insertions(+), 14 deletions(-) create mode 100644 changelog.d/10760.bugfix diff --git a/changelog.d/10760.bugfix b/changelog.d/10760.bugfix new file mode 100644 index 0000000000..4995c28190 --- /dev/null +++ b/changelog.d/10760.bugfix @@ -0,0 +1 @@ +Only return the stripped state events for the `m.space.child` events in a room for the spaces summary from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 4bc9c73e6e..781da9e811 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -37,7 +37,6 @@ from synapse.api.errors import ( UnsupportedRoomVersionError, ) from synapse.events import EventBase -from synapse.events.utils import format_event_for_client_v2 from synapse.types import JsonDict from synapse.util.caches.response_cache import ResponseCache @@ -89,7 +88,6 @@ class RoomSummaryHandler: _PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000 def __init__(self, hs: "HomeServer"): - self._clock = hs.get_clock() self._event_auth_handler = hs.get_event_auth_handler() self._store = hs.get_datastore() self._event_serializer = hs.get_event_client_serializer() @@ -648,18 +646,18 @@ class RoomSummaryHandler: if max_children is None or max_children > MAX_ROOMS_PER_SPACE: max_children = MAX_ROOMS_PER_SPACE - now = self._clock.time_msec() - events_result: List[JsonDict] = [] - for edge_event in itertools.islice(child_events, max_children): - events_result.append( - await self._event_serializer.serialize_event( - edge_event, - time_now=now, - event_format=format_event_for_client_v2, - ) - ) - - return _RoomEntry(room_id, room_entry, events_result) + stripped_events: List[JsonDict] = [ + { + "type": e.type, + "state_key": e.state_key, + "content": e.content, + "room_id": e.room_id, + "sender": e.sender, + "origin_server_ts": e.origin_server_ts, + } + for e in itertools.islice(child_events, max_children) + ] + return _RoomEntry(room_id, room_entry, stripped_events) async def _summarize_remote_room( self, -- cgit 1.5.1