diff options
222 files changed, 6439 insertions, 3066 deletions
diff --git a/.ci/scripts/test_old_deps.sh b/.ci/scripts/test_old_deps.sh index a54aa86fbc..b2859f7522 100755 --- a/.ci/scripts/test_old_deps.sh +++ b/.ci/scripts/test_old_deps.sh @@ -8,11 +8,13 @@ export DEBIAN_FRONTEND=noninteractive set -ex apt-get update -apt-get install -y python3 python3-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox libjpeg-dev libwebp-dev +apt-get install -y \ + python3 python3-dev python3-pip python3-venv \ + libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox libjpeg-dev libwebp-dev export LANG="C.UTF-8" # Prevent virtualenv from auto-updating pip to an incompatible version export VIRTUALENV_NO_DOWNLOAD=1 -exec tox -e py3-old,combine +exec tox -e py3-old diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml index eb294f1619..eee3633d50 100644 --- a/.github/workflows/release-artifacts.yml +++ b/.github/workflows/release-artifacts.yml @@ -7,7 +7,7 @@ on: # of things breaking (but only build one set of debs) pull_request: push: - branches: ["develop"] + branches: ["develop", "release-*"] # we do the full build on tags. tags: ["v*"] @@ -91,17 +91,7 @@ jobs: build-sdist: name: "Build pypi distribution files" - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 - - run: pip install wheel - - run: | - python setup.py sdist bdist_wheel - - uses: actions/upload-artifact@v2 - with: - name: python-dist - path: dist/* + uses: "matrix-org/backend-meta/.github/workflows/packaging.yml@v1" # if it's a tag, create a release and attach the artifacts to it attach-assets: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e0f80aaaa7..bbf1033bdd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -48,24 +48,10 @@ jobs: env: PULL_REQUEST_NUMBER: ${{ github.event.number }} - lint-sdist: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 - with: - python-version: "3.x" - - run: pip install wheel - - run: python setup.py sdist bdist_wheel - - uses: actions/upload-artifact@v2 - with: - name: Python Distributions - path: dist/* - # Dummy step to gate other tests on without repeating the whole list linting-done: if: ${{ !cancelled() }} # Run this even if prior jobs were skipped - needs: [lint, lint-crlf, lint-newsfile, lint-sdist] + needs: [lint, lint-crlf, lint-newsfile] runs-on: ubuntu-latest steps: - run: "true" @@ -345,7 +331,7 @@ jobs: path: synapse # Attempt to check out the same branch of Complement as the PR. If it - # doesn't exist, fallback to master. + # doesn't exist, fallback to HEAD. - name: Checkout complement shell: bash run: | @@ -358,8 +344,8 @@ jobs: # for pull requests, otherwise GITHUB_REF). # 2. Attempt to use the base branch, e.g. when merging into release-vX.Y # (GITHUB_BASE_REF for pull requests). - # 3. Use the default complement branch ("master"). - for BRANCH_NAME in "$GITHUB_HEAD_REF" "$GITHUB_BASE_REF" "${GITHUB_REF#refs/heads/}" "master"; do + # 3. Use the default complement branch ("HEAD"). + for BRANCH_NAME in "$GITHUB_HEAD_REF" "$GITHUB_BASE_REF" "${GITHUB_REF#refs/heads/}" "HEAD"; do # Skip empty branch names and merge commits. if [[ -z "$BRANCH_NAME" || $BRANCH_NAME =~ ^refs/pull/.* ]]; then continue @@ -383,7 +369,7 @@ jobs: # Run Complement - run: | set -o pipefail - go test -v -json -tags synapse_blacklist,msc2403 ./tests/... 2>&1 | gotestfmt + go test -v -json -p 1 -tags synapse_blacklist,msc2403 ./tests/... 2>&1 | gotestfmt shell: bash name: Run Complement Tests env: @@ -397,7 +383,6 @@ jobs: - lint - lint-crlf - lint-newsfile - - lint-sdist - trial - trial-olddeps - sytest diff --git a/CHANGES.md b/CHANGES.md index cd62e5256a..576dab74f4 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,82 @@ +Synapse 1.53.0rc1 (2022-02-15) +============================== + +Features +-------- + +- Add experimental support for sending to-device messages to application services, as specified by [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409). ([\#11215](https://github.com/matrix-org/synapse/issues/11215), [\#11966](https://github.com/matrix-org/synapse/issues/11966)) +- Remove account data (including client config, push rules and ignored users) upon user deactivation. ([\#11655](https://github.com/matrix-org/synapse/issues/11655)) +- Experimental support for [MSC3666](https://github.com/matrix-org/matrix-doc/pull/3666): including bundled aggregations in server side search results. ([\#11837](https://github.com/matrix-org/synapse/issues/11837)) +- Enable cache time-based expiry by default. The `expiry_time` config flag has been superseded by `expire_caches` and `cache_entry_ttl`. ([\#11849](https://github.com/matrix-org/synapse/issues/11849)) +- Add a callback to allow modules to allow or forbid a 3PID (email address, phone number) from being associated to a local account. ([\#11854](https://github.com/matrix-org/synapse/issues/11854)) +- Stabilize support and remove unstable endpoints for [MSC3231](https://github.com/matrix-org/matrix-doc/pull/3231). Clients must switch to the stable identifier and endpoint. See the [upgrade notes](https://matrix-org.github.io/synapse/develop/upgrade#stablisation-of-msc3231) for more information. ([\#11867](https://github.com/matrix-org/synapse/issues/11867)) +- Allow modules to retrieve the current instance's server name and worker name. ([\#11868](https://github.com/matrix-org/synapse/issues/11868)) +- Use a dedicated configurable rate limiter for 3PID invites. ([\#11892](https://github.com/matrix-org/synapse/issues/11892)) +- Support the stable API endpoint for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283): new settings in `/capabilities` endpoint. ([\#11933](https://github.com/matrix-org/synapse/issues/11933), [\#11989](https://github.com/matrix-org/synapse/issues/11989)) +- Support the `dir` parameter on the `/relations` endpoint, per [MSC3715](https://github.com/matrix-org/matrix-doc/pull/3715). ([\#11941](https://github.com/matrix-org/synapse/issues/11941)) +- Experimental implementation of [MSC3706](https://github.com/matrix-org/matrix-doc/pull/3706): extensions to `/send_join` to support reduced response size. ([\#11967](https://github.com/matrix-org/synapse/issues/11967)) + + +Bugfixes +-------- + +- Fix [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) historical messages backfilling in random order on remote homeservers. ([\#11114](https://github.com/matrix-org/synapse/issues/11114)) +- Fix a bug introduced in Synapse 1.51.0 where incoming federation transactions containing at least one EDU would be dropped if debug logging was enabled for `synapse.8631_debug`. ([\#11890](https://github.com/matrix-org/synapse/issues/11890)) +- Fix a long-standing bug where some unknown endpoints would return HTML error pages instead of JSON `M_UNRECOGNIZED` errors. ([\#11930](https://github.com/matrix-org/synapse/issues/11930)) +- Implement an allow list of content types for which we will attempt to preview a URL. This prevents Synapse from making useless longer-lived connections to streaming media servers. ([\#11936](https://github.com/matrix-org/synapse/issues/11936)) +- Fix a long-standing bug where pagination tokens from `/sync` and `/messages` could not be provided to the `/relations` API. ([\#11952](https://github.com/matrix-org/synapse/issues/11952)) +- Require that modules register their callbacks using keyword arguments. ([\#11975](https://github.com/matrix-org/synapse/issues/11975)) +- Fix a long-standing bug where `M_WRONG_ROOM_KEYS_VERSION` errors would not include the specced `current_version` field. ([\#11988](https://github.com/matrix-org/synapse/issues/11988)) + + +Improved Documentation +---------------------- + +- Fix typo in User Admin API: unpind -> unbind. ([\#11859](https://github.com/matrix-org/synapse/issues/11859)) +- Document images returned by the User List Media Admin API can include those generated by URL previews. ([\#11862](https://github.com/matrix-org/synapse/issues/11862)) +- Remove outdated MSC1711 FAQ document. ([\#11907](https://github.com/matrix-org/synapse/issues/11907)) +- Correct the structured logging configuration example. Contributed by Brad Jones. ([\#11946](https://github.com/matrix-org/synapse/issues/11946)) +- Add information on the Synapse release cycle. ([\#11954](https://github.com/matrix-org/synapse/issues/11954)) +- Fix broken link in the README to the admin API for password reset. ([\#11955](https://github.com/matrix-org/synapse/issues/11955)) + + +Deprecations and Removals +------------------------- + +- Drop support for `webclient` listeners and configuring `web_client_location` to a non-HTTP(S) URL. Deprecated configurations are a configuration error. ([\#11895](https://github.com/matrix-org/synapse/issues/11895)) +- Remove deprecated `user_may_create_room_with_invites` spam checker callback. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html#removal-of-user_may_create_room_with_invites) for more information. ([\#11950](https://github.com/matrix-org/synapse/issues/11950)) +- No longer build `.deb` packages for Ubuntu 21.04 Hirsute Hippo, which has now EOLed. ([\#11961](https://github.com/matrix-org/synapse/issues/11961)) + + +Internal Changes +---------------- + +- Enhance user registration test helpers to make them more useful for tests involving application services and devices. ([\#11615](https://github.com/matrix-org/synapse/issues/11615), [\#11616](https://github.com/matrix-org/synapse/issues/11616)) +- Improve performance when fetching bundled aggregations for multiple events. ([\#11660](https://github.com/matrix-org/synapse/issues/11660), [\#11752](https://github.com/matrix-org/synapse/issues/11752)) +- Fix type errors introduced by new annotations in the Prometheus Client library. ([\#11832](https://github.com/matrix-org/synapse/issues/11832)) +- Add missing type hints to replication code. ([\#11856](https://github.com/matrix-org/synapse/issues/11856), [\#11938](https://github.com/matrix-org/synapse/issues/11938)) +- Ensure that `opentracing` scopes are activated and closed at the right time. ([\#11869](https://github.com/matrix-org/synapse/issues/11869)) +- Improve opentracing for incoming federation requests. ([\#11870](https://github.com/matrix-org/synapse/issues/11870)) +- Improve internal docstrings in `synapse.util.caches`. ([\#11876](https://github.com/matrix-org/synapse/issues/11876)) +- Do not needlessly clear the `get_users_in_room` and `get_users_in_room_with_profiles` caches when any room state changes. ([\#11878](https://github.com/matrix-org/synapse/issues/11878)) +- Convert `ApplicationServiceTestCase` to use `simple_async_mock`. ([\#11880](https://github.com/matrix-org/synapse/issues/11880)) +- Remove experimental changes to the default push rules which were introduced in Synapse 1.19.0 but never enabled. ([\#11884](https://github.com/matrix-org/synapse/issues/11884)) +- Disable coverage calculation for olddeps build. ([\#11888](https://github.com/matrix-org/synapse/issues/11888)) +- Preparation to support sending device list updates to application services. ([\#11905](https://github.com/matrix-org/synapse/issues/11905)) +- Add a test that checks users receive their own device list updates down `/sync`. ([\#11909](https://github.com/matrix-org/synapse/issues/11909)) +- Run Complement tests sequentially. ([\#11910](https://github.com/matrix-org/synapse/issues/11910)) +- Various refactors to the application service notifier code. ([\#11911](https://github.com/matrix-org/synapse/issues/11911), [\#11912](https://github.com/matrix-org/synapse/issues/11912)) +- Tests: replace mocked `Authenticator` with the real thing. ([\#11913](https://github.com/matrix-org/synapse/issues/11913)) +- Various refactors to the typing notifications code. ([\#11914](https://github.com/matrix-org/synapse/issues/11914)) +- Use the proper type for the `Content-Length` header in the `UploadResource`. ([\#11927](https://github.com/matrix-org/synapse/issues/11927)) +- Remove an unnecessary ignoring of type hints due to fixes in upstream packages. ([\#11939](https://github.com/matrix-org/synapse/issues/11939)) +- Add missing type hints. ([\#11953](https://github.com/matrix-org/synapse/issues/11953)) +- Fix an import cycle in `synapse.event_auth`. ([\#11965](https://github.com/matrix-org/synapse/issues/11965)) +- Unpin `frozendict` but exclude the known bad version 2.1.2. ([\#11969](https://github.com/matrix-org/synapse/issues/11969)) +- Prepare for rename of default Complement branch. ([\#11971](https://github.com/matrix-org/synapse/issues/11971)) +- Fetch Synapse's version using a helper from `matrix-common`. ([\#11979](https://github.com/matrix-org/synapse/issues/11979)) + + Synapse 1.52.0 (2022-02-08) =========================== diff --git a/README.rst b/README.rst index 50de3a49b0..4281c87d1f 100644 --- a/README.rst +++ b/README.rst @@ -246,7 +246,7 @@ Password reset ============== Users can reset their password through their client. Alternatively, a server admin -can reset a users password using the `admin API <docs/admin_api/user_admin_api.rst#reset-password>`_ +can reset a users password using the `admin API <docs/admin_api/user_admin_api.md#reset-password>`_ or by directly editing the database as shown below. First calculate the hash of the new password:: diff --git a/changelog.d/10870.misc b/changelog.d/10870.misc new file mode 100644 index 0000000000..3af049b969 --- /dev/null +++ b/changelog.d/10870.misc @@ -0,0 +1 @@ +Deduplicate in-flight requests in `_get_state_for_groups`. diff --git a/changelog.d/11835.feature b/changelog.d/11835.feature new file mode 100644 index 0000000000..7cee39b08c --- /dev/null +++ b/changelog.d/11835.feature @@ -0,0 +1 @@ +Make a `POST` to `/rooms/<room_id>/receipt/m.read/<event_id>` only trigger a push notification if the count of unread messages is different to the one in the last successfully sent push. diff --git a/changelog.d/11972.misc b/changelog.d/11972.misc new file mode 100644 index 0000000000..29c38bfd82 --- /dev/null +++ b/changelog.d/11972.misc @@ -0,0 +1 @@ +Add tests for device list changes between local users. \ No newline at end of file diff --git a/changelog.d/11974.misc b/changelog.d/11974.misc new file mode 100644 index 0000000000..1debad2361 --- /dev/null +++ b/changelog.d/11974.misc @@ -0,0 +1 @@ +Optimise calculating device_list changes in `/sync`. diff --git a/changelog.d/11984.misc b/changelog.d/11984.misc new file mode 100644 index 0000000000..8e405b9226 --- /dev/null +++ b/changelog.d/11984.misc @@ -0,0 +1 @@ +Add missing type hints to storage classes. \ No newline at end of file diff --git a/changelog.d/11991.misc b/changelog.d/11991.misc new file mode 100644 index 0000000000..34a3b3a6b9 --- /dev/null +++ b/changelog.d/11991.misc @@ -0,0 +1 @@ +Refactor the search code for improved readability. diff --git a/changelog.d/11992.bugfix b/changelog.d/11992.bugfix new file mode 100644 index 0000000000..f73c86bb25 --- /dev/null +++ b/changelog.d/11992.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary. diff --git a/changelog.d/11994.misc b/changelog.d/11994.misc new file mode 100644 index 0000000000..d64297dd78 --- /dev/null +++ b/changelog.d/11994.misc @@ -0,0 +1 @@ +Move common deduplication code down into `_auth_and_persist_outliers`. diff --git a/changelog.d/11996.misc b/changelog.d/11996.misc new file mode 100644 index 0000000000..6c675fd193 --- /dev/null +++ b/changelog.d/11996.misc @@ -0,0 +1 @@ +Limit concurrent joins from applications services. \ No newline at end of file diff --git a/changelog.d/11997.docker b/changelog.d/11997.docker new file mode 100644 index 0000000000..1b3271457e --- /dev/null +++ b/changelog.d/11997.docker @@ -0,0 +1 @@ +The docker image no longer automatically creates a temporary volume at `/data`. This is not expected to affect normal usage. diff --git a/changelog.d/11999.bugfix b/changelog.d/11999.bugfix new file mode 100644 index 0000000000..fd84095900 --- /dev/null +++ b/changelog.d/11999.bugfix @@ -0,0 +1 @@ +Fix long standing bug where `get_rooms_for_user` was not correctly invalidated for remote users when the server left a room. diff --git a/changelog.d/12000.feature b/changelog.d/12000.feature new file mode 100644 index 0000000000..246cc87f0b --- /dev/null +++ b/changelog.d/12000.feature @@ -0,0 +1 @@ +Track cache invalidations in Prometheus metrics, as already happens for cache eviction based on size or time. diff --git a/changelog.d/12003.doc b/changelog.d/12003.doc new file mode 100644 index 0000000000..1ac8163559 --- /dev/null +++ b/changelog.d/12003.doc @@ -0,0 +1 @@ +Explain the meaning of spam checker callbacks' return values. diff --git a/changelog.d/12004.doc b/changelog.d/12004.doc new file mode 100644 index 0000000000..0b4baef210 --- /dev/null +++ b/changelog.d/12004.doc @@ -0,0 +1 @@ +Clarify information about external Identity Provider IDs. diff --git a/changelog.d/12005.misc b/changelog.d/12005.misc new file mode 100644 index 0000000000..45e21dbe59 --- /dev/null +++ b/changelog.d/12005.misc @@ -0,0 +1 @@ +Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`. diff --git a/changelog.d/12008.removal b/changelog.d/12008.removal new file mode 100644 index 0000000000..57599d9ee9 --- /dev/null +++ b/changelog.d/12008.removal @@ -0,0 +1 @@ +Remove support for the legacy structured logging configuration (please see the the [upgrade notes](https://matrix-org.github.io/synapse/develop/upgrade#legacy-structured-logging-configuration-removal) if you are using `structured: true` in the Synapse configuration). diff --git a/changelog.d/12009.feature b/changelog.d/12009.feature new file mode 100644 index 0000000000..c8a531481e --- /dev/null +++ b/changelog.d/12009.feature @@ -0,0 +1 @@ +Enable modules to set a custom display name when registering a user. diff --git a/changelog.d/12011.misc b/changelog.d/12011.misc new file mode 100644 index 0000000000..258b0e389f --- /dev/null +++ b/changelog.d/12011.misc @@ -0,0 +1 @@ +Preparation for faster-room-join work: parse msc3706 fields in send_join response. diff --git a/changelog.d/12015.misc b/changelog.d/12015.misc new file mode 100644 index 0000000000..3aa32ab4cf --- /dev/null +++ b/changelog.d/12015.misc @@ -0,0 +1 @@ +Configure `tox` to use `venv` rather than `virtualenv`. diff --git a/changelog.d/12016.misc b/changelog.d/12016.misc new file mode 100644 index 0000000000..8856ef46a9 --- /dev/null +++ b/changelog.d/12016.misc @@ -0,0 +1 @@ +Fix bug in `StateFilter.return_expanded()` and add some tests. \ No newline at end of file diff --git a/changelog.d/12018.removal b/changelog.d/12018.removal new file mode 100644 index 0000000000..e940b62228 --- /dev/null +++ b/changelog.d/12018.removal @@ -0,0 +1 @@ +Drop support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283) unstable flags now that the stable flags are supported. diff --git a/changelog.d/12019.misc b/changelog.d/12019.misc new file mode 100644 index 0000000000..b2186320ea --- /dev/null +++ b/changelog.d/12019.misc @@ -0,0 +1 @@ +Use Matrix v1.1 endpoints (`/_matrix/client/v3/auth/...`) in fallback auth HTML forms. \ No newline at end of file diff --git a/changelog.d/12020.feature b/changelog.d/12020.feature new file mode 100644 index 0000000000..1ac9d2060e --- /dev/null +++ b/changelog.d/12020.feature @@ -0,0 +1 @@ +Advertise Matrix 1.1 support on `/_matrix/client/versions`. \ No newline at end of file diff --git a/changelog.d/12021.feature b/changelog.d/12021.feature new file mode 100644 index 0000000000..01378df8ca --- /dev/null +++ b/changelog.d/12021.feature @@ -0,0 +1 @@ +Support only the stable identifier for [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069)'s `is_guest` on `/_matrix/client/v3/account/whoami`. \ No newline at end of file diff --git a/changelog.d/12022.feature b/changelog.d/12022.feature new file mode 100644 index 0000000000..188fb12570 --- /dev/null +++ b/changelog.d/12022.feature @@ -0,0 +1 @@ +Advertise Matrix 1.2 support on `/_matrix/client/versions`. \ No newline at end of file diff --git a/changelog.d/12024.bugfix b/changelog.d/12024.bugfix new file mode 100644 index 0000000000..59bcdb93a5 --- /dev/null +++ b/changelog.d/12024.bugfix @@ -0,0 +1 @@ +Fix 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint. diff --git a/changelog.d/12025.misc b/changelog.d/12025.misc new file mode 100644 index 0000000000..d9475a7718 --- /dev/null +++ b/changelog.d/12025.misc @@ -0,0 +1 @@ +Update the `olddeps` CI job to use an old version of `markupsafe`. diff --git a/changelog.d/12030.misc b/changelog.d/12030.misc new file mode 100644 index 0000000000..607ee97ce6 --- /dev/null +++ b/changelog.d/12030.misc @@ -0,0 +1 @@ +Upgrade mypy to version 0.931. diff --git a/changelog.d/12033.misc b/changelog.d/12033.misc new file mode 100644 index 0000000000..3af049b969 --- /dev/null +++ b/changelog.d/12033.misc @@ -0,0 +1 @@ +Deduplicate in-flight requests in `_get_state_for_groups`. diff --git a/changelog.d/12034.misc b/changelog.d/12034.misc new file mode 100644 index 0000000000..8374a63220 --- /dev/null +++ b/changelog.d/12034.misc @@ -0,0 +1 @@ +Minor typing fixes. diff --git a/changelog.d/12039.misc b/changelog.d/12039.misc new file mode 100644 index 0000000000..45e21dbe59 --- /dev/null +++ b/changelog.d/12039.misc @@ -0,0 +1 @@ +Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`. diff --git a/changelog.d/12051.misc b/changelog.d/12051.misc new file mode 100644 index 0000000000..9959191352 --- /dev/null +++ b/changelog.d/12051.misc @@ -0,0 +1 @@ +Tidy up GitHub Actions config which builds distributions for PyPI. \ No newline at end of file diff --git a/changelog.d/12052.misc b/changelog.d/12052.misc new file mode 100644 index 0000000000..fbaff67e95 --- /dev/null +++ b/changelog.d/12052.misc @@ -0,0 +1 @@ +Move `isort` configuration to `pyproject.toml`. diff --git a/debian/changelog b/debian/changelog index 64ea103f3e..fe79d7ed57 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.53.0~rc1) stable; urgency=medium + + * New synapse release 1.53.0~rc1. + + -- Synapse Packaging team <packages@matrix.org> Tue, 15 Feb 2022 10:40:50 +0000 + matrix-synapse-py3 (1.52.0) stable; urgency=medium * New synapse release 1.52.0. diff --git a/docker/Dockerfile b/docker/Dockerfile index 306f75ae56..e4c1c19b86 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -98,8 +98,6 @@ COPY --from=builder /install /usr/local COPY ./docker/start.py /start.py COPY ./docker/conf /conf -VOLUME ["/data"] - EXPOSE 8008/tcp 8009/tcp 8448/tcp ENTRYPOINT ["/start.py"] diff --git a/docs/MSC1711_certificates_FAQ.md b/docs/MSC1711_certificates_FAQ.md deleted file mode 100644 index 32ba15652d..0000000000 --- a/docs/MSC1711_certificates_FAQ.md +++ /dev/null @@ -1,314 +0,0 @@ -# MSC1711 Certificates FAQ - -## Historical Note -This document was originally written to guide server admins through the upgrade -path towards Synapse 1.0. Specifically, -[MSC1711](https://github.com/matrix-org/matrix-doc/blob/main/proposals/1711-x509-for-federation.md) -required that all servers present valid TLS certificates on their federation -API. Admins were encouraged to achieve compliance from version 0.99.0 (released -in February 2019) ahead of version 1.0 (released June 2019) enforcing the -certificate checks. - -Much of what follows is now outdated since most admins will have already -upgraded, however it may be of use to those with old installs returning to the -project. - -If you are setting up a server from scratch you almost certainly should look at -the [installation guide](setup/installation.md) instead. - -## Introduction -The goal of Synapse 0.99.0 is to act as a stepping stone to Synapse 1.0.0. It -supports the r0.1 release of the server to server specification, but is -compatible with both the legacy Matrix federation behaviour (pre-r0.1) as well -as post-r0.1 behaviour, in order to allow for a smooth upgrade across the -federation. - -The most important thing to know is that Synapse 1.0.0 will require a valid TLS -certificate on federation endpoints. Self signed certificates will not be -sufficient. - -Synapse 0.99.0 makes it easy to configure TLS certificates and will -interoperate with both >= 1.0.0 servers as well as existing servers yet to -upgrade. - -**It is critical that all admins upgrade to 0.99.0 and configure a valid TLS -certificate.** Admins will have 1 month to do so, after which 1.0.0 will be -released and those servers without a valid certificate will not longer be able -to federate with >= 1.0.0 servers. - -Full details on how to carry out this configuration change is given -[below](#configuring-certificates-for-compatibility-with-synapse-100). A -timeline and some frequently asked questions are also given below. - -For more details and context on the release of the r0.1 Server/Server API and -imminent Matrix 1.0 release, you can also see our -[main talk from FOSDEM 2019](https://matrix.org/blog/2019/02/04/matrix-at-fosdem-2019/). - -## Timeline - -**5th Feb 2019 - Synapse 0.99.0 is released.** - -All server admins are encouraged to upgrade. - -0.99.0: - -- provides support for ACME to make setting up Let's Encrypt certs easy, as - well as .well-known support. - -- does not enforce that a valid CA cert is present on the federation API, but - rather makes it easy to set one up. - -- provides support for .well-known - -Admins should upgrade and configure a valid CA cert. Homeservers that require a -.well-known entry (see below), should retain their SRV record and use it -alongside their .well-known record. - -**10th June 2019 - Synapse 1.0.0 is released** - -1.0.0 is scheduled for release on 10th June. In -accordance with the the [S2S spec](https://matrix.org/docs/spec/server_server/r0.1.0.html) -1.0.0 will enforce certificate validity. This means that any homeserver without a -valid certificate after this point will no longer be able to federate with -1.0.0 servers. - -## Configuring certificates for compatibility with Synapse 1.0.0 - -### If you do not currently have an SRV record - -In this case, your `server_name` points to the host where your Synapse is -running. There is no need to create a `.well-known` URI or an SRV record, but -you will need to give Synapse a valid, signed, certificate. - -### If you do have an SRV record currently - -If you are using an SRV record, your matrix domain (`server_name`) may not -point to the same host that your Synapse is running on (the 'target -domain'). (If it does, you can follow the recommendation above; otherwise, read -on.) - -Let's assume that your `server_name` is `example.com`, and your Synapse is -hosted at a target domain of `customer.example.net`. Currently you should have -an SRV record which looks like: - -``` -_matrix._tcp.example.com. IN SRV 10 5 8000 customer.example.net. -``` - -In this situation, you have three choices for how to proceed: - -#### Option 1: give Synapse a certificate for your matrix domain - -Synapse 1.0 will expect your server to present a TLS certificate for your -`server_name` (`example.com` in the above example). You can achieve this by acquiring a -certificate for the `server_name` yourself (for example, using `certbot`), and giving it -and the key to Synapse via `tls_certificate_path` and `tls_private_key_path`. - -#### Option 2: run Synapse behind a reverse proxy - -If you have an existing reverse proxy set up with correct TLS certificates for -your domain, you can simply route all traffic through the reverse proxy by -updating the SRV record appropriately (or removing it, if the proxy listens on -8448). - -See [the reverse proxy documentation](reverse_proxy.md) for information on setting up a -reverse proxy. - -#### Option 3: add a .well-known file to delegate your matrix traffic - -This will allow you to keep Synapse on a separate domain, without having to -give it a certificate for the matrix domain. - -You can do this with a `.well-known` file as follows: - - 1. Keep the SRV record in place - it is needed for backwards compatibility - with Synapse 0.34 and earlier. - - 2. Give Synapse a certificate corresponding to the target domain - (`customer.example.net` in the above example). You can do this by acquire a - certificate for the target domain and giving it to Synapse via `tls_certificate_path` - and `tls_private_key_path`. - - 3. Restart Synapse to ensure the new certificate is loaded. - - 4. Arrange for a `.well-known` file at - `https://<server_name>/.well-known/matrix/server` with contents: - - ```json - {"m.server": "<target server name>"} - ``` - - where the target server name is resolved as usual (i.e. SRV lookup, falling - back to talking to port 8448). - - In the above example, where synapse is listening on port 8000, - `https://example.com/.well-known/matrix/server` should have `m.server` set to one of: - - 1. `customer.example.net` ─ with a SRV record on - `_matrix._tcp.customer.example.com` pointing to port 8000, or: - - 2. `customer.example.net` ─ updating synapse to listen on the default port - 8448, or: - - 3. `customer.example.net:8000` ─ ensuring that if there is a reverse proxy - on `customer.example.net:8000` it correctly handles HTTP requests with - Host header set to `customer.example.net:8000`. - -## FAQ - -### Synapse 0.99.0 has just been released, what do I need to do right now? - -Upgrade as soon as you can in preparation for Synapse 1.0.0, and update your -TLS certificates as [above](#configuring-certificates-for-compatibility-with-synapse-100). - -### What will happen if I do not set up a valid federation certificate immediately? - -Nothing initially, but once 1.0.0 is in the wild it will not be possible to -federate with 1.0.0 servers. - -### What will happen if I do nothing at all? - -If the admin takes no action at all, and remains on a Synapse < 0.99.0 then the -homeserver will be unable to federate with those who have implemented -.well-known. Then, as above, once the month upgrade window has expired the -homeserver will not be able to federate with any Synapse >= 1.0.0 - -### When do I need a SRV record or .well-known URI? - -If your homeserver listens on the default federation port (8448), and your -`server_name` points to the host that your homeserver runs on, you do not need an -SRV record or `.well-known/matrix/server` URI. - -For instance, if you registered `example.com` and pointed its DNS A record at a -fresh Upcloud VPS or similar, you could install Synapse 0.99 on that host, -giving it a server_name of `example.com`, and it would automatically generate a -valid TLS certificate for you via Let's Encrypt and no SRV record or -`.well-known` URI would be needed. - -This is the common case, although you can add an SRV record or -`.well-known/matrix/server` URI for completeness if you wish. - -**However**, if your server does not listen on port 8448, or if your `server_name` -does not point to the host that your homeserver runs on, you will need to let -other servers know how to find it. - -In this case, you should see ["If you do have an SRV record -currently"](#if-you-do-have-an-srv-record-currently) above. - -### Can I still use an SRV record? - -Firstly, if you didn't need an SRV record before (because your server is -listening on port 8448 of your server_name), you certainly don't need one now: -the defaults are still the same. - -If you previously had an SRV record, you can keep using it provided you are -able to give Synapse a TLS certificate corresponding to your server name. For -example, suppose you had the following SRV record, which directs matrix traffic -for example.com to matrix.example.com:443: - -``` -_matrix._tcp.example.com. IN SRV 10 5 443 matrix.example.com -``` - -In this case, Synapse must be given a certificate for example.com - or be -configured to acquire one from Let's Encrypt. - -If you are unable to give Synapse a certificate for your server_name, you will -also need to use a .well-known URI instead. However, see also "I have created a -.well-known URI. Do I still need an SRV record?". - -### I have created a .well-known URI. Do I still need an SRV record? - -As of Synapse 0.99, Synapse will first check for the existence of a `.well-known` -URI and follow any delegation it suggests. It will only then check for the -existence of an SRV record. - -That means that the SRV record will often be redundant. However, you should -remember that there may still be older versions of Synapse in the federation -which do not understand `.well-known` URIs, so if you removed your SRV record you -would no longer be able to federate with them. - -It is therefore best to leave the SRV record in place for now. Synapse 0.34 and -earlier will follow the SRV record (and not care about the invalid -certificate). Synapse 0.99 and later will follow the .well-known URI, with the -correct certificate chain. - -### It used to work just fine, why are you breaking everything? - -We have always wanted Matrix servers to be as easy to set up as possible, and -so back when we started federation in 2014 we didn't want admins to have to go -through the cumbersome process of buying a valid TLS certificate to run a -server. This was before Let's Encrypt came along and made getting a free and -valid TLS certificate straightforward. So instead, we adopted a system based on -[Perspectives](https://en.wikipedia.org/wiki/Convergence_(SSL)): an approach -where you check a set of "notary servers" (in practice, homeservers) to vouch -for the validity of a certificate rather than having it signed by a CA. As long -as enough different notaries agree on the certificate's validity, then it is -trusted. - -However, in practice this has never worked properly. Most people only use the -default notary server (matrix.org), leading to inadvertent centralisation which -we want to eliminate. Meanwhile, we never implemented the full consensus -algorithm to query the servers participating in a room to determine consensus -on whether a given certificate is valid. This is fiddly to get right -(especially in face of sybil attacks), and we found ourselves questioning -whether it was worth the effort to finish the work and commit to maintaining a -secure certificate validation system as opposed to focusing on core Matrix -development. - -Meanwhile, Let's Encrypt came along in 2016, and put the final nail in the -coffin of the Perspectives project (which was already pretty dead). So, the -Spec Core Team decided that a better approach would be to mandate valid TLS -certificates for federation alongside the rest of the Web. More details can be -found in -[MSC1711](https://github.com/matrix-org/matrix-doc/blob/main/proposals/1711-x509-for-federation.md#background-the-failure-of-the-perspectives-approach). - -This results in a breaking change, which is disruptive, but absolutely critical -for the security model. However, the existence of Let's Encrypt as a trivial -way to replace the old self-signed certificates with valid CA-signed ones helps -smooth things over massively, especially as Synapse can now automate Let's -Encrypt certificate generation if needed. - -### Can I manage my own certificates rather than having Synapse renew certificates itself? - -Yes, you are welcome to manage your certificates yourself. Synapse will only -attempt to obtain certificates from Let's Encrypt if you configure it to do -so.The only requirement is that there is a valid TLS cert present for -federation end points. - -### Do you still recommend against using a reverse proxy on the federation port? - -We no longer actively recommend against using a reverse proxy. Many admins will -find it easier to direct federation traffic to a reverse proxy and manage their -own TLS certificates, and this is a supported configuration. - -See [the reverse proxy documentation](reverse_proxy.md) for information on setting up a -reverse proxy. - -### Do I still need to give my TLS certificates to Synapse if I am using a reverse proxy? - -Practically speaking, this is no longer necessary. - -If you are using a reverse proxy for all of your TLS traffic, then you can set -`no_tls: True`. In that case, the only reason Synapse needs the certificate is -to populate a legacy 'tls_fingerprints' field in the federation API. This is -ignored by Synapse 0.99.0 and later, and the only time pre-0.99 Synapses will -check it is when attempting to fetch the server keys - and generally this is -delegated via `matrix.org`, which is on 0.99.0. - -However, there is a bug in Synapse 0.99.0 -[4554](<https://github.com/matrix-org/synapse/issues/4554>) which prevents -Synapse from starting if you do not give it a TLS certificate. To work around -this, you can give it any TLS certificate at all. This will be fixed soon. - -### Do I need the same certificate for the client and federation port? - -No. There is nothing stopping you from using different certificates, -particularly if you are using a reverse proxy. However, Synapse will use the -same certificate on any ports where TLS is configured. - -### How do I tell Synapse to reload my keys/certificates after I replace them? - -Synapse will reload the keys and certificates when it receives a SIGHUP - for -example `kill -HUP $(cat homeserver.pid)`. Alternatively, simply restart -Synapse, though this will result in downtime while it restarts. diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 11f597b3ed..ef9cabf555 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -13,7 +13,6 @@ # Upgrading - [Upgrading between Synapse Versions](upgrade.md) - - [Upgrading from pre-Synapse 1.0](MSC1711_certificates_FAQ.md) # Usage - [Federation](federate.md) @@ -72,7 +71,7 @@ - [Understanding Synapse Through Grafana Graphs](usage/administration/understanding_synapse_through_grafana_graphs.md) - [Useful SQL for Admins](usage/administration/useful_sql_for_admins.md) - [Database Maintenance Tools](usage/administration/database_maintenance_tools.md) - - [State Groups](usage/administration/state_groups.md) + - [State Groups](usage/administration/state_groups.md) - [Request log format](usage/administration/request_log.md) - [Admin FAQ](usage/administration/admin_faq.md) - [Scripts]() @@ -80,6 +79,7 @@ # Development - [Contributing Guide](development/contributing_guide.md) - [Code Style](code_style.md) + - [Release Cycle](development/releases.md) - [Git Usage](development/git.md) - [Testing]() - [OpenTracing](opentracing.md) diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index a8cdf19727..96b3668f2a 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -2,6 +2,9 @@ These APIs allow extracting media information from the homeserver. +Details about the format of the `media_id` and storage of the media in the file system +are documented under [media repository](../media_repository.md). + To use it, you will need to authenticate by providing an `access_token` for a server admin: see [Admin API](../usage/administration/admin_api). diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 4f5f377b38..4076fcab65 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -126,7 +126,8 @@ Body parameters: [Sample Configuration File](../usage/configuration/homeserver_sample_config.html) section `sso` and `oidc_providers`. - `auth_provider` - string. ID of the external identity provider. Value of `idp_id` - in homeserver configuration. + in the homeserver configuration. Note that no error is raised if the provided + value is not in the homeserver configuration. - `external_id` - string, user ID in the external identity provider. - `avatar_url` - string, optional, must be a [MXC URI](https://matrix.org/docs/spec/client_server/r0.6.0#matrix-content-mxc-uris). @@ -331,7 +332,7 @@ An empty body may be passed for backwards compatibility. The following actions are performed when deactivating an user: -- Try to unpind 3PIDs from the identity server +- Try to unbind 3PIDs from the identity server - Remove all 3PIDs from the homeserver - Delete all devices and E2EE keys - Delete all access tokens @@ -539,6 +540,11 @@ The following fields are returned in the JSON response body: ### List media uploaded by a user Gets a list of all local media that a specific `user_id` has created. +These are media that the user has uploaded themselves +([local media](../media_repository.md#local-media)), as well as +[URL preview images](../media_repository.md#url-previews) requested by the user if the +[feature is enabled](../development/url_previews.md). + By default, the response is ordered by descending creation date and ascending media ID. The newest media is on top. You can change the order with parameters `order_by` and `dir`. @@ -635,7 +641,9 @@ The following fields are returned in the JSON response body: Media objects contain the following fields: - `created_ts` - integer - Timestamp when the content was uploaded in ms. - `last_access_ts` - integer - Timestamp when the content was last accessed in ms. - - `media_id` - string - The id used to refer to the media. + - `media_id` - string - The id used to refer to the media. Details about the format + are documented under + [media repository](../media_repository.md). - `media_length` - integer - Length of the media in bytes. - `media_type` - string - The MIME-type of the media. - `quarantined_by` - string - The user ID that initiated the quarantine request diff --git a/docs/development/releases.md b/docs/development/releases.md new file mode 100644 index 0000000000..c9a8c69945 --- /dev/null +++ b/docs/development/releases.md @@ -0,0 +1,37 @@ +# Synapse Release Cycle + +Releases of Synapse follow a two week release cycle with new releases usually +occurring on Tuesdays: + +* Day 0: Synapse `N - 1` is released. +* Day 7: Synapse `N` release candidate 1 is released. +* Days 7 - 13: Synapse `N` release candidates 2+ are released, if bugs are found. +* Day 14: Synapse `N` is released. + +Note that this schedule might be modified depending on the availability of the +Synapse team, e.g. releases may be skipped to avoid holidays. + +Release announcements can be found in the +[release category of the Matrix blog](https://matrix.org/blog/category/releases). + +## Bugfix releases + +If a bug is found after release that is deemed severe enough (by a combination +of the impacted users and the impact on those users) then a bugfix release may +be issued. This may be at any point in the release cycle. + +## Security releases + +Security will sometimes be backported to the previous version and released +immediately before the next release candidate. An example of this might be: + +* Day 0: Synapse N - 1 is released. +* Day 7: Synapse (N - 1).1 is released as Synapse N - 1 + the security fix. +* Day 7: Synapse N release candidate 1 is released (including the security fix). + +Depending on the impact and complexity of security fixes, multiple fixes might +be held to be released together. + +In some cases, a pre-disclosure of a security release will be issued as a notice +to Synapse operators that there is an upcoming security release. These can be +found in the [security category of the Matrix blog](https://matrix.org/blog/category/security). diff --git a/docs/modules/password_auth_provider_callbacks.md b/docs/modules/password_auth_provider_callbacks.md index ec8324d292..ec810fd292 100644 --- a/docs/modules/password_auth_provider_callbacks.md +++ b/docs/modules/password_auth_provider_callbacks.md @@ -85,7 +85,7 @@ If the authentication is unsuccessful, the module must return `None`. If multiple modules implement this callback, they will be considered in order. If a callback returns `None`, Synapse falls through to the next one. The value of the first callback that does not return `None` will be used. If this happens, Synapse will not call -any of the subsequent implementations of this callback. If every callback return `None`, +any of the subsequent implementations of this callback. If every callback returns `None`, the authentication is denied. ### `on_logged_out` @@ -148,7 +148,7 @@ Here's an example featuring all currently supported keys: "address": "33123456789", "validated_at": 1642701357084, }, - "org.matrix.msc3231.login.registration_token": "sometoken", # User has registered through the flow described in MSC3231 + "m.login.registration_token": "sometoken", # User has registered through a registration token } ``` @@ -162,10 +162,57 @@ return `None`. If multiple modules implement this callback, they will be considered in order. If a callback returns `None`, Synapse falls through to the next one. The value of the first callback that does not return `None` will be used. If this happens, Synapse will not call -any of the subsequent implementations of this callback. If every callback return `None`, +any of the subsequent implementations of this callback. If every callback returns `None`, the username provided by the user is used, if any (otherwise one is automatically generated). +### `get_displayname_for_registration` + +_First introduced in Synapse v1.54.0_ + +```python +async def get_displayname_for_registration( + uia_results: Dict[str, Any], + params: Dict[str, Any], +) -> Optional[str] +``` + +Called when registering a new user. The module can return a display name to set for the +user being registered by returning it as a string, or `None` if it doesn't wish to force a +display name for this user. + +This callback is called once [User-Interactive Authentication](https://spec.matrix.org/latest/client-server-api/#user-interactive-authentication-api) +has been completed by the user. It is not called when registering a user via SSO. It is +passed two dictionaries, which include the information that the user has provided during +the registration process. These dictionaries are identical to the ones passed to +[`get_username_for_registration`](#get_username_for_registration), so refer to the +documentation of this callback for more information about them. + +If multiple modules implement this callback, they will be considered in order. If a +callback returns `None`, Synapse falls through to the next one. The value of the first +callback that does not return `None` will be used. If this happens, Synapse will not call +any of the subsequent implementations of this callback. If every callback returns `None`, +the username will be used (e.g. `alice` if the user being registered is `@alice:example.com`). + +## `is_3pid_allowed` + +_First introduced in Synapse v1.53.0_ + +```python +async def is_3pid_allowed(self, medium: str, address: str, registration: bool) -> bool +``` + +Called when attempting to bind a third-party identifier (i.e. an email address or a phone +number). The module is given the medium of the third-party identifier (which is `email` if +the identifier is an email address, or `msisdn` if the identifier is a phone number) and +its address, as well as a boolean indicating whether the attempt to bind is happening as +part of registering a new user. The module must return a boolean indicating whether the +identifier can be allowed to be bound to an account on the local homeserver. + +If multiple modules implement this callback, they will be considered in order. If a +callback returns `True`, Synapse falls through to the next one. The value of the first +callback that does not return `True` will be used. If this happens, Synapse will not call +any of the subsequent implementations of this callback. ## Example @@ -175,8 +222,7 @@ The example module below implements authentication checkers for two different lo - Is checked by the method: `self.check_my_login` - `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based)) - Expects a `password` field to be sent to `/login` - - Is checked by the method: `self.check_pass` - + - Is checked by the method: `self.check_pass` ```python from typing import Awaitable, Callable, Optional, Tuple diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md index 2eb9032f41..2b672b78f9 100644 --- a/docs/modules/spam_checker_callbacks.md +++ b/docs/modules/spam_checker_callbacks.md @@ -16,10 +16,12 @@ _First introduced in Synapse v1.37.0_ async def check_event_for_spam(event: "synapse.events.EventBase") -> Union[bool, str] ``` -Called when receiving an event from a client or via federation. The module can return -either a `bool` to indicate whether the event must be rejected because of spam, or a `str` -to indicate the event must be rejected because of spam and to give a rejection reason to -forward to clients. +Called when receiving an event from a client or via federation. The callback must return +either: +- an error message string, to indicate the event must be rejected because of spam and + give a rejection reason to forward to clients; +- the boolean `True`, to indicate that the event is spammy, but not provide further details; or +- the booelan `False`, to indicate that the event is not considered spammy. If multiple modules implement this callback, they will be considered in order. If a callback returns `False`, Synapse falls through to the next one. The value of the first @@ -35,7 +37,10 @@ async def user_may_join_room(user: str, room: str, is_invited: bool) -> bool ``` Called when a user is trying to join a room. The module must return a `bool` to indicate -whether the user can join the room. The user is represented by their Matrix user ID (e.g. +whether the user can join the room. Return `False` to prevent the user from joining the +room; otherwise return `True` to permit the joining. + +The user is represented by their Matrix user ID (e.g. `@alice:example.com`) and the room is represented by its Matrix ID (e.g. `!room:example.com`). The module is also given a boolean to indicate whether the user currently has a pending invite in the room. @@ -58,7 +63,8 @@ async def user_may_invite(inviter: str, invitee: str, room_id: str) -> bool Called when processing an invitation. The module must return a `bool` indicating whether the inviter can invite the invitee to the given room. Both inviter and invitee are -represented by their Matrix user ID (e.g. `@alice:example.com`). +represented by their Matrix user ID (e.g. `@alice:example.com`). Return `False` to prevent +the invitation; otherwise return `True` to permit it. If multiple modules implement this callback, they will be considered in order. If a callback returns `True`, Synapse falls through to the next one. The value of the first @@ -80,7 +86,8 @@ async def user_may_send_3pid_invite( Called when processing an invitation using a third-party identifier (also called a 3PID, e.g. an email address or a phone number). The module must return a `bool` indicating -whether the inviter can invite the invitee to the given room. +whether the inviter can invite the invitee to the given room. Return `False` to prevent +the invitation; otherwise return `True` to permit it. The inviter is represented by their Matrix user ID (e.g. `@alice:example.com`), and the invitee is represented by its medium (e.g. "email") and its address @@ -117,6 +124,7 @@ async def user_may_create_room(user: str) -> bool Called when processing a room creation request. The module must return a `bool` indicating whether the given user (represented by their Matrix user ID) is allowed to create a room. +Return `False` to prevent room creation; otherwise return `True` to permit it. If multiple modules implement this callback, they will be considered in order. If a callback returns `True`, Synapse falls through to the next one. The value of the first @@ -133,7 +141,8 @@ async def user_may_create_room_alias(user: str, room_alias: "synapse.types.RoomA Called when trying to associate an alias with an existing room. The module must return a `bool` indicating whether the given user (represented by their Matrix user ID) is allowed -to set the given alias. +to set the given alias. Return `False` to prevent the alias creation; otherwise return +`True` to permit it. If multiple modules implement this callback, they will be considered in order. If a callback returns `True`, Synapse falls through to the next one. The value of the first @@ -150,7 +159,8 @@ async def user_may_publish_room(user: str, room_id: str) -> bool Called when trying to publish a room to the homeserver's public rooms directory. The module must return a `bool` indicating whether the given user (represented by their -Matrix user ID) is allowed to publish the given room. +Matrix user ID) is allowed to publish the given room. Return `False` to prevent the +room from being published; otherwise return `True` to permit its publication. If multiple modules implement this callback, they will be considered in order. If a callback returns `True`, Synapse falls through to the next one. The value of the first @@ -166,8 +176,11 @@ async def check_username_for_spam(user_profile: Dict[str, str]) -> bool ``` Called when computing search results in the user directory. The module must return a -`bool` indicating whether the given user profile can appear in search results. The profile -is represented as a dictionary with the following keys: +`bool` indicating whether the given user should be excluded from user directory +searches. Return `True` to indicate that the user is spammy and exclude them from +search results; otherwise return `False`. + +The profile is represented as a dictionary with the following keys: * `user_id`: The Matrix ID for this user. * `display_name`: The user's display name. @@ -225,8 +238,9 @@ async def check_media_file_for_spam( ) -> bool ``` -Called when storing a local or remote file. The module must return a boolean indicating -whether the given file can be stored in the homeserver's media store. +Called when storing a local or remote file. The module must return a `bool` indicating +whether the given file should be excluded from the homeserver's media store. Return +`True` to prevent this file from being stored; otherwise return `False`. If multiple modules implement this callback, they will be considered in order. If a callback returns `False`, Synapse falls through to the next one. The value of the first diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 689b207fc0..d2bb3d4208 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -751,11 +751,16 @@ caches: per_cache_factors: #get_users_who_share_room_with_user: 2.0 - # Controls how long an entry can be in a cache without having been - # accessed before being evicted. Defaults to None, which means - # entries are never evicted based on time. + # Controls whether cache entries are evicted after a specified time + # period. Defaults to true. Uncomment to disable this feature. # - #expiry_time: 30m + #expire_caches: false + + # If expire_caches is enabled, this flag controls how long an entry can + # be in a cache without having been accessed before being evicted. + # Defaults to 30m. Uncomment to set a different time to live for cache entries. + # + #cache_entry_ttl: 30m # Controls how long the results of a /sync request are cached for after # a successful response is returned. A higher duration can help clients with @@ -857,6 +862,9 @@ log_config: "CONFDIR/SERVERNAME.log.config" # - one for ratelimiting how often a user or IP can attempt to validate a 3PID. # - two for ratelimiting how often invites can be sent in a room or to a # specific user. +# - one for ratelimiting 3PID invites (i.e. invites sent to a third-party ID +# such as an email address or a phone number) based on the account that's +# sending the invite. # # The defaults are as shown below. # @@ -906,6 +914,10 @@ log_config: "CONFDIR/SERVERNAME.log.config" # per_user: # per_second: 0.003 # burst_count: 5 +# +#rc_third_party_invite: +# per_second: 0.2 +# burst_count: 10 # Ratelimiting settings for incoming federation # diff --git a/docs/structured_logging.md b/docs/structured_logging.md index b1281667e0..805c867653 100644 --- a/docs/structured_logging.md +++ b/docs/structured_logging.md @@ -81,14 +81,12 @@ remote endpoint at 10.1.2.3:9999. ## Upgrading from legacy structured logging configuration -Versions of Synapse prior to v1.23.0 included a custom structured logging -configuration which is deprecated. It used a `structured: true` flag and -configured `drains` instead of ``handlers`` and `formatters`. +Versions of Synapse prior to v1.54.0 automatically converted the legacy +structured logging configuration, which was deprecated in v1.23.0, to the standard +library logging configuration. -Synapse currently automatically converts the old configuration to the new -configuration, but this will be removed in a future version of Synapse. The -following reference can be used to update your configuration. Based on the drain -`type`, we can pick a new handler: +The following reference can be used to update your configuration. Based on the +drain `type`, we can pick a new handler: 1. For a type of `console`, `console_json`, or `console_json_terse`: a handler with a class of `logging.StreamHandler` and a `stream` of `ext://sys.stdout` @@ -141,7 +139,7 @@ formatters: handlers: console: class: logging.StreamHandler - location: ext://sys.stdout + stream: ext://sys.stdout file: class: logging.FileHandler formatter: json diff --git a/docs/upgrade.md b/docs/upgrade.md index df873e5317..f9be3ac6bc 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -85,6 +85,79 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.54.0 + +## Legacy structured logging configuration removal + +This release removes support for the `structured: true` logging configuration +which was deprecated in Synapse v1.23.0. If your logging configuration contains +`structured: true` then it should be modified based on the +[structured logging documentation](structured_logging.md). + +# Upgrading to v1.53.0 + +## Dropping support for `webclient` listeners and non-HTTP(S) `web_client_location` + +Per the deprecation notice in Synapse v1.51.0, listeners of type `webclient` +are no longer supported and configuring them is a now a configuration error. + +Configuring a non-HTTP(S) `web_client_location` configuration is is now a +configuration error. Since the `webclient` listener is no longer supported, this +setting only applies to the root path `/` of Synapse's web server and no longer +the `/_matrix/client/` path. + +## Stablisation of MSC3231 + +The unstable validity-check endpoint for the +[Registration Tokens](https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv1registermloginregistration_tokenvalidity) +feature has been stabilised and moved from: + +`/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity` + +to: + +`/_matrix/client/v1/register/m.login.registration_token/validity` + +Please update any relevant reverse proxy or firewall configurations appropriately. + +## Time-based cache expiry is now enabled by default + +Formerly, entries in the cache were not evicted regardless of whether they were accessed after storing. +This behavior has now changed. By default entries in the cache are now evicted after 30m of not being accessed. +To change the default behavior, go to the `caches` section of the config and change the `expire_caches` and +`cache_entry_ttl` flags as necessary. Please note that these flags replace the `expiry_time` flag in the config. +The `expiry_time` flag will still continue to work, but it has been deprecated and will be removed in the future. + +## Deprecation of `capability` `org.matrix.msc3283.*` + +The `capabilities` of MSC3283 from the REST API `/_matrix/client/r0/capabilities` +becomes stable. + +The old `capabilities` +- `org.matrix.msc3283.set_displayname`, +- `org.matrix.msc3283.set_avatar_url` and +- `org.matrix.msc3283.3pid_changes` + +are deprecated and scheduled to be removed in Synapse v1.54.0. + +The new `capabilities` +- `m.set_displayname`, +- `m.set_avatar_url` and +- `m.3pid_changes` + +are now active by default. + +## Removal of `user_may_create_room_with_invites` + +As announced with the release of [Synapse 1.47.0](#deprecation-of-the-user_may_create_room_with_invites-module-callback), +the deprecated `user_may_create_room_with_invites` module callback has been removed. + +Modules relying on it can instead implement [`user_may_invite`](https://matrix-org.github.io/synapse/latest/modules/spam_checker_callbacks.html#user_may_invite) +and use the [`get_room_state`](https://github.com/matrix-org/synapse/blob/872f23b95fa980a61b0866c1475e84491991fa20/synapse/module_api/__init__.py#L869-L876) +module API to infer whether the invite is happening while creating a room (see [this function](https://github.com/matrix-org/synapse-domain-rule-checker/blob/e7d092dd9f2a7f844928771dbfd9fd24c2332e48/synapse_domain_rule_checker/__init__.py#L56-L89) +as an example). Alternately, modules can also implement [`on_create_room`](https://matrix-org.github.io/synapse/latest/modules/third_party_rules_callbacks.html#on_create_room). + + # Upgrading to v1.52.0 ## Twisted security release @@ -1141,8 +1214,7 @@ more details on upgrading your database. Synapse v1.0 is the first release to enforce validation of TLS certificates for the federation API. It is therefore essential that your -certificates are correctly configured. See the -[FAQ](MSC1711_certificates_FAQ.md) for more information. +certificates are correctly configured. Note, v1.0 installations will also no longer be able to federate with servers that have not correctly configured their certificates. @@ -1207,9 +1279,6 @@ you will need to replace any self-signed certificates with those verified by a root CA. Information on how to do so can be found at the ACME docs. -For more information on configuring TLS certificates see the -[FAQ](MSC1711_certificates_FAQ.md). - # Upgrading to v0.34.0 1. This release is the first to fully support Python 3. Synapse will diff --git a/docs/workers.md b/docs/workers.md index fd83e2ddeb..dadde4d726 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -241,7 +241,7 @@ expressions: # Registration/login requests ^/_matrix/client/(api/v1|r0|v3|unstable)/login$ ^/_matrix/client/(r0|v3|unstable)/register$ - ^/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity$ + ^/_matrix/client/v1/register/m.login.registration_token/validity$ # Event sending requests ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/redact diff --git a/mypy.ini b/mypy.ini index 2884078d0a..610660b9b7 100644 --- a/mypy.ini +++ b/mypy.ini @@ -31,14 +31,11 @@ exclude = (?x) |synapse/storage/databases/main/group_server.py |synapse/storage/databases/main/metrics.py |synapse/storage/databases/main/monthly_active_users.py - |synapse/storage/databases/main/presence.py - |synapse/storage/databases/main/purge_events.py |synapse/storage/databases/main/push_rule.py |synapse/storage/databases/main/receipts.py |synapse/storage/databases/main/roommember.py |synapse/storage/databases/main/search.py |synapse/storage/databases/main/state.py - |synapse/storage/databases/main/user_directory.py |synapse/storage/schema/ |tests/api/test_auth.py @@ -142,6 +139,9 @@ disallow_untyped_defs = True [mypy-synapse.crypto.*] disallow_untyped_defs = True +[mypy-synapse.event_auth] +disallow_untyped_defs = True + [mypy-synapse.events.*] disallow_untyped_defs = True @@ -166,9 +166,15 @@ disallow_untyped_defs = True [mypy-synapse.module_api.*] disallow_untyped_defs = True +[mypy-synapse.notifier] +disallow_untyped_defs = True + [mypy-synapse.push.*] disallow_untyped_defs = True +[mypy-synapse.replication.*] +disallow_untyped_defs = True + [mypy-synapse.rest.*] disallow_untyped_defs = True diff --git a/pyproject.toml b/pyproject.toml index 963f149c6a..c9cd0cf6ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,3 +54,15 @@ exclude = ''' )/ ) ''' + +[tool.isort] +line_length = 88 +sections = ["FUTURE", "STDLIB", "THIRDPARTY", "TWISTED", "FIRSTPARTY", "TESTS", "LOCALFOLDER"] +default_section = "THIRDPARTY" +known_first_party = ["synapse"] +known_tests = ["tests"] +known_twisted = ["twisted", "OpenSSL"] +multi_line_output = 3 +include_trailing_comma = true +combine_as_imports = true + diff --git a/scripts-dev/build_debian_packages b/scripts-dev/build_debian_packages index 4d34e90703..7ff96a1ee6 100755 --- a/scripts-dev/build_debian_packages +++ b/scripts-dev/build_debian_packages @@ -25,7 +25,6 @@ DISTS = ( "debian:bookworm", "debian:sid", "ubuntu:focal", # 20.04 LTS (our EOL forced by Py38 on 2024-10-14) - "ubuntu:hirsute", # 21.04 (EOL 2022-01-05) "ubuntu:impish", # 21.10 (EOL 2022-07) ) diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 640ff15277..db354b3c8c 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -24,10 +24,10 @@ import traceback from typing import Dict, Iterable, Optional, Set import yaml +from matrix_common.versionstring import get_distribution_version_string from twisted.internet import defer, reactor -import synapse from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.logging.context import ( @@ -36,6 +36,8 @@ from synapse.logging.context import ( run_in_background, ) from synapse.storage.database import DatabasePool, make_conn +from synapse.storage.databases.main import PushRuleStore +from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore @@ -65,7 +67,6 @@ from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStor from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database from synapse.util import Clock -from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse_port_db") @@ -180,6 +181,8 @@ class Store( UserDirectoryBackgroundUpdateStore, EndToEndKeyBackgroundStore, StatsStore, + AccountDataWorkerStore, + PushRuleStore, PusherWorkerStore, PresenceBackgroundUpdateStore, GroupServerWorkerStore, @@ -218,7 +221,9 @@ class MockHomeserver: self.clock = Clock(reactor) self.config = config self.hostname = config.server.server_name - self.version_string = "Synapse/" + get_version_string(synapse) + self.version_string = "Synapse/" + get_distribution_version_string( + "matrix-synapse" + ) def get_clock(self): return self.clock diff --git a/scripts/update_synapse_database b/scripts/update_synapse_database index 6c088bad93..5c6453d77f 100755 --- a/scripts/update_synapse_database +++ b/scripts/update_synapse_database @@ -18,15 +18,14 @@ import logging import sys import yaml +from matrix_common.versionstring import get_distribution_version_string from twisted.internet import defer, reactor -import synapse from synapse.config.homeserver import HomeServerConfig from synapse.metrics.background_process_metrics import run_as_background_process from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.util.versionstring import get_version_string logger = logging.getLogger("update_database") @@ -39,7 +38,9 @@ class MockHomeserver(HomeServer): config.server.server_name, reactor=reactor, config=config, **kwargs ) - self.version_string = "Synapse/" + get_version_string(synapse) + self.version_string = "Synapse/" + get_distribution_version_string( + "matrix-synapse" + ) def run_background_updates(hs): diff --git a/setup.cfg b/setup.cfg index e5ceb7ed19..a0506572d9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,14 +19,3 @@ ignore = # E731: do not assign a lambda expression, use a def # E501: Line too long (black enforces this for us) ignore=W503,W504,E203,E731,E501 - -[isort] -line_length = 88 -sections=FUTURE,STDLIB,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER -default_section=THIRDPARTY -known_first_party = synapse -known_tests=tests -known_twisted=twisted,OpenSSL -multi_line_output=3 -include_trailing_comma=true -combine_as_imports=true diff --git a/setup.py b/setup.py index d0511c767f..c80cb6f207 100755 --- a/setup.py +++ b/setup.py @@ -103,8 +103,8 @@ CONDITIONAL_REQUIREMENTS["lint"] = [ ] CONDITIONAL_REQUIREMENTS["mypy"] = [ - "mypy==0.910", - "mypy-zope==0.3.2", + "mypy==0.931", + "mypy-zope==0.3.5", "types-bleach>=4.1.0", "types-jsonschema>=3.2.0", "types-opentracing>=2.4.2", diff --git a/stubs/sortedcontainers/sorteddict.pyi b/stubs/sortedcontainers/sorteddict.pyi index 0eaef00498..344d55cce1 100644 --- a/stubs/sortedcontainers/sorteddict.pyi +++ b/stubs/sortedcontainers/sorteddict.pyi @@ -66,13 +66,18 @@ class SortedDict(Dict[_KT, _VT]): def __copy__(self: _SD) -> _SD: ... @classmethod @overload - def fromkeys(cls, seq: Iterable[_T_h]) -> SortedDict[_T_h, None]: ... + def fromkeys( + cls, seq: Iterable[_T_h], value: None = ... + ) -> SortedDict[_T_h, None]: ... @classmethod @overload def fromkeys(cls, seq: Iterable[_T_h], value: _S) -> SortedDict[_T_h, _S]: ... - def keys(self) -> SortedKeysView[_KT]: ... - def items(self) -> SortedItemsView[_KT, _VT]: ... - def values(self) -> SortedValuesView[_VT]: ... + # As of Python 3.10, `dict_{keys,items,values}` have an extra `mapping` attribute and so + # `Sorted{Keys,Items,Values}View` are no longer compatible with them. + # See https://github.com/python/typeshed/issues/6837 + def keys(self) -> SortedKeysView[_KT]: ... # type: ignore[override] + def items(self) -> SortedItemsView[_KT, _VT]: ... # type: ignore[override] + def values(self) -> SortedValuesView[_VT]: ... # type: ignore[override] @overload def pop(self, key: _KT) -> _VT: ... @overload diff --git a/synapse/__init__.py b/synapse/__init__.py index a23563937a..2bf8eb2a11 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.52.0" +__version__ = "1.53.0rc1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 52c083a20b..36ace7c613 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -81,7 +81,7 @@ class LoginType: TERMS: Final = "m.login.terms" SSO: Final = "m.login.sso" DUMMY: Final = "m.login.dummy" - REGISTRATION_TOKEN: Final = "org.matrix.msc3231.login.registration_token" + REGISTRATION_TOKEN: Final = "m.login.registration_token" # This is used in the `type` parameter for /register when called by diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 85302163da..e92db29f6d 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -406,6 +406,9 @@ class RoomKeysVersionError(SynapseError): super().__init__(403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION) self.current_version = current_version + def error_dict(self) -> "JsonDict": + return cs_error(self.msg, self.errcode, current_version=self.current_version) + class UnsupportedRoomVersionError(SynapseError): """The client's request to create a room used a room version that the server does diff --git a/synapse/api/urls.py b/synapse/api/urls.py index f9f9467dc1..bd49fa6a5f 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -28,7 +28,6 @@ FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1" FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2" FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable" STATIC_PREFIX = "/_matrix/static" -WEB_CLIENT_PREFIX = "/_matrix/client" SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" MEDIA_R0_PREFIX = "/_matrix/media/r0" MEDIA_V3_PREFIX = "/_matrix/media/v3" diff --git a/synapse/app/_base.py b/synapse/app/_base.py index bbab8a052a..452c0c09d5 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -37,6 +37,7 @@ from typing import ( ) from cryptography.utils import CryptographyDeprecationWarning +from matrix_common.versionstring import get_distribution_version_string import twisted from twisted.internet import defer, error, reactor as _reactor @@ -67,7 +68,6 @@ from synapse.util.caches.lrucache import setup_expire_lru_cache_entries from synapse.util.daemonize import daemonize_process from synapse.util.gai_resolver import GAIResolver from synapse.util.rlimit import change_resource_limit -from synapse.util.versionstring import get_version_string if TYPE_CHECKING: from synapse.server import HomeServer @@ -487,7 +487,8 @@ def setup_sentry(hs: "HomeServer") -> None: import sentry_sdk sentry_sdk.init( - dsn=hs.config.metrics.sentry_dsn, release=get_version_string(synapse) + dsn=hs.config.metrics.sentry_dsn, + release=get_distribution_version_string("matrix-synapse"), ) # We set some default tags that give some context to this instance diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 42238f7f28..6f8e33a156 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -19,6 +19,8 @@ import sys import tempfile from typing import List, Optional +from matrix_common.versionstring import get_distribution_version_string + from twisted.internet import defer, task import synapse @@ -44,7 +46,6 @@ from synapse.server import HomeServer from synapse.storage.databases.main.room import RoomWorkerStore from synapse.types import StateMap from synapse.util.logcontext import LoggingContext -from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse.app.admin_cmd") @@ -223,7 +224,7 @@ def start(config_options: List[str]) -> None: ss = AdminCmdServer( config.server.server_name, config=config, - version_string="Synapse/" + get_version_string(synapse), + version_string="Synapse/" + get_distribution_version_string("matrix-synapse"), ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index e256de2003..aadc882bf8 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -16,6 +16,8 @@ import logging import sys from typing import Dict, List, Optional, Tuple +from matrix_common.versionstring import get_distribution_version_string + from twisted.internet import address from twisted.web.resource import Resource @@ -122,7 +124,6 @@ from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.user_directory import UserDirectoryStore from synapse.types import JsonDict from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse.app.generic_worker") @@ -482,7 +483,7 @@ def start(config_options: List[str]) -> None: hs = GenericWorkerServer( config.server.server_name, config=config, - version_string="Synapse/" + get_version_string(synapse), + version_string="Synapse/" + get_distribution_version_string("matrix-synapse"), ) setup_logging(hs, config, use_worker_options=True) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index efedcc8889..bfb30003c2 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -18,22 +18,23 @@ import os import sys from typing import Dict, Iterable, Iterator, List +from matrix_common.versionstring import get_distribution_version_string + from twisted.internet.tcp import Port from twisted.web.resource import EncodingResourceWrapper, Resource from twisted.web.server import GzipEncoderFactory -from twisted.web.static import File import synapse import synapse.config.logger from synapse import events from synapse.api.urls import ( + CLIENT_API_PREFIX, FEDERATION_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_R0_PREFIX, MEDIA_V3_PREFIX, SERVER_KEY_V2_PREFIX, STATIC_PREFIX, - WEB_CLIENT_PREFIX, ) from synapse.app import _base from synapse.app._base import ( @@ -53,7 +54,6 @@ from synapse.http.additional_resource import AdditionalResource from synapse.http.server import ( OptionsResource, RootOptionsRedirectResource, - RootRedirect, StaticResource, ) from synapse.http.site import SynapseSite @@ -72,7 +72,6 @@ from synapse.server import HomeServer from synapse.storage import DataStore from synapse.util.httpresourcetree import create_resource_tree from synapse.util.module_loader import load_module -from synapse.util.versionstring import get_version_string logger = logging.getLogger("synapse.app.homeserver") @@ -134,15 +133,12 @@ class SynapseHomeServer(HomeServer): # Try to find something useful to serve at '/': # # 1. Redirect to the web client if it is an HTTP(S) URL. - # 2. Redirect to the web client served via Synapse. - # 3. Redirect to the static "Synapse is running" page. - # 4. Do not redirect and use a blank resource. - if self.config.server.web_client_location_is_redirect: + # 2. Redirect to the static "Synapse is running" page. + # 3. Do not redirect and use a blank resource. + if self.config.server.web_client_location: root_resource: Resource = RootOptionsRedirectResource( self.config.server.web_client_location ) - elif WEB_CLIENT_PREFIX in resources: - root_resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX) elif STATIC_PREFIX in resources: root_resource = RootOptionsRedirectResource(STATIC_PREFIX) else: @@ -201,13 +197,7 @@ class SynapseHomeServer(HomeServer): resources.update( { - "/_matrix/client/api/v1": client_resource, - "/_matrix/client/r0": client_resource, - "/_matrix/client/v1": client_resource, - "/_matrix/client/v3": client_resource, - "/_matrix/client/unstable": client_resource, - "/_matrix/client/v2_alpha": client_resource, - "/_matrix/client/versions": client_resource, + CLIENT_API_PREFIX: client_resource, "/.well-known": well_known_resource(self), "/_synapse/admin": AdminRestResource(self), **build_synapse_client_resource_tree(self), @@ -270,28 +260,6 @@ class SynapseHomeServer(HomeServer): if name in ["keys", "federation"]: resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) - if name == "webclient": - # webclient listeners are deprecated as of Synapse v1.51.0, remove it - # in > v1.53.0. - webclient_loc = self.config.server.web_client_location - - if webclient_loc is None: - logger.warning( - "Not enabling webclient resource, as web_client_location is unset." - ) - elif self.config.server.web_client_location_is_redirect: - resources[WEB_CLIENT_PREFIX] = RootRedirect(webclient_loc) - else: - logger.warning( - "Running webclient on the same domain is not recommended: " - "https://github.com/matrix-org/synapse#security-note - " - "after you move webclient to different host you can set " - "web_client_location to its full URL to enable redirection." - ) - # GZip is disabled here due to - # https://twistedmatrix.com/trac/ticket/7678 - resources[WEB_CLIENT_PREFIX] = File(webclient_loc) - if name == "metrics" and self.config.metrics.enable_metrics: resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) @@ -383,7 +351,7 @@ def setup(config_options: List[str]) -> SynapseHomeServer: hs = SynapseHomeServer( config.server.server_name, config=config, - version_string="Synapse/" + get_version_string(synapse), + version_string="Synapse/" + get_distribution_version_string("matrix-synapse"), ) synapse.config.logger.setup_logging(hs, config, use_worker_options=False) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 8c9ff93b2c..a340a8c9c7 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -165,23 +165,16 @@ class ApplicationService: return namespace.exclusive return False - async def _matches_user( - self, event: Optional[EventBase], store: Optional["DataStore"] = None - ) -> bool: - if not event: - return False - + async def _matches_user(self, event: EventBase, store: "DataStore") -> bool: if self.is_interested_in_user(event.sender): return True + # also check m.room.member state key if event.type == EventTypes.Member and self.is_interested_in_user( event.state_key ): return True - if not store: - return False - does_match = await self.matches_user_in_member_list(event.room_id, store) return does_match @@ -216,21 +209,15 @@ class ApplicationService: return self.is_interested_in_room(event.room_id) return False - async def _matches_aliases( - self, event: EventBase, store: Optional["DataStore"] = None - ) -> bool: - if not store or not event: - return False - + async def _matches_aliases(self, event: EventBase, store: "DataStore") -> bool: alias_list = await store.get_aliases_for_room(event.room_id) for alias in alias_list: if self.is_interested_in_alias(alias): return True + return False - async def is_interested( - self, event: EventBase, store: Optional["DataStore"] = None - ) -> bool: + async def is_interested(self, event: EventBase, store: "DataStore") -> bool: """Check if this service is interested in this event. Args: @@ -351,11 +338,13 @@ class AppServiceTransaction: id: int, events: List[EventBase], ephemeral: List[JsonDict], + to_device_messages: List[JsonDict], ): self.service = service self.id = id self.events = events self.ephemeral = ephemeral + self.to_device_messages = to_device_messages async def send(self, as_api: "ApplicationServiceApi") -> bool: """Sends this transaction using the provided AS API interface. @@ -369,6 +358,7 @@ class AppServiceTransaction: service=self.service, events=self.events, ephemeral=self.ephemeral, + to_device_messages=self.to_device_messages, txn_id=self.id, ) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index def4424af0..73be7ff3d4 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -218,8 +218,23 @@ class ApplicationServiceApi(SimpleHttpClient): service: "ApplicationService", events: List[EventBase], ephemeral: List[JsonDict], + to_device_messages: List[JsonDict], txn_id: Optional[int] = None, ) -> bool: + """ + Push data to an application service. + + Args: + service: The application service to send to. + events: The persistent events to send. + ephemeral: The ephemeral events to send. + to_device_messages: The to-device messages to send. + txn_id: An unique ID to assign to this transaction. Application services should + deduplicate transactions received with identitical IDs. + + Returns: + True if the task succeeded, False if it failed. + """ if service.url is None: return True @@ -237,13 +252,15 @@ class ApplicationServiceApi(SimpleHttpClient): uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id))) # Never send ephemeral events to appservices that do not support it + body: Dict[str, List[JsonDict]] = {"events": serialized_events} if service.supports_ephemeral: - body = { - "events": serialized_events, - "de.sorunome.msc2409.ephemeral": ephemeral, - } - else: - body = {"events": serialized_events} + body.update( + { + # TODO: Update to stable prefixes once MSC2409 completes FCP merge. + "de.sorunome.msc2409.ephemeral": ephemeral, + "de.sorunome.msc2409.to_device": to_device_messages, + } + ) try: await self.put_json( diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 185e3a5278..c42fa32fff 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -48,7 +48,16 @@ This is all tied together by the AppServiceScheduler which DIs the required components. """ import logging -from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set +from typing import ( + TYPE_CHECKING, + Awaitable, + Callable, + Collection, + Dict, + List, + Optional, + Set, +) from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.appservice.api import ApplicationServiceApi @@ -71,6 +80,9 @@ MAX_PERSISTENT_EVENTS_PER_TRANSACTION = 100 # Maximum number of ephemeral events to provide in an AS transaction. MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100 +# Maximum number of to-device messages to provide in an AS transaction. +MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION = 100 + class ApplicationServiceScheduler: """Public facing API for this module. Does the required DI to tie the @@ -97,15 +109,40 @@ class ApplicationServiceScheduler: for service in services: self.txn_ctrl.start_recoverer(service) - def submit_event_for_as( - self, service: ApplicationService, event: EventBase + def enqueue_for_appservice( + self, + appservice: ApplicationService, + events: Optional[Collection[EventBase]] = None, + ephemeral: Optional[Collection[JsonDict]] = None, + to_device_messages: Optional[Collection[JsonDict]] = None, ) -> None: - self.queuer.enqueue_event(service, event) + """ + Enqueue some data to be sent off to an application service. - def submit_ephemeral_events_for_as( - self, service: ApplicationService, events: List[JsonDict] - ) -> None: - self.queuer.enqueue_ephemeral(service, events) + Args: + appservice: The application service to create and send a transaction to. + events: The persistent room events to send. + ephemeral: The ephemeral events to send. + to_device_messages: The to-device messages to send. These differ from normal + to-device messages sent to clients, as they have 'to_device_id' and + 'to_user_id' fields. + """ + # We purposefully allow this method to run with empty events/ephemeral + # collections, so that callers do not need to check iterable size themselves. + if not events and not ephemeral and not to_device_messages: + return + + if events: + self.queuer.queued_events.setdefault(appservice.id, []).extend(events) + if ephemeral: + self.queuer.queued_ephemeral.setdefault(appservice.id, []).extend(ephemeral) + if to_device_messages: + self.queuer.queued_to_device_messages.setdefault(appservice.id, []).extend( + to_device_messages + ) + + # Kick off a new application service transaction + self.queuer.start_background_request(appservice) class _ServiceQueuer: @@ -121,13 +158,15 @@ class _ServiceQueuer: self.queued_events: Dict[str, List[EventBase]] = {} # dict of {service_id: [events]} self.queued_ephemeral: Dict[str, List[JsonDict]] = {} + # dict of {service_id: [to_device_message_json]} + self.queued_to_device_messages: Dict[str, List[JsonDict]] = {} # the appservices which currently have a transaction in flight self.requests_in_flight: Set[str] = set() self.txn_ctrl = txn_ctrl self.clock = clock - def _start_background_request(self, service: ApplicationService) -> None: + def start_background_request(self, service: ApplicationService) -> None: # start a sender for this appservice if we don't already have one if service.id in self.requests_in_flight: return @@ -136,16 +175,6 @@ class _ServiceQueuer: "as-sender-%s" % (service.id,), self._send_request, service ) - def enqueue_event(self, service: ApplicationService, event: EventBase) -> None: - self.queued_events.setdefault(service.id, []).append(event) - self._start_background_request(service) - - def enqueue_ephemeral( - self, service: ApplicationService, events: List[JsonDict] - ) -> None: - self.queued_ephemeral.setdefault(service.id, []).extend(events) - self._start_background_request(service) - async def _send_request(self, service: ApplicationService) -> None: # sanity-check: we shouldn't get here if this service already has a sender # running. @@ -162,11 +191,21 @@ class _ServiceQueuer: ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION] del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION] - if not events and not ephemeral: + all_to_device_messages = self.queued_to_device_messages.get( + service.id, [] + ) + to_device_messages_to_send = all_to_device_messages[ + :MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION + ] + del all_to_device_messages[:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION] + + if not events and not ephemeral and not to_device_messages_to_send: return try: - await self.txn_ctrl.send(service, events, ephemeral) + await self.txn_ctrl.send( + service, events, ephemeral, to_device_messages_to_send + ) except Exception: logger.exception("AS request failed") finally: @@ -198,10 +237,24 @@ class _TransactionController: service: ApplicationService, events: List[EventBase], ephemeral: Optional[List[JsonDict]] = None, + to_device_messages: Optional[List[JsonDict]] = None, ) -> None: + """ + Create a transaction with the given data and send to the provided + application service. + + Args: + service: The application service to send the transaction to. + events: The persistent events to include in the transaction. + ephemeral: The ephemeral events to include in the transaction. + to_device_messages: The to-device messages to include in the transaction. + """ try: txn = await self.store.create_appservice_txn( - service=service, events=events, ephemeral=ephemeral or [] + service=service, + events=events, + ephemeral=ephemeral or [], + to_device_messages=to_device_messages or [], ) service_is_up = await self._is_service_up(service) if service_is_up: diff --git a/synapse/config/cache.py b/synapse/config/cache.py index d9d85f98e1..387ac6d115 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import re import threading @@ -23,6 +24,8 @@ from synapse.python_dependencies import DependencyException, check_requirements from ._base import Config, ConfigError +logger = logging.getLogger(__name__) + # The prefix for all cache factor-related environment variables _CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR" @@ -148,11 +151,16 @@ class CacheConfig(Config): per_cache_factors: #get_users_who_share_room_with_user: 2.0 - # Controls how long an entry can be in a cache without having been - # accessed before being evicted. Defaults to None, which means - # entries are never evicted based on time. + # Controls whether cache entries are evicted after a specified time + # period. Defaults to true. Uncomment to disable this feature. + # + #expire_caches: false + + # If expire_caches is enabled, this flag controls how long an entry can + # be in a cache without having been accessed before being evicted. + # Defaults to 30m. Uncomment to set a different time to live for cache entries. # - #expiry_time: 30m + #cache_entry_ttl: 30m # Controls how long the results of a /sync request are cached for after # a successful response is returned. A higher duration can help clients with @@ -217,12 +225,30 @@ class CacheConfig(Config): e.message # noqa: B306, DependencyException.message is a property ) - expiry_time = cache_config.get("expiry_time") - if expiry_time: - self.expiry_time_msec: Optional[int] = self.parse_duration(expiry_time) + expire_caches = cache_config.get("expire_caches", True) + cache_entry_ttl = cache_config.get("cache_entry_ttl", "30m") + + if expire_caches: + self.expiry_time_msec: Optional[int] = self.parse_duration(cache_entry_ttl) else: self.expiry_time_msec = None + # Backwards compatibility support for the now-removed "expiry_time" config flag. + expiry_time = cache_config.get("expiry_time") + + if expiry_time and expire_caches: + logger.warning( + "You have set two incompatible options, expiry_time and expire_caches. Please only use the " + "expire_caches and cache_entry_ttl options and delete the expiry_time option as it is " + "deprecated." + ) + if expiry_time: + logger.warning( + "Expiry_time is a deprecated option, please use the expire_caches and cache_entry_ttl options " + "instead." + ) + self.expiry_time_msec = self.parse_duration(expiry_time) + self.sync_response_cache_duration = self.parse_duration( cache_config.get("sync_response_cache_duration", 0) ) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 65c807a19a..bcdeb9ee23 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -26,6 +26,8 @@ class ExperimentalConfig(Config): # MSC3440 (thread relation) self.msc3440_enabled: bool = experimental.get("msc3440_enabled", False) + # MSC3666: including bundled relations in /search. + self.msc3666_enabled: bool = experimental.get("msc3666_enabled", False) # MSC3026 (busy presence state) self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False) @@ -39,9 +41,6 @@ class ExperimentalConfig(Config): # MSC3244 (room version capabilities) self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True) - # MSC3283 (set displayname, avatar_url and change 3pid capabilities) - self.msc3283_enabled: bool = experimental.get("msc3283_enabled", False) - # MSC3266 (room summary api) self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False) @@ -52,3 +51,17 @@ class ExperimentalConfig(Config): self.msc3202_device_masquerading_enabled: bool = experimental.get( "msc3202_device_masquerading", False ) + + # MSC2409 (this setting only relates to optionally sending to-device messages). + # Presence, typing and read receipt EDUs are already sent to application services that + # have opted in to receive them. If enabled, this adds to-device messages to that list. + self.msc2409_to_device_messages_enabled: bool = experimental.get( + "msc2409_to_device_messages_enabled", False + ) + + # MSC3706 (server-side support for partial state in /send_join responses) + self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False) + + # experimental support for faster joins over federation (msc2775, msc3706) + # requires a target server with msc3706_enabled enabled. + self.faster_joins_enabled: bool = experimental.get("faster_joins", False) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index ea69b9bd9b..cbbe221965 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -22,6 +22,7 @@ from string import Template from typing import TYPE_CHECKING, Any, Dict, Optional import yaml +from matrix_common.versionstring import get_distribution_version_string from zope.interface import implementer from twisted.logger import ( @@ -32,11 +33,8 @@ from twisted.logger import ( globalLogBeginner, ) -import synapse -from synapse.logging._structured import setup_structured_logging from synapse.logging.context import LoggingContextFilter from synapse.logging.filter import MetadataFilter -from synapse.util.versionstring import get_version_string from ._base import Config, ConfigError @@ -139,6 +137,12 @@ Support for the log_file configuration option and --log-file command-line option removed in Synapse 1.3.0. You should instead set up a separate log configuration file. """ +STRUCTURED_ERROR = """\ +Support for the structured configuration option was removed in Synapse 1.54.0. +You should instead use the standard logging configuration. See +https://matrix-org.github.io/synapse/v1.54/structured_logging.html +""" + class LoggingConfig(Config): section = "logging" @@ -293,10 +297,9 @@ def _load_logging_config(log_config_path: str) -> None: if not log_config: logging.warning("Loaded a blank logging config?") - # If the old structured logging configuration is being used, convert it to - # the new style configuration. + # If the old structured logging configuration is being used, raise an error. if "structured" in log_config and log_config.get("structured"): - log_config = setup_structured_logging(log_config) + raise ConfigError(STRUCTURED_ERROR) logging.config.dictConfig(log_config) @@ -347,6 +350,10 @@ def setup_logging( # Log immediately so we can grep backwards. logging.warning("***** STARTING SERVER *****") - logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse)) + logging.warning( + "Server %s version %s", + sys.argv[0], + get_distribution_version_string("matrix-synapse"), + ) logging.info("Server hostname: %s", config.server.server_name) logging.info("Instance name: %s", hs.get_instance_name()) diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 36636ab07e..e9ccf1bd62 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -134,6 +134,14 @@ class RatelimitConfig(Config): defaults={"per_second": 0.003, "burst_count": 5}, ) + self.rc_third_party_invite = RateLimitConfig( + config.get("rc_third_party_invite", {}), + defaults={ + "per_second": self.rc_message.per_second, + "burst_count": self.rc_message.burst_count, + }, + ) + def generate_config_section(self, **kwargs): return """\ ## Ratelimiting ## @@ -168,6 +176,9 @@ class RatelimitConfig(Config): # - one for ratelimiting how often a user or IP can attempt to validate a 3PID. # - two for ratelimiting how often invites can be sent in a room or to a # specific user. + # - one for ratelimiting 3PID invites (i.e. invites sent to a third-party ID + # such as an email address or a phone number) based on the account that's + # sending the invite. # # The defaults are as shown below. # @@ -217,6 +228,10 @@ class RatelimitConfig(Config): # per_user: # per_second: 0.003 # burst_count: 5 + # + #rc_third_party_invite: + # per_second: 0.2 + # burst_count: 10 # Ratelimiting settings for incoming federation # diff --git a/synapse/config/server.py b/synapse/config/server.py index a460cf25b4..7bc9624546 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -179,7 +179,6 @@ KNOWN_RESOURCES = { "openid", "replication", "static", - "webclient", } @@ -519,16 +518,12 @@ class ServerConfig(Config): self.listeners = l2 self.web_client_location = config.get("web_client_location", None) - self.web_client_location_is_redirect = self.web_client_location and ( + # Non-HTTP(S) web client location is not supported. + if self.web_client_location and not ( self.web_client_location.startswith("http://") or self.web_client_location.startswith("https://") - ) - # A non-HTTP(S) web client location is deprecated. - if self.web_client_location and not self.web_client_location_is_redirect: - logger.warning(NO_MORE_NONE_HTTP_WEB_CLIENT_LOCATION_WARNING) - - # Warn if webclient is configured for a worker. - _warn_if_webclient_configured(self.listeners) + ): + raise ConfigError("web_client_location must point to a HTTP(S) URL.") self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None)) self.gc_seconds = self.read_gc_intervals(config.get("gc_min_interval", None)) @@ -656,19 +651,6 @@ class ServerConfig(Config): False, ) - # List of users trialing the new experimental default push rules. This setting is - # not included in the sample configuration file on purpose as it's a temporary - # hack, so that some users can trial the new defaults without impacting every - # user on the homeserver. - users_new_default_push_rules: list = ( - config.get("users_new_default_push_rules") or [] - ) - if not isinstance(users_new_default_push_rules, list): - raise ConfigError("'users_new_default_push_rules' must be a list") - - # Turn the list into a set to improve lookup speed. - self.users_new_default_push_rules: set = set(users_new_default_push_rules) - # Whitelist of domain names that given next_link parameters must have next_link_domain_whitelist: Optional[List[str]] = config.get( "next_link_domain_whitelist" @@ -1364,11 +1346,16 @@ def parse_listener_def(listener: Any) -> ListenerConfig: http_config = None if listener_type == "http": + try: + resources = [ + HttpResourceConfig(**res) for res in listener.get("resources", []) + ] + except ValueError as e: + raise ConfigError("Unknown listener resource") from e + http_config = HttpListenerConfig( x_forwarded=listener.get("x_forwarded", False), - resources=[ - HttpResourceConfig(**res) for res in listener.get("resources", []) - ], + resources=resources, additional_resources=listener.get("additional_resources", {}), tag=listener.get("tag"), ) @@ -1376,30 +1363,6 @@ def parse_listener_def(listener: Any) -> ListenerConfig: return ListenerConfig(port, bind_addresses, listener_type, tls, http_config) -NO_MORE_NONE_HTTP_WEB_CLIENT_LOCATION_WARNING = """ -Synapse no longer supports serving a web client. To remove this warning, -configure 'web_client_location' with an HTTP(S) URL. -""" - - -NO_MORE_WEB_CLIENT_WARNING = """ -Synapse no longer includes a web client. To redirect the root resource to a web client, configure -'web_client_location'. To remove this warning, remove 'webclient' from the 'listeners' -configuration. -""" - - -def _warn_if_webclient_configured(listeners: Iterable[ListenerConfig]) -> None: - for listener in listeners: - if not listener.http_options: - continue - for res in listener.http_options.resources: - for name in res.names: - if name == "webclient": - logger.warning(NO_MORE_WEB_CLIENT_WARNING) - return - - _MANHOLE_SETTINGS_SCHEMA = { "type": "object", "properties": { diff --git a/synapse/event_auth.py b/synapse/event_auth.py index e885961698..eca00bc975 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +import typing from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from canonicaljson import encode_canonical_json @@ -34,15 +35,18 @@ from synapse.api.room_versions import ( EventFormatVersions, RoomVersion, ) -from synapse.events import EventBase -from synapse.events.builder import EventBuilder from synapse.types import StateMap, UserID, get_domain_from_id +if typing.TYPE_CHECKING: + # conditional imports to avoid import cycle + from synapse.events import EventBase + from synapse.events.builder import EventBuilder + logger = logging.getLogger(__name__) def validate_event_for_room_version( - room_version_obj: RoomVersion, event: EventBase + room_version_obj: RoomVersion, event: "EventBase" ) -> None: """Ensure that the event complies with the limits, and has the right signatures @@ -113,7 +117,9 @@ def validate_event_for_room_version( def check_auth_rules_for_event( - room_version_obj: RoomVersion, event: EventBase, auth_events: Iterable[EventBase] + room_version_obj: RoomVersion, + event: "EventBase", + auth_events: Iterable["EventBase"], ) -> None: """Check that an event complies with the auth rules @@ -256,7 +262,7 @@ def check_auth_rules_for_event( logger.debug("Allowing! %s", event) -def _check_size_limits(event: EventBase) -> None: +def _check_size_limits(event: "EventBase") -> None: if len(event.user_id) > 255: raise EventSizeError("'user_id' too large") if len(event.room_id) > 255: @@ -271,7 +277,7 @@ def _check_size_limits(event: EventBase) -> None: raise EventSizeError("event too large") -def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool: +def _can_federate(event: "EventBase", auth_events: StateMap["EventBase"]) -> bool: creation_event = auth_events.get((EventTypes.Create, "")) # There should always be a creation event, but if not don't federate. if not creation_event: @@ -281,7 +287,7 @@ def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool: def _is_membership_change_allowed( - room_version: RoomVersion, event: EventBase, auth_events: StateMap[EventBase] + room_version: RoomVersion, event: "EventBase", auth_events: StateMap["EventBase"] ) -> None: """ Confirms that the event which changes membership is an allowed change. @@ -471,7 +477,7 @@ def _is_membership_change_allowed( def _check_event_sender_in_room( - event: EventBase, auth_events: StateMap[EventBase] + event: "EventBase", auth_events: StateMap["EventBase"] ) -> None: key = (EventTypes.Member, event.user_id) member_event = auth_events.get(key) @@ -479,7 +485,9 @@ def _check_event_sender_in_room( _check_joined_room(member_event, event.user_id, event.room_id) -def _check_joined_room(member: Optional[EventBase], user_id: str, room_id: str) -> None: +def _check_joined_room( + member: Optional["EventBase"], user_id: str, room_id: str +) -> None: if not member or member.membership != Membership.JOIN: raise AuthError( 403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member)) @@ -487,7 +495,7 @@ def _check_joined_room(member: Optional[EventBase], user_id: str, room_id: str) def get_send_level( - etype: str, state_key: Optional[str], power_levels_event: Optional[EventBase] + etype: str, state_key: Optional[str], power_levels_event: Optional["EventBase"] ) -> int: """Get the power level required to send an event of a given type @@ -523,7 +531,7 @@ def get_send_level( return int(send_level) -def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool: +def _can_send_event(event: "EventBase", auth_events: StateMap["EventBase"]) -> bool: power_levels_event = get_power_level_event(auth_events) send_level = get_send_level(event.type, event.get("state_key"), power_levels_event) @@ -547,8 +555,8 @@ def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool: def check_redaction( room_version_obj: RoomVersion, - event: EventBase, - auth_events: StateMap[EventBase], + event: "EventBase", + auth_events: StateMap["EventBase"], ) -> bool: """Check whether the event sender is allowed to redact the target event. @@ -585,8 +593,8 @@ def check_redaction( def check_historical( room_version_obj: RoomVersion, - event: EventBase, - auth_events: StateMap[EventBase], + event: "EventBase", + auth_events: StateMap["EventBase"], ) -> None: """Check whether the event sender is allowed to send historical related events like "insertion", "batch", and "marker". @@ -616,8 +624,8 @@ def check_historical( def _check_power_levels( room_version_obj: RoomVersion, - event: EventBase, - auth_events: StateMap[EventBase], + event: "EventBase", + auth_events: StateMap["EventBase"], ) -> None: user_list = event.content.get("users", {}) # Validate users @@ -710,11 +718,11 @@ def _check_power_levels( ) -def get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]: +def get_power_level_event(auth_events: StateMap["EventBase"]) -> Optional["EventBase"]: return auth_events.get((EventTypes.PowerLevels, "")) -def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int: +def get_user_power_level(user_id: str, auth_events: StateMap["EventBase"]) -> int: """Get a user's power level Args: @@ -750,7 +758,7 @@ def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int: return 0 -def get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int: +def get_named_level(auth_events: StateMap["EventBase"], name: str, default: int) -> int: power_level_event = get_power_level_event(auth_events) if not power_level_event: @@ -763,7 +771,9 @@ def get_named_level(auth_events: StateMap[EventBase], name: str, default: int) - return default -def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase]): +def _verify_third_party_invite( + event: "EventBase", auth_events: StateMap["EventBase"] +) -> bool: """ Validates that the invite event is authorized by a previous third-party invite. @@ -827,7 +837,7 @@ def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase return False -def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]: +def get_public_keys(invite_event: "EventBase") -> List[Dict[str, Any]]: public_keys = [] if "public_key" in invite_event.content: o = {"public_key": invite_event.content["public_key"]} @@ -839,7 +849,7 @@ def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]: def auth_types_for_event( - room_version: RoomVersion, event: Union[EventBase, EventBuilder] + room_version: RoomVersion, event: Union["EventBase", "EventBuilder"] ) -> Set[Tuple[str, str]]: """Given an event, return a list of (EventType, StateKey) that may be needed to auth the event. The returned list may be a superset of what diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 3134beb8d3..04afd48274 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -48,9 +48,6 @@ USER_MAY_JOIN_ROOM_CALLBACK = Callable[[str, str, bool], Awaitable[bool]] USER_MAY_INVITE_CALLBACK = Callable[[str, str, str], Awaitable[bool]] USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[[str, str, str, str], Awaitable[bool]] USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]] -USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK = Callable[ - [str, List[str], List[Dict[str, str]]], Awaitable[bool] -] USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]] USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]] CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]] @@ -174,9 +171,6 @@ class SpamChecker: USER_MAY_SEND_3PID_INVITE_CALLBACK ] = [] self._user_may_create_room_callbacks: List[USER_MAY_CREATE_ROOM_CALLBACK] = [] - self._user_may_create_room_with_invites_callbacks: List[ - USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK - ] = [] self._user_may_create_room_alias_callbacks: List[ USER_MAY_CREATE_ROOM_ALIAS_CALLBACK ] = [] @@ -198,9 +192,6 @@ class SpamChecker: user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None, user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None, user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None, - user_may_create_room_with_invites: Optional[ - USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK - ] = None, user_may_create_room_alias: Optional[ USER_MAY_CREATE_ROOM_ALIAS_CALLBACK ] = None, @@ -229,11 +220,6 @@ class SpamChecker: if user_may_create_room is not None: self._user_may_create_room_callbacks.append(user_may_create_room) - if user_may_create_room_with_invites is not None: - self._user_may_create_room_with_invites_callbacks.append( - user_may_create_room_with_invites, - ) - if user_may_create_room_alias is not None: self._user_may_create_room_alias_callbacks.append( user_may_create_room_alias, @@ -359,34 +345,6 @@ class SpamChecker: return True - async def user_may_create_room_with_invites( - self, - userid: str, - invites: List[str], - threepid_invites: List[Dict[str, str]], - ) -> bool: - """Checks if a given user may create a room with invites - - If this method returns false, the creation request will be rejected. - - Args: - userid: The ID of the user attempting to create a room - invites: The IDs of the Matrix users to be invited if the room creation is - allowed. - threepid_invites: The threepids to be invited if the room creation is allowed, - as a dict including a "medium" key indicating the threepid's medium (e.g. - "email") and an "address" key indicating the threepid's address (e.g. - "alice@example.com") - - Returns: - True if the user may create the room, otherwise False - """ - for callback in self._user_may_create_room_with_invites_callbacks: - if await callback(userid, invites, threepid_invites) is False: - return False - - return True - async def user_may_create_room_alias( self, userid: str, room_alias: RoomAlias ) -> bool: diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 243696b357..9386fa29dd 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -425,6 +425,33 @@ class EventClientSerializer: return serialized_event + def _apply_edit( + self, orig_event: EventBase, serialized_event: JsonDict, edit: EventBase + ) -> None: + """Replace the content, preserving existing relations of the serialized event. + + Args: + orig_event: The original event. + serialized_event: The original event, serialized. This is modified. + edit: The event which edits the above. + """ + + # Ensure we take copies of the edit content, otherwise we risk modifying + # the original event. + edit_content = edit.content.copy() + + # Unfreeze the event content if necessary, so that we may modify it below + edit_content = unfreeze(edit_content) + serialized_event["content"] = edit_content.get("m.new_content", {}) + + # Check for existing relations + relates_to = orig_event.content.get("m.relates_to") + if relates_to: + # Keep the relations, ensuring we use a dict copy of the original + serialized_event["content"]["m.relates_to"] = relates_to.copy() + else: + serialized_event["content"].pop("m.relates_to", None) + def _inject_bundled_aggregations( self, event: EventBase, @@ -450,26 +477,11 @@ class EventClientSerializer: serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references if aggregations.replace: - # If there is an edit replace the content, preserving existing - # relations. + # If there is an edit, apply it to the event. edit = aggregations.replace + self._apply_edit(event, serialized_event, edit) - # Ensure we take copies of the edit content, otherwise we risk modifying - # the original event. - edit_content = edit.content.copy() - - # Unfreeze the event content if necessary, so that we may modify it below - edit_content = unfreeze(edit_content) - serialized_event["content"] = edit_content.get("m.new_content", {}) - - # Check for existing relations - relates_to = event.content.get("m.relates_to") - if relates_to: - # Keep the relations, ensuring we use a dict copy of the original - serialized_event["content"]["m.relates_to"] = relates_to.copy() - else: - serialized_event["content"].pop("m.relates_to", None) - + # Include information about it in the relations dict. serialized_aggregations[RelationTypes.REPLACE] = { "event_id": edit.event_id, "origin_server_ts": edit.origin_server_ts, @@ -478,13 +490,22 @@ class EventClientSerializer: # If this event is the start of a thread, include a summary of the replies. if aggregations.thread: + thread = aggregations.thread + + # Don't bundle aggregations as this could recurse forever. + serialized_latest_event = self.serialize_event( + thread.latest_event, time_now, bundle_aggregations=None + ) + # Manually apply an edit, if one exists. + if thread.latest_edit: + self._apply_edit( + thread.latest_event, serialized_latest_event, thread.latest_edit + ) + serialized_aggregations[RelationTypes.THREAD] = { - # Don't bundle aggregations as this could recurse forever. - "latest_event": self.serialize_event( - aggregations.thread.latest_event, time_now, bundle_aggregations=None - ), - "count": aggregations.thread.count, - "current_user_participated": aggregations.thread.current_user_participated, + "latest_event": serialized_latest_event, + "count": thread.count, + "current_user_participated": thread.current_user_participated, } # Include the bundled aggregations in the event. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 74f17aa4da..48c90bf0bb 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1,4 +1,4 @@ -# Copyright 2015-2021 The Matrix.org Foundation C.I.C. +# Copyright 2015-2022 The Matrix.org Foundation C.I.C. # Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -89,6 +89,12 @@ class SendJoinResult: state: List[EventBase] auth_chain: List[EventBase] + # True if 'state' elides non-critical membership events + partial_state: bool + + # if 'partial_state' is set, a list of the servers in the room (otherwise empty) + servers_in_room: List[str] + class FederationClient(FederationBase): def __init__(self, hs: "HomeServer"): @@ -864,23 +870,32 @@ class FederationClient(FederationBase): for s in signed_state: s.internal_metadata = copy.deepcopy(s.internal_metadata) - # double-check that the same create event has ended up in the auth chain + # double-check that the auth chain doesn't include a different create event auth_chain_create_events = [ e.event_id for e in signed_auth if (e.type, e.state_key) == (EventTypes.Create, "") ] - if auth_chain_create_events != [create_event.event_id]: + if auth_chain_create_events and auth_chain_create_events != [ + create_event.event_id + ]: raise InvalidResponseError( "Unexpected create event(s) in auth chain: %s" % (auth_chain_create_events,) ) + if response.partial_state and not response.servers_in_room: + raise InvalidResponseError( + "partial_state was set, but no servers were listed in the room" + ) + return SendJoinResult( event=event, state=signed_state, auth_chain=signed_auth, origin=destination, + partial_state=response.partial_state, + servers_in_room=response.servers_in_room or [], ) # MSC3083 defines additional error codes for room joins. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index af9cb98f67..482bbdd867 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -20,6 +20,7 @@ from typing import ( Any, Awaitable, Callable, + Collection, Dict, Iterable, List, @@ -64,7 +65,7 @@ from synapse.replication.http.federation import ( ReplicationGetQueryRestServlet, ) from synapse.storage.databases.main.lock import Lock -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import JsonDict, StateMap, get_domain_from_id from synapse.util import json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results from synapse.util.caches.response_cache import ResponseCache @@ -571,7 +572,7 @@ class FederationServer(FederationBase): ) -> JsonDict: state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids) - return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} + return {"pdu_ids": state_ids, "auth_chain_ids": list(auth_chain_ids)} async def _on_context_state_request_compute( self, room_id: str, event_id: Optional[str] @@ -645,27 +646,61 @@ class FederationServer(FederationBase): return {"event": ret_pdu.get_pdu_json(time_now)} async def on_send_join_request( - self, origin: str, content: JsonDict, room_id: str + self, + origin: str, + content: JsonDict, + room_id: str, + caller_supports_partial_state: bool = False, ) -> Dict[str, Any]: event, context = await self._on_send_membership_event( origin, content, Membership.JOIN, room_id ) prev_state_ids = await context.get_prev_state_ids() - state_ids = list(prev_state_ids.values()) - auth_chain = await self.store.get_auth_chain(room_id, state_ids) - state = await self.store.get_events(state_ids) + state_event_ids: Collection[str] + servers_in_room: Optional[Collection[str]] + if caller_supports_partial_state: + state_event_ids = _get_event_ids_for_partial_state_join( + event, prev_state_ids + ) + servers_in_room = await self.state.get_hosts_in_room_at_events( + room_id, event_ids=event.prev_event_ids() + ) + else: + state_event_ids = prev_state_ids.values() + servers_in_room = None + + auth_chain_event_ids = await self.store.get_auth_chain_ids( + room_id, state_event_ids + ) + + # if the caller has opted in, we can omit any auth_chain events which are + # already in state_event_ids + if caller_supports_partial_state: + auth_chain_event_ids.difference_update(state_event_ids) + + auth_chain_events = await self.store.get_events_as_list(auth_chain_event_ids) + state_events = await self.store.get_events_as_list(state_event_ids) + + # we try to do all the async stuff before this point, so that time_now is as + # accurate as possible. time_now = self._clock.time_msec() - event_json = event.get_pdu_json() - return { + event_json = event.get_pdu_json(time_now) + resp = { # TODO Remove the unstable prefix when servers have updated. "org.matrix.msc3083.v2.event": event_json, "event": event_json, - "state": [p.get_pdu_json(time_now) for p in state.values()], - "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain], + "state": [p.get_pdu_json(time_now) for p in state_events], + "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events], + "org.matrix.msc3706.partial_state": caller_supports_partial_state, } + if servers_in_room is not None: + resp["org.matrix.msc3706.servers_in_room"] = list(servers_in_room) + + return resp + async def on_make_leave_request( self, origin: str, room_id: str, user_id: str ) -> Dict[str, Any]: @@ -1339,3 +1374,39 @@ class FederationHandlerRegistry: # error. logger.warning("No handler registered for query type %s", query_type) raise NotFoundError("No handler for Query type '%s'" % (query_type,)) + + +def _get_event_ids_for_partial_state_join( + join_event: EventBase, + prev_state_ids: StateMap[str], +) -> Collection[str]: + """Calculate state to be retuned in a partial_state send_join + + Args: + join_event: the join event being send_joined + prev_state_ids: the event ids of the state before the join + + Returns: + the event ids to be returned + """ + + # return all non-member events + state_event_ids = { + event_id + for (event_type, state_key), event_id in prev_state_ids.items() + if event_type != EventTypes.Member + } + + # we also need the current state of the current user (it's going to + # be an auth event for the new join, so we may as well return it) + current_membership_event_id = prev_state_ids.get( + (EventTypes.Member, join_event.state_key) + ) + if current_membership_event_id is not None: + state_event_ids.add(current_membership_event_id) + + # TODO: return a few more members: + # - those with invites + # - those that are kicked? / banned + + return state_event_ids diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 8152e80b88..c3132f7319 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -381,7 +381,9 @@ class PerDestinationQueue: ) ) - if self._last_successful_stream_ordering is None: + last_successful_stream_ordering = self._last_successful_stream_ordering + + if last_successful_stream_ordering is None: # if it's still None, then this means we don't have the information # in our database we haven't successfully sent a PDU to this server # (at least since the introduction of the feature tracking @@ -394,8 +396,7 @@ class PerDestinationQueue: # get at most 50 catchup room/PDUs while True: event_ids = await self._store.get_catch_up_room_event_ids( - self._destination, - self._last_successful_stream_ordering, + self._destination, last_successful_stream_ordering ) if not event_ids: @@ -403,7 +404,7 @@ class PerDestinationQueue: # of a race condition, so we check that no new events have been # skipped due to us being in catch-up mode - if self._catchup_last_skipped > self._last_successful_stream_ordering: + if self._catchup_last_skipped > last_successful_stream_ordering: # another event has been skipped because we were in catch-up mode continue @@ -470,7 +471,7 @@ class PerDestinationQueue: # offline if ( p.internal_metadata.stream_ordering - < self._last_successful_stream_ordering + < last_successful_stream_ordering ): continue @@ -513,12 +514,11 @@ class PerDestinationQueue: # from the *original* PDU, rather than the PDU(s) we actually # send. This is because we use it to mark our position in the # queue of missed PDUs to process. - self._last_successful_stream_ordering = ( - pdu.internal_metadata.stream_ordering - ) + last_successful_stream_ordering = pdu.internal_metadata.stream_ordering + self._last_successful_stream_ordering = last_successful_stream_ordering await self._store.set_destination_last_successful_stream_ordering( - self._destination, self._last_successful_stream_ordering + self._destination, last_successful_stream_ordering ) def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 8782586cd6..dca6e5c45d 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -1,4 +1,4 @@ -# Copyright 2014-2021 The Matrix.org Foundation C.I.C. +# Copyright 2014-2022 The Matrix.org Foundation C.I.C. # Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -60,6 +60,7 @@ class TransportLayerClient: def __init__(self, hs): self.server_name = hs.hostname self.client = hs.get_federation_http_client() + self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled async def get_room_state_ids( self, destination: str, room_id: str, event_id: str @@ -336,10 +337,15 @@ class TransportLayerClient: content: JsonDict, ) -> "SendJoinResponse": path = _create_v2_path("/send_join/%s/%s", room_id, event_id) + query_params: Dict[str, str] = {} + if self._faster_joins_enabled: + # lazy-load state on join + query_params["org.matrix.msc3706.partial_state"] = "true" return await self.client.put_json( destination=destination, path=path, + args=query_params, data=content, parser=SendJoinParser(room_version, v1_api=False), max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN, @@ -1271,6 +1277,12 @@ class SendJoinResponse: # "event" is not included in the response. event: Optional[EventBase] = None + # The room state is incomplete + partial_state: bool = False + + # List of servers in the room + servers_in_room: Optional[List[str]] = None + @ijson.coroutine def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]: @@ -1297,6 +1309,32 @@ def _event_list_parser( events.append(event) +@ijson.coroutine +def _partial_state_parser(response: SendJoinResponse) -> Generator[None, Any, None]: + """Helper function for use with `ijson.items_coro` + + Parses the partial_state field in send_join responses + """ + while True: + val = yield + if not isinstance(val, bool): + raise TypeError("partial_state must be a boolean") + response.partial_state = val + + +@ijson.coroutine +def _servers_in_room_parser(response: SendJoinResponse) -> Generator[None, Any, None]: + """Helper function for use with `ijson.items_coro` + + Parses the servers_in_room field in send_join responses + """ + while True: + val = yield + if not isinstance(val, list) or any(not isinstance(x, str) for x in val): + raise TypeError("servers_in_room must be a list of strings") + response.servers_in_room = val + + class SendJoinParser(ByteParser[SendJoinResponse]): """A parser for the response to `/send_join` requests. @@ -1308,44 +1346,62 @@ class SendJoinParser(ByteParser[SendJoinResponse]): CONTENT_TYPE = "application/json" def __init__(self, room_version: RoomVersion, v1_api: bool): - self._response = SendJoinResponse([], [], {}) + self._response = SendJoinResponse([], [], event_dict={}) self._room_version = room_version + self._coros = [] # The V1 API has the shape of `[200, {...}]`, which we handle by # prefixing with `item.*`. prefix = "item." if v1_api else "" - self._coro_state = ijson.items_coro( - _event_list_parser(room_version, self._response.state), - prefix + "state.item", - use_float=True, - ) - self._coro_auth = ijson.items_coro( - _event_list_parser(room_version, self._response.auth_events), - prefix + "auth_chain.item", - use_float=True, - ) - # TODO Remove the unstable prefix when servers have updated. - # - # By re-using the same event dictionary this will cause the parsing of - # org.matrix.msc3083.v2.event and event to stomp over each other. - # Generally this should be fine. - self._coro_unstable_event = ijson.kvitems_coro( - _event_parser(self._response.event_dict), - prefix + "org.matrix.msc3083.v2.event", - use_float=True, - ) - self._coro_event = ijson.kvitems_coro( - _event_parser(self._response.event_dict), - prefix + "event", - use_float=True, - ) + self._coros = [ + ijson.items_coro( + _event_list_parser(room_version, self._response.state), + prefix + "state.item", + use_float=True, + ), + ijson.items_coro( + _event_list_parser(room_version, self._response.auth_events), + prefix + "auth_chain.item", + use_float=True, + ), + # TODO Remove the unstable prefix when servers have updated. + # + # By re-using the same event dictionary this will cause the parsing of + # org.matrix.msc3083.v2.event and event to stomp over each other. + # Generally this should be fine. + ijson.kvitems_coro( + _event_parser(self._response.event_dict), + prefix + "org.matrix.msc3083.v2.event", + use_float=True, + ), + ijson.kvitems_coro( + _event_parser(self._response.event_dict), + prefix + "event", + use_float=True, + ), + ] + + if not v1_api: + self._coros.append( + ijson.items_coro( + _partial_state_parser(self._response), + "org.matrix.msc3706.partial_state", + use_float="True", + ) + ) + + self._coros.append( + ijson.items_coro( + _servers_in_room_parser(self._response), + "org.matrix.msc3706.servers_in_room", + use_float="True", + ) + ) def write(self, data: bytes) -> int: - self._coro_state.send(data) - self._coro_auth.send(data) - self._coro_unstable_event.send(data) - self._coro_event.send(data) + for c in self._coros: + c.send(data) return len(data) diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index 2ca7c05835..dff2b68359 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -15,6 +15,7 @@ import functools import logging import re +import time from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, cast from synapse.api.errors import Codes, FederationDeniedError, SynapseError @@ -24,8 +25,10 @@ from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( + active_span, set_tag, span_context_from_request, + start_active_span, start_active_span_follows_from, whitelisted_homeserver, ) @@ -265,9 +268,10 @@ class BaseFederationServlet: content = parse_json_object_from_request(request) try: - origin: Optional[str] = await authenticator.authenticate_request( - request, content - ) + with start_active_span("authenticate_request"): + origin: Optional[str] = await authenticator.authenticate_request( + request, content + ) except NoAuthenticationError: origin = None if self.REQUIRE_AUTH: @@ -282,32 +286,57 @@ class BaseFederationServlet: # update the active opentracing span with the authenticated entity set_tag("authenticated_entity", origin) - # if the origin is authenticated and whitelisted, link to its span context + # if the origin is authenticated and whitelisted, use its span context + # as the parent. context = None if origin and whitelisted_homeserver(origin): context = span_context_from_request(request) - scope = start_active_span_follows_from( - "incoming-federation-request", contexts=(context,) if context else () - ) + if context: + servlet_span = active_span() + # a scope which uses the origin's context as a parent + processing_start_time = time.time() + scope = start_active_span_follows_from( + "incoming-federation-request", + child_of=context, + contexts=(servlet_span,), + start_time=processing_start_time, + ) - with scope: - if origin and self.RATELIMIT: - with ratelimiter.ratelimit(origin) as d: - await d - if request._disconnected: - logger.warning( - "client disconnected before we started processing " - "request" + else: + # just use our context as a parent + scope = start_active_span( + "incoming-federation-request", + ) + + try: + with scope: + if origin and self.RATELIMIT: + with ratelimiter.ratelimit(origin) as d: + await d + if request._disconnected: + logger.warning( + "client disconnected before we started processing " + "request" + ) + return None + response = await func( + origin, content, request.args, *args, **kwargs ) - return None + else: response = await func( origin, content, request.args, *args, **kwargs ) - else: - response = await func( - origin, content, request.args, *args, **kwargs + finally: + # if we used the origin's context as the parent, add a new span using + # the servlet span as a parent, so that we have a link + if context: + scope2 = start_active_span_follows_from( + "process-federation_request", + contexts=(scope.span,), + start_time=processing_start_time, ) + scope2.close() return response diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 9c1ad5851f..e85a8eda5b 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -24,9 +24,9 @@ from typing import ( Union, ) +from matrix_common.versionstring import get_distribution_version_string from typing_extensions import Literal -import synapse from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersions from synapse.api.urls import FEDERATION_UNSTABLE_PREFIX, FEDERATION_V2_PREFIX @@ -42,7 +42,6 @@ from synapse.http.servlet import ( ) from synapse.types import JsonDict from synapse.util.ratelimitutils import FederationRateLimiter -from synapse.util.versionstring import get_version_string if TYPE_CHECKING: from synapse.server import HomeServer @@ -109,11 +108,11 @@ class FederationSendServlet(BaseFederationServerServlet): ) if issue_8631_logger.isEnabledFor(logging.DEBUG): - DEVICE_UPDATE_EDUS = {"m.device_list_update", "m.signing_key_update"} + DEVICE_UPDATE_EDUS = ["m.device_list_update", "m.signing_key_update"] device_list_updates = [ edu.content for edu in transaction_data.get("edus", []) - if edu.edu_type in DEVICE_UPDATE_EDUS + if edu.get("edu_type") in DEVICE_UPDATE_EDUS ] if device_list_updates: issue_8631_logger.debug( @@ -412,6 +411,16 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet): PREFIX = FEDERATION_V2_PREFIX + def __init__( + self, + hs: "HomeServer", + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self._msc3706_enabled = hs.config.experimental.msc3706_enabled + async def on_PUT( self, origin: str, @@ -422,7 +431,15 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet): ) -> Tuple[int, JsonDict]: # TODO(paul): assert that event_id parsed from path actually # match those given in content - result = await self.handler.on_send_join_request(origin, content, room_id) + + partial_state = False + if self._msc3706_enabled: + partial_state = parse_boolean_from_args( + query, "org.matrix.msc3706.partial_state", default=False + ) + result = await self.handler.on_send_join_request( + origin, content, room_id, caller_supports_partial_state=partial_state + ) return 200, result @@ -598,7 +615,12 @@ class FederationVersionServlet(BaseFederationServlet): ) -> Tuple[int, JsonDict]: return ( 200, - {"server": {"name": "Synapse", "version": get_version_string(synapse)}}, + { + "server": { + "name": "Synapse", + "version": get_distribution_version_string("matrix-synapse"), + } + }, ) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 7833e77e2b..a42c3558e4 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -55,6 +55,9 @@ class ApplicationServicesHandler: self.clock = hs.get_clock() self.notify_appservices = hs.config.appservice.notify_appservices self.event_sources = hs.get_event_sources() + self._msc2409_to_device_messages_enabled = ( + hs.config.experimental.msc2409_to_device_messages_enabled + ) self.current_max = 0 self.is_processing = False @@ -132,7 +135,9 @@ class ApplicationServicesHandler: # Fork off pushes to these services for service in services: - self.scheduler.submit_event_for_as(service, event) + self.scheduler.enqueue_for_appservice( + service, events=[event] + ) now = self.clock.time_msec() ts = await self.store.get_received_ts(event.event_id) @@ -199,8 +204,9 @@ class ApplicationServicesHandler: Args: stream_key: The stream the event came from. - `stream_key` can be "typing_key", "receipt_key" or "presence_key". Any other - value for `stream_key` will cause this function to return early. + `stream_key` can be "typing_key", "receipt_key", "presence_key" or + "to_device_key". Any other value for `stream_key` will cause this function + to return early. Ephemeral events will only be pushed to appservices that have opted into receiving them by setting `push_ephemeral` to true in their registration @@ -216,8 +222,15 @@ class ApplicationServicesHandler: if not self.notify_appservices: return - # Ignore any unsupported streams - if stream_key not in ("typing_key", "receipt_key", "presence_key"): + # Notify appservices of updates in ephemeral event streams. + # Only the following streams are currently supported. + # FIXME: We should use constants for these values. + if stream_key not in ( + "typing_key", + "receipt_key", + "presence_key", + "to_device_key", + ): return # Assert that new_token is an integer (and not a RoomStreamToken). @@ -233,6 +246,13 @@ class ApplicationServicesHandler: # Additional context: https://github.com/matrix-org/synapse/pull/11137 assert isinstance(new_token, int) + # Ignore to-device messages if the feature flag is not enabled + if ( + stream_key == "to_device_key" + and not self._msc2409_to_device_messages_enabled + ): + return + # Check whether there are any appservices which have registered to receive # ephemeral events. # @@ -266,7 +286,7 @@ class ApplicationServicesHandler: with Measure(self.clock, "notify_interested_services_ephemeral"): for service in services: if stream_key == "typing_key": - # Note that we don't persist the token (via set_type_stream_id_for_appservice) + # Note that we don't persist the token (via set_appservice_stream_type_pos) # for typing_key due to performance reasons and due to their highly # ephemeral nature. # @@ -274,7 +294,7 @@ class ApplicationServicesHandler: # and, if they apply to this application service, send it off. events = await self._handle_typing(service, new_token) if events: - self.scheduler.submit_ephemeral_events_for_as(service, events) + self.scheduler.enqueue_for_appservice(service, ephemeral=events) continue # Since we read/update the stream position for this AS/stream @@ -285,28 +305,37 @@ class ApplicationServicesHandler: ): if stream_key == "receipt_key": events = await self._handle_receipts(service, new_token) - if events: - self.scheduler.submit_ephemeral_events_for_as( - service, events - ) + self.scheduler.enqueue_for_appservice(service, ephemeral=events) # Persist the latest handled stream token for this appservice - await self.store.set_type_stream_id_for_appservice( + await self.store.set_appservice_stream_type_pos( service, "read_receipt", new_token ) elif stream_key == "presence_key": events = await self._handle_presence(service, users, new_token) - if events: - self.scheduler.submit_ephemeral_events_for_as( - service, events - ) + self.scheduler.enqueue_for_appservice(service, ephemeral=events) # Persist the latest handled stream token for this appservice - await self.store.set_type_stream_id_for_appservice( + await self.store.set_appservice_stream_type_pos( service, "presence", new_token ) + elif stream_key == "to_device_key": + # Retrieve a list of to-device message events, as well as the + # maximum stream token of the messages we were able to retrieve. + to_device_messages = await self._get_to_device_messages( + service, new_token, users + ) + self.scheduler.enqueue_for_appservice( + service, to_device_messages=to_device_messages + ) + + # Persist the latest handled stream token for this appservice + await self.store.set_appservice_stream_type_pos( + service, "to_device", new_token + ) + async def _handle_typing( self, service: ApplicationService, new_token: int ) -> List[JsonDict]: @@ -440,6 +469,79 @@ class ApplicationServicesHandler: return events + async def _get_to_device_messages( + self, + service: ApplicationService, + new_token: int, + users: Collection[Union[str, UserID]], + ) -> List[JsonDict]: + """ + Given an application service, determine which events it should receive + from those between the last-recorded to-device message stream token for this + appservice and the given stream token. + + Args: + service: The application service to check for which events it should receive. + new_token: The latest to-device event stream token. + users: The users to be notified for the new to-device messages + (ie, the recipients of the messages). + + Returns: + A list of JSON dictionaries containing data derived from the to-device events + that should be sent to the given application service. + """ + # Get the stream token that this application service has processed up until + from_key = await self.store.get_type_stream_id_for_appservice( + service, "to_device" + ) + + # Filter out users that this appservice is not interested in + users_appservice_is_interested_in: List[str] = [] + for user in users: + # FIXME: We should do this farther up the call stack. We currently repeat + # this operation in _handle_presence. + if isinstance(user, UserID): + user = user.to_string() + + if service.is_interested_in_user(user): + users_appservice_is_interested_in.append(user) + + if not users_appservice_is_interested_in: + # Return early if the AS was not interested in any of these users + return [] + + # Retrieve the to-device messages for each user + recipient_device_to_messages = await self.store.get_messages_for_user_devices( + users_appservice_is_interested_in, + from_key, + new_token, + ) + + # According to MSC2409, we'll need to add 'to_user_id' and 'to_device_id' fields + # to the event JSON so that the application service will know which user/device + # combination this messages was intended for. + # + # So we mangle this dict into a flat list of to-device messages with the relevant + # user ID and device ID embedded inside each message dict. + message_payload: List[JsonDict] = [] + for ( + user_id, + device_id, + ), messages in recipient_device_to_messages.items(): + for message_json in messages: + # Remove 'message_id' from the to-device message, as it's an internal ID + message_json.pop("message_id", None) + + message_payload.append( + { + "to_user_id": user_id, + "to_device_id": device_id, + **message_json, + } + ) + + return message_payload + async def query_user_exists(self, user_id: str) -> bool: """Check if any application service knows this user_id exists. @@ -547,7 +649,7 @@ class ApplicationServicesHandler: """Retrieve a list of application services interested in this event. Args: - event: The event to check. Can be None if alias_list is not. + event: The event to check. Returns: A list of services interested in this event based on the service regex. """ diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index e32c93e234..572f54b1e3 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -2064,6 +2064,11 @@ GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[ [JsonDict, JsonDict], Awaitable[Optional[str]], ] +GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[ + [JsonDict, JsonDict], + Awaitable[Optional[str]], +] +IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]] class PasswordAuthProvider: @@ -2079,6 +2084,10 @@ class PasswordAuthProvider: self.get_username_for_registration_callbacks: List[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = [] + self.get_displayname_for_registration_callbacks: List[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + ] = [] + self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = [] # Mapping from login type to login parameters self._supported_login_types: Dict[str, Iterable[str]] = {} @@ -2090,12 +2099,16 @@ class PasswordAuthProvider: self, check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None, on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None, + is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None, auth_checkers: Optional[ Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK] ] = None, get_username_for_registration: Optional[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = None, + get_displayname_for_registration: Optional[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + ] = None, ) -> None: # Register check_3pid_auth callback if check_3pid_auth is not None: @@ -2145,6 +2158,14 @@ class PasswordAuthProvider: get_username_for_registration, ) + if get_displayname_for_registration is not None: + self.get_displayname_for_registration_callbacks.append( + get_displayname_for_registration, + ) + + if is_3pid_allowed is not None: + self.is_3pid_allowed_callbacks.append(is_3pid_allowed) + def get_supported_login_types(self) -> Mapping[str, Iterable[str]]: """Get the login types supported by this password provider @@ -2343,3 +2364,84 @@ class PasswordAuthProvider: raise SynapseError(code=500, msg="Internal Server Error") return None + + async def get_displayname_for_registration( + self, + uia_results: JsonDict, + params: JsonDict, + ) -> Optional[str]: + """Defines the display name to use when registering the user, using the + credentials and parameters provided during the UIA flow. + + Stops at the first callback that returns a tuple containing at least one string. + + Args: + uia_results: The credentials provided during the UIA flow. + params: The parameters provided by the registration request. + + Returns: + A tuple which first element is the display name, and the second is an MXC URL + to the user's avatar. + """ + for callback in self.get_displayname_for_registration_callbacks: + try: + res = await callback(uia_results, params) + + if isinstance(res, str): + return res + elif res is not None: + # mypy complains that this line is unreachable because it assumes the + # data returned by the module fits the expected type. We just want + # to make sure this is the case. + logger.warning( # type: ignore[unreachable] + "Ignoring non-string value returned by" + " get_displayname_for_registration callback %s: %s", + callback, + res, + ) + except Exception as e: + logger.error( + "Module raised an exception in get_displayname_for_registration: %s", + e, + ) + raise SynapseError(code=500, msg="Internal Server Error") + + return None + + async def is_3pid_allowed( + self, + medium: str, + address: str, + registration: bool, + ) -> bool: + """Check if the user can be allowed to bind a 3PID on this homeserver. + + Args: + medium: The medium of the 3PID. + address: The address of the 3PID. + registration: Whether the 3PID is being bound when registering a new user. + + Returns: + Whether the 3PID is allowed to be bound on this homeserver + """ + for callback in self.is_3pid_allowed_callbacks: + try: + res = await callback(medium, address, registration) + + if res is False: + return res + elif not isinstance(res, bool): + # mypy complains that this line is unreachable because it assumes the + # data returned by the module fits the expected type. We just want + # to make sure this is the case. + logger.warning( # type: ignore[unreachable] + "Ignoring non-string value returned by" + " is_3pid_allowed callback %s: %s", + callback, + res, + ) + except Exception as e: + logger.error("Module raised an exception in is_3pid_allowed: %s", e) + raise SynapseError(code=500, msg="Internal Server Error") + + return True diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index b184a48cb1..36c05f8363 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -495,13 +495,11 @@ class DeviceHandler(DeviceWorkerHandler): "Notifying about update %r/%r, ID: %r", user_id, device_id, position ) - room_ids = await self.store.get_rooms_for_user(user_id) - # specify the user ID too since the user should always get their own device list # updates, even if they aren't in any rooms. - self.notifier.on_new_event( - "device_list_key", position, users=[user_id], rooms=room_ids - ) + users_to_notify = users_who_share_room.union({user_id}) + + self.notifier.on_new_event("device_list_key", position, users=users_to_notify) if hosts: logger.info( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index a37ae0ca09..c8356f233d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -166,9 +166,14 @@ class FederationHandler: oldest_events_with_depth = ( await self.store.get_oldest_event_ids_with_depth_in_room(room_id) ) - insertion_events_to_be_backfilled = ( - await self.store.get_insertion_event_backwards_extremities_in_room(room_id) - ) + + insertion_events_to_be_backfilled: Dict[str, int] = {} + if self.hs.config.experimental.msc2716_enabled: + insertion_events_to_be_backfilled = ( + await self.store.get_insertion_event_backward_extremities_in_room( + room_id + ) + ) logger.debug( "_maybe_backfill_inner: extremities oldest_events_with_depth=%s insertion_events_to_be_backfilled=%s", oldest_events_with_depth, @@ -271,11 +276,12 @@ class FederationHandler: ] logger.debug( - "room_id: %s, backfill: current_depth: %s, limit: %s, max_depth: %s, extrems: %s filtered_sorted_extremeties_tuple: %s", + "room_id: %s, backfill: current_depth: %s, limit: %s, max_depth: %s, extrems (%d): %s filtered_sorted_extremeties_tuple: %s", room_id, current_depth, limit, max_depth, + len(sorted_extremeties_tuple), sorted_extremeties_tuple, filtered_sorted_extremeties_tuple, ) @@ -510,7 +516,7 @@ class FederationHandler: await self.store.upsert_room_on_join( room_id=room_id, room_version=room_version_obj, - auth_events=auth_chain, + state_events=state, ) max_stream_id = await self._federation_event_handler.process_remote_join( @@ -1047,6 +1053,19 @@ class FederationHandler: limit = min(limit, 100) events = await self.store.get_backfill_events(room_id, pdu_list, limit) + logger.debug( + "on_backfill_request: backfill events=%s", + [ + "event_id=%s,depth=%d,body=%s,prevs=%s\n" + % ( + event.event_id, + event.depth, + event.content.get("body", event.type), + event.prev_event_ids(), + ) + for event in events + ], + ) events = await filter_events_for_server(self.storage, origin, events) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 3905f60b3a..7683246bef 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -419,10 +419,8 @@ class FederationEventHandler: Raises: SynapseError if the response is in some way invalid. """ - event_map = {e.event_id: e for e in itertools.chain(auth_events, state)} - create_event = None - for e in auth_events: + for e in state: if (e.type, e.state_key) == (EventTypes.Create, ""): create_event = e break @@ -439,11 +437,6 @@ class FederationEventHandler: if room_version.identifier != room_version_id: raise SynapseError(400, "Room version mismatch") - # filter out any events we have already seen - seen_remotes = await self._store.have_seen_events(room_id, event_map.keys()) - for s in seen_remotes: - event_map.pop(s, None) - # persist the auth chain and state events. # # any invalid events here will be marked as rejected, and we'll carry on. @@ -455,7 +448,9 @@ class FederationEventHandler: # signatures right now doesn't mean that we will *never* be able to, so it # is premature to reject them. # - await self._auth_and_persist_outliers(room_id, event_map.values()) + await self._auth_and_persist_outliers( + room_id, itertools.chain(auth_events, state) + ) # and now persist the join event itself. logger.info("Peristing join-via-remote %s", event) @@ -508,7 +503,11 @@ class FederationEventHandler: f"room {ev.room_id}, when we were backfilling in {room_id}" ) - await self._process_pulled_events(dest, events, backfilled=True) + await self._process_pulled_events( + dest, + events, + backfilled=True, + ) async def _get_missing_events_for_pdu( self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int @@ -626,11 +625,24 @@ class FederationEventHandler: backfilled: True if this is part of a historical batch of events (inhibits notification to clients, and validation of device keys.) """ + logger.debug( + "processing pulled backfilled=%s events=%s", + backfilled, + [ + "event_id=%s,depth=%d,body=%s,prevs=%s\n" + % ( + event.event_id, + event.depth, + event.content.get("body", event.type), + event.prev_event_ids(), + ) + for event in events + ], + ) # We want to sort these by depth so we process them and # tell clients about them in order. sorted_events = sorted(events, key=lambda x: x.depth) - for ev in sorted_events: with nested_logging_context(ev.event_id): await self._process_pulled_event(origin, ev, backfilled=backfilled) @@ -992,6 +1004,8 @@ class FederationEventHandler: await self._run_push_actions_and_persist_event(event, context, backfilled) + await self._handle_marker_event(origin, event) + if backfilled or context.rejected: return @@ -1071,8 +1085,6 @@ class FederationEventHandler: event.sender, ) - await self._handle_marker_event(origin, event) - async def _resync_device(self, sender: str) -> None: """We have detected that the device list for the given user may be out of sync, so we try and resync them. @@ -1228,6 +1240,16 @@ class FederationEventHandler: """ event_map = {event.event_id: event for event in events} + # filter out any events we have already seen. This might happen because + # the events were eagerly pushed to us (eg, during a room join), or because + # another thread has raced against us since we decided to request the event. + # + # This is just an optimisation, so it doesn't need to be watertight - the event + # persister does another round of deduplication. + seen_remotes = await self._store.have_seen_events(room_id, event_map.keys()) + for s in seen_remotes: + event_map.pop(s, None) + # XXX: it might be possible to kick this process off in parallel with fetching # the events. while event_map: @@ -1323,7 +1345,14 @@ class FederationEventHandler: return event, context events_to_persist = (x for x in (prep(event) for event in fetched_events) if x) - await self.persist_events_and_notify(room_id, tuple(events_to_persist)) + await self.persist_events_and_notify( + room_id, + tuple(events_to_persist), + # Mark these events backfilled as they're historic events that will + # eventually be backfilled. For example, missing events we fetch + # during backfill should be marked as backfilled as well. + backfilled=True, + ) async def _check_event_auth( self, @@ -1693,31 +1722,22 @@ class FederationEventHandler: event_id: the event for which we are lacking auth events """ try: - remote_event_map = { - e.event_id: e - for e in await self._federation_client.get_event_auth( - destination, room_id, event_id - ) - } + remote_events = await self._federation_client.get_event_auth( + destination, room_id, event_id + ) + except RequestSendFailed as e1: # The other side isn't around or doesn't implement the # endpoint, so lets just bail out. logger.info("Failed to get event auth from remote: %s", e1) return - logger.info("/event_auth returned %i events", len(remote_event_map)) + logger.info("/event_auth returned %i events", len(remote_events)) # `event` may be returned, but we should not yet process it. - remote_event_map.pop(event_id, None) - - # nor should we reprocess any events we have already seen. - seen_remotes = await self._store.have_seen_events( - room_id, remote_event_map.keys() - ) - for s in seen_remotes: - remote_event_map.pop(s, None) + remote_auth_events = (e for e in remote_events if e.event_id != event_id) - await self._auth_and_persist_outliers(room_id, remote_event_map.values()) + await self._auth_and_persist_outliers(room_id, remote_auth_events) async def _update_context_for_auth_events( self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index b37250aa38..4d0da84287 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -490,12 +490,12 @@ class EventCreationHandler: requester: Requester, event_dict: dict, txn_id: Optional[str] = None, + allow_no_prev_events: bool = False, prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, require_consent: bool = True, outlier: bool = False, historical: bool = False, - allow_no_prev_events: bool = False, depth: Optional[int] = None, ) -> Tuple[EventBase, EventContext]: """ @@ -510,6 +510,10 @@ class EventCreationHandler: requester event_dict: An entire event txn_id + allow_no_prev_events: Whether to allow this event to be created an empty + list of prev_events. Normally this is prohibited just because most + events should have a prev_event and we should only use this in special + cases like MSC2716. prev_event_ids: the forward extremities to use as the prev_events for the new event. @@ -546,10 +550,11 @@ class EventCreationHandler: if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "": room_version_id = event_dict["content"]["room_version"] - room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) - if not room_version_obj: + maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not maybe_room_version_obj: # this can happen if support is withdrawn for a room version raise UnsupportedRoomVersionError(room_version_id) + room_version_obj = maybe_room_version_obj else: try: room_version_obj = await self.store.get_room_version( @@ -604,10 +609,10 @@ class EventCreationHandler: event, context = await self.create_new_client_event( builder=builder, requester=requester, + allow_no_prev_events=allow_no_prev_events, prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, depth=depth, - allow_no_prev_events=allow_no_prev_events, ) # In an ideal world we wouldn't need the second part of this condition. However, @@ -764,6 +769,7 @@ class EventCreationHandler: self, requester: Requester, event_dict: dict, + allow_no_prev_events: bool = False, prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, ratelimit: bool = True, @@ -781,6 +787,10 @@ class EventCreationHandler: Args: requester: The requester sending the event. event_dict: An entire event. + allow_no_prev_events: Whether to allow this event to be created an empty + list of prev_events. Normally this is prohibited just because most + events should have a prev_event and we should only use this in special + cases like MSC2716. prev_event_ids: The event IDs to use as the prev events. Should normally be left as None to automatically request them @@ -880,16 +890,20 @@ class EventCreationHandler: self, builder: EventBuilder, requester: Optional[Requester] = None, + allow_no_prev_events: bool = False, prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, depth: Optional[int] = None, - allow_no_prev_events: bool = False, ) -> Tuple[EventBase, EventContext]: """Create a new event for a local client Args: builder: requester: + allow_no_prev_events: Whether to allow this event to be created an empty + list of prev_events. Normally this is prohibited just because most + events should have a prev_event and we should only use this in special + cases like MSC2716. prev_event_ids: the forward extremities to use as the prev_events for the new event. @@ -908,7 +922,6 @@ class EventCreationHandler: Returns: Tuple of created event, context """ - # Strip down the auth_event_ids to only what we need to auth the event. # For example, we don't need extra m.room.member that don't match event.sender full_state_ids_at_event = None @@ -1133,12 +1146,13 @@ class EventCreationHandler: room_version_id = event.content.get( "room_version", RoomVersions.V1.identifier ) - room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) - if not room_version_obj: + maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) + if not maybe_room_version_obj: raise UnsupportedRoomVersionError( "Attempt to create a room with unsupported room version %s" % (room_version_id,) ) + room_version_obj = maybe_room_version_obj else: room_version_obj = await self.store.get_room_version(event.room_id) diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index deb3539751..8f71d975e9 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -544,9 +544,9 @@ class OidcProvider: """ metadata = await self.load_metadata() token_endpoint = metadata.get("token_endpoint") - raw_headers = { + raw_headers: Dict[str, str] = { "Content-Type": "application/x-www-form-urlencoded", - "User-Agent": self._http_client.user_agent, + "User-Agent": self._http_client.user_agent.decode("ascii"), "Accept": "application/json", } diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 067c43ae47..b223b72623 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -204,25 +204,27 @@ class BasePresenceHandler(abc.ABC): Returns: dict: `user_id` -> `UserPresenceState` """ - states = { - user_id: self.user_to_current_state.get(user_id, None) - for user_id in user_ids - } + states = {} + missing = [] + for user_id in user_ids: + state = self.user_to_current_state.get(user_id, None) + if state: + states[user_id] = state + else: + missing.append(user_id) - missing = [user_id for user_id, state in states.items() if not state] if missing: # There are things not in our in memory cache. Lets pull them out of # the database. res = await self.store.get_presence_for_users(missing) states.update(res) - missing = [user_id for user_id, state in states.items() if not state] - if missing: - new = { - user_id: UserPresenceState.default(user_id) for user_id in missing - } - states.update(new) - self.user_to_current_state.update(new) + for user_id in missing: + # if user has no state in database, create the state + if not res.get(user_id, None): + new_state = UserPresenceState.default(user_id) + states[user_id] = new_state + self.user_to_current_state[user_id] = new_state return states diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a719d5eef3..80320d2c07 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -320,12 +320,12 @@ class RegistrationHandler: if fail_count > 10: raise SynapseError(500, "Unable to find a suitable guest user ID") - localpart = await self.store.generate_user_id() - user = UserID(localpart, self.hs.hostname) + generated_localpart = await self.store.generate_user_id() + user = UserID(generated_localpart, self.hs.hostname) user_id = user.to_string() self.check_user_id_not_appservice_exclusive(user_id) if generate_display_name: - default_display_name = localpart + default_display_name = generated_localpart try: await self.register_with_store( user_id=user_id, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 1420d67729..a990727fc5 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -694,11 +694,6 @@ class RoomCreationHandler: if not is_requester_admin and not ( await self.spam_checker.user_may_create_room(user_id) - and await self.spam_checker.user_may_create_room_with_invites( - user_id, - invite_list, - invite_3pid_list, - ) ): raise SynapseError( 403, "You are not permitted to create rooms", Codes.FORBIDDEN diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index f880aa93d2..f8137ec04c 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -13,10 +13,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -def generate_fake_event_id() -> str: - return "$fake_" + random_string(43) - - class RoomBatchHandler: def __init__(self, hs: "HomeServer"): self.hs = hs @@ -182,11 +178,12 @@ class RoomBatchHandler: state_event_ids_at_start = [] auth_event_ids = initial_auth_event_ids.copy() - # Make the state events float off on their own so we don't have a - # bunch of `@mxid joined the room` noise between each batch - prev_event_id_for_state_chain = generate_fake_event_id() + # Make the state events float off on their own by specifying no + # prev_events for the first one in the chain so we don't have a bunch of + # `@mxid joined the room` noise between each batch. + prev_event_ids_for_state_chain: List[str] = [] - for state_event in state_events_at_start: + for index, state_event in enumerate(state_events_at_start): assert_params_in_dict( state_event, ["type", "origin_server_ts", "content", "sender"] ) @@ -222,7 +219,10 @@ class RoomBatchHandler: content=event_dict["content"], outlier=True, historical=True, - prev_event_ids=[prev_event_id_for_state_chain], + # Only the first event in the chain should be floating. + # The rest should hang off each other in a chain. + allow_no_prev_events=index == 0, + prev_event_ids=prev_event_ids_for_state_chain, # Make sure to use a copy of this list because we modify it # later in the loop here. Otherwise it will be the same # reference and also update in the event when we append later. @@ -242,7 +242,10 @@ class RoomBatchHandler: event_dict, outlier=True, historical=True, - prev_event_ids=[prev_event_id_for_state_chain], + # Only the first event in the chain should be floating. + # The rest should hang off each other in a chain. + allow_no_prev_events=index == 0, + prev_event_ids=prev_event_ids_for_state_chain, # Make sure to use a copy of this list because we modify it # later in the loop here. Otherwise it will be the same # reference and also update in the event when we append later. @@ -253,7 +256,7 @@ class RoomBatchHandler: state_event_ids_at_start.append(event_id) auth_event_ids.append(event_id) # Connect all the state in a floating chain - prev_event_id_for_state_chain = event_id + prev_event_ids_for_state_chain = [event_id] return state_event_ids_at_start @@ -261,7 +264,6 @@ class RoomBatchHandler: self, events_to_create: List[JsonDict], room_id: str, - initial_prev_event_ids: List[str], inherited_depth: int, auth_event_ids: List[str], app_service_requester: Requester, @@ -277,9 +279,6 @@ class RoomBatchHandler: events_to_create: List of historical events to create in JSON dictionary format. room_id: Room where you want the events persisted in. - initial_prev_event_ids: These will be the prev_events for the first - event created. Each event created afterwards will point to the - previous event created. inherited_depth: The depth to create the events at (you will probably by calling inherit_depth_from_prev_ids(...)). auth_event_ids: Define which events allow you to create the given @@ -291,11 +290,14 @@ class RoomBatchHandler: """ assert app_service_requester.app_service - prev_event_ids = initial_prev_event_ids.copy() + # Make the historical event chain float off on its own by specifying no + # prev_events for the first event in the chain which causes the HS to + # ask for the state at the start of the batch later. + prev_event_ids: List[str] = [] event_ids = [] events_to_persist = [] - for ev in events_to_create: + for index, ev in enumerate(events_to_create): assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) assert self.hs.is_mine_id(ev["sender"]), "User must be our own: %s" % ( @@ -319,6 +321,9 @@ class RoomBatchHandler: ev["sender"], app_service_requester.app_service ), event_dict, + # Only the first event in the chain should be floating. + # The rest should hang off each other in a chain. + allow_no_prev_events=index == 0, prev_event_ids=event_dict.get("prev_events"), auth_event_ids=auth_event_ids, historical=True, @@ -370,7 +375,6 @@ class RoomBatchHandler: events_to_create: List[JsonDict], room_id: str, batch_id_to_connect_to: str, - initial_prev_event_ids: List[str], inherited_depth: int, auth_event_ids: List[str], app_service_requester: Requester, @@ -385,9 +389,6 @@ class RoomBatchHandler: room_id: Room where you want the events created in. batch_id_to_connect_to: The batch_id from the insertion event you want this batch to connect to. - initial_prev_event_ids: These will be the prev_events for the first - event created. Each event created afterwards will point to the - previous event created. inherited_depth: The depth to create the events at (you will probably by calling inherit_depth_from_prev_ids(...)). auth_event_ids: Define which events allow you to create the given @@ -436,7 +437,6 @@ class RoomBatchHandler: event_ids = await self.persist_historical_events( events_to_create=events_to_create, room_id=room_id, - initial_prev_event_ids=initial_prev_event_ids, inherited_depth=inherited_depth, auth_event_ids=auth_event_ids, app_service_requester=app_service_requester, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 3dd5e1b6e4..b2adc0f48b 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -82,6 +82,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.event_auth_handler = hs.get_event_auth_handler() self.member_linearizer: Linearizer = Linearizer(name="member") + self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter") self.clock = hs.get_clock() self.spam_checker = hs.get_spam_checker() @@ -116,6 +117,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, ) + self._third_party_invite_limiter = Ratelimiter( + store=self.store, + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_third_party_invite.per_second, + burst_count=hs.config.ratelimiting.rc_third_party_invite.burst_count, + ) + self.request_ratelimiter = hs.get_request_ratelimiter() @abc.abstractmethod @@ -261,7 +269,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): target: UserID, room_id: str, membership: str, - prev_event_ids: List[str], + allow_no_prev_events: bool = False, + prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, txn_id: Optional[str] = None, ratelimit: bool = True, @@ -279,8 +288,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): target: room_id: membership: - prev_event_ids: The event IDs to use as the prev events + allow_no_prev_events: Whether to allow this event to be created an empty + list of prev_events. Normally this is prohibited just because most + events should have a prev_event and we should only use this in special + cases like MSC2716. + prev_event_ids: The event IDs to use as the prev events auth_event_ids: The event ids to use as the auth_events for the new event. Should normally be left as None, which will cause them to be calculated @@ -337,6 +350,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): "membership": membership, }, txn_id=txn_id, + allow_no_prev_events=allow_no_prev_events, prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, require_consent=require_consent, @@ -439,6 +453,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): require_consent: bool = True, outlier: bool = False, historical: bool = False, + allow_no_prev_events: bool = False, prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, ) -> Tuple[str, int]: @@ -463,6 +478,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): historical: Indicates whether the message is being inserted back in time around some existing events. This is used to skip a few checks and mark the event as backfilled. + allow_no_prev_events: Whether to allow this event to be created an empty + list of prev_events. Normally this is prohibited just because most + events should have a prev_event and we should only use this in special + cases like MSC2716. prev_event_ids: The event IDs to use as the prev events auth_event_ids: The event ids to use as the auth_events for the new event. @@ -482,24 +501,32 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): key = (room_id,) - with (await self.member_linearizer.queue(key)): - result = await self.update_membership_locked( - requester, - target, - room_id, - action, - txn_id=txn_id, - remote_room_hosts=remote_room_hosts, - third_party_signed=third_party_signed, - ratelimit=ratelimit, - content=content, - new_room=new_room, - require_consent=require_consent, - outlier=outlier, - historical=historical, - prev_event_ids=prev_event_ids, - auth_event_ids=auth_event_ids, - ) + as_id = object() + if requester.app_service: + as_id = requester.app_service.id + + # We first linearise by the application service (to try to limit concurrent joins + # by application services), and then by room ID. + with (await self.member_as_limiter.queue(as_id)): + with (await self.member_linearizer.queue(key)): + result = await self.update_membership_locked( + requester, + target, + room_id, + action, + txn_id=txn_id, + remote_room_hosts=remote_room_hosts, + third_party_signed=third_party_signed, + ratelimit=ratelimit, + content=content, + new_room=new_room, + require_consent=require_consent, + outlier=outlier, + historical=historical, + allow_no_prev_events=allow_no_prev_events, + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, + ) return result @@ -518,6 +545,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): require_consent: bool = True, outlier: bool = False, historical: bool = False, + allow_no_prev_events: bool = False, prev_event_ids: Optional[List[str]] = None, auth_event_ids: Optional[List[str]] = None, ) -> Tuple[str, int]: @@ -544,6 +572,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): historical: Indicates whether the message is being inserted back in time around some existing events. This is used to skip a few checks and mark the event as backfilled. + allow_no_prev_events: Whether to allow this event to be created an empty + list of prev_events. Normally this is prohibited just because most + events should have a prev_event and we should only use this in special + cases like MSC2716. prev_event_ids: The event IDs to use as the prev events auth_event_ids: The event ids to use as the auth_events for the new event. @@ -673,6 +705,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): membership=effective_membership_state, txn_id=txn_id, ratelimit=ratelimit, + allow_no_prev_events=allow_no_prev_events, prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, content=content, @@ -1295,7 +1328,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # We need to rate limit *before* we send out any 3PID invites, so we # can't just rely on the standard ratelimiting of events. - await self.request_ratelimiter.ratelimit(requester) + await self._third_party_invite_limiter.ratelimit(requester) can_invite = await self.third_party_event_rules.check_threepid_can_be_invited( medium, address, room_id diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 02bb5ae72f..0e0e58de02 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -14,8 +14,9 @@ import itertools import logging -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional +from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +import attr from unpaddedbase64 import decode_base64, encode_base64 from synapse.api.constants import EventTypes, Membership @@ -32,6 +33,20 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _SearchResult: + # The count of results. + count: int + # A mapping of event ID to the rank of that event. + rank_map: Dict[str, int] + # A list of the resulting events. + allowed_events: List[EventBase] + # A map of room ID to results. + room_groups: Dict[str, JsonDict] + # A set of event IDs to highlight. + highlights: Set[str] + + class SearchHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -43,6 +58,8 @@ class SearchHandler: self.state_store = self.storage.state self.auth = hs.get_auth() + self._msc3666_enabled = hs.config.experimental.msc3666_enabled + async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]: """Retrieves room IDs of old rooms in the history of an upgraded room. @@ -98,7 +115,7 @@ class SearchHandler: """Performs a full text search for a user. Args: - user + user: The user performing the search. content: Search parameters batch: The next_batch parameter. Used for pagination. @@ -154,6 +171,8 @@ class SearchHandler: # Include context around each event? event_context = room_cat.get("event_context", None) + before_limit = after_limit = None + include_profile = False # Group results together? May allow clients to paginate within a # group @@ -180,6 +199,73 @@ class SearchHandler: % (set(group_keys) - {"room_id", "sender"},), ) + return await self._search( + user, + batch_group, + batch_group_key, + batch_token, + search_term, + keys, + filter_dict, + order_by, + include_state, + group_keys, + event_context, + before_limit, + after_limit, + include_profile, + ) + + async def _search( + self, + user: UserID, + batch_group: Optional[str], + batch_group_key: Optional[str], + batch_token: Optional[str], + search_term: str, + keys: List[str], + filter_dict: JsonDict, + order_by: str, + include_state: bool, + group_keys: List[str], + event_context: Optional[bool], + before_limit: Optional[int], + after_limit: Optional[int], + include_profile: bool, + ) -> JsonDict: + """Performs a full text search for a user. + + Args: + user: The user performing the search. + batch_group: Pagination information. + batch_group_key: Pagination information. + batch_token: Pagination information. + search_term: Search term to search for + keys: List of keys to search in, currently supports + "content.body", "content.name", "content.topic" + filter_dict: The JSON to build a filter out of. + order_by: How to order the results. Valid values ore "rank" and "recent". + include_state: True if the state of the room at each result should + be included. + group_keys: A list of ways to group the results. Valid values are + "room_id" and "sender". + event_context: True to include contextual events around results. + before_limit: + The number of events before a result to include as context. + + Only used if event_context is True. + after_limit: + The number of events after a result to include as context. + + Only used if event_context is True. + include_profile: True if historical profile information should be + included in the event context. + + Only used if event_context is True. + + Returns: + dict to be returned to the client with results of search + """ search_filter = Filter(self.hs, filter_dict) # TODO: Search through left rooms too @@ -214,261 +300,399 @@ class SearchHandler: } } - rank_map = {} # event_id -> rank of event - allowed_events = [] - # Holds result of grouping by room, if applicable - room_groups: Dict[str, JsonDict] = {} - # Holds result of grouping by sender, if applicable - sender_group: Dict[str, JsonDict] = {} + sender_group: Optional[Dict[str, JsonDict]] - # Holds the next_batch for the entire result set if one of those exists - global_next_batch = None + if order_by == "rank": + search_result, sender_group = await self._search_by_rank( + user, room_ids, search_term, keys, search_filter + ) + # Unused return values for rank search. + global_next_batch = None + elif order_by == "recent": + search_result, global_next_batch = await self._search_by_recent( + user, + room_ids, + search_term, + keys, + search_filter, + batch_group, + batch_group_key, + batch_token, + ) + # Unused return values for recent search. + sender_group = None + else: + # We should never get here due to the guard earlier. + raise NotImplementedError() - highlights = set() + logger.info("Found %d events to return", len(search_result.allowed_events)) - count = None + # If client has asked for "context" for each event (i.e. some surrounding + # events and state), fetch that + if event_context is not None: + # Note that before and after limit must be set in this case. + assert before_limit is not None + assert after_limit is not None + + contexts = await self._calculate_event_contexts( + user, + search_result.allowed_events, + before_limit, + after_limit, + include_profile, + ) + else: + contexts = {} - if order_by == "rank": - search_result = await self.store.search_msgs(room_ids, search_term, keys) + # TODO: Add a limit - count = search_result["count"] + state_results = {} + if include_state: + for room_id in {e.room_id for e in search_result.allowed_events}: + state = await self.state_handler.get_current_state(room_id) + state_results[room_id] = list(state.values()) - if search_result["highlights"]: - highlights.update(search_result["highlights"]) + aggregations = None + if self._msc3666_enabled: + aggregations = await self.store.get_bundled_aggregations( + # Generate an iterable of EventBase for all the events that will be + # returned, including contextual events. + itertools.chain( + # The events_before and events_after for each context. + itertools.chain.from_iterable( + itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type] + for context in contexts.values() + ), + # The returned events. + search_result.allowed_events, + ), + user.to_string(), + ) - results = search_result["results"] + # We're now about to serialize the events. We should not make any + # blocking calls after this. Otherwise, the 'age' will be wrong. - results_map = {r["event"].event_id: r for r in results} + time_now = self.clock.time_msec() - rank_map.update({r["event"].event_id: r["rank"] for r in results}) + for context in contexts.values(): + context["events_before"] = self._event_serializer.serialize_events( + context["events_before"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type] + ) + context["events_after"] = self._event_serializer.serialize_events( + context["events_after"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type] + ) - filtered_events = await search_filter.filter([r["event"] for r in results]) + results = [ + { + "rank": search_result.rank_map[e.event_id], + "result": self._event_serializer.serialize_event( + e, time_now, bundle_aggregations=aggregations + ), + "context": contexts.get(e.event_id, {}), + } + for e in search_result.allowed_events + ] - events = await filter_events_for_client( - self.storage, user.to_string(), filtered_events - ) + rooms_cat_res: JsonDict = { + "results": results, + "count": search_result.count, + "highlights": list(search_result.highlights), + } - events.sort(key=lambda e: -rank_map[e.event_id]) - allowed_events = events[: search_filter.limit] + if state_results: + rooms_cat_res["state"] = { + room_id: self._event_serializer.serialize_events(state_events, time_now) + for room_id, state_events in state_results.items() + } - for e in allowed_events: - rm = room_groups.setdefault( - e.room_id, {"results": [], "order": rank_map[e.event_id]} - ) - rm["results"].append(e.event_id) + if search_result.room_groups and "room_id" in group_keys: + rooms_cat_res.setdefault("groups", {})[ + "room_id" + ] = search_result.room_groups - s = sender_group.setdefault( - e.sender, {"results": [], "order": rank_map[e.event_id]} - ) - s["results"].append(e.event_id) + if sender_group and "sender" in group_keys: + rooms_cat_res.setdefault("groups", {})["sender"] = sender_group - elif order_by == "recent": - room_events: List[EventBase] = [] - i = 0 - - pagination_token = batch_token - - # We keep looping and we keep filtering until we reach the limit - # or we run out of things. - # But only go around 5 times since otherwise synapse will be sad. - while len(room_events) < search_filter.limit and i < 5: - i += 1 - search_result = await self.store.search_rooms( - room_ids, - search_term, - keys, - search_filter.limit * 2, - pagination_token=pagination_token, - ) + if global_next_batch: + rooms_cat_res["next_batch"] = global_next_batch - if search_result["highlights"]: - highlights.update(search_result["highlights"]) + return {"search_categories": {"room_events": rooms_cat_res}} - count = search_result["count"] + async def _search_by_rank( + self, + user: UserID, + room_ids: Collection[str], + search_term: str, + keys: Iterable[str], + search_filter: Filter, + ) -> Tuple[_SearchResult, Dict[str, JsonDict]]: + """ + Performs a full text search for a user ordering by rank. - results = search_result["results"] + Args: + user: The user performing the search. + room_ids: List of room ids to search in + search_term: Search term to search for + keys: List of keys to search in, currently supports + "content.body", "content.name", "content.topic" + search_filter: The event filter to use. - results_map = {r["event"].event_id: r for r in results} + Returns: + A tuple of: + The search results. + A map of sender ID to results. + """ + rank_map = {} # event_id -> rank of event + # Holds result of grouping by room, if applicable + room_groups: Dict[str, JsonDict] = {} + # Holds result of grouping by sender, if applicable + sender_group: Dict[str, JsonDict] = {} - rank_map.update({r["event"].event_id: r["rank"] for r in results}) + search_result = await self.store.search_msgs(room_ids, search_term, keys) - filtered_events = await search_filter.filter( - [r["event"] for r in results] - ) + if search_result["highlights"]: + highlights = search_result["highlights"] + else: + highlights = set() - events = await filter_events_for_client( - self.storage, user.to_string(), filtered_events - ) + results = search_result["results"] - room_events.extend(events) - room_events = room_events[: search_filter.limit] + # event_id -> rank of event + rank_map = {r["event"].event_id: r["rank"] for r in results} - if len(results) < search_filter.limit * 2: - pagination_token = None - break - else: - pagination_token = results[-1]["pagination_token"] - - for event in room_events: - group = room_groups.setdefault(event.room_id, {"results": []}) - group["results"].append(event.event_id) - - if room_events and len(room_events) >= search_filter.limit: - last_event_id = room_events[-1].event_id - pagination_token = results_map[last_event_id]["pagination_token"] - - # We want to respect the given batch group and group keys so - # that if people blindly use the top level `next_batch` token - # it returns more from the same group (if applicable) rather - # than reverting to searching all results again. - if batch_group and batch_group_key: - global_next_batch = encode_base64( - ( - "%s\n%s\n%s" - % (batch_group, batch_group_key, pagination_token) - ).encode("ascii") - ) - else: - global_next_batch = encode_base64( - ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii") - ) + filtered_events = await search_filter.filter([r["event"] for r in results]) - for room_id, group in room_groups.items(): - group["next_batch"] = encode_base64( - ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode( - "ascii" - ) - ) + events = await filter_events_for_client( + self.storage, user.to_string(), filtered_events + ) - allowed_events.extend(room_events) + events.sort(key=lambda e: -rank_map[e.event_id]) + allowed_events = events[: search_filter.limit] - else: - # We should never get here due to the guard earlier. - raise NotImplementedError() + for e in allowed_events: + rm = room_groups.setdefault( + e.room_id, {"results": [], "order": rank_map[e.event_id]} + ) + rm["results"].append(e.event_id) - logger.info("Found %d events to return", len(allowed_events)) + s = sender_group.setdefault( + e.sender, {"results": [], "order": rank_map[e.event_id]} + ) + s["results"].append(e.event_id) + + return ( + _SearchResult( + search_result["count"], + rank_map, + allowed_events, + room_groups, + highlights, + ), + sender_group, + ) - # If client has asked for "context" for each event (i.e. some surrounding - # events and state), fetch that - if event_context is not None: - now_token = self.hs.get_event_sources().get_current_token() + async def _search_by_recent( + self, + user: UserID, + room_ids: Collection[str], + search_term: str, + keys: Iterable[str], + search_filter: Filter, + batch_group: Optional[str], + batch_group_key: Optional[str], + batch_token: Optional[str], + ) -> Tuple[_SearchResult, Optional[str]]: + """ + Performs a full text search for a user ordering by recent. - contexts = {} - for event in allowed_events: - res = await self.store.get_events_around( - event.room_id, event.event_id, before_limit, after_limit - ) + Args: + user: The user performing the search. + room_ids: List of room ids to search in + search_term: Search term to search for + keys: List of keys to search in, currently supports + "content.body", "content.name", "content.topic" + search_filter: The event filter to use. + batch_group: Pagination information. + batch_group_key: Pagination information. + batch_token: Pagination information. - logger.info( - "Context for search returned %d and %d events", - len(res.events_before), - len(res.events_after), - ) + Returns: + A tuple of: + The search results. + Optionally, a pagination token. + """ + rank_map = {} # event_id -> rank of event + # Holds result of grouping by room, if applicable + room_groups: Dict[str, JsonDict] = {} - events_before = await filter_events_for_client( - self.storage, user.to_string(), res.events_before - ) + # Holds the next_batch for the entire result set if one of those exists + global_next_batch = None - events_after = await filter_events_for_client( - self.storage, user.to_string(), res.events_after - ) + highlights = set() - context = { - "events_before": events_before, - "events_after": events_after, - "start": await now_token.copy_and_replace( - "room_key", res.start - ).to_string(self.store), - "end": await now_token.copy_and_replace( - "room_key", res.end - ).to_string(self.store), - } + room_events: List[EventBase] = [] + i = 0 + + pagination_token = batch_token + + # We keep looping and we keep filtering until we reach the limit + # or we run out of things. + # But only go around 5 times since otherwise synapse will be sad. + while len(room_events) < search_filter.limit and i < 5: + i += 1 + search_result = await self.store.search_rooms( + room_ids, + search_term, + keys, + search_filter.limit * 2, + pagination_token=pagination_token, + ) - if include_profile: - senders = { - ev.sender - for ev in itertools.chain(events_before, [event], events_after) - } + if search_result["highlights"]: + highlights.update(search_result["highlights"]) + + count = search_result["count"] - if events_after: - last_event_id = events_after[-1].event_id - else: - last_event_id = event.event_id + results = search_result["results"] - state_filter = StateFilter.from_types( - [(EventTypes.Member, sender) for sender in senders] - ) + results_map = {r["event"].event_id: r for r in results} + + rank_map.update({r["event"].event_id: r["rank"] for r in results}) + + filtered_events = await search_filter.filter([r["event"] for r in results]) + + events = await filter_events_for_client( + self.storage, user.to_string(), filtered_events + ) + + room_events.extend(events) + room_events = room_events[: search_filter.limit] + + if len(results) < search_filter.limit * 2: + break + else: + pagination_token = results[-1]["pagination_token"] + + for event in room_events: + group = room_groups.setdefault(event.room_id, {"results": []}) + group["results"].append(event.event_id) + + if room_events and len(room_events) >= search_filter.limit: + last_event_id = room_events[-1].event_id + pagination_token = results_map[last_event_id]["pagination_token"] + + # We want to respect the given batch group and group keys so + # that if people blindly use the top level `next_batch` token + # it returns more from the same group (if applicable) rather + # than reverting to searching all results again. + if batch_group and batch_group_key: + global_next_batch = encode_base64( + ( + "%s\n%s\n%s" % (batch_group, batch_group_key, pagination_token) + ).encode("ascii") + ) + else: + global_next_batch = encode_base64( + ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii") + ) - state = await self.state_store.get_state_for_event( - last_event_id, state_filter + for room_id, group in room_groups.items(): + group["next_batch"] = encode_base64( + ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode( + "ascii" ) + ) - context["profile_info"] = { - s.state_key: { - "displayname": s.content.get("displayname", None), - "avatar_url": s.content.get("avatar_url", None), - } - for s in state.values() - if s.type == EventTypes.Member and s.state_key in senders - } + return ( + _SearchResult(count, rank_map, room_events, room_groups, highlights), + global_next_batch, + ) - contexts[event.event_id] = context - else: - contexts = {} + async def _calculate_event_contexts( + self, + user: UserID, + allowed_events: List[EventBase], + before_limit: int, + after_limit: int, + include_profile: bool, + ) -> Dict[str, JsonDict]: + """ + Calculates the contextual events for any search results. - # TODO: Add a limit + Args: + user: The user performing the search. + allowed_events: The search results. + before_limit: + The number of events before a result to include as context. + after_limit: + The number of events after a result to include as context. + include_profile: True if historical profile information should be + included in the event context. - time_now = self.clock.time_msec() + Returns: + A map of event ID to contextual information. + """ + now_token = self.hs.get_event_sources().get_current_token() - for context in contexts.values(): - context["events_before"] = self._event_serializer.serialize_events( - context["events_before"], time_now # type: ignore[arg-type] + contexts = {} + for event in allowed_events: + res = await self.store.get_events_around( + event.room_id, event.event_id, before_limit, after_limit ) - context["events_after"] = self._event_serializer.serialize_events( - context["events_after"], time_now # type: ignore[arg-type] + + logger.info( + "Context for search returned %d and %d events", + len(res.events_before), + len(res.events_after), ) - state_results = {} - if include_state: - for room_id in {e.room_id for e in allowed_events}: - state = await self.state_handler.get_current_state(room_id) - state_results[room_id] = list(state.values()) + events_before = await filter_events_for_client( + self.storage, user.to_string(), res.events_before + ) - # We're now about to serialize the events. We should not make any - # blocking calls after this. Otherwise the 'age' will be wrong + events_after = await filter_events_for_client( + self.storage, user.to_string(), res.events_after + ) - results = [] - for e in allowed_events: - results.append( - { - "rank": rank_map[e.event_id], - "result": self._event_serializer.serialize_event(e, time_now), - "context": contexts.get(e.event_id, {}), + context: JsonDict = { + "events_before": events_before, + "events_after": events_after, + "start": await now_token.copy_and_replace( + "room_key", res.start + ).to_string(self.store), + "end": await now_token.copy_and_replace("room_key", res.end).to_string( + self.store + ), + } + + if include_profile: + senders = { + ev.sender + for ev in itertools.chain(events_before, [event], events_after) } - ) - rooms_cat_res = { - "results": results, - "count": count, - "highlights": list(highlights), - } + if events_after: + last_event_id = events_after[-1].event_id + else: + last_event_id = event.event_id - if state_results: - s = {} - for room_id, state_events in state_results.items(): - s[room_id] = self._event_serializer.serialize_events( - state_events, time_now + state_filter = StateFilter.from_types( + [(EventTypes.Member, sender) for sender in senders] ) - rooms_cat_res["state"] = s - - if room_groups and "room_id" in group_keys: - rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups + state = await self.state_store.get_state_for_event( + last_event_id, state_filter + ) - if sender_group and "sender" in group_keys: - rooms_cat_res.setdefault("groups", {})["sender"] = sender_group + context["profile_info"] = { + s.state_key: { + "displayname": s.content.get("displayname", None), + "avatar_url": s.content.get("avatar_url", None), + } + for s in state.values() + if s.type == EventTypes.Member and s.state_key in senders + } - if global_next_batch: - rooms_cat_res["next_batch"] = global_next_batch + contexts[event.event_id] = context - return {"search_categories": {"room_events": rooms_cat_res}} + return contexts diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py index 1a062a784c..a305a66860 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py @@ -106,7 +106,7 @@ async def _sendmail( factory = build_sender_factory(hostname=smtphost if enable_tls else None) reactor.connectTCP( - smtphost, # type: ignore[arg-type] + smtphost, smtpport, factory, timeout=30, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c72ed7c290..e6050cbce6 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1289,23 +1289,54 @@ class SyncHandler: # room with by looking at all users that have left a room plus users # that were in a room we've left. - users_who_share_room = await self.store.get_users_who_share_room_with_user( - user_id - ) - - # Always tell the user about their own devices. We check as the user - # ID is almost certainly already included (unless they're not in any - # rooms) and taking a copy of the set is relatively expensive. - if user_id not in users_who_share_room: - users_who_share_room = set(users_who_share_room) - users_who_share_room.add(user_id) + users_that_have_changed = set() - tracked_users = users_who_share_room + joined_rooms = sync_result_builder.joined_room_ids - # Step 1a, check for changes in devices of users we share a room with - users_that_have_changed = await self.store.get_users_whose_devices_changed( - since_token.device_list_key, tracked_users + # Step 1a, check for changes in devices of users we share a room + # with + # + # We do this in two different ways depending on what we have cached. + # If we already have a list of all the user that have changed since + # the last sync then it's likely more efficient to compare the rooms + # they're in with the rooms the syncing user is in. + # + # If we don't have that info cached then we get all the users that + # share a room with our user and check if those users have changed. + changed_users = self.store.get_cached_device_list_changes( + since_token.device_list_key ) + if changed_users is not None: + result = await self.store.get_rooms_for_users_with_stream_ordering( + changed_users + ) + + for changed_user_id, entries in result.items(): + # Check if the changed user shares any rooms with the user, + # or if the changed user is the syncing user (as we always + # want to include device list updates of their own devices). + if user_id == changed_user_id or any( + e.room_id in joined_rooms for e in entries + ): + users_that_have_changed.add(changed_user_id) + else: + users_who_share_room = ( + await self.store.get_users_who_share_room_with_user(user_id) + ) + + # Always tell the user about their own devices. We check as the user + # ID is almost certainly already included (unless they're not in any + # rooms) and taking a copy of the set is relatively expensive. + if user_id not in users_who_share_room: + users_who_share_room = set(users_who_share_room) + users_who_share_room.add(user_id) + + tracked_users = users_who_share_room + users_that_have_changed = ( + await self.store.get_users_whose_devices_changed( + since_token.device_list_key, tracked_users + ) + ) # Step 1b, check for newly joined rooms for room_id in newly_joined_rooms: @@ -1329,7 +1360,14 @@ class SyncHandler: newly_left_users.update(left_users) # Remove any users that we still share a room with. - newly_left_users -= users_who_share_room + left_users_rooms = ( + await self.store.get_rooms_for_users_with_stream_ordering( + newly_left_users + ) + ) + for user_id, entries in left_users_rooms.items(): + if any(e.room_id in joined_rooms for e in entries): + newly_left_users.discard(user_id) return DeviceLists(changed=users_that_have_changed, left=newly_left_users) else: @@ -1348,8 +1386,8 @@ class SyncHandler: if sync_result_builder.since_token is not None: since_stream_id = int(sync_result_builder.since_token.to_device_key) - if since_stream_id != int(now_token.to_device_key): - messages, stream_id = await self.store.get_new_messages_for_device( + if device_id is not None and since_stream_id != int(now_token.to_device_key): + messages, stream_id = await self.store.get_messages_for_device( user_id, device_id, since_stream_id, now_token.to_device_key ) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index e43c22832d..e4bed1c937 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -446,7 +446,7 @@ class TypingWriterHandler(FollowerTypingHandler): class TypingNotificationEventSource(EventSource[int, JsonDict]): def __init__(self, hs: "HomeServer"): - self.hs = hs + self._main_store = hs.get_datastore() self.clock = hs.get_clock() # We can't call get_typing_handler here because there's a cycle: # @@ -487,7 +487,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): continue if not await service.matches_user_in_member_list( - room_id, handler.store + room_id, self._main_store ): continue diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py index 13b0c61d2e..56eee4057f 100644 --- a/synapse/handlers/ui_auth/__init__.py +++ b/synapse/handlers/ui_auth/__init__.py @@ -38,4 +38,4 @@ class UIAuthSessionDataConstants: # used during registration to store the registration token used (if required) so that: # - we can prevent a token being used twice by one session # - we can 'use up' the token after registration has successfully completed - REGISTRATION_TOKEN = "org.matrix.msc3231.login.registration_token" + REGISTRATION_TOKEN = "m.login.registration_token" diff --git a/synapse/http/client.py b/synapse/http/client.py index 743a7ffcb1..c01d2326cf 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -20,6 +20,7 @@ from typing import ( TYPE_CHECKING, Any, BinaryIO, + Callable, Dict, Iterable, List, @@ -321,21 +322,20 @@ class SimpleHttpClient: self._ip_whitelist = ip_whitelist self._ip_blacklist = ip_blacklist self._extra_treq_args = treq_args or {} - - self.user_agent = hs.version_string self.clock = hs.get_clock() + + user_agent = hs.version_string if hs.config.server.user_agent_suffix: - self.user_agent = "%s %s" % ( - self.user_agent, + user_agent = "%s %s" % ( + user_agent, hs.config.server.user_agent_suffix, ) + self.user_agent = user_agent.encode("ascii") # We use this for our body producers to ensure that they use the correct # reactor. self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor())) - self.user_agent = self.user_agent.encode("ascii") - if self._ip_blacklist: # If we have an IP blacklist, we need to use a DNS resolver which # filters out blacklisted IP addresses, to prevent DNS rebinding. @@ -693,12 +693,18 @@ class SimpleHttpClient: output_stream: BinaryIO, max_size: Optional[int] = None, headers: Optional[RawHeaders] = None, + is_allowed_content_type: Optional[Callable[[str], bool]] = None, ) -> Tuple[int, Dict[bytes, List[bytes]], str, int]: """GETs a file from a given URL Args: url: The URL to GET output_stream: File to write the response body to. headers: A map from header name to a list of values for that header + is_allowed_content_type: A predicate to determine whether the + content type of the file we're downloading is allowed. If set and + it evaluates to False when called with the content type, the + request will be terminated before completing the download by + raising SynapseError. Returns: A tuple of the file length, dict of the response headers, absolute URI of the response and HTTP response code. @@ -726,6 +732,17 @@ class SimpleHttpClient: HTTPStatus.BAD_GATEWAY, "Got error %d" % (response.code,), Codes.UNKNOWN ) + if is_allowed_content_type and b"Content-Type" in resp_headers: + content_type = resp_headers[b"Content-Type"][0].decode("ascii") + if not is_allowed_content_type(content_type): + raise SynapseError( + HTTPStatus.BAD_GATEWAY, + ( + "Requested file's content type not allowed for this operation: %s" + % content_type + ), + ) + # TODO: if our Content-Type is HTML or something, just read the first # N bytes into RAM rather than saving it all to disk only to read it # straight back in again diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 2e668363b2..c5f8fcbb2a 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -334,12 +334,11 @@ class MatrixFederationHttpClient: user_agent = hs.version_string if hs.config.server.user_agent_suffix: user_agent = "%s %s" % (user_agent, hs.config.server.user_agent_suffix) - user_agent = user_agent.encode("ascii") federation_agent = MatrixFederationAgent( self.reactor, tls_client_options_factory, - user_agent, + user_agent.encode("ascii"), hs.config.server.federation_ip_range_whitelist, hs.config.server.federation_ip_range_blacklist, ) diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py deleted file mode 100644 index b9933a1528..0000000000 --- a/synapse/logging/_structured.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os.path -from typing import Any, Dict, Generator, Optional, Tuple - -from constantly import NamedConstant, Names - -from synapse.config._base import ConfigError - - -class DrainType(Names): - CONSOLE = NamedConstant() - CONSOLE_JSON = NamedConstant() - CONSOLE_JSON_TERSE = NamedConstant() - FILE = NamedConstant() - FILE_JSON = NamedConstant() - NETWORK_JSON_TERSE = NamedConstant() - - -DEFAULT_LOGGERS = {"synapse": {"level": "info"}} - - -def parse_drain_configs( - drains: dict, -) -> Generator[Tuple[str, Dict[str, Any]], None, None]: - """ - Parse the drain configurations. - - Args: - drains (dict): A list of drain configurations. - - Yields: - dict instances representing a logging handler. - - Raises: - ConfigError: If any of the drain configuration items are invalid. - """ - - for name, config in drains.items(): - if "type" not in config: - raise ConfigError("Logging drains require a 'type' key.") - - try: - logging_type = DrainType.lookupByName(config["type"].upper()) - except ValueError: - raise ConfigError( - "%s is not a known logging drain type." % (config["type"],) - ) - - # Either use the default formatter or the tersejson one. - if logging_type in ( - DrainType.CONSOLE_JSON, - DrainType.FILE_JSON, - ): - formatter: Optional[str] = "json" - elif logging_type in ( - DrainType.CONSOLE_JSON_TERSE, - DrainType.NETWORK_JSON_TERSE, - ): - formatter = "tersejson" - else: - # A formatter of None implies using the default formatter. - formatter = None - - if logging_type in [ - DrainType.CONSOLE, - DrainType.CONSOLE_JSON, - DrainType.CONSOLE_JSON_TERSE, - ]: - location = config.get("location") - if location is None or location not in ["stdout", "stderr"]: - raise ConfigError( - ( - "The %s drain needs the 'location' key set to " - "either 'stdout' or 'stderr'." - ) - % (logging_type,) - ) - - yield name, { - "class": "logging.StreamHandler", - "formatter": formatter, - "stream": "ext://sys." + location, - } - - elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]: - if "location" not in config: - raise ConfigError( - "The %s drain needs the 'location' key set." % (logging_type,) - ) - - location = config.get("location") - if os.path.abspath(location) != location: - raise ConfigError( - "File paths need to be absolute, '%s' is a relative path" - % (location,) - ) - - yield name, { - "class": "logging.FileHandler", - "formatter": formatter, - "filename": location, - } - - elif logging_type in [DrainType.NETWORK_JSON_TERSE]: - host = config.get("host") - port = config.get("port") - maximum_buffer = config.get("maximum_buffer", 1000) - - yield name, { - "class": "synapse.logging.RemoteHandler", - "formatter": formatter, - "host": host, - "port": port, - "maximum_buffer": maximum_buffer, - } - - else: - raise ConfigError( - "The %s drain type is currently not implemented." - % (config["type"].upper(),) - ) - - -def setup_structured_logging( - log_config: dict, -) -> dict: - """ - Convert a legacy structured logging configuration (from Synapse < v1.23.0) - to one compatible with the new standard library handlers. - """ - if "drains" not in log_config: - raise ConfigError("The logging configuration requires a list of drains.") - - new_config = { - "version": 1, - "formatters": { - "json": {"class": "synapse.logging.JsonFormatter"}, - "tersejson": {"class": "synapse.logging.TerseJsonFormatter"}, - }, - "handlers": {}, - "loggers": log_config.get("loggers", DEFAULT_LOGGERS), - "root": {"handlers": []}, - } - - for handler_name, handler in parse_drain_configs(log_config["drains"]): - new_config["handlers"][handler_name] = handler - - # Add each handler to the root logger. - new_config["root"]["handlers"].append(handler_name) - - return new_config diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index b240d2d21d..3ebed5c161 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -443,10 +443,14 @@ def start_active_span( start_time=None, ignore_active_span=False, finish_on_close=True, + *, + tracer=None, ): - """Starts an active opentracing span. Note, the scope doesn't become active - until it has been entered, however, the span starts from the time this - message is called. + """Starts an active opentracing span. + + Records the start time for the span, and sets it as the "active span" in the + scope manager. + Args: See opentracing.tracer Returns: @@ -456,7 +460,11 @@ def start_active_span( if opentracing is None: return noop_context_manager() # type: ignore[unreachable] - return opentracing.tracer.start_active_span( + if tracer is None: + # use the global tracer by default + tracer = opentracing.tracer + + return tracer.start_active_span( operation_name, child_of=child_of, references=references, @@ -468,21 +476,42 @@ def start_active_span( def start_active_span_follows_from( - operation_name: str, contexts: Collection, inherit_force_tracing=False + operation_name: str, + contexts: Collection, + child_of=None, + start_time: Optional[float] = None, + *, + inherit_force_tracing=False, + tracer=None, ): """Starts an active opentracing span, with additional references to previous spans Args: operation_name: name of the operation represented by the new span contexts: the previous spans to inherit from + + child_of: optionally override the parent span. If unset, the currently active + span will be the parent. (If there is no currently active span, the first + span in `contexts` will be the parent.) + + start_time: optional override for the start time of the created span. Seconds + since the epoch. + inherit_force_tracing: if set, and any of the previous contexts have had tracing forced, the new span will also have tracing forced. + tracer: override the opentracing tracer. By default the global tracer is used. """ if opentracing is None: return noop_context_manager() # type: ignore[unreachable] references = [opentracing.follows_from(context) for context in contexts] - scope = start_active_span(operation_name, references=references) + scope = start_active_span( + operation_name, + child_of=child_of, + references=references, + start_time=start_time, + tracer=tracer, + ) if inherit_force_tracing and any( is_context_forced_tracing(ctx) for ctx in contexts diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py index db8ca2c049..d57e7c5324 100644 --- a/synapse/logging/scopecontextmanager.py +++ b/synapse/logging/scopecontextmanager.py @@ -28,8 +28,9 @@ class LogContextScopeManager(ScopeManager): The LogContextScopeManager tracks the active scope in opentracing by using the log contexts which are native to synapse. This is so that the basic opentracing api can be used across twisted defereds. - (I would love to break logcontexts and this into an OS package. but - let's wait for twisted's contexts to be released.) + + It would be nice just to use opentracing's ContextVarsScopeManager, + but currently that doesn't work due to https://twistedmatrix.com/trac/ticket/10301. """ def __init__(self, config): @@ -65,29 +66,45 @@ class LogContextScopeManager(ScopeManager): Scope.close() on the returned instance. """ - enter_logcontext = False ctx = current_context() if not ctx: - # We don't want this scope to affect. logger.error("Tried to activate scope outside of loggingcontext") return Scope(None, span) # type: ignore[arg-type] - elif ctx.scope is not None: - # We want the logging scope to look exactly the same so we give it - # a blank suffix + + if ctx.scope is not None: + # start a new logging context as a child of the existing one. + # Doing so -- rather than updating the existing logcontext -- means that + # creating several concurrent spans under the same logcontext works + # correctly. ctx = nested_logging_context("") enter_logcontext = True + else: + # if there is no span currently associated with the current logcontext, we + # just store the scope in it. + # + # This feels a bit dubious, but it does hack around a problem where a + # span outlasts its parent logcontext (which would otherwise lead to + # "Re-starting finished log context" errors). + enter_logcontext = False scope = _LogContextScope(self, span, ctx, enter_logcontext, finish_on_close) ctx.scope = scope + if enter_logcontext: + ctx.__enter__() + return scope class _LogContextScope(Scope): """ - A custom opentracing scope. The only significant difference is that it will - close the log context it's related to if the logcontext was created specifically - for this scope. + A custom opentracing scope, associated with a LogContext + + * filters out _DefGen_Return exceptions which arise from calling + `defer.returnValue` in Twisted code + + * When the scope is closed, the logcontext's active scope is reset to None. + and - if enter_logcontext was set - the logcontext is finished too. """ def __init__(self, manager, span, logcontext, enter_logcontext, finish_on_close): @@ -101,8 +118,7 @@ class _LogContextScope(Scope): logcontext (LogContext): the logcontext to which this scope is attached. enter_logcontext (Boolean): - if True the logcontext will be entered and exited when the scope - is entered and exited respectively + if True the logcontext will be exited when the scope is finished finish_on_close (Boolean): if True finish the span when the scope is closed """ @@ -111,26 +127,28 @@ class _LogContextScope(Scope): self._finish_on_close = finish_on_close self._enter_logcontext = enter_logcontext - def __enter__(self): - if self._enter_logcontext: - self.logcontext.__enter__() + def __exit__(self, exc_type, value, traceback): + if exc_type == twisted.internet.defer._DefGen_Return: + # filter out defer.returnValue() calls + exc_type = value = traceback = None + super().__exit__(exc_type, value, traceback) - return self - - def __exit__(self, type, value, traceback): - if type == twisted.internet.defer._DefGen_Return: - super().__exit__(None, None, None) - else: - super().__exit__(type, value, traceback) - if self._enter_logcontext: - self.logcontext.__exit__(type, value, traceback) - else: # the logcontext existed before the creation of the scope - self.logcontext.scope = None + def __str__(self): + return f"Scope<{self.span}>" def close(self): - if self.manager.active is not self: - logger.error("Tried to close a non-active scope!") - return + active_scope = self.manager.active + if active_scope is not self: + logger.error( + "Closing scope %s which is not the currently-active one %s", + self, + active_scope, + ) if self._finish_on_close: self.span.finish() + + self.logcontext.scope = None + + if self._enter_logcontext: + self.logcontext.__exit__(None, None, None) diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index 9e6c1b2f3b..d321946aa2 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -30,9 +30,11 @@ from typing import ( Type, TypeVar, Union, + cast, ) import attr +from matrix_common.versionstring import get_distribution_version_string from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram, Metric from prometheus_client.core import ( REGISTRY, @@ -42,14 +44,14 @@ from prometheus_client.core import ( from twisted.python.threadpool import ThreadPool -import synapse.metrics._reactor_metrics +# This module is imported for its side effects; flake8 needn't warn that it's unused. +import synapse.metrics._reactor_metrics # noqa: F401 from synapse.metrics._exposition import ( MetricsResource, generate_latest, start_http_server, ) from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager -from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) @@ -60,7 +62,7 @@ all_gauges: "Dict[str, Union[LaterGauge, InFlightGauge]]" = {} HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") -class RegistryProxy: +class _RegistryProxy: @staticmethod def collect() -> Iterable[Metric]: for metric in REGISTRY.collect(): @@ -68,6 +70,13 @@ class RegistryProxy: yield metric +# A little bit nasty, but collect() above is static so a Protocol doesn't work. +# _RegistryProxy matches the signature of a CollectorRegistry instance enough +# for it to be usable in the contexts in which we use it. +# TODO Do something nicer about this. +RegistryProxy = cast(CollectorRegistry, _RegistryProxy) + + @attr.s(slots=True, hash=True, auto_attribs=True) class LaterGauge: @@ -409,7 +418,7 @@ build_info = Gauge( ) build_info.labels( " ".join([platform.python_implementation(), platform.python_version()]), - get_version_string(synapse), + get_distribution_version_string("matrix-synapse"), " ".join([platform.system(), platform.release()]), ).set(1) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 788b2e47d5..07020bfb8d 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -48,7 +48,6 @@ from synapse.events.spamcheck import ( CHECK_USERNAME_FOR_SPAM_CALLBACK, USER_MAY_CREATE_ROOM_ALIAS_CALLBACK, USER_MAY_CREATE_ROOM_CALLBACK, - USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK, USER_MAY_INVITE_CALLBACK, USER_MAY_JOIN_ROOM_CALLBACK, USER_MAY_PUBLISH_ROOM_CALLBACK, @@ -71,7 +70,9 @@ from synapse.handlers.account_validity import ( from synapse.handlers.auth import ( CHECK_3PID_AUTH_CALLBACK, CHECK_AUTH_CALLBACK, + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK, GET_USERNAME_FOR_REGISTRATION_CALLBACK, + IS_3PID_ALLOWED_CALLBACK, ON_LOGGED_OUT_CALLBACK, AuthHandler, ) @@ -211,14 +212,12 @@ class ModuleApi: def register_spam_checker_callbacks( self, + *, check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None, user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None, user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None, user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None, user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None, - user_may_create_room_with_invites: Optional[ - USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK - ] = None, user_may_create_room_alias: Optional[ USER_MAY_CREATE_ROOM_ALIAS_CALLBACK ] = None, @@ -239,7 +238,6 @@ class ModuleApi: user_may_invite=user_may_invite, user_may_send_3pid_invite=user_may_send_3pid_invite, user_may_create_room=user_may_create_room, - user_may_create_room_with_invites=user_may_create_room_with_invites, user_may_create_room_alias=user_may_create_room_alias, user_may_publish_room=user_may_publish_room, check_username_for_spam=check_username_for_spam, @@ -249,6 +247,7 @@ class ModuleApi: def register_account_validity_callbacks( self, + *, is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None, on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None, on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, @@ -269,6 +268,7 @@ class ModuleApi: def register_third_party_rules_callbacks( self, + *, check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None, on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None, check_threepid_can_be_invited: Optional[ @@ -293,6 +293,7 @@ class ModuleApi: def register_presence_router_callbacks( self, + *, get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None, get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None, ) -> None: @@ -307,14 +308,19 @@ class ModuleApi: def register_password_auth_provider_callbacks( self, + *, check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None, on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None, auth_checkers: Optional[ Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK] ] = None, + is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None, get_username_for_registration: Optional[ GET_USERNAME_FOR_REGISTRATION_CALLBACK ] = None, + get_displayname_for_registration: Optional[ + GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK + ] = None, ) -> None: """Registers callbacks for password auth provider capabilities. @@ -323,12 +329,15 @@ class ModuleApi: return self._password_auth_provider.register_password_auth_provider_callbacks( check_3pid_auth=check_3pid_auth, on_logged_out=on_logged_out, + is_3pid_allowed=is_3pid_allowed, auth_checkers=auth_checkers, get_username_for_registration=get_username_for_registration, + get_displayname_for_registration=get_displayname_for_registration, ) def register_background_update_controller_callbacks( self, + *, on_update: ON_UPDATE_CALLBACK, default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None, min_batch_size: Optional[MIN_BATCH_SIZE_CALLBACK] = None, @@ -401,6 +410,32 @@ class ModuleApi: """ return self._hs.config.email.email_app_name + @property + def server_name(self) -> str: + """The server name for the local homeserver. + + Added in Synapse v1.53.0. + """ + return self._server_name + + @property + def worker_name(self) -> Optional[str]: + """The name of the worker this specific instance is running as per the + "worker_name" configuration setting, or None if it's the main process. + + Added in Synapse v1.53.0. + """ + return self._hs.config.worker.worker_name + + @property + def worker_app(self) -> Optional[str]: + """The name of the worker app this specific instance is running as per the + "worker_app" configuration setting, or None if it's the main process. + + Added in Synapse v1.53.0. + """ + return self._hs.config.worker.worker_app + async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: """Get user info by user_id @@ -618,7 +653,11 @@ class ModuleApi: Added in Synapse v1.9.0. Args: - auth_provider: identifier for the remote auth provider + auth_provider: identifier for the remote auth provider, see `sso` and + `oidc_providers` in the homeserver configuration. + + Note that no error is raised if the provided value is not in the + homeserver configuration. external_id: id on that system user_id: complete mxid that it is mapped to """ diff --git a/synapse/notifier.py b/synapse/notifier.py index 632b2245ef..753dd6b6a5 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -14,6 +14,7 @@ import logging from typing import ( + TYPE_CHECKING, Awaitable, Callable, Collection, @@ -32,7 +33,6 @@ from prometheus_client import Counter from twisted.internet import defer -import synapse.server from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.api.errors import AuthError from synapse.events import EventBase @@ -53,6 +53,9 @@ from synapse.util.async_helpers import ObservableDeferred, timeout_deferred from synapse.util.metrics import Measure from synapse.visibility import filter_events_for_client +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) notified_events_counter = Counter("synapse_notifier_notified_events", "") @@ -82,7 +85,7 @@ class _NotificationListener: __slots__ = ["deferred"] - def __init__(self, deferred): + def __init__(self, deferred: "defer.Deferred"): self.deferred = deferred @@ -124,7 +127,7 @@ class _NotifierUserStream: stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int, - ): + ) -> None: """Notify any listeners for this user of a new event from an event source. Args: @@ -135,7 +138,7 @@ class _NotifierUserStream: self.current_token = self.current_token.copy_and_advance(stream_key, stream_id) self.last_notified_token = self.current_token self.last_notified_ms = time_now_ms - noify_deferred = self.notify_deferred + notify_deferred = self.notify_deferred log_kv( { @@ -150,9 +153,9 @@ class _NotifierUserStream: with PreserveLoggingContext(): self.notify_deferred = ObservableDeferred(defer.Deferred()) - noify_deferred.callback(self.current_token) + notify_deferred.callback(self.current_token) - def remove(self, notifier: "Notifier"): + def remove(self, notifier: "Notifier") -> None: """Remove this listener from all the indexes in the Notifier it knows about. """ @@ -188,7 +191,7 @@ class EventStreamResult: start_token: StreamToken end_token: StreamToken - def __bool__(self): + def __bool__(self) -> bool: return bool(self.events) @@ -212,7 +215,7 @@ class Notifier: UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 - def __init__(self, hs: "synapse.server.HomeServer"): + def __init__(self, hs: "HomeServer"): self.user_to_user_stream: Dict[str, _NotifierUserStream] = {} self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {} @@ -248,7 +251,7 @@ class Notifier: # This is not a very cheap test to perform, but it's only executed # when rendering the metrics page, which is likely once per minute at # most when scraping it. - def count_listeners(): + def count_listeners() -> int: all_user_streams: Set[_NotifierUserStream] = set() for streams in list(self.room_to_user_streams.values()): @@ -270,7 +273,7 @@ class Notifier: "synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream) ) - def add_replication_callback(self, cb: Callable[[], None]): + def add_replication_callback(self, cb: Callable[[], None]) -> None: """Add a callback that will be called when some new data is available. Callback is not given any arguments. It should *not* return a Deferred - if it needs to do any asynchronous work, a background thread should be started and @@ -284,7 +287,7 @@ class Notifier: event_pos: PersistedEventPosition, max_room_stream_token: RoomStreamToken, extra_users: Optional[Collection[UserID]] = None, - ): + ) -> None: """Unwraps event and calls `on_new_room_event_args`.""" await self.on_new_room_event_args( event_pos=event_pos, @@ -307,7 +310,7 @@ class Notifier: event_pos: PersistedEventPosition, max_room_stream_token: RoomStreamToken, extra_users: Optional[Collection[UserID]] = None, - ): + ) -> None: """Used by handlers to inform the notifier something has happened in the room, room event wise. @@ -338,7 +341,9 @@ class Notifier: self.notify_replication() - def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken): + def _notify_pending_new_room_events( + self, max_room_stream_token: RoomStreamToken + ) -> None: """Notify for the room events that were queued waiting for a previous event to be persisted. Args: @@ -374,7 +379,7 @@ class Notifier: ) self._on_updated_room_token(max_room_stream_token) - def _on_updated_room_token(self, max_room_stream_token: RoomStreamToken): + def _on_updated_room_token(self, max_room_stream_token: RoomStreamToken) -> None: """Poke services that might care that the room position has been updated. """ @@ -386,13 +391,13 @@ class Notifier: if self.federation_sender: self.federation_sender.notify_new_events(max_room_stream_token) - def _notify_app_services(self, max_room_stream_token: RoomStreamToken): + def _notify_app_services(self, max_room_stream_token: RoomStreamToken) -> None: try: self.appservice_handler.notify_interested_services(max_room_stream_token) except Exception: logger.exception("Error notifying application services of event") - def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken): + def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken) -> None: try: self._pusher_pool.on_new_notifications(max_room_stream_token) except Exception: @@ -461,7 +466,9 @@ class Notifier: users, ) except Exception: - logger.exception("Error notifying application services of event") + logger.exception( + "Error notifying application services of ephemeral events" + ) def on_new_replication_data(self) -> None: """Used to inform replication listeners that something has happened @@ -473,8 +480,8 @@ class Notifier: user_id: str, timeout: int, callback: Callable[[StreamToken, StreamToken], Awaitable[T]], - room_ids=None, - from_token=StreamToken.START, + room_ids: Optional[Collection[str]] = None, + from_token: StreamToken = StreamToken.START, ) -> T: """Wait until the callback returns a non empty response or the timeout fires. @@ -698,14 +705,14 @@ class Notifier: for expired_stream in expired_streams: expired_stream.remove(self) - def _register_with_keys(self, user_stream: _NotifierUserStream): + def _register_with_keys(self, user_stream: _NotifierUserStream) -> None: self.user_to_user_stream[user_stream.user_id] = user_stream for room in user_stream.rooms: s = self.room_to_user_streams.setdefault(room, set()) s.add(user_stream) - def _user_joined_room(self, user_id: str, room_id: str): + def _user_joined_room(self, user_id: str, room_id: str) -> None: new_user_stream = self.user_to_user_stream.get(user_id) if new_user_stream is not None: room_streams = self.room_to_user_streams.setdefault(room_id, set()) @@ -717,7 +724,7 @@ class Notifier: for cb in self.replication_callbacks: cb() - def notify_remote_server_up(self, server: str): + def notify_remote_server_up(self, server: str) -> None: """Notify any replication that a remote server has come back up""" # We call federation_sender directly rather than registering as a # callback as a) we already have a reference to it and b) it introduces diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 6211506990..832eaa34e9 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -20,15 +20,11 @@ from typing import Any, Dict, List from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP -def list_with_base_rules( - rawrules: List[Dict[str, Any]], use_new_defaults: bool = False -) -> List[Dict[str, Any]]: +def list_with_base_rules(rawrules: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Combine the list of rules set by the user with the default push rules Args: rawrules: The rules the user has modified or set. - use_new_defaults: Whether to use the new experimental default rules when - appending or prepending default rules. Returns: A new list with the rules set by the user combined with the defaults. @@ -48,9 +44,7 @@ def list_with_base_rules( ruleslist.extend( make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], - modified_base_rules, - use_new_defaults, + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules ) ) @@ -61,7 +55,6 @@ def list_with_base_rules( make_base_append_rules( PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules, - use_new_defaults, ) ) current_prio_class -= 1 @@ -70,7 +63,6 @@ def list_with_base_rules( make_base_prepend_rules( PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules, - use_new_defaults, ) ) @@ -79,18 +71,14 @@ def list_with_base_rules( while current_prio_class > 0: ruleslist.extend( make_base_append_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], - modified_base_rules, - use_new_defaults, + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules ) ) current_prio_class -= 1 if current_prio_class > 0: ruleslist.extend( make_base_prepend_rules( - PRIORITY_CLASS_INVERSE_MAP[current_prio_class], - modified_base_rules, - use_new_defaults, + PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules ) ) @@ -98,24 +86,14 @@ def list_with_base_rules( def make_base_append_rules( - kind: str, - modified_base_rules: Dict[str, Dict[str, Any]], - use_new_defaults: bool = False, + kind: str, modified_base_rules: Dict[str, Dict[str, Any]] ) -> List[Dict[str, Any]]: rules = [] if kind == "override": - rules = ( - NEW_APPEND_OVERRIDE_RULES - if use_new_defaults - else BASE_APPEND_OVERRIDE_RULES - ) + rules = BASE_APPEND_OVERRIDE_RULES elif kind == "underride": - rules = ( - NEW_APPEND_UNDERRIDE_RULES - if use_new_defaults - else BASE_APPEND_UNDERRIDE_RULES - ) + rules = BASE_APPEND_UNDERRIDE_RULES elif kind == "content": rules = BASE_APPEND_CONTENT_RULES @@ -134,7 +112,6 @@ def make_base_append_rules( def make_base_prepend_rules( kind: str, modified_base_rules: Dict[str, Dict[str, Any]], - use_new_defaults: bool = False, ) -> List[Dict[str, Any]]: rules = [] @@ -153,7 +130,9 @@ def make_base_prepend_rules( return rules -BASE_APPEND_CONTENT_RULES = [ +# We have to annotate these types, otherwise mypy infers them as +# `List[Dict[str, Sequence[Collection[str]]]]`. +BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [ { "rule_id": "global/content/.m.rule.contains_user_name", "conditions": [ @@ -172,7 +151,7 @@ BASE_APPEND_CONTENT_RULES = [ ] -BASE_PREPEND_OVERRIDE_RULES = [ +BASE_PREPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [ { "rule_id": "global/override/.m.rule.master", "enabled": False, @@ -182,7 +161,7 @@ BASE_PREPEND_OVERRIDE_RULES = [ ] -BASE_APPEND_OVERRIDE_RULES = [ +BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [ { "rule_id": "global/override/.m.rule.suppress_notices", "conditions": [ @@ -301,136 +280,7 @@ BASE_APPEND_OVERRIDE_RULES = [ ] -NEW_APPEND_OVERRIDE_RULES = [ - { - "rule_id": "global/override/.m.rule.encrypted", - "conditions": [ - { - "kind": "event_match", - "key": "type", - "pattern": "m.room.encrypted", - "_id": "_encrypted", - } - ], - "actions": ["notify"], - }, - { - "rule_id": "global/override/.m.rule.suppress_notices", - "conditions": [ - { - "kind": "event_match", - "key": "type", - "pattern": "m.room.message", - "_id": "_suppress_notices_type", - }, - { - "kind": "event_match", - "key": "content.msgtype", - "pattern": "m.notice", - "_id": "_suppress_notices", - }, - ], - "actions": [], - }, - { - "rule_id": "global/underride/.m.rule.suppress_edits", - "conditions": [ - { - "kind": "event_match", - "key": "m.relates_to.m.rel_type", - "pattern": "m.replace", - "_id": "_suppress_edits", - } - ], - "actions": [], - }, - { - "rule_id": "global/override/.m.rule.invite_for_me", - "conditions": [ - { - "kind": "event_match", - "key": "type", - "pattern": "m.room.member", - "_id": "_member", - }, - { - "kind": "event_match", - "key": "content.membership", - "pattern": "invite", - "_id": "_invite_member", - }, - {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"}, - ], - "actions": ["notify", {"set_tweak": "sound", "value": "default"}], - }, - { - "rule_id": "global/override/.m.rule.contains_display_name", - "conditions": [{"kind": "contains_display_name"}], - "actions": [ - "notify", - {"set_tweak": "sound", "value": "default"}, - {"set_tweak": "highlight"}, - ], - }, - { - "rule_id": "global/override/.m.rule.tombstone", - "conditions": [ - { - "kind": "event_match", - "key": "type", - "pattern": "m.room.tombstone", - "_id": "_tombstone", - }, - { - "kind": "event_match", - "key": "state_key", - "pattern": "", - "_id": "_tombstone_statekey", - }, - ], - "actions": [ - "notify", - {"set_tweak": "sound", "value": "default"}, - {"set_tweak": "highlight"}, - ], - }, - { - "rule_id": "global/override/.m.rule.roomnotif", - "conditions": [ - { - "kind": "event_match", - "key": "content.body", - "pattern": "@room", - "_id": "_roomnotif_content", - }, - { - "kind": "sender_notification_permission", - "key": "room", - "_id": "_roomnotif_pl", - }, - ], - "actions": [ - "notify", - {"set_tweak": "highlight"}, - {"set_tweak": "sound", "value": "default"}, - ], - }, - { - "rule_id": "global/override/.m.rule.call", - "conditions": [ - { - "kind": "event_match", - "key": "type", - "pattern": "m.call.invite", - "_id": "_call", - } - ], - "actions": ["notify", {"set_tweak": "sound", "value": "ring"}], - }, -] - - -BASE_APPEND_UNDERRIDE_RULES = [ +BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [ { "rule_id": "global/underride/.m.rule.call", "conditions": [ @@ -538,36 +388,6 @@ BASE_APPEND_UNDERRIDE_RULES = [ ] -NEW_APPEND_UNDERRIDE_RULES = [ - { - "rule_id": "global/underride/.m.rule.room_one_to_one", - "conditions": [ - {"kind": "room_member_count", "is": "2", "_id": "member_count"}, - { - "kind": "event_match", - "key": "content.body", - "pattern": "*", - "_id": "body", - }, - ], - "actions": ["notify", {"set_tweak": "sound", "value": "default"}], - }, - { - "rule_id": "global/underride/.m.rule.message", - "conditions": [ - { - "kind": "event_match", - "key": "content.body", - "pattern": "*", - "_id": "body", - }, - ], - "actions": ["notify"], - "enabled": False, - }, -] - - BASE_RULE_IDS = set() for r in BASE_APPEND_CONTENT_RULES: @@ -589,26 +409,3 @@ for r in BASE_APPEND_UNDERRIDE_RULES: r["priority_class"] = PRIORITY_CLASS_MAP["underride"] r["default"] = True BASE_RULE_IDS.add(r["rule_id"]) - - -NEW_RULE_IDS = set() - -for r in BASE_APPEND_CONTENT_RULES: - r["priority_class"] = PRIORITY_CLASS_MAP["content"] - r["default"] = True - NEW_RULE_IDS.add(r["rule_id"]) - -for r in BASE_PREPEND_OVERRIDE_RULES: - r["priority_class"] = PRIORITY_CLASS_MAP["override"] - r["default"] = True - NEW_RULE_IDS.add(r["rule_id"]) - -for r in NEW_APPEND_OVERRIDE_RULES: - r["priority_class"] = PRIORITY_CLASS_MAP["override"] - r["default"] = True - NEW_RULE_IDS.add(r["rule_id"]) - -for r in NEW_APPEND_UNDERRIDE_RULES: - r["priority_class"] = PRIORITY_CLASS_MAP["underride"] - r["default"] = True - NEW_RULE_IDS.add(r["rule_id"]) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 96559081d0..52c7ff3572 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -109,6 +109,7 @@ class HttpPusher(Pusher): self.data_minus_url = {} self.data_minus_url.update(self.data) del self.data_minus_url["url"] + self.badge_count_last_call: Optional[int] = None def on_started(self, should_check_for_notifs: bool) -> None: """Called when this pusher has been started. @@ -136,7 +137,9 @@ class HttpPusher(Pusher): self.user_id, group_by_room=self._group_unread_count_by_room, ) - await self._send_badge(badge) + if self.badge_count_last_call is None or self.badge_count_last_call != badge: + self.badge_count_last_call = badge + await self._send_badge(badge) def on_timer(self) -> None: self._start_processing() @@ -322,7 +325,7 @@ class HttpPusher(Pusher): # This was checked in the __init__, but mypy doesn't seem to know that. assert self.data is not None if self.data.get("format") == "event_id_only": - d = { + d: Dict[str, Any] = { "notification": { "event_id": event.event_id, "room_id": event.room_id, @@ -402,6 +405,8 @@ class HttpPusher(Pusher): rejected = [] if "rejected" in resp: rejected = resp["rejected"] + else: + self.badge_count_last_call = badge return rejected async def _send_badge(self, badge: int) -> None: diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 80786464c2..f43fbb5842 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -51,7 +51,7 @@ REQUIREMENTS = [ # we use the TYPE_CHECKER.redefine method added in jsonschema 3.0.0 "jsonschema>=3.0.0", # frozendict 2.1.2 is broken on Debian 10: https://github.com/Marco-Sulla/python-frozendict/issues/41 - "frozendict>=1,<2.1.2", + "frozendict>=1,!=2.1.2", "unpaddedbase64>=1.1.0", "canonicaljson>=1.4.0", # we use the type definitions added in signedjson 1.1. @@ -76,8 +76,7 @@ REQUIREMENTS = [ "msgpack>=0.5.2", "phonenumbers>=8.2.0", # we use GaugeHistogramMetric, which was added in prom-client 0.4.0. - # 0.13.0 has an incorrect type annotation, see #11832. - "prometheus_client>=0.4.0,<0.13.0", + "prometheus_client>=0.4.0", # we use `order`, which arrived in attrs 19.2.0. # Note: 21.1.0 broke `/sync`, see #9936 "attrs>=19.2.0,!=21.1.0", @@ -88,8 +87,9 @@ REQUIREMENTS = [ # We enforce that we have a `cryptography` version that bundles an `openssl` # with the latest security patches. "cryptography>=3.4.7", - "ijson>=3.1", - "matrix-common==1.0.0", + # ijson 3.1.4 fixes a bug with "." in property names + "ijson>=3.1.4", + "matrix-common~=1.1.0", ] CONDITIONAL_REQUIREMENTS = { diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index 1457d9d59b..aec040ee19 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -40,7 +40,7 @@ class ReplicationRestResource(JsonResource): super().__init__(hs, canonical_json=False, extract_context=True) self.register_servlets(hs) - def register_servlets(self, hs: "HomeServer"): + def register_servlets(self, hs: "HomeServer") -> None: send_event.register_servlets(hs, self) federation.register_servlets(hs, self) presence.register_servlets(hs, self) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 585332b244..bc1d28dd19 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -15,16 +15,20 @@ import abc import logging import re -import urllib +import urllib.parse from inspect import signature from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple from prometheus_client import Counter, Gauge +from twisted.web.server import Request + from synapse.api.errors import HttpResponseException, SynapseError from synapse.http import RequestTimedOutError +from synapse.http.server import HttpServer from synapse.logging import opentracing from synapse.logging.opentracing import trace +from synapse.types import JsonDict from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import random_string @@ -113,10 +117,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): if hs.config.worker.worker_replication_secret: self._replication_secret = hs.config.worker.worker_replication_secret - def _check_auth(self, request) -> None: + def _check_auth(self, request: Request) -> None: # Get the authorization header. auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + if not auth_headers: + raise RuntimeError("Missing Authorization header.") if len(auth_headers) > 1: raise RuntimeError("Too many Authorization headers.") parts = auth_headers[0].split(b" ") @@ -129,7 +135,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): raise RuntimeError("Invalid Authorization header.") @abc.abstractmethod - async def _serialize_payload(**kwargs): + async def _serialize_payload(**kwargs) -> JsonDict: """Static method that is called when creating a request. Concrete implementations should have explicit parameters (rather than @@ -144,19 +150,20 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): return {} @abc.abstractmethod - async def _handle_request(self, request, **kwargs): + async def _handle_request( + self, request: Request, **kwargs: Any + ) -> Tuple[int, JsonDict]: """Handle incoming request. This is called with the request object and PATH_ARGS. Returns: - tuple[int, dict]: HTTP status code and a JSON serialisable dict - to be used as response body of request. + HTTP status code and a JSON serialisable dict to be used as response + body of request. """ - pass @classmethod - def make_client(cls, hs: "HomeServer"): + def make_client(cls, hs: "HomeServer") -> Callable: """Create a client that makes requests. Returns a callable that accepts the same parameters as @@ -182,7 +189,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): ) @trace(opname="outgoing_replication_request") - async def send_request(*, instance_name="master", **kwargs): + async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any: with outgoing_gauge.track_inprogress(): if instance_name == local_instance_name: raise Exception("Trying to send HTTP request to self") @@ -268,7 +275,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): return send_request - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: """Called by the server to register this as a handler to the appropriate path. """ @@ -289,7 +296,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): self.__class__.__name__, ) - async def _check_auth_and_handle(self, request, **kwargs): + async def _check_auth_and_handle( + self, request: Request, **kwargs: Any + ) -> Tuple[int, JsonDict]: """Called on new incoming requests when caching is enabled. Checks if there is a cached response for the request and returns that, otherwise calls `_handle_request` and caches its response. diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py index 5f0f225aa9..310f609153 100644 --- a/synapse/replication/http/account_data.py +++ b/synapse/replication/http/account_data.py @@ -13,10 +13,14 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Tuple +from twisted.web.server import Request + +from synapse.http.server import HttpServer from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -48,14 +52,18 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - async def _serialize_payload(user_id, account_data_type, content): + async def _serialize_payload( # type: ignore[override] + user_id: str, account_data_type: str, content: JsonDict + ) -> JsonDict: payload = { "content": content, } return payload - async def _handle_request(self, request, user_id, account_data_type): + async def _handle_request( # type: ignore[override] + self, request: Request, user_id: str, account_data_type: str + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) max_stream_id = await self.handler.add_account_data_for_user( @@ -89,14 +97,18 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - async def _serialize_payload(user_id, room_id, account_data_type, content): + async def _serialize_payload( # type: ignore[override] + user_id: str, room_id: str, account_data_type: str, content: JsonDict + ) -> JsonDict: payload = { "content": content, } return payload - async def _handle_request(self, request, user_id, room_id, account_data_type): + async def _handle_request( # type: ignore[override] + self, request: Request, user_id: str, room_id: str, account_data_type: str + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) max_stream_id = await self.handler.add_account_data_to_room( @@ -130,14 +142,18 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - async def _serialize_payload(user_id, room_id, tag, content): + async def _serialize_payload( # type: ignore[override] + user_id: str, room_id: str, tag: str, content: JsonDict + ) -> JsonDict: payload = { "content": content, } return payload - async def _handle_request(self, request, user_id, room_id, tag): + async def _handle_request( # type: ignore[override] + self, request: Request, user_id: str, room_id: str, tag: str + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) max_stream_id = await self.handler.add_tag_to_room( @@ -173,11 +189,13 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - async def _serialize_payload(user_id, room_id, tag): + async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: # type: ignore[override] return {} - async def _handle_request(self, request, user_id, room_id, tag): + async def _handle_request( # type: ignore[override] + self, request: Request, user_id: str, room_id: str, tag: str + ) -> Tuple[int, JsonDict]: max_stream_id = await self.handler.remove_tag_from_room( user_id, room_id, @@ -187,7 +205,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): return 200, {"max_stream_id": max_stream_id} -def register_servlets(hs: "HomeServer", http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReplicationUserAccountDataRestServlet(hs).register(http_server) ReplicationRoomAccountDataRestServlet(hs).register(http_server) ReplicationAddTagRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py index 42dffb39cb..f2f40129fe 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py @@ -13,9 +13,13 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Tuple +from twisted.web.server import Request + +from synapse.http.server import HttpServer from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -63,14 +67,16 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - async def _serialize_payload(user_id): + async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override] return {} - async def _handle_request(self, request, user_id): + async def _handle_request( # type: ignore[override] + self, request: Request, user_id: str + ) -> Tuple[int, JsonDict]: user_devices = await self.device_list_updater.user_device_resync(user_id) return 200, user_devices -def register_servlets(hs: "HomeServer", http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReplicationUserDevicesResyncRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 5ed535c90d..d529c8a19f 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -13,17 +13,22 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Tuple -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.events import make_event_from_dict +from twisted.web.server import Request + +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion +from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext +from synapse.http.server import HttpServer from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import JsonDict from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -69,14 +74,18 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.federation_event_handler = hs.get_federation_event_handler() @staticmethod - async def _serialize_payload(store, room_id, event_and_contexts, backfilled): + async def _serialize_payload( # type: ignore[override] + store: "DataStore", + room_id: str, + event_and_contexts: List[Tuple[EventBase, EventContext]], + backfilled: bool, + ) -> JsonDict: """ Args: store - room_id (str) - event_and_contexts (list[tuple[FrozenEvent, EventContext]]) - backfilled (bool): Whether or not the events are the result of - backfilling + room_id + event_and_contexts + backfilled: Whether or not the events are the result of backfilling """ event_payloads = [] for event, context in event_and_contexts: @@ -102,7 +111,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): return payload - async def _handle_request(self, request): + async def _handle_request(self, request: Request) -> Tuple[int, JsonDict]: # type: ignore[override] with Measure(self.clock, "repl_fed_send_events_parse"): content = parse_json_object_from_request(request) @@ -163,10 +172,14 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): self.registry = hs.get_federation_registry() @staticmethod - async def _serialize_payload(edu_type, origin, content): + async def _serialize_payload( # type: ignore[override] + edu_type: str, origin: str, content: JsonDict + ) -> JsonDict: return {"origin": origin, "content": content} - async def _handle_request(self, request, edu_type): + async def _handle_request( # type: ignore[override] + self, request: Request, edu_type: str + ) -> Tuple[int, JsonDict]: with Measure(self.clock, "repl_fed_send_edu_parse"): content = parse_json_object_from_request(request) @@ -175,9 +188,9 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): logger.info("Got %r edu from %s", edu_type, origin) - result = await self.registry.on_edu(edu_type, origin, edu_content) + await self.registry.on_edu(edu_type, origin, edu_content) - return 200, result + return 200, {} class ReplicationGetQueryRestServlet(ReplicationEndpoint): @@ -206,15 +219,17 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): self.registry = hs.get_federation_registry() @staticmethod - async def _serialize_payload(query_type, args): + async def _serialize_payload(query_type: str, args: JsonDict) -> JsonDict: # type: ignore[override] """ Args: - query_type (str) - args (dict): The arguments received for the given query type + query_type + args: The arguments received for the given query type """ return {"args": args} - async def _handle_request(self, request, query_type): + async def _handle_request( # type: ignore[override] + self, request: Request, query_type: str + ) -> Tuple[int, JsonDict]: with Measure(self.clock, "repl_fed_query_parse"): content = parse_json_object_from_request(request) @@ -248,14 +263,16 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): self.store = hs.get_datastore() @staticmethod - async def _serialize_payload(room_id, args): + async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override] """ Args: - room_id (str) + room_id """ return {} - async def _handle_request(self, request, room_id): + async def _handle_request( # type: ignore[override] + self, request: Request, room_id: str + ) -> Tuple[int, JsonDict]: await self.store.clean_room_for_join(room_id) return 200, {} @@ -283,17 +300,19 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint): self.store = hs.get_datastore() @staticmethod - async def _serialize_payload(room_id, room_version): + async def _serialize_payload(room_id: str, room_version: RoomVersion) -> JsonDict: # type: ignore[override] return {"room_version": room_version.identifier} - async def _handle_request(self, request, room_id): + async def _handle_request( # type: ignore[override] + self, request: Request, room_id: str + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) room_version = KNOWN_ROOM_VERSIONS[content["room_version"]] await self.store.maybe_store_room_on_outlier_membership(room_id, room_version) return 200, {} -def register_servlets(hs: "HomeServer", http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReplicationFederationSendEventsRestServlet(hs).register(http_server) ReplicationFederationSendEduRestServlet(hs).register(http_server) ReplicationGetQueryRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index daacc34cea..c68e18da12 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -13,10 +13,14 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Tuple, cast +from twisted.web.server import Request + +from synapse.http.server import HttpServer from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -39,25 +43,24 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): self.registration_handler = hs.get_registration_handler() @staticmethod - async def _serialize_payload( - user_id, - device_id, - initial_display_name, - is_guest, - is_appservice_ghost, - should_issue_refresh_token, - auth_provider_id, - auth_provider_session_id, - ): + async def _serialize_payload( # type: ignore[override] + user_id: str, + device_id: Optional[str], + initial_display_name: Optional[str], + is_guest: bool, + is_appservice_ghost: bool, + should_issue_refresh_token: bool, + auth_provider_id: Optional[str], + auth_provider_session_id: Optional[str], + ) -> JsonDict: """ Args: - user_id (int) - device_id (str|None): Device ID to use, if None a new one is - generated. - initial_display_name (str|None) - is_guest (bool) - is_appservice_ghost (bool) - should_issue_refresh_token (bool) + user_id + device_id: Device ID to use, if None a new one is generated. + initial_display_name + is_guest + is_appservice_ghost + should_issue_refresh_token """ return { "device_id": device_id, @@ -69,7 +72,9 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): "auth_provider_session_id": auth_provider_session_id, } - async def _handle_request(self, request, user_id): + async def _handle_request( # type: ignore[override] + self, request: Request, user_id: str + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) device_id = content["device_id"] @@ -91,8 +96,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): auth_provider_session_id=auth_provider_session_id, ) - return 200, res + return 200, cast(JsonDict, res) -def register_servlets(hs: "HomeServer", http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RegisterDeviceReplicationServlet(hs).register(http_server) diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 7371c240b2..0145858e47 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple from twisted.web.server import Request +from synapse.http.server import HttpServer from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest from synapse.replication.http._base import ReplicationEndpoint @@ -53,7 +54,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - async def _serialize_payload( # type: ignore + async def _serialize_payload( # type: ignore[override] requester: Requester, room_id: str, user_id: str, @@ -77,7 +78,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): "content": content, } - async def _handle_request( # type: ignore + async def _handle_request( # type: ignore[override] self, request: SynapseRequest, room_id: str, user_id: str ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) @@ -122,13 +123,13 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - async def _serialize_payload( # type: ignore + async def _serialize_payload( # type: ignore[override] requester: Requester, room_id: str, user_id: str, remote_room_hosts: List[str], content: JsonDict, - ): + ) -> JsonDict: """ Args: requester: The user making the request, according to the access token. @@ -143,12 +144,12 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint): "content": content, } - async def _handle_request( # type: ignore + async def _handle_request( # type: ignore[override] self, request: SynapseRequest, room_id: str, user_id: str, - ): + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) remote_room_hosts = content["remote_room_hosts"] @@ -192,7 +193,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): self.member_handler = hs.get_room_member_handler() @staticmethod - async def _serialize_payload( # type: ignore + async def _serialize_payload( # type: ignore[override] invite_event_id: str, txn_id: Optional[str], requester: Requester, @@ -215,7 +216,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): "content": content, } - async def _handle_request( # type: ignore + async def _handle_request( # type: ignore[override] self, request: SynapseRequest, invite_event_id: str ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) @@ -262,12 +263,12 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint): self.member_handler = hs.get_room_member_handler() @staticmethod - async def _serialize_payload( # type: ignore + async def _serialize_payload( # type: ignore[override] knock_event_id: str, txn_id: Optional[str], requester: Requester, content: JsonDict, - ): + ) -> JsonDict: """ Args: knock_event_id: The ID of the knock to be rescinded. @@ -281,11 +282,11 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint): "content": content, } - async def _handle_request( # type: ignore + async def _handle_request( # type: ignore[override] self, request: SynapseRequest, knock_event_id: str, - ): + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) txn_id = content["txn_id"] @@ -329,7 +330,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): self.distributor = hs.get_distributor() @staticmethod - async def _serialize_payload( # type: ignore + async def _serialize_payload( # type: ignore[override] room_id: str, user_id: str, change: str ) -> JsonDict: """ @@ -345,7 +346,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): return {} - async def _handle_request( # type: ignore + async def _handle_request( # type: ignore[override] self, request: Request, room_id: str, user_id: str, change: str ) -> Tuple[int, JsonDict]: logger.info("user membership change: %s in %s", user_id, room_id) @@ -360,7 +361,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): return 200, {} -def register_servlets(hs: "HomeServer", http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReplicationRemoteJoinRestServlet(hs).register(http_server) ReplicationRemoteRejectInviteRestServlet(hs).register(http_server) ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py index 63143085d5..4a5b08f56f 100644 --- a/synapse/replication/http/presence.py +++ b/synapse/replication/http/presence.py @@ -13,11 +13,14 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Tuple +from twisted.web.server import Request + +from synapse.http.server import HttpServer from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint -from synapse.types import UserID +from synapse.types import JsonDict, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -49,18 +52,17 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint): self._presence_handler = hs.get_presence_handler() @staticmethod - async def _serialize_payload(user_id): + async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override] return {} - async def _handle_request(self, request, user_id): + async def _handle_request( # type: ignore[override] + self, request: Request, user_id: str + ) -> Tuple[int, JsonDict]: await self._presence_handler.bump_presence_active_time( UserID.from_string(user_id) ) - return ( - 200, - {}, - ) + return (200, {}) class ReplicationPresenceSetState(ReplicationEndpoint): @@ -92,16 +94,21 @@ class ReplicationPresenceSetState(ReplicationEndpoint): self._presence_handler = hs.get_presence_handler() @staticmethod - async def _serialize_payload( - user_id, state, ignore_status_msg=False, force_notify=False - ): + async def _serialize_payload( # type: ignore[override] + user_id: str, + state: JsonDict, + ignore_status_msg: bool = False, + force_notify: bool = False, + ) -> JsonDict: return { "state": state, "ignore_status_msg": ignore_status_msg, "force_notify": force_notify, } - async def _handle_request(self, request, user_id): + async def _handle_request( # type: ignore[override] + self, request: Request, user_id: str + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) await self._presence_handler.set_state( @@ -111,12 +118,9 @@ class ReplicationPresenceSetState(ReplicationEndpoint): content["force_notify"], ) - return ( - 200, - {}, - ) + return (200, {}) -def register_servlets(hs: "HomeServer", http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReplicationBumpPresenceActiveTime(hs).register(http_server) ReplicationPresenceSetState(hs).register(http_server) diff --git a/synapse/replication/http/push.py b/synapse/replication/http/push.py index 6c8db3061e..af5c2f66a7 100644 --- a/synapse/replication/http/push.py +++ b/synapse/replication/http/push.py @@ -13,10 +13,14 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Tuple +from twisted.web.server import Request + +from synapse.http.server import HttpServer from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -48,7 +52,7 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint): self.pusher_pool = hs.get_pusherpool() @staticmethod - async def _serialize_payload(app_id, pushkey, user_id): + async def _serialize_payload(app_id: str, pushkey: str, user_id: str) -> JsonDict: # type: ignore[override] payload = { "app_id": app_id, "pushkey": pushkey, @@ -56,7 +60,9 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint): return payload - async def _handle_request(self, request, user_id): + async def _handle_request( # type: ignore[override] + self, request: Request, user_id: str + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) app_id = content["app_id"] @@ -67,5 +73,5 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint): return 200, {} -def register_servlets(hs: "HomeServer", http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReplicationRemovePusherRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 7adfbb666f..c7f751b70d 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -13,10 +13,14 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Tuple +from twisted.web.server import Request + +from synapse.http.server import HttpServer from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -36,34 +40,34 @@ class ReplicationRegisterServlet(ReplicationEndpoint): self.registration_handler = hs.get_registration_handler() @staticmethod - async def _serialize_payload( - user_id, - password_hash, - was_guest, - make_guest, - appservice_id, - create_profile_with_displayname, - admin, - user_type, - address, - shadow_banned, - ): + async def _serialize_payload( # type: ignore[override] + user_id: str, + password_hash: Optional[str], + was_guest: bool, + make_guest: bool, + appservice_id: Optional[str], + create_profile_with_displayname: Optional[str], + admin: bool, + user_type: Optional[str], + address: Optional[str], + shadow_banned: bool, + ) -> JsonDict: """ Args: - user_id (str): The desired user ID to register. - password_hash (str|None): Optional. The password hash for this user. - was_guest (bool): Optional. Whether this is a guest account being - upgraded to a non-guest account. - make_guest (boolean): True if the the new user should be guest, - false to add a regular user account. - appservice_id (str|None): The ID of the appservice registering the user. - create_profile_with_displayname (unicode|None): Optionally create a - profile for the user, setting their displayname to the given value - admin (boolean): is an admin user? - user_type (str|None): type of user. One of the values from - api.constants.UserTypes, or None for a normal user. - address (str|None): the IP address used to perform the regitration. - shadow_banned (bool): Whether to shadow-ban the user + user_id: The desired user ID to register. + password_hash: Optional. The password hash for this user. + was_guest: Optional. Whether this is a guest account being upgraded + to a non-guest account. + make_guest: True if the the new user should be guest, false to add a + regular user account. + appservice_id: The ID of the appservice registering the user. + create_profile_with_displayname: Optionally create a profile for the + user, setting their displayname to the given value + admin: is an admin user? + user_type: type of user. One of the values from api.constants.UserTypes, + or None for a normal user. + address: the IP address used to perform the regitration. + shadow_banned: Whether to shadow-ban the user """ return { "password_hash": password_hash, @@ -77,7 +81,9 @@ class ReplicationRegisterServlet(ReplicationEndpoint): "shadow_banned": shadow_banned, } - async def _handle_request(self, request, user_id): + async def _handle_request( # type: ignore[override] + self, request: Request, user_id: str + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) await self.registration_handler.check_registration_ratelimit(content["address"]) @@ -110,18 +116,21 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): self.registration_handler = hs.get_registration_handler() @staticmethod - async def _serialize_payload(user_id, auth_result, access_token): + async def _serialize_payload( # type: ignore[override] + user_id: str, auth_result: JsonDict, access_token: Optional[str] + ) -> JsonDict: """ Args: - user_id (str): The user ID that consented - auth_result (dict): The authenticated credentials of the newly - registered user. - access_token (str|None): The access token of the newly logged in + user_id: The user ID that consented + auth_result: The authenticated credentials of the newly registered user. + access_token: The access token of the newly logged in device, or None if `inhibit_login` enabled. """ return {"auth_result": auth_result, "access_token": access_token} - async def _handle_request(self, request, user_id): + async def _handle_request( # type: ignore[override] + self, request: Request, user_id: str + ) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) auth_result = content["auth_result"] @@ -134,6 +143,6 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): return 200, {} -def register_servlets(hs: "HomeServer", http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReplicationRegisterServlet(hs).register(http_server) ReplicationPostRegisterActionsServlet(hs).register(http_server) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 9f6851d059..33e98daf8a 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -13,18 +13,22 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Tuple + +from twisted.web.server import Request from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext +from synapse.http.server import HttpServer from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint -from synapse.types import Requester, UserID +from synapse.types import JsonDict, Requester, UserID from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -70,18 +74,24 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - async def _serialize_payload( - event_id, store, event, context, requester, ratelimit, extra_users - ): + async def _serialize_payload( # type: ignore[override] + event_id: str, + store: "DataStore", + event: EventBase, + context: EventContext, + requester: Requester, + ratelimit: bool, + extra_users: List[UserID], + ) -> JsonDict: """ Args: - event_id (str) - store (DataStore) - requester (Requester) - event (FrozenEvent) - context (EventContext) - ratelimit (bool) - extra_users (list(UserID)): Any extra users to notify about event + event_id + store + requester + event + context + ratelimit + extra_users: Any extra users to notify about event """ serialized_context = await context.serialize(event, store) @@ -100,7 +110,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): return payload - async def _handle_request(self, request, event_id): + async def _handle_request( # type: ignore[override] + self, request: Request, event_id: str + ) -> Tuple[int, JsonDict]: with Measure(self.clock, "repl_send_event_parse"): content = parse_json_object_from_request(request) @@ -120,8 +132,6 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): ratelimit = content["ratelimit"] extra_users = [UserID.from_string(u) for u in content["extra_users"]] - request.requester = requester - logger.info( "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) @@ -139,5 +149,5 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): ) -def register_servlets(hs: "HomeServer", http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReplicationSendEventRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index 3223bc2432..c065225362 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -13,11 +13,15 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Tuple + +from twisted.web.server import Request from synapse.api.errors import SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import parse_integer from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -57,10 +61,14 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): self.streams = hs.get_replication_streams() @staticmethod - async def _serialize_payload(stream_name, from_token, upto_token): + async def _serialize_payload( # type: ignore[override] + stream_name: str, from_token: int, upto_token: int + ) -> JsonDict: return {"from_token": from_token, "upto_token": upto_token} - async def _handle_request(self, request, stream_name): + async def _handle_request( # type: ignore[override] + self, request: Request, stream_name: str + ) -> Tuple[int, JsonDict]: stream = self.streams.get(stream_name) if stream is None: raise SynapseError(400, "Unknown stream") @@ -78,5 +86,5 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): ) -def register_servlets(hs: "HomeServer", http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReplicationGetStreamUpdates(hs).register(http_server) diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index fa132d10b4..8f3f953ed4 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -40,7 +40,7 @@ class SlavedIdTracker(AbstractStreamIdTracker): for table, column in extra_tables: self.advance(None, _load_current_id(db_conn, table, column)) - def advance(self, instance_name: Optional[str], new_id: int): + def advance(self, instance_name: Optional[str], new_id: int) -> None: self._current = (max if self.step > 0 else min)(self._current, new_id) def get_current_token(self) -> int: diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index bc888ce1a8..b5b84c09ae 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -37,7 +37,9 @@ class SlavedClientIpStore(BaseSlavedStore): cache_name="client_ip_last_seen", max_size=50000 ) - async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): + async def insert_client_ip( + self, user_id: str, access_token: str, ip: str, user_agent: str, device_id: str + ) -> None: now = int(self._clock.time_msec()) key = (user_id, access_token, ip) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index a2aff75b70..0ffd34f1da 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Iterable from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker @@ -60,7 +60,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(instance_name, token) self._invalidate_caches_for_devices(token, rows) @@ -70,7 +72,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) - def _invalidate_caches_for_devices(self, token, rows): + def _invalidate_caches_for_devices( + self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] + ) -> None: for row in rows: # The entities are either user IDs (starting with '@') whose devices # have changed, or remote servers that we need to tell about diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 9d90e26375..d6f37d7479 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Iterable from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker @@ -44,10 +44,12 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): self._group_updates_id_gen.get_current_token(), ) - def get_group_stream_token(self): + def get_group_stream_token(self) -> int: return self._group_updates_id_gen.get_current_token() - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: if stream_name == GroupServerStream.NAME: self._group_updates_id_gen.advance(instance_name, token) for row in rows: diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 7541e21de9..52ee3f7e58 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Iterable from synapse.replication.tcp.streams import PushRulesStream from synapse.storage.databases.main.push_rule import PushRulesWorkerStore @@ -20,10 +21,12 @@ from .events import SlavedEventStore class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): - def get_max_push_rules_stream_id(self): + def get_max_push_rules_stream_id(self) -> int: return self._push_rules_stream_id_gen.get_current_token() - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: if stream_name == PushRulesStream.NAME: self._push_rules_stream_id_gen.advance(instance_name, token) for row in rows: diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index cea90c0f1b..de642bba71 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Iterable from synapse.replication.tcp.streams import PushersStream from synapse.storage.database import DatabasePool, LoggingDatabaseConnection @@ -41,8 +41,8 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): return self._pushers_id_gen.get_current_token() def process_replication_rows( - self, stream_name: str, instance_name: str, token, rows + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] ) -> None: if stream_name == PushersStream.NAME: - self._pushers_id_gen.advance(instance_name, token) # type: ignore + self._pushers_id_gen.advance(instance_name, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e29ae1e375..d59ce7ccf9 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -14,10 +14,12 @@ """A replication client for use by synapse workers. """ import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IAddress, IConnector from twisted.internet.protocol import ReconnectingClientFactory +from twisted.python.failure import Failure from synapse.api.constants import EventTypes from synapse.federation import send_queue @@ -79,10 +81,10 @@ class DirectTcpReplicationClientFactory(ReconnectingClientFactory): hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying) - def startedConnecting(self, connector): + def startedConnecting(self, connector: IConnector) -> None: logger.info("Connecting to replication: %r", connector.getDestination()) - def buildProtocol(self, addr): + def buildProtocol(self, addr: IAddress) -> ClientReplicationStreamProtocol: logger.info("Connected to replication: %r", addr) return ClientReplicationStreamProtocol( self.hs, @@ -92,11 +94,11 @@ class DirectTcpReplicationClientFactory(ReconnectingClientFactory): self.command_handler, ) - def clientConnectionLost(self, connector, reason): + def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None: logger.error("Lost replication conn: %r", reason) ReconnectingClientFactory.clientConnectionLost(self, connector, reason) - def clientConnectionFailed(self, connector, reason): + def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None: logger.error("Failed to connect to replication: %r", reason) ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) @@ -131,7 +133,7 @@ class ReplicationDataHandler: async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list - ): + ) -> None: """Called to handle a batch of replication data with a given stream token. By default this just pokes the slave store. Can be overridden in subclasses to @@ -252,14 +254,16 @@ class ReplicationDataHandler: # loop. (This maintains the order so no need to resort) waiting_list[:] = waiting_list[index_of_first_deferred_not_called:] - async def on_position(self, stream_name: str, instance_name: str, token: int): + async def on_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: await self.on_rdata(stream_name, instance_name, token, []) # We poke the generic "replication" notifier to wake anything up that # may be streaming. self.notifier.notify_replication() - def on_remote_server_up(self, server: str): + def on_remote_server_up(self, server: str) -> None: """Called when get a new REMOTE_SERVER_UP command.""" # Let's wake up the transaction queue for the server in case we have @@ -269,7 +273,7 @@ class ReplicationDataHandler: async def wait_for_stream_position( self, instance_name: str, stream_name: str, position: int - ): + ) -> None: """Wait until this instance has received updates up to and including the given stream position. """ @@ -304,7 +308,7 @@ class ReplicationDataHandler: "Finished waiting for repl stream %r to reach %s", stream_name, position ) - def stop_pusher(self, user_id, app_id, pushkey): + def stop_pusher(self, user_id: str, app_id: str, pushkey: str) -> None: if not self._notify_pushers: return @@ -316,13 +320,13 @@ class ReplicationDataHandler: logger.info("Stopping pusher %r / %r", user_id, key) pusher.on_stop() - async def start_pusher(self, user_id, app_id, pushkey): + async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None: if not self._notify_pushers: return key = "%s:%s" % (app_id, pushkey) logger.info("Starting pusher %r / %r", user_id, key) - return await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) + await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) class FederationSenderHandler: @@ -353,10 +357,12 @@ class FederationSenderHandler: self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") - def wake_destination(self, server: str): + def wake_destination(self, server: str) -> None: self.federation_sender.wake_destination(server) - async def process_replication_rows(self, stream_name, token, rows): + async def process_replication_rows( + self, stream_name: str, token: int, rows: list + ) -> None: # The federation stream contains things that we want to send out, e.g. # presence, typing, etc. if stream_name == "federation": @@ -384,11 +390,12 @@ class FederationSenderHandler: for host in hosts: self.federation_sender.send_device_messages(host) - async def _on_new_receipts(self, rows): + async def _on_new_receipts( + self, rows: Iterable[ReceiptsStream.ReceiptsStreamRow] + ) -> None: """ Args: - rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]): - new receipts to be processed + rows: new receipts to be processed """ for receipt in rows: # we only want to send on receipts for our own users @@ -408,7 +415,7 @@ class FederationSenderHandler: ) await self.federation_sender.send_read_receipt(receipt_info) - async def update_token(self, token): + async def update_token(self, token: int) -> None: """Update the record of where we have processed to in the federation stream. Called after we have processed a an update received over replication. Sends @@ -428,7 +435,7 @@ class FederationSenderHandler: run_as_background_process("_save_and_send_ack", self._save_and_send_ack) - async def _save_and_send_ack(self): + async def _save_and_send_ack(self) -> None: """Save the current federation position in the database and send an ACK to master with where we're up to. """ diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 1311b013da..3654f6c03c 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -18,12 +18,15 @@ allowed to be sent by which side. """ import abc import logging -from typing import Tuple, Type +from typing import Optional, Tuple, Type, TypeVar +from synapse.replication.tcp.streams._base import StreamRow from synapse.util import json_decoder, json_encoder logger = logging.getLogger(__name__) +T = TypeVar("T", bound="Command") + class Command(metaclass=abc.ABCMeta): """The base command class. @@ -38,7 +41,7 @@ class Command(metaclass=abc.ABCMeta): @classmethod @abc.abstractmethod - def from_line(cls, line): + def from_line(cls: Type[T], line: str) -> T: """Deserialises a line from the wire into this command. `line` does not include the command. """ @@ -49,21 +52,24 @@ class Command(metaclass=abc.ABCMeta): prefix. """ - def get_logcontext_id(self): + def get_logcontext_id(self) -> str: """Get a suitable string for the logcontext when processing this command""" # by default, we just use the command name. return self.NAME +SC = TypeVar("SC", bound="_SimpleCommand") + + class _SimpleCommand(Command): """An implementation of Command whose argument is just a 'data' string.""" - def __init__(self, data): + def __init__(self, data: str): self.data = data @classmethod - def from_line(cls, line): + def from_line(cls: Type[SC], line: str) -> SC: return cls(line) def to_line(self) -> str: @@ -109,14 +115,16 @@ class RdataCommand(Command): NAME = "RDATA" - def __init__(self, stream_name, instance_name, token, row): + def __init__( + self, stream_name: str, instance_name: str, token: Optional[int], row: StreamRow + ): self.stream_name = stream_name self.instance_name = instance_name self.token = token self.row = row @classmethod - def from_line(cls, line): + def from_line(cls: Type["RdataCommand"], line: str) -> "RdataCommand": stream_name, instance_name, token, row_json = line.split(" ", 3) return cls( stream_name, @@ -125,7 +133,7 @@ class RdataCommand(Command): json_decoder.decode(row_json), ) - def to_line(self): + def to_line(self) -> str: return " ".join( ( self.stream_name, @@ -135,7 +143,7 @@ class RdataCommand(Command): ) ) - def get_logcontext_id(self): + def get_logcontext_id(self) -> str: return "RDATA-" + self.stream_name @@ -164,18 +172,20 @@ class PositionCommand(Command): NAME = "POSITION" - def __init__(self, stream_name, instance_name, prev_token, new_token): + def __init__( + self, stream_name: str, instance_name: str, prev_token: int, new_token: int + ): self.stream_name = stream_name self.instance_name = instance_name self.prev_token = prev_token self.new_token = new_token @classmethod - def from_line(cls, line): + def from_line(cls: Type["PositionCommand"], line: str) -> "PositionCommand": stream_name, instance_name, prev_token, new_token = line.split(" ", 3) return cls(stream_name, instance_name, int(prev_token), int(new_token)) - def to_line(self): + def to_line(self) -> str: return " ".join( ( self.stream_name, @@ -218,14 +228,14 @@ class ReplicateCommand(Command): NAME = "REPLICATE" - def __init__(self): + def __init__(self) -> None: pass @classmethod - def from_line(cls, line): + def from_line(cls: Type[T], line: str) -> T: return cls() - def to_line(self): + def to_line(self) -> str: return "" @@ -247,14 +257,16 @@ class UserSyncCommand(Command): NAME = "USER_SYNC" - def __init__(self, instance_id, user_id, is_syncing, last_sync_ms): + def __init__( + self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int + ): self.instance_id = instance_id self.user_id = user_id self.is_syncing = is_syncing self.last_sync_ms = last_sync_ms @classmethod - def from_line(cls, line): + def from_line(cls: Type["UserSyncCommand"], line: str) -> "UserSyncCommand": instance_id, user_id, state, last_sync_ms = line.split(" ", 3) if state not in ("start", "end"): @@ -262,7 +274,7 @@ class UserSyncCommand(Command): return cls(instance_id, user_id, state == "start", int(last_sync_ms)) - def to_line(self): + def to_line(self) -> str: return " ".join( ( self.instance_id, @@ -286,14 +298,16 @@ class ClearUserSyncsCommand(Command): NAME = "CLEAR_USER_SYNC" - def __init__(self, instance_id): + def __init__(self, instance_id: str): self.instance_id = instance_id @classmethod - def from_line(cls, line): + def from_line( + cls: Type["ClearUserSyncsCommand"], line: str + ) -> "ClearUserSyncsCommand": return cls(line) - def to_line(self): + def to_line(self) -> str: return self.instance_id @@ -316,7 +330,9 @@ class FederationAckCommand(Command): self.token = token @classmethod - def from_line(cls, line: str) -> "FederationAckCommand": + def from_line( + cls: Type["FederationAckCommand"], line: str + ) -> "FederationAckCommand": instance_name, token = line.split(" ") return cls(instance_name, int(token)) @@ -334,7 +350,15 @@ class UserIpCommand(Command): NAME = "USER_IP" - def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen): + def __init__( + self, + user_id: str, + access_token: str, + ip: str, + user_agent: str, + device_id: str, + last_seen: int, + ): self.user_id = user_id self.access_token = access_token self.ip = ip @@ -343,14 +367,14 @@ class UserIpCommand(Command): self.last_seen = last_seen @classmethod - def from_line(cls, line): + def from_line(cls: Type["UserIpCommand"], line: str) -> "UserIpCommand": user_id, jsn = line.split(" ", 1) access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn) return cls(user_id, access_token, ip, user_agent, device_id, last_seen) - def to_line(self): + def to_line(self) -> str: return ( self.user_id + " " diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 21293038ef..17e1572393 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -261,7 +261,7 @@ class ReplicationCommandHandler: "process-replication-data", self._unsafe_process_queue, stream_name ) - async def _unsafe_process_queue(self, stream_name: str): + async def _unsafe_process_queue(self, stream_name: str) -> None: """Processes the command queue for the given stream, until it is empty Does not check if there is already a thread processing the queue, hence "unsafe" @@ -294,7 +294,7 @@ class ReplicationCommandHandler: # This shouldn't be possible raise Exception("Unrecognised command %s in stream queue", cmd.NAME) - def start_replication(self, hs: "HomeServer"): + def start_replication(self, hs: "HomeServer") -> None: """Helper method to start a replication connection to the remote server using TCP. """ @@ -318,7 +318,7 @@ class ReplicationCommandHandler: hs, outbound_redis_connection ) hs.get_reactor().connectTCP( - hs.config.redis.redis_host, # type: ignore[arg-type] + hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory, timeout=30, @@ -330,7 +330,7 @@ class ReplicationCommandHandler: host = hs.config.worker.worker_replication_host port = hs.config.worker.worker_replication_port hs.get_reactor().connectTCP( - host, # type: ignore[arg-type] + host, port, self._factory, timeout=30, @@ -345,10 +345,10 @@ class ReplicationCommandHandler: """Get a list of streams that this instances replicates.""" return self._streams_to_replicate - def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand): + def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand) -> None: self.send_positions_to_connection(conn) - def send_positions_to_connection(self, conn: IReplicationConnection): + def send_positions_to_connection(self, conn: IReplicationConnection) -> None: """Send current position of all streams this process is source of to the connection. """ @@ -392,7 +392,7 @@ class ReplicationCommandHandler: def on_FEDERATION_ACK( self, conn: IReplicationConnection, cmd: FederationAckCommand - ): + ) -> None: federation_ack_counter.inc() if self._federation_sender: @@ -408,7 +408,7 @@ class ReplicationCommandHandler: else: return None - async def _handle_user_ip(self, cmd: UserIpCommand): + async def _handle_user_ip(self, cmd: UserIpCommand) -> None: await self._store.insert_client_ip( cmd.user_id, cmd.access_token, @@ -421,7 +421,7 @@ class ReplicationCommandHandler: assert self._server_notices_sender is not None await self._server_notices_sender.on_user_ip(cmd.user_id) - def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand): + def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand) -> None: if cmd.instance_name == self._instance_name: # Ignore RDATA that are just our own echoes return @@ -497,7 +497,7 @@ class ReplicationCommandHandler: async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list - ): + ) -> None: """Called to handle a batch of replication data with a given stream token. Args: @@ -512,7 +512,7 @@ class ReplicationCommandHandler: stream_name, instance_name, token, rows ) - def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand): + def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand) -> None: if cmd.instance_name == self._instance_name: # Ignore POSITION that are just our own echoes return @@ -581,7 +581,7 @@ class ReplicationCommandHandler: def on_REMOTE_SERVER_UP( self, conn: IReplicationConnection, cmd: RemoteServerUpCommand - ): + ) -> None: """Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) @@ -604,7 +604,7 @@ class ReplicationCommandHandler: # between two instances, but that is not currently supported). self.send_command(cmd, ignore_conn=conn) - def new_connection(self, connection: IReplicationConnection): + def new_connection(self, connection: IReplicationConnection) -> None: """Called when we have a new connection.""" self._connections.append(connection) @@ -631,7 +631,7 @@ class ReplicationCommandHandler: UserSyncCommand(self._instance_id, user_id, True, now) ) - def lost_connection(self, connection: IReplicationConnection): + def lost_connection(self, connection: IReplicationConnection) -> None: """Called when a connection is closed/lost.""" # we no longer need _streams_by_connection for this connection. streams = self._streams_by_connection.pop(connection, None) @@ -653,7 +653,7 @@ class ReplicationCommandHandler: def send_command( self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None - ): + ) -> None: """Send a command to all connected connections. Args: @@ -680,7 +680,7 @@ class ReplicationCommandHandler: else: logger.warning("Dropping command as not connected: %r", cmd.NAME) - def send_federation_ack(self, token: int): + def send_federation_ack(self, token: int) -> None: """Ack data for the federation stream. This allows the master to drop data stored purely in memory. """ @@ -688,7 +688,7 @@ class ReplicationCommandHandler: def send_user_sync( self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int - ): + ) -> None: """Poke the master that a user has started/stopped syncing.""" self.send_command( UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) @@ -702,15 +702,15 @@ class ReplicationCommandHandler: user_agent: str, device_id: str, last_seen: int, - ): + ) -> None: """Tell the master that the user made a request.""" cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) self.send_command(cmd) - def send_remote_server_up(self, server: str): + def send_remote_server_up(self, server: str) -> None: self.send_command(RemoteServerUpCommand(server)) - def stream_update(self, stream_name: str, token: str, data: Any): + def stream_update(self, stream_name: str, token: Optional[int], data: Any) -> None: """Called when a new update is available to stream to clients. We need to check if the client is interested in the stream or not diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 7bae36db16..7763ffb2d0 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -49,7 +49,7 @@ import fcntl import logging import struct from inspect import isawaitable -from typing import TYPE_CHECKING, Collection, List, Optional +from typing import TYPE_CHECKING, Any, Collection, List, Optional from prometheus_client import Counter from zope.interface import Interface, implementer @@ -123,7 +123,7 @@ class ConnectionStates: class IReplicationConnection(Interface): """An interface for replication connections.""" - def send_command(cmd: Command): + def send_command(cmd: Command) -> None: """Send the command down the connection""" @@ -190,7 +190,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): "replication-conn", self.conn_id ) - def connectionMade(self): + def connectionMade(self) -> None: logger.info("[%s] Connection established", self.id()) self.state = ConnectionStates.ESTABLISHED @@ -207,11 +207,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # Always send the initial PING so that the other side knows that they # can time us out. - self.send_command(PingCommand(self.clock.time_msec())) + self.send_command(PingCommand(str(self.clock.time_msec()))) self.command_handler.new_connection(self) - def send_ping(self): + def send_ping(self) -> None: """Periodically sends a ping and checks if we should close the connection due to the other side timing out. """ @@ -226,7 +226,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.transport.abortConnection() else: if now - self.last_sent_command >= PING_TIME: - self.send_command(PingCommand(now)) + self.send_command(PingCommand(str(now))) if ( self.received_ping @@ -239,12 +239,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): ) self.send_error("ping timeout") - def lineReceived(self, line: bytes): + def lineReceived(self, line: bytes) -> None: """Called when we've received a line""" with PreserveLoggingContext(self._logging_context): self._parse_and_dispatch_line(line) - def _parse_and_dispatch_line(self, line: bytes): + def _parse_and_dispatch_line(self, line: bytes) -> None: if line.strip() == "": # Ignore blank lines return @@ -309,24 +309,24 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if not handled: logger.warning("Unhandled command: %r", cmd) - def close(self): + def close(self) -> None: logger.warning("[%s] Closing connection", self.id()) self.time_we_closed = self.clock.time_msec() assert self.transport is not None self.transport.loseConnection() self.on_connection_closed() - def send_error(self, error_string, *args): + def send_error(self, error_string: str, *args: Any) -> None: """Send an error to remote and close the connection.""" self.send_command(ErrorCommand(error_string % args)) self.close() - def send_command(self, cmd, do_buffer=True): + def send_command(self, cmd: Command, do_buffer: bool = True) -> None: """Send a command if connection has been established. Args: - cmd (Command) - do_buffer (bool): Whether to buffer the message or always attempt + cmd + do_buffer: Whether to buffer the message or always attempt to send the command. This is mostly used to send an error message if we're about to close the connection due our buffers becoming full. @@ -357,7 +357,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.last_sent_command = self.clock.time_msec() - def _queue_command(self, cmd): + def _queue_command(self, cmd: Command) -> None: """Queue the command until the connection is ready to write to again.""" logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd) self.pending_commands.append(cmd) @@ -370,20 +370,20 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.send_command(ErrorCommand("Failed to keep up"), do_buffer=False) self.close() - def _send_pending_commands(self): + def _send_pending_commands(self) -> None: """Send any queued commandes""" pending = self.pending_commands self.pending_commands = [] for cmd in pending: self.send_command(cmd) - def on_PING(self, line): + def on_PING(self, cmd: PingCommand) -> None: self.received_ping = True - def on_ERROR(self, cmd): + def on_ERROR(self, cmd: ErrorCommand) -> None: logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) - def pauseProducing(self): + def pauseProducing(self) -> None: """This is called when both the kernel send buffer and the twisted tcp connection send buffers have become full. @@ -394,26 +394,26 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): logger.info("[%s] Pause producing", self.id()) self.state = ConnectionStates.PAUSED - def resumeProducing(self): + def resumeProducing(self) -> None: """The remote has caught up after we started buffering!""" logger.info("[%s] Resume producing", self.id()) self.state = ConnectionStates.ESTABLISHED self._send_pending_commands() - def stopProducing(self): + def stopProducing(self) -> None: """We're never going to send any more data (normally because either we or the remote has closed the connection) """ logger.info("[%s] Stop producing", self.id()) self.on_connection_closed() - def connectionLost(self, reason): + def connectionLost(self, reason: Failure) -> None: # type: ignore[override] logger.info("[%s] Replication connection closed: %r", self.id(), reason) if isinstance(reason, Failure): assert reason.type is not None connection_close_counter.labels(reason.type.__name__).inc() else: - connection_close_counter.labels(reason.__class__.__name__).inc() + connection_close_counter.labels(reason.__class__.__name__).inc() # type: ignore[unreachable] try: # Remove us from list of connections to be monitored @@ -427,7 +427,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.on_connection_closed() - def on_connection_closed(self): + def on_connection_closed(self) -> None: logger.info("[%s] Connection was closed", self.id()) self.state = ConnectionStates.CLOSED @@ -445,7 +445,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # the sentinel context is now active, which may not be correct. # PreserveLoggingContext() will restore the correct logging context. - def __str__(self): + def __str__(self) -> str: addr = None if self.transport: addr = str(self.transport.getPeer()) @@ -455,10 +455,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): addr, ) - def id(self): + def id(self) -> str: return "%s-%s" % (self.name, self.conn_id) - def lineLengthExceeded(self, line): + def lineLengthExceeded(self, line: str) -> None: """Called when we receive a line that is above the maximum line length""" self.send_error("Line length exceeded") @@ -474,11 +474,11 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.server_name = server_name - def connectionMade(self): + def connectionMade(self) -> None: self.send_command(ServerCommand(self.server_name)) super().connectionMade() - def on_NAME(self, cmd): + def on_NAME(self, cmd: NameCommand) -> None: logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data @@ -500,19 +500,19 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.client_name = client_name self.server_name = server_name - def connectionMade(self): + def connectionMade(self) -> None: self.send_command(NameCommand(self.client_name)) super().connectionMade() # Once we've connected subscribe to the necessary streams self.replicate() - def on_SERVER(self, cmd): + def on_SERVER(self, cmd: ServerCommand) -> None: if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) self.send_error("Wrong remote") - def replicate(self): + def replicate(self) -> None: """Send the subscription request to the server""" logger.info("[%s] Subscribing to replication streams", self.id()) @@ -529,7 +529,7 @@ pending_commands = LaterGauge( ) -def transport_buffer_size(protocol): +def transport_buffer_size(protocol: BaseReplicationStreamProtocol) -> int: if protocol.transport: size = len(protocol.transport.dataBuffer) + protocol.transport._tempDataLen return size @@ -544,7 +544,9 @@ transport_send_buffer = LaterGauge( ) -def transport_kernel_read_buffer_size(protocol, read=True): +def transport_kernel_read_buffer_size( + protocol: BaseReplicationStreamProtocol, read: bool = True +) -> int: SIOCINQ = 0x541B SIOCOUTQ = 0x5411 diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 8d28bd3f3f..3170f7c59b 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -14,7 +14,7 @@ import logging from inspect import isawaitable -from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, Optional, Type, TypeVar, cast import attr import txredisapi @@ -62,7 +62,7 @@ class ConstantProperty(Generic[T, V]): def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V: return self.constant - def __set__(self, obj: Optional[T], value: V): + def __set__(self, obj: Optional[T], value: V) -> None: pass @@ -95,7 +95,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): synapse_stream_name: str synapse_outbound_redis_connection: txredisapi.RedisProtocol - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) # a logcontext which we use for processing incoming commands. We declare it as a @@ -108,12 +108,12 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): "replication_command_handler" ) - def connectionMade(self): + def connectionMade(self) -> None: logger.info("Connected to redis") super().connectionMade() run_as_background_process("subscribe-replication", self._send_subscribe) - async def _send_subscribe(self): + async def _send_subscribe(self) -> None: # it's important to make sure that we only send the REPLICATE command once we # have successfully subscribed to the stream - otherwise we might miss the # POSITION response sent back by the other end. @@ -131,12 +131,12 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): # otherside won't know we've connected and so won't issue a REPLICATE. self.synapse_handler.send_positions_to_connection(self) - def messageReceived(self, pattern: str, channel: str, message: str): + def messageReceived(self, pattern: str, channel: str, message: str) -> None: """Received a message from redis.""" with PreserveLoggingContext(self._logging_context): self._parse_and_dispatch_message(message) - def _parse_and_dispatch_message(self, message: str): + def _parse_and_dispatch_message(self, message: str) -> None: if message.strip() == "": # Ignore blank lines return @@ -181,7 +181,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): "replication-" + cmd.get_logcontext_id(), lambda: res ) - def connectionLost(self, reason): + def connectionLost(self, reason: Failure) -> None: # type: ignore[override] logger.info("Lost connection to redis") super().connectionLost(reason) self.synapse_handler.lost_connection(self) @@ -193,17 +193,17 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): # the sentinel context is now active, which may not be correct. # PreserveLoggingContext() will restore the correct logging context. - def send_command(self, cmd: Command): + def send_command(self, cmd: Command) -> None: """Send a command if connection has been established. Args: - cmd (Command) + cmd: The command to send """ run_as_background_process( "send-cmd", self._async_send_command, cmd, bg_start_span=False ) - async def _async_send_command(self, cmd: Command): + async def _async_send_command(self, cmd: Command) -> None: """Encode a replication command and send it over our outbound connection""" string = "%s %s" % (cmd.NAME, cmd.to_line()) if "\n" in string: @@ -259,7 +259,7 @@ class SynapseRedisFactory(txredisapi.RedisFactory): hs.get_clock().looping_call(self._send_ping, 30 * 1000) @wrap_as_background_process("redis_ping") - async def _send_ping(self): + async def _send_ping(self) -> None: for connection in self.pool: try: await make_deferred_yieldable(connection.ping()) @@ -269,13 +269,13 @@ class SynapseRedisFactory(txredisapi.RedisFactory): # ReconnectingClientFactory has some logging (if you enable `self.noisy`), but # it's rubbish. We add our own here. - def startedConnecting(self, connector: IConnector): + def startedConnecting(self, connector: IConnector) -> None: logger.info( "Connecting to redis server %s", format_address(connector.getDestination()) ) super().startedConnecting(connector) - def clientConnectionFailed(self, connector: IConnector, reason: Failure): + def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None: logger.info( "Connection to redis server %s failed: %s", format_address(connector.getDestination()), @@ -283,7 +283,7 @@ class SynapseRedisFactory(txredisapi.RedisFactory): ) super().clientConnectionFailed(connector, reason) - def clientConnectionLost(self, connector: IConnector, reason: Failure): + def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None: logger.info( "Connection to redis server %s lost: %s", format_address(connector.getDestination()), @@ -330,7 +330,7 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): self.synapse_outbound_redis_connection = outbound_redis_connection - def buildProtocol(self, addr): + def buildProtocol(self, addr: IAddress) -> RedisSubscriber: p = super().buildProtocol(addr) p = cast(RedisSubscriber, p) @@ -373,7 +373,7 @@ def lazyConnection( reactor = hs.get_reactor() reactor.connectTCP( - host, # type: ignore[arg-type] + host, port, factory, timeout=30, diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index a9d85f4f6c..ecd6190f5b 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -16,16 +16,18 @@ import logging import random -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List, Optional, Tuple from prometheus_client import Counter +from twisted.internet.interfaces import IAddress from twisted.internet.protocol import ServerFactory from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import PositionCommand from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol from synapse.replication.tcp.streams import EventsStream +from synapse.replication.tcp.streams._base import StreamRow, Token from synapse.util.metrics import Measure if TYPE_CHECKING: @@ -56,7 +58,7 @@ class ReplicationStreamProtocolFactory(ServerFactory): # listener config again or always starting a `ReplicationStreamer`.) hs.get_replication_streamer() - def buildProtocol(self, addr): + def buildProtocol(self, addr: IAddress) -> ServerReplicationStreamProtocol: return ServerReplicationStreamProtocol( self.server_name, self.clock, self.command_handler ) @@ -105,7 +107,7 @@ class ReplicationStreamer: if any(EventsStream.NAME == s.NAME for s in self.streams): self.clock.looping_call(self.on_notifier_poke, 1000) - def on_notifier_poke(self): + def on_notifier_poke(self) -> None: """Checks if there is actually any new data and sends it to the connections if there are. @@ -137,7 +139,7 @@ class ReplicationStreamer: run_as_background_process("replication_notifier", self._run_notifier_loop) - async def _run_notifier_loop(self): + async def _run_notifier_loop(self) -> None: self.is_looping = True try: @@ -238,7 +240,9 @@ class ReplicationStreamer: self.is_looping = False -def _batch_updates(updates): +def _batch_updates( + updates: List[Tuple[Token, StreamRow]] +) -> List[Tuple[Optional[Token], StreamRow]]: """Takes a list of updates of form [(token, row)] and sets the token to None for all rows where the next row has the same token. This is used to implement batching. @@ -254,7 +258,7 @@ def _batch_updates(updates): if not updates: return [] - new_updates = [] + new_updates: List[Tuple[Optional[Token], StreamRow]] = [] for i, update in enumerate(updates[:-1]): if update[0] == updates[i + 1][0]: new_updates.append((None, update[1])) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 5a2d90c530..914b9eae84 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -90,7 +90,7 @@ class Stream: ROW_TYPE: Any = None @classmethod - def parse_row(cls, row: StreamRow): + def parse_row(cls, row: StreamRow) -> Any: """Parse a row received over replication By default, assumes that the row data is an array object and passes its contents @@ -139,7 +139,7 @@ class Stream: # The token from which we last asked for updates self.last_token = self.current_token(self.local_instance_name) - def discard_updates_and_advance(self): + def discard_updates_and_advance(self) -> None: """Called when the stream should advance but the updates would be discarded, e.g. when there are no currently connected workers. """ @@ -200,7 +200,7 @@ def current_token_without_instance( return lambda instance_name: current_token() -def make_http_update_function(hs, stream_name: str) -> UpdateFunction: +def make_http_update_function(hs: "HomeServer", stream_name: str) -> UpdateFunction: """Makes a suitable function for use as an `update_function` that queries the master process for updates. """ diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 4f4f1ad453..50c4a5ba03 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -13,12 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import heapq -from collections.abc import Iterable -from typing import TYPE_CHECKING, Optional, Tuple, Type +from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast import attr -from ._base import Stream, StreamUpdateResult, Token +from synapse.replication.tcp.streams._base import ( + Stream, + StreamRow, + StreamUpdateResult, + Token, +) if TYPE_CHECKING: from synapse.server import HomeServer @@ -58,6 +62,9 @@ class EventsStreamRow: data: "BaseEventsStreamRow" +T = TypeVar("T", bound="BaseEventsStreamRow") + + class BaseEventsStreamRow: """Base class for rows to be sent in the events stream. @@ -68,7 +75,7 @@ class BaseEventsStreamRow: TypeId: str @classmethod - def from_data(cls, data): + def from_data(cls: Type[T], data: Iterable[Optional[str]]) -> T: """Parse the data from the replication stream into a row. By default we just call the constructor with the data list as arguments @@ -221,7 +228,7 @@ class EventsStream(Stream): return updates, upper_limit, limited @classmethod - def parse_row(cls, row): - (typ, data) = row - data = TypeToRow[typ].from_data(data) - return EventsStreamRow(typ, data) + def parse_row(cls, row: StreamRow) -> "EventsStreamRow": + (typ, data) = cast(Tuple[str, Iterable[Optional[str]]], row) + event_stream_row_data = TypeToRow[typ].from_data(data) + return EventsStreamRow(typ, event_stream_row_data) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 9be9e33c8e..ba0d989d81 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -20,7 +20,8 @@ import platform from http import HTTPStatus from typing import TYPE_CHECKING, Optional, Tuple -import synapse +from matrix_common.versionstring import get_distribution_version_string + from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer, JsonResource from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -88,7 +89,6 @@ from synapse.rest.admin.users import ( WhoisRestServlet, ) from synapse.types import JsonDict, RoomStreamToken -from synapse.util.versionstring import get_version_string if TYPE_CHECKING: from synapse.server import HomeServer @@ -101,7 +101,7 @@ class VersionServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.res = { - "server_version": get_version_string(synapse), + "server_version": get_distribution_version_string("matrix-synapse"), "python_version": platform.python_version(), } diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 6b272658fc..efe299e698 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -385,7 +385,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param - if not check_3pid_allowed(self.hs, "email", email): + if not await check_3pid_allowed(self.hs, "email", email): raise SynapseError( 403, "Your email domain is not authorized on this server", @@ -468,7 +468,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): msisdn = phone_number_to_msisdn(country, phone_number) - if not check_3pid_allowed(self.hs, "msisdn", msisdn): + if not await check_3pid_allowed(self.hs, "msisdn", msisdn): raise SynapseError( 403, "Account phone numbers are not authorized on this server", @@ -883,7 +883,9 @@ class WhoamiRestServlet(RestServlet): response = { "user_id": requester.user.to_string(), # MSC: https://github.com/matrix-org/matrix-doc/pull/3069 + # Entered spec in Matrix 1.2 "org.matrix.msc3069.is_guest": bool(requester.is_guest), + "is_guest": bool(requester.is_guest), } # Appservices and similar accounts do not have device IDs diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 9c15a04338..e0b2b80e5b 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -62,7 +62,7 @@ class AuthRestServlet(RestServlet): if stagetype == LoginType.RECAPTCHA: html = self.recaptcha_template.render( session=session, - myurl="%s/r0/auth/%s/fallback/web" + myurl="%s/v3/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), sitekey=self.hs.config.captcha.recaptcha_public_key, ) @@ -74,7 +74,7 @@ class AuthRestServlet(RestServlet): self.hs.config.server.public_baseurl, self.hs.config.consent.user_consent_version, ), - myurl="%s/r0/auth/%s/fallback/web" + myurl="%s/v3/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), ) @@ -118,7 +118,7 @@ class AuthRestServlet(RestServlet): # Authentication failed, let user try again html = self.recaptcha_template.render( session=session, - myurl="%s/r0/auth/%s/fallback/web" + myurl="%s/v3/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), sitekey=self.hs.config.captcha.recaptcha_public_key, error=e.msg, @@ -143,7 +143,7 @@ class AuthRestServlet(RestServlet): self.hs.config.server.public_baseurl, self.hs.config.consent.user_consent_version, ), - myurl="%s/r0/auth/%s/fallback/web" + myurl="%s/v3/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), error=e.msg, ) diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index 5c0e3a5680..e05c926b6f 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Tuple from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, MSC3244_CAPABILITIES @@ -54,6 +55,15 @@ class CapabilitiesRestServlet(RestServlet): }, }, "m.change_password": {"enabled": change_password}, + "m.set_displayname": { + "enabled": self.config.registration.enable_set_displayname + }, + "m.set_avatar_url": { + "enabled": self.config.registration.enable_set_avatar_url + }, + "m.3pid_changes": { + "enabled": self.config.registration.enable_3pid_changes + }, } } @@ -62,21 +72,10 @@ class CapabilitiesRestServlet(RestServlet): "org.matrix.msc3244.room_capabilities" ] = MSC3244_CAPABILITIES - if self.config.experimental.msc3283_enabled: - response["capabilities"]["org.matrix.msc3283.set_displayname"] = { - "enabled": self.config.registration.enable_set_displayname - } - response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = { - "enabled": self.config.registration.enable_set_avatar_url - } - response["capabilities"]["org.matrix.msc3283.3pid_changes"] = { - "enabled": self.config.registration.enable_3pid_changes - } - if self.config.experimental.msc3440_enabled: response["capabilities"]["io.element.thread"] = {"enabled": True} - return 200, response + return HTTPStatus.OK, response def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index 6f796d5e50..8fe75bd750 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -29,7 +29,7 @@ from synapse.http.servlet import ( parse_string, ) from synapse.http.site import SynapseRequest -from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS +from synapse.push.baserules import BASE_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.rest.client._base import client_patterns @@ -61,10 +61,6 @@ class PushRuleRestServlet(RestServlet): self.notifier = hs.get_notifier() self._is_worker = hs.config.worker.worker_app is not None - self._users_new_default_push_rules = ( - hs.config.server.users_new_default_push_rules - ) - async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: if self._is_worker: raise Exception("Cannot handle PUT /push_rules on worker") @@ -217,12 +213,7 @@ class PushRuleRestServlet(RestServlet): rule_id = spec.rule_id is_default_rule = rule_id.startswith(".") if is_default_rule: - if user_id in self._users_new_default_push_rules: - rule_ids = NEW_RULE_IDS - else: - rule_ids = BASE_RULE_IDS - - if namespaced_rule_id not in rule_ids: + if namespaced_rule_id not in BASE_RULE_IDS: raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) await self.store.set_push_rule_actions( user_id, namespaced_rule_id, actions, is_default_rule diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index e3492f9f93..b8a5135e02 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -112,7 +112,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param - if not check_3pid_allowed(self.hs, "email", email): + if not await check_3pid_allowed(self.hs, "email", email, registration=True): raise SynapseError( 403, "Your email domain is not authorized to register on this server", @@ -192,7 +192,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): msisdn = phone_number_to_msisdn(country, phone_number) - if not check_3pid_allowed(self.hs, "msisdn", msisdn): + if not await check_3pid_allowed(self.hs, "msisdn", msisdn, registration=True): raise SynapseError( 403, "Phone numbers are not authorized to register on this server", @@ -368,7 +368,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): Example: - GET /_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity?token=abcd + GET /_matrix/client/v1/register/m.login.registration_token/validity?token=abcd 200 OK @@ -378,9 +378,8 @@ class RegistrationTokenValidityRestServlet(RestServlet): """ PATTERNS = client_patterns( - f"/org.matrix.msc3231/register/{LoginType.REGISTRATION_TOKEN}/validity", - releases=(), - unstable=True, + f"/register/{LoginType.REGISTRATION_TOKEN}/validity", + releases=("v1",), ) def __init__(self, hs: "HomeServer"): @@ -617,7 +616,9 @@ class RegisterRestServlet(RestServlet): medium = auth_result[login_type]["medium"] address = auth_result[login_type]["address"] - if not check_3pid_allowed(self.hs, medium, address): + if not await check_3pid_allowed( + self.hs, medium, address, registration=True + ): raise SynapseError( 403, "Third party identifiers (email/phone numbers)" @@ -693,11 +694,18 @@ class RegisterRestServlet(RestServlet): session_id ) + display_name = await ( + self.password_auth_provider.get_displayname_for_registration( + auth_result, params + ) + ) + registered_user_id = await self.registration_handler.register_user( localpart=desired_username, password_hash=password_hash, guest_access_token=guest_access_token, threepid=threepid, + default_display_name=display_name, address=client_addr, user_agent_ips=entries, ) diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 8cf5ebaa07..2cab83c4e6 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -32,14 +32,45 @@ from synapse.storage.relations import ( PaginationChunk, RelationPaginationToken, ) -from synapse.types import JsonDict +from synapse.types import JsonDict, RoomStreamToken, StreamToken if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) +async def _parse_token( + store: "DataStore", token: Optional[str] +) -> Optional[StreamToken]: + """ + For backwards compatibility support RelationPaginationToken, but new pagination + tokens are generated as full StreamTokens, to be compatible with /sync and /messages. + """ + if not token: + return None + # Luckily the format for StreamToken and RelationPaginationToken differ enough + # that they can easily be separated. An "_" appears in the serialization of + # RoomStreamToken (as part of StreamToken), but RelationPaginationToken uses + # "-" only for separators. + if "_" in token: + return await StreamToken.from_string(store, token) + else: + relation_token = RelationPaginationToken.from_string(token) + return StreamToken( + room_key=RoomStreamToken(relation_token.topological, relation_token.stream), + presence_key=0, + typing_key=0, + receipt_key=0, + account_data_key=0, + push_rules_key=0, + to_device_key=0, + device_list_key=0, + groups_key=0, + ) + + class RelationPaginationServlet(RestServlet): """API to paginate relations on an event by topological ordering, optionally filtered by relation type and event type. @@ -80,6 +111,9 @@ class RelationPaginationServlet(RestServlet): raise SynapseError(404, "Unknown parent event.") limit = parse_integer(request, "limit", default=5) + direction = parse_string( + request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"] + ) from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") @@ -88,13 +122,8 @@ class RelationPaginationServlet(RestServlet): pagination_chunk = PaginationChunk(chunk=[]) else: # Return the relations - from_token = None - if from_token_str: - from_token = RelationPaginationToken.from_string(from_token_str) - - to_token = None - if to_token_str: - to_token = RelationPaginationToken.from_string(to_token_str) + from_token = await _parse_token(self.store, from_token_str) + to_token = await _parse_token(self.store, to_token_str) pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, @@ -102,6 +131,7 @@ class RelationPaginationServlet(RestServlet): relation_type=relation_type, event_type=event_type, limit=limit, + direction=direction, from_token=from_token, to_token=to_token, ) @@ -125,7 +155,7 @@ class RelationPaginationServlet(RestServlet): events, now, bundle_aggregations=aggregations ) - return_value = pagination_chunk.to_dict() + return_value = await pagination_chunk.to_dict(self.store) return_value["chunk"] = serialized_events return_value["original_event"] = original_event @@ -216,7 +246,7 @@ class RelationAggregationPaginationServlet(RestServlet): to_token=to_token, ) - return 200, pagination_chunk.to_dict() + return 200, await pagination_chunk.to_dict(self.store) class RelationAggregationGroupPaginationServlet(RestServlet): @@ -287,13 +317,8 @@ class RelationAggregationGroupPaginationServlet(RestServlet): from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") - from_token = None - if from_token_str: - from_token = RelationPaginationToken.from_string(from_token_str) - - to_token = None - if to_token_str: - to_token = RelationPaginationToken.from_string(to_token_str) + from_token = await _parse_token(self.store, from_token_str) + to_token = await _parse_token(self.store, to_token_str) result = await self.store.get_relations_for_event( event_id=parent_id, @@ -313,7 +338,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): now = self.clock.time_msec() serialized_events = self._event_serializer.serialize_events(events, now) - return_value = result.to_dict() + return_value = await result.to_dict(self.store) return_value["chunk"] = serialized_events return 200, return_value diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index e4c9451ae0..4b6be38327 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -131,6 +131,14 @@ class RoomBatchSendEventRestServlet(RestServlet): prev_event_ids_from_query ) + if not auth_event_ids: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "No auth events found for given prev_event query parameter. The prev_event=%s probably does not exist." + % prev_event_ids_from_query, + errcode=Codes.INVALID_PARAM, + ) + state_event_ids_at_start = [] # Create and persist all of the state events that float off on their own # before the batch. These will most likely be all of the invite/member @@ -197,21 +205,12 @@ class RoomBatchSendEventRestServlet(RestServlet): EventContentFields.MSC2716_NEXT_BATCH_ID ] - # Also connect the historical event chain to the end of the floating - # state chain, which causes the HS to ask for the state at the start of - # the batch later. If there is no state chain to connect to, just make - # the insertion event float itself. - prev_event_ids = [] - if len(state_event_ids_at_start): - prev_event_ids = [state_event_ids_at_start[-1]] - # Create and persist all of the historical events as well as insertion # and batch meta events to make the batch navigable in the DAG. event_ids, next_batch_id = await self.room_batch_handler.handle_batch_of_events( events_to_create=events_to_create, room_id=room_id, batch_id_to_connect_to=batch_id_to_connect_to, - initial_prev_event_ids=prev_event_ids, inherited_depth=inherited_depth, auth_event_ids=auth_event_ids, app_service_requester=requester, diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 2290c57c12..00f29344a8 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -73,6 +73,8 @@ class VersionsRestServlet(RestServlet): "r0.5.0", "r0.6.0", "r0.6.1", + "v1.1", + "v1.2", ], # as per MSC1497: "unstable_features": { diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index efd84ced8f..8d3d1e54dc 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -403,6 +403,7 @@ class PreviewUrlResource(DirectServeJsonResource): output_stream=output_stream, max_size=self.max_spider_size, headers={"Accept-Language": self.url_preview_accept_language}, + is_allowed_content_type=_is_previewable, ) except SynapseError: # Pass SynapseErrors through directly, so that the servlet @@ -761,3 +762,10 @@ def _is_html(content_type: str) -> bool: def _is_json(content_type: str) -> bool: return content_type.lower().startswith("application/json") + + +def _is_previewable(content_type: str) -> bool: + """Returns True for content types for which we will perform URL preview and False + otherwise.""" + + return _is_html(content_type) or _is_media(content_type) or _is_json(content_type) diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 8162094cf6..fde28d08cb 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -49,10 +49,14 @@ class UploadResource(DirectServeJsonResource): async def _async_render_POST(self, request: SynapseRequest) -> None: requester = await self.auth.get_user_by_req(request) - content_length = request.getHeader("Content-Length") - if content_length is None: + raw_content_length = request.getHeader("Content-Length") + if raw_content_length is None: raise SynapseError(msg="Request must specify a Content-Length", code=400) - if int(content_length) > self.max_upload_size: + try: + content_length = int(raw_content_length) + except ValueError: + raise SynapseError(msg="Content-Length value is invalid", code=400) + if content_length > self.max_upload_size: raise SynapseError( msg="Upload request body is too large", code=413, @@ -66,7 +70,8 @@ class UploadResource(DirectServeJsonResource): upload_name: Optional[str] = upload_name_bytes.decode("utf8") except UnicodeDecodeError: raise SynapseError( - msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400 + msg="Invalid UTF-8 filename parameter: %r" % (upload_name_bytes,), + code=400, ) # If the name is falsey (e.g. an empty byte string) ensure it is None. diff --git a/synapse/server.py b/synapse/server.py index 3032f0b738..564afdcb96 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -233,8 +233,8 @@ class HomeServer(metaclass=abc.ABCMeta): self, hostname: str, config: HomeServerConfig, - reactor=None, - version_string="Synapse", + reactor: Optional[ISynapseReactor] = None, + version_string: str = "Synapse", ): """ Args: @@ -244,7 +244,7 @@ class HomeServer(metaclass=abc.ABCMeta): if not reactor: from twisted.internet import reactor as _reactor - reactor = _reactor + reactor = cast(ISynapseReactor, _reactor) self._reactor = reactor self.hostname = hostname @@ -264,7 +264,7 @@ class HomeServer(metaclass=abc.ABCMeta): self._module_web_resources: Dict[str, Resource] = {} self._module_web_resources_consumed = False - def register_module_web_resource(self, path: str, resource: Resource): + def register_module_web_resource(self, path: str, resource: Resource) -> None: """Allows a module to register a web resource to be served at the given path. If multiple modules register a resource for the same path, the module that diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 7967011afd..8df80664a2 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -57,7 +57,7 @@ class SQLBaseStore(metaclass=ABCMeta): pass def _invalidate_state_caches( - self, room_id: str, members_changed: Iterable[str] + self, room_id: str, members_changed: Collection[str] ) -> None: """Invalidates caches that are based on the current state, but does not stream invalidations down replication. @@ -66,11 +66,16 @@ class SQLBaseStore(metaclass=ABCMeta): room_id: Room where state changed members_changed: The user_ids of members that have changed """ + # If there were any membership changes, purge the appropriate caches. for host in {get_domain_from_id(u) for u in members_changed}: self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) + if members_changed: + self._attempt_to_invalidate_cache("get_users_in_room", (room_id,)) + self._attempt_to_invalidate_cache( + "get_users_in_room_with_profiles", (room_id,) + ) - self._attempt_to_invalidate_cache("get_users_in_room", (room_id,)) - self._attempt_to_invalidate_cache("get_users_in_room_with_profiles", (room_id,)) + # Purge other caches based on room state. self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,)) diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 5bfa408f74..52146aacc8 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -106,6 +106,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) "AccountDataAndTagsChangeCache", account_max ) + self.db_pool.updates.register_background_update_handler( + "delete_account_data_for_deactivated_users", + self._delete_account_data_for_deactivated_users, + ) + def get_max_account_data_stream_id(self) -> int: """Get the current max stream ID for account data stream @@ -549,72 +554,121 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) async def purge_account_data_for_user(self, user_id: str) -> None: """ - Removes the account data for a user. + Removes ALL the account data for a user. + Intended to be used upon user deactivation. - This is intended to be used upon user deactivation and also removes any - derived information from account data (e.g. push rules and ignored users). + Also purges the user from the ignored_users cache table + and the push_rules cache tables. + """ - Args: - user_id: The user ID to remove data for. + await self.db_pool.runInteraction( + "purge_account_data_for_user_txn", + self._purge_account_data_for_user_txn, + user_id, + ) + + def _purge_account_data_for_user_txn( + self, txn: LoggingTransaction, user_id: str + ) -> None: """ + See `purge_account_data_for_user`. + """ + # Purge from the primary account_data tables. + self.db_pool.simple_delete_txn( + txn, table="account_data", keyvalues={"user_id": user_id} + ) - def purge_account_data_for_user_txn(txn: LoggingTransaction) -> None: - # Purge from the primary account_data tables. - self.db_pool.simple_delete_txn( - txn, table="account_data", keyvalues={"user_id": user_id} - ) + self.db_pool.simple_delete_txn( + txn, table="room_account_data", keyvalues={"user_id": user_id} + ) - self.db_pool.simple_delete_txn( - txn, table="room_account_data", keyvalues={"user_id": user_id} - ) + # Purge from ignored_users where this user is the ignorer. + # N.B. We don't purge where this user is the ignoree, because that + # interferes with other users' account data. + # It's also not this user's data to delete! + self.db_pool.simple_delete_txn( + txn, table="ignored_users", keyvalues={"ignorer_user_id": user_id} + ) - # Purge from ignored_users where this user is the ignorer. - # N.B. We don't purge where this user is the ignoree, because that - # interferes with other users' account data. - # It's also not this user's data to delete! - self.db_pool.simple_delete_txn( - txn, table="ignored_users", keyvalues={"ignorer_user_id": user_id} - ) + # Remove the push rules + self.db_pool.simple_delete_txn( + txn, table="push_rules", keyvalues={"user_name": user_id} + ) + self.db_pool.simple_delete_txn( + txn, table="push_rules_enable", keyvalues={"user_name": user_id} + ) + self.db_pool.simple_delete_txn( + txn, table="push_rules_stream", keyvalues={"user_id": user_id} + ) - # Remove the push rules - self.db_pool.simple_delete_txn( - txn, table="push_rules", keyvalues={"user_name": user_id} - ) - self.db_pool.simple_delete_txn( - txn, table="push_rules_enable", keyvalues={"user_name": user_id} - ) - self.db_pool.simple_delete_txn( - txn, table="push_rules_stream", keyvalues={"user_id": user_id} - ) + # Invalidate caches as appropriate + self._invalidate_cache_and_stream( + txn, self.get_account_data_for_room_and_type, (user_id,) + ) + self._invalidate_cache_and_stream( + txn, self.get_account_data_for_user, (user_id,) + ) + self._invalidate_cache_and_stream( + txn, self.get_global_account_data_by_type_for_user, (user_id,) + ) + self._invalidate_cache_and_stream( + txn, self.get_account_data_for_room, (user_id,) + ) + self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,)) + self._invalidate_cache_and_stream( + txn, self.get_push_rules_enabled_for_user, (user_id,) + ) + # This user might be contained in the ignored_by cache for other users, + # so we have to invalidate it all. + self._invalidate_all_cache_and_stream(txn, self.ignored_by) - # Invalidate caches as appropriate - self._invalidate_cache_and_stream( - txn, self.get_account_data_for_room_and_type, (user_id,) - ) - self._invalidate_cache_and_stream( - txn, self.get_account_data_for_user, (user_id,) - ) - self._invalidate_cache_and_stream( - txn, self.get_global_account_data_by_type_for_user, (user_id,) - ) - self._invalidate_cache_and_stream( - txn, self.get_account_data_for_room, (user_id,) - ) - self._invalidate_cache_and_stream( - txn, self.get_push_rules_for_user, (user_id,) - ) - self._invalidate_cache_and_stream( - txn, self.get_push_rules_enabled_for_user, (user_id,) - ) - # This user might be contained in the ignored_by cache for other users, - # so we have to invalidate it all. - self._invalidate_all_cache_and_stream(txn, self.ignored_by) + async def _delete_account_data_for_deactivated_users( + self, progress: dict, batch_size: int + ) -> int: + """ + Retroactively purges account data for users that have already been deactivated. + Gets run as a background update caused by a schema delta. + """ - await self.db_pool.runInteraction( - "purge_account_data_for_user_txn", - purge_account_data_for_user_txn, + last_user: str = progress.get("last_user", "") + + def _delete_account_data_for_deactivated_users_txn( + txn: LoggingTransaction, + ) -> int: + sql = """ + SELECT name FROM users + WHERE deactivated = ? and name > ? + ORDER BY name ASC + LIMIT ? + """ + + txn.execute(sql, (1, last_user, batch_size)) + users = [row[0] for row in txn] + + for user in users: + self._purge_account_data_for_user_txn(txn, user_id=user) + + if users: + self.db_pool.updates._background_update_progress_txn( + txn, + "delete_account_data_for_deactivated_users", + {"last_user": users[-1]}, + ) + + return len(users) + + number_deleted = await self.db_pool.runInteraction( + "_delete_account_data_for_deactivated_users", + _delete_account_data_for_deactivated_users_txn, ) + if number_deleted < batch_size: + await self.db_pool.updates._end_background_update( + "delete_account_data_for_deactivated_users" + ) + + return number_deleted + class AccountDataStore(AccountDataWorkerStore): pass diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 2bb5288431..304814af5d 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -198,6 +198,7 @@ class ApplicationServiceTransactionWorkerStore( service: ApplicationService, events: List[EventBase], ephemeral: List[JsonDict], + to_device_messages: List[JsonDict], ) -> AppServiceTransaction: """Atomically creates a new transaction for this application service with the given list of events. Ephemeral events are NOT persisted to the @@ -207,6 +208,7 @@ class ApplicationServiceTransactionWorkerStore( service: The service who the transaction is for. events: A list of persistent events to put in the transaction. ephemeral: A list of ephemeral events to put in the transaction. + to_device_messages: A list of to-device messages to put in the transaction. Returns: A new transaction. @@ -237,7 +239,11 @@ class ApplicationServiceTransactionWorkerStore( (service.id, new_txn_id, event_ids), ) return AppServiceTransaction( - service=service, id=new_txn_id, events=events, ephemeral=ephemeral + service=service, + id=new_txn_id, + events=events, + ephemeral=ephemeral, + to_device_messages=to_device_messages, ) return await self.db_pool.runInteraction( @@ -330,7 +336,11 @@ class ApplicationServiceTransactionWorkerStore( events = await self.get_events_as_list(event_ids) return AppServiceTransaction( - service=service, id=entry["txn_id"], events=events, ephemeral=[] + service=service, + id=entry["txn_id"], + events=events, + ephemeral=[], + to_device_messages=[], ) def _get_last_txn(self, txn, service_id: Optional[str]) -> int: @@ -391,7 +401,7 @@ class ApplicationServiceTransactionWorkerStore( async def get_type_stream_id_for_appservice( self, service: ApplicationService, type: str ) -> int: - if type not in ("read_receipt", "presence"): + if type not in ("read_receipt", "presence", "to_device"): raise ValueError( "Expected type to be a valid application stream id type, got %s" % (type,) @@ -415,16 +425,16 @@ class ApplicationServiceTransactionWorkerStore( "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn ) - async def set_type_stream_id_for_appservice( + async def set_appservice_stream_type_pos( self, service: ApplicationService, stream_type: str, pos: Optional[int] ) -> None: - if stream_type not in ("read_receipt", "presence"): + if stream_type not in ("read_receipt", "presence", "to_device"): raise ValueError( "Expected type to be a valid application stream id type, got %s" % (stream_type,) ) - def set_type_stream_id_for_appservice_txn(txn): + def set_appservice_stream_type_pos_txn(txn): stream_id_type = "%s_stream_id" % stream_type txn.execute( "UPDATE application_services_state SET %s = ? WHERE as_id=?" @@ -433,7 +443,7 @@ class ApplicationServiceTransactionWorkerStore( ) await self.db_pool.runInteraction( - "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn + "set_appservice_stream_type_pos", set_appservice_stream_type_pos_txn ) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 0024348067..c428dd5596 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -15,7 +15,7 @@ import itertools import logging -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Tuple from synapse.api.constants import EventTypes from synapse.replication.tcp.streams import BackfillStream, CachesStream @@ -25,7 +25,11 @@ from synapse.replication.tcp.streams.events import ( EventsStreamEventRow, ) from synapse.storage._base import SQLBaseStore -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import PostgresEngine from synapse.util.iterutils import batch_iter @@ -236,7 +240,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): txn.call_after(cache_func.invalidate_all) self._send_invalidation_to_replication(txn, cache_func.__name__, None) - def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): + def _invalidate_state_caches_and_stream( + self, txn: LoggingTransaction, room_id: str, members_changed: Collection[str] + ) -> None: """Special case invalidation of caches based on current state. We special case this so that we can batch the cache invalidations into a @@ -244,8 +250,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): Args: txn - room_id (str): Room where state changed - members_changed (iterable[str]): The user_ids of members that have changed + room_id: Room where state changed + members_changed: The user_ids of members that have changed """ txn.call_after(self._invalidate_state_caches, room_id, members_changed) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 4eca97189b..1392363de1 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast from synapse.logging import issue9533_logger from synapse.logging.opentracing import log_kv, set_tag, trace @@ -24,6 +24,7 @@ from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_in_list_sql_clause, ) from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( @@ -136,63 +137,263 @@ class DeviceInboxWorkerStore(SQLBaseStore): def get_to_device_stream_token(self): return self._device_inbox_id_gen.get_current_token() - async def get_new_messages_for_device( + async def get_messages_for_user_devices( + self, + user_ids: Collection[str], + from_stream_id: int, + to_stream_id: int, + ) -> Dict[Tuple[str, str], List[JsonDict]]: + """ + Retrieve to-device messages for a given set of users. + + Only to-device messages with stream ids between the given boundaries + (from < X <= to) are returned. + + Args: + user_ids: The users to retrieve to-device messages for. + from_stream_id: The lower boundary of stream id to filter with (exclusive). + to_stream_id: The upper boundary of stream id to filter with (inclusive). + + Returns: + A dictionary of (user id, device id) -> list of to-device messages. + """ + # We expect the stream ID returned by _get_device_messages to always + # be to_stream_id. So, no need to return it from this function. + ( + user_id_device_id_to_messages, + last_processed_stream_id, + ) = await self._get_device_messages( + user_ids=user_ids, + from_stream_id=from_stream_id, + to_stream_id=to_stream_id, + ) + + assert ( + last_processed_stream_id == to_stream_id + ), "Expected _get_device_messages to process all to-device messages up to `to_stream_id`" + + return user_id_device_id_to_messages + + async def get_messages_for_device( self, user_id: str, - device_id: Optional[str], - last_stream_id: int, - current_stream_id: int, + device_id: str, + from_stream_id: int, + to_stream_id: int, limit: int = 100, - ) -> Tuple[List[dict], int]: + ) -> Tuple[List[JsonDict], int]: """ + Retrieve to-device messages for a single user device. + + Only to-device messages with stream ids between the given boundaries + (from < X <= to) are returned. + Args: - user_id: The recipient user_id. - device_id: The recipient device_id. - last_stream_id: The last stream ID checked. - current_stream_id: The current position of the to device - message stream. - limit: The maximum number of messages to retrieve. + user_id: The ID of the user to retrieve messages for. + device_id: The ID of the device to retrieve to-device messages for. + from_stream_id: The lower boundary of stream id to filter with (exclusive). + to_stream_id: The upper boundary of stream id to filter with (inclusive). + limit: A limit on the number of to-device messages returned. Returns: A tuple containing: - * A list of messages for the device. - * The max stream token of these messages. There may be more to retrieve - if the given limit was reached. + * A list of to-device messages within the given stream id range intended for + the given user / device combo. + * The last-processed stream ID. Subsequent calls of this function with the + same device should pass this value as 'from_stream_id'. """ - has_changed = self._device_inbox_stream_cache.has_entity_changed( - user_id, last_stream_id + ( + user_id_device_id_to_messages, + last_processed_stream_id, + ) = await self._get_device_messages( + user_ids=[user_id], + device_id=device_id, + from_stream_id=from_stream_id, + to_stream_id=to_stream_id, + limit=limit, ) - if not has_changed: - return [], current_stream_id - def get_new_messages_for_device_txn(txn): - sql = ( - "SELECT stream_id, message_json FROM device_inbox" - " WHERE user_id = ? AND device_id = ?" - " AND ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" - " LIMIT ?" + if not user_id_device_id_to_messages: + # There were no messages! + return [], to_stream_id + + # Extract the messages, no need to return the user and device ID again + to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), []) + + return to_device_messages, last_processed_stream_id + + async def _get_device_messages( + self, + user_ids: Collection[str], + from_stream_id: int, + to_stream_id: int, + device_id: Optional[str] = None, + limit: Optional[int] = None, + ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]: + """ + Retrieve pending to-device messages for a collection of user devices. + + Only to-device messages with stream ids between the given boundaries + (from < X <= to) are returned. + + Note that a stream ID can be shared by multiple copies of the same message with + different recipient devices. Stream IDs are only unique in the context of a single + user ID / device ID pair. Thus, applying a limit (of messages to return) when working + with a sliding window of stream IDs is only possible when querying messages of a + single user device. + + Finally, note that device IDs are not unique across users. + + Args: + user_ids: The user IDs to filter device messages by. + from_stream_id: The lower boundary of stream id to filter with (exclusive). + to_stream_id: The upper boundary of stream id to filter with (inclusive). + device_id: A device ID to query to-device messages for. If not provided, to-device + messages from all device IDs for the given user IDs will be queried. May not be + provided if `user_ids` contains more than one entry. + limit: The maximum number of to-device messages to return. Can only be used when + passing a single user ID / device ID tuple. + + Returns: + A tuple containing: + * A dict of (user_id, device_id) -> list of to-device messages + * The last-processed stream ID. If this is less than `to_stream_id`, then + there may be more messages to retrieve. If `limit` is not set, then this + is always equal to 'to_stream_id'. + """ + if not user_ids: + logger.warning("No users provided upon querying for device IDs") + return {}, to_stream_id + + # Prevent a query for one user's device also retrieving another user's device with + # the same device ID (device IDs are not unique across users). + if len(user_ids) > 1 and device_id is not None: + raise AssertionError( + "Programming error: 'device_id' cannot be supplied to " + "_get_device_messages when >1 user_id has been provided" ) - txn.execute( - sql, (user_id, device_id, last_stream_id, current_stream_id, limit) + + # A limit can only be applied when querying for a single user ID / device ID tuple. + # See the docstring of this function for more details. + if limit is not None and device_id is None: + raise AssertionError( + "Programming error: _get_device_messages was passed 'limit' " + "without a specific user_id/device_id" ) - messages = [] - stream_pos = current_stream_id + user_ids_to_query: Set[str] = set() + device_ids_to_query: Set[str] = set() + + # Note that a device ID could be an empty str + if device_id is not None: + # If a device ID was passed, use it to filter results. + # Otherwise, device IDs will be derived from the given collection of user IDs. + device_ids_to_query.add(device_id) + + # Determine which users have devices with pending messages + for user_id in user_ids: + if self._device_inbox_stream_cache.has_entity_changed( + user_id, from_stream_id + ): + # This user has new messages sent to them. Query messages for them + user_ids_to_query.add(user_id) + + def get_device_messages_txn(txn: LoggingTransaction): + # Build a query to select messages from any of the given devices that + # are between the given stream id bounds. + + # If a list of device IDs was not provided, retrieve all devices IDs + # for the given users. We explicitly do not query hidden devices, as + # hidden devices should not receive to-device messages. + # Note that this is more efficient than just dropping `device_id` from the query, + # since device_inbox has an index on `(user_id, device_id, stream_id)` + if not device_ids_to_query: + user_device_dicts = self.db_pool.simple_select_many_txn( + txn, + table="devices", + column="user_id", + iterable=user_ids_to_query, + keyvalues={"user_id": user_id, "hidden": False}, + retcols=("device_id",), + ) - for row in txn: - stream_pos = row[0] - messages.append(db_to_json(row[1])) + device_ids_to_query.update( + {row["device_id"] for row in user_device_dicts} + ) - # If the limit was not reached we know that there's no more data for this - # user/device pair up to current_stream_id. - if len(messages) < limit: - stream_pos = current_stream_id + if not device_ids_to_query: + # We've ended up with no devices to query. + return {}, to_stream_id - return messages, stream_pos + # We include both user IDs and device IDs in this query, as we have an index + # (device_inbox_user_stream_id) for them. + user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause( + self.database_engine, "user_id", user_ids_to_query + ) + ( + device_id_many_clause_sql, + device_id_many_clause_args, + ) = make_in_list_sql_clause( + self.database_engine, "device_id", device_ids_to_query + ) + + sql = f""" + SELECT stream_id, user_id, device_id, message_json FROM device_inbox + WHERE {user_id_many_clause_sql} + AND {device_id_many_clause_sql} + AND ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + """ + sql_args = ( + *user_id_many_clause_args, + *device_id_many_clause_args, + from_stream_id, + to_stream_id, + ) + + # If a limit was provided, limit the data retrieved from the database + if limit is not None: + sql += "LIMIT ?" + sql_args += (limit,) + + txn.execute(sql, sql_args) + + # Create and fill a dictionary of (user ID, device ID) -> list of messages + # intended for each device. + last_processed_stream_pos = to_stream_id + recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {} + rowcount = 0 + for row in txn: + rowcount += 1 + + last_processed_stream_pos = row[0] + recipient_user_id = row[1] + recipient_device_id = row[2] + message_dict = db_to_json(row[3]) + + # Store the device details + recipient_device_to_messages.setdefault( + (recipient_user_id, recipient_device_id), [] + ).append(message_dict) + + if limit is not None and rowcount == limit: + # We ended up bumping up against the message limit. There may be more messages + # to retrieve. Return what we have, as well as the last stream position that + # was processed. + # + # The caller is expected to set this as the lower (exclusive) bound + # for the next query of this device. + return recipient_device_to_messages, last_processed_stream_pos + + # The limit was not reached, thus we know that recipient_device_to_messages + # contains all to-device messages for the given device and stream id range. + # + # We return to_stream_id, which the caller should then provide as the lower + # (exclusive) bound on the next query of this device. + return recipient_device_to_messages, to_stream_id return await self.db_pool.runInteraction( - "get_new_messages_for_device", get_new_messages_for_device_txn + "get_device_messages", get_device_messages_txn ) @trace diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index b2a5cd9a65..3b3a089b76 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -670,6 +670,16 @@ class DeviceWorkerStore(SQLBaseStore): device["device_id"]: db_to_json(device["content"]) for device in devices } + def get_cached_device_list_changes( + self, + from_key: int, + ) -> Optional[Set[str]]: + """Get set of users whose devices have changed since `from_key`, or None + if that information is not in our cache. + """ + + return self._device_list_stream_cache.get_all_entities_changed(from_key) + async def get_users_whose_devices_changed( self, from_key: int, user_ids: Iterable[str] ) -> Set[str]: @@ -1496,13 +1506,23 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) async def add_device_change_to_streams( - self, user_id: str, device_ids: Collection[str], hosts: List[str] - ) -> int: + self, user_id: str, device_ids: Collection[str], hosts: Collection[str] + ) -> Optional[int]: """Persist that a user's devices have been updated, and which hosts (if any) should be poked. + + Args: + user_id: The ID of the user whose device changed. + device_ids: The IDs of any changed devices. If empty, this function will + return None. + hosts: The remote destinations that should be notified of the change. + + Returns: + The maximum stream ID of device list updates that were added to the database, or + None if no updates were added. """ if not device_ids: - return + return None async with self._device_list_id_gen.get_next_mult( len(device_ids) @@ -1573,11 +1593,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self, txn: LoggingTransaction, user_id: str, - device_ids: Collection[str], - hosts: List[str], + device_ids: Iterable[str], + hosts: Collection[str], stream_ids: List[str], context: Dict[str, str], - ): + ) -> None: for host in hosts: txn.call_after( self._device_list_federation_stream_cache.entity_has_changed, diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ca71f073fc..277e6422eb 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -16,9 +16,10 @@ import logging from queue import Empty, PriorityQueue from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +import attr from prometheus_client import Counter, Gauge -from synapse.api.constants import MAX_DEPTH +from synapse.api.constants import MAX_DEPTH, EventTypes from synapse.api.errors import StoreError from synapse.api.room_versions import EventFormatVersions, RoomVersion from synapse.events import EventBase, make_event_from_dict @@ -60,6 +61,15 @@ pdus_pruned_from_federation_queue = Counter( logger = logging.getLogger(__name__) +# All the info we need while iterating the DAG while backfilling +@attr.s(frozen=True, slots=True, auto_attribs=True) +class BackfillQueueNavigationItem: + depth: int + stream_ordering: int + event_id: str + type: str + + class _NoChainCoverIndex(Exception): def __init__(self, room_id: str): super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,)) @@ -74,6 +84,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ): super().__init__(database, db_conn, hs) + self.hs = hs + if hs.config.worker.run_background_tasks: hs.get_clock().looping_call( self._delete_old_forward_extrem_cache, 60 * 60 * 1000 @@ -109,7 +121,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas room_id: str, event_ids: Collection[str], include_given: bool = False, - ) -> List[str]: + ) -> Set[str]: """Get auth events for given event_ids. The events *must* be state events. Args: @@ -118,7 +130,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas include_given: include the given events in result Returns: - list of event_ids + set of event_ids """ # Check if we have indexed the room so we can use the chain cover @@ -147,7 +159,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas def _get_auth_chain_ids_using_cover_index_txn( self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool - ) -> List[str]: + ) -> Set[str]: """Calculates the auth chain IDs using the chain index.""" # First we look up the chain ID/sequence numbers for the given events. @@ -260,11 +272,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas txn.execute(sql, (chain_id, max_no)) results.update(r for r, in txn) - return list(results) + return results def _get_auth_chain_ids_txn( self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool - ) -> List[str]: + ) -> Set[str]: """Calculates the auth chain IDs. This is used when we don't have a cover index for the room. @@ -319,7 +331,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas front = new_front results.update(front) - return list(results) + return results async def get_auth_chain_difference( self, room_id: str, state_sets: List[Set[str]] @@ -737,7 +749,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas room_id, ) - async def get_insertion_event_backwards_extremities_in_room( + async def get_insertion_event_backward_extremities_in_room( self, room_id ) -> Dict[str, int]: """Get the insertion events we know about that we haven't backfilled yet. @@ -754,7 +766,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas Map from event_id to depth """ - def get_insertion_event_backwards_extremities_in_room_txn(txn, room_id): + def get_insertion_event_backward_extremities_in_room_txn(txn, room_id): sql = """ SELECT b.event_id, MAX(e.depth) FROM insertion_events as i /* We only want insertion events that are also marked as backwards extremities */ @@ -770,8 +782,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas return dict(txn) return await self.db_pool.runInteraction( - "get_insertion_event_backwards_extremities_in_room", - get_insertion_event_backwards_extremities_in_room_txn, + "get_insertion_event_backward_extremities_in_room", + get_insertion_event_backward_extremities_in_room_txn, room_id, ) @@ -997,143 +1009,242 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) - async def get_backfill_events(self, room_id: str, event_list: list, limit: int): - """Get a list of Events for a given topic that occurred before (and - including) the events in event_list. Return a list of max size `limit` + def _get_connected_batch_event_backfill_results_txn( + self, txn: LoggingTransaction, insertion_event_id: str, limit: int + ) -> List[BackfillQueueNavigationItem]: + """ + Find any batch connections of a given insertion event. + A batch event points at a insertion event via: + batch_event.content[MSC2716_BATCH_ID] -> insertion_event.content[MSC2716_NEXT_BATCH_ID] Args: - room_id - event_list - limit + txn: The database transaction to use + insertion_event_id: The event ID to navigate from. We will find + batch events that point back at this insertion event. + limit: Max number of event ID's to query for and return + + Returns: + List of batch events that the backfill queue can process + """ + batch_connection_query = """ + SELECT e.depth, e.stream_ordering, c.event_id, e.type FROM insertion_events AS i + /* Find the batch that connects to the given insertion event */ + INNER JOIN batch_events AS c + ON i.next_batch_id = c.batch_id + /* Get the depth of the batch start event from the events table */ + INNER JOIN events AS e USING (event_id) + /* Find an insertion event which matches the given event_id */ + WHERE i.event_id = ? + LIMIT ? """ - event_ids = await self.db_pool.runInteraction( - "get_backfill_events", - self._get_backfill_events, - room_id, - event_list, - limit, - ) - events = await self.get_events_as_list(event_ids) - return sorted(events, key=lambda e: -e.depth) - def _get_backfill_events(self, txn, room_id, event_list, limit): - logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit) + # Find any batch connections for the given insertion event + txn.execute( + batch_connection_query, + (insertion_event_id, limit), + ) + return [ + BackfillQueueNavigationItem( + depth=row[0], + stream_ordering=row[1], + event_id=row[2], + type=row[3], + ) + for row in txn + ] - event_results = set() + def _get_connected_prev_event_backfill_results_txn( + self, txn: LoggingTransaction, event_id: str, limit: int + ) -> List[BackfillQueueNavigationItem]: + """ + Find any events connected by prev_event the specified event_id. - # We want to make sure that we do a breadth-first, "depth" ordered - # search. + Args: + txn: The database transaction to use + event_id: The event ID to navigate from + limit: Max number of event ID's to query for and return + Returns: + List of prev events that the backfill queue can process + """ # Look for the prev_event_id connected to the given event_id - query = """ - SELECT depth, prev_event_id FROM event_edges - /* Get the depth of the prev_event_id from the events table */ + connected_prev_event_query = """ + SELECT depth, stream_ordering, prev_event_id, events.type FROM event_edges + /* Get the depth and stream_ordering of the prev_event_id from the events table */ INNER JOIN events ON prev_event_id = events.event_id - /* Find an event which matches the given event_id */ + /* Look for an edge which matches the given event_id */ WHERE event_edges.event_id = ? AND event_edges.is_state = ? + /* Because we can have many events at the same depth, + * we want to also tie-break and sort on stream_ordering */ + ORDER BY depth DESC, stream_ordering DESC LIMIT ? """ - # Look for the "insertion" events connected to the given event_id - connected_insertion_event_query = """ - SELECT e.depth, i.event_id FROM insertion_event_edges AS i - /* Get the depth of the insertion event from the events table */ - INNER JOIN events AS e USING (event_id) - /* Find an insertion event which points via prev_events to the given event_id */ - WHERE i.insertion_prev_event_id = ? - LIMIT ? + txn.execute( + connected_prev_event_query, + (event_id, False, limit), + ) + return [ + BackfillQueueNavigationItem( + depth=row[0], + stream_ordering=row[1], + event_id=row[2], + type=row[3], + ) + for row in txn + ] + + async def get_backfill_events( + self, room_id: str, seed_event_id_list: list, limit: int + ): + """Get a list of Events for a given topic that occurred before (and + including) the events in seed_event_id_list. Return a list of max size `limit` + + Args: + room_id + seed_event_id_list + limit """ + event_ids = await self.db_pool.runInteraction( + "get_backfill_events", + self._get_backfill_events, + room_id, + seed_event_id_list, + limit, + ) + events = await self.get_events_as_list(event_ids) + return sorted( + events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering) + ) - # Find any batch connections of a given insertion event - batch_connection_query = """ - SELECT e.depth, c.event_id FROM insertion_events AS i - /* Find the batch that connects to the given insertion event */ - INNER JOIN batch_events AS c - ON i.next_batch_id = c.batch_id - /* Get the depth of the batch start event from the events table */ - INNER JOIN events AS e USING (event_id) - /* Find an insertion event which matches the given event_id */ - WHERE i.event_id = ? - LIMIT ? + def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit): + """ + We want to make sure that we do a breadth-first, "depth" ordered search. + We also handle navigating historical branches of history connected by + insertion and batch events. """ + logger.debug( + "_get_backfill_events(room_id=%s): seeding backfill with seed_event_id_list=%s limit=%s", + room_id, + seed_event_id_list, + limit, + ) + + event_id_results = set() # In a PriorityQueue, the lowest valued entries are retrieved first. - # We're using depth as the priority in the queue. - # Depth is lowest at the oldest-in-time message and highest and - # newest-in-time message. We add events to the queue with a negative depth so that - # we process the newest-in-time messages first going backwards in time. + # We're using depth as the priority in the queue and tie-break based on + # stream_ordering. Depth is lowest at the oldest-in-time message and + # highest and newest-in-time message. We add events to the queue with a + # negative depth so that we process the newest-in-time messages first + # going backwards in time. stream_ordering follows the same pattern. queue = PriorityQueue() - for event_id in event_list: - depth = self.db_pool.simple_select_one_onecol_txn( + for seed_event_id in seed_event_id_list: + event_lookup_result = self.db_pool.simple_select_one_txn( txn, table="events", - keyvalues={"event_id": event_id, "room_id": room_id}, - retcol="depth", + keyvalues={"event_id": seed_event_id, "room_id": room_id}, + retcols=( + "type", + "depth", + "stream_ordering", + ), allow_none=True, ) - if depth: - queue.put((-depth, event_id)) + if event_lookup_result is not None: + logger.debug( + "_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s", + room_id, + seed_event_id, + event_lookup_result["depth"], + event_lookup_result["stream_ordering"], + event_lookup_result["type"], + ) + + if event_lookup_result["depth"]: + queue.put( + ( + -event_lookup_result["depth"], + -event_lookup_result["stream_ordering"], + seed_event_id, + event_lookup_result["type"], + ) + ) - while not queue.empty() and len(event_results) < limit: + while not queue.empty() and len(event_id_results) < limit: try: - _, event_id = queue.get_nowait() + _, _, event_id, event_type = queue.get_nowait() except Empty: break - if event_id in event_results: + if event_id in event_id_results: continue - event_results.add(event_id) + event_id_results.add(event_id) # Try and find any potential historical batches of message history. - # - # First we look for an insertion event connected to the current - # event (by prev_event). If we find any, we need to go and try to - # find any batch events connected to the insertion event (by - # batch_id). If we find any, we'll add them to the queue and - # navigate up the DAG like normal in the next iteration of the loop. - txn.execute( - connected_insertion_event_query, (event_id, limit - len(event_results)) - ) - connected_insertion_event_id_results = txn.fetchall() - logger.debug( - "_get_backfill_events: connected_insertion_event_query %s", - connected_insertion_event_id_results, - ) - for row in connected_insertion_event_id_results: - connected_insertion_event_depth = row[0] - connected_insertion_event = row[1] - queue.put((-connected_insertion_event_depth, connected_insertion_event)) + if self.hs.config.experimental.msc2716_enabled: + # We need to go and try to find any batch events connected + # to a given insertion event (by batch_id). If we find any, we'll + # add them to the queue and navigate up the DAG like normal in the + # next iteration of the loop. + if event_type == EventTypes.MSC2716_INSERTION: + # Find any batch connections for the given insertion event + connected_batch_event_backfill_results = ( + self._get_connected_batch_event_backfill_results_txn( + txn, event_id, limit - len(event_id_results) + ) + ) + logger.debug( + "_get_backfill_events(room_id=%s): connected_batch_event_backfill_results=%s", + room_id, + connected_batch_event_backfill_results, + ) + for ( + connected_batch_event_backfill_item + ) in connected_batch_event_backfill_results: + if ( + connected_batch_event_backfill_item.event_id + not in event_id_results + ): + queue.put( + ( + -connected_batch_event_backfill_item.depth, + -connected_batch_event_backfill_item.stream_ordering, + connected_batch_event_backfill_item.event_id, + connected_batch_event_backfill_item.type, + ) + ) - # Find any batch connections for the given insertion event - txn.execute( - batch_connection_query, - (connected_insertion_event, limit - len(event_results)), - ) - batch_start_event_id_results = txn.fetchall() - logger.debug( - "_get_backfill_events: batch_start_event_id_results %s", - batch_start_event_id_results, + # Now we just look up the DAG by prev_events as normal + connected_prev_event_backfill_results = ( + self._get_connected_prev_event_backfill_results_txn( + txn, event_id, limit - len(event_id_results) ) - for row in batch_start_event_id_results: - if row[1] not in event_results: - queue.put((-row[0], row[1])) - - txn.execute(query, (event_id, False, limit - len(event_results))) - prev_event_id_results = txn.fetchall() + ) logger.debug( - "_get_backfill_events: prev_event_ids %s", prev_event_id_results + "_get_backfill_events(room_id=%s): connected_prev_event_backfill_results=%s", + room_id, + connected_prev_event_backfill_results, ) + for ( + connected_prev_event_backfill_item + ) in connected_prev_event_backfill_results: + if connected_prev_event_backfill_item.event_id not in event_id_results: + queue.put( + ( + -connected_prev_event_backfill_item.depth, + -connected_prev_event_backfill_item.stream_ordering, + connected_prev_event_backfill_item.event_id, + connected_prev_event_backfill_item.type, + ) + ) - for row in prev_event_id_results: - if row[1] not in event_results: - queue.put((-row[0], row[1])) - - return event_results + return event_id_results async def get_missing_events(self, room_id, earliest_events, latest_events, limit): ids = await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index b7554154ac..a1d7a9b413 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -975,6 +975,17 @@ class PersistEventsStore: to_delete = delta_state.to_delete to_insert = delta_state.to_insert + # Figure out the changes of membership to invalidate the + # `get_rooms_for_user` cache. + # We find out which membership events we may have deleted + # and which we have added, then we invalidate the caches for all + # those users. + members_changed = { + state_key + for ev_type, state_key in itertools.chain(to_delete, to_insert) + if ev_type == EventTypes.Member + } + if delta_state.no_longer_in_room: # Server is no longer in the room so we delete the room from # current_state_events, being careful we've already updated the @@ -993,6 +1004,11 @@ class PersistEventsStore: """ txn.execute(sql, (stream_id, self._instance_name, room_id)) + # We also want to invalidate the membership caches for users + # that were in the room. + users_in_room = self.store.get_users_in_room_txn(txn, room_id) + members_changed.update(users_in_room) + self.db_pool.simple_delete_txn( txn, table="current_state_events", @@ -1102,17 +1118,6 @@ class PersistEventsStore: # Invalidate the various caches - # Figure out the changes of membership to invalidate the - # `get_rooms_for_user` cache. - # We find out which membership events we may have deleted - # and which we have added, then we invalidate the caches for all - # those users. - members_changed = { - state_key - for ev_type, state_key in itertools.chain(to_delete, to_insert) - if ev_type == EventTypes.Member - } - for member in members_changed: txn.call_after( self.store.get_rooms_for_user_with_stream_ordering.invalidate, @@ -1801,9 +1806,7 @@ class PersistEventsStore: ) if rel_type == RelationTypes.REPLACE: - txn.call_after( - self.store.get_applicable_edit.invalidate, (parent_id, event.room_id) - ) + txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) if rel_type == RelationTypes.THREAD: txn.call_after( @@ -1814,7 +1817,7 @@ class PersistEventsStore: # potentially error-prone) so it is always invalidated. txn.call_after( self.store.get_thread_participated.invalidate, - (parent_id, event.room_id, event.sender), + (parent_id, event.sender), ) def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): @@ -2215,9 +2218,14 @@ class PersistEventsStore: " SELECT 1 FROM event_backward_extremities" " WHERE event_id = ? AND room_id = ?" " )" + # 1. Don't add an event as a extremity again if we already persisted it + # as a non-outlier. + # 2. Don't add an outlier as an extremity if it has no prev_events " AND NOT EXISTS (" - " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " - " AND outlier = ?" + " SELECT 1 FROM events" + " LEFT JOIN event_edges edge" + " ON edge.event_id = events.event_id" + " WHERE events.event_id = ? AND events.room_id = ? AND (events.outlier = ? OR edge.event_id IS NULL)" " )" ) @@ -2243,6 +2251,10 @@ class PersistEventsStore: (ev.event_id, ev.room_id) for ev in events if not ev.internal_metadata.is_outlier() + # If we encountered an event with no prev_events, then we might + # as well remove it now because it won't ever have anything else + # to backfill from. + or len(ev.prev_event_ids()) == 0 ], ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 8d4287045a..2a255d1031 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -408,7 +408,7 @@ class EventsWorkerStore(SQLBaseStore): include the previous states content in the unsigned field. allow_rejected: If True, return rejected events. Otherwise, - omits rejeted events from the response. + omits rejected events from the response. Returns: A mapping from event_id to event. @@ -1854,7 +1854,7 @@ class EventsWorkerStore(SQLBaseStore): forward_edge_query = """ SELECT 1 FROM event_edges /* Check to make sure the event referencing our event in question is not rejected */ - LEFT JOIN rejections ON event_edges.event_id == rejections.event_id + LEFT JOIN rejections ON event_edges.event_id = rejections.event_id WHERE event_edges.room_id = ? AND event_edges.prev_event_id = ? diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 4f05811a77..d3c4611686 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -12,15 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast from synapse.api.presence import PresenceState, UserPresenceState from synapse.replication.tcp.streams import PresenceStream from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.engines import PostgresEngine from synapse.storage.types import Connection -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, + MultiWriterIdGenerator, + StreamIdGenerator, +) from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.iterutils import batch_iter @@ -35,7 +43,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore): database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", - ): + ) -> None: super().__init__(database, db_conn, hs) # Used by `PresenceStore._get_active_presence()` @@ -54,11 +62,14 @@ class PresenceStore(PresenceBackgroundUpdateStore): database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", - ): + ) -> None: super().__init__(database, db_conn, hs) + self._instance_name = hs.get_instance_name() + self._presence_id_gen: AbstractStreamIdGenerator + self._can_persist_presence = ( - hs.get_instance_name() in hs.config.worker.writers.presence + self._instance_name in hs.config.worker.writers.presence ) if isinstance(database.engine, PostgresEngine): @@ -109,7 +120,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): return stream_orderings[-1], self._presence_id_gen.get_current_token() - def _update_presence_txn(self, txn, stream_orderings, presence_states): + def _update_presence_txn( + self, txn: LoggingTransaction, stream_orderings, presence_states + ) -> None: for stream_id, state in zip(stream_orderings, presence_states): txn.call_after( self.presence_stream_cache.entity_has_changed, state.user_id, stream_id @@ -183,19 +196,23 @@ class PresenceStore(PresenceBackgroundUpdateStore): if last_id == current_id: return [], current_id, False - def get_all_presence_updates_txn(txn): + def get_all_presence_updates_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, list]], int, bool]: sql = """ SELECT stream_id, user_id, state, last_active_ts, last_federation_update_ts, last_user_sync_ts, - status_msg, - currently_active + status_msg, currently_active FROM presence_stream WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (last_id, current_id, limit)) - updates = [(row[0], row[1:]) for row in txn] + updates = cast( + List[Tuple[int, list]], + [(row[0], row[1:]) for row in txn], + ) upper_bound = current_id limited = False @@ -210,7 +227,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): ) @cached() - def _get_presence_for_user(self, user_id): + def _get_presence_for_user(self, user_id: str) -> None: raise NotImplementedError() @cachedList( @@ -218,7 +235,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): list_name="user_ids", num_args=1, ) - async def get_presence_for_users(self, user_ids): + async def get_presence_for_users( + self, user_ids: Iterable[str] + ) -> Dict[str, UserPresenceState]: rows = await self.db_pool.simple_select_many_batch( table="presence_stream", column="user_id", @@ -257,7 +276,9 @@ class PresenceStore(PresenceBackgroundUpdateStore): True if the user should have full presence sent to them, False otherwise. """ - def _should_user_receive_full_presence_with_token_txn(txn): + def _should_user_receive_full_presence_with_token_txn( + txn: LoggingTransaction, + ) -> bool: sql = """ SELECT 1 FROM users_to_send_full_presence_to WHERE user_id = ? @@ -271,7 +292,7 @@ class PresenceStore(PresenceBackgroundUpdateStore): _should_user_receive_full_presence_with_token_txn, ) - async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]): + async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None: """Adds to the list of users who should receive a full snapshot of presence upon their next sync. @@ -353,10 +374,10 @@ class PresenceStore(PresenceBackgroundUpdateStore): return users_to_state - def get_current_presence_token(self): + def get_current_presence_token(self) -> int: return self._presence_id_gen.get_current_token() - def _get_active_presence(self, db_conn: Connection): + def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]: """Fetch non-offline presence from the database so that we can register the appropriate time outs. """ @@ -379,12 +400,12 @@ class PresenceStore(PresenceBackgroundUpdateStore): return [UserPresenceState(**row) for row in rows] - def take_presence_startup_info(self): + def take_presence_startup_info(self) -> List[UserPresenceState]: active_on_startup = self._presence_on_startup - self._presence_on_startup = None + self._presence_on_startup = [] return active_on_startup - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows) -> None: if stream_name == PresenceStream.NAME: self._presence_id_gen.advance(instance_name, token) for row in rows: diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index e87a8fb85d..2e3818e432 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -13,9 +13,10 @@ # limitations under the License. import logging -from typing import Any, List, Set, Tuple +from typing import Any, List, Set, Tuple, cast from synapse.api.errors import SynapseError +from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.types import RoomStreamToken @@ -55,7 +56,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): ) def _purge_history_txn( - self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool + self, + txn: LoggingTransaction, + room_id: str, + token: RoomStreamToken, + delete_local_events: bool, ) -> Set[int]: # Tables that should be pruned: # event_auth @@ -273,7 +278,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): """, (room_id,), ) - (min_depth,) = txn.fetchone() + (min_depth,) = cast(Tuple[int], txn.fetchone()) logger.info("[purge] updating room_depth to %d", min_depth) @@ -318,7 +323,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "purge_room", self._purge_room_txn, room_id ) - def _purge_room_txn(self, txn, room_id: str) -> List[int]: + def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]: # First we fetch all the state groups that should be deleted, before # we delete that information. txn.execute( diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index e01c94930a..92539f5d41 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -42,7 +42,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -def _load_rules(rawrules, enabled_map, use_new_defaults=False): +def _load_rules(rawrules, enabled_map): ruleslist = [] for rawrule in rawrules: rule = dict(rawrule) @@ -52,7 +52,7 @@ def _load_rules(rawrules, enabled_map, use_new_defaults=False): ruleslist.append(rule) # We're going to be mutating this a lot, so do a deep copy - rules = list(list_with_base_rules(ruleslist, use_new_defaults)) + rules = list(list_with_base_rules(ruleslist)) for i, rule in enumerate(rules): rule_id = rule["rule_id"] @@ -112,10 +112,6 @@ class PushRulesWorkerStore( prefilled_cache=push_rules_prefill, ) - self._users_new_default_push_rules = ( - hs.config.server.users_new_default_push_rules - ) - @abc.abstractmethod def get_max_push_rules_stream_id(self): """Get the position of the push rules stream. @@ -145,9 +141,7 @@ class PushRulesWorkerStore( enabled_map = await self.get_push_rules_enabled_for_user(user_id) - use_new_defaults = user_id in self._users_new_default_push_rules - - return _load_rules(rows, enabled_map, use_new_defaults) + return _load_rules(rows, enabled_map) @cached(max_entries=5000) async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]: @@ -206,13 +200,7 @@ class PushRulesWorkerStore( enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids) for user_id, rules in results.items(): - use_new_defaults = user_id in self._users_new_default_push_rules - - results[user_id] = _load_rules( - rules, - enabled_map_by_user.get(user_id, {}), - use_new_defaults, - ) + results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) return results diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index aac94fa464..17110bb033 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -622,10 +622,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) -> None: """Record a mapping from an external user id to a mxid + See notes in _record_user_external_id_txn about what constitutes valid data. + Args: auth_provider: identifier for the remote auth provider external_id: id on that system user_id: complete mxid that it is mapped to + Raises: ExternalIDReuseException if the new external_id could not be mapped. """ @@ -648,6 +651,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): external_id: str, user_id: str, ) -> None: + """ + Record a mapping from an external user id to a mxid. + + Note that the auth provider IDs (and the external IDs) are not validated + against configured IdPs as Synapse does not know its relationship to + external systems. For example, it might be useful to pre-configure users + before enabling a new IdP or an IdP might be temporarily offline, but + still valid. + + Args: + txn: The database transaction. + auth_provider: identifier for the remote auth provider + external_id: id on that system + user_id: complete mxid that it is mapped to + """ self.db_pool.simple_insert_txn( txn, @@ -687,10 +705,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): """Replace mappings from external user ids to a mxid in a single transaction. All mappings are deleted and the new ones are created. + See notes in _record_user_external_id_txn about what constitutes valid data. + Args: record_external_ids: List with tuple of auth_provider and external_id to record user_id: complete mxid that it is mapped to + Raises: ExternalIDReuseException if the new external_id could not be mapped. """ diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 37468a5183..36aa1092f6 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -13,12 +13,23 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Union, + cast, +) import attr from frozendict import frozendict -from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.constants import RelationTypes from synapse.events import EventBase from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -28,24 +39,27 @@ from synapse.storage.database import ( make_in_list_sql_clause, ) from synapse.storage.databases.main.stream import generate_pagination_where_clause -from synapse.storage.relations import ( - AggregationPaginationToken, - PaginationChunk, - RelationPaginationToken, -) -from synapse.types import JsonDict -from synapse.util.caches.descriptors import cached +from synapse.storage.engines import PostgresEngine +from synapse.storage.relations import AggregationPaginationToken, PaginationChunk +from synapse.types import JsonDict, RoomStreamToken, StreamToken +from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @attr.s(slots=True, frozen=True, auto_attribs=True) class _ThreadAggregation: + # The latest event in the thread. latest_event: EventBase + # The latest edit to the latest event in the thread. + latest_edit: Optional[EventBase] + # The total number of events in the thread. count: int + # True if the current user has sent an event to the thread. current_user_participated: bool @@ -87,8 +101,8 @@ class RelationsWorkerStore(SQLBaseStore): aggregation_key: Optional[str] = None, limit: int = 5, direction: str = "b", - from_token: Optional[RelationPaginationToken] = None, - to_token: Optional[RelationPaginationToken] = None, + from_token: Optional[StreamToken] = None, + to_token: Optional[StreamToken] = None, ) -> PaginationChunk: """Get a list of relations for an event, ordered by topological ordering. @@ -127,8 +141,10 @@ class RelationsWorkerStore(SQLBaseStore): pagination_clause = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type] - to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type] + from_token=from_token.room_key.as_historical_tuple() + if from_token + else None, + to_token=to_token.room_key.as_historical_tuple() if to_token else None, engine=self.database_engine, ) @@ -166,12 +182,27 @@ class RelationsWorkerStore(SQLBaseStore): last_topo_id = row[1] last_stream_id = row[2] - next_batch = None + # If there are more events, generate the next pagination key. + next_token = None if len(events) > limit and last_topo_id and last_stream_id: - next_batch = RelationPaginationToken(last_topo_id, last_stream_id) + next_key = RoomStreamToken(last_topo_id, last_stream_id) + if from_token: + next_token = from_token.copy_and_replace("room_key", next_key) + else: + next_token = StreamToken( + room_key=next_key, + presence_key=0, + typing_key=0, + receipt_key=0, + account_data_key=0, + push_rules_key=0, + to_device_key=0, + device_list_key=0, + groups_key=0, + ) return PaginationChunk( - chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token + chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token ) return await self.db_pool.runInteraction( @@ -340,20 +371,24 @@ class RelationsWorkerStore(SQLBaseStore): ) @cached() - async def get_applicable_edit( - self, event_id: str, room_id: str - ) -> Optional[EventBase]: + def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: + raise NotImplementedError() + + @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") + async def _get_applicable_edits( + self, event_ids: Collection[str] + ) -> Dict[str, Optional[EventBase]]: """Get the most recent edit (if any) that has happened for the given - event. + events. Correctly handles checking whether edits were allowed to happen. Args: - event_id: The original event ID - room_id: The original event's room ID + event_ids: The original event IDs Returns: - The most recent edit, if any. + A map of the most recent edit for each event. If there are no edits, + the event will map to None. """ # We only allow edits for `m.room.message` events that have the same sender @@ -362,139 +397,244 @@ class RelationsWorkerStore(SQLBaseStore): # Fetches latest edit that has the same type and sender as the # original, and is an `m.room.message`. - sql = """ - SELECT edit.event_id FROM events AS edit - INNER JOIN event_relations USING (event_id) - INNER JOIN events AS original ON - original.event_id = relates_to_id - AND edit.type = original.type - AND edit.sender = original.sender - WHERE - relates_to_id = ? - AND relation_type = ? - AND edit.room_id = ? - AND edit.type = 'm.room.message' - ORDER by edit.origin_server_ts DESC, edit.event_id DESC - LIMIT 1 - """ + if isinstance(self.database_engine, PostgresEngine): + # The `DISTINCT ON` clause will pick the *first* row it encounters, + # so ordering by origin server ts + event ID desc will ensure we get + # the latest edit. + sql = """ + SELECT DISTINCT ON (original.event_id) original.event_id, edit.event_id FROM events AS edit + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS original ON + original.event_id = relates_to_id + AND edit.type = original.type + AND edit.sender = original.sender + AND edit.room_id = original.room_id + WHERE + %s + AND relation_type = ? + AND edit.type = 'm.room.message' + ORDER by original.event_id DESC, edit.origin_server_ts DESC, edit.event_id DESC + """ + else: + # SQLite uses a simplified query which returns all edits for an + # original event. The results are then de-duplicated when turned into + # a dict. Due to the chosen ordering, the latest edit stomps on + # earlier edits. + sql = """ + SELECT original.event_id, edit.event_id FROM events AS edit + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS original ON + original.event_id = relates_to_id + AND edit.type = original.type + AND edit.sender = original.sender + AND edit.room_id = original.room_id + WHERE + %s + AND relation_type = ? + AND edit.type = 'm.room.message' + ORDER by edit.origin_server_ts, edit.event_id + """ - def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]: - txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id)) - row = txn.fetchone() - if row: - return row[0] - return None + def _get_applicable_edits_txn(txn: LoggingTransaction) -> Dict[str, str]: + clause, args = make_in_list_sql_clause( + txn.database_engine, "relates_to_id", event_ids + ) + args.append(RelationTypes.REPLACE) - edit_id = await self.db_pool.runInteraction( - "get_applicable_edit", _get_applicable_edit_txn + txn.execute(sql % (clause,), args) + return dict(cast(Iterable[Tuple[str, str]], txn.fetchall())) + + edit_ids = await self.db_pool.runInteraction( + "get_applicable_edits", _get_applicable_edits_txn ) - if not edit_id: - return None + edits = await self.get_events(edit_ids.values()) # type: ignore[attr-defined] - return await self.get_event(edit_id, allow_none=True) # type: ignore[attr-defined] + # Map to the original event IDs to the edit events. + # + # There might not be an edit event due to there being no edits or + # due to the event not being known, either case is treated the same. + return { + original_event_id: edits.get(edit_ids.get(original_event_id)) + for original_event_id in event_ids + } @cached() - async def get_thread_summary( - self, event_id: str, room_id: str - ) -> Tuple[int, Optional[EventBase]]: - """Get the number of threaded replies and the latest reply (if any) for the given event. + def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]: + raise NotImplementedError() + + @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") + async def _get_thread_summaries( + self, event_ids: Collection[str] + ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]: + """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event. Args: - event_id: Summarize the thread related to this event ID. - room_id: The room the event belongs to. + event_ids: Summarize the thread related to this event ID. Returns: - The number of items in the thread and the most recent response, if any. + A map of the thread summary each event. A missing event implies there + are no threaded replies. + + Each summary is a tuple of: + The number of events in the thread. + The most recent event in the thread. + The most recent edit to the most recent event in the thread, if applicable. """ - def _get_thread_summary_txn( + def _get_thread_summaries_txn( txn: LoggingTransaction, - ) -> Tuple[int, Optional[str]]: - # Fetch the latest event ID in the thread. + ) -> Tuple[Dict[str, int], Dict[str, str]]: + # Fetch the count of threaded events and the latest event ID. # TODO Should this only allow m.room.message events. - sql = """ - SELECT event_id - FROM event_relations - INNER JOIN events USING (event_id) - WHERE - relates_to_id = ? - AND room_id = ? - AND relation_type = ? - ORDER BY topological_ordering DESC, stream_ordering DESC - LIMIT 1 - """ + if isinstance(self.database_engine, PostgresEngine): + # The `DISTINCT ON` clause will pick the *first* row it encounters, + # so ordering by topological ordering + stream ordering desc will + # ensure we get the latest event in the thread. + sql = """ + SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = child.room_id + WHERE + %s + AND relation_type = ? + ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC + """ + else: + # SQLite uses a simplified query which returns all entries for a + # thread. The first result for each thread is chosen to and subsequent + # results for a thread are ignored. + sql = """ + SELECT parent.event_id, child.event_id FROM events AS child + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = child.room_id + WHERE + %s + AND relation_type = ? + ORDER BY child.topological_ordering DESC, child.stream_ordering DESC + """ - txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) - row = txn.fetchone() - if row is None: - return 0, None + clause, args = make_in_list_sql_clause( + txn.database_engine, "relates_to_id", event_ids + ) + args.append(RelationTypes.THREAD) - latest_event_id = row[0] + txn.execute(sql % (clause,), args) + latest_event_ids = {} + for parent_event_id, child_event_id in txn: + # Only consider the latest threaded reply (by topological ordering). + if parent_event_id not in latest_event_ids: + latest_event_ids[parent_event_id] = child_event_id + + # If no threads were found, bail. + if not latest_event_ids: + return {}, latest_event_ids # Fetch the number of threaded replies. sql = """ - SELECT COUNT(event_id) - FROM event_relations - INNER JOIN events USING (event_id) + SELECT parent.event_id, COUNT(child.event_id) FROM events AS child + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = child.room_id WHERE - relates_to_id = ? - AND room_id = ? + %s AND relation_type = ? + GROUP BY parent.event_id """ - txn.execute(sql, (event_id, room_id, RelationTypes.THREAD)) - count = cast(Tuple[int], txn.fetchone())[0] - return count, latest_event_id + # Regenerate the arguments since only threads found above could + # possibly have any replies. + clause, args = make_in_list_sql_clause( + txn.database_engine, "relates_to_id", latest_event_ids.keys() + ) + args.append(RelationTypes.THREAD) - count, latest_event_id = await self.db_pool.runInteraction( - "get_thread_summary", _get_thread_summary_txn + txn.execute(sql % (clause,), args) + counts = dict(cast(List[Tuple[str, int]], txn.fetchall())) + + return counts, latest_event_ids + + counts, latest_event_ids = await self.db_pool.runInteraction( + "get_thread_summaries", _get_thread_summaries_txn ) - latest_event = None - if latest_event_id: - latest_event = await self.get_event(latest_event_id, allow_none=True) # type: ignore[attr-defined] + latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined] + + # Check to see if any of those events are edited. + latest_edits = await self._get_applicable_edits(latest_event_ids.values()) - return count, latest_event + # Map to the event IDs to the thread summary. + # + # There might not be a summary due to there not being a thread or + # due to the latest event not being known, either case is treated the same. + summaries = {} + for parent_event_id, latest_event_id in latest_event_ids.items(): + latest_event = latest_events.get(latest_event_id) + + summary = None + if latest_event: + latest_edit = latest_edits.get(latest_event_id) + summary = (counts[parent_event_id], latest_event, latest_edit) + summaries[parent_event_id] = summary + + return summaries @cached() - async def get_thread_participated( - self, event_id: str, room_id: str, user_id: str - ) -> bool: - """Get whether the requesting user participated in a thread. + def get_thread_participated(self, event_id: str, user_id: str) -> bool: + raise NotImplementedError() + + @cachedList(cached_method_name="get_thread_participated", list_name="event_ids") + async def _get_threads_participated( + self, event_ids: Collection[str], user_id: str + ) -> Dict[str, bool]: + """Get whether the requesting user participated in the given threads. - This is separate from get_thread_summary since that can be cached across - all users while this value is specific to the requeser. + This is separate from get_thread_summaries since that can be cached across + all users while this value is specific to the requester. Args: - event_id: The thread related to this event ID. - room_id: The room the event belongs to. + event_ids: The thread related to these event IDs. user_id: The user requesting the summary. Returns: - True if the requesting user participated in the thread, otherwise false. + A map of event ID to a boolean which represents if the requesting + user participated in that event's thread, otherwise false. """ - def _get_thread_summary_txn(txn: LoggingTransaction) -> bool: + def _get_thread_summary_txn(txn: LoggingTransaction) -> Set[str]: # Fetch whether the requester has participated or not. sql = """ - SELECT 1 - FROM event_relations - INNER JOIN events USING (event_id) + SELECT DISTINCT relates_to_id + FROM events AS child + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = child.room_id WHERE - relates_to_id = ? - AND room_id = ? + %s AND relation_type = ? - AND sender = ? + AND child.sender = ? """ - txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id)) - return bool(txn.fetchone()) + clause, args = make_in_list_sql_clause( + txn.database_engine, "relates_to_id", event_ids + ) + args.extend((RelationTypes.THREAD, user_id)) - return await self.db_pool.runInteraction( + txn.execute(sql % (clause,), args) + return {row[0] for row in txn.fetchall()} + + participated_threads = await self.db_pool.runInteraction( "get_thread_summary", _get_thread_summary_txn ) + return {event_id: event_id in participated_threads for event_id in event_ids} + async def events_have_relations( self, parent_ids: List[str], @@ -612,9 +752,6 @@ class RelationsWorkerStore(SQLBaseStore): The bundled aggregations for an event, if bundled aggregations are enabled and the event can have bundled aggregations. """ - # State events and redacted events do not get bundled aggregations. - if event.is_state() or event.internal_metadata.is_redacted(): - return None # Do not bundle aggregations for an event which represents an edit or an # annotation. It does not make sense for them to have related events. @@ -634,43 +771,21 @@ class RelationsWorkerStore(SQLBaseStore): annotations = await self.get_aggregation_groups_for_event(event_id, room_id) if annotations.chunk: - aggregations.annotations = annotations.to_dict() + aggregations.annotations = await annotations.to_dict( + cast("DataStore", self) + ) references = await self.get_relations_for_event( event_id, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: - aggregations.references = references.to_dict() - - edit = None - if event.type == EventTypes.Message: - edit = await self.get_applicable_edit(event_id, room_id) - - if edit: - aggregations.replace = edit - - # If this event is the start of a thread, include a summary of the replies. - if self._msc3440_enabled: - thread_count, latest_thread_event = await self.get_thread_summary( - event_id, room_id - ) - participated = await self.get_thread_participated( - event_id, room_id, user_id - ) - if latest_thread_event: - aggregations.thread = _ThreadAggregation( - latest_event=latest_thread_event, - count=thread_count, - current_user_participated=participated, - ) + aggregations.references = await references.to_dict(cast("DataStore", self)) # Store the bundled aggregations in the event metadata for later use. return aggregations async def get_bundled_aggregations( - self, - events: Iterable[EventBase], - user_id: str, + self, events: Iterable[EventBase], user_id: str ) -> Dict[str, BundledAggregations]: """Generate bundled aggregations for events. @@ -682,14 +797,60 @@ class RelationsWorkerStore(SQLBaseStore): A map of event ID to the bundled aggregation for the event. Not all events may have bundled aggregations in the results. """ + # The already processed event IDs. Tracked separately from the result + # since the result omits events which do not have bundled aggregations. + seen_event_ids = set() - # TODO Parallelize. - results = {} + # State events and redacted events do not get bundled aggregations. + events = [ + event + for event in events + if not event.is_state() and not event.internal_metadata.is_redacted() + ] + + # event ID -> bundled aggregation in non-serialized form. + results: Dict[str, BundledAggregations] = {} + + # Fetch other relations per event. for event in events: + # De-duplicate events by ID to handle the same event requested multiple + # times. The caches that _get_bundled_aggregation_for_event use should + # capture this, but best to reduce work. + if event.event_id in seen_event_ids: + continue + seen_event_ids.add(event.event_id) + event_result = await self._get_bundled_aggregation_for_event(event, user_id) if event_result: results[event.event_id] = event_result + # Fetch any edits. + edits = await self._get_applicable_edits(seen_event_ids) + for event_id, edit in edits.items(): + results.setdefault(event_id, BundledAggregations()).replace = edit + + # Fetch thread summaries. + if self._msc3440_enabled: + summaries = await self._get_thread_summaries(seen_event_ids) + # Only fetch participated for a limited selection based on what had + # summaries. + participated = await self._get_threads_participated( + summaries.keys(), user_id + ) + for event_id, summary in summaries.items(): + if summary: + thread_count, latest_thread_event, edit = summary + results.setdefault( + event_id, BundledAggregations() + ).thread = _ThreadAggregation( + latest_event=latest_thread_event, + latest_edit=edit, + count=thread_count, + # If there's a thread summary it must also exist in the + # participated dictionary. + current_user_participated=participated[event_id], + ) + return results diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 95167116c9..0416df64ce 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1498,7 +1498,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") async def upsert_room_on_join( - self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase] + self, room_id: str, room_version: RoomVersion, state_events: List[EventBase] ) -> None: """Ensure that the room is stored in the table @@ -1511,7 +1511,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): has_auth_chain_index = await self.has_auth_chain_index(room_id) create_event = None - for e in auth_events: + for e in state_events: if (e.type, e.state_key) == (EventTypes.Create, ""): create_event = e break diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 4489732fda..e48ec5f495 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -504,6 +504,68 @@ class RoomMemberWorkerStore(EventsWorkerStore): for room_id, instance, stream_id in txn ) + @cachedList( + cached_method_name="get_rooms_for_user_with_stream_ordering", + list_name="user_ids", + ) + async def get_rooms_for_users_with_stream_ordering( + self, user_ids: Collection[str] + ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: + """A batched version of `get_rooms_for_user_with_stream_ordering`. + + Returns: + Map from user_id to set of rooms that is currently in. + """ + return await self.db_pool.runInteraction( + "get_rooms_for_users_with_stream_ordering", + self._get_rooms_for_users_with_stream_ordering_txn, + user_ids, + ) + + def _get_rooms_for_users_with_stream_ordering_txn( + self, txn, user_ids: Collection[str] + ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: + + clause, args = make_in_list_sql_clause( + self.database_engine, + "c.state_key", + user_ids, + ) + + if self._current_state_events_membership_up_to_date: + sql = f""" + SELECT c.state_key, room_id, e.instance_name, e.stream_ordering + FROM current_state_events AS c + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND c.membership = ? + AND {clause} + """ + else: + sql = f""" + SELECT c.state_key, room_id, e.instance_name, e.stream_ordering + FROM current_state_events AS c + INNER JOIN room_memberships AS m USING (room_id, event_id) + INNER JOIN events AS e USING (room_id, event_id) + WHERE + c.type = 'm.room.member' + AND m.membership = ? + AND {clause} + """ + + txn.execute(sql, [Membership.JOIN] + args) + + result = {user_id: set() for user_id in user_ids} + for user_id, room_id, instance, stream_id in txn: + result[user_id].add( + GetRoomsForUserWithStreamOrdering( + room_id, PersistedEventPosition(instance, stream_id) + ) + ) + + return {user_id: frozenset(v) for user_id, v in result.items()} + async def get_users_server_still_shares_room_with( self, user_ids: Collection[str] ) -> Set[str]: diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 2d085a5764..acea300ed3 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -28,6 +28,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.types import JsonDict if TYPE_CHECKING: from synapse.server import HomeServer @@ -381,17 +382,19 @@ class SearchStore(SearchBackgroundUpdateStore): ): super().__init__(database, db_conn, hs) - async def search_msgs(self, room_ids, search_term, keys): + async def search_msgs( + self, room_ids: Collection[str], search_term: str, keys: Iterable[str] + ) -> JsonDict: """Performs a full text search over events with given keys. Args: - room_ids (list): List of room ids to search in - search_term (str): Search term to search for - keys (list): List of keys to search in, currently supports + room_ids: List of room ids to search in + search_term: Search term to search for + keys: List of keys to search in, currently supports "content.body", "content.name", "content.topic" Returns: - list of dicts + Dictionary of results """ clauses = [] @@ -499,10 +502,10 @@ class SearchStore(SearchBackgroundUpdateStore): self, room_ids: Collection[str], search_term: str, - keys: List[str], + keys: Iterable[str], limit, pagination_token: Optional[str] = None, - ) -> List[dict]: + ) -> JsonDict: """Performs a full text search over events with given keys. Args: diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index f7c778bdf2..e7fddd2426 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -58,7 +58,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", - ): + ) -> None: super().__init__(database, db_conn, hs) self.server_name = hs.hostname @@ -234,10 +234,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): processed_event_count = 0 for room_id, event_count in rooms_to_work_on: - is_in_room = await self.is_host_joined(room_id, self.server_name) + is_in_room = await self.is_host_joined(room_id, self.server_name) # type: ignore[attr-defined] if is_in_room: - users_with_profile = await self.get_users_in_room_with_profiles(room_id) + users_with_profile = await self.get_users_in_room_with_profiles(room_id) # type: ignore[attr-defined] # Throw away users excluded from the directory. users_with_profile = { user_id: profile @@ -368,7 +368,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): for user_id in users_to_work_on: if await self.should_include_local_user_in_dir(user_id): - profile = await self.get_profileinfo(get_localpart_from_id(user_id)) + profile = await self.get_profileinfo(get_localpart_from_id(user_id)) # type: ignore[attr-defined] await self.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) @@ -397,7 +397,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # technically it could be DM-able. In the future, this could potentially # be configurable per-appservice whether the appservice sender can be # contacted. - if self.get_app_service_by_user_id(user) is not None: + if self.get_app_service_by_user_id(user) is not None: # type: ignore[attr-defined] return False # We're opting to exclude appservice users (anyone matching the user @@ -405,17 +405,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # they could be DM-able. In the future, this could potentially # be configurable per-appservice whether the appservice users can be # contacted. - if self.get_if_app_services_interested_in_user(user): + if self.get_if_app_services_interested_in_user(user): # type: ignore[attr-defined] # TODO we might want to make this configurable for each app service return False # Support users are for diagnostics and should not appear in the user directory. - if await self.is_support_user(user): + if await self.is_support_user(user): # type: ignore[attr-defined] return False # Deactivated users aren't contactable, so should not appear in the user directory. try: - if await self.get_user_deactivated_status(user): + if await self.get_user_deactivated_status(user): # type: ignore[attr-defined] return False except StoreError: # No such user in the users table. No need to do this when calling @@ -433,20 +433,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): (EventTypes.RoomHistoryVisibility, ""), ) - current_state_ids = await self.get_filtered_current_state_ids( + current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined] room_id, StateFilter.from_types(types_to_filter) ) join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) if join_rules_id: - join_rule_ev = await self.get_event(join_rules_id, allow_none=True) + join_rule_ev = await self.get_event(join_rules_id, allow_none=True) # type: ignore[attr-defined] if join_rule_ev: if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: return True hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) if hist_vis_id: - hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) + hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) # type: ignore[attr-defined] if hist_vis_ev: if ( hist_vis_ev.content.get("history_visibility") diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 7614d76ac6..3af69a2076 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -13,11 +13,23 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + Optional, + Sequence, + Set, + Tuple, +) import attr +from twisted.internet import defer + from synapse.api.constants import EventTypes +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -29,6 +41,12 @@ from synapse.storage.state import StateFilter from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator from synapse.types import MutableStateMap, StateKey, StateMap +from synapse.util import unwrapFirstError +from synapse.util.async_helpers import ( + AbstractObservableDeferred, + ObservableDeferred, + yieldable_gather_results, +) from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache @@ -37,7 +55,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) - MAX_STATE_DELTA_HOPS = 100 @@ -106,6 +123,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): 500000, ) + # Current ongoing get_state_for_groups in-flight requests + # {group ID -> {StateFilter -> ObservableDeferred}} + self._state_group_inflight_requests: Dict[ + int, Dict[StateFilter, AbstractObservableDeferred[StateMap[str]]] + ] = {} + def get_max_state_group_txn(txn: Cursor) -> int: txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") return txn.fetchone()[0] # type: ignore @@ -157,7 +180,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ) async def _get_state_groups_from_groups( - self, groups: List[int], state_filter: StateFilter + self, groups: Sequence[int], state_filter: StateFilter ) -> Dict[int, StateMap[str]]: """Returns the state groups for a given set of groups from the database, filtering on types of state events. @@ -228,6 +251,150 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_filter.filter_state(state_dict_ids), not missing_types + def _get_state_for_group_gather_inflight_requests( + self, group: int, state_filter_left_over: StateFilter + ) -> Tuple[Sequence[AbstractObservableDeferred[StateMap[str]]], StateFilter]: + """ + Attempts to gather in-flight requests and re-use them to retrieve state + for the given state group, filtered with the given state filter. + + Used as part of _get_state_for_group_using_inflight_cache. + + Returns: + Tuple of two values: + A sequence of ObservableDeferreds to observe + A StateFilter representing what else needs to be requested to fulfill the request + """ + + inflight_requests = self._state_group_inflight_requests.get(group) + if inflight_requests is None: + # no requests for this group, need to retrieve it all ourselves + return (), state_filter_left_over + + # The list of ongoing requests which will help narrow the current request. + reusable_requests = [] + for (request_state_filter, request_deferred) in inflight_requests.items(): + new_state_filter_left_over = state_filter_left_over.approx_difference( + request_state_filter + ) + if new_state_filter_left_over == state_filter_left_over: + # Reusing this request would not gain us anything, so don't bother. + continue + + reusable_requests.append(request_deferred) + state_filter_left_over = new_state_filter_left_over + if state_filter_left_over == StateFilter.none(): + # we have managed to collect enough of the in-flight requests + # to cover our StateFilter and give us the state we need. + break + + return reusable_requests, state_filter_left_over + + async def _get_state_for_group_fire_request( + self, group: int, state_filter: StateFilter + ) -> StateMap[str]: + """ + Fires off a request to get the state at a state group, + potentially filtering by type and/or state key. + + This request will be tracked in the in-flight request cache and automatically + removed when it is finished. + + Used as part of _get_state_for_group_using_inflight_cache. + + Args: + group: ID of the state group for which we want to get state + state_filter: the state filter used to fetch state from the database + """ + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence + + # Help the cache hit ratio by expanding the filter a bit + db_state_filter = state_filter.return_expanded() + + async def _the_request() -> StateMap[str]: + group_to_state_dict = await self._get_state_groups_from_groups( + (group,), state_filter=db_state_filter + ) + + # Now let's update the caches + self._insert_into_cache( + group_to_state_dict, + db_state_filter, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + # Remove ourselves from the in-flight cache + group_request_dict = self._state_group_inflight_requests[group] + del group_request_dict[db_state_filter] + if not group_request_dict: + # If there are no more requests in-flight for this group, + # clean up the cache by removing the empty dictionary + del self._state_group_inflight_requests[group] + + return group_to_state_dict[group] + + # We don't immediately await the result, so must use run_in_background + # But we DO await the result before the current log context (request) + # finishes, so don't need to run it as a background process. + request_deferred = run_in_background(_the_request) + observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True) + + # Insert the ObservableDeferred into the cache + group_request_dict = self._state_group_inflight_requests.setdefault(group, {}) + group_request_dict[db_state_filter] = observable_deferred + + return await make_deferred_yieldable(observable_deferred.observe()) + + async def _get_state_for_group_using_inflight_cache( + self, group: int, state_filter: StateFilter + ) -> MutableStateMap[str]: + """ + Gets the state at a state group, potentially filtering by type and/or + state key. + + 1. Calls _get_state_for_group_gather_inflight_requests to gather any + ongoing requests which might overlap with the current request. + 2. Fires a new request, using _get_state_for_group_fire_request, + for any state which cannot be gathered from ongoing requests. + + Args: + group: ID of the state group for which we want to get state + state_filter: the state filter used to fetch state from the database + Returns: + state map + """ + + # first, figure out whether we can re-use any in-flight requests + # (and if so, what would be left over) + ( + reusable_requests, + state_filter_left_over, + ) = self._get_state_for_group_gather_inflight_requests(group, state_filter) + + if state_filter_left_over != StateFilter.none(): + # Fetch remaining state + remaining = await self._get_state_for_group_fire_request( + group, state_filter_left_over + ) + assembled_state: MutableStateMap[str] = dict(remaining) + else: + assembled_state = {} + + gathered = await make_deferred_yieldable( + defer.gatherResults( + (r.observe() for r in reusable_requests), consumeErrors=True + ) + ).addErrback(unwrapFirstError) + + # assemble our result. + for result_piece in gathered: + assembled_state.update(result_piece) + + # Filter out any state that may be more than what we asked for. + return state_filter.filter_state(assembled_state) + async def _get_state_for_groups( self, groups: Iterable[int], state_filter: Optional[StateFilter] = None ) -> Dict[int, MutableStateMap[str]]: @@ -269,31 +436,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): if not incomplete_groups: return state - cache_sequence_nm = self._state_group_cache.sequence - cache_sequence_m = self._state_group_members_cache.sequence - - # Help the cache hit ratio by expanding the filter a bit - db_state_filter = state_filter.return_expanded() - - group_to_state_dict = await self._get_state_groups_from_groups( - list(incomplete_groups), state_filter=db_state_filter - ) + async def get_from_cache(group: int, state_filter: StateFilter) -> None: + state[group] = await self._get_state_for_group_using_inflight_cache( + group, state_filter + ) - # Now lets update the caches - self._insert_into_cache( - group_to_state_dict, - db_state_filter, - cache_seq_num_members=cache_sequence_m, - cache_seq_num_non_members=cache_sequence_nm, + await yieldable_gather_results( + get_from_cache, + incomplete_groups, + state_filter, ) - # And finally update the result dict, by filtering out any extra - # stuff we pulled out of the database. - for group, group_state_dict in group_to_state_dict.items(): - # We just replace any existing entries, as we will have loaded - # everything we need from the database anyway. - state[group] = state_filter.filter_state(group_state_dict) - return state def _get_state_for_groups_using_cache( diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index b1536c1ca4..36ca2b8273 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -13,13 +13,16 @@ # limitations under the License. import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import attr from synapse.api.errors import SynapseError from synapse.types import JsonDict +if TYPE_CHECKING: + from synapse.storage.databases.main import DataStore + logger = logging.getLogger(__name__) @@ -39,14 +42,14 @@ class PaginationChunk: next_batch: Optional[Any] = None prev_batch: Optional[Any] = None - def to_dict(self) -> Dict[str, Any]: + async def to_dict(self, store: "DataStore") -> Dict[str, Any]: d = {"chunk": self.chunk} if self.next_batch: - d["next_batch"] = self.next_batch.to_string() + d["next_batch"] = await self.next_batch.to_string(store) if self.prev_batch: - d["prev_batch"] = self.prev_batch.to_string() + d["prev_batch"] = await self.prev_batch.to_string(store) return d @@ -75,7 +78,7 @@ class RelationPaginationToken: except ValueError: raise SynapseError(400, "Invalid relation pagination token") - def to_string(self) -> str: + async def to_string(self, store: "DataStore") -> str: return "%d-%d" % (self.topological, self.stream) def as_tuple(self) -> Tuple[Any, ...]: @@ -105,7 +108,7 @@ class AggregationPaginationToken: except ValueError: raise SynapseError(400, "Invalid aggregation pagination token") - def to_string(self) -> str: + async def to_string(self, store: "DataStore") -> str: return "%d-%d" % (self.count, self.stream) def as_tuple(self) -> Tuple[Any, ...]: diff --git a/synapse/storage/schema/main/delta/68/02_msc2409_add_device_id_appservice_stream_type.sql b/synapse/storage/schema/main/delta/68/02_msc2409_add_device_id_appservice_stream_type.sql new file mode 100644 index 0000000000..bbf0af5311 --- /dev/null +++ b/synapse/storage/schema/main/delta/68/02_msc2409_add_device_id_appservice_stream_type.sql @@ -0,0 +1,21 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a column to track what to_device stream id that this application +-- service has been caught up to. + +-- NULL indicates that this appservice has never received any to_device messages. This +-- can be used, for example, to avoid sending a huge dump of messages at startup. +ALTER TABLE application_services_state ADD COLUMN to_device_stream_id BIGINT; \ No newline at end of file diff --git a/synapse/storage/schema/main/delta/68/03_delete_account_data_for_deactivated_accounts.sql b/synapse/storage/schema/main/delta/68/03_delete_account_data_for_deactivated_accounts.sql new file mode 100644 index 0000000000..e124933843 --- /dev/null +++ b/synapse/storage/schema/main/delta/68/03_delete_account_data_for_deactivated_accounts.sql @@ -0,0 +1,20 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- We want to retroactively delete account data for users that were already +-- deactivated. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6803, 'delete_account_data_for_deactivated_users', '{}'); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 913448f0f9..e79ecf64a0 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -204,13 +204,16 @@ class StateFilter: if get_all_members: # We want to return everything. return StateFilter.all() - else: + elif EventTypes.Member in self.types: # We want to return all non-members, but only particular # memberships return StateFilter( types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}), include_others=True, ) + else: + # We want to return all non-members + return _ALL_NON_MEMBER_STATE_FILTER def make_sql_filter_clause(self) -> Tuple[str, List[str]]: """Converts the filter to an SQL clause. @@ -528,6 +531,9 @@ class StateFilter: _ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True) +_ALL_NON_MEMBER_STATE_FILTER = StateFilter( + types=frozendict({EventTypes.Member: frozenset()}), include_others=True +) _NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 21591d0bfd..4ec2a713cf 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -37,14 +37,16 @@ class _EventSourcesInner: account_data: AccountDataEventSource def get_sources(self) -> Iterator[Tuple[str, EventSource]]: - for attribute in _EventSourcesInner.__attrs_attrs__: # type: ignore[attr-defined] + for attribute in attr.fields(_EventSourcesInner): yield attribute.name, getattr(self, attribute.name) class EventSources: def __init__(self, hs: "HomeServer"): self.sources = _EventSourcesInner( - *(attribute.type(hs) for attribute in _EventSourcesInner.__attrs_attrs__) # type: ignore[attr-defined] + # mypy thinks attribute.type is `Optional`, but we know it's never `None` here since + # all the attributes of `_EventSourcesInner` are annotated. + *(attribute.type(hs) for attribute in attr.fields(_EventSourcesInner)) # type: ignore[misc] ) self.store = hs.get_datastore() diff --git a/synapse/types.py b/synapse/types.py index f89fb216a6..53be3583a0 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -51,7 +51,7 @@ from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: from synapse.appservice.api import ApplicationService - from synapse.storage.databases.main import DataStore + from synapse.storage.databases.main import DataStore, PurgeEventsStore # Define a state map type from type/state_key to T (usually an event ID or # event) @@ -485,7 +485,7 @@ class RoomStreamToken: ) @classmethod - async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken": + async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken": try: if string[0] == "s": return cls(topological=None, stream=int(string[1:])) @@ -502,7 +502,7 @@ class RoomStreamToken: instance_id = int(key) pos = int(value) - instance_name = await store.get_name_from_instance_id(instance_id) + instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined] instance_map[instance_name] = pos return cls( diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 15debd6c46..1cbc180eda 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -56,6 +56,7 @@ response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["n class EvictionReason(Enum): size = auto() time = auto() + invalidation = auto() @attr.s(slots=True, auto_attribs=True) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 377c9a282a..1d6ec22191 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -81,13 +81,14 @@ class DeferredCache(Generic[KT, VT]): Args: name: The name of the cache max_entries: Maximum amount of entries that the cache will hold - keylen: The length of the tuple used as the cache key. Ignored unless - `tree` is True. tree: Use a TreeCache instead of a dict as the underlying cache type iterable: If True, count each item in the cached object as an entry, rather than each cached object apply_cache_factor_from_config: Whether cache factors specified in the config file affect `max_entries` + prune_unread_entries: If True, cache entries that haven't been read recently + will be evicted from the cache in the background. Set to False to + opt-out of this behaviour. """ cache_type = TreeCache if tree else dict diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 375cd443f1..df4fb156c2 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -254,9 +254,17 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): return r1 + r2 Args: + orig: + max_entries: num_args: number of positional arguments (excluding ``self`` and ``cache_context``) to use as cache keys. Defaults to all named args of the function. + tree: + cache_context: + iterable: + prune_unread_entries: If True, cache entries that haven't been read recently + will be evicted from the cache in the background. Set to False to opt-out + of this behaviour. """ def __init__( diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 67ee4c693b..c6a5d0dfc0 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -133,6 +133,11 @@ class ExpiringCache(Generic[KT, VT]): raise KeyError(key) return default + if self.iterable: + self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value)) + else: + self.metrics.inc_evictions(EvictionReason.invalidation) + return value.value def __contains__(self, key: KT) -> bool: diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 3f11a2f9dd..45ff0de638 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -340,6 +340,12 @@ class LruCache(Generic[KT, VT]): apply_cache_factor_from_config (bool): If true, `max_size` will be multiplied by a cache factor derived from the homeserver config + + clock: + + prune_unread_entries: If True, cache entries that haven't been read recently + will be evicted from the cache in the background. Set to False to + opt-out of this behaviour. """ # Default `clock` to something sensible. Note that we rename it to # `real_clock` so that mypy doesn't think its still `Optional`. @@ -554,8 +560,10 @@ class LruCache(Generic[KT, VT]): def cache_pop(key: KT, default: Optional[T] = None) -> Union[None, T, VT]: node = cache.get(key, None) if node: - delete_node(node) + evicted_len = delete_node(node) cache.pop(node.key, None) + if metrics: + metrics.inc_evictions(EvictionReason.invalidation, evicted_len) return node.value else: return default diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py index de04f34e4e..031880ec39 100644 --- a/synapse/util/daemonize.py +++ b/synapse/util/daemonize.py @@ -20,7 +20,7 @@ import os import signal import sys from types import FrameType, TracebackType -from typing import NoReturn, Type +from typing import NoReturn, Optional, Type def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -> None: @@ -100,7 +100,9 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - # also catch any other uncaught exceptions before we get that far.) def excepthook( - type_: Type[BaseException], value: BaseException, traceback: TracebackType + type_: Type[BaseException], + value: BaseException, + traceback: Optional[TracebackType], ) -> None: logger.critical("Unhanded exception", exc_info=(type_, value, traceback)) @@ -123,7 +125,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") - sys.exit(1) # write a log line on SIGTERM. - def sigterm(signum: signal.Signals, frame: FrameType) -> NoReturn: + def sigterm(signum: int, frame: Optional[FrameType]) -> NoReturn: logger.warning("Caught signal %s. Stopping daemon." % signum) sys.exit(0) diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py index 1f18654d47..6d4b0b7c5a 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py @@ -14,7 +14,7 @@ import functools import sys -from typing import Any, Callable, Generator, List, TypeVar +from typing import Any, Callable, Generator, List, TypeVar, cast from twisted.internet import defer from twisted.internet.defer import Deferred @@ -174,7 +174,9 @@ def _check_yield_points( ) ) changes.append(err) - return getattr(e, "value", None) + # The `StopIteration` or `_DefGen_Return` contains the return value from the + # generator. + return cast(T, e.value) frame = gen.gi_frame diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index ea1032b4fc..b26546aecd 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -16,8 +16,7 @@ import itertools import re import secrets import string -from collections.abc import Iterable -from typing import Optional, Tuple +from typing import Iterable, Optional, Tuple from netaddr import valid_ipv6 @@ -197,7 +196,7 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str: """If iterable has maxitems or fewer, return the stringification of a list containing those items. - Otherwise, return the stringification of a a list with the first maxitems items, + Otherwise, return the stringification of a list with the first maxitems items, followed by "...". Args: diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index 389adf00f6..1e9c2faa64 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -32,7 +32,12 @@ logger = logging.getLogger(__name__) MAX_EMAIL_ADDRESS_LENGTH = 500 -def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool: +async def check_3pid_allowed( + hs: "HomeServer", + medium: str, + address: str, + registration: bool = False, +) -> bool: """Checks whether a given format of 3PID is allowed to be used on this HS Args: @@ -40,9 +45,15 @@ def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool: medium: 3pid medium - e.g. email, msisdn address: address within that medium (e.g. "wotan@matrix.org") msisdns need to first have been canonicalised + registration: whether we want to bind the 3PID as part of registering a new user. + Returns: bool: whether the 3PID medium/address is allowed to be added to this HS """ + if not await hs.get_password_auth_provider().is_3pid_allowed( + medium, address, registration + ): + return False if hs.config.registration.allowed_local_3pids: for constraint in hs.config.registration.allowed_local_3pids: diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py deleted file mode 100644 index c144ff62c1..0000000000 --- a/synapse/util/versionstring.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -import subprocess -from types import ModuleType -from typing import Dict - -logger = logging.getLogger(__name__) - -version_cache: Dict[ModuleType, str] = {} - - -def get_version_string(module: ModuleType) -> str: - """Given a module calculate a git-aware version string for it. - - If called on a module not in a git checkout will return `__version__`. - - Args: - module: The module to check the version of. Must declare a __version__ - attribute. - - Returns: - The module version (as a string). - """ - - cached_version = version_cache.get(module) - if cached_version is not None: - return cached_version - - # We want this to fail loudly with an AttributeError. Type-ignore this so - # mypy only considers the happy path. - version_string = module.__version__ # type: ignore[attr-defined] - - try: - cwd = os.path.dirname(os.path.abspath(module.__file__)) - - def _run_git_command(prefix: str, *params: str) -> str: - try: - result = ( - subprocess.check_output( - ["git", *params], stderr=subprocess.DEVNULL, cwd=cwd - ) - .strip() - .decode("ascii") - ) - return prefix + result - except (subprocess.CalledProcessError, FileNotFoundError): - return "" - - git_branch = _run_git_command("b=", "rev-parse", "--abbrev-ref", "HEAD") - git_tag = _run_git_command("t=", "describe", "--exact-match") - git_commit = _run_git_command("", "rev-parse", "--short", "HEAD") - - dirty_string = "-this_is_a_dirty_checkout" - is_dirty = _run_git_command("", "describe", "--dirty=" + dirty_string).endswith( - dirty_string - ) - git_dirty = "dirty" if is_dirty else "" - - if git_branch or git_tag or git_commit or git_dirty: - git_version = ",".join( - s for s in (git_branch, git_tag, git_commit, git_dirty) if s - ) - - version_string = f"{version_string} ({git_version})" - except Exception as e: - logger.info("Failed to check for git repository: %s", e) - - version_cache[module] = version_string - - return version_string diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index ba2a2bfd64..9bd6275e92 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.appservice import ApplicationService, Namespace from tests import unittest +from tests.test_utils import simple_async_mock def _regex(regex: str, exclusive: bool = True) -> Namespace: @@ -39,13 +40,19 @@ class ApplicationServiceTestCase(unittest.TestCase): ) self.store = Mock() + self.store.get_aliases_for_room = simple_async_mock([]) + self.store.get_users_in_room = simple_async_mock([]) @defer.inlineCallbacks def test_regex_user_id_prefix_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" self.assertTrue( - (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) ) @defer.inlineCallbacks @@ -53,7 +60,11 @@ class ApplicationServiceTestCase(unittest.TestCase): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" self.assertFalse( - (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) ) @defer.inlineCallbacks @@ -63,7 +74,11 @@ class ApplicationServiceTestCase(unittest.TestCase): self.event.type = "m.room.member" self.event.state_key = "@irc_foobar:matrix.org" self.assertTrue( - (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) ) @defer.inlineCallbacks @@ -73,7 +88,11 @@ class ApplicationServiceTestCase(unittest.TestCase): ) self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org" self.assertTrue( - (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) ) @defer.inlineCallbacks @@ -83,7 +102,11 @@ class ApplicationServiceTestCase(unittest.TestCase): ) self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org" self.assertFalse( - (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) ) @defer.inlineCallbacks @@ -91,10 +114,10 @@ class ApplicationServiceTestCase(unittest.TestCase): self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.store.get_aliases_for_room.return_value = defer.succeed( + self.store.get_aliases_for_room = simple_async_mock( ["#irc_foobar:matrix.org", "#athing:matrix.org"] ) - self.store.get_users_in_room.return_value = defer.succeed([]) + self.store.get_users_in_room = simple_async_mock([]) self.assertTrue( ( yield defer.ensureDeferred( @@ -144,10 +167,10 @@ class ApplicationServiceTestCase(unittest.TestCase): self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.store.get_aliases_for_room.return_value = defer.succeed( + self.store.get_aliases_for_room = simple_async_mock( ["#xmpp_foobar:matrix.org", "#athing:matrix.org"] ) - self.store.get_users_in_room.return_value = defer.succeed([]) + self.store.get_users_in_room = simple_async_mock([]) self.assertFalse( ( yield defer.ensureDeferred( @@ -163,10 +186,8 @@ class ApplicationServiceTestCase(unittest.TestCase): ) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" - self.store.get_aliases_for_room.return_value = defer.succeed( - ["#irc_barfoo:matrix.org"] - ) - self.store.get_users_in_room.return_value = defer.succeed([]) + self.store.get_aliases_for_room = simple_async_mock(["#irc_barfoo:matrix.org"]) + self.store.get_users_in_room = simple_async_mock([]) self.assertTrue( ( yield defer.ensureDeferred( @@ -184,17 +205,21 @@ class ApplicationServiceTestCase(unittest.TestCase): self.event.content = {"membership": "invite"} self.event.state_key = self.service.sender self.assertTrue( - (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) ) @defer.inlineCallbacks def test_member_list_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) # Note that @irc_fo:here is the AS user. - self.store.get_users_in_room.return_value = defer.succeed( + self.store.get_users_in_room = simple_async_mock( ["@alice:here", "@irc_fo:here", "@bob:here"] ) - self.store.get_aliases_for_room.return_value = defer.succeed([]) + self.store.get_aliases_for_room = simple_async_mock([]) self.event.sender = "@xmpp_foobar:matrix.org" self.assertTrue( diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 55f0899bae..8fb6687f89 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -11,23 +11,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING from unittest.mock import Mock from twisted.internet import defer from synapse.appservice import ApplicationServiceState from synapse.appservice.scheduler import ( + ApplicationServiceScheduler, _Recoverer, - _ServiceQueuer, _TransactionController, ) from synapse.logging.context import make_deferred_yieldable +from synapse.server import HomeServer +from synapse.util import Clock from tests import unittest from tests.test_utils import simple_async_mock from ..utils import MockClock +if TYPE_CHECKING: + from twisted.internet.testing import MemoryReactor + class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): def setUp(self): @@ -58,7 +64,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events, ephemeral=[] # txn made and saved + service=service, + events=events, + ephemeral=[], + to_device_messages=[], # txn made and saved ) self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made txn.complete.assert_called_once_with(self.store) # txn completed @@ -79,7 +88,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events, ephemeral=[] # txn made and saved + service=service, + events=events, + ephemeral=[], + to_device_messages=[], # txn made and saved ) self.assertEquals(0, txn.send.call_count) # txn not sent though self.assertEquals(0, txn.complete.call_count) # or completed @@ -102,7 +114,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( - service=service, events=events, ephemeral=[] + service=service, events=events, ephemeral=[], to_device_messages=[] ) self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made self.assertEquals(1, self.recoverer.recover.call_count) # and invoked @@ -189,38 +201,41 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.callback.assert_called_once_with(self.recoverer) -class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): - def setUp(self): +class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer): + self.scheduler = ApplicationServiceScheduler(hs) self.txn_ctrl = Mock() self.txn_ctrl.send = simple_async_mock() - self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock()) + + # Replace instantiated _TransactionController instances with our Mock + self.scheduler.txn_ctrl = self.txn_ctrl + self.scheduler.queuer.txn_ctrl = self.txn_ctrl def test_send_single_event_no_queue(self): # Expect the event to be sent immediately. service = Mock(id=4) event = Mock() - self.queuer.enqueue_event(service, event) - self.txn_ctrl.send.assert_called_once_with(service, [event], []) + self.scheduler.enqueue_for_appservice(service, events=[event]) + self.txn_ctrl.send.assert_called_once_with(service, [event], [], []) def test_send_single_event_with_queue(self): d = defer.Deferred() - self.txn_ctrl.send = Mock( - side_effect=lambda x, y, z: make_deferred_yieldable(d) - ) + self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d)) service = Mock(id=4) event = Mock(event_id="first") event2 = Mock(event_id="second") event3 = Mock(event_id="third") # Send an event and don't resolve it just yet. - self.queuer.enqueue_event(service, event) + self.scheduler.enqueue_for_appservice(service, events=[event]) # Send more events: expect send() to NOT be called multiple times. - self.queuer.enqueue_event(service, event2) - self.queuer.enqueue_event(service, event3) - self.txn_ctrl.send.assert_called_with(service, [event], []) + # (call enqueue_for_appservice multiple times deliberately) + self.scheduler.enqueue_for_appservice(service, events=[event2]) + self.scheduler.enqueue_for_appservice(service, events=[event3]) + self.txn_ctrl.send.assert_called_with(service, [event], [], []) self.assertEquals(1, self.txn_ctrl.send.call_count) # Resolve the send event: expect the queued events to be sent d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [event2, event3], []) + self.txn_ctrl.send.assert_called_with(service, [event2, event3], [], []) self.assertEquals(2, self.txn_ctrl.send.call_count) def test_multiple_service_queues(self): @@ -238,23 +253,23 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): send_return_list = [srv_1_defer, srv_2_defer] - def do_send(x, y, z): + def do_send(*args, **kwargs): return make_deferred_yieldable(send_return_list.pop(0)) self.txn_ctrl.send = Mock(side_effect=do_send) # send events for different ASes and make sure they are sent - self.queuer.enqueue_event(srv1, srv_1_event) - self.queuer.enqueue_event(srv1, srv_1_event2) - self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], []) - self.queuer.enqueue_event(srv2, srv_2_event) - self.queuer.enqueue_event(srv2, srv_2_event2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], []) + self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event]) + self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2]) + self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], []) + self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event]) + self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2]) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], []) # make sure callbacks for a service only send queued events for THAT # service srv_2_defer.callback(srv2) - self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], []) + self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], []) self.assertEquals(3, self.txn_ctrl.send.call_count) def test_send_large_txns(self): @@ -262,7 +277,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): srv_2_defer = defer.Deferred() send_return_list = [srv_1_defer, srv_2_defer] - def do_send(x, y, z): + def do_send(*args, **kwargs): return make_deferred_yieldable(send_return_list.pop(0)) self.txn_ctrl.send = Mock(side_effect=do_send) @@ -270,67 +285,65 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): service = Mock(id=4, name="service") event_list = [Mock(name="event%i" % (i + 1)) for i in range(200)] for event in event_list: - self.queuer.enqueue_event(service, event) + self.scheduler.enqueue_for_appservice(service, [event], []) # Expect the first event to be sent immediately. - self.txn_ctrl.send.assert_called_with(service, [event_list[0]], []) + self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [], []) srv_1_defer.callback(service) # Then send the next 100 events - self.txn_ctrl.send.assert_called_with(service, event_list[1:101], []) + self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [], []) srv_2_defer.callback(service) # Then the final 99 events - self.txn_ctrl.send.assert_called_with(service, event_list[101:], []) + self.txn_ctrl.send.assert_called_with(service, event_list[101:], [], []) self.assertEquals(3, self.txn_ctrl.send.call_count) def test_send_single_ephemeral_no_queue(self): # Expect the event to be sent immediately. service = Mock(id=4, name="service") event_list = [Mock(name="event")] - self.queuer.enqueue_ephemeral(service, event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], event_list) + self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) + self.txn_ctrl.send.assert_called_once_with(service, [], event_list, []) def test_send_multiple_ephemeral_no_queue(self): # Expect the event to be sent immediately. service = Mock(id=4, name="service") event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")] - self.queuer.enqueue_ephemeral(service, event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], event_list) + self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) + self.txn_ctrl.send.assert_called_once_with(service, [], event_list, []) def test_send_single_ephemeral_with_queue(self): d = defer.Deferred() - self.txn_ctrl.send = Mock( - side_effect=lambda x, y, z: make_deferred_yieldable(d) - ) + self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d)) service = Mock(id=4) event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")] event_list_2 = [Mock(event_id="event3"), Mock(event_id="event4")] event_list_3 = [Mock(event_id="event5"), Mock(event_id="event6")] # Send an event and don't resolve it just yet. - self.queuer.enqueue_ephemeral(service, event_list_1) + self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_1) # Send more events: expect send() to NOT be called multiple times. - self.queuer.enqueue_ephemeral(service, event_list_2) - self.queuer.enqueue_ephemeral(service, event_list_3) - self.txn_ctrl.send.assert_called_with(service, [], event_list_1) + self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2) + self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3) + self.txn_ctrl.send.assert_called_with(service, [], event_list_1, []) self.assertEquals(1, self.txn_ctrl.send.call_count) # Resolve txn_ctrl.send d.callback(service) # Expect the queued events to be sent - self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3) + self.txn_ctrl.send.assert_called_with( + service, [], event_list_2 + event_list_3, [] + ) self.assertEquals(2, self.txn_ctrl.send.call_count) def test_send_large_txns_ephemeral(self): d = defer.Deferred() - self.txn_ctrl.send = Mock( - side_effect=lambda x, y, z: make_deferred_yieldable(d) - ) + self.txn_ctrl.send = Mock(return_value=make_deferred_yieldable(d)) # Expect the event to be sent immediately. service = Mock(id=4, name="service") first_chunk = [Mock(name="event%i" % (i + 1)) for i in range(100)] second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)] event_list = first_chunk + second_chunk - self.queuer.enqueue_ephemeral(service, event_list) - self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk) + self.scheduler.enqueue_for_appservice(service, ephemeral=event_list) + self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk, []) d.callback(service) - self.txn_ctrl.send.assert_called_with(service, [], second_chunk) + self.txn_ctrl.send.assert_called_with(service, [], second_chunk, []) self.assertEquals(2, self.txn_ctrl.send.call_count) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 7b486aba4a..e40ef95874 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -47,7 +47,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): ) # Get the room complexity - channel = self.make_request( + channel = self.make_signed_federation_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) self.assertEquals(200, channel.code) @@ -59,7 +59,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23) # Get the room complexity again -- make sure it's our artificial value - channel = self.make_request( + channel = self.make_signed_federation_request( "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) ) self.assertEquals(200, channel.code) diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 03e1e11f49..d084919ef7 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -16,12 +16,21 @@ import logging from parameterized import parameterized +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.config.server import DEFAULT_ROOM_VERSION +from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events import make_event_from_dict from synapse.federation.federation_server import server_matches_acl_event from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest +from tests.unittest import override_config class FederationServerTests(unittest.FederatingHomeserverTestCase): @@ -113,7 +122,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): room_1 = self.helper.create_room_as(u1, tok=u1_token) self.inject_room_member(room_1, "@user:other.example.com", "join") - channel = self.make_request( + channel = self.make_signed_federation_request( "GET", "/_matrix/federation/v1/state/%s" % (room_1,) ) self.assertEquals(200, channel.code, channel.result) @@ -145,13 +154,152 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): room_1 = self.helper.create_room_as(u1, tok=u1_token) - channel = self.make_request( + channel = self.make_signed_federation_request( "GET", "/_matrix/federation/v1/state/%s" % (room_1,) ) self.assertEquals(403, channel.code, channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") +class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + super().prepare(reactor, clock, hs) + + # create the room + creator_user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + self._room_id = self.helper.create_room_as( + room_creator=creator_user_id, tok=tok + ) + + # a second member on the orgin HS + second_member_user_id = self.register_user("fozzie", "bear") + tok2 = self.login("fozzie", "bear") + self.helper.join(self._room_id, second_member_user_id, tok=tok2) + + def _make_join(self, user_id) -> JsonDict: + channel = self.make_signed_federation_request( + "GET", + f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}" + f"?ver={DEFAULT_ROOM_VERSION}", + ) + self.assertEquals(channel.code, 200, channel.json_body) + return channel.json_body + + def test_send_join(self): + """happy-path test of send_join""" + joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME + join_result = self._make_join(joining_user) + + join_event_dict = join_result["event"] + add_hashes_and_signatures( + KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], + join_event_dict, + signature_name=self.OTHER_SERVER_NAME, + signing_key=self.OTHER_SERVER_SIGNATURE_KEY, + ) + channel = self.make_signed_federation_request( + "PUT", + f"/_matrix/federation/v2/send_join/{self._room_id}/x", + content=join_event_dict, + ) + self.assertEquals(channel.code, 200, channel.json_body) + + # we should get complete room state back + returned_state = [ + (ev["type"], ev["state_key"]) for ev in channel.json_body["state"] + ] + self.assertCountEqual( + returned_state, + [ + ("m.room.create", ""), + ("m.room.power_levels", ""), + ("m.room.join_rules", ""), + ("m.room.history_visibility", ""), + ("m.room.member", "@kermit:test"), + ("m.room.member", "@fozzie:test"), + # nb: *not* the joining user + ], + ) + + # also check the auth chain + returned_auth_chain_events = [ + (ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"] + ] + self.assertCountEqual( + returned_auth_chain_events, + [ + ("m.room.create", ""), + ("m.room.member", "@kermit:test"), + ("m.room.power_levels", ""), + ("m.room.join_rules", ""), + ], + ) + + # the room should show that the new user is a member + r = self.get_success( + self.hs.get_state_handler().get_current_state(self._room_id) + ) + self.assertEqual(r[("m.room.member", joining_user)].membership, "join") + + @override_config({"experimental_features": {"msc3706_enabled": True}}) + def test_send_join_partial_state(self): + """When MSC3706 support is enabled, /send_join should return partial state""" + joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME + join_result = self._make_join(joining_user) + + join_event_dict = join_result["event"] + add_hashes_and_signatures( + KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], + join_event_dict, + signature_name=self.OTHER_SERVER_NAME, + signing_key=self.OTHER_SERVER_SIGNATURE_KEY, + ) + channel = self.make_signed_federation_request( + "PUT", + f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true", + content=join_event_dict, + ) + self.assertEquals(channel.code, 200, channel.json_body) + + # expect a reduced room state + returned_state = [ + (ev["type"], ev["state_key"]) for ev in channel.json_body["state"] + ] + self.assertCountEqual( + returned_state, + [ + ("m.room.create", ""), + ("m.room.power_levels", ""), + ("m.room.join_rules", ""), + ("m.room.history_visibility", ""), + ], + ) + + # the auth chain should not include anything already in "state" + returned_auth_chain_events = [ + (ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"] + ] + self.assertCountEqual( + returned_auth_chain_events, + [ + ("m.room.member", "@kermit:test"), + ], + ) + + # the room should show that the new user is a member + r = self.get_success( + self.hs.get_state_handler().get_current_state(self._room_id) + ) + self.assertEqual(r[("m.room.member", joining_user)].membership, "join") + + def _create_acl_event(content): return make_event_from_dict( { diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py index a7031a55f2..c2320ce133 100644 --- a/tests/federation/transport/test_client.py +++ b/tests/federation/transport/test_client.py @@ -62,3 +62,35 @@ class SendJoinParserTestCase(TestCase): self.assertEqual(len(parsed_response.state), 1, parsed_response) self.assertEqual(parsed_response.event_dict, {}, parsed_response) self.assertIsNone(parsed_response.event, parsed_response) + self.assertFalse(parsed_response.partial_state, parsed_response) + self.assertEqual(parsed_response.servers_in_room, None, parsed_response) + + def test_partial_state(self) -> None: + """Check that the partial_state flag is correctly parsed""" + parser = SendJoinParser(RoomVersions.V1, False) + response = { + "org.matrix.msc3706.partial_state": True, + } + + serialised_response = json.dumps(response).encode() + + # Send data to the parser + parser.write(serialised_response) + + # Retrieve and check the parsed SendJoinResponse + parsed_response = parser.finish() + self.assertTrue(parsed_response.partial_state) + + def test_servers_in_room(self) -> None: + """Check that the servers_in_room field is correctly parsed""" + parser = SendJoinParser(RoomVersions.V1, False) + response = {"org.matrix.msc3706.servers_in_room": ["hs1", "hs2"]} + + serialised_response = json.dumps(response).encode() + + # Send data to the parser + parser.write(serialised_response) + + # Retrieve and check the parsed SendJoinResponse + parsed_response = parser.finish() + self.assertEqual(parsed_response.servers_in_room, ["hs1", "hs2"]) diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index bfa156eebb..686f42ab48 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -245,7 +245,7 @@ class FederationKnockingTestCase( self.hs, room_id, user_id ) - channel = self.make_request( + channel = self.make_signed_federation_request( "GET", "/_matrix/federation/v1/make_knock/%s/%s?ver=%s" % ( @@ -288,7 +288,7 @@ class FederationKnockingTestCase( ) # Send the signed knock event into the room - channel = self.make_request( + channel = self.make_signed_federation_request( "PUT", "/_matrix/federation/v1/send_knock/%s/%s" % (room_id, signed_knock_event.event_id), diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py index 84fa72b9ff..eb62addda8 100644 --- a/tests/federation/transport/test_server.py +++ b/tests/federation/transport/test_server.py @@ -22,10 +22,9 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): """Test that unauthenticated requests to the public rooms directory 403 when allow_public_rooms_over_federation is False. """ - channel = self.make_request( + channel = self.make_signed_federation_request( "GET", "/_matrix/federation/v1/publicRooms", - federation_auth_origin=b"example.com", ) self.assertEquals(403, channel.code) @@ -34,9 +33,8 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase): """Test that unauthenticated requests to the public rooms directory 200 when allow_public_rooms_over_federation is True. """ - channel = self.make_request( + channel = self.make_signed_federation_request( "GET", "/_matrix/federation/v1/publicRooms", - federation_auth_origin=b"example.com", ) self.assertEquals(200, channel.code) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index d6f14e2dba..fe57ff2671 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -1,4 +1,4 @@ -# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2015-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,18 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Iterable, List, Optional from unittest.mock import Mock from twisted.internet import defer +import synapse.rest.admin +import synapse.storage +from synapse.appservice import ApplicationService from synapse.handlers.appservice import ApplicationServicesHandler +from synapse.rest.client import login, receipts, room, sendtodevice from synapse.types import RoomStreamToken +from synapse.util.stringutils import random_string -from tests.test_utils import make_awaitable +from tests import unittest +from tests.test_utils import make_awaitable, simple_async_mock from tests.utils import MockClock -from .. import unittest - class AppServiceHandlerTestCase(unittest.TestCase): """Tests the ApplicationServicesHandler.""" @@ -36,6 +41,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): hs.get_datastore.return_value = self.mock_store self.mock_store.get_received_ts.return_value = make_awaitable(0) self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None) + self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable( + None + ) hs.get_application_service_api.return_value = self.mock_as_api hs.get_application_service_scheduler.return_value = self.mock_scheduler hs.get_clock.return_value = MockClock() @@ -63,8 +71,8 @@ class AppServiceHandlerTestCase(unittest.TestCase): ] self.handler.notify_interested_services(RoomStreamToken(None, 1)) - self.mock_scheduler.submit_event_for_as.assert_called_once_with( - interested_service, event + self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( + interested_service, events=[event] ) def test_query_user_exists_unknown_user(self): @@ -261,7 +269,6 @@ class AppServiceHandlerTestCase(unittest.TestCase): """ interested_service = self._mkservice(is_interested=True) services = [interested_service] - self.mock_store.get_app_services.return_value = services self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable( 579 @@ -275,10 +282,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.handler.notify_interested_services_ephemeral( "receipt_key", 580, ["@fakerecipient:example.com"] ) - self.mock_scheduler.submit_ephemeral_events_for_as.assert_called_once_with( - interested_service, [event] + self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( + interested_service, ephemeral=[event] ) - self.mock_store.set_type_stream_id_for_appservice.assert_called_once_with( + self.mock_store.set_appservice_stream_type_pos.assert_called_once_with( interested_service, "read_receipt", 580, @@ -305,7 +312,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.handler.notify_interested_services_ephemeral( "receipt_key", 580, ["@fakerecipient:example.com"] ) - self.mock_scheduler.submit_ephemeral_events_for_as.assert_not_called() + # This method will be called, but with an empty list of events + self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( + interested_service, ephemeral=[] + ) def _mkservice(self, is_interested, protocols=None): service = Mock() @@ -321,3 +331,252 @@ class AppServiceHandlerTestCase(unittest.TestCase): service.token = "mock_service_token" service.url = "mock_service_url" return service + + +class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): + """ + Tests that the ApplicationServicesHandler sends events to application + services correctly. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + room.register_servlets, + sendtodevice.register_servlets, + receipts.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + # Mock the ApplicationServiceScheduler's _TransactionController's send method so that + # we can track any outgoing ephemeral events + self.send_mock = simple_async_mock() + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock + + # Mock out application services, and allow defining our own in tests + self._services: List[ApplicationService] = [] + self.hs.get_datastore().get_app_services = Mock(return_value=self._services) + + # A user on the homeserver. + self.local_user_device_id = "local_device" + self.local_user = self.register_user("local_user", "password") + self.local_user_token = self.login( + "local_user", "password", self.local_user_device_id + ) + + # A user on the homeserver which lies within an appservice's exclusive user namespace. + self.exclusive_as_user_device_id = "exclusive_as_device" + self.exclusive_as_user = self.register_user("exclusive_as_user", "password") + self.exclusive_as_user_token = self.login( + "exclusive_as_user", "password", self.exclusive_as_user_device_id + ) + + @unittest.override_config( + {"experimental_features": {"msc2409_to_device_messages_enabled": True}} + ) + def test_application_services_receive_local_to_device(self): + """ + Test that when a user sends a to-device message to another user + that is an application service's user namespace, the + application service will receive it. + """ + interested_appservice = self._register_application_service( + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": "@exclusive_as_user:.+", + "exclusive": True, + } + ], + }, + ) + + # Have local_user send a to-device message to exclusive_as_user + message_content = {"some_key": "some really interesting value"} + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.room_key_request/3", + content={ + "messages": { + self.exclusive_as_user: { + self.exclusive_as_user_device_id: message_content + } + } + }, + access_token=self.local_user_token, + ) + self.assertEqual(chan.code, 200, chan.result) + + # Have exclusive_as_user send a to-device message to local_user + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.room_key_request/4", + content={ + "messages": { + self.local_user: {self.local_user_device_id: message_content} + } + }, + access_token=self.exclusive_as_user_token, + ) + self.assertEqual(chan.code, 200, chan.result) + + # Check if our application service - that is interested in exclusive_as_user - received + # the to-device message as part of an AS transaction. + # Only the local_user -> exclusive_as_user to-device message should have been forwarded to the AS. + # + # The uninterested application service should not have been notified at all. + self.send_mock.assert_called_once() + service, _events, _ephemeral, to_device_messages = self.send_mock.call_args[0] + + # Assert that this was the same to-device message that local_user sent + self.assertEqual(service, interested_appservice) + self.assertEqual(to_device_messages[0]["type"], "m.room_key_request") + self.assertEqual(to_device_messages[0]["sender"], self.local_user) + + # Additional fields 'to_user_id' and 'to_device_id' specifically for + # to-device messages via the AS API + self.assertEqual(to_device_messages[0]["to_user_id"], self.exclusive_as_user) + self.assertEqual( + to_device_messages[0]["to_device_id"], self.exclusive_as_user_device_id + ) + self.assertEqual(to_device_messages[0]["content"], message_content) + + @unittest.override_config( + {"experimental_features": {"msc2409_to_device_messages_enabled": True}} + ) + def test_application_services_receive_bursts_of_to_device(self): + """ + Test that when a user sends >100 to-device messages at once, any + interested AS's will receive them in separate transactions. + + Also tests that uninterested application services do not receive messages. + """ + # Register two application services with exclusive interest in a user + interested_appservices = [] + for _ in range(2): + appservice = self._register_application_service( + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": "@exclusive_as_user:.+", + "exclusive": True, + } + ], + }, + ) + interested_appservices.append(appservice) + + # ...and an application service which does not have any user interest. + self._register_application_service() + + to_device_message_content = { + "some key": "some interesting value", + } + + # We need to send a large burst of to-device messages. We also would like to + # include them all in the same application service transaction so that we can + # test large transactions. + # + # To do this, we can send a single to-device message to many user devices at + # once. + # + # We insert number_of_messages - 1 messages into the database directly. We'll then + # send a final to-device message to the real device, which will also kick off + # an AS transaction (as just inserting messages into the DB won't). + number_of_messages = 150 + fake_device_ids = [f"device_{num}" for num in range(number_of_messages - 1)] + messages = { + self.exclusive_as_user: { + device_id: to_device_message_content for device_id in fake_device_ids + } + } + + # Create a fake device per message. We can't send to-device messages to + # a device that doesn't exist. + self.get_success( + self.hs.get_datastore().db_pool.simple_insert_many( + desc="test_application_services_receive_burst_of_to_device", + table="devices", + keys=("user_id", "device_id"), + values=[ + ( + self.exclusive_as_user, + device_id, + ) + for device_id in fake_device_ids + ], + ) + ) + + # Seed the device_inbox table with our fake messages + self.get_success( + self.hs.get_datastore().add_messages_to_device_inbox(messages, {}) + ) + + # Now have local_user send a final to-device message to exclusive_as_user. All unsent + # to-device messages should be sent to any application services + # interested in exclusive_as_user. + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.room_key_request/4", + content={ + "messages": { + self.exclusive_as_user: { + self.exclusive_as_user_device_id: to_device_message_content + } + } + }, + access_token=self.local_user_token, + ) + self.assertEqual(chan.code, 200, chan.result) + + self.send_mock.assert_called() + + # Count the total number of to-device messages that were sent out per-service. + # Ensure that we only sent to-device messages to interested services, and that + # each interested service received the full count of to-device messages. + service_id_to_message_count: Dict[str, int] = {} + + for call in self.send_mock.call_args_list: + service, _events, _ephemeral, to_device_messages = call[0] + + # Check that this was made to an interested service + self.assertIn(service, interested_appservices) + + # Add to the count of messages for this application service + service_id_to_message_count.setdefault(service.id, 0) + service_id_to_message_count[service.id] += len(to_device_messages) + + # Assert that each interested service received the full count of messages + for count in service_id_to_message_count.values(): + self.assertEqual(count, number_of_messages) + + def _register_application_service( + self, + namespaces: Optional[Dict[str, Iterable[Dict]]] = None, + ) -> ApplicationService: + """ + Register a new application service, with the given namespaces of interest. + + Args: + namespaces: A dictionary containing any user, room or alias namespaces that + the application service is interested in. + + Returns: + The registered application service. + """ + # Create an application service + appservice = ApplicationService( + token=random_string(10), + hostname="example.com", + id=random_string(10), + sender="@as:example.com", + rate_limited=False, + namespaces=namespaces, + supports_ephemeral=True, + ) + + # Register the application service + self._services.append(appservice) + + return appservice diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py index 3da597768c..01096a1581 100644 --- a/tests/handlers/test_deactivate_account.py +++ b/tests/handlers/test_deactivate_account.py @@ -217,3 +217,109 @@ class DeactivateAccountTestCase(HomeserverTestCase): self.assertEqual( self.get_success(self._store.ignored_by("@sheltie:test")), set() ) + + def _rerun_retroactive_account_data_deletion_update(self) -> None: + # Reset the 'all done' flag + self._store.db_pool.updates._all_done = False + + self.get_success( + self._store.db_pool.simple_insert( + "background_updates", + { + "update_name": "delete_account_data_for_deactivated_users", + "progress_json": "{}", + }, + ) + ) + + self.wait_for_background_updates() + + def test_account_data_deleted_retroactively_by_background_update_if_deactivated( + self, + ) -> None: + """ + Tests that a user, who deactivated their account before account data was + deleted automatically upon deactivation, has their account data retroactively + scrubbed by the background update. + """ + + # Request the deactivation of our account + self._deactivate_my_account() + + # Add some account data + # (we do this after the deactivation so that the act of deactivating doesn't + # clear it out. This emulates a user that was deactivated before this was cleared + # upon deactivation.) + self.get_success( + self._store.add_account_data_for_user( + self.user, + AccountDataTypes.DIRECT, + {"@someone:remote": ["!somewhere:remote"]}, + ) + ) + + # Check that the account data is there. + self.assertIsNotNone( + self.get_success( + self._store.get_global_account_data_by_type_for_user( + self.user, + AccountDataTypes.DIRECT, + ) + ), + ) + + # Re-run the retroactive deletion update + self._rerun_retroactive_account_data_deletion_update() + + # Check that the account data was cleared. + self.assertIsNone( + self.get_success( + self._store.get_global_account_data_by_type_for_user( + self.user, + AccountDataTypes.DIRECT, + ) + ), + ) + + def test_account_data_preserved_by_background_update_if_not_deactivated( + self, + ) -> None: + """ + Tests that the background update does not scrub account data for users that have + not been deactivated. + """ + + # Add some account data + # (we do this after the deactivation so that the act of deactivating doesn't + # clear it out. This emulates a user that was deactivated before this was cleared + # upon deactivation.) + self.get_success( + self._store.add_account_data_for_user( + self.user, + AccountDataTypes.DIRECT, + {"@someone:remote": ["!somewhere:remote"]}, + ) + ) + + # Check that the account data is there. + self.assertIsNotNone( + self.get_success( + self._store.get_global_account_data_by_type_for_user( + self.user, + AccountDataTypes.DIRECT, + ) + ), + ) + + # Re-run the retroactive deletion update + self._rerun_retroactive_account_data_deletion_update() + + # Check that the account data was NOT cleared. + self.assertIsNotNone( + self.get_success( + self._store.get_global_account_data_by_type_for_user( + self.user, + AccountDataTypes.DIRECT, + ) + ), + ) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index cfe3de5266..a552d8182e 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -155,7 +155,7 @@ class OidcHandlerTestCase(HomeserverTestCase): def make_homeserver(self, reactor, clock): self.http_client = Mock(spec=["get_json"]) self.http_client.get_json.side_effect = get_json - self.http_client.user_agent = "Synapse Test" + self.http_client.user_agent = b"Synapse Test" hs = self.setup_test_homeserver(proxied_http_client=self.http_client) @@ -438,12 +438,9 @@ class OidcHandlerTestCase(HomeserverTestCase): state = "state" nonce = "nonce" client_redirect_url = "http://client/redirect" - user_agent = "Browser" ip_address = "10.0.0.1" session = self._generate_oidc_session_token(state, nonce, client_redirect_url) - request = _build_callback_request( - code, state, session, user_agent=user_agent, ip_address=ip_address - ) + request = _build_callback_request(code, state, session, ip_address=ip_address) self.get_success(self.handler.handle_oidc_callback(request)) @@ -1274,7 +1271,6 @@ def _build_callback_request( code: str, state: str, session: str, - user_agent: str = "Browser", ip_address: str = "10.0.0.1", ): """Builds a fake SynapseRequest to mock the browser callback @@ -1289,7 +1285,6 @@ def _build_callback_request( query param. Should be the same as was embedded in the session in _build_oidc_session. session: the "session" which would have been passed around in the cookie. - user_agent: the user-agent to present ip_address: the IP address to pretend the request came from """ request = Mock( diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 94809cb8be..49d832de81 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -21,13 +21,15 @@ from twisted.internet import defer import synapse from synapse.api.constants import LoginType +from synapse.api.errors import Codes from synapse.handlers.auth import load_legacy_password_auth_providers from synapse.module_api import ModuleApi -from synapse.rest.client import devices, login, logout, register +from synapse.rest.client import account, devices, login, logout, register from synapse.types import JsonDict, UserID from tests import unittest from tests.server import FakeChannel +from tests.test_utils import make_awaitable from tests.unittest import override_config # (possibly experimental) login flows we expect to appear in the list after the normal @@ -82,7 +84,7 @@ class CustomAuthProvider: def __init__(self, config, api: ModuleApi): api.register_password_auth_provider_callbacks( - auth_checkers={("test.login_type", ("test_field",)): self.check_auth}, + auth_checkers={("test.login_type", ("test_field",)): self.check_auth} ) def check_auth(self, *args): @@ -120,7 +122,7 @@ class PasswordCustomAuthProvider: auth_checkers={ ("test.login_type", ("test_field",)): self.check_auth, ("m.login.password", ("password",)): self.check_auth, - }, + } ) pass @@ -158,8 +160,12 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): devices.register_servlets, logout.register_servlets, register.register_servlets, + account.register_servlets, ] + CALLBACK_USERNAME = "get_username_for_registration" + CALLBACK_DISPLAYNAME = "get_displayname_for_registration" + def setUp(self): # we use a global mock device, so make sure we are starting with a clean slate mock_password_provider.reset_mock() @@ -751,7 +757,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): """Tests that the get_username_for_registration callback can define the username of a user when registering. """ - self._setup_get_username_for_registration() + self._setup_get_name_for_registration( + callback_name=self.CALLBACK_USERNAME, + ) username = "rin" channel = self.make_request( @@ -774,52 +782,181 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): """Tests that the get_username_for_registration callback is only called at the end of the UIA flow. """ - m = self._setup_get_username_for_registration() + m = self._setup_get_name_for_registration( + callback_name=self.CALLBACK_USERNAME, + ) + + username = "rin" + res = self._do_uia_assert_mock_not_called(username, m) + + mxid = res["user_id"] + self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") + + # Check that the callback has been called. + m.assert_called_once() + + # Set some email configuration so the test doesn't fail because of its absence. + @override_config({"email": {"notif_from": "noreply@test"}}) + def test_3pid_allowed(self): + """Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse + to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind + the 3PID. Also checks that the module is passed a boolean indicating whether the + user to bind this 3PID to is currently registering. + """ + self._test_3pid_allowed("rin", False) + self._test_3pid_allowed("kitay", True) + + def test_displayname(self): + """Tests that the get_displayname_for_registration callback can define the + display name of a user when registering. + """ + self._setup_get_name_for_registration( + callback_name=self.CALLBACK_DISPLAYNAME, + ) - # Initiate the UIA flow. username = "rin" channel = self.make_request( "POST", - "register", - {"username": username, "type": "m.login.password", "password": "bar"}, + "/register", + { + "username": username, + "password": "bar", + "auth": {"type": LoginType.DUMMY}, + }, ) - self.assertEqual(channel.code, 401) - self.assertIn("session", channel.json_body) + self.assertEqual(channel.code, 200) - # Check that the callback hasn't been called yet. - m.assert_not_called() + # Our callback takes the username and appends "-foo" to it, check that's what we + # have. + user_id = UserID.from_string(channel.json_body["user_id"]) + display_name = self.get_success( + self.hs.get_profile_handler().get_displayname(user_id) + ) - # Finish the UIA flow. - session = channel.json_body["session"] - channel = self.make_request( - "POST", - "register", - {"auth": {"session": session, "type": LoginType.DUMMY}}, + self.assertEqual(display_name, username + "-foo") + + def test_displayname_uia(self): + """Tests that the get_displayname_for_registration callback is only called at the + end of the UIA flow. + """ + m = self._setup_get_name_for_registration( + callback_name=self.CALLBACK_DISPLAYNAME, ) - self.assertEqual(channel.code, 200, channel.json_body) - mxid = channel.json_body["user_id"] - self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo") + + username = "rin" + res = self._do_uia_assert_mock_not_called(username, m) + + user_id = UserID.from_string(res["user_id"]) + display_name = self.get_success( + self.hs.get_profile_handler().get_displayname(user_id) + ) + + self.assertEqual(display_name, username + "-foo") # Check that the callback has been called. m.assert_called_once() - def _setup_get_username_for_registration(self) -> Mock: - """Registers a get_username_for_registration callback that appends "-foo" to the - username the client is trying to register. + def _test_3pid_allowed(self, username: str, registration: bool): + """Tests that the "is_3pid_allowed" module callback is called correctly, using + either /register or /account URLs depending on the arguments. + + Args: + username: The username to use for the test. + registration: Whether to test with registration URLs. """ + self.hs.get_identity_handler().send_threepid_validation = Mock( + return_value=make_awaitable(0), + ) + + m = Mock(return_value=make_awaitable(False)) + self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m] + + self.register_user(username, "password") + tok = self.login(username, "password") + + if registration: + url = "/register/email/requestToken" + else: + url = "/account/3pid/email/requestToken" + + channel = self.make_request( + "POST", + url, + { + "client_secret": "foo", + "email": "foo@test.com", + "send_attempt": 0, + }, + access_token=tok, + ) + self.assertEqual(channel.code, 403, channel.result) + self.assertEqual( + channel.json_body["errcode"], + Codes.THREEPID_DENIED, + channel.json_body, + ) + + m.assert_called_once_with("email", "foo@test.com", registration) + + m = Mock(return_value=make_awaitable(True)) + self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m] + + channel = self.make_request( + "POST", + url, + { + "client_secret": "foo", + "email": "bar@test.com", + "send_attempt": 0, + }, + access_token=tok, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertIn("sid", channel.json_body) - async def get_username_for_registration(uia_results, params): + m.assert_called_once_with("email", "bar@test.com", registration) + + def _setup_get_name_for_registration(self, callback_name: str) -> Mock: + """Registers either a get_username_for_registration callback or a + get_displayname_for_registration callback that appends "-foo" to the username the + client is trying to register. + """ + + async def callback(uia_results, params): self.assertIn(LoginType.DUMMY, uia_results) username = params["username"] return username + "-foo" - m = Mock(side_effect=get_username_for_registration) + m = Mock(side_effect=callback) password_auth_provider = self.hs.get_password_auth_provider() - password_auth_provider.get_username_for_registration_callbacks.append(m) + getattr(password_auth_provider, callback_name + "_callbacks").append(m) return m + def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict: + # Initiate the UIA flow. + channel = self.make_request( + "POST", + "register", + {"username": username, "type": "m.login.password", "password": "bar"}, + ) + self.assertEqual(channel.code, 401) + self.assertIn("session", channel.json_body) + + # Check that the callback hasn't been called yet. + m.assert_not_called() + + # Finish the UIA flow. + session = channel.json_body["session"] + channel = self.make_request( + "POST", + "register", + {"auth": {"session": session, "type": LoginType.DUMMY}}, + ) + self.assertEqual(channel.code, 200, channel.json_body) + return channel.json_body + def _get_login_flows(self) -> JsonDict: channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 70c621b825..482c90ef68 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -169,7 +169,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Register an AS user. user = self.register_user("user", "pass") token = self.login(user, "pass") - as_user = self.register_appservice_user("as_user_potato", self.appservice.token) + as_user, _ = self.register_appservice_user( + "as_user_potato", self.appservice.token + ) # Join the AS user to rooms owned by the normal user. public, private = self._create_rooms_and_inject_memberships( @@ -388,7 +390,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def test_handle_local_profile_change_with_appservice_user(self) -> None: # create user - as_user_id = self.register_appservice_user( + as_user_id, _ = self.register_appservice_user( "as_user_alice", self.appservice.token ) diff --git a/tests/http/test_webclient.py b/tests/http/test_webclient.py deleted file mode 100644 index ee5cf299f6..0000000000 --- a/tests/http/test_webclient.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2022 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from http import HTTPStatus -from typing import Dict - -from twisted.web.resource import Resource - -from synapse.app.homeserver import SynapseHomeServer -from synapse.config.server import HttpListenerConfig, HttpResourceConfig, ListenerConfig -from synapse.http.site import SynapseSite - -from tests.server import make_request -from tests.unittest import HomeserverTestCase, create_resource_tree, override_config - - -class WebClientTests(HomeserverTestCase): - @override_config( - { - "web_client_location": "https://example.org", - } - ) - def test_webclient_resolves_with_client_resource(self): - """ - Tests that both client and webclient resources can be accessed simultaneously. - - This is a regression test created in response to https://github.com/matrix-org/synapse/issues/11763. - """ - for resource_name_order_list in [ - ["webclient", "client"], - ["client", "webclient"], - ]: - # Create a dictionary from path regex -> resource - resource_dict: Dict[str, Resource] = {} - - for resource_name in resource_name_order_list: - resource_dict.update( - SynapseHomeServer._configure_named_resource(self.hs, resource_name) - ) - - # Create a root resource which ties the above resources together into one - root_resource = Resource() - create_resource_tree(resource_dict, root_resource) - - # Create a site configured with this resource to make HTTP requests against - listener_config = ListenerConfig( - port=8008, - bind_addresses=["127.0.0.1"], - type="http", - http_options=HttpListenerConfig( - resources=[HttpResourceConfig(names=resource_name_order_list)] - ), - ) - test_site = SynapseSite( - logger_name="synapse.access.http.fake", - site_tag=self.hs.config.server.server_name, - config=listener_config, - resource=root_resource, - server_version_string="1", - max_request_body_size=1234, - reactor=self.reactor, - ) - - # Attempt to make requests to endpoints on both the webclient and client resources - # on test_site. - self._request_client_and_webclient_resources(test_site) - - def _request_client_and_webclient_resources(self, test_site: SynapseSite) -> None: - """Make a request to an endpoint on both the webclient and client-server resources - of the given SynapseSite. - - Args: - test_site: The SynapseSite object to make requests against. - """ - - # Ensure that the *webclient* resource is behaving as expected (we get redirected to - # the configured web_client_location) - channel = make_request( - self.reactor, - site=test_site, - method="GET", - path="/_matrix/client", - ) - # Check that we are being redirected to the webclient location URI. - self.assertEqual(channel.code, HTTPStatus.FOUND) - self.assertEqual( - channel.headers.getRawHeaders("Location"), ["https://example.org"] - ) - - # Ensure that a request to the *client* resource works. - channel = make_request( - self.reactor, - site=test_site, - method="GET", - path="/_matrix/client/v3/login", - ) - self.assertEqual(channel.code, HTTPStatus.OK) - self.assertIn("flows", channel.json_body) diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py new file mode 100644 index 0000000000..e430941d27 --- /dev/null +++ b/tests/logging/test_opentracing.py @@ -0,0 +1,184 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactorClock + +from synapse.logging.context import ( + LoggingContext, + make_deferred_yieldable, + run_in_background, +) +from synapse.logging.opentracing import ( + start_active_span, + start_active_span_follows_from, +) +from synapse.util import Clock + +try: + from synapse.logging.scopecontextmanager import LogContextScopeManager +except ImportError: + LogContextScopeManager = None # type: ignore + +try: + import jaeger_client +except ImportError: + jaeger_client = None # type: ignore + +from tests.unittest import TestCase + + +class LogContextScopeManagerTestCase(TestCase): + if LogContextScopeManager is None: + skip = "Requires opentracing" # type: ignore[unreachable] + if jaeger_client is None: + skip = "Requires jaeger_client" # type: ignore[unreachable] + + def setUp(self) -> None: + # since this is a unit test, we don't really want to mess around with the + # global variables that power opentracing. We create our own tracer instance + # and test with it. + + scope_manager = LogContextScopeManager({}) + config = jaeger_client.config.Config( + config={}, service_name="test", scope_manager=scope_manager + ) + + self._reporter = jaeger_client.reporter.InMemoryReporter() + + self._tracer = config.create_tracer( + sampler=jaeger_client.ConstSampler(True), + reporter=self._reporter, + ) + + def test_start_active_span(self) -> None: + # the scope manager assumes a logging context of some sort. + with LoggingContext("root context"): + self.assertIsNone(self._tracer.active_span) + + # start_active_span should start and activate a span. + scope = start_active_span("span", tracer=self._tracer) + span = scope.span + self.assertEqual(self._tracer.active_span, span) + self.assertIsNotNone(span.start_time) + + # entering the context doesn't actually do a whole lot. + with scope as ctx: + self.assertIs(ctx, scope) + self.assertEqual(self._tracer.active_span, span) + + # ... but leaving it unsets the active span, and finishes the span. + self.assertIsNone(self._tracer.active_span) + self.assertIsNotNone(span.end_time) + + # the span should have been reported + self.assertEqual(self._reporter.get_spans(), [span]) + + def test_nested_spans(self) -> None: + """Starting two spans off inside each other should work""" + + with LoggingContext("root context"): + with start_active_span("root span", tracer=self._tracer) as root_scope: + self.assertEqual(self._tracer.active_span, root_scope.span) + + scope1 = start_active_span( + "child1", + tracer=self._tracer, + ) + self.assertEqual( + self._tracer.active_span, scope1.span, "child1 was not activated" + ) + self.assertEqual( + scope1.span.context.parent_id, root_scope.span.context.span_id + ) + + scope2 = start_active_span_follows_from( + "child2", + contexts=(scope1,), + tracer=self._tracer, + ) + self.assertEqual(self._tracer.active_span, scope2.span) + self.assertEqual( + scope2.span.context.parent_id, scope1.span.context.span_id + ) + + with scope1, scope2: + pass + + # the root scope should be restored + self.assertEqual(self._tracer.active_span, root_scope.span) + self.assertIsNotNone(scope2.span.end_time) + self.assertIsNotNone(scope1.span.end_time) + + self.assertIsNone(self._tracer.active_span) + + # the spans should be reported in order of their finishing. + self.assertEqual( + self._reporter.get_spans(), [scope2.span, scope1.span, root_scope.span] + ) + + def test_overlapping_spans(self) -> None: + """Overlapping spans which are not neatly nested should work""" + reactor = MemoryReactorClock() + clock = Clock(reactor) + + scopes = [] + + async def task(i: int): + scope = start_active_span( + f"task{i}", + tracer=self._tracer, + ) + scopes.append(scope) + + self.assertEqual(self._tracer.active_span, scope.span) + await clock.sleep(4) + self.assertEqual(self._tracer.active_span, scope.span) + scope.close() + + async def root(): + with start_active_span("root span", tracer=self._tracer) as root_scope: + self.assertEqual(self._tracer.active_span, root_scope.span) + scopes.append(root_scope) + + d1 = run_in_background(task, 1) + await clock.sleep(2) + d2 = run_in_background(task, 2) + + # because we did run_in_background, the active span should still be the + # root. + self.assertEqual(self._tracer.active_span, root_scope.span) + + await make_deferred_yieldable( + defer.gatherResults([d1, d2], consumeErrors=True) + ) + + self.assertEqual(self._tracer.active_span, root_scope.span) + + with LoggingContext("root context"): + # start the test off + d1 = defer.ensureDeferred(root()) + + # let the tasks complete + reactor.pump((2,) * 8) + + self.successResultOf(d1) + self.assertIsNone(self._tracer.active_span) + + # the spans should be reported in order of their finishing: task 1, task 2, + # root. + self.assertEqual( + self._reporter.get_spans(), + [scopes[1].span, scopes[2].span, scopes[0].span], + ) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index c068d329a9..e1e3fb97c5 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -571,9 +571,7 @@ class HTTPPusherTests(HomeserverTestCase): # Carry out our option-value specific test # # This push should still only contain an unread count of 1 (for 1 unread room) - self.assertEqual( - self.push_attempts[5][2]["notification"]["counts"]["unread"], 1 - ) + self._check_push_attempt(6, 1) @override_config({"push": {"group_unread_count_by_room": False}}) def test_push_unread_count_message_count(self): @@ -585,11 +583,9 @@ class HTTPPusherTests(HomeserverTestCase): # Carry out our option-value specific test # - # We're counting every unread message, so there should now be 4 since the + # We're counting every unread message, so there should now be 3 since the # last read receipt - self.assertEqual( - self.push_attempts[5][2]["notification"]["counts"]["unread"], 4 - ) + self._check_push_attempt(6, 3) def _test_push_unread_count(self): """ @@ -597,8 +593,9 @@ class HTTPPusherTests(HomeserverTestCase): Note that: * Sending messages will cause push notifications to go out to relevant users - * Sending a read receipt will cause a "badge update" notification to go out to - the user that sent the receipt + * Sending a read receipt will cause the HTTP pusher to check whether the unread + count has changed since the last push notification. If so, a "badge update" + notification goes out to the user that sent the receipt """ # Register the user who gets notified user_id = self.register_user("user", "pass") @@ -642,24 +639,74 @@ class HTTPPusherTests(HomeserverTestCase): # position in the room. We'll set the read position to this event in a moment first_message_event_id = response["event_id"] - # Advance time a bit (so the pusher will register something has happened) and - # make the push succeed - self.push_attempts[0][0].callback({}) + expected_push_attempts = 1 + self._check_push_attempt(expected_push_attempts, 0) + + self._send_read_request(access_token, first_message_event_id, room_id) + + # Unread count has not changed. Therefore, ensure that read request does not + # trigger a push notification. + self.assertEqual(len(self.push_attempts), 1) + + # Send another message + response2 = self.helper.send( + room_id, body="How's the weather today?", tok=other_access_token + ) + second_message_event_id = response2["event_id"] + + expected_push_attempts += 1 + + self._check_push_attempt(expected_push_attempts, 1) + + self._send_read_request(access_token, second_message_event_id, room_id) + expected_push_attempts += 1 + + self._check_push_attempt(expected_push_attempts, 0) + + # If we're grouping by room, sending more messages shouldn't increase the + # unread count, as they're all being sent in the same room. Otherwise, it + # should. Therefore, the last call to _check_push_attempt is done in the + # caller method. + self.helper.send(room_id, body="Hello?", tok=other_access_token) + expected_push_attempts += 1 + + self._advance_time_and_make_push_succeed(expected_push_attempts) + + self.helper.send(room_id, body="Hello??", tok=other_access_token) + expected_push_attempts += 1 + + self._advance_time_and_make_push_succeed(expected_push_attempts) + + self.helper.send(room_id, body="HELLO???", tok=other_access_token) + + def _advance_time_and_make_push_succeed(self, expected_push_attempts): self.pump() + self.push_attempts[expected_push_attempts - 1][0].callback({}) + def _check_push_attempt( + self, expected_push_attempts: int, expected_unread_count_last_push: int + ) -> None: + """ + Makes sure that the last expected push attempt succeeds and checks whether + it contains the expected unread count. + """ + self._advance_time_and_make_push_succeed(expected_push_attempts) # Check our push made it - self.assertEqual(len(self.push_attempts), 1) + self.assertEqual(len(self.push_attempts), expected_push_attempts) + _, push_url, push_body = self.push_attempts[expected_push_attempts - 1] self.assertEqual( - self.push_attempts[0][1], "http://example.com/_matrix/push/v1/notify" + push_url, + "http://example.com/_matrix/push/v1/notify", ) - # Check that the unread count for the room is 0 # # The unread count is zero as the user has no read receipt in the room yet self.assertEqual( - self.push_attempts[0][2]["notification"]["counts"]["unread"], 0 + push_body["notification"]["counts"]["unread"], + expected_unread_count_last_push, ) + def _send_read_request(self, access_token, message_event_id, room_id): # Now set the user's read receipt position to the first event # # This will actually trigger a new notification to be sent out so that @@ -667,56 +714,8 @@ class HTTPPusherTests(HomeserverTestCase): # count goes down channel = self.make_request( "POST", - "/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id), + "/rooms/%s/receipt/m.read/%s" % (room_id, message_event_id), {}, access_token=access_token, ) self.assertEqual(channel.code, 200, channel.json_body) - - # Advance time and make the push succeed - self.push_attempts[1][0].callback({}) - self.pump() - - # Unread count is still zero as we've read the only message in the room - self.assertEqual(len(self.push_attempts), 2) - self.assertEqual( - self.push_attempts[1][2]["notification"]["counts"]["unread"], 0 - ) - - # Send another message - self.helper.send( - room_id, body="How's the weather today?", tok=other_access_token - ) - - # Advance time and make the push succeed - self.push_attempts[2][0].callback({}) - self.pump() - - # This push should contain an unread count of 1 as there's now been one - # message since our last read receipt - self.assertEqual(len(self.push_attempts), 3) - self.assertEqual( - self.push_attempts[2][2]["notification"]["counts"]["unread"], 1 - ) - - # Since we're grouping by room, sending more messages shouldn't increase the - # unread count, as they're all being sent in the same room - self.helper.send(room_id, body="Hello?", tok=other_access_token) - - # Advance time and make the push succeed - self.pump() - self.push_attempts[3][0].callback({}) - - self.helper.send(room_id, body="Hello??", tok=other_access_token) - - # Advance time and make the push succeed - self.pump() - self.push_attempts[4][0].callback({}) - - self.helper.send(room_id, body="HELLO???", tok=other_access_token) - - # Advance time and make the push succeed - self.pump() - self.push_attempts[5][0].callback({}) - - self.assertEqual(len(self.push_attempts), 6) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index cb02eddf07..9fc50f8852 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -14,6 +14,7 @@ import logging from typing import Any, Dict, List, Optional, Tuple +from twisted.internet.address import IPv4Address from twisted.internet.protocol import Protocol from twisted.web.resource import Resource @@ -53,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): server_factory = ReplicationStreamProtocolFactory(hs) self.streamer = hs.get_replication_streamer() self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol( - None + IPv4Address("TCP", "127.0.0.1", 0) ) # Make a new HomeServer object for the worker @@ -345,7 +346,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.clock, repl_handler, ) - server = self.server_factory.buildProtocol(None) + server = self.server_factory.buildProtocol( + IPv4Address("TCP", "127.0.0.1", 0) + ) client_transport = FakeTransport(server, self.reactor) client.makeConnection(client_transport) diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py index 262c35cef3..545f11acd1 100644 --- a/tests/replication/tcp/test_remote_server_up.py +++ b/tests/replication/tcp/test_remote_server_up.py @@ -14,6 +14,7 @@ from typing import Tuple +from twisted.internet.address import IPv4Address from twisted.internet.interfaces import IProtocol from twisted.test.proto_helpers import StringTransport @@ -29,7 +30,7 @@ class RemoteServerUpTestCase(HomeserverTestCase): def _make_client(self) -> Tuple[IProtocol, StringTransport]: """Create a new direct TCP replication connection""" - proto = self.factory.buildProtocol(("127.0.0.1", 0)) + proto = self.factory.buildProtocol(IPv4Address("TCP", "127.0.0.1", 0)) transport = StringTransport() proto.makeConnection(transport) diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 89d85b0a17..51146c471d 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -486,8 +486,9 @@ class WhoamiTestCase(unittest.HomeserverTestCase): { "user_id": user_id, "device_id": device_id, - # Unstable until MSC3069 enters spec + # MSC3069 entered spec in Matrix 1.2 but maintained compatibility "org.matrix.msc3069.is_guest": False, + "is_guest": False, }, ) @@ -505,8 +506,9 @@ class WhoamiTestCase(unittest.HomeserverTestCase): { "user_id": user_id, "device_id": device_id, - # Unstable until MSC3069 enters spec + # MSC3069 entered spec in Matrix 1.2 but maintained compatibility "org.matrix.msc3069.is_guest": True, + "is_guest": True, }, ) @@ -528,8 +530,9 @@ class WhoamiTestCase(unittest.HomeserverTestCase): whoami, { "user_id": user_id, - # Unstable until MSC3069 enters spec + # MSC3069 entered spec in Matrix 1.2 but maintained compatibility "org.matrix.msc3069.is_guest": False, + "is_guest": False, }, ) self.assertFalse(hasattr(whoami, "device_id")) diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index 249808b031..989e801768 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus + import synapse.rest.admin from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.rest.client import capabilities, login @@ -28,7 +30,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor, clock): - self.url = b"/_matrix/client/r0/capabilities" + self.url = b"/capabilities" hs = self.setup_test_homeserver() self.config = hs.config self.auth_handler = hs.get_auth_handler() @@ -96,39 +98,20 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertFalse(capabilities["m.change_password"]["enabled"]) - def test_get_change_users_attributes_capabilities_when_msc3283_disabled(self): - """Test that per default msc3283 is disabled server returns `m.change_password`.""" + def test_get_change_users_attributes_capabilities(self): + """Test that server returns capabilities by default.""" access_token = self.login(self.localpart, self.password) channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.assertTrue(capabilities["m.change_password"]["enabled"]) - self.assertNotIn("org.matrix.msc3283.set_displayname", capabilities) - self.assertNotIn("org.matrix.msc3283.set_avatar_url", capabilities) - self.assertNotIn("org.matrix.msc3283.3pid_changes", capabilities) - - @override_config({"experimental_features": {"msc3283_enabled": True}}) - def test_get_change_users_attributes_capabilities_when_msc3283_enabled(self): - """Test if msc3283 is enabled server returns capabilities.""" - access_token = self.login(self.localpart, self.password) - - channel = self.make_request("GET", self.url, access_token=access_token) - capabilities = channel.json_body["capabilities"] + self.assertTrue(capabilities["m.set_displayname"]["enabled"]) + self.assertTrue(capabilities["m.set_avatar_url"]["enabled"]) + self.assertTrue(capabilities["m.3pid_changes"]["enabled"]) - self.assertEqual(channel.code, 200) - self.assertTrue(capabilities["m.change_password"]["enabled"]) - self.assertTrue(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) - self.assertTrue(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) - self.assertTrue(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) - - @override_config( - { - "enable_set_displayname": False, - "experimental_features": {"msc3283_enabled": True}, - } - ) + @override_config({"enable_set_displayname": False}) def test_get_set_displayname_capabilities_displayname_disabled(self): """Test if set displayname is disabled that the server responds it.""" access_token = self.login(self.localpart, self.password) @@ -136,15 +119,10 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["org.matrix.msc3283.set_displayname"]["enabled"]) - - @override_config( - { - "enable_set_avatar_url": False, - "experimental_features": {"msc3283_enabled": True}, - } - ) + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertFalse(capabilities["m.set_displayname"]["enabled"]) + + @override_config({"enable_set_avatar_url": False}) def test_get_set_avatar_url_capabilities_avatar_url_disabled(self): """Test if set avatar_url is disabled that the server responds it.""" access_token = self.login(self.localpart, self.password) @@ -152,24 +130,19 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"]) - - @override_config( - { - "enable_3pid_changes": False, - "experimental_features": {"msc3283_enabled": True}, - } - ) - def test_change_3pid_capabilities_3pid_disabled(self): + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertFalse(capabilities["m.set_avatar_url"]["enabled"]) + + @override_config({"enable_3pid_changes": False}) + def test_get_change_3pid_capabilities_3pid_disabled(self): """Test if change 3pid is disabled that the server responds it.""" access_token = self.login(self.localpart, self.password) channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] - self.assertEqual(channel.code, 200) - self.assertFalse(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"]) + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertFalse(capabilities["m.3pid_changes"]["enabled"]) @override_config({"experimental_features": {"msc3244_enabled": False}}) def test_get_does_not_include_msc3244_fields_when_disabled(self): diff --git a/tests/rest/client/test_device_lists.py b/tests/rest/client/test_device_lists.py new file mode 100644 index 0000000000..16070cf027 --- /dev/null +++ b/tests/rest/client/test_device_lists.py @@ -0,0 +1,155 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.rest import admin, devices, room, sync +from synapse.rest.client import account, login, register + +from tests import unittest + + +class DeviceListsTestCase(unittest.HomeserverTestCase): + """Tests regarding device list changes.""" + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + account.register_servlets, + room.register_servlets, + sync.register_servlets, + devices.register_servlets, + ] + + def test_receiving_local_device_list_changes(self): + """Tests that a local users that share a room receive each other's device list + changes. + """ + # Register two users + test_device_id = "TESTDEVICE" + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) + + bob_user_id = self.register_user("bob", "ponyponypony") + bob_access_token = self.login(bob_user_id, "ponyponypony") + + # Create a room for them to coexist peacefully in + new_room_id = self.helper.create_room_as( + alice_user_id, is_public=True, tok=alice_access_token + ) + self.assertIsNotNone(new_room_id) + + # Have Bob join the room + self.helper.invite( + new_room_id, alice_user_id, bob_user_id, tok=alice_access_token + ) + self.helper.join(new_room_id, bob_user_id, tok=bob_access_token) + + # Now have Bob initiate an initial sync (in order to get a since token) + channel = self.make_request( + "GET", + "/sync", + access_token=bob_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + next_batch_token = channel.json_body["next_batch"] + + # ...and then an incremental sync. This should block until the sync stream is woken up, + # which we hope will happen as a result of Alice updating their device list. + bob_sync_channel = self.make_request( + "GET", + f"/sync?since={next_batch_token}&timeout=30000", + access_token=bob_access_token, + # Start the request, then continue on. + await_result=False, + ) + + # Have alice update their device list + channel = self.make_request( + "PUT", + f"/devices/{test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=alice_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that bob's incremental sync contains the updated device list. + # If not, the client would only receive the device list update on the + # *next* sync. + bob_sync_channel.await_result() + self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) + + changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( + "changed", [] + ) + self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body) + + def test_not_receiving_local_device_list_changes(self): + """Tests a local users DO NOT receive device updates from each other if they do not + share a room. + """ + # Register two users + test_device_id = "TESTDEVICE" + alice_user_id = self.register_user("alice", "correcthorse") + alice_access_token = self.login( + alice_user_id, "correcthorse", device_id=test_device_id + ) + + bob_user_id = self.register_user("bob", "ponyponypony") + bob_access_token = self.login(bob_user_id, "ponyponypony") + + # These users do not share a room. They are lonely. + + # Have Bob initiate an initial sync (in order to get a since token) + channel = self.make_request( + "GET", + "/sync", + access_token=bob_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + next_batch_token = channel.json_body["next_batch"] + + # ...and then an incremental sync. This should block until the sync stream is woken up, + # which we hope will happen as a result of Alice updating their device list. + bob_sync_channel = self.make_request( + "GET", + f"/sync?since={next_batch_token}&timeout=1000", + access_token=bob_access_token, + # Start the request, then continue on. + await_result=False, + ) + + # Have alice update their device list + channel = self.make_request( + "PUT", + f"/devices/{test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=alice_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that bob's incremental sync does not contain the updated device list. + bob_sync_channel.await_result() + self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body) + + changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get( + "changed", [] + ) + self.assertNotIn( + alice_user_id, changed_device_lists, bob_sync_channel.json_body + ) diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 407dd32a73..0f1c47dcbb 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -1154,7 +1154,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase): class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): servlets = [register.register_servlets] - url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity" + url = "/_matrix/client/v1/register/m.login.registration_token/validity" def default_config(self): config = super().default_config() diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 96ae7790bb..dfd9ffcb93 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -21,7 +21,8 @@ from unittest.mock import patch from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, register, relations, room, sync -from synapse.types import JsonDict +from synapse.storage.relations import RelationPaginationToken +from synapse.types import JsonDict, StreamToken from tests import unittest from tests.server import FakeChannel @@ -168,24 +169,28 @@ class RelationsTestCase(unittest.HomeserverTestCase): """Tests that calling pagination API correctly the latest relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) + first_annotation_id = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") self.assertEquals(200, channel.code, channel.json_body) - annotation_id = channel.json_body["event_id"] + second_annotation_id = channel.json_body["event_id"] channel = self.make_request( "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1" - % (self.room, self.parent_id), + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - # We expect to get back a single pagination result, which is the full - # relation event we sent above. + # We expect to get back a single pagination result, which is the latest + # full relation event we sent above. self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) self.assert_dict( - {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"}, + { + "event_id": second_annotation_id, + "sender": self.user_id, + "type": "m.reaction", + }, channel.json_body["chunk"][0], ) @@ -200,6 +205,36 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel.json_body.get("next_batch"), str, channel.json_body ) + # Request the relations again, but with a different direction. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations" + f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # We expect to get back a single pagination result, which is the earliest + # full relation event we sent above. + self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + { + "event_id": first_annotation_id, + "sender": self.user_id, + "type": "m.reaction", + }, + channel.json_body["chunk"][0], + ) + + def _stream_token_to_relation_token(self, token: str) -> str: + """Convert a StreamToken into a legacy token (RelationPaginationToken).""" + room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key + return self.get_success( + RelationPaginationToken( + topological=room_key.topological, stream=room_key.stream + ).to_string(self.store) + ) + def test_repeated_paginate_relations(self): """Test that if we paginate using a limit and tokens then we get the expected events. @@ -213,7 +248,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) expected_event_ids.append(channel.json_body["event_id"]) - prev_token: Optional[str] = None + prev_token = "" found_event_ids: List[str] = [] for _ in range(20): from_token = "" @@ -222,8 +257,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s" - % (self.room, self.parent_id, from_token), + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) @@ -241,6 +275,93 @@ class RelationsTestCase(unittest.HomeserverTestCase): found_event_ids.reverse() self.assertEquals(found_event_ids, expected_event_ids) + # Reset and try again, but convert the tokens to the legacy format. + prev_token = "" + found_event_ids = [] + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + self._stream_token_to_relation_token(prev_token) + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEquals(found_event_ids, expected_event_ids) + + def test_pagination_from_sync_and_messages(self): + """Pagination tokens from /sync and /messages can be used to paginate /relations.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") + self.assertEquals(200, channel.code, channel.json_body) + annotation_id = channel.json_body["event_id"] + # Send an event after the relation events. + self.helper.send(self.room, body="Latest event", tok=self.user_token) + + # Request /sync, limiting it such that only the latest event is returned + # (and not the relation). + filter = urllib.parse.quote_plus( + '{"room": {"timeline": {"limit": 1}}}'.encode() + ) + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token + ) + self.assertEquals(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + sync_prev_batch = room_timeline["prev_batch"] + self.assertIsNotNone(sync_prev_batch) + # Ensure the relation event is not in the batch returned from /sync. + self.assertNotIn( + annotation_id, [ev["event_id"] for ev in room_timeline["events"]] + ) + + # Request /messages, limiting it such that only the latest event is + # returned (and not the relation). + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b&limit=1", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + messages_end = channel.json_body["end"] + self.assertIsNotNone(messages_end) + # Ensure the relation event is not in the chunk returned from /messages. + self.assertNotIn( + annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] + ) + + # Request /relations with the pagination tokens received from both the + # /sync and /messages responses above, in turn. + # + # This is a tiny bit silly since the client wouldn't know the parent ID + # from the requests above; consider the parent ID to be known from a + # previous /sync. + for from_token in (sync_prev_batch, messages_end): + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # The relation should be in the returned chunk. + self.assertIn( + annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] + ) + def test_aggregation_pagination_groups(self): """Test that we can paginate annotation groups correctly.""" @@ -337,7 +458,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") self.assertEquals(200, channel.code, channel.json_body) - prev_token: Optional[str] = None + prev_token = "" found_event_ids: List[str] = [] encoded_key = urllib.parse.quote_plus("👍".encode()) for _ in range(20): @@ -347,15 +468,42 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "/_matrix/client/unstable/rooms/%s" - "/aggregations/%s/%s/m.reaction/%s?limit=1%s" - % ( - self.room, - self.parent_id, - RelationTypes.ANNOTATION, - encoded_key, - from_token, - ), + f"/_matrix/client/unstable/rooms/{self.room}" + f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" + f"/m.reaction/{encoded_key}?limit=1{from_token}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEquals(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEquals(found_event_ids, expected_event_ids) + + # Reset and try again, but convert the tokens to the legacy format. + prev_token = "" + found_event_ids = [] + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + self._stream_token_to_relation_token(prev_token) + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}" + f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" + f"/m.reaction/{encoded_key}?limit=1{from_token}", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) @@ -453,7 +601,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): ) self.assertEquals(400, channel.code, channel.json_body) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + @unittest.override_config( + {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}} + ) def test_bundled_aggregations(self): """ Test that annotations, references, and threads get correctly bundled. @@ -579,6 +729,23 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertTrue(room_timeline["limited"]) assert_bundle(self._find_event_in_chunk(room_timeline["events"])) + # Request search. + channel = self.make_request( + "POST", + "/search", + # Search term matches the parent message. + content={"search_categories": {"room_events": {"search_term": "Hi"}}}, + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + chunk = [ + result["result"] + for result in channel.json_body["search_categories"]["room_events"][ + "results" + ] + ] + assert_bundle(self._find_event_in_chunk(chunk)) + def test_aggregation_get_event_for_annotation(self): """Test that annotations do not get bundled aggregations included when directly requested. @@ -759,6 +926,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) self.assertNotIn("m.relations", channel.json_body["unsigned"]) + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) def test_edit(self): """Test that a simple edit works.""" @@ -825,6 +993,23 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertTrue(room_timeline["limited"]) assert_bundle(self._find_event_in_chunk(room_timeline["events"])) + # Request search. + channel = self.make_request( + "POST", + "/search", + # Search term matches the parent message. + content={"search_categories": {"room_events": {"search_term": "Hi"}}}, + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + chunk = [ + result["result"] + for result in channel.json_body["search_categories"]["room_events"][ + "results" + ] + ] + assert_bundle(self._find_event_in_chunk(chunk)) + def test_multi_edit(self): """Test that multiple edits, including attempts by people who shouldn't be allowed, are correctly handled. @@ -938,6 +1123,48 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_edit_thread(self): + """Test that editing a thread works.""" + + # Create a thread and edit the last event. + channel = self._send_relation( + RelationTypes.THREAD, + "m.room.message", + content={"msgtype": "m.text", "body": "A threaded reply!"}, + ) + self.assertEquals(200, channel.code, channel.json_body) + threaded_event_id = channel.json_body["event_id"] + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + parent_id=threaded_event_id, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # Fetch the thread root, to get the bundled aggregation for the thread. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # We expect that the edit message appears in the thread summary in the + # unsigned relations section. + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.THREAD, relations_dict) + + thread_summary = relations_dict[RelationTypes.THREAD] + self.assertIn("latest_event", thread_summary) + latest_event_in_thread = thread_summary["latest_event"] + self.assertEquals( + latest_event_in_thread["content"]["body"], "I've been edited!" + ) + def test_edit_edit(self): """Test that an edit cannot be edited.""" new_body = {"msgtype": "m.text", "body": "Initial edit"} diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py index 721454c187..e9f8704035 100644 --- a/tests/rest/client/test_room_batch.py +++ b/tests/rest/client/test_room_batch.py @@ -89,7 +89,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase): self.clock = clock self.storage = hs.get_storage() - self.virtual_user_id = self.register_appservice_user( + self.virtual_user_id, _ = self.register_appservice_user( "as_user_potato", self.appservice.token ) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 10a4a4dc5e..b7f086927b 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -18,7 +18,7 @@ """Tests REST events for /rooms paths.""" import json -from typing import Dict, Iterable, List, Optional +from typing import Iterable, List from unittest.mock import Mock, call from urllib import parse as urlparse @@ -35,7 +35,7 @@ from synapse.api.errors import Codes, HttpResponseException from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client import account, directory, login, profile, room, sync -from synapse.types import JsonDict, Requester, RoomAlias, UserID, create_requester +from synapse.types import JsonDict, RoomAlias, UserID, create_requester from synapse.util.stringutils import random_string from tests import unittest @@ -674,121 +674,6 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request("POST", "/createRoom", content) self.assertEqual(200, channel.code) - def test_spamchecker_invites(self): - """Tests the user_may_create_room_with_invites spam checker callback.""" - - # Mock do_3pid_invite, so we don't fail from failing to send a 3PID invite to an - # IS. - async def do_3pid_invite( - room_id: str, - inviter: UserID, - medium: str, - address: str, - id_server: str, - requester: Requester, - txn_id: Optional[str], - id_access_token: Optional[str] = None, - ) -> int: - return 0 - - do_3pid_invite_mock = Mock(side_effect=do_3pid_invite) - self.hs.get_room_member_handler().do_3pid_invite = do_3pid_invite_mock - - # Add a mock callback for user_may_create_room_with_invites. Make it allow any - # room creation request for now. - return_value = True - - async def user_may_create_room_with_invites( - user: str, - invites: List[str], - threepid_invites: List[Dict[str, str]], - ) -> bool: - return return_value - - callback_mock = Mock(side_effect=user_may_create_room_with_invites) - self.hs.get_spam_checker()._user_may_create_room_with_invites_callbacks.append( - callback_mock, - ) - - # The MXIDs we'll try to invite. - invited_mxids = [ - "@alice1:red", - "@alice2:red", - "@alice3:red", - "@alice4:red", - ] - - # The 3PIDs we'll try to invite. - invited_3pids = [ - { - "id_server": "example.com", - "id_access_token": "sometoken", - "medium": "email", - "address": "alice1@example.com", - }, - { - "id_server": "example.com", - "id_access_token": "sometoken", - "medium": "email", - "address": "alice2@example.com", - }, - { - "id_server": "example.com", - "id_access_token": "sometoken", - "medium": "email", - "address": "alice3@example.com", - }, - ] - - # Create a room and invite the Matrix users, and check that it succeeded. - channel = self.make_request( - "POST", - "/createRoom", - json.dumps({"invite": invited_mxids}).encode("utf8"), - ) - self.assertEqual(200, channel.code) - - # Check that the callback was called with the right arguments. - expected_call_args = ((self.user_id, invited_mxids, []),) - self.assertEquals( - callback_mock.call_args, - expected_call_args, - callback_mock.call_args, - ) - - # Create a room and invite the 3PIDs, and check that it succeeded. - channel = self.make_request( - "POST", - "/createRoom", - json.dumps({"invite_3pid": invited_3pids}).encode("utf8"), - ) - self.assertEqual(200, channel.code) - - # Check that do_3pid_invite was called the right amount of time - self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids)) - - # Check that the callback was called with the right arguments. - expected_call_args = ((self.user_id, [], invited_3pids),) - self.assertEquals( - callback_mock.call_args, - expected_call_args, - callback_mock.call_args, - ) - - # Now deny any room creation. - return_value = False - - # Create a room and invite the 3PIDs, and check that it failed. - channel = self.make_request( - "POST", - "/createRoom", - json.dumps({"invite_3pid": invited_3pids}).encode("utf8"), - ) - self.assertEqual(403, channel.code) - - # Check that do_3pid_invite wasn't called this time. - self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids)) - def test_spam_checker_may_join_room(self): """Tests that the user_may_join_room spam checker callback is correctly bypassed when creating a new room. diff --git a/tests/rest/client/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py index 6db7062a8e..e2ed14457f 100644 --- a/tests/rest/client/test_sendtodevice.py +++ b/tests/rest/client/test_sendtodevice.py @@ -198,3 +198,43 @@ class SendToDeviceTestCase(HomeserverTestCase): "content": {"idx": 3}, }, ) + + def test_limited_sync(self): + """If a limited sync for to-devices happens the next /sync should respond immediately.""" + + self.register_user("u1", "pass") + user1_tok = self.login("u1", "pass", "d1") + + user2 = self.register_user("u2", "pass") + user2_tok = self.login("u2", "pass", "d2") + + # Do an initial sync + channel = self.make_request("GET", "/sync", access_token=user2_tok) + self.assertEqual(channel.code, 200, channel.result) + sync_token = channel.json_body["next_batch"] + + # Send 150 to-device messages. We limit to 100 in `/sync` + for i in range(150): + test_msg = {"foo": "bar"} + chan = self.make_request( + "PUT", + f"/_matrix/client/r0/sendToDevice/m.test/1234-{i}", + content={"messages": {user2: {"d2": test_msg}}}, + access_token=user1_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + channel = self.make_request( + "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + messages = channel.json_body.get("to_device", {}).get("events", []) + self.assertEqual(len(messages), 100) + sync_token = channel.json_body["next_batch"] + + channel = self.make_request( + "GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok + ) + self.assertEqual(channel.code, 200, channel.result) + messages = channel.json_body.get("to_device", {}).get("events", []) + self.assertEqual(len(messages), 50) diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index c427686376..cd4af2b1f3 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -23,7 +23,7 @@ from synapse.api.constants import ( ReadReceiptEventFields, RelationTypes, ) -from synapse.rest.client import knock, login, read_marker, receipts, room, sync +from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync from tests import unittest from tests.federation.transport.test_knocking import ( @@ -710,3 +710,58 @@ class SyncCacheTestCase(unittest.HomeserverTestCase): channel.await_result(timeout_ms=9900) channel.await_result(timeout_ms=200) self.assertEqual(channel.code, 200, channel.json_body) + + +class DeviceListSyncTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + devices.register_servlets, + ] + + def test_user_with_no_rooms_receives_self_device_list_updates(self): + """Tests that a user with no rooms still receives their own device list updates""" + device_id = "TESTDEVICE" + + # Register a user and login, creating a device + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey", device_id=device_id) + + # Request an initial sync + channel = self.make_request("GET", "/sync", access_token=self.tok) + self.assertEqual(channel.code, 200, channel.json_body) + next_batch = channel.json_body["next_batch"] + + # Now, make an incremental sync request. + # It won't return until something has happened + incremental_sync_channel = self.make_request( + "GET", + f"/sync?since={next_batch}&timeout=30000", + access_token=self.tok, + await_result=False, + ) + + # Change our device's display name + channel = self.make_request( + "PUT", + f"devices/{device_id}", + { + "display_name": "freeze ray", + }, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # The sync should now have returned + incremental_sync_channel.await_result(timeout_ms=20000) + self.assertEqual(incremental_sync_channel.code, 200, channel.json_body) + + # We should have received notification that the (user's) device has changed + device_list_changes = incremental_sync_channel.json_body.get( + "device_lists", {} + ).get("changed", []) + + self.assertIn( + self.user_id, device_list_changes, incremental_sync_channel.json_body + ) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 4e71b6ec12..ac6b86ff6b 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -107,6 +107,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): return hs def prepare(self, reactor, clock, homeserver): + super().prepare(reactor, clock, homeserver) # Create some users and a room to play with during the tests self.user_id = self.register_user("kermit", "monkey") self.invitee = self.register_user("invitee", "hackme") @@ -473,8 +474,6 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): def _send_event_over_federation(self) -> None: """Send a dummy event over federation and check that the request succeeds.""" body = { - "origin": self.hs.config.server.server_name, - "origin_server_ts": self.clock.time_msec(), "pdus": [ { "sender": self.user_id, @@ -492,11 +491,10 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): ], } - channel = self.make_request( + channel = self.make_signed_federation_request( method="PUT", path="/_matrix/federation/v1/send/1", content=body, - federation_auth_origin=self.hs.config.server.server_name.encode("utf8"), ) self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 8424383580..2b3fdadffa 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -31,6 +31,7 @@ from typing import ( overload, ) from unittest.mock import patch +from urllib.parse import urlencode import attr from typing_extensions import Literal @@ -105,9 +106,13 @@ class RestHelper: default room version. tok: The access token to use in the request. expect_code: The expected HTTP response code. + extra_content: Extra keys to include in the body of the /createRoom request. + Note that if is_public is set, the "visibility" key will be overridden. + If room_version is set, the "room_version" key will be overridden. + custom_headers: HTTP headers to include in the request. Returns: - The ID of the newly created room. + The ID of the newly created room, or None if the request failed. """ temp_id = self.auth_user_id self.auth_user_id = room_creator @@ -147,12 +152,20 @@ class RestHelper: expect_code=expect_code, ) - def join(self, room=None, user=None, expect_code=200, tok=None): + def join( + self, + room: str, + user: Optional[str] = None, + expect_code: int = 200, + tok: Optional[str] = None, + appservice_user_id: Optional[str] = None, + ) -> None: self.change_membership( room=room, src=user, targ=user, tok=tok, + appservice_user_id=appservice_user_id, membership=Membership.JOIN, expect_code=expect_code, ) @@ -209,11 +222,12 @@ class RestHelper: def change_membership( self, room: str, - src: str, - targ: str, + src: Optional[str], + targ: Optional[str], membership: str, extra_data: Optional[dict] = None, tok: Optional[str] = None, + appservice_user_id: Optional[str] = None, expect_code: int = 200, expect_errcode: Optional[str] = None, ) -> None: @@ -227,15 +241,26 @@ class RestHelper: membership: The type of membership event extra_data: Extra information to include in the content of the event tok: The user access token to use + appservice_user_id: The `user_id` URL parameter to pass. + This allows driving an application service user + using an application service access token in `tok`. expect_code: The expected HTTP response code expect_errcode: The expected Matrix error code """ temp_id = self.auth_user_id self.auth_user_id = src - path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ) + path = f"/_matrix/client/r0/rooms/{room}/state/m.room.member/{targ}" + url_params: Dict[str, str] = {} + if tok: - path = path + "?access_token=%s" % tok + url_params["access_token"] = tok + + if appservice_user_id: + url_params["user_id"] = appservice_user_id + + if url_params: + path += "?" + urlencode(url_params) data = {"membership": membership} data.update(extra_data or {}) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 53f6186213..da2c533260 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -243,6 +243,78 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430") + def test_video_rejected(self): + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + + end_content = b"anything" + + channel = self.make_request( + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b"Content-Type: video/mp4\r\n\r\n" + ) + % (len(end_content)) + + end_content + ) + + self.pump() + self.assertEqual(channel.code, 502) + self.assertEqual( + channel.json_body, + { + "errcode": "M_UNKNOWN", + "error": "Requested file's content type not allowed for this operation: video/mp4", + }, + ) + + def test_audio_rejected(self): + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + + end_content = b"anything" + + channel = self.make_request( + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b"Content-Type: audio/aac\r\n\r\n" + ) + % (len(end_content)) + + end_content + ) + + self.pump() + self.assertEqual(channel.code, 502) + self.assertEqual( + channel.json_body, + { + "errcode": "M_UNKNOWN", + "error": "Requested file's content type not allowed for this operation: audio/aac", + }, + ) + def test_non_ascii_preview_content_type(self): self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] diff --git a/tests/storage/databases/test_state_store.py b/tests/storage/databases/test_state_store.py new file mode 100644 index 0000000000..3a4a4a3a29 --- /dev/null +++ b/tests/storage/databases/test_state_store.py @@ -0,0 +1,283 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import typing +from typing import Dict, List, Sequence, Tuple +from unittest.mock import patch + +from twisted.internet.defer import Deferred, ensureDeferred +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import EventTypes +from synapse.storage.state import StateFilter +from synapse.types import StateMap +from synapse.util import Clock + +from tests.unittest import HomeserverTestCase + +if typing.TYPE_CHECKING: + from synapse.server import HomeServer + +# StateFilter for ALL non-m.room.member state events +ALL_NON_MEMBERS_STATE_FILTER = StateFilter.freeze( + types={EventTypes.Member: set()}, + include_others=True, +) + +FAKE_STATE = { + (EventTypes.Member, "@alice:test"): "join", + (EventTypes.Member, "@bob:test"): "leave", + (EventTypes.Member, "@charlie:test"): "invite", + ("test.type", "a"): "AAA", + ("test.type", "b"): "BBB", + ("other.event.type", "state.key"): "123", +} + + +class StateGroupInflightCachingTestCase(HomeserverTestCase): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: "HomeServer" + ) -> None: + self.state_storage = homeserver.get_storage().state + self.state_datastore = homeserver.get_datastores().state + # Patch out the `_get_state_groups_from_groups`. + # This is useful because it lets us pretend we have a slow database. + get_state_groups_patch = patch.object( + self.state_datastore, + "_get_state_groups_from_groups", + self._fake_get_state_groups_from_groups, + ) + get_state_groups_patch.start() + + self.addCleanup(get_state_groups_patch.stop) + self.get_state_group_calls: List[ + Tuple[Tuple[int, ...], StateFilter, Deferred[Dict[int, StateMap[str]]]] + ] = [] + + def _fake_get_state_groups_from_groups( + self, groups: Sequence[int], state_filter: StateFilter + ) -> "Deferred[Dict[int, StateMap[str]]]": + d: Deferred[Dict[int, StateMap[str]]] = Deferred() + self.get_state_group_calls.append((tuple(groups), state_filter, d)) + return d + + def _complete_request_fake( + self, + groups: Tuple[int, ...], + state_filter: StateFilter, + d: "Deferred[Dict[int, StateMap[str]]]", + ) -> None: + """ + Assemble a fake database response and complete the database request. + """ + + # Return a filtered copy of the fake state + d.callback({group: state_filter.filter_state(FAKE_STATE) for group in groups}) + + def test_duplicate_requests_deduplicated(self) -> None: + """ + Tests that duplicate requests for state are deduplicated. + + This test: + - requests some state (state group 42, 'all' state filter) + - requests it again, before the first request finishes + - checks to see that only one database query was made + - completes the database query + - checks that both requests see the same retrieved state + """ + req1 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.all() + ) + ) + self.pump(by=0.1) + + # This should have gone to the database + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + + req2 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.all() + ) + ) + self.pump(by=0.1) + + # No more calls should have gone to the database + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + self.assertFalse(req2.called) + + groups, sf, d = self.get_state_group_calls[0] + self.assertEqual(groups, (42,)) + self.assertEqual(sf, StateFilter.all()) + + # Now we can complete the request + self._complete_request_fake(groups, sf, d) + + self.assertEqual(self.get_success(req1), FAKE_STATE) + self.assertEqual(self.get_success(req2), FAKE_STATE) + + def test_smaller_request_deduplicated(self) -> None: + """ + Tests that duplicate requests for state are deduplicated. + + This test: + - requests some state (state group 42, 'all' state filter) + - requests a subset of that state, before the first request finishes + - checks to see that only one database query was made + - completes the database query + - checks that both requests see the correct retrieved state + """ + req1 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.from_types((("test.type", None),)) + ) + ) + self.pump(by=0.1) + + # This should have gone to the database + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + + req2 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.from_types((("test.type", "b"),)) + ) + ) + self.pump(by=0.1) + + # No more calls should have gone to the database, because the second + # request was already in the in-flight cache! + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + self.assertFalse(req2.called) + + groups, sf, d = self.get_state_group_calls[0] + self.assertEqual(groups, (42,)) + # The state filter is expanded internally for increased cache hit rate, + # so we the database sees a wider state filter than requested. + self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER) + + # Now we can complete the request + self._complete_request_fake(groups, sf, d) + + self.assertEqual( + self.get_success(req1), + {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"}, + ) + self.assertEqual(self.get_success(req2), {("test.type", "b"): "BBB"}) + + def test_partially_overlapping_request_deduplicated(self) -> None: + """ + Tests that partially-overlapping requests are partially deduplicated. + + This test: + - requests a single type of wildcard state + (This is internally expanded to be all non-member state) + - requests the entire state in parallel + - checks to see that two database queries were made, but that the second + one is only for member state. + - completes the database queries + - checks that both requests have the correct result. + """ + + req1 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.from_types((("test.type", None),)) + ) + ) + self.pump(by=0.1) + + # This should have gone to the database + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + + req2 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.all() + ) + ) + self.pump(by=0.1) + + # Because it only partially overlaps, this also went to the database + self.assertEqual(len(self.get_state_group_calls), 2) + self.assertFalse(req1.called) + self.assertFalse(req2.called) + + # First request: + groups, sf, d = self.get_state_group_calls[0] + self.assertEqual(groups, (42,)) + # The state filter is expanded internally for increased cache hit rate, + # so we the database sees a wider state filter than requested. + self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER) + self._complete_request_fake(groups, sf, d) + + # Second request: + groups, sf, d = self.get_state_group_calls[1] + self.assertEqual(groups, (42,)) + # The state filter is narrowed to only request membership state, because + # the remainder of the state is already being queried in the first request! + self.assertEqual( + sf, StateFilter.freeze({EventTypes.Member: None}, include_others=False) + ) + self._complete_request_fake(groups, sf, d) + + # Check the results are correct + self.assertEqual( + self.get_success(req1), + {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"}, + ) + self.assertEqual(self.get_success(req2), FAKE_STATE) + + def test_in_flight_requests_stop_being_in_flight(self) -> None: + """ + Tests that in-flight request deduplication doesn't somehow 'hold on' + to completed requests: once they're done, they're taken out of the + in-flight cache. + """ + req1 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.all() + ) + ) + self.pump(by=0.1) + + # This should have gone to the database + self.assertEqual(len(self.get_state_group_calls), 1) + self.assertFalse(req1.called) + + # Complete the request right away. + self._complete_request_fake(*self.get_state_group_calls[0]) + self.assertTrue(req1.called) + + # Send off another request + req2 = ensureDeferred( + self.state_datastore._get_state_for_group_using_inflight_cache( + 42, StateFilter.all() + ) + ) + self.pump(by=0.1) + + # It should have gone to the database again, because the previous request + # isn't in-flight and therefore isn't available for deduplication. + self.assertEqual(len(self.get_state_group_calls), 2) + self.assertFalse(req2.called) + + # Complete the request right away. + self._complete_request_fake(*self.get_state_group_calls[1]) + self.assertTrue(req2.called) + groups, sf, d = self.get_state_group_calls[0] + + self.assertEqual(self.get_success(req1), FAKE_STATE) + self.assertEqual(self.get_success(req2), FAKE_STATE) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 329490caad..ddcb7f5549 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -266,7 +266,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): service = Mock(id=self.as_list[0]["id"]) events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) txn = self.get_success( - defer.ensureDeferred(self.store.create_appservice_txn(service, events, [])) + defer.ensureDeferred( + self.store.create_appservice_txn(service, events, [], []) + ) ) self.assertEquals(txn.id, 1) self.assertEquals(txn.events, events) @@ -280,7 +282,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._set_last_txn(service.id, 9643)) # AS is falling behind self.get_success(self._insert_txn(service.id, 9644, events)) self.get_success(self._insert_txn(service.id, 9645, events)) - txn = self.get_success(self.store.create_appservice_txn(service, events, [])) + txn = self.get_success( + self.store.create_appservice_txn(service, events, [], []) + ) self.assertEquals(txn.id, 9646) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) @@ -291,7 +295,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): service = Mock(id=self.as_list[0]["id"]) events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")]) self.get_success(self._set_last_txn(service.id, 9643)) - txn = self.get_success(self.store.create_appservice_txn(service, events, [])) + txn = self.get_success( + self.store.create_appservice_txn(service, events, [], []) + ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) @@ -313,7 +319,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.get_success(self._insert_txn(self.as_list[2]["id"], 10, events)) self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events)) - txn = self.get_success(self.store.create_appservice_txn(service, events, [])) + txn = self.get_success( + self.store.create_appservice_txn(service, events, [], []) + ) self.assertEquals(txn.id, 9644) self.assertEquals(txn.events, events) self.assertEquals(txn.service, service) @@ -481,10 +489,10 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): ValueError, ) - def test_set_type_stream_id_for_appservice(self) -> None: + def test_set_appservice_stream_type_pos(self) -> None: read_receipt_value = 1024 self.get_success( - self.store.set_type_stream_id_for_appservice( + self.store.set_appservice_stream_type_pos( self.service, "read_receipt", read_receipt_value ) ) @@ -494,7 +502,7 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): self.assertEqual(result, read_receipt_value) self.get_success( - self.store.set_type_stream_id_for_appservice( + self.store.set_appservice_stream_type_pos( self.service, "presence", read_receipt_value ) ) @@ -503,9 +511,9 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase): ) self.assertEqual(result, read_receipt_value) - def test_set_type_stream_id_for_appservice_invalid_type(self) -> None: + def test_set_appservice_stream_type_pos_invalid_type(self) -> None: self.get_failure( - self.store.set_type_stream_id_for_appservice(self.service, "foobar", 1024), + self.store.set_appservice_stream_type_pos(self.service, "foobar", 1024), ValueError, ) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 2bc89512f8..667ca90a4d 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -260,16 +260,16 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"]) auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"])) - self.assertEqual(auth_chain_ids, ["k"]) + self.assertEqual(auth_chain_ids, {"k"}) auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"])) - self.assertEqual(auth_chain_ids, ["j"]) + self.assertEqual(auth_chain_ids, {"j"}) # j and k have no parents. auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"])) - self.assertEqual(auth_chain_ids, []) + self.assertEqual(auth_chain_ids, set()) auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"])) - self.assertEqual(auth_chain_ids, []) + self.assertEqual(auth_chain_ids, set()) # More complex input sequences. auth_chain_ids = self.get_success( diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index f462a8b1c7..a8639d8f82 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -329,3 +329,110 @@ class ExtremPruneTestCase(HomeserverTestCase): # Check the new extremity is just the new remote event. self.assert_extremities([local_message_event_id, remote_event_2.event_id]) + + +class InvalideUsersInRoomCacheTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.state = self.hs.get_state_handler() + self.persistence = self.hs.get_storage().persistence + self.store = self.hs.get_datastore() + + def test_remote_user_rooms_cache_invalidated(self): + """Test that if the server leaves a room the `get_rooms_for_user` cache + is invalidated for remote users. + """ + + # Set up a room with a local and remote user in it. + user_id = self.register_user("user", "pass") + token = self.login("user", "pass") + + room_id = self.helper.create_room_as( + "user", room_version=RoomVersions.V6.identifier, tok=token + ) + + body = self.helper.send(room_id, body="Test", tok=token) + local_message_event_id = body["event_id"] + + # Fudge a join event for a remote user. + remote_user = "@user:other" + remote_event_1 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": remote_user, + "content": {"membership": Membership.JOIN}, + "room_id": room_id, + "sender": remote_user, + "depth": 5, + "prev_events": [local_message_event_id], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + context = self.get_success(self.state.compute_event_context(remote_event_1)) + self.get_success(self.persistence.persist_event(remote_event_1, context)) + + # Call `get_rooms_for_user` to add the remote user to the cache + rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) + self.assertEqual(set(rooms), {room_id}) + + # Now we have the local server leave the room, and check that calling + # `get_user_in_room` for the remote user no longer includes the room. + self.helper.leave(room_id, user_id, tok=token) + + rooms = self.get_success(self.store.get_rooms_for_user(remote_user)) + self.assertEqual(set(rooms), set()) + + def test_room_remote_user_cache_invalidated(self): + """Test that if the server leaves a room the `get_users_in_room` cache + is invalidated for remote users. + """ + + # Set up a room with a local and remote user in it. + user_id = self.register_user("user", "pass") + token = self.login("user", "pass") + + room_id = self.helper.create_room_as( + "user", room_version=RoomVersions.V6.identifier, tok=token + ) + + body = self.helper.send(room_id, body="Test", tok=token) + local_message_event_id = body["event_id"] + + # Fudge a join event for a remote user. + remote_user = "@user:other" + remote_event_1 = event_from_pdu_json( + { + "type": EventTypes.Member, + "state_key": remote_user, + "content": {"membership": Membership.JOIN}, + "room_id": room_id, + "sender": remote_user, + "depth": 5, + "prev_events": [local_message_event_id], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + RoomVersions.V6, + ) + + context = self.get_success(self.state.compute_event_context(remote_event_1)) + self.get_success(self.persistence.persist_event(remote_event_1, context)) + + # Call `get_users_in_room` to add the remote user to the cache + users = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual(set(users), {user_id, remote_user}) + + # Now we have the local server leave the room, and check that calling + # `get_user_in_room` for the remote user no longer includes the room. + self.helper.leave(room_id, user_id, tok=token) + + users = self.get_success(self.store.get_users_in_room(room_id)) + self.assertEqual(users, []) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 70d52b088c..28c767ecfd 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -992,3 +992,112 @@ class StateFilterDifferenceTestCase(TestCase): StateFilter.none(), StateFilter.all(), ) + + +class StateFilterTestCase(TestCase): + def test_return_expanded(self): + """ + Tests the behaviour of the return_expanded() function that expands + StateFilters to include more state types (for the sake of cache hit rate). + """ + + self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all()) + + self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none()) + + # Concrete-only state filters stay the same + # (Case: mixed filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": {""}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": {""}, + }, + include_others=False, + ), + ) + + # Concrete-only state filters stay the same + # (Case: non-member-only filter) + self.assertEqual( + StateFilter.freeze( + {"some.other.state.type": {""}}, include_others=False + ).return_expanded(), + StateFilter.freeze({"some.other.state.type": {""}}, include_others=False), + ) + + # Concrete-only state filters stay the same + # (Case: member-only filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + }, + include_others=False, + ), + ) + + # Wildcard member-only state filters stay the same + self.assertEqual( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # If there is a wildcard in the non-member portion of the filter, + # it's expanded to include ALL non-member events. + # (Case: mixed filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": None, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + {EventTypes.Member: {"@wombat:test", "@alicia:test"}}, + include_others=True, + ), + ) + + # If there is a wildcard in the non-member portion of the filter, + # it's expanded to include ALL non-member events. + # (Case: non-member-only filter) + self.assertEqual( + StateFilter.freeze( + { + "some.other.state.type": None, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze({EventTypes.Member: set()}, include_others=True), + ) + self.assertEqual( + StateFilter.freeze( + { + "some.other.state.type": None, + "yet.another.state.type": {"wombat"}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze({EventTypes.Member: set()}, include_others=True), + ) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 7f5b28aed8..48f1e9d841 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -341,7 +341,9 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): # Register an AS user. user = self.register_user("user", "pass") token = self.login(user, "pass") - as_user = self.register_appservice_user("as_user_potato", self.appservice.token) + as_user, _ = self.register_appservice_user( + "as_user_potato", self.appservice.token + ) # Join the AS user to rooms owned by the normal user. public, private = self._create_rooms_and_inject_memberships( diff --git a/tests/unittest.py b/tests/unittest.py index 1431848367..a71892cb9d 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -17,6 +17,7 @@ import gc import hashlib import hmac import inspect +import json import logging import secrets import time @@ -36,9 +37,11 @@ from typing import ( ) from unittest.mock import Mock, patch -from canonicaljson import json +import canonicaljson +import signedjson.key +import unpaddedbase64 -from twisted.internet.defer import Deferred, ensureDeferred, succeed +from twisted.internet.defer import Deferred, ensureDeferred from twisted.python.failure import Failure from twisted.python.threadpool import ThreadPool from twisted.test.proto_helpers import MemoryReactor @@ -49,8 +52,7 @@ from twisted.web.server import Request from synapse import events from synapse.api.constants import EventTypes, Membership from synapse.config.homeserver import HomeServerConfig -from synapse.config.ratelimiting import FederationRateLimitConfig -from synapse.federation.transport import server as federation_server +from synapse.federation.transport.server import TransportLayerServer from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest, SynapseSite from synapse.logging.context import ( @@ -61,10 +63,10 @@ from synapse.logging.context import ( ) from synapse.rest import RegisterServletsFunc from synapse.server import HomeServer +from synapse.storage.keys import FetchKeyResult from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from synapse.util.httpresourcetree import create_resource_tree -from synapse.util.ratelimitutils import FederationRateLimiter from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver from tests.test_utils import event_injection, setup_awaitable_errors @@ -620,18 +622,19 @@ class HomeserverTestCase(TestCase): self, username: str, appservice_token: str, - ) -> str: + ) -> Tuple[str, str]: """Register an appservice user as an application service. Requires the client-facing registration API be registered. Args: username: the user to be registered by an application service. - Should be a full username, i.e. ""@localpart:hostname" as opposed to just "localpart" + Should NOT be a full username, i.e. just "localpart" as opposed to "@localpart:hostname" appservice_token: the acccess token for that application service. Raises: if the request to '/register' does not return 200 OK. - Returns: the MXID of the new user. + Returns: + The MXID of the new user, the device ID of the new user's first device. """ channel = self.make_request( "POST", @@ -643,7 +646,7 @@ class HomeserverTestCase(TestCase): access_token=appservice_token, ) self.assertEqual(channel.code, 200, channel.json_body) - return channel.json_body["user_id"] + return channel.json_body["user_id"], channel.json_body["device_id"] def login( self, @@ -754,42 +757,116 @@ class HomeserverTestCase(TestCase): class FederatingHomeserverTestCase(HomeserverTestCase): """ - A federating homeserver that authenticates incoming requests as `other.example.com`. + A federating homeserver, set up to validate incoming federation requests """ - def create_resource_dict(self) -> Dict[str, Resource]: - d = super().create_resource_dict() - d["/_matrix/federation"] = TestTransportLayerServer(self.hs) - return d + OTHER_SERVER_NAME = "other.example.com" + OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test") + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + super().prepare(reactor, clock, hs) -class TestTransportLayerServer(JsonResource): - """A test implementation of TransportLayerServer + # poke the other server's signing key into the key store, so that we don't + # make requests for it + verify_key = signedjson.key.get_verify_key(self.OTHER_SERVER_SIGNATURE_KEY) + verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version) - authenticates incoming requests as `other.example.com`. - """ + self.get_success( + hs.get_datastore().store_server_verify_keys( + from_server=self.OTHER_SERVER_NAME, + ts_added_ms=clock.time_msec(), + verify_keys=[ + ( + self.OTHER_SERVER_NAME, + verify_key_id, + FetchKeyResult( + verify_key=verify_key, + valid_until_ts=clock.time_msec() + 1000, + ), + ) + ], + ) + ) + + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_matrix/federation"] = TransportLayerServer(self.hs) + return d - def __init__(self, hs): - super().__init__(hs) + def make_signed_federation_request( + self, + method: str, + path: str, + content: Optional[JsonDict] = None, + await_result: bool = True, + custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, + client_ip: str = "127.0.0.1", + ) -> FakeChannel: + """Make an inbound signed federation request to this server - class Authenticator: - def authenticate_request(self, request, content): - return succeed("other.example.com") + The request is signed as if it came from "other.example.com", which our HS + already has the keys for. + """ - authenticator = Authenticator() + if custom_headers is None: + custom_headers = [] + else: + custom_headers = list(custom_headers) + + custom_headers.append( + ( + "Authorization", + _auth_header_for_request( + origin=self.OTHER_SERVER_NAME, + destination=self.hs.hostname, + signing_key=self.OTHER_SERVER_SIGNATURE_KEY, + method=method, + path=path, + content=content, + ), + ) + ) - ratelimiter = FederationRateLimiter( - hs.get_clock(), - FederationRateLimitConfig( - window_size=1, - sleep_limit=1, - sleep_delay=1, - reject_limit=1000, - concurrent=1000, - ), + return make_request( + self.reactor, + self.site, + method=method, + path=path, + content=content, + shorthand=False, + await_result=await_result, + custom_headers=custom_headers, + client_ip=client_ip, ) - federation_server.register_servlets(hs, self, authenticator, ratelimiter) + +def _auth_header_for_request( + origin: str, + destination: str, + signing_key: signedjson.key.SigningKey, + method: str, + path: str, + content: Optional[JsonDict], +) -> str: + """Build a suitable Authorization header for an outgoing federation request""" + request_description: JsonDict = { + "method": method, + "uri": path, + "destination": destination, + "origin": origin, + } + if content is not None: + request_description["content"] = content + signature_base64 = unpaddedbase64.encode_base64( + signing_key.sign( + canonicaljson.encode_canonical_json(request_description) + ).signature + ) + return ( + f"X-Matrix origin={origin}," + f"key={signing_key.alg}:{signing_key.version}," + f"sig={signature_base64}" + ) def override_config(extra_config): diff --git a/tox.ini b/tox.ini index 32679e9106..436ecf7552 100644 --- a/tox.ini +++ b/tox.ini @@ -4,6 +4,11 @@ envlist = packaging, py37, py38, py39, py310, check_codestyle, check_isort # we require tox>=2.3.2 for the fix to https://github.com/tox-dev/tox/issues/208 minversion = 2.3.2 +# the tox-venv plugin makes tox use python's built-in `venv` module rather than +# the legacy `virtualenv` tool. `virtualenv` embeds its own `pip`, `setuptools`, +# etc, and ends up being rather unreliable. +requires = tox-venv + [base] deps = python-subunit @@ -119,6 +124,9 @@ usedevelop = false deps = Automat == 0.8.0 lxml + # markupsafe 2.1 introduced a change that breaks Jinja 2.x. Since we depend on + # Jinja >= 2.9, it means this test suite will fail if markupsafe >= 2.1 is installed. + markupsafe < 2.1 {[base]deps} commands = @@ -158,7 +166,7 @@ commands = [testenv:check_isort] extras = lint -commands = isort -c --df --sp setup.cfg {[base]lint_targets} +commands = isort -c --df {[base]lint_targets} [testenv:check-newsfragment] skip_install = true |