diff options
122 files changed, 2828 insertions, 799 deletions
diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml index 325c1f7d39..eb294f1619 100644 --- a/.github/workflows/release-artifacts.yml +++ b/.github/workflows/release-artifacts.yml @@ -12,6 +12,10 @@ on: # we do the full build on tags. tags: ["v*"] +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + permissions: contents: write @@ -44,12 +48,43 @@ jobs: distro: ${{ fromJson(needs.get-distros.outputs.distros) }} steps: - - uses: actions/checkout@v2 + - name: Checkout + uses: actions/checkout@v2 with: path: src - - uses: actions/setup-python@v2 - - run: ./src/scripts-dev/build_debian_packages "${{ matrix.distro }}" - - uses: actions/upload-artifact@v2 + + - name: Set up Docker Buildx + id: buildx + uses: docker/setup-buildx-action@v1 + with: + install: true + + - name: Set up docker layer caching + uses: actions/cache@v2 + with: + path: /tmp/.buildx-cache + key: ${{ runner.os }}-buildx-${{ github.sha }} + restore-keys: | + ${{ runner.os }}-buildx- + + - name: Set up python + uses: actions/setup-python@v2 + + - name: Build the packages + # see https://github.com/docker/build-push-action/issues/252 + # for the cache magic here + run: | + ./src/scripts-dev/build_debian_packages \ + --docker-build-arg=--cache-from=type=local,src=/tmp/.buildx-cache \ + --docker-build-arg=--cache-to=type=local,mode=max,dest=/tmp/.buildx-cache-new \ + --docker-build-arg=--progress=plain \ + --docker-build-arg=--load \ + "${{ matrix.distro }}" + rm -rf /tmp/.buildx-cache + mv /tmp/.buildx-cache-new /tmp/.buildx-cache + + - name: Upload debs as artifacts + uses: actions/upload-artifact@v2 with: name: debs path: debs/* diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index cef4439477..0a62c62d02 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -5,6 +5,10 @@ on: branches: ["develop", "release-*"] pull_request: +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: lint: runs-on: ubuntu-latest @@ -340,14 +344,19 @@ jobs: working-directory: complement/dockerfiles # Run Complement - - run: go test -v -tags synapse_blacklist,msc2403,msc2946,msc3083 ./tests + - run: go test -v -tags synapse_blacklist,msc2403,msc2946,msc3083 ./tests/... env: COMPLEMENT_BASE_IMAGE: complement-synapse:latest working-directory: complement # a job which marks all the other jobs as complete, thus allowing PRs to be merged. tests-done: + if: ${{ always() }} needs: + - lint + - lint-crlf + - lint-newsfile + - lint-sdist - trial - trial-olddeps - sytest @@ -355,4 +364,16 @@ jobs: - complement runs-on: ubuntu-latest steps: - - run: "true" \ No newline at end of file + - name: Set build result + env: + NEEDS_CONTEXT: ${{ toJSON(needs) }} + # the `jq` incantation dumps out a series of "<job> <result>" lines + run: | + set -o pipefail + jq -r 'to_entries[] | [.key,.value.result] | join(" ")' \ + <<< $NEEDS_CONTEXT | + while read job result; do + if [ "$result" != "success" ]; then + echo "::set-failed ::Job $job returned $result" + fi + done diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a4e6688042..e7eef23419 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -155,7 +155,7 @@ source ./env/bin/activate ./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder ``` -## Run the unit tests. +## Run the unit tests (Twisted trial). The unit tests run parts of Synapse, including your changes, to see if anything was broken. They are slower than the linters but will typically catch more errors. @@ -186,7 +186,7 @@ SYNAPSE_TEST_LOG_LEVEL=DEBUG trial tests ``` -## Run the integration tests. +## Run the integration tests ([Sytest](https://github.com/matrix-org/sytest)). The integration tests are a more comprehensive suite of tests. They run a full version of Synapse, including your changes, to check if @@ -203,6 +203,43 @@ $ docker run --rm -it -v /path/where/you/have/cloned/the/repository\:/src:ro -v This configuration should generally cover your needs. For more details about other configurations, see [documentation in the SyTest repo](https://github.com/matrix-org/sytest/blob/develop/docker/README.md). +## Run the integration tests ([Complement](https://github.com/matrix-org/complement)). + +[Complement](https://github.com/matrix-org/complement) is a suite of black box tests that can be run on any homeserver implementation. It can also be thought of as end-to-end (e2e) tests. + +It's often nice to develop on Synapse and write Complement tests at the same time. +Here is how to run your local Synapse checkout against your local Complement checkout. + +(checkout [`complement`](https://github.com/matrix-org/complement) alongside your `synapse` checkout) +```sh +COMPLEMENT_DIR=../complement ./scripts-dev/complement.sh +``` + +To run a specific test file, you can pass the test name at the end of the command. The name passed comes from the naming structure in your Complement tests. If you're unsure of the name, you can do a full run and copy it from the test output: + +```sh +COMPLEMENT_DIR=../complement ./scripts-dev/complement.sh TestBackfillingHistory +``` + +To run a specific test, you can specify the whole name structure: + +```sh +COMPLEMENT_DIR=../complement ./scripts-dev/complement.sh TestBackfillingHistory/parallel/Backfilled_historical_events_resolve_with_proper_state_in_correct_order +``` + + +### Access database for homeserver after Complement test runs. + +If you're curious what the database looks like after you run some tests, here are some steps to get you going in Synapse: + + 1. In your Complement test comment out `defer deployment.Destroy(t)` and replace with `defer time.Sleep(2 * time.Hour)` to keep the homeserver running after the tests complete + 1. Start the Complement tests + 1. Find the name of the container, `docker ps -f name=complement_` (this will filter for just the Compelement related Docker containers) + 1. Access the container replacing the name with what you found in the previous step: `docker exec -it complement_1_hs_with_application_service.hs1_2 /bin/bash` + 1. Install sqlite (database driver), `apt-get update && apt-get install -y sqlite3` + 1. Then run `sqlite3` and open the database `.open /conf/homeserver.db` (this db path comes from the Synapse homeserver.yaml) + + # 9. Submit your patch. Once you're happy with your patch, it's time to prepare a Pull Request. @@ -392,7 +429,7 @@ By now, you know the drill! # Notes for maintainers on merging PRs etc There are some notes for those with commit access to the project on how we -manage git [here](docs/dev/git.md). +manage git [here](docs/development/git.md). # Conclusion diff --git a/changelog.d/10254.feature b/changelog.d/10254.feature new file mode 100644 index 0000000000..df8bb51167 --- /dev/null +++ b/changelog.d/10254.feature @@ -0,0 +1 @@ +Update support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083) to consider changes in the MSC around which servers can issue join events. diff --git a/changelog.d/10283.feature b/changelog.d/10283.feature new file mode 100644 index 0000000000..99d633dbfb --- /dev/null +++ b/changelog.d/10283.feature @@ -0,0 +1 @@ +Initial support for MSC3244, Room version capabilities over the /capabilities API. \ No newline at end of file diff --git a/changelog.d/10407.feature b/changelog.d/10407.feature new file mode 100644 index 0000000000..db277d9ecd --- /dev/null +++ b/changelog.d/10407.feature @@ -0,0 +1 @@ +Add a buffered logging handler which periodically flushes itself. diff --git a/changelog.d/10408.misc b/changelog.d/10408.misc new file mode 100644 index 0000000000..abccd210a9 --- /dev/null +++ b/changelog.d/10408.misc @@ -0,0 +1 @@ +Add type hints to `synapse.federation.transport.client` module. diff --git a/changelog.d/10410.bugfix b/changelog.d/10410.bugfix new file mode 100644 index 0000000000..65b418fd35 --- /dev/null +++ b/changelog.d/10410.bugfix @@ -0,0 +1 @@ +Improve character set detection in URL previews by supporting underscores (in addition to hyphens). Contributed by @srividyut. diff --git a/changelog.d/10411.feature b/changelog.d/10411.feature new file mode 100644 index 0000000000..ef0ab84b17 --- /dev/null +++ b/changelog.d/10411.feature @@ -0,0 +1 @@ +Add support for https connections to a proxy server. Contributed by @Bubu and @dklimpel. \ No newline at end of file diff --git a/changelog.d/10413.feature b/changelog.d/10413.feature new file mode 100644 index 0000000000..3964db7e0e --- /dev/null +++ b/changelog.d/10413.feature @@ -0,0 +1 @@ +Support for [MSC2285 (hidden read receipts)](https://github.com/matrix-org/matrix-doc/pull/2285). Contributed by @SimonBrandner. diff --git a/changelog.d/10426.feature b/changelog.d/10426.feature new file mode 100644 index 0000000000..9cca6dc456 --- /dev/null +++ b/changelog.d/10426.feature @@ -0,0 +1 @@ +Email notifications now state whether an invitation is to a room or a space. diff --git a/changelog.d/10429.misc b/changelog.d/10429.misc new file mode 100644 index 0000000000..ccb2217f64 --- /dev/null +++ b/changelog.d/10429.misc @@ -0,0 +1 @@ +Drop backwards-compatibility code that was required to support Ubuntu Xenial. diff --git a/changelog.d/10431.misc b/changelog.d/10431.misc new file mode 100644 index 0000000000..34b9b49da6 --- /dev/null +++ b/changelog.d/10431.misc @@ -0,0 +1 @@ +Use a docker image cache for the prerequisites for the debian package build. diff --git a/changelog.d/10432.misc b/changelog.d/10432.misc new file mode 100644 index 0000000000..3a8cdf0ae0 --- /dev/null +++ b/changelog.d/10432.misc @@ -0,0 +1 @@ +Connect historical chunks together with chunk events instead of a content field (MSC2716). diff --git a/changelog.d/10437.misc b/changelog.d/10437.misc new file mode 100644 index 0000000000..a557578499 --- /dev/null +++ b/changelog.d/10437.misc @@ -0,0 +1 @@ +Improve servlet type hints. diff --git a/changelog.d/10438.misc b/changelog.d/10438.misc new file mode 100644 index 0000000000..a557578499 --- /dev/null +++ b/changelog.d/10438.misc @@ -0,0 +1 @@ +Improve servlet type hints. diff --git a/changelog.d/10442.misc b/changelog.d/10442.misc new file mode 100644 index 0000000000..b8d412d732 --- /dev/null +++ b/changelog.d/10442.misc @@ -0,0 +1 @@ +Replace usage of `or_ignore` in `simple_insert` with `simple_upsert` usage, to stop spamming postgres logs with spurious ERROR messages. diff --git a/changelog.d/10444.misc b/changelog.d/10444.misc new file mode 100644 index 0000000000..c012e89f4b --- /dev/null +++ b/changelog.d/10444.misc @@ -0,0 +1 @@ +Update the `tests-done` Github Actions status. diff --git a/changelog.d/10445.doc b/changelog.d/10445.doc new file mode 100644 index 0000000000..4c023ded7c --- /dev/null +++ b/changelog.d/10445.doc @@ -0,0 +1 @@ +Fix hierarchy of providers on the OpenID page. diff --git a/changelog.d/10446.misc b/changelog.d/10446.misc new file mode 100644 index 0000000000..a5a0ca80eb --- /dev/null +++ b/changelog.d/10446.misc @@ -0,0 +1 @@ +Update type annotations to work with forthcoming Twisted 21.7.0 release. diff --git a/changelog.d/10448.feature b/changelog.d/10448.feature new file mode 100644 index 0000000000..f6579e0ca8 --- /dev/null +++ b/changelog.d/10448.feature @@ -0,0 +1 @@ +Add `creation_ts` to list users admin API. \ No newline at end of file diff --git a/changelog.d/10450.misc b/changelog.d/10450.misc new file mode 100644 index 0000000000..aa646f0841 --- /dev/null +++ b/changelog.d/10450.misc @@ -0,0 +1 @@ + Update type annotations to work with forthcoming Twisted 21.7.0 release. diff --git a/changelog.d/10451.misc b/changelog.d/10451.misc new file mode 100644 index 0000000000..e38f4b476d --- /dev/null +++ b/changelog.d/10451.misc @@ -0,0 +1 @@ +Cancel redundant GHA workflows when a new commit is pushed. diff --git a/changelog.d/10453.doc b/changelog.d/10453.doc new file mode 100644 index 0000000000..5d4db9bca2 --- /dev/null +++ b/changelog.d/10453.doc @@ -0,0 +1 @@ +Consolidate development documentation to `docs/development/`. diff --git a/changelog.d/10455.bugfix b/changelog.d/10455.bugfix new file mode 100644 index 0000000000..23c74a3c89 --- /dev/null +++ b/changelog.d/10455.bugfix @@ -0,0 +1 @@ +Fix `synapse_federation_server_oldest_inbound_pdu_in_staging` Prometheus metric to not report a max age of 51 years when the queue is empty. diff --git a/changelog.d/10463.misc b/changelog.d/10463.misc new file mode 100644 index 0000000000..d7b4d2222e --- /dev/null +++ b/changelog.d/10463.misc @@ -0,0 +1 @@ +Disable `msc2716` Complement tests until Complement updates are merged. diff --git a/changelog.d/10468.misc b/changelog.d/10468.misc new file mode 100644 index 0000000000..b9854bb4c1 --- /dev/null +++ b/changelog.d/10468.misc @@ -0,0 +1 @@ +Mitigate media repo XSS attacks on IE11 via the non-standard X-Content-Security-Policy header. diff --git a/changelog.d/10482.misc b/changelog.d/10482.misc new file mode 100644 index 0000000000..4e9e2126e1 --- /dev/null +++ b/changelog.d/10482.misc @@ -0,0 +1 @@ +Additional type hints in the state handler. diff --git a/changelog.d/10483.doc b/changelog.d/10483.doc new file mode 100644 index 0000000000..0f699fafdd --- /dev/null +++ b/changelog.d/10483.doc @@ -0,0 +1 @@ +Document how to use Complement while developing a new Synapse feature. diff --git a/changelog.d/10488.misc b/changelog.d/10488.misc new file mode 100644 index 0000000000..a55502c163 --- /dev/null +++ b/changelog.d/10488.misc @@ -0,0 +1 @@ +Update syntax used to run complement tests. diff --git a/changelog.d/10489.feature b/changelog.d/10489.feature new file mode 100644 index 0000000000..df8bb51167 --- /dev/null +++ b/changelog.d/10489.feature @@ -0,0 +1 @@ +Update support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083) to consider changes in the MSC around which servers can issue join events. diff --git a/changelog.d/10490.misc b/changelog.d/10490.misc new file mode 100644 index 0000000000..630c31adae --- /dev/null +++ b/changelog.d/10490.misc @@ -0,0 +1 @@ +Fix up type annotations to work with Twisted 21.7. diff --git a/changelog.d/9918.feature b/changelog.d/9918.feature new file mode 100644 index 0000000000..98f0a50893 --- /dev/null +++ b/changelog.d/9918.feature @@ -0,0 +1 @@ +Add support for [MSC2033](https://github.com/matrix-org/matrix-doc/pull/2033): `device_id` on `/account/whoami`. \ No newline at end of file diff --git a/debian/build_virtualenv b/debian/build_virtualenv index 21caad90cc..68c8659953 100755 --- a/debian/build_virtualenv +++ b/debian/build_virtualenv @@ -33,13 +33,11 @@ esac # Use --builtin-venv to use the better `venv` module from CPython 3.4+ rather # than the 2/3 compatible `virtualenv`. -# Pin pip to 20.3.4 to fix breakage in 21.0 on py3.5 (xenial) - dh_virtualenv \ --install-suffix "matrix-synapse" \ --builtin-venv \ --python "$SNAKE" \ - --upgrade-pip-to="20.3.4" \ + --upgrade-pip \ --preinstall="lxml" \ --preinstall="mock" \ --extra-pip-arg="--no-cache-dir" \ diff --git a/debian/changelog b/debian/changelog index 4944e55714..ff21599df1 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.39.0ubuntu1) UNRELEASED; urgency=medium + + * Drop backwards-compatibility code that was required to support Ubuntu Xenial. + + -- Richard van der Hoff <richard@matrix.org> Tue, 20 Jul 2021 00:10:03 +0100 + matrix-synapse-py3 (1.39.0~rc3) stable; urgency=medium * New synapse release 1.39.0~rc3. diff --git a/debian/compat b/debian/compat index ec635144f6..f599e28b8a 100644 --- a/debian/compat +++ b/debian/compat @@ -1 +1 @@ -9 +10 diff --git a/debian/control b/debian/control index 8167a901a4..763fabd6f6 100644 --- a/debian/control +++ b/debian/control @@ -3,11 +3,8 @@ Section: contrib/python Priority: extra Maintainer: Synapse Packaging team <packages@matrix.org> # keep this list in sync with the build dependencies in docker/Dockerfile-dhvirtualenv. -# TODO: Remove the dependency on dh-systemd after dropping support for Ubuntu xenial -# On all other supported releases, it's merely a transitional package which -# does nothing but depends on debhelper (> 9.20160709) Build-Depends: - debhelper (>= 9.20160709) | dh-systemd, + debhelper (>= 10), dh-virtualenv (>= 1.1), libsystemd-dev, libpq-dev, diff --git a/debian/rules b/debian/rules index c744060a57..b9d490adc9 100755 --- a/debian/rules +++ b/debian/rules @@ -51,7 +51,5 @@ override_dh_shlibdeps: override_dh_virtualenv: ./debian/build_virtualenv -# We are restricted to compat level 9 (because xenial), so have to -# enable the systemd bits manually. %: - dh $@ --with python-virtualenv --with systemd + dh $@ --with python-virtualenv diff --git a/docker/Dockerfile-dhvirtualenv b/docker/Dockerfile-dhvirtualenv index 0d74630370..017be8555e 100644 --- a/docker/Dockerfile-dhvirtualenv +++ b/docker/Dockerfile-dhvirtualenv @@ -15,6 +15,15 @@ ARG distro="" ### ### Stage 0: build a dh-virtualenv ### + +# This is only really needed on bionic and focal, since other distributions we +# care about have a recent version of dh-virtualenv by default. Unfortunately, +# it looks like focal is going to be with us for a while. +# +# (focal doesn't have a dh-virtualenv package at all. There is a PPA at +# https://launchpad.net/~jyrki-pulliainen/+archive/ubuntu/dh-virtualenv, but +# it's not obviously easier to use that than to build our own.) + FROM ${distro} as builder RUN apt-get update -qq -o Acquire::Languages=none @@ -27,7 +36,7 @@ RUN env DEBIAN_FRONTEND=noninteractive apt-get install \ wget # fetch and unpack the package -# TODO: Upgrade to 1.2.2 once xenial is dropped +# TODO: Upgrade to 1.2.2 once bionic is dropped (1.2.2 requires debhelper 12; bionic has only 11) RUN mkdir /dh-virtualenv RUN wget -q -O /dh-virtualenv.tar.gz https://github.com/spotify/dh-virtualenv/archive/ac6e1b1.tar.gz RUN tar -xv --strip-components=1 -C /dh-virtualenv -f /dh-virtualenv.tar.gz @@ -59,8 +68,6 @@ ENV LANG C.UTF-8 # # NB: keep this list in sync with the list of build-deps in debian/control # TODO: it would be nice to do that automatically. -# TODO: Remove the dh-systemd stanza after dropping support for Ubuntu xenial -# it's a transitional package on all other, more recent releases RUN apt-get update -qq -o Acquire::Languages=none \ && env DEBIAN_FRONTEND=noninteractive apt-get install \ -yqq --no-install-recommends -o Dpkg::Options::=--force-unsafe-io \ @@ -76,10 +83,7 @@ RUN apt-get update -qq -o Acquire::Languages=none \ python3-venv \ sqlite3 \ libpq-dev \ - xmlsec1 \ - && ( env DEBIAN_FRONTEND=noninteractive apt-get install \ - -yqq --no-install-recommends -o Dpkg::Options::=--force-unsafe-io \ - dh-systemd || true ) + xmlsec1 COPY --from=builder /dh-virtualenv_1.2~dev-1_all.deb / diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index db4ef1a44e..f1bde91420 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -67,7 +67,7 @@ # Development - [Contributing Guide](development/contributing_guide.md) - [Code Style](code_style.md) - - [Git Usage](dev/git.md) + - [Git Usage](development/git.md) - [Testing]() - [OpenTracing](opentracing.md) - [Database Schemas](development/database_schema.md) @@ -77,8 +77,8 @@ - [TCP Replication](tcp_replication.md) - [Internal Documentation](development/internal_documentation/README.md) - [Single Sign-On]() - - [SAML](dev/saml.md) - - [CAS](dev/cas.md) + - [SAML](development/saml.md) + - [CAS](development/cas.md) - [State Resolution]() - [The Auth Chain Difference Algorithm](auth_chain_difference_algorithm.md) - [Media Repository](media_repository.md) diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 4a65d0c3bc..160899754e 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -144,7 +144,8 @@ A response body like the following is returned: "deactivated": 0, "shadow_banned": 0, "displayname": "<User One>", - "avatar_url": null + "avatar_url": null, + "creation_ts": 1560432668000 }, { "name": "<user_id2>", "is_guest": 0, @@ -153,7 +154,8 @@ A response body like the following is returned: "deactivated": 0, "shadow_banned": 0, "displayname": "<User Two>", - "avatar_url": "<avatar_url>" + "avatar_url": "<avatar_url>", + "creation_ts": 1561550621000 } ], "next_token": "100", @@ -197,11 +199,12 @@ The following parameters should be set in the URL: - `shadow_banned` - Users are ordered by `shadow_banned` status. - `displayname` - Users are ordered alphabetically by `displayname`. - `avatar_url` - Users are ordered alphabetically by avatar URL. + - `creation_ts` - Users are ordered by when the users was created in ms. - `dir` - Direction of media order. Either `f` for forwards or `b` for backwards. Setting this value to `b` will reverse the above sort order. Defaults to `f`. -Caution. The database only has indexes on the columns `name` and `created_ts`. +Caution. The database only has indexes on the columns `name` and `creation_ts`. This means that if a different sort order is used (`is_guest`, `admin`, `user_type`, `deactivated`, `shadow_banned`, `avatar_url` or `displayname`), this can cause a large load on the database, especially for large environments. @@ -222,6 +225,7 @@ The following fields are returned in the JSON response body: - `shadow_banned` - bool - Status if that user has been marked as shadow banned. - `displayname` - string - The user's display name if they have set one. - `avatar_url` - string - The user's avatar URL if they have set one. + - `creation_ts` - integer - The user's creation timestamp in ms. - `next_token`: string representing a positive integer - Indication for pagination. See above. - `total` - integer - Total number of media. diff --git a/docs/dev/cas.md b/docs/development/cas.md index 592b2d8d4f..592b2d8d4f 100644 --- a/docs/dev/cas.md +++ b/docs/development/cas.md diff --git a/docs/dev/git.md b/docs/development/git.md index 87950f07b2..9b1ed54b65 100644 --- a/docs/dev/git.md +++ b/docs/development/git.md @@ -9,7 +9,7 @@ commits each of which contains a single change building on what came before. Here, by way of an arbitrary example, is the top of `git log --graph b2dba0607`: -<img src="git/clean.png" alt="clean git graph" width="500px"> +<img src="img/git/clean.png" alt="clean git graph" width="500px"> Note how the commit comment explains clearly what is changing and why. Also note the *absence* of merge commits, as well as the absence of commits called @@ -61,7 +61,7 @@ Ok, so that's what we'd like to achieve. How do we achieve it? The TL;DR is: when you come to merge a pull request, you *probably* want to “squash and merge”: -![squash and merge](git/squash.png). +![squash and merge](img/git/squash.png). (This applies whether you are merging your own PR, or that of another contributor.) @@ -105,7 +105,7 @@ complicated. Here's how we do it. Let's start with a picture: -![branching model](git/branches.jpg) +![branching model](img/git/branches.jpg) It looks complicated, but it's really not. There's one basic rule: *anyone* is free to merge from *any* more-stable branch to *any* less-stable branch at diff --git a/docs/dev/git/branches.jpg b/docs/development/img/git/branches.jpg index 715ecc8cd0..715ecc8cd0 100644 --- a/docs/dev/git/branches.jpg +++ b/docs/development/img/git/branches.jpg Binary files differdiff --git a/docs/dev/git/clean.png b/docs/development/img/git/clean.png index 3accd7ccef..3accd7ccef 100644 --- a/docs/dev/git/clean.png +++ b/docs/development/img/git/clean.png Binary files differdiff --git a/docs/dev/git/squash.png b/docs/development/img/git/squash.png index 234caca3e4..234caca3e4 100644 --- a/docs/dev/git/squash.png +++ b/docs/development/img/git/squash.png Binary files differdiff --git a/docs/dev/saml.md b/docs/development/saml.md index a9bfd2dc05..a9bfd2dc05 100644 --- a/docs/dev/saml.md +++ b/docs/development/saml.md diff --git a/docs/openid.md b/docs/openid.md index cfaafc5015..f685fd551a 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -410,7 +410,7 @@ oidc_providers: display_name_template: "{{ user.name }}" ``` -## Apple +### Apple Configuring "Sign in with Apple" (SiWA) requires an Apple Developer account. diff --git a/docs/sample_log_config.yaml b/docs/sample_log_config.yaml index 669e600081..b088c83405 100644 --- a/docs/sample_log_config.yaml +++ b/docs/sample_log_config.yaml @@ -28,7 +28,7 @@ handlers: # will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR # logs will still be flushed immediately. buffer: - class: logging.handlers.MemoryHandler + class: synapse.logging.handlers.PeriodicallyFlushingMemoryHandler target: file # The capacity is the number of log lines that are buffered before # being written to disk. Increasing this will lead to better @@ -36,6 +36,9 @@ handlers: # be written to disk. capacity: 10 flushLevel: 30 # Flush for WARNING logs as well + # The period of time, in seconds, between forced flushes. + # Messages will not be delayed for longer than this time. + period: 5 # A handler that writes logs to stderr. Unused by default, but can be used # instead of "buffer" and "file" in the logger handlers. diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages index e25c5bb265..0ed1c679fd 100755 --- a/scripts-dev/build_debian_packages +++ b/scripts-dev/build_debian_packages @@ -17,6 +17,7 @@ import subprocess import sys import threading from concurrent.futures import ThreadPoolExecutor +from typing import Optional, Sequence DISTS = ( "debian:buster", @@ -39,8 +40,11 @@ projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) class Builder(object): - def __init__(self, redirect_stdout=False): + def __init__( + self, redirect_stdout=False, docker_build_args: Optional[Sequence[str]] = None + ): self.redirect_stdout = redirect_stdout + self._docker_build_args = tuple(docker_build_args or ()) self.active_containers = set() self._lock = threading.Lock() self._failed = False @@ -79,8 +83,8 @@ class Builder(object): stdout = None # first build a docker image for the build environment - subprocess.check_call( - [ + build_args = ( + ( "docker", "build", "--tag", @@ -89,8 +93,13 @@ class Builder(object): "distro=" + dist, "-f", "docker/Dockerfile-dhvirtualenv", - "docker", - ], + ) + + self._docker_build_args + + ("docker",) + ) + + subprocess.check_call( + build_args, stdout=stdout, stderr=subprocess.STDOUT, cwd=projdir, @@ -147,9 +156,7 @@ class Builder(object): self.active_containers.remove(c) -def run_builds(dists, jobs=1, skip_tests=False): - builder = Builder(redirect_stdout=(jobs > 1)) - +def run_builds(builder, dists, jobs=1, skip_tests=False): def sig(signum, _frame): print("Caught SIGINT") builder.kill_containers() @@ -181,6 +188,11 @@ if __name__ == "__main__": help="skip running tests after building", ) parser.add_argument( + "--docker-build-arg", + action="append", + help="specify an argument to pass to docker build", + ) + parser.add_argument( "--show-dists-json", action="store_true", help="instead of building the packages, just list the dists to build for, as a json array", @@ -195,4 +207,12 @@ if __name__ == "__main__": if args.show_dists_json: print(json.dumps(DISTS)) else: - run_builds(dists=args.dist, jobs=args.jobs, skip_tests=args.no_check) + builder = Builder( + redirect_stdout=(args.jobs > 1), docker_build_args=args.docker_build_arg + ) + run_builds( + builder, + dists=args.dist, + jobs=args.jobs, + skip_tests=args.no_check, + ) diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index aca32edc17..cba015d942 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -65,4 +65,4 @@ if [[ -n "$1" ]]; then fi # Run the tests! -go test -v -tags synapse_blacklist,msc2946,msc3083,msc2716,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests +go test -v -tags synapse_blacklist,msc2946,msc3083,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 8c7ad2a407..a40d6d7961 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -120,6 +120,7 @@ class EventTypes: SpaceParent = "m.space.parent" MSC2716_INSERTION = "org.matrix.msc2716.insertion" + MSC2716_CHUNK = "org.matrix.msc2716.chunk" MSC2716_MARKER = "org.matrix.msc2716.marker" @@ -198,9 +199,10 @@ class EventContentFields: # Used on normal messages to indicate they were historically imported after the fact MSC2716_HISTORICAL = "org.matrix.msc2716.historical" - # For "insertion" events + # For "insertion" events to indicate what the next chunk ID should be in + # order to connect to it MSC2716_NEXT_CHUNK_ID = "org.matrix.msc2716.next_chunk_id" - # Used on normal message events to indicate where the chunk connects to + # Used on "chunk" events to indicate which insertion event it connects to MSC2716_CHUNK_ID = "org.matrix.msc2716.chunk_id" # For "marker" events MSC2716_MARKER_INSERTION = "org.matrix.msc2716.marker.insertion" @@ -230,3 +232,7 @@ class HistoryVisibility: JOINED = "joined" SHARED = "shared" WORLD_READABLE = "world_readable" + + +class ReadReceiptEventFields: + MSC2285_HIDDEN = "org.matrix.msc2285.hidden" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 054ab14ab6..dc662bca83 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -75,6 +75,9 @@ class Codes: INVALID_SIGNATURE = "M_INVALID_SIGNATURE" USER_DEACTIVATED = "M_USER_DEACTIVATED" BAD_ALIAS = "M_BAD_ALIAS" + # For restricted join rules. + UNABLE_AUTHORISE_JOIN = "M_UNABLE_TO_AUTHORISE_JOIN" + UNABLE_TO_GRANT_JOIN = "M_UNABLE_TO_GRANT_JOIN" class CodeMessageException(RuntimeError): diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index a20abc5a65..697319e52d 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict +from typing import Callable, Dict, Optional import attr @@ -168,7 +168,7 @@ class RoomVersions: msc2403_knocking=False, ) MSC3083 = RoomVersion( - "org.matrix.msc3083", + "org.matrix.msc3083.v2", RoomDisposition.UNSTABLE, EventFormatVersions.V3, StateResolutionVersions.V2, @@ -208,5 +208,39 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = { RoomVersions.MSC3083, RoomVersions.V7, ) - # Note that we do not include MSC2043 here unless it is enabled in the config. +} + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RoomVersionCapability: + """An object which describes the unique attributes of a room version.""" + + identifier: str # the identifier for this capability + preferred_version: Optional[RoomVersion] + support_check_lambda: Callable[[RoomVersion], bool] + + +MSC3244_CAPABILITIES = { + cap.identifier: { + "preferred": cap.preferred_version.identifier + if cap.preferred_version is not None + else None, + "support": [ + v.identifier + for v in KNOWN_ROOM_VERSIONS.values() + if cap.support_check_lambda(v) + ], + } + for cap in ( + RoomVersionCapability( + "knock", + RoomVersions.V7, + lambda room_version: room_version.msc2403_knocking, + ), + RoomVersionCapability( + "restricted", + None, + lambda room_version: room_version.msc3083_join_rules, + ), + ) } diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index bcecbfec03..8d8f166e9b 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -39,12 +39,13 @@ DEFAULT_SUBJECTS = { "messages_from_person_and_others": "[%(app)s] You have messages on %(app)s from %(person)s and others...", "invite_from_person": "[%(app)s] %(person)s has invited you to chat on %(app)s...", "invite_from_person_to_room": "[%(app)s] %(person)s has invited you to join the %(room)s room on %(app)s...", + "invite_from_person_to_space": "[%(app)s] %(person)s has invited you to join the %(space)s space on %(app)s...", "password_reset": "[%(server_name)s] Password reset", "email_validation": "[%(server_name)s] Validate your email", } -@attr.s +@attr.s(slots=True, frozen=True) class EmailSubjectConfig: message_from_person_in_room = attr.ib(type=str) message_from_person = attr.ib(type=str) @@ -54,6 +55,7 @@ class EmailSubjectConfig: messages_from_person_and_others = attr.ib(type=str) invite_from_person = attr.ib(type=str) invite_from_person_to_room = attr.ib(type=str) + invite_from_person_to_space = attr.ib(type=str) password_reset = attr.ib(type=str) email_validation = attr.ib(type=str) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index e25ccba9ac..4c60ee8c28 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -32,3 +32,9 @@ class ExperimentalConfig(Config): # MSC2716 (backfill existing history) self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False) + + # MSC2285 (hidden read receipts) + self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False) + + # MSC3244 (room version capabilities) + self.msc3244_enabled: bool = experimental.get("msc3244_enabled", False) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index ad4e6e61c3..dcd3ed1dac 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -71,7 +71,7 @@ handlers: # will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR # logs will still be flushed immediately. buffer: - class: logging.handlers.MemoryHandler + class: synapse.logging.handlers.PeriodicallyFlushingMemoryHandler target: file # The capacity is the number of log lines that are buffered before # being written to disk. Increasing this will lead to better @@ -79,6 +79,9 @@ handlers: # be written to disk. capacity: 10 flushLevel: 30 # Flush for WARNING logs as well + # The period of time, in seconds, between forced flushes. + # Messages will not be delayed for longer than this time. + period: 5 # A handler that writes logs to stderr. Unused by default, but can be used # instead of "buffer" and "file" in the logger handlers. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 137dff2513..cc92d35477 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -106,6 +106,18 @@ def check( if not event.signatures.get(event_id_domain): raise AuthError(403, "Event not signed by sending server") + is_invite_via_allow_rule = ( + event.type == EventTypes.Member + and event.membership == Membership.JOIN + and "join_authorised_via_users_server" in event.content + ) + if is_invite_via_allow_rule: + authoriser_domain = get_domain_from_id( + event.content["join_authorised_via_users_server"] + ) + if not event.signatures.get(authoriser_domain): + raise AuthError(403, "Event not signed by authorising server") + # Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules # # 1. If type is m.room.create: @@ -177,7 +189,7 @@ def check( # https://github.com/vector-im/vector-web/issues/1208 hopefully if event.type == EventTypes.ThirdPartyInvite: user_level = get_user_power_level(event.user_id, auth_events) - invite_level = _get_named_level(auth_events, "invite", 0) + invite_level = get_named_level(auth_events, "invite", 0) if user_level < invite_level: raise AuthError(403, "You don't have permission to invite users") @@ -285,8 +297,8 @@ def _is_membership_change_allowed( user_level = get_user_power_level(event.user_id, auth_events) target_level = get_user_power_level(target_user_id, auth_events) - # FIXME (erikj): What should we do here as the default? - ban_level = _get_named_level(auth_events, "ban", 50) + invite_level = get_named_level(auth_events, "invite", 0) + ban_level = get_named_level(auth_events, "ban", 50) logger.debug( "_is_membership_change_allowed: %s", @@ -336,8 +348,6 @@ def _is_membership_change_allowed( elif target_in_room: # the target is already in the room. raise AuthError(403, "%s is already in the room." % target_user_id) else: - invite_level = _get_named_level(auth_events, "invite", 0) - if user_level < invite_level: raise AuthError(403, "You don't have permission to invite users") elif Membership.JOIN == membership: @@ -345,16 +355,41 @@ def _is_membership_change_allowed( # * They are not banned. # * They are accepting a previously sent invitation. # * They are already joined (it's a NOOP). - # * The room is public or restricted. + # * The room is public. + # * The room is restricted and the user meets the allows rules. if event.user_id != target_user_id: raise AuthError(403, "Cannot force another user to join.") elif target_banned: raise AuthError(403, "You are banned from this room") - elif join_rule == JoinRules.PUBLIC or ( + elif join_rule == JoinRules.PUBLIC: + pass + elif ( room_version.msc3083_join_rules and join_rule == JoinRules.MSC3083_RESTRICTED ): - pass + # This is the same as public, but the event must contain a reference + # to the server who authorised the join. If the event does not contain + # the proper content it is rejected. + # + # Note that if the caller is in the room or invited, then they do + # not need to meet the allow rules. + if not caller_in_room and not caller_invited: + authorising_user = event.content.get("join_authorised_via_users_server") + + if authorising_user is None: + raise AuthError(403, "Join event is missing authorising user.") + + # The authorising user must be in the room. + key = (EventTypes.Member, authorising_user) + member_event = auth_events.get(key) + _check_joined_room(member_event, authorising_user, event.room_id) + + authorising_user_level = get_user_power_level( + authorising_user, auth_events + ) + if authorising_user_level < invite_level: + raise AuthError(403, "Join event authorised by invalid server.") + elif join_rule == JoinRules.INVITE or ( room_version.msc2403_knocking and join_rule == JoinRules.KNOCK ): @@ -369,7 +404,7 @@ def _is_membership_change_allowed( if target_banned and user_level < ban_level: raise AuthError(403, "You cannot unban user %s." % (target_user_id,)) elif target_user_id != event.user_id: - kick_level = _get_named_level(auth_events, "kick", 50) + kick_level = get_named_level(auth_events, "kick", 50) if user_level < kick_level or user_level <= target_level: raise AuthError(403, "You cannot kick user %s." % target_user_id) @@ -445,7 +480,7 @@ def get_send_level( def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool: - power_levels_event = _get_power_level_event(auth_events) + power_levels_event = get_power_level_event(auth_events) send_level = get_send_level(event.type, event.get("state_key"), power_levels_event) user_level = get_user_power_level(event.user_id, auth_events) @@ -485,7 +520,7 @@ def check_redaction( """ user_level = get_user_power_level(event.user_id, auth_events) - redact_level = _get_named_level(auth_events, "redact", 50) + redact_level = get_named_level(auth_events, "redact", 50) if user_level >= redact_level: return False @@ -600,7 +635,7 @@ def _check_power_levels( ) -def _get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]: +def get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]: return auth_events.get((EventTypes.PowerLevels, "")) @@ -616,7 +651,7 @@ def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int: Returns: the user's power level in this room. """ - power_level_event = _get_power_level_event(auth_events) + power_level_event = get_power_level_event(auth_events) if power_level_event: level = power_level_event.content.get("users", {}).get(user_id) if not level: @@ -640,8 +675,8 @@ def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int: return 0 -def _get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int: - power_level_event = _get_power_level_event(auth_events) +def get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int: + power_level_event = get_power_level_event(auth_events) if not power_level_event: return default @@ -728,7 +763,9 @@ def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]: return public_keys -def auth_types_for_event(event: Union[EventBase, EventBuilder]) -> Set[Tuple[str, str]]: +def auth_types_for_event( + room_version: RoomVersion, event: Union[EventBase, EventBuilder] +) -> Set[Tuple[str, str]]: """Given an event, return a list of (EventType, StateKey) that may be needed to auth the event. The returned list may be a superset of what would actually be required depending on the full state of the room. @@ -760,4 +797,12 @@ def auth_types_for_event(event: Union[EventBase, EventBuilder]) -> Set[Tuple[str ) auth_types.add(key) + if room_version.msc3083_join_rules and membership == Membership.JOIN: + if "join_authorised_via_users_server" in event.content: + key = ( + EventTypes.Member, + event.content["join_authorised_via_users_server"], + ) + auth_types.add(key) + return auth_types diff --git a/synapse/events/utils.py b/synapse/events/utils.py index ec96999e4e..f4da9e0923 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -109,6 +109,8 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: add_fields("creator") elif event_type == EventTypes.JoinRules: add_fields("join_rule") + if room_version.msc3083_join_rules: + add_fields("allow") elif event_type == EventTypes.PowerLevels: add_fields( "users", diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 2bfe6a3d37..024e440ff4 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -178,6 +178,34 @@ async def _check_sigs_on_pdu( ) raise SynapseError(403, errmsg, Codes.FORBIDDEN) + # If this is a join event for a restricted room it may have been authorised + # via a different server from the sending server. Check those signatures. + if ( + room_version.msc3083_join_rules + and pdu.type == EventTypes.Member + and pdu.membership == Membership.JOIN + and "join_authorised_via_users_server" in pdu.content + ): + authorising_server = get_domain_from_id( + pdu.content["join_authorised_via_users_server"] + ) + try: + await keyring.verify_event_for_server( + authorising_server, + pdu, + pdu.origin_server_ts if room_version.enforce_key_validity else 0, + ) + except Exception as e: + errmsg = ( + "event id %s: unable to verify signature for authorising server %s: %s" + % ( + pdu.event_id, + authorising_server, + e, + ) + ) + raise SynapseError(403, errmsg, Codes.FORBIDDEN) + def _is_invite_via_3pid(event: EventBase) -> bool: return ( diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index c767d30627..dbadf102f2 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -19,7 +19,6 @@ import itertools import logging from typing import ( TYPE_CHECKING, - Any, Awaitable, Callable, Collection, @@ -79,7 +78,15 @@ class InvalidResponseError(RuntimeError): we couldn't parse """ - pass + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class SendJoinResult: + # The event to persist. + event: EventBase + # A string giving the server the event was sent to. + origin: str + state: List[EventBase] + auth_chain: List[EventBase] class FederationClient(FederationBase): @@ -677,7 +684,7 @@ class FederationClient(FederationBase): async def send_join( self, destinations: Iterable[str], pdu: EventBase, room_version: RoomVersion - ) -> Dict[str, Any]: + ) -> SendJoinResult: """Sends a join event to one of a list of homeservers. Doing so will cause the remote server to add the event to the graph, @@ -691,18 +698,38 @@ class FederationClient(FederationBase): did the make_join) Returns: - a dict with members ``origin`` (a string - giving the server the event was sent to, ``state`` (?) and - ``auth_chain``. + The result of the send join request. Raises: SynapseError: if the chosen remote server returns a 300/400 code, or no servers successfully handle the request. """ - async def send_request(destination) -> Dict[str, Any]: + async def send_request(destination) -> SendJoinResult: response = await self._do_send_join(room_version, destination, pdu) + # If an event was returned (and expected to be returned): + # + # * Ensure it has the same event ID (note that the event ID is a hash + # of the event fields for versions which support MSC3083). + # * Ensure the signatures are good. + # + # Otherwise, fallback to the provided event. + if room_version.msc3083_join_rules and response.event: + event = response.event + + valid_pdu = await self._check_sigs_and_hash_and_fetch_one( + pdu=event, + origin=destination, + outlier=True, + room_version=room_version, + ) + + if valid_pdu is None or event.event_id != pdu.event_id: + raise InvalidResponseError("Returned an invalid join event") + else: + event = pdu + state = response.state auth_chain = response.auth_events @@ -784,11 +811,21 @@ class FederationClient(FederationBase): % (auth_chain_create_events,) ) - return { - "state": signed_state, - "auth_chain": signed_auth, - "origin": destination, - } + return SendJoinResult( + event=event, + state=signed_state, + auth_chain=signed_auth, + origin=destination, + ) + + if room_version.msc3083_join_rules: + # If the join is being authorised via allow rules, we need to send + # the /send_join back to the same server that was originally used + # with /make_join. + if "join_authorised_via_users_server" in pdu.content: + destinations = [ + get_domain_from_id(pdu.content["join_authorised_via_users_server"]) + ] return await self._try_destination_list("send_join", destinations, send_request) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 29619aeeb8..2892a11d7d 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -45,6 +45,7 @@ from synapse.api.errors import ( UnsupportedRoomVersionError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion +from synapse.crypto.event_signing import compute_event_signature from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.federation.federation_base import FederationBase, event_from_pdu_json @@ -64,7 +65,7 @@ from synapse.replication.http.federation import ( ReplicationGetQueryRestServlet, ) from synapse.storage.databases.main.lock import Lock -from synapse.types import JsonDict +from synapse.types import JsonDict, get_domain_from_id from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache @@ -586,7 +587,7 @@ class FederationServer(FederationBase): async def on_send_join_request( self, origin: str, content: JsonDict, room_id: str ) -> Dict[str, Any]: - context = await self._on_send_membership_event( + event, context = await self._on_send_membership_event( origin, content, Membership.JOIN, room_id ) @@ -597,6 +598,7 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() return { + "org.matrix.msc3083.v2.event": event.get_pdu_json(), "state": [p.get_pdu_json(time_now) for p in state.values()], "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain], } @@ -681,7 +683,7 @@ class FederationServer(FederationBase): Returns: The stripped room state. """ - event_context = await self._on_send_membership_event( + _, context = await self._on_send_membership_event( origin, content, Membership.KNOCK, room_id ) @@ -690,14 +692,14 @@ class FederationServer(FederationBase): # related to the room while the knock request is pending. stripped_room_state = ( await self.store.get_stripped_room_state_from_event_context( - event_context, self._room_prejoin_state_types + context, self._room_prejoin_state_types ) ) return {"knock_state_events": stripped_room_state} async def _on_send_membership_event( self, origin: str, content: JsonDict, membership_type: str, room_id: str - ) -> EventContext: + ) -> Tuple[EventBase, EventContext]: """Handle an on_send_{join,leave,knock} request Does some preliminary validation before passing the request on to the @@ -712,7 +714,7 @@ class FederationServer(FederationBase): in the event Returns: - The context of the event after inserting it into the room graph. + The event and context of the event after inserting it into the room graph. Raises: SynapseError if there is a problem with the request, including things like @@ -748,6 +750,33 @@ class FederationServer(FederationBase): logger.debug("_on_send_membership_event: pdu sigs: %s", event.signatures) + # Sign the event since we're vouching on behalf of the remote server that + # the event is valid to be sent into the room. Currently this is only done + # if the user is being joined via restricted join rules. + if ( + room_version.msc3083_join_rules + and event.membership == Membership.JOIN + and "join_authorised_via_users_server" in event.content + ): + # We can only authorise our own users. + authorising_server = get_domain_from_id( + event.content["join_authorised_via_users_server"] + ) + if authorising_server != self.server_name: + raise SynapseError( + 400, + f"Cannot authorise request from resident server: {authorising_server}", + ) + + event.signatures.update( + compute_event_signature( + room_version, + event.get_pdu_json(), + self.hs.hostname, + self.hs.signing_key, + ) + ) + event = await self._check_sigs_and_hash(room_version, event) return await self.handler.on_send_membership_event(origin, event) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 98b1bf77fd..6a8d3ad4fe 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -15,7 +15,7 @@ import logging import urllib -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union import attr import ijson @@ -29,6 +29,7 @@ from synapse.api.urls import ( FEDERATION_V2_PREFIX, ) from synapse.events import EventBase, make_event_from_dict +from synapse.federation.units import Transaction from synapse.http.matrixfederationclient import ByteParser from synapse.logging.utils import log_function from synapse.types import JsonDict @@ -49,23 +50,25 @@ class TransportLayerClient: self.client = hs.get_federation_http_client() @log_function - def get_room_state_ids(self, destination, room_id, event_id): + async def get_room_state_ids( + self, destination: str, room_id: str, event_id: str + ) -> JsonDict: """Requests all state for a given room from the given server at the given event. Returns the state's event_id's Args: - destination (str): The host name of the remote homeserver we want + destination: The host name of the remote homeserver we want to get the state from. - context (str): The name of the context we want the state of - event_id (str): The event we want the context at. + context: The name of the context we want the state of + event_id: The event we want the context at. Returns: - Awaitable: Results in a dict received from the remote homeserver. + Results in a dict received from the remote homeserver. """ logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id) path = _create_v1_path("/state_ids/%s", room_id) - return self.client.get_json( + return await self.client.get_json( destination, path=path, args={"event_id": event_id}, @@ -73,39 +76,43 @@ class TransportLayerClient: ) @log_function - def get_event(self, destination, event_id, timeout=None): + async def get_event( + self, destination: str, event_id: str, timeout: Optional[int] = None + ) -> JsonDict: """Requests the pdu with give id and origin from the given server. Args: - destination (str): The host name of the remote homeserver we want + destination: The host name of the remote homeserver we want to get the state from. - event_id (str): The id of the event being requested. - timeout (int): How long to try (in ms) the destination for before + event_id: The id of the event being requested. + timeout: How long to try (in ms) the destination for before giving up. None indicates no timeout. Returns: - Awaitable: Results in a dict received from the remote homeserver. + Results in a dict received from the remote homeserver. """ logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id) path = _create_v1_path("/event/%s", event_id) - return self.client.get_json( + return await self.client.get_json( destination, path=path, timeout=timeout, try_trailing_slash_on_400=True ) @log_function - def backfill(self, destination, room_id, event_tuples, limit): + async def backfill( + self, destination: str, room_id: str, event_tuples: Iterable[str], limit: int + ) -> Optional[JsonDict]: """Requests `limit` previous PDUs in a given context before list of PDUs. Args: - dest (str) - room_id (str) - event_tuples (list) - limit (int) + destination + room_id + event_tuples + limit Returns: - Awaitable: Results in a dict received from the remote homeserver. + Results in a dict received from the remote homeserver. """ logger.debug( "backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s", @@ -117,18 +124,22 @@ class TransportLayerClient: if not event_tuples: # TODO: raise? - return + return None path = _create_v1_path("/backfill/%s", room_id) args = {"v": event_tuples, "limit": [str(limit)]} - return self.client.get_json( + return await self.client.get_json( destination, path=path, args=args, try_trailing_slash_on_400=True ) @log_function - async def send_transaction(self, transaction, json_data_callback=None): + async def send_transaction( + self, + transaction: Transaction, + json_data_callback: Optional[Callable[[], JsonDict]] = None, + ) -> JsonDict: """Sends the given Transaction to its destination Args: @@ -149,21 +160,21 @@ class TransportLayerClient: """ logger.debug( "send_data dest=%s, txid=%s", - transaction.destination, - transaction.transaction_id, + transaction.destination, # type: ignore + transaction.transaction_id, # type: ignore ) - if transaction.destination == self.server_name: + if transaction.destination == self.server_name: # type: ignore raise RuntimeError("Transport layer cannot send to itself!") # FIXME: This is only used by the tests. The actual json sent is # generated by the json_data_callback. json_data = transaction.get_dict() - path = _create_v1_path("/send/%s", transaction.transaction_id) + path = _create_v1_path("/send/%s", transaction.transaction_id) # type: ignore - response = await self.client.put_json( - transaction.destination, + return await self.client.put_json( + transaction.destination, # type: ignore path=path, data=json_data, json_data_callback=json_data_callback, @@ -172,8 +183,6 @@ class TransportLayerClient: try_trailing_slash_on_400=True, ) - return response - @log_function async def make_query( self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False @@ -193,8 +202,13 @@ class TransportLayerClient: @log_function async def make_membership_event( - self, destination, room_id, user_id, membership, params - ): + self, + destination: str, + room_id: str, + user_id: str, + membership: str, + params: Optional[Mapping[str, Union[str, Iterable[str]]]], + ) -> JsonDict: """Asks a remote server to build and sign us a membership event Note that this does not append any events to any graphs. @@ -240,7 +254,7 @@ class TransportLayerClient: ignore_backoff = True retry_on_dns_fail = True - content = await self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args=params, @@ -249,20 +263,18 @@ class TransportLayerClient: ignore_backoff=ignore_backoff, ) - return content - @log_function async def send_join_v1( self, - room_version, - destination, - room_id, - event_id, - content, + room_version: RoomVersion, + destination: str, + room_id: str, + event_id: str, + content: JsonDict, ) -> "SendJoinResponse": path = _create_v1_path("/send_join/%s/%s", room_id, event_id) - response = await self.client.put_json( + return await self.client.put_json( destination=destination, path=path, data=content, @@ -270,15 +282,18 @@ class TransportLayerClient: max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN, ) - return response - @log_function async def send_join_v2( - self, room_version, destination, room_id, event_id, content + self, + room_version: RoomVersion, + destination: str, + room_id: str, + event_id: str, + content: JsonDict, ) -> "SendJoinResponse": path = _create_v2_path("/send_join/%s/%s", room_id, event_id) - response = await self.client.put_json( + return await self.client.put_json( destination=destination, path=path, data=content, @@ -286,13 +301,13 @@ class TransportLayerClient: max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN, ) - return response - @log_function - async def send_leave_v1(self, destination, room_id, event_id, content): + async def send_leave_v1( + self, destination: str, room_id: str, event_id: str, content: JsonDict + ) -> Tuple[int, JsonDict]: path = _create_v1_path("/send_leave/%s/%s", room_id, event_id) - response = await self.client.put_json( + return await self.client.put_json( destination=destination, path=path, data=content, @@ -303,13 +318,13 @@ class TransportLayerClient: ignore_backoff=True, ) - return response - @log_function - async def send_leave_v2(self, destination, room_id, event_id, content): + async def send_leave_v2( + self, destination: str, room_id: str, event_id: str, content: JsonDict + ) -> JsonDict: path = _create_v2_path("/send_leave/%s/%s", room_id, event_id) - response = await self.client.put_json( + return await self.client.put_json( destination=destination, path=path, data=content, @@ -320,8 +335,6 @@ class TransportLayerClient: ignore_backoff=True, ) - return response - @log_function async def send_knock_v1( self, @@ -357,25 +370,25 @@ class TransportLayerClient: ) @log_function - async def send_invite_v1(self, destination, room_id, event_id, content): + async def send_invite_v1( + self, destination: str, room_id: str, event_id: str, content: JsonDict + ) -> Tuple[int, JsonDict]: path = _create_v1_path("/invite/%s/%s", room_id, event_id) - response = await self.client.put_json( + return await self.client.put_json( destination=destination, path=path, data=content, ignore_backoff=True ) - return response - @log_function - async def send_invite_v2(self, destination, room_id, event_id, content): + async def send_invite_v2( + self, destination: str, room_id: str, event_id: str, content: JsonDict + ) -> JsonDict: path = _create_v2_path("/invite/%s/%s", room_id, event_id) - response = await self.client.put_json( + return await self.client.put_json( destination=destination, path=path, data=content, ignore_backoff=True ) - return response - @log_function async def get_public_rooms( self, @@ -385,7 +398,7 @@ class TransportLayerClient: search_filter: Optional[Dict] = None, include_all_networks: bool = False, third_party_instance_id: Optional[str] = None, - ): + ) -> JsonDict: """Get the list of public rooms from a remote homeserver See synapse.federation.federation_client.FederationClient.get_public_rooms for @@ -450,25 +463,27 @@ class TransportLayerClient: return response @log_function - async def exchange_third_party_invite(self, destination, room_id, event_dict): + async def exchange_third_party_invite( + self, destination: str, room_id: str, event_dict: JsonDict + ) -> JsonDict: path = _create_v1_path("/exchange_third_party_invite/%s", room_id) - response = await self.client.put_json( + return await self.client.put_json( destination=destination, path=path, data=event_dict ) - return response - @log_function - async def get_event_auth(self, destination, room_id, event_id): + async def get_event_auth( + self, destination: str, room_id: str, event_id: str + ) -> JsonDict: path = _create_v1_path("/event_auth/%s/%s", room_id, event_id) - content = await self.client.get_json(destination=destination, path=path) - - return content + return await self.client.get_json(destination=destination, path=path) @log_function - async def query_client_keys(self, destination, query_content, timeout): + async def query_client_keys( + self, destination: str, query_content: JsonDict, timeout: int + ) -> JsonDict: """Query the device keys for a list of user ids hosted on a remote server. @@ -496,20 +511,21 @@ class TransportLayerClient: } Args: - destination(str): The server to query. - query_content(dict): The user ids to query. + destination: The server to query. + query_content: The user ids to query. Returns: A dict containing device and cross-signing keys. """ path = _create_v1_path("/user/keys/query") - content = await self.client.post_json( + return await self.client.post_json( destination=destination, path=path, data=query_content, timeout=timeout ) - return content @log_function - async def query_user_devices(self, destination, user_id, timeout): + async def query_user_devices( + self, destination: str, user_id: str, timeout: int + ) -> JsonDict: """Query the devices for a user id hosted on a remote server. Response: @@ -535,20 +551,21 @@ class TransportLayerClient: } Args: - destination(str): The server to query. - query_content(dict): The user ids to query. + destination: The server to query. + query_content: The user ids to query. Returns: A dict containing device and cross-signing keys. """ path = _create_v1_path("/user/devices/%s", user_id) - content = await self.client.get_json( + return await self.client.get_json( destination=destination, path=path, timeout=timeout ) - return content @log_function - async def claim_client_keys(self, destination, query_content, timeout): + async def claim_client_keys( + self, destination: str, query_content: JsonDict, timeout: int + ) -> JsonDict: """Claim one-time keys for a list of devices hosted on a remote server. Request: @@ -572,33 +589,32 @@ class TransportLayerClient: } Args: - destination(str): The server to query. - query_content(dict): The user ids to query. + destination: The server to query. + query_content: The user ids to query. Returns: A dict containing the one-time keys. """ path = _create_v1_path("/user/keys/claim") - content = await self.client.post_json( + return await self.client.post_json( destination=destination, path=path, data=query_content, timeout=timeout ) - return content @log_function async def get_missing_events( self, - destination, - room_id, - earliest_events, - latest_events, - limit, - min_depth, - timeout, - ): + destination: str, + room_id: str, + earliest_events: Iterable[str], + latest_events: Iterable[str], + limit: int, + min_depth: int, + timeout: int, + ) -> JsonDict: path = _create_v1_path("/get_missing_events/%s", room_id) - content = await self.client.post_json( + return await self.client.post_json( destination=destination, path=path, data={ @@ -610,14 +626,14 @@ class TransportLayerClient: timeout=timeout, ) - return content - @log_function - def get_group_profile(self, destination, group_id, requester_user_id): + async def get_group_profile( + self, destination: str, group_id: str, requester_user_id: str + ) -> JsonDict: """Get a group profile""" path = _create_v1_path("/groups/%s/profile", group_id) - return self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -625,14 +641,16 @@ class TransportLayerClient: ) @log_function - def update_group_profile(self, destination, group_id, requester_user_id, content): + async def update_group_profile( + self, destination: str, group_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: """Update a remote group profile Args: - destination (str) - group_id (str) - requester_user_id (str) - content (dict): The new profile of the group + destination + group_id + requester_user_id + content: The new profile of the group """ path = _create_v1_path("/groups/%s/profile", group_id) @@ -645,11 +663,13 @@ class TransportLayerClient: ) @log_function - def get_group_summary(self, destination, group_id, requester_user_id): + async def get_group_summary( + self, destination: str, group_id: str, requester_user_id: str + ) -> JsonDict: """Get a group summary""" path = _create_v1_path("/groups/%s/summary", group_id) - return self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -657,24 +677,31 @@ class TransportLayerClient: ) @log_function - def get_rooms_in_group(self, destination, group_id, requester_user_id): + async def get_rooms_in_group( + self, destination: str, group_id: str, requester_user_id: str + ) -> JsonDict: """Get all rooms in a group""" path = _create_v1_path("/groups/%s/rooms", group_id) - return self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, ignore_backoff=True, ) - def add_room_to_group( - self, destination, group_id, requester_user_id, room_id, content - ): + async def add_room_to_group( + self, + destination: str, + group_id: str, + requester_user_id: str, + room_id: str, + content: JsonDict, + ) -> JsonDict: """Add a room to a group""" path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -682,15 +709,21 @@ class TransportLayerClient: ignore_backoff=True, ) - def update_room_in_group( - self, destination, group_id, requester_user_id, room_id, config_key, content - ): + async def update_room_in_group( + self, + destination: str, + group_id: str, + requester_user_id: str, + room_id: str, + config_key: str, + content: JsonDict, + ) -> JsonDict: """Update room in group""" path = _create_v1_path( "/groups/%s/room/%s/config/%s", group_id, room_id, config_key ) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -698,11 +731,13 @@ class TransportLayerClient: ignore_backoff=True, ) - def remove_room_from_group(self, destination, group_id, requester_user_id, room_id): + async def remove_room_from_group( + self, destination: str, group_id: str, requester_user_id: str, room_id: str + ) -> JsonDict: """Remove a room from a group""" path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) - return self.client.delete_json( + return await self.client.delete_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -710,11 +745,13 @@ class TransportLayerClient: ) @log_function - def get_users_in_group(self, destination, group_id, requester_user_id): + async def get_users_in_group( + self, destination: str, group_id: str, requester_user_id: str + ) -> JsonDict: """Get users in a group""" path = _create_v1_path("/groups/%s/users", group_id) - return self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -722,11 +759,13 @@ class TransportLayerClient: ) @log_function - def get_invited_users_in_group(self, destination, group_id, requester_user_id): + async def get_invited_users_in_group( + self, destination: str, group_id: str, requester_user_id: str + ) -> JsonDict: """Get users that have been invited to a group""" path = _create_v1_path("/groups/%s/invited_users", group_id) - return self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -734,16 +773,20 @@ class TransportLayerClient: ) @log_function - def accept_group_invite(self, destination, group_id, user_id, content): + async def accept_group_invite( + self, destination: str, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Accept a group invite""" path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, data=content, ignore_backoff=True ) @log_function - def join_group(self, destination, group_id, user_id, content): + def join_group( + self, destination: str, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Attempts to join a group""" path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id) @@ -752,13 +795,18 @@ class TransportLayerClient: ) @log_function - def invite_to_group( - self, destination, group_id, user_id, requester_user_id, content - ): + async def invite_to_group( + self, + destination: str, + group_id: str, + user_id: str, + requester_user_id: str, + content: JsonDict, + ) -> JsonDict: """Invite a user to a group""" path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -767,25 +815,32 @@ class TransportLayerClient: ) @log_function - def invite_to_group_notification(self, destination, group_id, user_id, content): + async def invite_to_group_notification( + self, destination: str, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Sent by group server to inform a user's server that they have been invited. """ path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, data=content, ignore_backoff=True ) @log_function - def remove_user_from_group( - self, destination, group_id, requester_user_id, user_id, content - ): + async def remove_user_from_group( + self, + destination: str, + group_id: str, + requester_user_id: str, + user_id: str, + content: JsonDict, + ) -> JsonDict: """Remove a user from a group""" path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -794,35 +849,43 @@ class TransportLayerClient: ) @log_function - def remove_user_from_group_notification( - self, destination, group_id, user_id, content - ): + async def remove_user_from_group_notification( + self, destination: str, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Sent by group server to inform a user's server that they have been kicked from the group. """ path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, data=content, ignore_backoff=True ) @log_function - def renew_group_attestation(self, destination, group_id, user_id, content): + async def renew_group_attestation( + self, destination: str, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Sent by either a group server or a user's server to periodically update the attestations """ path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, data=content, ignore_backoff=True ) @log_function - def update_group_summary_room( - self, destination, group_id, user_id, room_id, category_id, content - ): + async def update_group_summary_room( + self, + destination: str, + group_id: str, + user_id: str, + room_id: str, + category_id: str, + content: JsonDict, + ) -> JsonDict: """Update a room entry in a group summary""" if category_id: path = _create_v1_path( @@ -834,7 +897,7 @@ class TransportLayerClient: else: path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, args={"requester_user_id": user_id}, @@ -843,9 +906,14 @@ class TransportLayerClient: ) @log_function - def delete_group_summary_room( - self, destination, group_id, user_id, room_id, category_id - ): + async def delete_group_summary_room( + self, + destination: str, + group_id: str, + user_id: str, + room_id: str, + category_id: str, + ) -> JsonDict: """Delete a room entry in a group summary""" if category_id: path = _create_v1_path( @@ -857,7 +925,7 @@ class TransportLayerClient: else: path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id) - return self.client.delete_json( + return await self.client.delete_json( destination=destination, path=path, args={"requester_user_id": user_id}, @@ -865,11 +933,13 @@ class TransportLayerClient: ) @log_function - def get_group_categories(self, destination, group_id, requester_user_id): + async def get_group_categories( + self, destination: str, group_id: str, requester_user_id: str + ) -> JsonDict: """Get all categories in a group""" path = _create_v1_path("/groups/%s/categories", group_id) - return self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -877,11 +947,13 @@ class TransportLayerClient: ) @log_function - def get_group_category(self, destination, group_id, requester_user_id, category_id): + async def get_group_category( + self, destination: str, group_id: str, requester_user_id: str, category_id: str + ) -> JsonDict: """Get category info in a group""" path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) - return self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -889,13 +961,18 @@ class TransportLayerClient: ) @log_function - def update_group_category( - self, destination, group_id, requester_user_id, category_id, content - ): + async def update_group_category( + self, + destination: str, + group_id: str, + requester_user_id: str, + category_id: str, + content: JsonDict, + ) -> JsonDict: """Update a category in a group""" path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -904,13 +981,13 @@ class TransportLayerClient: ) @log_function - def delete_group_category( - self, destination, group_id, requester_user_id, category_id - ): + async def delete_group_category( + self, destination: str, group_id: str, requester_user_id: str, category_id: str + ) -> JsonDict: """Delete a category in a group""" path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) - return self.client.delete_json( + return await self.client.delete_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -918,11 +995,13 @@ class TransportLayerClient: ) @log_function - def get_group_roles(self, destination, group_id, requester_user_id): + async def get_group_roles( + self, destination: str, group_id: str, requester_user_id: str + ) -> JsonDict: """Get all roles in a group""" path = _create_v1_path("/groups/%s/roles", group_id) - return self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -930,11 +1009,13 @@ class TransportLayerClient: ) @log_function - def get_group_role(self, destination, group_id, requester_user_id, role_id): + async def get_group_role( + self, destination: str, group_id: str, requester_user_id: str, role_id: str + ) -> JsonDict: """Get a roles info""" path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) - return self.client.get_json( + return await self.client.get_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -942,13 +1023,18 @@ class TransportLayerClient: ) @log_function - def update_group_role( - self, destination, group_id, requester_user_id, role_id, content - ): + async def update_group_role( + self, + destination: str, + group_id: str, + requester_user_id: str, + role_id: str, + content: JsonDict, + ) -> JsonDict: """Update a role in a group""" path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -957,11 +1043,13 @@ class TransportLayerClient: ) @log_function - def delete_group_role(self, destination, group_id, requester_user_id, role_id): + async def delete_group_role( + self, destination: str, group_id: str, requester_user_id: str, role_id: str + ) -> JsonDict: """Delete a role in a group""" path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) - return self.client.delete_json( + return await self.client.delete_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -969,9 +1057,15 @@ class TransportLayerClient: ) @log_function - def update_group_summary_user( - self, destination, group_id, requester_user_id, user_id, role_id, content - ): + async def update_group_summary_user( + self, + destination: str, + group_id: str, + requester_user_id: str, + user_id: str, + role_id: str, + content: JsonDict, + ) -> JsonDict: """Update a users entry in a group""" if role_id: path = _create_v1_path( @@ -980,7 +1074,7 @@ class TransportLayerClient: else: path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id) - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -989,11 +1083,13 @@ class TransportLayerClient: ) @log_function - def set_group_join_policy(self, destination, group_id, requester_user_id, content): + async def set_group_join_policy( + self, destination: str, group_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: """Sets the join policy for a group""" path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id) - return self.client.put_json( + return await self.client.put_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, @@ -1002,9 +1098,14 @@ class TransportLayerClient: ) @log_function - def delete_group_summary_user( - self, destination, group_id, requester_user_id, user_id, role_id - ): + async def delete_group_summary_user( + self, + destination: str, + group_id: str, + requester_user_id: str, + user_id: str, + role_id: str, + ) -> JsonDict: """Delete a users entry in a group""" if role_id: path = _create_v1_path( @@ -1013,33 +1114,35 @@ class TransportLayerClient: else: path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id) - return self.client.delete_json( + return await self.client.delete_json( destination=destination, path=path, args={"requester_user_id": requester_user_id}, ignore_backoff=True, ) - def bulk_get_publicised_groups(self, destination, user_ids): + async def bulk_get_publicised_groups( + self, destination: str, user_ids: Iterable[str] + ) -> JsonDict: """Get the groups a list of users are publicising""" path = _create_v1_path("/get_groups_publicised") content = {"user_ids": user_ids} - return self.client.post_json( + return await self.client.post_json( destination=destination, path=path, data=content, ignore_backoff=True ) - def get_room_complexity(self, destination, room_id): + async def get_room_complexity(self, destination: str, room_id: str) -> JsonDict: """ Args: - destination (str): The remote server - room_id (str): The room ID to ask about. + destination: The remote server + room_id: The room ID to ask about. """ path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/rooms/%s/complexity", room_id) - return self.client.get_json(destination=destination, path=path) + return await self.client.get_json(destination=destination, path=path) async def get_space_summary( self, @@ -1075,14 +1178,14 @@ class TransportLayerClient: ) -def _create_path(federation_prefix, path, *args): +def _create_path(federation_prefix: str, path: str, *args: str) -> str: """ Ensures that all args are url encoded. """ return federation_prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args) -def _create_v1_path(path, *args): +def _create_v1_path(path: str, *args: str) -> str: """Creates a path against V1 federation API from the path template and args. Ensures that all args are url encoded. @@ -1091,16 +1194,13 @@ def _create_v1_path(path, *args): _create_v1_path("/event/%s", event_id) Args: - path (str): String template for the path - args: ([str]): Args to insert into path. Each arg will be url encoded - - Returns: - str + path: String template for the path + args: Args to insert into path. Each arg will be url encoded """ return _create_path(FEDERATION_V1_PREFIX, path, *args) -def _create_v2_path(path, *args): +def _create_v2_path(path: str, *args: str) -> str: """Creates a path against V2 federation API from the path template and args. Ensures that all args are url encoded. @@ -1109,11 +1209,8 @@ def _create_v2_path(path, *args): _create_v2_path("/event/%s", event_id) Args: - path (str): String template for the path - args: ([str]): Args to insert into path. Each arg will be url encoded - - Returns: - str + path: String template for the path + args: Args to insert into path. Each arg will be url encoded """ return _create_path(FEDERATION_V2_PREFIX, path, *args) @@ -1122,8 +1219,26 @@ def _create_v2_path(path, *args): class SendJoinResponse: """The parsed response of a `/send_join` request.""" + # The list of auth events from the /send_join response. auth_events: List[EventBase] + # The list of state from the /send_join response. state: List[EventBase] + # The raw join event from the /send_join response. + event_dict: JsonDict + # The parsed join event from the /send_join response. This will be None if + # "event" is not included in the response. + event: Optional[EventBase] = None + + +@ijson.coroutine +def _event_parser(event_dict: JsonDict): + """Helper function for use with `ijson.kvitems_coro` to parse key-value pairs + to add them to a given dictionary. + """ + + while True: + key, value = yield + event_dict[key] = value @ijson.coroutine @@ -1149,7 +1264,8 @@ class SendJoinParser(ByteParser[SendJoinResponse]): CONTENT_TYPE = "application/json" def __init__(self, room_version: RoomVersion, v1_api: bool): - self._response = SendJoinResponse([], []) + self._response = SendJoinResponse([], [], {}) + self._room_version = room_version # The V1 API has the shape of `[200, {...}]`, which we handle by # prefixing with `item.*`. @@ -1163,12 +1279,21 @@ class SendJoinParser(ByteParser[SendJoinResponse]): _event_list_parser(room_version, self._response.auth_events), prefix + "auth_chain.item", ) + self._coro_event = ijson.kvitems_coro( + _event_parser(self._response.event_dict), + prefix + "org.matrix.msc3083.v2.event", + ) def write(self, data: bytes) -> int: self._coro_state.send(data) self._coro_auth.send(data) + self._coro_event.send(data) return len(data) def finish(self) -> SendJoinResponse: + if self._response.event_dict: + self._response.event = make_event_from_dict( + self._response.event_dict, self._room_version + ) return self._response diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 2974d4d0cc..5e059d6e09 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -984,7 +984,7 @@ class PublicRoomList(BaseFederationServlet): limit = parse_integer_from_args(query, "limit", 0) since_token = parse_string_from_args(query, "since", None) include_all_networks = parse_boolean_from_args( - query, "include_all_networks", False + query, "include_all_networks", default=False ) third_party_instance_id = parse_string_from_args( query, "third_party_instance_id", None @@ -1908,16 +1908,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) max_rooms_per_space = parse_integer_from_args(query, "max_rooms_per_space") - exclude_rooms = [] - if b"exclude_rooms" in query: - try: - exclude_rooms = [ - room_id.decode("ascii") for room_id in query[b"exclude_rooms"] - ] - except Exception: - raise SynapseError( - 400, "Bad query parameter for exclude_rooms", Codes.INVALID_PARAM - ) + exclude_rooms = parse_strings_from_args(query, "exclude_rooms", default=[]) return 200, await self.handler.federation_space_summary( origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 41dbdfd0a1..53fac1f8a3 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import TYPE_CHECKING, Collection, List, Optional, Union from synapse import event_auth @@ -20,16 +21,18 @@ from synapse.api.constants import ( Membership, RestrictedJoinRuleTypes, ) -from synapse.api.errors import AuthError +from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.events.builder import EventBuilder -from synapse.types import StateMap +from synapse.types import StateMap, get_domain_from_id from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer +logger = logging.getLogger(__name__) + class EventAuthHandler: """ @@ -39,6 +42,7 @@ class EventAuthHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() self._store = hs.get_datastore() + self._server_name = hs.hostname async def check_from_context( self, room_version: str, event, context, do_sig_check=True @@ -81,15 +85,76 @@ class EventAuthHandler: # introduce undesirable "state reset" behaviour. # # All of which sounds a bit tricky so we don't bother for now. - auth_ids = [] - for etype, state_key in event_auth.auth_types_for_event(event): + for etype, state_key in event_auth.auth_types_for_event( + event.room_version, event + ): auth_ev_id = current_state_ids.get((etype, state_key)) if auth_ev_id: auth_ids.append(auth_ev_id) return auth_ids + async def get_user_which_could_invite( + self, room_id: str, current_state_ids: StateMap[str] + ) -> str: + """ + Searches the room state for a local user who has the power level necessary + to invite other users. + + Args: + room_id: The room ID under search. + current_state_ids: The current state of the room. + + Returns: + The MXID of the user which could issue an invite. + + Raises: + SynapseError if no appropriate user is found. + """ + power_level_event_id = current_state_ids.get((EventTypes.PowerLevels, "")) + invite_level = 0 + users_default_level = 0 + if power_level_event_id: + power_level_event = await self._store.get_event(power_level_event_id) + invite_level = power_level_event.content.get("invite", invite_level) + users_default_level = power_level_event.content.get( + "users_default", users_default_level + ) + users = power_level_event.content.get("users", {}) + else: + users = {} + + # Find the user with the highest power level. + users_in_room = await self._store.get_users_in_room(room_id) + # Only interested in local users. + local_users_in_room = [ + u for u in users_in_room if get_domain_from_id(u) == self._server_name + ] + chosen_user = max( + local_users_in_room, + key=lambda user: users.get(user, users_default_level), + default=None, + ) + + # Return the chosen if they can issue invites. + user_power_level = users.get(chosen_user, users_default_level) + if chosen_user and user_power_level >= invite_level: + logger.debug( + "Found a user who can issue invites %s with power level %d >= invite level %d", + chosen_user, + user_power_level, + invite_level, + ) + return chosen_user + + # No user was found. + raise SynapseError( + 400, + "Unable to find a user which could issue an invite", + Codes.UNABLE_TO_GRANT_JOIN, + ) + async def check_host_in_room(self, room_id: str, host: str) -> bool: with Measure(self._clock, "check_host_in_room"): return await self._store.is_host_joined(room_id, host) @@ -134,6 +199,18 @@ class EventAuthHandler: # in any of them. allowed_rooms = await self.get_rooms_that_allow_join(state_ids) if not await self.is_user_in_rooms(allowed_rooms, user_id): + + # If this is a remote request, the user might be in an allowed room + # that we do not know about. + if get_domain_from_id(user_id) != self._server_name: + for room_id in allowed_rooms: + if not await self._store.is_host_joined(room_id, self._server_name): + raise SynapseError( + 400, + f"Unable to check if {user_id} is in allowed rooms.", + Codes.UNABLE_AUTHORISE_JOIN, + ) + raise AuthError( 403, "You do not belong to any of the required rooms to join this room.", diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 5728719909..aba095d2e1 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1494,9 +1494,10 @@ class FederationHandler(BaseHandler): host_list, event, room_version_obj ) - origin = ret["origin"] - state = ret["state"] - auth_chain = ret["auth_chain"] + event = ret.event + origin = ret.origin + state = ret.state + auth_chain = ret.auth_chain auth_chain.sort(key=lambda e: e.depth) logger.debug("do_invite_join auth_chain: %s", auth_chain) @@ -1676,7 +1677,7 @@ class FederationHandler(BaseHandler): # checking the room version will check that we've actually heard of the room # (and return a 404 otherwise) - room_version = await self.store.get_room_version_id(room_id) + room_version = await self.store.get_room_version(room_id) # now check that we are *still* in the room is_in_room = await self._event_auth_handler.check_host_in_room( @@ -1691,8 +1692,38 @@ class FederationHandler(BaseHandler): event_content = {"membership": Membership.JOIN} + # If the current room is using restricted join rules, additional information + # may need to be included in the event content in order to efficiently + # validate the event. + # + # Note that this requires the /send_join request to come back to the + # same server. + if room_version.msc3083_join_rules: + state_ids = await self.store.get_current_state_ids(room_id) + if await self._event_auth_handler.has_restricted_join_rules( + state_ids, room_version + ): + prev_member_event_id = state_ids.get((EventTypes.Member, user_id), None) + # If the user is invited or joined to the room already, then + # no additional info is needed. + include_auth_user_id = True + if prev_member_event_id: + prev_member_event = await self.store.get_event(prev_member_event_id) + include_auth_user_id = prev_member_event.membership not in ( + Membership.JOIN, + Membership.INVITE, + ) + + if include_auth_user_id: + event_content[ + "join_authorised_via_users_server" + ] = await self._event_auth_handler.get_user_which_could_invite( + room_id, + state_ids, + ) + builder = self.event_builder_factory.new( - room_version, + room_version.identifier, { "type": EventTypes.Member, "content": event_content, @@ -1710,10 +1741,13 @@ class FederationHandler(BaseHandler): logger.warning("Failed to create join to %s because %s", room_id, e) raise + # Ensure the user can even join the room. + await self._check_join_restrictions(context, event) + # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` await self._event_auth_handler.check_from_context( - room_version, event, context, do_sig_check=False + room_version.identifier, event, context, do_sig_check=False ) return event @@ -1958,7 +1992,7 @@ class FederationHandler(BaseHandler): @log_function async def on_send_membership_event( self, origin: str, event: EventBase - ) -> EventContext: + ) -> Tuple[EventBase, EventContext]: """ We have received a join/leave/knock event for a room via send_join/leave/knock. @@ -1981,7 +2015,7 @@ class FederationHandler(BaseHandler): event: The member event that has been signed by the remote homeserver. Returns: - The context of the event after inserting it into the room graph. + The event and context of the event after inserting it into the room graph. Raises: SynapseError if the event is not accepted into the room @@ -2037,7 +2071,7 @@ class FederationHandler(BaseHandler): # all looks good, we can persist the event. await self._run_push_actions_and_persist_event(event, context) - return context + return event, context async def _check_join_restrictions( self, context: EventContext, event: EventBase @@ -2473,7 +2507,7 @@ class FederationHandler(BaseHandler): ) # Now check if event pass auth against said current state - auth_types = auth_types_for_event(event) + auth_types = auth_types_for_event(room_version_obj, event) current_state_ids_list = [ e for k, e in current_state_ids.items() if k in auth_types ] diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 5d49640760..e1c544a3c9 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -21,6 +21,7 @@ from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import SynapseError from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state +from synapse.handlers.receipts import ReceiptEventSource from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage.roommember import RoomsForUser from synapse.streams.config import PaginationConfig @@ -134,6 +135,8 @@ class InitialSyncHandler(BaseHandler): joined_rooms, to_key=int(now_token.receipt_key), ) + if self.hs.config.experimental.msc2285_enabled: + receipt = ReceiptEventSource.filter_out_hidden(receipt, user_id) tags_by_room = await self.store.get_tags_for_user(user_id) @@ -430,7 +433,9 @@ class InitialSyncHandler(BaseHandler): room_id, to_key=now_token.receipt_key ) if not receipts: - receipts = [] + return [] + if self.hs.config.experimental.msc2285_enabled: + receipts = ReceiptEventSource.filter_out_hidden(receipts, user_id) return receipts presence, receipts, (messages, token) = await make_deferred_yieldable( diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 283483fc2c..b9085bbccb 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -14,9 +14,10 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple +from synapse.api.constants import ReadReceiptEventFields from synapse.appservice import ApplicationService from synapse.handlers._base import BaseHandler -from synapse.types import JsonDict, ReadReceipt, get_domain_from_id +from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id if TYPE_CHECKING: from synapse.server import HomeServer @@ -137,7 +138,7 @@ class ReceiptsHandler(BaseHandler): return True async def received_client_receipt( - self, room_id: str, receipt_type: str, user_id: str, event_id: str + self, room_id: str, receipt_type: str, user_id: str, event_id: str, hidden: bool ) -> None: """Called when a client tells us a local user has read up to the given event_id in the room. @@ -147,23 +148,67 @@ class ReceiptsHandler(BaseHandler): receipt_type=receipt_type, user_id=user_id, event_ids=[event_id], - data={"ts": int(self.clock.time_msec())}, + data={"ts": int(self.clock.time_msec()), "hidden": hidden}, ) is_new = await self._handle_new_receipts([receipt]) if not is_new: return - if self.federation_sender: + if self.federation_sender and not ( + self.hs.config.experimental.msc2285_enabled and hidden + ): await self.federation_sender.send_read_receipt(receipt) class ReceiptEventSource: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() + self.config = hs.config + + @staticmethod + def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]: + visible_events = [] + + # filter out hidden receipts the user shouldn't see + for event in events: + content = event.get("content", {}) + new_event = event.copy() + new_event["content"] = {} + + for event_id in content.keys(): + event_content = content.get(event_id, {}) + m_read = event_content.get("m.read", {}) + + # If m_read is missing copy over the original event_content as there is nothing to process here + if not m_read: + new_event["content"][event_id] = event_content.copy() + continue + + new_users = {} + for rr_user_id, user_rr in m_read.items(): + hidden = user_rr.get("hidden", None) + if hidden is not True or rr_user_id == user_id: + new_users[rr_user_id] = user_rr.copy() + # If hidden has a value replace hidden with the correct prefixed key + if hidden is not None: + new_users[rr_user_id].pop("hidden") + new_users[rr_user_id][ + ReadReceiptEventFields.MSC2285_HIDDEN + ] = hidden + + # Set new users unless empty + if len(new_users.keys()) > 0: + new_event["content"][event_id] = {"m.read": new_users} + + # Append new_event to visible_events unless empty + if len(new_event["content"].keys()) > 0: + visible_events.append(new_event) + + return visible_events async def get_new_events( - self, from_key: int, room_ids: List[str], **kwargs + self, from_key: int, room_ids: List[str], user: UserID, **kwargs ) -> Tuple[List[JsonDict], int]: from_key = int(from_key) to_key = self.get_current_key() @@ -175,6 +220,9 @@ class ReceiptEventSource: room_ids, from_key=from_key, to_key=to_key ) + if self.config.experimental.msc2285_enabled: + events = ReceiptEventSource.filter_out_hidden(events, user.to_string()) + return (events, to_key) async def get_new_events_as( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 1192591609..65ad3efa6a 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -16,7 +16,7 @@ import abc import logging import random from http import HTTPStatus -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple from synapse import types from synapse.api.constants import AccountDataTypes, EventTypes, Membership @@ -28,6 +28,7 @@ from synapse.api.errors import ( SynapseError, ) from synapse.api.ratelimiting import Ratelimiter +from synapse.event_auth import get_named_level, get_power_level_event from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.types import ( @@ -340,16 +341,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if event.membership == Membership.JOIN: newly_joined = True - prev_member_event = None if prev_member_event_id: prev_member_event = await self.store.get_event(prev_member_event_id) newly_joined = prev_member_event.membership != Membership.JOIN - # Check if the member should be allowed access via membership in a space. - await self.event_auth_handler.check_restricted_join_rules( - prev_state_ids, event.room_version, user_id, prev_member_event - ) - # Only rate-limit if the user actually joined the room, otherwise we'll end # up blocking profile updates. if newly_joined and ratelimit: @@ -701,7 +696,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") - if not is_host_in_room: + # Check if a remote join should be performed. + remote_join, remote_room_hosts = await self._should_perform_remote_join( + target.to_string(), room_id, remote_room_hosts, content, is_host_in_room + ) + if remote_join: if ratelimit: time_now_s = self.clock.time() ( @@ -826,6 +825,106 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): outlier=outlier, ) + async def _should_perform_remote_join( + self, + user_id: str, + room_id: str, + remote_room_hosts: List[str], + content: JsonDict, + is_host_in_room: bool, + ) -> Tuple[bool, List[str]]: + """ + Check whether the server should do a remote join (as opposed to a local + join) for a user. + + Generally a remote join is used if: + + * The server is not yet in the room. + * The server is in the room, the room has restricted join rules, the user + is not joined or invited to the room, and the server does not have + another user who is capable of issuing invites. + + Args: + user_id: The user joining the room. + room_id: The room being joined. + remote_room_hosts: A list of remote room hosts. + content: The content to use as the event body of the join. This may + be modified. + is_host_in_room: True if the host is in the room. + + Returns: + A tuple of: + True if a remote join should be performed. False if the join can be + done locally. + + A list of remote room hosts to use. This is an empty list if a + local join is to be done. + """ + # If the host isn't in the room, pass through the prospective hosts. + if not is_host_in_room: + return True, remote_room_hosts + + # If the host is in the room, but not one of the authorised hosts + # for restricted join rules, a remote join must be used. + room_version = await self.store.get_room_version(room_id) + current_state_ids = await self.store.get_current_state_ids(room_id) + + # If restricted join rules are not being used, a local join can always + # be used. + if not await self.event_auth_handler.has_restricted_join_rules( + current_state_ids, room_version + ): + return False, [] + + # If the user is invited to the room or already joined, the join + # event can always be issued locally. + prev_member_event_id = current_state_ids.get((EventTypes.Member, user_id), None) + prev_member_event = None + if prev_member_event_id: + prev_member_event = await self.store.get_event(prev_member_event_id) + if prev_member_event.membership in ( + Membership.JOIN, + Membership.INVITE, + ): + return False, [] + + # If the local host has a user who can issue invites, then a local + # join can be done. + # + # If not, generate a new list of remote hosts based on which + # can issue invites. + event_map = await self.store.get_events(current_state_ids.values()) + current_state = { + state_key: event_map[event_id] + for state_key, event_id in current_state_ids.items() + } + allowed_servers = get_servers_from_users( + get_users_which_can_issue_invite(current_state) + ) + + # If the local server is not one of allowed servers, then a remote + # join must be done. Return the list of prospective servers based on + # which can issue invites. + if self.hs.hostname not in allowed_servers: + return True, list(allowed_servers) + + # Ensure the member should be allowed access via membership in a room. + await self.event_auth_handler.check_restricted_join_rules( + current_state_ids, room_version, user_id, prev_member_event + ) + + # If this is going to be a local join, additional information must + # be included in the event content in order to efficiently validate + # the event. + content[ + "join_authorised_via_users_server" + ] = await self.event_auth_handler.get_user_which_could_invite( + room_id, + current_state_ids, + ) + + return False, [] + async def transfer_room_state_on_room_upgrade( self, old_room_id: str, room_id: str ) -> None: @@ -1514,3 +1613,63 @@ class RoomMemberMasterHandler(RoomMemberHandler): if membership: await self.store.forget(user_id, room_id) + + +def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[str]: + """ + Return the list of users which can issue invites. + + This is done by exploring the joined users and comparing their power levels + to the necessyar power level to issue an invite. + + Args: + auth_events: state in force at this point in the room + + Returns: + The users which can issue invites. + """ + invite_level = get_named_level(auth_events, "invite", 0) + users_default_level = get_named_level(auth_events, "users_default", 0) + power_level_event = get_power_level_event(auth_events) + + # Custom power-levels for users. + if power_level_event: + users = power_level_event.content.get("users", {}) + else: + users = {} + + result = [] + + # Check which members are able to invite by ensuring they're joined and have + # the necessary power level. + for (event_type, state_key), event in auth_events.items(): + if event_type != EventTypes.Member: + continue + + if event.membership != Membership.JOIN: + continue + + # Check if the user has a custom power level. + if users.get(state_key, users_default_level) >= invite_level: + result.append(state_key) + + return result + + +def get_servers_from_users(users: List[str]) -> Set[str]: + """ + Resolve a list of users into their servers. + + Args: + users: A list of users. + + Returns: + A set of servers. + """ + servers = set() + for user in users: + try: + servers.add(get_domain_from_id(user)) + except SynapseError: + pass + return servers diff --git a/synapse/http/client.py b/synapse/http/client.py index 2ac76b15c2..c2ea51ee16 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -847,7 +847,7 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): def read_body_with_max_size( response: IResponse, stream: ByteWriteable, max_size: Optional[int] -) -> defer.Deferred: +) -> "defer.Deferred[int]": """ Read a HTTP response body to a file-object. Optionally enforcing a maximum file size. @@ -862,7 +862,7 @@ def read_body_with_max_size( Returns: A Deferred which resolves to the length of the read body. """ - d = defer.Deferred() + d: "defer.Deferred[int]" = defer.Deferred() # If the Content-Length header gives a size larger than the maximum allowed # size, do not bother downloading the body. diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 950770201a..c16b7f10e6 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -27,7 +27,7 @@ from twisted.internet.interfaces import ( ) from twisted.web.client import URI, Agent, HTTPConnectionPool from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer +from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer, IResponse from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.http.client import BlacklistingAgentWrapper @@ -116,7 +116,7 @@ class MatrixFederationAgent: uri: bytes, headers: Optional[Headers] = None, bodyProducer: Optional[IBodyProducer] = None, - ) -> Generator[defer.Deferred, Any, defer.Deferred]: + ) -> Generator[defer.Deferred, Any, IResponse]: """ Args: method: HTTP method: GET/POST/etc diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index f7193e60bd..19e987f118 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -14,21 +14,32 @@ import base64 import logging import re -from typing import Optional, Tuple -from urllib.request import getproxies_environment, proxy_bypass_environment +from typing import Any, Dict, Optional, Tuple +from urllib.parse import urlparse +from urllib.request import ( # type: ignore[attr-defined] + getproxies_environment, + proxy_bypass_environment, +) import attr from zope.interface import implementer from twisted.internet import defer from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS +from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint from twisted.python.failure import Failure -from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase +from twisted.web.client import ( + URI, + BrowserLikePolicyForHTTPS, + HTTPConnectionPool, + _AgentBase, +) from twisted.web.error import SchemeNotSupported from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent, IPolicyForHTTPS +from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint +from synapse.types import ISynapseReactor logger = logging.getLogger(__name__) @@ -63,35 +74,38 @@ class ProxyAgent(_AgentBase): reactor might have some blacklisting applied (i.e. for DNS queries), but we need unblocked access to the proxy. - contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the + contextFactory: A factory for TLS contexts, to control the verification parameters of OpenSSL. The default is to use a `BrowserLikePolicyForHTTPS`, so unless you have special requirements you can leave this as-is. - connectTimeout (Optional[float]): The amount of time that this Agent will wait + connectTimeout: The amount of time that this Agent will wait for the peer to accept a connection, in seconds. If 'None', HostnameEndpoint's default (30s) will be used. - This is used for connections to both proxies and destination servers. - bindAddress (bytes): The local address for client sockets to bind to. + bindAddress: The local address for client sockets to bind to. - pool (HTTPConnectionPool|None): connection pool to be used. If None, a + pool: connection pool to be used. If None, a non-persistent pool instance will be created. - use_proxy (bool): Whether proxy settings should be discovered and used + use_proxy: Whether proxy settings should be discovered and used from conventional environment variables. + + Raises: + ValueError if use_proxy is set and the environment variables + contain an invalid proxy specification. """ def __init__( self, - reactor, - proxy_reactor=None, + reactor: IReactorCore, + proxy_reactor: Optional[ISynapseReactor] = None, contextFactory: Optional[IPolicyForHTTPS] = None, - connectTimeout=None, - bindAddress=None, - pool=None, - use_proxy=False, + connectTimeout: Optional[float] = None, + bindAddress: Optional[bytes] = None, + pool: Optional[HTTPConnectionPool] = None, + use_proxy: bool = False, ): contextFactory = contextFactory or BrowserLikePolicyForHTTPS() @@ -102,7 +116,7 @@ class ProxyAgent(_AgentBase): else: self.proxy_reactor = proxy_reactor - self._endpoint_kwargs = {} + self._endpoint_kwargs: Dict[str, Any] = {} if connectTimeout is not None: self._endpoint_kwargs["timeout"] = connectTimeout if bindAddress is not None: @@ -117,16 +131,12 @@ class ProxyAgent(_AgentBase): https_proxy = proxies["https"].encode() if "https" in proxies else None no_proxy = proxies["no"] if "no" in proxies else None - # Parse credentials from http and https proxy connection string if present - self.http_proxy_creds, http_proxy = parse_username_password(http_proxy) - self.https_proxy_creds, https_proxy = parse_username_password(https_proxy) - - self.http_proxy_endpoint = _http_proxy_endpoint( - http_proxy, self.proxy_reactor, **self._endpoint_kwargs + self.http_proxy_endpoint, self.http_proxy_creds = _http_proxy_endpoint( + http_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs ) - self.https_proxy_endpoint = _http_proxy_endpoint( - https_proxy, self.proxy_reactor, **self._endpoint_kwargs + self.https_proxy_endpoint, self.https_proxy_creds = _http_proxy_endpoint( + https_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs ) self.no_proxy = no_proxy @@ -134,7 +144,13 @@ class ProxyAgent(_AgentBase): self._policy_for_https = contextFactory self._reactor = reactor - def request(self, method, uri, headers=None, bodyProducer=None): + def request( + self, + method: bytes, + uri: bytes, + headers: Optional[Headers] = None, + bodyProducer: Optional[IBodyProducer] = None, + ) -> defer.Deferred: """ Issue a request to the server indicated by the given uri. @@ -146,16 +162,15 @@ class ProxyAgent(_AgentBase): See also: twisted.web.iweb.IAgent.request Args: - method (bytes): The request method to use, such as `GET`, `POST`, etc + method: The request method to use, such as `GET`, `POST`, etc - uri (bytes): The location of the resource to request. + uri: The location of the resource to request. - headers (Headers|None): Extra headers to send with the request + headers: Extra headers to send with the request - bodyProducer (IBodyProducer|None): An object which can generate bytes to - make up the body of this request (for example, the properly encoded - contents of a file for a file upload). Or, None if the request is to - have no body. + bodyProducer: An object which can generate bytes to make up the body of + this request (for example, the properly encoded contents of a file for + a file upload). Or, None if the request is to have no body. Returns: Deferred[IResponse]: completes when the header of the response has @@ -253,70 +268,89 @@ class ProxyAgent(_AgentBase): ) -def _http_proxy_endpoint(proxy: Optional[bytes], reactor, **kwargs): +def _http_proxy_endpoint( + proxy: Optional[bytes], + reactor: IReactorCore, + tls_options_factory: IPolicyForHTTPS, + **kwargs, +) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]: """Parses an http proxy setting and returns an endpoint for the proxy Args: - proxy: the proxy setting in the form: [<username>:<password>@]<host>[:<port>] - Note that compared to other apps, this function currently lacks support - for specifying a protocol schema (i.e. protocol://...). + proxy: the proxy setting in the form: [scheme://][<username>:<password>@]<host>[:<port>] + This currently supports http:// and https:// proxies. + A hostname without scheme is assumed to be http. reactor: reactor to be used to connect to the proxy + tls_options_factory: the TLS options to use when connecting through a https proxy + kwargs: other args to be passed to HostnameEndpoint Returns: - interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy, - or None + a tuple of + endpoint to use to connect to the proxy, or None + ProxyCredentials or if no credentials were found, or None + + Raise: + ValueError if proxy has no hostname or unsupported scheme. """ if proxy is None: - return None + return None, None - # Parse the connection string - host, port = parse_host_port(proxy, default_port=1080) - return HostnameEndpoint(reactor, host, port, **kwargs) + # Note: urlsplit/urlparse cannot be used here as that does not work (for Python + # 3.9+) on scheme-less proxies, e.g. host:port. + scheme, host, port, credentials = parse_proxy(proxy) + proxy_endpoint = HostnameEndpoint(reactor, host, port, **kwargs) -def parse_username_password(proxy: bytes) -> Tuple[Optional[ProxyCredentials], bytes]: - """ - Parses the username and password from a proxy declaration e.g - username:password@hostname:port. + if scheme == b"https": + tls_options = tls_options_factory.creatorForNetloc(host, port) + proxy_endpoint = wrapClientTLS(tls_options, proxy_endpoint) - Args: - proxy: The proxy connection string. + return proxy_endpoint, credentials - Returns - An instance of ProxyCredentials and the proxy connection string with any credentials - stripped, i.e u:p@host:port -> host:port. If no credentials were found, the - ProxyCredentials instance is replaced with None. - """ - if proxy and b"@" in proxy: - # We use rsplit here as the password could contain an @ character - credentials, proxy_without_credentials = proxy.rsplit(b"@", 1) - return ProxyCredentials(credentials), proxy_without_credentials - return None, proxy +def parse_proxy( + proxy: bytes, default_scheme: bytes = b"http", default_port: int = 1080 +) -> Tuple[bytes, bytes, int, Optional[ProxyCredentials]]: + """ + Parse a proxy connection string. + Given a HTTP proxy URL, breaks it down into components and checks that it + has a hostname (otherwise it is not useful to us when trying to find a + proxy) and asserts that the URL has a scheme we support. -def parse_host_port(hostport: bytes, default_port: int = None) -> Tuple[bytes, int]: - """ - Parse the hostname and port from a proxy connection byte string. Args: - hostport: The proxy connection string. Must be in the form 'host[:port]'. - default_port: The default port to return if one is not found in `hostport`. + proxy: The proxy connection string. Must be in the form '[scheme://][<username>:<password>@]host[:port]'. + default_scheme: The default scheme to return if one is not found in `proxy`. Defaults to http + default_port: The default port to return if one is not found in `proxy`. Defaults to 1080 Returns: - A tuple containing the hostname and port. Uses `default_port` if one was not found. + A tuple containing the scheme, hostname, port and ProxyCredentials. + If no credentials were found, the ProxyCredentials instance is replaced with None. + + Raise: + ValueError if proxy has no hostname or unsupported scheme. """ - if b":" in hostport: - host, port = hostport.rsplit(b":", 1) - try: - port = int(port) - return host, port - except ValueError: - # the thing after the : wasn't a valid port; presumably this is an - # IPv6 address. - pass + # First check if we have a scheme present + # Note: urlsplit/urlparse cannot be used (for Python # 3.9+) on scheme-less proxies, e.g. host:port. + if b"://" not in proxy: + proxy = b"".join([default_scheme, b"://", proxy]) + + url = urlparse(proxy) + + if not url.hostname: + raise ValueError("Proxy URL did not contain a hostname! Please specify one.") + + if url.scheme not in (b"http", b"https"): + raise ValueError( + f"Unknown proxy scheme {url.scheme!s}; only 'http' and 'https' is supported." + ) + + credentials = None + if url.username and url.password: + credentials = ProxyCredentials(b"".join([url.username, b":", url.password])) - return hostport, default_port + return url.scheme, url.hostname, url.port or default_port, credentials diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 04560fb589..732a1e6aeb 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -14,47 +14,86 @@ """ This module contains base REST classes for constructing REST servlets. """ import logging -from typing import Dict, Iterable, List, Optional, overload +from typing import Iterable, List, Mapping, Optional, Sequence, overload from typing_extensions import Literal from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError +from synapse.types import JsonDict from synapse.util import json_decoder logger = logging.getLogger(__name__) -def parse_integer(request, name, default=None, required=False): +@overload +def parse_integer(request: Request, name: str, default: int) -> int: + ... + + +@overload +def parse_integer(request: Request, name: str, *, required: Literal[True]) -> int: + ... + + +@overload +def parse_integer( + request: Request, name: str, default: Optional[int] = None, required: bool = False +) -> Optional[int]: + ... + + +def parse_integer( + request: Request, name: str, default: Optional[int] = None, required: bool = False +) -> Optional[int]: """Parse an integer parameter from the request string Args: request: the twisted HTTP request. - name (bytes/unicode): the name of the query parameter. - default (int|None): value to use if the parameter is absent, defaults - to None. - required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. + name: the name of the query parameter. + default: value to use if the parameter is absent, defaults to None. + required: whether to raise a 400 SynapseError if the parameter is absent, + defaults to False. Returns: - int|None: An int value or the default. + An int value or the default. Raises: SynapseError: if the parameter is absent and required, or if the parameter is present and not an integer. """ - return parse_integer_from_args(request.args, name, default, required) + args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore + return parse_integer_from_args(args, name, default, required) + +def parse_integer_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: Optional[int] = None, + required: bool = False, +) -> Optional[int]: + """Parse an integer parameter from the request string + + Args: + args: A mapping of request args as bytes to a list of bytes (e.g. request.args). + name: the name of the query parameter. + default: value to use if the parameter is absent, defaults to None. + required: whether to raise a 400 SynapseError if the parameter is absent, + defaults to False. -def parse_integer_from_args(args, name, default=None, required=False): + Returns: + An int value or the default. - if not isinstance(name, bytes): - name = name.encode("ascii") + Raises: + SynapseError: if the parameter is absent and required, or if the + parameter is present and not an integer. + """ + name_bytes = name.encode("ascii") - if name in args: + if name_bytes in args: try: - return int(args[name][0]) + return int(args[name_bytes][0]) except Exception: message = "Query parameter %r must be an integer" % (name,) raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) @@ -66,36 +105,102 @@ def parse_integer_from_args(args, name, default=None, required=False): return default -def parse_boolean(request, name, default=None, required=False): +@overload +def parse_boolean(request: Request, name: str, default: bool) -> bool: + ... + + +@overload +def parse_boolean(request: Request, name: str, *, required: Literal[True]) -> bool: + ... + + +@overload +def parse_boolean( + request: Request, name: str, default: Optional[bool] = None, required: bool = False +) -> Optional[bool]: + ... + + +def parse_boolean( + request: Request, name: str, default: Optional[bool] = None, required: bool = False +) -> Optional[bool]: """Parse a boolean parameter from the request query string Args: request: the twisted HTTP request. - name (bytes/unicode): the name of the query parameter. - default (bool|None): value to use if the parameter is absent, defaults - to None. - required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. + name: the name of the query parameter. + default: value to use if the parameter is absent, defaults to None. + required: whether to raise a 400 SynapseError if the parameter is absent, + defaults to False. Returns: - bool|None: A bool value or the default. + A bool value or the default. Raises: SynapseError: if the parameter is absent and required, or if the parameter is present and not one of "true" or "false". """ + args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore + return parse_boolean_from_args(args, name, default, required) + + +@overload +def parse_boolean_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: bool, +) -> bool: + ... + - return parse_boolean_from_args(request.args, name, default, required) +@overload +def parse_boolean_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + *, + required: Literal[True], +) -> bool: + ... -def parse_boolean_from_args(args, name, default=None, required=False): +@overload +def parse_boolean_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: Optional[bool] = None, + required: bool = False, +) -> Optional[bool]: + ... - if not isinstance(name, bytes): - name = name.encode("ascii") - if name in args: +def parse_boolean_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: Optional[bool] = None, + required: bool = False, +) -> Optional[bool]: + """Parse a boolean parameter from the request query string + + Args: + args: A mapping of request args as bytes to a list of bytes (e.g. request.args). + name: the name of the query parameter. + default: value to use if the parameter is absent, defaults to None. + required: whether to raise a 400 SynapseError if the parameter is absent, + defaults to False. + + Returns: + A bool value or the default. + + Raises: + SynapseError: if the parameter is absent and required, or if the + parameter is present and not one of "true" or "false". + """ + name_bytes = name.encode("ascii") + + if name_bytes in args: try: - return {b"true": True, b"false": False}[args[name][0]] + return {b"true": True, b"false": False}[args[name_bytes][0]] except Exception: message = ( "Boolean query parameter %r must be one of ['true', 'false']" @@ -111,7 +216,7 @@ def parse_boolean_from_args(args, name, default=None, required=False): @overload def parse_bytes_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[bytes] = None, ) -> Optional[bytes]: @@ -120,7 +225,7 @@ def parse_bytes_from_args( @overload def parse_bytes_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Literal[None] = None, *, @@ -131,7 +236,7 @@ def parse_bytes_from_args( @overload def parse_bytes_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[bytes] = None, required: bool = False, @@ -140,7 +245,7 @@ def parse_bytes_from_args( def parse_bytes_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[bytes] = None, required: bool = False, @@ -172,6 +277,42 @@ def parse_bytes_from_args( return default +@overload +def parse_string( + request: Request, + name: str, + default: str, + *, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> str: + ... + + +@overload +def parse_string( + request: Request, + name: str, + *, + required: Literal[True], + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> str: + ... + + +@overload +def parse_string( + request: Request, + name: str, + *, + required: bool = False, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> Optional[str]: + ... + + def parse_string( request: Request, name: str, @@ -179,7 +320,7 @@ def parse_string( required: bool = False, allowed_values: Optional[Iterable[str]] = None, encoding: str = "ascii", -): +) -> Optional[str]: """ Parse a string parameter from the request query string. @@ -205,7 +346,7 @@ def parse_string( parameter is present, must be one of a list of allowed values and is not one of those allowed values. """ - args: Dict[bytes, List[bytes]] = request.args # type: ignore + args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore return parse_string_from_args( args, name, @@ -239,9 +380,8 @@ def _parse_string_value( @overload def parse_strings_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[List[str]] = None, *, allowed_values: Optional[Iterable[str]] = None, encoding: str = "ascii", @@ -251,9 +391,20 @@ def parse_strings_from_args( @overload def parse_strings_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: List[str], + *, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> List[str]: + ... + + +@overload +def parse_strings_from_args( + args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[List[str]] = None, *, required: Literal[True], allowed_values: Optional[Iterable[str]] = None, @@ -264,7 +415,7 @@ def parse_strings_from_args( @overload def parse_strings_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[List[str]] = None, *, @@ -276,7 +427,7 @@ def parse_strings_from_args( def parse_strings_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[List[str]] = None, required: bool = False, @@ -325,7 +476,7 @@ def parse_strings_from_args( @overload def parse_string_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[str] = None, *, @@ -337,7 +488,7 @@ def parse_string_from_args( @overload def parse_string_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[str] = None, *, @@ -350,7 +501,7 @@ def parse_string_from_args( @overload def parse_string_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[str] = None, required: bool = False, @@ -361,7 +512,7 @@ def parse_string_from_args( def parse_string_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[str] = None, required: bool = False, @@ -409,13 +560,14 @@ def parse_string_from_args( return strings[0] -def parse_json_value_from_request(request, allow_empty_body=False): +def parse_json_value_from_request( + request: Request, allow_empty_body: bool = False +) -> Optional[JsonDict]: """Parse a JSON value from the body of a twisted HTTP request. Args: request: the twisted HTTP request. - allow_empty_body (bool): if True, an empty body will be accepted and - turned into None + allow_empty_body: if True, an empty body will be accepted and turned into None Returns: The JSON value. @@ -424,7 +576,7 @@ def parse_json_value_from_request(request, allow_empty_body=False): SynapseError if the request body couldn't be decoded as JSON. """ try: - content_bytes = request.content.read() + content_bytes = request.content.read() # type: ignore except Exception: raise SynapseError(400, "Error reading JSON content.") @@ -440,13 +592,15 @@ def parse_json_value_from_request(request, allow_empty_body=False): return content -def parse_json_object_from_request(request, allow_empty_body=False): +def parse_json_object_from_request( + request: Request, allow_empty_body: bool = False +) -> JsonDict: """Parse a JSON object from the body of a twisted HTTP request. Args: request: the twisted HTTP request. - allow_empty_body (bool): if True, an empty body will be accepted and - turned into an empty dict. + allow_empty_body: if True, an empty body will be accepted and turned into + an empty dict. Raises: SynapseError if the request body couldn't be decoded as JSON or @@ -457,14 +611,14 @@ def parse_json_object_from_request(request, allow_empty_body=False): if allow_empty_body and content is None: return {} - if type(content) != dict: + if not isinstance(content, dict): message = "Content must be a JSON object." raise SynapseError(400, message, errcode=Codes.BAD_JSON) return content -def assert_params_in_dict(body, required): +def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None: absent = [] for k in required: if k not in body: diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 18ac507802..02e5ddd2ef 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -25,7 +25,7 @@ See doc/log_contexts.rst for details on how this works. import inspect import logging import threading -import types +import typing import warnings from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union @@ -745,7 +745,7 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred: # by synchronous exceptions, so let's turn them into Failures. return defer.fail() - if isinstance(res, types.CoroutineType): + if isinstance(res, typing.Coroutine): res = defer.ensureDeferred(res) # At this point we should have a Deferred, if not then f was a synchronous diff --git a/synapse/logging/handlers.py b/synapse/logging/handlers.py new file mode 100644 index 0000000000..a6c212f300 --- /dev/null +++ b/synapse/logging/handlers.py @@ -0,0 +1,88 @@ +import logging +import time +from logging import Handler, LogRecord +from logging.handlers import MemoryHandler +from threading import Thread +from typing import Optional + +from twisted.internet.interfaces import IReactorCore + + +class PeriodicallyFlushingMemoryHandler(MemoryHandler): + """ + This is a subclass of MemoryHandler that additionally spawns a background + thread to periodically flush the buffer. + + This prevents messages from being buffered for too long. + + Additionally, all messages will be immediately flushed if the reactor has + not yet been started. + """ + + def __init__( + self, + capacity: int, + flushLevel: int = logging.ERROR, + target: Optional[Handler] = None, + flushOnClose: bool = True, + period: float = 5.0, + reactor: Optional[IReactorCore] = None, + ) -> None: + """ + period: the period between automatic flushes + + reactor: if specified, a custom reactor to use. If not specifies, + defaults to the globally-installed reactor. + Log entries will be flushed immediately until this reactor has + started. + """ + super().__init__(capacity, flushLevel, target, flushOnClose) + + self._flush_period: float = period + self._active: bool = True + self._reactor_started = False + + self._flushing_thread: Thread = Thread( + name="PeriodicallyFlushingMemoryHandler flushing thread", + target=self._flush_periodically, + ) + self._flushing_thread.start() + + def on_reactor_running(): + self._reactor_started = True + + reactor_to_use: IReactorCore + if reactor is None: + from twisted.internet import reactor as global_reactor + + reactor_to_use = global_reactor # type: ignore[assignment] + else: + reactor_to_use = reactor + + # call our hook when the reactor start up + reactor_to_use.callWhenRunning(on_reactor_running) + + def shouldFlush(self, record: LogRecord) -> bool: + """ + Before reactor start-up, log everything immediately. + Otherwise, fall back to original behaviour of waiting for the buffer to fill. + """ + + if self._reactor_started: + return super().shouldFlush(record) + else: + return True + + def _flush_periodically(self): + """ + Whilst this handler is active, flush the handler periodically. + """ + + while self._active: + # flush is thread-safe; it acquires and releases the lock internally + self.flush() + time.sleep(self._flush_period) + + def close(self) -> None: + self._active = False + super().close() diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 1259fc2d90..473812b8e2 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -484,7 +484,7 @@ class ModuleApi: @defer.inlineCallbacks def get_state_events_in_room( self, room_id: str, types: Iterable[Tuple[str, Optional[str]]] - ) -> Generator[defer.Deferred, Any, defer.Deferred]: + ) -> Generator[defer.Deferred, Any, Iterable[EventBase]]: """Gets current state events for the given room. (This is exposed for compatibility with the old SpamCheckerApi. We should diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 7be5fe1e9b..941fb238b7 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar import bleach import jinja2 -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, Membership, RoomTypes from synapse.api.errors import StoreError from synapse.config.emailconfig import EmailSubjectConfig from synapse.events import EventBase @@ -600,6 +600,22 @@ class Mailer: "app": self.app_name, } + # If the room is a space, it gets a slightly different topic. + create_event_id = room_state_ids.get(("m.room.create", "")) + if create_event_id: + create_event = await self.store.get_event( + create_event_id, allow_none=True + ) + if ( + create_event + and create_event.content.get("room_type") == RoomTypes.SPACE + ): + return self.email_subjects.invite_from_person_to_space % { + "person": inviter_name, + "space": room_name, + "app": self.app_name, + } + return self.email_subjects.invite_from_person_to_room % { "person": inviter_name, "room": room_name, diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 9d4859798b..3fd2811713 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -285,7 +285,7 @@ class ReplicationDataHandler: # Create a new deferred that times out after N seconds, as we don't want # to wedge here forever. - deferred = Deferred() + deferred: "Deferred[None]" = Deferred() deferred = timeout_deferred( deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor ) @@ -393,6 +393,11 @@ class FederationSenderHandler: # we only want to send on receipts for our own users if not self._is_mine_id(receipt.user_id): continue + if ( + receipt.data.get("hidden", False) + and self._hs.config.experimental.msc2285_enabled + ): + continue receipt_info = ReadReceipt( receipt.room_id, receipt.receipt_type, diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 589e47fa47..eef76ab18a 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -62,6 +62,7 @@ class UsersRestServletV2(RestServlet): The parameter `name` can be used to filter by user id or display name. The parameter `guests` can be used to exclude guest users. The parameter `deactivated` can be used to include deactivated users. + The parameter `order_by` can be used to order the result. """ def __init__(self, hs: "HomeServer"): @@ -90,8 +91,8 @@ class UsersRestServletV2(RestServlet): errcode=Codes.INVALID_PARAM, ) - user_id = parse_string(request, "user_id", default=None) - name = parse_string(request, "name", default=None) + user_id = parse_string(request, "user_id") + name = parse_string(request, "name") guests = parse_boolean(request, "guests", default=True) deactivated = parse_boolean(request, "deactivated", default=False) @@ -108,6 +109,7 @@ class UsersRestServletV2(RestServlet): UserSortOrder.USER_TYPE.value, UserSortOrder.AVATAR_URL.value, UserSortOrder.SHADOW_BANNED.value, + UserSortOrder.CREATION_TS.value, ), ) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 31a1193cd3..25ba52c624 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -413,7 +413,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): assert_params_in_dict(body, ["state_events_at_start", "events"]) prev_events_from_query = parse_strings_from_args(request.args, "prev_event") - chunk_id_from_query = parse_string(request, "chunk_id", default=None) + chunk_id_from_query = parse_string(request, "chunk_id") if prev_events_from_query is None: raise SynapseError( @@ -553,9 +553,18 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): ] # Connect this current chunk to the insertion event from the previous chunk - last_event_in_chunk["content"][ - EventContentFields.MSC2716_CHUNK_ID - ] = chunk_id_to_connect_to + chunk_event = { + "type": EventTypes.MSC2716_CHUNK, + "sender": requester.user.to_string(), + "room_id": room_id, + "content": {EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to}, + # Since the chunk event is put at the end of the chunk, + # where the newest-in-time event is, copy the origin_server_ts from + # the last event we're inserting + "origin_server_ts": last_event_in_chunk["origin_server_ts"], + } + # Add the chunk event to the end of the chunk (newest-in-time) + events_to_create.append(chunk_event) # Add an "insertion" event to the start of each chunk (next to the oldest-in-time # event in the chunk) so the next chunk can be connected to this one. @@ -567,7 +576,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet): # the first event we're inserting origin_server_ts=events_to_create[0]["origin_server_ts"], ) - # Prepend the insertion event to the start of the chunk + # Prepend the insertion event to the start of the chunk (oldest-in-time) events_to_create = [insertion_event] + events_to_create event_ids = [] @@ -726,7 +735,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): self.auth = hs.get_auth() async def on_GET(self, request): - server = parse_string(request, "server", default=None) + server = parse_string(request, "server") try: await self.auth.get_user_by_req(request, allow_guest=True) @@ -745,8 +754,8 @@ class PublicRoomListRestServlet(TransactionRestServlet): if server: raise e - limit = parse_integer(request, "limit", 0) - since_token = parse_string(request, "since", None) + limit: Optional[int] = parse_integer(request, "limit", 0) + since_token = parse_string(request, "since") if limit == 0: # zero is a special value which corresponds to no limit. @@ -780,7 +789,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): async def on_POST(self, request): await self.auth.get_user_by_req(request, allow_guest=True) - server = parse_string(request, "server", default=None) + server = parse_string(request, "server") content = parse_json_object_from_request(request) limit: Optional[int] = int(content.get("limit", 100)) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 085561d3e9..fb5ad2906e 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -884,7 +884,14 @@ class WhoamiRestServlet(RestServlet): async def on_GET(self, request): requester = await self.auth.get_user_by_req(request) - return 200, {"user_id": requester.user.to_string()} + response = {"user_id": requester.user.to_string()} + + # Appservices and similar accounts do not have device IDs + # that we can report on, so exclude them for compliance. + if requester.device_id is not None: + response["device_id"] = requester.device_id + + return 200, response def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py index 6a24021484..88e3aac797 100644 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ b/synapse/rest/client/v2_alpha/capabilities.py @@ -14,7 +14,7 @@ import logging from typing import TYPE_CHECKING, Tuple -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES from synapse.http.servlet import RestServlet from synapse.http.site import SynapseRequest from synapse.types import JsonDict @@ -55,6 +55,12 @@ class CapabilitiesRestServlet(RestServlet): "m.change_password": {"enabled": change_password}, } } + + if self.config.experimental.msc3244_enabled: + response["capabilities"]["m.room_versions"][ + "org.matrix.msc3244.room_capabilities" + ] = MSC3244_CAPABILITIES + return 200, response diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 33cf8de186..d0d9d30d40 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -194,7 +194,7 @@ class KeyChangesServlet(RestServlet): async def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True) - from_token_string = parse_string(request, "from") + from_token_string = parse_string(request, "from", required=True) set_tag("from", from_token_string) # We want to enforce they do pass us one, but we ignore it and return diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py index 5988fa47e5..027f8b81fa 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -14,6 +14,8 @@ import logging +from synapse.api.constants import ReadReceiptEventFields +from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request from ._base import client_patterns @@ -37,14 +39,24 @@ class ReadMarkerRestServlet(RestServlet): await self.presence_handler.bump_presence_active_time(requester.user) body = parse_json_object_from_request(request) - read_event_id = body.get("m.read", None) + hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) + + if not isinstance(hidden, bool): + raise SynapseError( + 400, + "Param %s must be a boolean, if given" + % ReadReceiptEventFields.MSC2285_HIDDEN, + Codes.BAD_JSON, + ) + if read_event_id: await self.receipts_handler.received_client_receipt( room_id, "m.read", user_id=requester.user.to_string(), event_id=read_event_id, + hidden=hidden, ) read_marker_event_id = body.get("m.fully_read", None) diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 8cf4aebdbe..4b98979b47 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -14,8 +14,9 @@ import logging -from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet +from synapse.api.constants import ReadReceiptEventFields +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet, parse_json_object_from_request from ._base import client_patterns @@ -42,10 +43,25 @@ class ReceiptRestServlet(RestServlet): if receipt_type != "m.read": raise SynapseError(400, "Receipt type must be 'm.read'") + body = parse_json_object_from_request(request) + hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) + + if not isinstance(hidden, bool): + raise SynapseError( + 400, + "Param %s must be a boolean, if given" + % ReadReceiptEventFields.MSC2285_HIDDEN, + Codes.BAD_JSON, + ) + await self.presence_handler.bump_presence_active_time(requester.user) await self.receipts_handler.received_client_receipt( - room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id + room_id, + receipt_type, + user_id=requester.user.to_string(), + event_id=event_id, + hidden=hidden, ) return 200, {} diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index c7da6759db..0821cd285f 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -158,19 +158,21 @@ class RelationPaginationServlet(RestServlet): event = await self.event_handler.get_event(requester.user, room_id, parent_id) limit = parse_integer(request, "limit", default=5) - from_token = parse_string(request, "from") - to_token = parse_string(request, "to") + from_token_str = parse_string(request, "from") + to_token_str = parse_string(request, "to") if event.internal_metadata.is_redacted(): # If the event is redacted, return an empty list of relations pagination_chunk = PaginationChunk(chunk=[]) else: # Return the relations - if from_token: - from_token = RelationPaginationToken.from_string(from_token) + from_token = None + if from_token_str: + from_token = RelationPaginationToken.from_string(from_token_str) - if to_token: - to_token = RelationPaginationToken.from_string(to_token) + to_token = None + if to_token_str: + to_token = RelationPaginationToken.from_string(to_token_str) pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, @@ -256,19 +258,21 @@ class RelationAggregationPaginationServlet(RestServlet): raise SynapseError(400, "Relation type must be 'annotation'") limit = parse_integer(request, "limit", default=5) - from_token = parse_string(request, "from") - to_token = parse_string(request, "to") + from_token_str = parse_string(request, "from") + to_token_str = parse_string(request, "to") if event.internal_metadata.is_redacted(): # If the event is redacted, return an empty list of relations pagination_chunk = PaginationChunk(chunk=[]) else: # Return the relations - if from_token: - from_token = AggregationPaginationToken.from_string(from_token) + from_token = None + if from_token_str: + from_token = AggregationPaginationToken.from_string(from_token_str) - if to_token: - to_token = AggregationPaginationToken.from_string(to_token) + to_token = None + if to_token_str: + to_token = AggregationPaginationToken.from_string(to_token_str) pagination_chunk = await self.store.get_aggregation_groups_for_event( event_id=parent_id, @@ -336,14 +340,16 @@ class RelationAggregationGroupPaginationServlet(RestServlet): raise SynapseError(400, "Relation type must be 'annotation'") limit = parse_integer(request, "limit", default=5) - from_token = parse_string(request, "from") - to_token = parse_string(request, "to") + from_token_str = parse_string(request, "from") + to_token_str = parse_string(request, "to") - if from_token: - from_token = RelationPaginationToken.from_string(from_token) + from_token = None + if from_token_str: + from_token = RelationPaginationToken.from_string(from_token_str) - if to_token: - to_token = RelationPaginationToken.from_string(to_token) + to_token = None + if to_token_str: + to_token = RelationPaginationToken.from_string(to_token_str) result = await self.store.get_relations_for_event( event_id=parent_id, diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 32e8500795..e321668698 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -112,7 +112,7 @@ class SyncRestServlet(RestServlet): default="online", allowed_values=self.ALLOWED_PRESENCE, ) - filter_id = parse_string(request, "filter", default=None) + filter_id = parse_string(request, "filter") full_state = parse_boolean(request, "full_state", default=False) logger.debug( diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 4582c274c7..fa2e4e9cba 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -82,6 +82,8 @@ class VersionsRestServlet(RestServlet): "io.element.e2ee_forced.trusted_private": self.e2ee_forced_trusted_private, # Supports the busy presence state described in MSC3026. "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled, + # Supports receiving hidden read receipts as per MSC2285 + "org.matrix.msc2285": self.config.experimental.msc2285_enabled, }, }, ) diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 4282e2b228..11f7320832 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -112,7 +112,7 @@ class ConsentResource(DirectServeHtmlResource): request (twisted.web.http.Request): """ version = parse_string(request, "v", default=self._default_consent_version) - username = parse_string(request, "u", required=False, default="") + username = parse_string(request, "u", default="") userhmac = None has_consented = False public_version = username == "" diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index cd2468f9c5..d6d938953e 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -49,6 +49,8 @@ class DownloadResource(DirectServeJsonResource): b" media-src 'self';" b" object-src 'self';", ) + # Limited non-standard form of CSP for IE11 + request.setHeader(b"X-Content-Security-Policy", b"sandbox;") request.setHeader( b"Referrer-Policy", b"no-referrer", diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 8e7fead3a2..0f051d4041 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -58,9 +58,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -_charset_match = re.compile(br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9-]+)"?', flags=re.I) +_charset_match = re.compile( + br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I +) _xml_encoding_match = re.compile( - br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9-]+)"', flags=re.I + br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I ) _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) @@ -186,15 +188,11 @@ class PreviewUrlResource(DirectServeJsonResource): respond_with_json(request, 200, {}, send_cors=True) async def _async_render_GET(self, request: SynapseRequest) -> None: - # This will always be set by the time Twisted calls us. - assert request.args is not None - # XXX: if get_user_by_req fails, what should we do in an async render? requester = await self.auth.get_user_by_req(request) - url = parse_string(request, "url") - if b"ts" in request.args: - ts = parse_integer(request, "ts") - else: + url = parse_string(request, "url", required=True) + ts = parse_integer(request, "ts") + if ts is None: ts = self.clock.time_msec() # XXX: we could move this into _do_preview if we wanted. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 6223daf522..463ce58dae 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,6 +16,7 @@ import heapq import logging from collections import defaultdict, namedtuple from typing import ( + TYPE_CHECKING, Any, Awaitable, Callable, @@ -52,6 +53,10 @@ from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure, measure_func +if TYPE_CHECKING: + from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore + logger = logging.getLogger(__name__) metrics_logger = logging.getLogger("synapse.state.metrics") @@ -74,7 +79,7 @@ _NEXT_STATE_ID = 1 POWER_KEY = (EventTypes.PowerLevels, "") -def _gen_state_id(): +def _gen_state_id() -> str: global _NEXT_STATE_ID s = "X%d" % (_NEXT_STATE_ID,) _NEXT_STATE_ID += 1 @@ -109,7 +114,7 @@ class _StateCacheEntry: # `state_id` is either a state_group (and so an int) or a string. This # ensures we don't accidentally persist a state_id as a stateg_group if state_group: - self.state_id = state_group + self.state_id: Union[str, int] = state_group else: self.state_id = _gen_state_id() @@ -122,7 +127,7 @@ class StateHandler: where necessary """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastore() self.state_store = hs.get_storage().state @@ -507,7 +512,7 @@ class StateResolutionHandler: be storage-independent. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.resolve_linearizer = Linearizer(name="state_resolve_lock") @@ -636,16 +641,20 @@ class StateResolutionHandler: """ try: with Measure(self.clock, "state._resolve_events") as m: - v = KNOWN_ROOM_VERSIONS[room_version] - if v.state_res == StateResolutionVersions.V1: + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + if room_version_obj.state_res == StateResolutionVersions.V1: return await v1.resolve_events_with_store( - room_id, state_sets, event_map, state_res_store.get_events + room_id, + room_version_obj, + state_sets, + event_map, + state_res_store.get_events, ) else: return await v2.resolve_events_with_store( self.clock, room_id, - room_version, + room_version_obj, state_sets, event_map, state_res_store, @@ -653,13 +662,15 @@ class StateResolutionHandler: finally: self._record_state_res_metrics(room_id, m.get_resource_usage()) - def _record_state_res_metrics(self, room_id: str, rusage: ContextResourceUsage): + def _record_state_res_metrics( + self, room_id: str, rusage: ContextResourceUsage + ) -> None: room_metrics = self._state_res_metrics[room_id] room_metrics.cpu_time += rusage.ru_utime + rusage.ru_stime room_metrics.db_time += rusage.db_txn_duration_sec room_metrics.db_events += rusage.evt_db_fetch_count - def _report_metrics(self): + def _report_metrics(self) -> None: if not self._state_res_metrics: # no state res has happened since the last iteration: don't bother logging. return @@ -769,16 +780,13 @@ def _make_state_cache_entry( ) -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class StateResolutionStore: """Interface that allows state resolution algorithms to access the database in well defined way. - - Args: - store (DataStore) """ - store = attr.ib() + store: "DataStore" def get_events( self, event_ids: Iterable[str], allow_rejected: bool = False diff --git a/synapse/state/v1.py b/synapse/state/v1.py index 267193cedf..92336d7cc8 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -29,7 +29,7 @@ from typing import ( from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError -from synapse.api.room_versions import RoomVersions +from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.events import EventBase from synapse.types import MutableStateMap, StateMap @@ -41,6 +41,7 @@ POWER_KEY = (EventTypes.PowerLevels, "") async def resolve_events_with_store( room_id: str, + room_version: RoomVersion, state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]], @@ -104,7 +105,7 @@ async def resolve_events_with_store( # get the ids of the auth events which allow us to authenticate the # conflicted state, picking only from the unconflicting state. auth_events = _create_auth_events_from_maps( - unconflicted_state, conflicted_state, state_map + room_version, unconflicted_state, conflicted_state, state_map ) new_needed_events = set(auth_events.values()) @@ -132,7 +133,7 @@ async def resolve_events_with_store( state_map.update(state_map_new) return _resolve_with_state( - unconflicted_state, conflicted_state, auth_events, state_map + room_version, unconflicted_state, conflicted_state, auth_events, state_map ) @@ -187,6 +188,7 @@ def _seperate( def _create_auth_events_from_maps( + room_version: RoomVersion, unconflicted_state: StateMap[str], conflicted_state: StateMap[Set[str]], state_map: Dict[str, EventBase], @@ -194,6 +196,7 @@ def _create_auth_events_from_maps( """ Args: + room_version: The room version. unconflicted_state: The unconflicted state map. conflicted_state: The conflicted state map. state_map: @@ -205,7 +208,9 @@ def _create_auth_events_from_maps( for event_ids in conflicted_state.values(): for event_id in event_ids: if event_id in state_map: - keys = event_auth.auth_types_for_event(state_map[event_id]) + keys = event_auth.auth_types_for_event( + room_version, state_map[event_id] + ) for key in keys: if key not in auth_events: auth_event_id = unconflicted_state.get(key, None) @@ -215,6 +220,7 @@ def _create_auth_events_from_maps( def _resolve_with_state( + room_version: RoomVersion, unconflicted_state_ids: MutableStateMap[str], conflicted_state_ids: StateMap[Set[str]], auth_event_ids: StateMap[str], @@ -235,7 +241,9 @@ def _resolve_with_state( } try: - resolved_state = _resolve_state_events(conflicted_state, auth_events) + resolved_state = _resolve_state_events( + room_version, conflicted_state, auth_events + ) except Exception: logger.exception("Failed to resolve state") raise @@ -248,7 +256,9 @@ def _resolve_with_state( def _resolve_state_events( - conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase] + room_version: RoomVersion, + conflicted_state: StateMap[List[EventBase]], + auth_events: MutableStateMap[EventBase], ) -> StateMap[EventBase]: """This is where we actually decide which of the conflicted state to use. @@ -263,21 +273,27 @@ def _resolve_state_events( if POWER_KEY in conflicted_state: events = conflicted_state[POWER_KEY] logger.debug("Resolving conflicted power levels %r", events) - resolved_state[POWER_KEY] = _resolve_auth_events(events, auth_events) + resolved_state[POWER_KEY] = _resolve_auth_events( + room_version, events, auth_events + ) auth_events.update(resolved_state) for key, events in conflicted_state.items(): if key[0] == EventTypes.JoinRules: logger.debug("Resolving conflicted join rules %r", events) - resolved_state[key] = _resolve_auth_events(events, auth_events) + resolved_state[key] = _resolve_auth_events( + room_version, events, auth_events + ) auth_events.update(resolved_state) for key, events in conflicted_state.items(): if key[0] == EventTypes.Member: logger.debug("Resolving conflicted member lists %r", events) - resolved_state[key] = _resolve_auth_events(events, auth_events) + resolved_state[key] = _resolve_auth_events( + room_version, events, auth_events + ) auth_events.update(resolved_state) @@ -290,12 +306,14 @@ def _resolve_state_events( def _resolve_auth_events( - events: List[EventBase], auth_events: StateMap[EventBase] + room_version: RoomVersion, events: List[EventBase], auth_events: StateMap[EventBase] ) -> EventBase: reverse = list(reversed(_ordered_events(events))) auth_keys = { - key for event in events for key in event_auth.auth_types_for_event(event) + key + for event in events + for key in event_auth.auth_types_for_event(room_version, event) } new_auth_events = {} diff --git a/synapse/state/v2.py b/synapse/state/v2.py index e66e6571c8..7b1e8361de 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -36,7 +36,7 @@ import synapse.state from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.api.room_versions import RoomVersion from synapse.events import EventBase from synapse.types import MutableStateMap, StateMap from synapse.util import Clock @@ -53,7 +53,7 @@ _AWAIT_AFTER_ITERATIONS = 100 async def resolve_events_with_store( clock: Clock, room_id: str, - room_version: str, + room_version: RoomVersion, state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "synapse.state.StateResolutionStore", @@ -497,7 +497,7 @@ async def _reverse_topological_power_sort( async def _iterative_auth_checks( clock: Clock, room_id: str, - room_version: str, + room_version: RoomVersion, event_ids: List[str], base_state: StateMap[str], event_map: Dict[str, EventBase], @@ -519,7 +519,6 @@ async def _iterative_auth_checks( Returns the final updated state """ resolved_state = dict(base_state) - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] for idx, event_id in enumerate(event_ids, start=1): event = event_map[event_id] @@ -538,7 +537,7 @@ async def _iterative_auth_checks( if ev.rejected_reason is None: auth_events[(ev.type, ev.state_key)] = ev - for key in event_auth.auth_types_for_event(event): + for key in event_auth.auth_types_for_event(room_version, event): if key in resolved_state: ev_id = resolved_state[key] ev = await _get_event(room_id, ev_id, event_map, state_res_store) @@ -548,7 +547,7 @@ async def _iterative_auth_checks( try: event_auth.check( - room_version_obj, + room_version, event, auth_events, do_sig_check=False, diff --git a/synapse/storage/database.py b/synapse/storage/database.py index ccf9ac51ef..4d4643619f 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -832,31 +832,16 @@ class DatabasePool: self, table: str, values: Dict[str, Any], - or_ignore: bool = False, desc: str = "simple_insert", - ) -> bool: + ) -> None: """Executes an INSERT query on the named table. Args: table: string giving the table name values: dict of new column names and values for them - or_ignore: bool stating whether an exception should be raised - when a conflicting row already exists. If True, False will be - returned by the function instead desc: description of the transaction, for logging and metrics - - Returns: - Whether the row was inserted or not. Only useful when `or_ignore` is True """ - try: - await self.runInteraction(desc, self.simple_insert_txn, table, values) - except self.engine.module.IntegrityError: - # We have to do or_ignore flag at this layer, since we can't reuse - # a cursor after we receive an error from the db. - if not or_ignore: - raise - return False - return True + await self.runInteraction(desc, self.simple_insert_txn, table, values) @staticmethod def simple_insert_txn( @@ -930,7 +915,7 @@ class DatabasePool: insertion_values: Optional[Dict[str, Any]] = None, desc: str = "simple_upsert", lock: bool = True, - ) -> Optional[bool]: + ) -> bool: """ `lock` should generally be set to True (the default), but can be set @@ -951,8 +936,8 @@ class DatabasePool: desc: description of the transaction, for logging and metrics lock: True to lock the table when doing the upsert. Returns: - Native upserts always return None. Emulated upserts return True if a - new entry was created, False if an existing one was updated. + Returns True if a row was inserted or updated (i.e. if `values` is + not empty then this always returns True) """ insertion_values = insertion_values or {} @@ -995,7 +980,7 @@ class DatabasePool: values: Dict[str, Any], insertion_values: Optional[Dict[str, Any]] = None, lock: bool = True, - ) -> Optional[bool]: + ) -> bool: """ Pick the UPSERT method which works best on the platform. Either the native one (Pg9.5+, recent SQLites), or fall back to an emulated method. @@ -1008,16 +993,15 @@ class DatabasePool: insertion_values: additional key/values to use only when inserting lock: True to lock the table when doing the upsert. Returns: - Native upserts always return None. Emulated upserts return True if a - new entry was created, False if an existing one was updated. + Returns True if a row was inserted or updated (i.e. if `values` is + not empty then this always returns True) """ insertion_values = insertion_values or {} if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: - self.simple_upsert_txn_native_upsert( + return self.simple_upsert_txn_native_upsert( txn, table, keyvalues, values, insertion_values=insertion_values ) - return None else: return self.simple_upsert_txn_emulated( txn, @@ -1045,8 +1029,8 @@ class DatabasePool: insertion_values: additional key/values to use only when inserting lock: True to lock the table when doing the upsert. Returns: - Returns True if a new entry was created, False if an existing - one was updated. + Returns True if a row was inserted or updated (i.e. if `values` is + not empty then this always returns True) """ insertion_values = insertion_values or {} @@ -1086,8 +1070,7 @@ class DatabasePool: txn.execute(sql, sqlargs) if txn.rowcount > 0: - # successfully updated at least one row. - return False + return True # We didn't find any existing rows, so insert a new one allvalues: Dict[str, Any] = {} @@ -1111,15 +1094,19 @@ class DatabasePool: keyvalues: Dict[str, Any], values: Dict[str, Any], insertion_values: Optional[Dict[str, Any]] = None, - ) -> None: + ) -> bool: """ - Use the native UPSERT functionality in recent PostgreSQL versions. + Use the native UPSERT functionality in PostgreSQL. Args: table: The table to upsert into keyvalues: The unique key tables and their new values values: The nonunique columns and their new values insertion_values: additional key/values to use only when inserting + + Returns: + Returns True if a row was inserted or updated (i.e. if `values` is + not empty then this always returns True) """ allvalues: Dict[str, Any] = {} allvalues.update(keyvalues) @@ -1140,6 +1127,8 @@ class DatabasePool: ) txn.execute(sql, list(allvalues.values())) + return bool(txn.rowcount) + async def simple_upsert_many( self, table: str, diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index a3fddea042..8d9f07111d 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -249,7 +249,7 @@ class DataStore( name: Optional[str] = None, guests: bool = True, deactivated: bool = False, - order_by: UserSortOrder = UserSortOrder.USER_ID.value, + order_by: str = UserSortOrder.USER_ID.value, direction: str = "f", ) -> Tuple[List[JsonDict], int]: """Function to retrieve a paginated list of users from @@ -297,27 +297,22 @@ class DataStore( where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" - sql_base = """ + sql_base = f""" FROM users as u LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ? - {} - """.format( - where_clause - ) + {where_clause} + """ sql = "SELECT COUNT(*) as total_users " + sql_base txn.execute(sql, args) count = txn.fetchone()[0] - sql = """ - SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url + sql = f""" + SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, + displayname, avatar_url, creation_ts * 1000 as creation_ts {sql_base} ORDER BY {order_by_column} {order}, u.name ASC LIMIT ? OFFSET ? - """.format( - sql_base=sql_base, - order_by_column=order_by_column, - order=order, - ) + """ args += [limit, start] txn.execute(sql, args) users = self.db_pool.cursor_to_dict(txn) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 18f07d96dc..3816a0ca53 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1078,16 +1078,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return False try: - inserted = await self.db_pool.simple_insert( + inserted = await self.db_pool.simple_upsert( "devices", - values={ + keyvalues={ "user_id": user_id, "device_id": device_id, + }, + values={}, + insertion_values={ "display_name": initial_device_display_name, "hidden": False, }, desc="store_device", - or_ignore=True, ) if not inserted: # if the device already exists, check if it's a real device, or @@ -1099,6 +1101,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) if hidden: raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) + self.device_id_exists_cache.set(key, True) return inserted except StoreError: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index d39368c20e..f4a00b0736 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1227,12 +1227,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas (count,) = txn.fetchone() txn.execute( - "SELECT coalesce(min(received_ts), 0) FROM federation_inbound_events_staging" + "SELECT min(received_ts) FROM federation_inbound_events_staging" ) (received_ts,) = txn.fetchone() - age = self._clock.time_msec() - received_ts + # If there is nothing in the staging area default it to 0. + age = 0 + if received_ts is not None: + age = self._clock.time_msec() - received_ts return count, age diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index fe25638289..d213b26703 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -297,17 +297,13 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): Args: txn (cursor): user_id (str): user to add/update - - Returns: - bool: True if a new entry was created, False if an - existing one was updated. """ # Am consciously deciding to lock the table on the basis that is ought # never be a big table and alternative approaches (batching multiple # upserts into a single txn) introduced a lot of extra complexity. # See https://github.com/matrix-org/synapse/issues/3854 for more - is_insert = self.db_pool.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="monthly_active_users", keyvalues={"user_id": user_id}, @@ -322,8 +318,6 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): txn, self.user_last_seen_monthly_active, (user_id,) ) - return is_insert - async def populate_monthly_active_users(self, user_id): """Checks on the state of monthly active user limits and optionally add the user to the monthly active tables diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 6ddafe5434..443e5f3315 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -363,7 +363,7 @@ class RoomWorkerStore(SQLBaseStore): self, start: int, limit: int, - order_by: RoomSortOrder, + order_by: str, reverse_order: bool, search_term: Optional[str], ) -> Tuple[List[Dict[str, Any]], int]: diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 59d67c255b..42edbcc057 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -75,6 +75,7 @@ class UserSortOrder(Enum): USER_TYPE = ordered alphabetically by `user_type` AVATAR_URL = ordered alphabetically by `avatar_url` SHADOW_BANNED = ordered by `shadow_banned` + CREATION_TS = ordered by `creation_ts` """ MEDIA_LENGTH = "media_length" @@ -88,6 +89,7 @@ class UserSortOrder(Enum): USER_TYPE = "user_type" AVATAR_URL = "avatar_url" SHADOW_BANNED = "shadow_banned" + CREATION_TS = "creation_ts" class StatsStore(StateDeltasStore): @@ -647,10 +649,10 @@ class StatsStore(StateDeltasStore): limit: int, from_ts: Optional[int] = None, until_ts: Optional[int] = None, - order_by: Optional[UserSortOrder] = UserSortOrder.USER_ID.value, + order_by: Optional[str] = UserSortOrder.USER_ID.value, direction: Optional[str] = "f", search_term: Optional[str] = None, - ) -> Tuple[List[JsonDict], Dict[str, int]]: + ) -> Tuple[List[JsonDict], int]: """Function to retrieve a paginated list of users and their uploaded local media (size and number). This will return a json list of users and the total number of users matching the filter criteria. diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index d211c423b2..7728d5f102 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -134,16 +134,18 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): response_dict: The response, to be encoded into JSON. """ - await self.db_pool.simple_insert( + await self.db_pool.simple_upsert( table="received_transactions", - values={ + keyvalues={ "transaction_id": transaction_id, "origin": origin, + }, + values={}, + insertion_values={ "response_code": code, "response_json": db_binary_type(encode_canonical_json(response_dict)), "ts": self._clock.time_msec(), }, - or_ignore=True, desc="set_received_txn_response", ) diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index a6bfb4902a..9d28d69ac7 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -377,7 +377,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): avatar_url = None def _update_profile_in_user_dir_txn(txn): - new_entry = self.db_pool.simple_upsert_txn( + self.db_pool.simple_upsert_txn( txn, table="user_directory", keyvalues={"user_id": user_id}, @@ -388,8 +388,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): if isinstance(self.database_engine, PostgresEngine): # We weight the localpart most highly, then display name and finally # server name - if self.database_engine.can_native_upsert: - sql = """ + sql = """ INSERT INTO user_directory_search(user_id, vector) VALUES (?, setweight(to_tsvector('simple', ?), 'A') @@ -397,58 +396,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): || setweight(to_tsvector('simple', COALESCE(?, '')), 'B') ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector """ - txn.execute( - sql, - ( - user_id, - get_localpart_from_id(user_id), - get_domain_from_id(user_id), - display_name, - ), - ) - else: - # TODO: Remove this code after we've bumped the minimum version - # of postgres to always support upserts, so we can get rid of - # `new_entry` usage - if new_entry is True: - sql = """ - INSERT INTO user_directory_search(user_id, vector) - VALUES (?, - setweight(to_tsvector('simple', ?), 'A') - || setweight(to_tsvector('simple', ?), 'D') - || setweight(to_tsvector('simple', COALESCE(?, '')), 'B') - ) - """ - txn.execute( - sql, - ( - user_id, - get_localpart_from_id(user_id), - get_domain_from_id(user_id), - display_name, - ), - ) - elif new_entry is False: - sql = """ - UPDATE user_directory_search - SET vector = setweight(to_tsvector('simple', ?), 'A') - || setweight(to_tsvector('simple', ?), 'D') - || setweight(to_tsvector('simple', COALESCE(?, '')), 'B') - WHERE user_id = ? - """ - txn.execute( - sql, - ( - get_localpart_from_id(user_id), - get_domain_from_id(user_id), - display_name, - user_id, - ), - ) - else: - raise RuntimeError( - "upsert returned None when 'can_native_upsert' is False" - ) + txn.execute( + sql, + ( + user_id, + get_localpart_from_id(user_id), + get_domain_from_id(user_id), + display_name, + ), + ) elif isinstance(self.database_engine, Sqlite3Engine): value = "%s %s" % (user_id, display_name) if display_name else user_id self.db_pool.simple_upsert_txn( diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index e38461adbc..f839c0c24f 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -372,18 +372,23 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ) async def store_state_group( - self, event_id, room_id, prev_group, delta_ids, current_state_ids + self, + event_id: str, + room_id: str, + prev_group: Optional[int], + delta_ids: Optional[StateMap[str]], + current_state_ids: StateMap[str], ) -> int: """Store a new set of state, returning a newly assigned state group. Args: - event_id (str): The event ID for which the state was calculated - room_id (str) - prev_group (int|None): A previous state group for the room, optional. - delta_ids (dict|None): The delta between state at `prev_group` and + event_id: The event ID for which the state was calculated + room_id + prev_group: A previous state group for the room, optional. + delta_ids: The delta between state at `prev_group` and `current_state_ids`, if `prev_group` was given. Same format as `current_state_ids`. - current_state_ids (dict): The state to store. Map of (type, state_key) + current_state_ids: The state to store. Map of (type, state_key) to event_id. Returns: diff --git a/synapse/storage/state.py b/synapse/storage/state.py index f8fbba9d38..e5400d681a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -570,8 +570,8 @@ class StateGroupStorage: event_id: str, room_id: str, prev_group: Optional[int], - delta_ids: Optional[dict], - current_state_ids: dict, + delta_ids: Optional[StateMap[str]], + current_state_ids: StateMap[str], ) -> int: """Store a new set of state, returning a newly assigned state group. diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 13d300588b..cf4005984b 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -47,20 +47,22 @@ class PaginationConfig: ) -> "PaginationConfig": direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) - from_tok = parse_string(request, "from") - to_tok = parse_string(request, "to") + from_tok_str = parse_string(request, "from") + to_tok_str = parse_string(request, "to") try: - if from_tok == "END": + from_tok = None + if from_tok_str == "END": from_tok = None # For backwards compat. - elif from_tok: - from_tok = await StreamToken.from_string(store, from_tok) + elif from_tok_str: + from_tok = await StreamToken.from_string(store, from_tok_str) except Exception: raise SynapseError(400, "'from' parameter is invalid") try: - if to_tok: - to_tok = await StreamToken.from_string(store, to_tok) + to_tok = None + if to_tok_str: + to_tok = await StreamToken.from_string(store, to_tok_str) except Exception: raise SynapseError(400, "'to' parameter is invalid") diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 014db1355b..912cf85f89 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -49,6 +49,8 @@ from synapse.util import Clock, unwrapFirstError logger = logging.getLogger(__name__) +_T = TypeVar("_T") + class ObservableDeferred: """Wraps a deferred object so that we can add observer deferreds. These @@ -121,7 +123,7 @@ class ObservableDeferred: effect the underlying deferred. """ if not self._result: - d = defer.Deferred() + d: "defer.Deferred[Any]" = defer.Deferred() def remove(r): self._observers.discard(d) @@ -415,7 +417,7 @@ class ReadWriteLock: self.key_to_current_writer: Dict[str, defer.Deferred] = {} async def read(self, key: str) -> ContextManager: - new_defer = defer.Deferred() + new_defer: "defer.Deferred[None]" = defer.Deferred() curr_readers = self.key_to_current_readers.setdefault(key, set()) curr_writer = self.key_to_current_writer.get(key, None) @@ -438,7 +440,7 @@ class ReadWriteLock: return _ctx_manager() async def write(self, key: str) -> ContextManager: - new_defer = defer.Deferred() + new_defer: "defer.Deferred[None]" = defer.Deferred() curr_readers = self.key_to_current_readers.get(key, set()) curr_writer = self.key_to_current_writer.get(key, None) @@ -471,10 +473,8 @@ R = TypeVar("R") def timeout_deferred( - deferred: defer.Deferred, - timeout: float, - reactor: IReactorTime, -) -> defer.Deferred: + deferred: "defer.Deferred[_T]", timeout: float, reactor: IReactorTime +) -> "defer.Deferred[_T]": """The in built twisted `Deferred.addTimeout` fails to time out deferreds that have a canceller that throws exceptions. This method creates a new deferred that wraps and times out the given deferred, correctly handling @@ -497,7 +497,7 @@ def timeout_deferred( Returns: A new Deferred, which will errback with defer.TimeoutError on timeout. """ - new_d = defer.Deferred() + new_d: "defer.Deferred[_T]" = defer.Deferred() timed_out = [False] diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py index 891bee0b33..e58dd91eda 100644 --- a/synapse/util/caches/cached_call.py +++ b/synapse/util/caches/cached_call.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import enum from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union from twisted.internet.defer import Deferred @@ -22,6 +22,10 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background TV = TypeVar("TV") +class _Sentinel(enum.Enum): + sentinel = object() + + class CachedCall(Generic[TV]): """A wrapper for asynchronous calls whose results should be shared @@ -65,7 +69,7 @@ class CachedCall(Generic[TV]): """ self._callable: Optional[Callable[[], Awaitable[TV]]] = f self._deferred: Optional[Deferred] = None - self._result: Union[None, Failure, TV] = None + self._result: Union[_Sentinel, TV, Failure] = _Sentinel.sentinel async def get(self) -> TV: """Kick off the call if necessary, and return the result""" @@ -78,8 +82,9 @@ class CachedCall(Generic[TV]): self._callable = None # once the deferred completes, store the result. We cannot simply leave the - # result in the deferred, since if it's a Failure, GCing the deferred - # would then log a critical error about unhandled Failures. + # result in the deferred, since `awaiting` a deferred destroys its result. + # (Also, if it's a Failure, GCing the deferred would log a critical error + # about unhandled Failures) def got_result(r): self._result = r @@ -92,13 +97,15 @@ class CachedCall(Generic[TV]): # and any eventual exception may not be reported. # we can now await the deferred, and once it completes, return the result. - await make_deferred_yieldable(self._deferred) + if isinstance(self._result, _Sentinel): + await make_deferred_yieldable(self._deferred) + assert not isinstance(self._result, _Sentinel) + + if isinstance(self._result, Failure): + self._result.raiseException() + raise AssertionError("unexpected return from Failure.raiseException") - # I *think* this is the easiest way to correctly raise a Failure without having - # to gut-wrench into the implementation of Deferred. - d = Deferred() - d.callback(self._result) - return await d + return self._result class RetryOnExceptionCachedCall(Generic[TV]): diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 8c6fafc677..b6456392cd 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -16,7 +16,16 @@ import enum import threading -from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, Union +from typing import ( + Callable, + Generic, + Iterable, + MutableMapping, + Optional, + TypeVar, + Union, + cast, +) from prometheus_client import Gauge @@ -166,7 +175,7 @@ class DeferredCache(Generic[KT, VT]): def set( self, key: KT, - value: defer.Deferred, + value: "defer.Deferred[VT]", callback: Optional[Callable[[], None]] = None, ) -> defer.Deferred: """Adds a new entry to the cache (or updates an existing one). @@ -214,7 +223,7 @@ class DeferredCache(Generic[KT, VT]): if value.called: result = value.result if not isinstance(result, failure.Failure): - self.cache.set(key, result, callbacks) + self.cache.set(key, cast(VT, result), callbacks) return value # otherwise, we'll add an entry to the _pending_deferred_cache for now, diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 1e8e6b1d01..1ca31e41ac 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -413,7 +413,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): # relevant result for that key. deferreds_map = {} for arg in missing: - deferred = defer.Deferred() + deferred: "defer.Deferred[Any]" = defer.Deferred() deferreds_map[arg] = deferred key = arg_to_cache_key(arg) cache.set(key, deferred, callback=invalidate_callback) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 9274ce4c39..e2a5fc018c 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -301,6 +301,49 @@ class PruneEventTestCase(unittest.TestCase): room_version=RoomVersions.MSC2176, ) + def test_join_rules(self): + """Join rules events have changed behavior starting with MSC3083.""" + self.run_test( + { + "type": "m.room.join_rules", + "event_id": "$test:domain", + "content": { + "join_rule": "invite", + "allow": [], + "other_key": "stripped", + }, + }, + { + "type": "m.room.join_rules", + "event_id": "$test:domain", + "content": {"join_rule": "invite"}, + "signatures": {}, + "unsigned": {}, + }, + ) + + # After MSC3083, alias events have no special behavior. + self.run_test( + { + "type": "m.room.join_rules", + "content": { + "join_rule": "invite", + "allow": [], + "other_key": "stripped", + }, + }, + { + "type": "m.room.join_rules", + "content": { + "join_rule": "invite", + "allow": [], + }, + "signatures": {}, + "unsigned": {}, + }, + room_version=RoomVersions.MSC3083, + ) + class SerializeEventTestCase(unittest.TestCase): def serialize(self, ev, fields): diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py new file mode 100644 index 0000000000..93a9a084b2 --- /dev/null +++ b/tests/handlers/test_receipts.py @@ -0,0 +1,294 @@ +# Copyright 2021 Šimon Brandner <simon.bra.ag@gmail.com> +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List + +from synapse.api.constants import ReadReceiptEventFields +from synapse.types import JsonDict + +from tests import unittest + + +class ReceiptsTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.event_source = hs.get_event_sources().sources["receipt"] + + # In the first param of _test_filters_hidden we use "hidden" instead of + # ReadReceiptEventFields.MSC2285_HIDDEN. We do this because we're mocking + # the data from the database which doesn't use the prefix + + def test_filters_out_hidden_receipt(self): + self._test_filters_hidden( + [ + { + "content": { + "$1435641916114394fHBLK:matrix.org": { + "m.read": { + "@rikj:jki.re": { + "ts": 1436451550453, + "hidden": True, + } + } + } + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + } + ], + [], + ) + + def test_does_not_filter_out_our_hidden_receipt(self): + self._test_filters_hidden( + [ + { + "content": { + "$1435641916hfgh4394fHBLK:matrix.org": { + "m.read": { + "@me:server.org": { + "ts": 1436451550453, + "hidden": True, + }, + } + } + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + } + ], + [ + { + "content": { + "$1435641916hfgh4394fHBLK:matrix.org": { + "m.read": { + "@me:server.org": { + "ts": 1436451550453, + ReadReceiptEventFields.MSC2285_HIDDEN: True, + }, + } + } + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + } + ], + ) + + def test_filters_out_hidden_receipt_and_ignores_rest(self): + self._test_filters_hidden( + [ + { + "content": { + "$1dgdgrd5641916114394fHBLK:matrix.org": { + "m.read": { + "@rikj:jki.re": { + "ts": 1436451550453, + "hidden": True, + }, + "@user:jki.re": { + "ts": 1436451550453, + }, + } + } + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + } + ], + [ + { + "content": { + "$1dgdgrd5641916114394fHBLK:matrix.org": { + "m.read": { + "@user:jki.re": { + "ts": 1436451550453, + } + } + } + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + } + ], + ) + + def test_filters_out_event_with_only_hidden_receipts_and_ignores_the_rest(self): + self._test_filters_hidden( + [ + { + "content": { + "$14356419edgd14394fHBLK:matrix.org": { + "m.read": { + "@rikj:jki.re": { + "ts": 1436451550453, + "hidden": True, + }, + } + }, + "$1435641916114394fHBLK:matrix.org": { + "m.read": { + "@user:jki.re": { + "ts": 1436451550453, + } + } + }, + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + } + ], + [ + { + "content": { + "$1435641916114394fHBLK:matrix.org": { + "m.read": { + "@user:jki.re": { + "ts": 1436451550453, + } + } + } + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + } + ], + ) + + def test_handles_missing_content_of_m_read(self): + self._test_filters_hidden( + [ + { + "content": { + "$14356419ggffg114394fHBLK:matrix.org": {"m.read": {}}, + "$1435641916114394fHBLK:matrix.org": { + "m.read": { + "@user:jki.re": { + "ts": 1436451550453, + } + } + }, + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + } + ], + [ + { + "content": { + "$14356419ggffg114394fHBLK:matrix.org": {"m.read": {}}, + "$1435641916114394fHBLK:matrix.org": { + "m.read": { + "@user:jki.re": { + "ts": 1436451550453, + } + } + }, + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + } + ], + ) + + def test_handles_empty_event(self): + self._test_filters_hidden( + [ + { + "content": { + "$143564gdfg6114394fHBLK:matrix.org": {}, + "$1435641916114394fHBLK:matrix.org": { + "m.read": { + "@user:jki.re": { + "ts": 1436451550453, + } + } + }, + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + } + ], + [ + { + "content": { + "$143564gdfg6114394fHBLK:matrix.org": {}, + "$1435641916114394fHBLK:matrix.org": { + "m.read": { + "@user:jki.re": { + "ts": 1436451550453, + } + } + }, + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + } + ], + ) + + def test_filters_out_receipt_event_with_only_hidden_receipt_and_ignores_rest(self): + self._test_filters_hidden( + [ + { + "content": { + "$14356419edgd14394fHBLK:matrix.org": { + "m.read": { + "@rikj:jki.re": { + "ts": 1436451550453, + "hidden": True, + }, + } + }, + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + }, + { + "content": { + "$1435641916114394fHBLK:matrix.org": { + "m.read": { + "@user:jki.re": { + "ts": 1436451550453, + } + } + }, + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + }, + ], + [ + { + "content": { + "$1435641916114394fHBLK:matrix.org": { + "m.read": { + "@user:jki.re": { + "ts": 1436451550453, + } + } + } + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + } + ], + ) + + def _test_filters_hidden( + self, events: List[JsonDict], expected_output: List[JsonDict] + ): + """Tests that the _filter_out_hidden returns the expected output""" + filtered_events = self.event_source.filter_out_hidden(events, "@me:server.org") + self.assertEquals(filtered_events, expected_output) diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index 437113929a..e5865c161d 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -14,19 +14,22 @@ import base64 import logging import os -from typing import Optional +from typing import Iterable, Optional from unittest.mock import patch import treq from netaddr import IPSet +from parameterized import parameterized from twisted.internet import interfaces # noqa: F401 +from twisted.internet.endpoints import HostnameEndpoint, _WrapperEndpoint +from twisted.internet.interfaces import IProtocol, IProtocolFactory from twisted.internet.protocol import Factory -from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web.http import HTTPChannel from synapse.http.client import BlacklistingReactorWrapper -from synapse.http.proxyagent import ProxyAgent +from synapse.http.proxyagent import ProxyAgent, ProxyCredentials, parse_proxy from tests.http import TestServerTLSConnectionFactory, get_test_https_policy from tests.server import FakeTransport, ThreadedMemoryReactorClock @@ -37,33 +40,208 @@ logger = logging.getLogger(__name__) HTTPFactory = Factory.forProtocol(HTTPChannel) +class ProxyParserTests(TestCase): + """ + Values for test + [ + proxy_string, + expected_scheme, + expected_hostname, + expected_port, + expected_credentials, + ] + """ + + @parameterized.expand( + [ + # host + [b"localhost", b"http", b"localhost", 1080, None], + [b"localhost:9988", b"http", b"localhost", 9988, None], + # host+scheme + [b"https://localhost", b"https", b"localhost", 1080, None], + [b"https://localhost:1234", b"https", b"localhost", 1234, None], + # ipv4 + [b"1.2.3.4", b"http", b"1.2.3.4", 1080, None], + [b"1.2.3.4:9988", b"http", b"1.2.3.4", 9988, None], + # ipv4+scheme + [b"https://1.2.3.4", b"https", b"1.2.3.4", 1080, None], + [b"https://1.2.3.4:9988", b"https", b"1.2.3.4", 9988, None], + # ipv6 - without brackets is broken + # [ + # b"2001:0db8:85a3:0000:0000:8a2e:0370:effe", + # b"http", + # b"2001:0db8:85a3:0000:0000:8a2e:0370:effe", + # 1080, + # None, + # ], + # [ + # b"2001:0db8:85a3:0000:0000:8a2e:0370:1234", + # b"http", + # b"2001:0db8:85a3:0000:0000:8a2e:0370:1234", + # 1080, + # None, + # ], + # [b"::1", b"http", b"::1", 1080, None], + # [b"::ffff:0.0.0.0", b"http", b"::ffff:0.0.0.0", 1080, None], + # ipv6 - with brackets + [ + b"[2001:0db8:85a3:0000:0000:8a2e:0370:effe]", + b"http", + b"2001:0db8:85a3:0000:0000:8a2e:0370:effe", + 1080, + None, + ], + [ + b"[2001:0db8:85a3:0000:0000:8a2e:0370:1234]", + b"http", + b"2001:0db8:85a3:0000:0000:8a2e:0370:1234", + 1080, + None, + ], + [b"[::1]", b"http", b"::1", 1080, None], + [b"[::ffff:0.0.0.0]", b"http", b"::ffff:0.0.0.0", 1080, None], + # ipv6+port + [ + b"[2001:0db8:85a3:0000:0000:8a2e:0370:effe]:9988", + b"http", + b"2001:0db8:85a3:0000:0000:8a2e:0370:effe", + 9988, + None, + ], + [ + b"[2001:0db8:85a3:0000:0000:8a2e:0370:1234]:9988", + b"http", + b"2001:0db8:85a3:0000:0000:8a2e:0370:1234", + 9988, + None, + ], + [b"[::1]:9988", b"http", b"::1", 9988, None], + [b"[::ffff:0.0.0.0]:9988", b"http", b"::ffff:0.0.0.0", 9988, None], + # ipv6+scheme + [ + b"https://[2001:0db8:85a3:0000:0000:8a2e:0370:effe]", + b"https", + b"2001:0db8:85a3:0000:0000:8a2e:0370:effe", + 1080, + None, + ], + [ + b"https://[2001:0db8:85a3:0000:0000:8a2e:0370:1234]", + b"https", + b"2001:0db8:85a3:0000:0000:8a2e:0370:1234", + 1080, + None, + ], + [b"https://[::1]", b"https", b"::1", 1080, None], + [b"https://[::ffff:0.0.0.0]", b"https", b"::ffff:0.0.0.0", 1080, None], + # ipv6+scheme+port + [ + b"https://[2001:0db8:85a3:0000:0000:8a2e:0370:effe]:9988", + b"https", + b"2001:0db8:85a3:0000:0000:8a2e:0370:effe", + 9988, + None, + ], + [ + b"https://[2001:0db8:85a3:0000:0000:8a2e:0370:1234]:9988", + b"https", + b"2001:0db8:85a3:0000:0000:8a2e:0370:1234", + 9988, + None, + ], + [b"https://[::1]:9988", b"https", b"::1", 9988, None], + # with credentials + [ + b"https://user:pass@1.2.3.4:9988", + b"https", + b"1.2.3.4", + 9988, + b"user:pass", + ], + [b"user:pass@1.2.3.4:9988", b"http", b"1.2.3.4", 9988, b"user:pass"], + [ + b"https://user:pass@proxy.local:9988", + b"https", + b"proxy.local", + 9988, + b"user:pass", + ], + [ + b"user:pass@proxy.local:9988", + b"http", + b"proxy.local", + 9988, + b"user:pass", + ], + ] + ) + def test_parse_proxy( + self, + proxy_string: bytes, + expected_scheme: bytes, + expected_hostname: bytes, + expected_port: int, + expected_credentials: Optional[bytes], + ): + """ + Tests that a given proxy URL will be broken into the components. + Args: + proxy_string: The proxy connection string. + expected_scheme: Expected value of proxy scheme. + expected_hostname: Expected value of proxy hostname. + expected_port: Expected value of proxy port. + expected_credentials: Expected value of credentials. + Must be in form '<username>:<password>' or None + """ + proxy_cred = None + if expected_credentials: + proxy_cred = ProxyCredentials(expected_credentials) + self.assertEqual( + ( + expected_scheme, + expected_hostname, + expected_port, + proxy_cred, + ), + parse_proxy(proxy_string), + ) + + class MatrixFederationAgentTests(TestCase): def setUp(self): self.reactor = ThreadedMemoryReactorClock() def _make_connection( - self, client_factory, server_factory, ssl=False, expected_sni=None - ): + self, + client_factory: IProtocolFactory, + server_factory: IProtocolFactory, + ssl: bool = False, + expected_sni: Optional[bytes] = None, + tls_sanlist: Optional[Iterable[bytes]] = None, + ) -> IProtocol: """Builds a test server, and completes the outgoing client connection Args: - client_factory (interfaces.IProtocolFactory): the the factory that the + client_factory: the the factory that the application is trying to use to make the outbound connection. We will invoke it to build the client Protocol - server_factory (interfaces.IProtocolFactory): a factory to build the + server_factory: a factory to build the server-side protocol - ssl (bool): If true, we will expect an ssl connection and wrap + ssl: If true, we will expect an ssl connection and wrap server_factory with a TLSMemoryBIOFactory - expected_sni (bytes|None): the expected SNI value + expected_sni: the expected SNI value + + tls_sanlist: list of SAN entries for the TLS cert presented by the server. + Defaults to [b'DNS:test.com'] Returns: - IProtocol: the server Protocol returned by server_factory + the server Protocol returned by server_factory """ if ssl: - server_factory = _wrap_server_factory_for_tls(server_factory) + server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist) server_protocol = server_factory.buildProtocol(None) @@ -98,22 +276,28 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual( server_name, expected_sni, - "Expected SNI %s but got %s" % (expected_sni, server_name), + f"Expected SNI {expected_sni!s} but got {server_name!s}", ) return http_protocol - def _test_request_direct_connection(self, agent, scheme, hostname, path): + def _test_request_direct_connection( + self, + agent: ProxyAgent, + scheme: bytes, + hostname: bytes, + path: bytes, + ): """Runs a test case for a direct connection not going through a proxy. Args: - agent (ProxyAgent): the proxy agent being tested + agent: the proxy agent being tested - scheme (bytes): expected to be either "http" or "https" + scheme: expected to be either "http" or "https" - hostname (bytes): the hostname to connect to in the test + hostname: the hostname to connect to in the test - path (bytes): the path to connect to in the test + path: the path to connect to in the test """ is_https = scheme == b"https" @@ -208,7 +392,7 @@ class MatrixFederationAgentTests(TestCase): """ Tests that requests can be made through a proxy. """ - self._do_http_request_via_proxy(auth_credentials=None) + self._do_http_request_via_proxy(ssl=False, auth_credentials=None) @patch.dict( os.environ, @@ -218,12 +402,28 @@ class MatrixFederationAgentTests(TestCase): """ Tests that authenticated requests can be made through a proxy. """ - self._do_http_request_via_proxy(auth_credentials="bob:pinkponies") + self._do_http_request_via_proxy(ssl=False, auth_credentials=b"bob:pinkponies") + + @patch.dict( + os.environ, {"http_proxy": "https://proxy.com:8888", "no_proxy": "unused.com"} + ) + def test_http_request_via_https_proxy(self): + self._do_http_request_via_proxy(ssl=True, auth_credentials=None) + + @patch.dict( + os.environ, + { + "http_proxy": "https://bob:pinkponies@proxy.com:8888", + "no_proxy": "unused.com", + }, + ) + def test_http_request_via_https_proxy_with_auth(self): + self._do_http_request_via_proxy(ssl=True, auth_credentials=b"bob:pinkponies") @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) def test_https_request_via_proxy(self): """Tests that TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(auth_credentials=None) + self._do_https_request_via_proxy(ssl=False, auth_credentials=None) @patch.dict( os.environ, @@ -231,16 +431,40 @@ class MatrixFederationAgentTests(TestCase): ) def test_https_request_via_proxy_with_auth(self): """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" - self._do_https_request_via_proxy(auth_credentials="bob:pinkponies") + self._do_https_request_via_proxy(ssl=False, auth_credentials=b"bob:pinkponies") + + @patch.dict( + os.environ, {"https_proxy": "https://proxy.com", "no_proxy": "unused.com"} + ) + def test_https_request_via_https_proxy(self): + """Tests that TLS-encrypted requests can be made through a proxy""" + self._do_https_request_via_proxy(ssl=True, auth_credentials=None) + + @patch.dict( + os.environ, + {"https_proxy": "https://bob:pinkponies@proxy.com", "no_proxy": "unused.com"}, + ) + def test_https_request_via_https_proxy_with_auth(self): + """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" + self._do_https_request_via_proxy(ssl=True, auth_credentials=b"bob:pinkponies") def _do_http_request_via_proxy( self, - auth_credentials: Optional[str] = None, + ssl: bool = False, + auth_credentials: Optional[bytes] = None, ): + """Send a http request via an agent and check that it is correctly received at + the proxy. The proxy can use either http or https. + Args: + ssl: True if we expect the request to connect via https to proxy + auth_credentials: credentials to authenticate at proxy """ - Tests that requests can be made through a proxy. - """ - agent = ProxyAgent(self.reactor, use_proxy=True) + if ssl: + agent = ProxyAgent( + self.reactor, use_proxy=True, contextFactory=get_test_https_policy() + ) + else: + agent = ProxyAgent(self.reactor, use_proxy=True) self.reactor.lookups["proxy.com"] = "1.2.3.5" d = agent.request(b"GET", b"http://test.com") @@ -254,7 +478,11 @@ class MatrixFederationAgentTests(TestCase): # make a test server, and wire up the client http_server = self._make_connection( - client_factory, _get_test_protocol_factory() + client_factory, + _get_test_protocol_factory(), + ssl=ssl, + tls_sanlist=[b"DNS:proxy.com"] if ssl else None, + expected_sni=b"proxy.com" if ssl else None, ) # the FakeTransport is async, so we need to pump the reactor @@ -272,7 +500,7 @@ class MatrixFederationAgentTests(TestCase): if auth_credentials is not None: # Compute the correct header value for Proxy-Authorization - encoded_credentials = base64.b64encode(b"bob:pinkponies") + encoded_credentials = base64.b64encode(auth_credentials) expected_header_value = b"Basic " + encoded_credentials # Validate the header's value @@ -295,8 +523,15 @@ class MatrixFederationAgentTests(TestCase): def _do_https_request_via_proxy( self, - auth_credentials: Optional[str] = None, + ssl: bool = False, + auth_credentials: Optional[bytes] = None, ): + """Send a https request via an agent and check that it is correctly received at + the proxy and client. The proxy can use either http or https. + Args: + ssl: True if we expect the request to connect via https to proxy + auth_credentials: credentials to authenticate at proxy + """ agent = ProxyAgent( self.reactor, contextFactory=get_test_https_policy(), @@ -313,18 +548,15 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(host, "1.2.3.5") self.assertEqual(port, 1080) - # make a test HTTP server, and wire up the client + # make a test server to act as the proxy, and wire up the client proxy_server = self._make_connection( - client_factory, _get_test_protocol_factory() + client_factory, + _get_test_protocol_factory(), + ssl=ssl, + tls_sanlist=[b"DNS:proxy.com"] if ssl else None, + expected_sni=b"proxy.com" if ssl else None, ) - - # fish the transports back out so that we can do the old switcheroo - s2c_transport = proxy_server.transport - client_protocol = s2c_transport.other - c2s_transport = client_protocol.transport - - # the FakeTransport is async, so we need to pump the reactor - self.reactor.advance(0) + assert isinstance(proxy_server, HTTPChannel) # now there should be a pending CONNECT request self.assertEqual(len(proxy_server.requests), 1) @@ -340,7 +572,7 @@ class MatrixFederationAgentTests(TestCase): if auth_credentials is not None: # Compute the correct header value for Proxy-Authorization - encoded_credentials = base64.b64encode(b"bob:pinkponies") + encoded_credentials = base64.b64encode(auth_credentials) expected_header_value = b"Basic " + encoded_credentials # Validate the header's value @@ -352,31 +584,49 @@ class MatrixFederationAgentTests(TestCase): # tell the proxy server not to close the connection proxy_server.persistent = True - # this just stops the http Request trying to do a chunked response - # request.setHeader(b"Content-Length", b"0") request.finish() - # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel - ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory()) - ssl_protocol = ssl_factory.buildProtocol(None) - http_server = ssl_protocol.wrappedProtocol + # now we make another test server to act as the upstream HTTP server. + server_ssl_protocol = _wrap_server_factory_for_tls( + _get_test_protocol_factory() + ).buildProtocol(None) - ssl_protocol.makeConnection( - FakeTransport(client_protocol, self.reactor, ssl_protocol) - ) - c2s_transport.other = ssl_protocol + # Tell the HTTP server to send outgoing traffic back via the proxy's transport. + proxy_server_transport = proxy_server.transport + server_ssl_protocol.makeConnection(proxy_server_transport) + + # ... and replace the protocol on the proxy's transport with the + # TLSMemoryBIOProtocol for the test server, so that incoming traffic + # to the proxy gets sent over to the HTTP(s) server. + # + # This needs a bit of gut-wrenching, which is different depending on whether + # the proxy is using TLS or not. + # + # (an alternative, possibly more elegant, approach would be to use a custom + # Protocol to implement the proxy, which starts out by forwarding to an + # HTTPChannel (to implement the CONNECT command) and can then be switched + # into a mode where it forwards its traffic to another Protocol.) + if ssl: + assert isinstance(proxy_server_transport, TLSMemoryBIOProtocol) + proxy_server_transport.wrappedProtocol = server_ssl_protocol + else: + assert isinstance(proxy_server_transport, FakeTransport) + client_protocol = proxy_server_transport.other + c2s_transport = client_protocol.transport + c2s_transport.other = server_ssl_protocol self.reactor.advance(0) - server_name = ssl_protocol._tlsConnection.get_servername() + server_name = server_ssl_protocol._tlsConnection.get_servername() expected_sni = b"test.com" self.assertEqual( server_name, expected_sni, - "Expected SNI %s but got %s" % (expected_sni, server_name), + f"Expected SNI {expected_sni!s} but got {server_name!s}", ) # now there should be a pending request + http_server = server_ssl_protocol.wrappedProtocol self.assertEqual(len(http_server.requests), 1) request = http_server.requests[0] @@ -510,7 +760,7 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual( server_name, expected_sni, - "Expected SNI %s but got %s" % (expected_sni, server_name), + f"Expected SNI {expected_sni!s} but got {server_name!s}", ) # now there should be a pending request @@ -529,16 +779,48 @@ class MatrixFederationAgentTests(TestCase): body = self.successResultOf(treq.content(resp)) self.assertEqual(body, b"result") + @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"}) + def test_proxy_with_no_scheme(self): + http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) + self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) + self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com") + self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888) + + @patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"}) + def test_proxy_with_unsupported_scheme(self): + with self.assertRaises(ValueError): + ProxyAgent(self.reactor, use_proxy=True) + + @patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"}) + def test_proxy_with_http_scheme(self): + http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) + self.assertIsInstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint) + self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com") + self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888) + + @patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"}) + def test_proxy_with_https_scheme(self): + https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) + self.assertIsInstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint) + self.assertEqual( + https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._hostStr, "proxy.com" + ) + self.assertEqual( + https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._port, 8888 + ) + -def _wrap_server_factory_for_tls(factory, sanlist=None): +def _wrap_server_factory_for_tls( + factory: IProtocolFactory, sanlist: Iterable[bytes] = None +) -> IProtocolFactory: """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory The resultant factory will create a TLS server which presents a certificate signed by our test CA, valid for the domains in `sanlist` Args: - factory (interfaces.IProtocolFactory): protocol factory to wrap - sanlist (iterable[bytes]): list of domains the cert should be valid for + factory: protocol factory to wrap + sanlist: list of domains the cert should be valid for Returns: interfaces.IProtocolFactory @@ -552,7 +834,7 @@ def _wrap_server_factory_for_tls(factory, sanlist=None): ) -def _get_test_protocol_factory(): +def _get_test_protocol_factory() -> IProtocolFactory: """Get a protocol Factory which will build an HTTPChannel Returns: @@ -566,6 +848,6 @@ def _get_test_protocol_factory(): return server_factory -def _log_request(request): +def _log_request(request: str): """Implements Factory.log, which is expected by Request.finish""" - logger.info("Completed request %s", request) + logger.info(f"Completed request {request}") diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 6fee0f95b6..7198fd293f 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -261,7 +261,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) self.assertEqual( - "Missing integer query parameter b'before_ts'", channel.json_body["error"] + "Missing integer query parameter 'before_ts'", channel.json_body["error"] ) def test_invalid_parameter(self): @@ -303,7 +303,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual( - "Boolean query parameter b'keep_profiles' must be one of ['true', 'false']", + "Boolean query parameter 'keep_profiles' must be one of ['true', 'false']", channel.json_body["error"], ) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 4fccce34fd..42f50c0921 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -473,7 +473,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -485,7 +485,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", self.url, access_token=other_user_token) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_all_users(self): @@ -497,11 +497,11 @@ class UsersListTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", self.url + "?deactivated=true", - b"{}", + {}, access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(3, len(channel.json_body["users"])) self.assertEqual(3, channel.json_body["total"]) @@ -532,7 +532,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): ) channel = self.make_request( "GET", - url.encode("ascii"), + url, access_token=self.admin_user_tok, ) self.assertEqual(expected_http_code, channel.code, msg=channel.json_body) @@ -598,7 +598,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # negative from @@ -608,7 +608,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) # invalid guests @@ -618,7 +618,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # invalid deactivated @@ -628,7 +628,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) # unkown order_by @@ -648,7 +648,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) def test_limit(self): @@ -666,7 +666,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 5) self.assertEqual(channel.json_body["next_token"], "5") @@ -687,7 +687,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 15) self.assertNotIn("next_token", channel.json_body) @@ -708,7 +708,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(channel.json_body["next_token"], "15") self.assertEqual(len(channel.json_body["users"]), 10) @@ -731,7 +731,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -744,7 +744,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), number_users) self.assertNotIn("next_token", channel.json_body) @@ -757,7 +757,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 19) self.assertEqual(channel.json_body["next_token"], "19") @@ -771,7 +771,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): access_token=self.admin_user_tok, ) - self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(channel.json_body["total"], number_users) self.assertEqual(len(channel.json_body["users"]), 1) self.assertNotIn("next_token", channel.json_body) @@ -781,7 +781,10 @@ class UsersListTestCase(unittest.HomeserverTestCase): Testing order list with parameter `order_by` """ + # make sure that the users do not have the same timestamps + self.reactor.advance(10) user1 = self.register_user("user1", "pass1", admin=False, displayname="Name Z") + self.reactor.advance(10) user2 = self.register_user("user2", "pass2", admin=False, displayname="Name Y") # Modify user @@ -841,6 +844,11 @@ class UsersListTestCase(unittest.HomeserverTestCase): self._order_test([self.admin_user, user2, user1], "avatar_url", "f") self._order_test([user1, user2, self.admin_user], "avatar_url", "b") + # order by creation_ts + self._order_test([self.admin_user, user1, user2], "creation_ts") + self._order_test([self.admin_user, user1, user2], "creation_ts", "f") + self._order_test([user2, user1, self.admin_user], "creation_ts", "b") + def _order_test( self, expected_user_list: List[str], @@ -863,7 +871,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): url += "dir=%s" % (dir,) channel = self.make_request( "GET", - url.encode("ascii"), + url, access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -887,6 +895,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertIn("shadow_banned", u) self.assertIn("displayname", u) self.assertIn("avatar_url", u) + self.assertIn("creation_ts", u) def _create_users(self, number_users: int): """ diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index 4ef19145d1..317a2287e3 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -24,6 +24,7 @@ import pkg_resources import synapse.rest.admin from synapse.api.constants import LoginType, Membership from synapse.api.errors import Codes, HttpResponseException +from synapse.appservice import ApplicationService from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource @@ -397,7 +398,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id))) # Check that this access token has been invalidated. - channel = self.make_request("GET", "account/whoami") + channel = self.make_request("GET", "account/whoami", access_token=tok) self.assertEqual(channel.code, 401) def test_pending_invites(self): @@ -458,6 +459,46 @@ class DeactivateTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) +class WhoamiTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + account.register_servlets, + register.register_servlets, + ] + + def test_GET_whoami(self): + device_id = "wouldgohere" + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test", device_id=device_id) + + whoami = self.whoami(tok) + self.assertEqual(whoami, {"user_id": user_id, "device_id": device_id}) + + def test_GET_whoami_appservices(self): + user_id = "@as:test" + as_token = "i_am_an_app_service" + + appservice = ApplicationService( + as_token, + self.hs.config.server_name, + id="1234", + namespaces={"users": [{"regex": user_id, "exclusive": True}]}, + sender=user_id, + ) + self.hs.get_datastore().services_cache.append(appservice) + + whoami = self.whoami(as_token) + self.assertEqual(whoami, {"user_id": user_id}) + self.assertFalse(hasattr(whoami, "device_id")) + + def whoami(self, tok): + channel = self.make_request("GET", "account/whoami", {}, access_token=tok) + self.assertEqual(channel.code, 200) + return channel.json_body + + class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index 874052c61c..f80f48a455 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -102,3 +102,49 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertFalse(capabilities["m.change_password"]["enabled"]) + + def test_get_does_not_include_msc3244_fields_by_default(self): + localpart = "user" + password = "pass" + user = self.register_user(localpart, password) + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertNotIn( + "org.matrix.msc3244.room_capabilities", capabilities["m.room_versions"] + ) + + @override_config({"experimental_features": {"msc3244_enabled": True}}) + def test_get_does_include_msc3244_fields_when_enabled(self): + localpart = "user" + password = "pass" + user = self.register_user(localpart, password) + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + for details in capabilities["m.room_versions"][ + "org.matrix.msc3244.room_capabilities" + ].values(): + if details["preferred"] is not None: + self.assertTrue( + details["preferred"] in KNOWN_ROOM_VERSIONS, + str(details["preferred"]), + ) + + self.assertGreater(len(details["support"]), 0) + for room_version in details["support"]: + self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, str(room_version)) diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index cdca3a3e23..f6ae9ae181 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -15,9 +15,14 @@ import json import synapse.rest.admin -from synapse.api.constants import EventContentFields, EventTypes, RelationTypes +from synapse.api.constants import ( + EventContentFields, + EventTypes, + ReadReceiptEventFields, + RelationTypes, +) from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import knock, read_marker, sync +from synapse.rest.client.v2_alpha import knock, read_marker, receipts, sync from tests import unittest from tests.federation.transport.test_knocking import ( @@ -368,6 +373,76 @@ class SyncKnockTestCase( ) +class ReadReceiptsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + receipts.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.url = "/sync?since=%s" + self.next_batch = "s0" + + # Register the first user + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # Create the room + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # Register the second user + self.user2 = self.register_user("kermit2", "monkey") + self.tok2 = self.login("kermit2", "monkey") + + # Join the second user + self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) + + @override_config({"experimental_features": {"msc2285_enabled": True}}) + def test_hidden_read_receipts(self): + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a read receipt to tell the server the first user's message was read + body = json.dumps({ReadReceiptEventFields.MSC2285_HIDDEN: True}).encode("utf8") + channel = self.make_request( + "POST", + "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), + body, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + + # Test that the first user can't see the other user's hidden read receipt + self.assertEqual(self._get_read_receipt(), None) + + def _get_read_receipt(self): + """Syncs and returns the read receipt.""" + + # Checks if event is a read receipt + def is_read_receipt(event): + return event["type"] == "m.receipt" + + # Sync + channel = self.make_request( + "GET", + self.url % self.next_batch, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200) + + # Store the next batch for the next request. + self.next_batch = channel.json_body["next_batch"] + + # Return the read receipt + ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][ + "ephemeral" + ]["events"] + return next(filter(is_read_receipt, ephemeral_events), None) + + class UnreadMessagesTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, @@ -375,6 +450,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): read_marker.register_servlets, room.register_servlets, sync.register_servlets, + receipts.register_servlets, ] def prepare(self, reactor, clock, hs): @@ -448,6 +524,23 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Check that the unread counter is back to 0. self._check_unread_count(0) + # Check that hidden read receipts don't break unread counts + res = self.helper.send(self.room_id, "hello", tok=self.tok2) + self._check_unread_count(1) + + # Send a read receipt to tell the server we've read the latest event. + body = json.dumps({ReadReceiptEventFields.MSC2285_HIDDEN: True}).encode("utf8") + channel = self.make_request( + "POST", + "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), + body, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the unread counter is back to 0. + self._check_unread_count(0) + # Check that room name changes increase the unread counter. self.helper.send_state( self.room_id, diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 43fc79ca74..8370a27195 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -484,7 +484,7 @@ class StateTestCase(unittest.TestCase): state_d = resolve_events_with_store( FakeClock(), ROOM_ID, - RoomVersions.V2.identifier, + RoomVersions.V2, [state_at_event[n] for n in prev_events], event_map=event_map, state_res_store=TestStateResolutionStore(event_map), @@ -496,7 +496,7 @@ class StateTestCase(unittest.TestCase): if fake_event.state_key is not None: state_after[(fake_event.type, fake_event.state_key)] = event_id - auth_types = set(auth_types_for_event(fake_event)) + auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event)) auth_events = [] for key in auth_types: @@ -633,7 +633,7 @@ class SimpleParamStateTestCase(unittest.TestCase): state_d = resolve_events_with_store( FakeClock(), ROOM_ID, - RoomVersions.V2.identifier, + RoomVersions.V2, [self.state_at_bob, self.state_at_charlie], event_map=None, state_res_store=TestStateResolutionStore(self.event_map), diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index dbacce4380..8c95a0a2fb 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import List, Optional from canonicaljson import json @@ -234,8 +234,8 @@ class RedactionTestCase(unittest.HomeserverTestCase): async def build( self, - prev_event_ids, - auth_event_ids, + prev_event_ids: List[str], + auth_event_ids: Optional[List[str]], depth: Optional[int] = None, ): built_event = await self._base_builder.build( diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index f73306ecc4..e5550aec4d 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -351,7 +351,11 @@ class EventAuthTestCase(unittest.TestCase): """ Test joining a restricted room from MSC3083. - This is pretty much the same test as public. + This is similar to the public test, but has some additional checks on + signatures. + + The checks which care about signatures fake them by simply adding an + object of the proper form, not generating valid signatures. """ creator = "@creator:example.com" pleb = "@joiner:example.com" @@ -359,6 +363,7 @@ class EventAuthTestCase(unittest.TestCase): auth_events = { ("m.room.create", ""): _create_event(creator), ("m.room.member", creator): _join_event(creator), + ("m.room.power_levels", ""): _power_levels_event(creator, {"invite": 0}), ("m.room.join_rules", ""): _join_rules_event(creator, "restricted"), } @@ -371,19 +376,81 @@ class EventAuthTestCase(unittest.TestCase): do_sig_check=False, ) - # Check join. + # A properly formatted join event should work. + authorised_join_event = _join_event( + pleb, + additional_content={ + "join_authorised_via_users_server": "@creator:example.com" + }, + ) event_auth.check( RoomVersions.MSC3083, - _join_event(pleb), + authorised_join_event, auth_events, do_sig_check=False, ) - # A user cannot be force-joined to a room. + # A join issued by a specific user works (i.e. the power level checks + # are done properly). + pl_auth_events = auth_events.copy() + pl_auth_events[("m.room.power_levels", "")] = _power_levels_event( + creator, {"invite": 100, "users": {"@inviter:foo.test": 150}} + ) + pl_auth_events[("m.room.member", "@inviter:foo.test")] = _join_event( + "@inviter:foo.test" + ) + event_auth.check( + RoomVersions.MSC3083, + _join_event( + pleb, + additional_content={ + "join_authorised_via_users_server": "@inviter:foo.test" + }, + ), + pl_auth_events, + do_sig_check=False, + ) + + # A join which is missing an authorised server is rejected. with self.assertRaises(AuthError): event_auth.check( RoomVersions.MSC3083, - _member_event(pleb, "join", sender=creator), + _join_event(pleb), + auth_events, + do_sig_check=False, + ) + + # An join authorised by a user who is not in the room is rejected. + pl_auth_events = auth_events.copy() + pl_auth_events[("m.room.power_levels", "")] = _power_levels_event( + creator, {"invite": 100, "users": {"@other:example.com": 150}} + ) + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.MSC3083, + _join_event( + pleb, + additional_content={ + "join_authorised_via_users_server": "@other:example.com" + }, + ), + auth_events, + do_sig_check=False, + ) + + # A user cannot be force-joined to a room. (This uses an event which + # *would* be valid, but is sent be a different user.) + with self.assertRaises(AuthError): + event_auth.check( + RoomVersions.MSC3083, + _member_event( + pleb, + "join", + sender=creator, + additional_content={ + "join_authorised_via_users_server": "@inviter:foo.test" + }, + ), auth_events, do_sig_check=False, ) @@ -393,7 +460,7 @@ class EventAuthTestCase(unittest.TestCase): with self.assertRaises(AuthError): event_auth.check( RoomVersions.MSC3083, - _join_event(pleb), + authorised_join_event, auth_events, do_sig_check=False, ) @@ -402,12 +469,13 @@ class EventAuthTestCase(unittest.TestCase): auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave") event_auth.check( RoomVersions.MSC3083, - _join_event(pleb), + authorised_join_event, auth_events, do_sig_check=False, ) - # A user can send a join if they're in the room. + # A user can send a join if they're in the room. (This doesn't need to + # be authorised since the user is already joined.) auth_events[("m.room.member", pleb)] = _member_event(pleb, "join") event_auth.check( RoomVersions.MSC3083, @@ -416,7 +484,8 @@ class EventAuthTestCase(unittest.TestCase): do_sig_check=False, ) - # A user can accept an invite. + # A user can accept an invite. (This doesn't need to be authorised since + # the user was invited.) auth_events[("m.room.member", pleb)] = _member_event( pleb, "invite", sender=creator ) @@ -446,7 +515,10 @@ def _create_event(user_id: str) -> EventBase: def _member_event( - user_id: str, membership: str, sender: Optional[str] = None + user_id: str, + membership: str, + sender: Optional[str] = None, + additional_content: Optional[dict] = None, ) -> EventBase: return make_event_from_dict( { @@ -455,14 +527,14 @@ def _member_event( "type": "m.room.member", "sender": sender or user_id, "state_key": user_id, - "content": {"membership": membership}, + "content": {"membership": membership, **(additional_content or {})}, "prev_events": [], } ) -def _join_event(user_id: str) -> EventBase: - return _member_event(user_id, "join") +def _join_event(user_id: str, additional_content: Optional[dict] = None) -> EventBase: + return _member_event(user_id, "join", additional_content=additional_content) def _power_levels_event(sender: str, content: JsonDict) -> EventBase: diff --git a/tests/test_preview.py b/tests/test_preview.py index cac3d81ac1..48e792b55b 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -325,6 +325,19 @@ class MediaEncodingTestCase(unittest.TestCase): ) self.assertEqual(encoding, "ascii") + def test_meta_charset_underscores(self): + """A character encoding contains underscore.""" + encoding = get_html_media_encoding( + b""" + <html> + <head><meta charset="Shift_JIS"> + </head> + </html> + """, + "text/html", + ) + self.assertEqual(encoding, "Shift_JIS") + def test_xml_encoding(self): """A character encoding is found via the meta tag.""" encoding = get_html_media_encoding( |