diff options
135 files changed, 2592 insertions, 1646 deletions
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 6c3a998499..0dfab4e087 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -8,6 +8,7 @@ - Use markdown where necessary, mostly for `code blocks`. - End with either a period (.) or an exclamation mark (!). - Start with a capital letter. + - Feel free to credit yourself, by adding a sentence "Contributed by @github_username." or "Contributed by [Your Name]." to the end of the entry. * [ ] Pull request includes a [sign off](https://matrix-org.github.io/synapse/latest/development/contributing_guide.html#sign-off) * [ ] [Code style](https://matrix-org.github.io/synapse/latest/code_style.html) is correct (run the [linters](https://matrix-org.github.io/synapse/latest/development/contributing_guide.html#run-the-linters)) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index cb72e1a233..4f58069702 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -366,6 +366,8 @@ jobs: # Build initial Synapse image - run: docker build -t matrixdotorg/synapse:latest -f docker/Dockerfile . working-directory: synapse + env: + DOCKER_BUILDKIT: 1 # Build a ready-to-run Synapse image based on the initial image above. # This new image includes a config file, keys for signing and TLS, and @@ -374,7 +376,8 @@ jobs: working-directory: complement/dockerfiles # Run Complement - - run: go test -v -tags synapse_blacklist,msc2403 ./tests/... + - run: set -o pipefail && go test -v -json -tags synapse_blacklist,msc2403 ./tests/... 2>&1 | gotestfmt + shell: bash env: COMPLEMENT_BASE_IMAGE: complement-synapse:latest working-directory: complement diff --git a/.gitignore b/.gitignore index fe137f3370..3bd6b1a08c 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,7 @@ __pycache__/ # docs book/ + +# complement +/complement-* +/master.tar.gz diff --git a/CHANGES.md b/CHANGES.md index b8f3588ec9..6da56a11f5 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,12 +1,103 @@ -Synapse 1.50.2 (2022-01-24) +Synapse 1.51.0 (2022-01-25) =========================== +No significant changes since 1.51.0rc2. + +Synapse 1.51.0 deprecates `webclient` listeners and non-HTTP(S) `web_client_location`s. Support for these will be removed in Synapse 1.53.0, at which point Synapse will not be capable of directly serving a web client for Matrix. + +Synapse 1.51.0rc2 (2022-01-24) +============================== + Bugfixes -------- - Fix a bug introduced in Synapse 1.40.0 that caused Synapse to fail to process incoming federation traffic after handling a large amount of events in a v1 room. ([\#11806](https://github.com/matrix-org/synapse/issues/11806)) +Synapse 1.51.0rc1 (2022-01-21) +============================== + +Features +-------- + +- Add `track_puppeted_user_ips` config flag to record client IP addresses against puppeted users, and include the puppeted users in monthly active user counts. ([\#11561](https://github.com/matrix-org/synapse/issues/11561), [\#11749](https://github.com/matrix-org/synapse/issues/11749), [\#11757](https://github.com/matrix-org/synapse/issues/11757)) +- Include whether the requesting user has participated in a thread when generating a summary for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440). ([\#11577](https://github.com/matrix-org/synapse/issues/11577)) +- Return an `M_FORBIDDEN` error code instead of `M_UNKNOWN` when a spam checker module prevents a user from creating a room. ([\#11672](https://github.com/matrix-org/synapse/issues/11672)) +- Add a flag to the `synapse_review_recent_signups` script to ignore and filter appservice users. ([\#11675](https://github.com/matrix-org/synapse/issues/11675), [\#11770](https://github.com/matrix-org/synapse/issues/11770)) + + +Bugfixes +-------- + +- Fix a long-standing issue which could cause Synapse to incorrectly accept data in the unsigned field of events + received over federation. ([\#11530](https://github.com/matrix-org/synapse/issues/11530)) +- Fix a long-standing bug where Synapse wouldn't cache a response indicating that a remote user has no devices. ([\#11587](https://github.com/matrix-org/synapse/issues/11587)) +- Fix an error that occurs whilst trying to get the federation status of a destination server that was working normally. This admin API was newly introduced in Synapse v1.49.0. ([\#11593](https://github.com/matrix-org/synapse/issues/11593)) +- Fix bundled aggregations not being included in the `/sync` response, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675). ([\#11612](https://github.com/matrix-org/synapse/issues/11612), [\#11659](https://github.com/matrix-org/synapse/issues/11659), [\#11791](https://github.com/matrix-org/synapse/issues/11791)) +- Fix the `/_matrix/client/v1/room/{roomId}/hierarchy` endpoint returning incorrect fields which have been present since Synapse 1.49.0. ([\#11667](https://github.com/matrix-org/synapse/issues/11667)) +- Fix preview of some GIF URLs (like tenor.com). Contributed by Philippe Daouadi. ([\#11669](https://github.com/matrix-org/synapse/issues/11669)) +- Fix a bug where only the first 50 rooms from a space were returned from the `/hierarchy` API. This has existed since the introduction of the API in Synapse v1.41.0. ([\#11695](https://github.com/matrix-org/synapse/issues/11695)) +- Fix a bug introduced in Synapse v1.18.0 where password reset and address validation emails would not be sent if their subject was configured to use the 'app' template variable. Contributed by @br4nnigan. ([\#11710](https://github.com/matrix-org/synapse/issues/11710), [\#11745](https://github.com/matrix-org/synapse/issues/11745)) +- Make the 'List Rooms' Admin API sort stable. Contributed by Daniël Sonck. ([\#11737](https://github.com/matrix-org/synapse/issues/11737)) +- Fix a long-standing bug where space hierarchy over federation would only work correctly some of the time. ([\#11775](https://github.com/matrix-org/synapse/issues/11775)) +- Fix a bug introduced in Synapse v1.46.0 that prevented `on_logged_out` module callbacks from being correctly awaited by Synapse. ([\#11786](https://github.com/matrix-org/synapse/issues/11786)) + + +Improved Documentation +---------------------- + +- Warn against using a Let's Encrypt certificate for TLS/DTLS TURN server client connections, and suggest using ZeroSSL certificate instead. This works around client-side connectivity errors caused by WebRTC libraries that reject Let's Encrypt certificates. Contibuted by @AndrewFerr. ([\#11686](https://github.com/matrix-org/synapse/issues/11686)) +- Document the new `SYNAPSE_TEST_PERSIST_SQLITE_DB` environment variable in the contributing guide. ([\#11715](https://github.com/matrix-org/synapse/issues/11715)) +- Document that the minimum supported PostgreSQL version is now 10. ([\#11725](https://github.com/matrix-org/synapse/issues/11725)) +- Fix typo in demo docs: differnt. ([\#11735](https://github.com/matrix-org/synapse/issues/11735)) +- Update room spec URL in config files. ([\#11739](https://github.com/matrix-org/synapse/issues/11739)) +- Mention `python3-venv` and `libpq-dev` dependencies in the contribution guide. ([\#11740](https://github.com/matrix-org/synapse/issues/11740)) +- Update documentation for configuring login with Facebook. ([\#11755](https://github.com/matrix-org/synapse/issues/11755)) +- Update installation instructions to note that Python 3.6 is no longer supported. ([\#11781](https://github.com/matrix-org/synapse/issues/11781)) + + +Deprecations and Removals +------------------------- + +- Remove the unstable `/send_relation` endpoint. ([\#11682](https://github.com/matrix-org/synapse/issues/11682)) +- Remove `python_twisted_reactor_pending_calls` Prometheus metric. ([\#11724](https://github.com/matrix-org/synapse/issues/11724)) +- Remove the `password_hash` field from the response dictionaries of the [Users Admin API](https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html). ([\#11576](https://github.com/matrix-org/synapse/issues/11576)) +- **Deprecate support for `webclient` listeners and non-HTTP(S) `web_client_location` configuration. ([\#11774](https://github.com/matrix-org/synapse/issues/11774), [\#11783](https://github.com/matrix-org/synapse/issues/11783))** + + +Internal Changes +---------------- + +- Run `pyupgrade --py37-plus --keep-percent-format` on Synapse. ([\#11685](https://github.com/matrix-org/synapse/issues/11685)) +- Use buildkit's cache feature to speed up docker builds. ([\#11691](https://github.com/matrix-org/synapse/issues/11691)) +- Use `auto_attribs` and native type hints for attrs classes. ([\#11692](https://github.com/matrix-org/synapse/issues/11692), [\#11768](https://github.com/matrix-org/synapse/issues/11768)) +- Remove debug logging for #4422, which has been closed since Synapse 0.99. ([\#11693](https://github.com/matrix-org/synapse/issues/11693)) +- Remove fallback code for Python 2. ([\#11699](https://github.com/matrix-org/synapse/issues/11699)) +- Add a test for [an edge case](https://github.com/matrix-org/synapse/pull/11532#discussion_r769104461) in the `/sync` logic. ([\#11701](https://github.com/matrix-org/synapse/issues/11701)) +- Add the option to write SQLite test dbs to disk when running tests. ([\#11702](https://github.com/matrix-org/synapse/issues/11702)) +- Improve Complement test output for Gitub Actions. ([\#11707](https://github.com/matrix-org/synapse/issues/11707)) +- Fix docstring on `add_account_data_for_user`. ([\#11716](https://github.com/matrix-org/synapse/issues/11716)) +- Complement environment variable name change and update `.gitignore`. ([\#11718](https://github.com/matrix-org/synapse/issues/11718)) +- Simplify calculation of Prometheus metrics for garbage collection. ([\#11723](https://github.com/matrix-org/synapse/issues/11723)) +- Improve accuracy of `python_twisted_reactor_tick_time` Prometheus metric. ([\#11724](https://github.com/matrix-org/synapse/issues/11724), [\#11771](https://github.com/matrix-org/synapse/issues/11771)) +- Minor efficiency improvements when inserting many values into the database. ([\#11742](https://github.com/matrix-org/synapse/issues/11742)) +- Invite PR authors to give themselves credit in the changelog. ([\#11744](https://github.com/matrix-org/synapse/issues/11744)) +- Add optional debugging to investigate [issue 8631](https://github.com/matrix-org/synapse/issues/8631). ([\#11760](https://github.com/matrix-org/synapse/issues/11760)) +- Remove `log_function` utility function and its uses. ([\#11761](https://github.com/matrix-org/synapse/issues/11761)) +- Add a unit test that checks both `client` and `webclient` resources will function when simultaneously enabled. ([\#11765](https://github.com/matrix-org/synapse/issues/11765)) +- Allow overriding complement commit using `COMPLEMENT_REF`. ([\#11766](https://github.com/matrix-org/synapse/issues/11766)) +- Add some comments and type annotations for `_update_outliers_txn`. ([\#11776](https://github.com/matrix-org/synapse/issues/11776)) + + +Synapse 1.50.2 (2022-01-24) +=========================== + +Bugfixes +-------- + +- Backport the sole fix from v1.51.0rc2. This fixes a bug introduced in Synapse 1.40.0 that caused Synapse to fail to process incoming federation traffic after handling a large amount of events in a v1 room. ([\#11806](https://github.com/matrix-org/synapse/issues/11806)) + + Synapse 1.50.1 (2022-01-18) =========================== diff --git a/contrib/prometheus/consoles/synapse.html b/contrib/prometheus/consoles/synapse.html index cd9ad15231..d17c8a08d9 100644 --- a/contrib/prometheus/consoles/synapse.html +++ b/contrib/prometheus/consoles/synapse.html @@ -92,22 +92,6 @@ new PromConsole.Graph({ }) </script> -<h3>Pending calls per tick</h3> -<div id="reactor_pending_calls"></div> -<script> -new PromConsole.Graph({ - node: document.querySelector("#reactor_pending_calls"), - expr: "rate(python_twisted_reactor_pending_calls_sum[30s]) / rate(python_twisted_reactor_pending_calls_count[30s])", - name: "[[job]]-[[index]]", - min: 0, - renderer: "line", - height: 150, - yAxisFormatter: PromConsole.NumberFormatter.humanize, - yHoverFormatter: PromConsole.NumberFormatter.humanize, - yTitle: "Pending Calls" -}) -</script> - <h1>Storage</h1> <h3>Queries</h3> diff --git a/debian/changelog b/debian/changelog index c790c21877..3a598c4148 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,21 @@ +matrix-synapse-py3 (1.51.0) stable; urgency=medium + + * New synapse release 1.51.0. + + -- Synapse Packaging team <packages@matrix.org> Tue, 25 Jan 2022 11:28:51 +0000 + +matrix-synapse-py3 (1.51.0~rc2) stable; urgency=medium + + * New synapse release 1.51.0~rc2. + + -- Synapse Packaging team <packages@matrix.org> Mon, 24 Jan 2022 12:25:00 +0000 + +matrix-synapse-py3 (1.51.0~rc1) stable; urgency=medium + + * New synapse release 1.51.0~rc1. + + -- Synapse Packaging team <packages@matrix.org> Fri, 21 Jan 2022 10:46:02 +0000 + matrix-synapse-py3 (1.50.2) stable; urgency=medium * New synapse release 1.50.2. diff --git a/demo/README b/demo/README index 0bec820ad6..a5a95bd196 100644 --- a/demo/README +++ b/demo/README @@ -22,5 +22,5 @@ Logs and sqlitedb will be stored in demo/808{0,1,2}.{log,db} -Also note that when joining a public room on a differnt HS via "#foo:bar.net", then you are (in the current impl) joining a room with room_id "foo". This means that it won't work if your HS already has a room with that name. +Also note that when joining a public room on a different HS via "#foo:bar.net", then you are (in the current impl) joining a room with room_id "foo". This means that it won't work if your HS already has a room with that name. diff --git a/docker/Dockerfile b/docker/Dockerfile index 2bdc607e66..306f75ae56 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,14 +1,17 @@ # Dockerfile to build the matrixdotorg/synapse docker images. # +# Note that it uses features which are only available in BuildKit - see +# https://docs.docker.com/go/buildkit/ for more information. +# # To build the image, run `docker build` command from the root of the # synapse repository: # -# docker build -f docker/Dockerfile . +# DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile . # # There is an optional PYTHON_VERSION build argument which sets the # version of python to build against: for example: # -# docker build -f docker/Dockerfile --build-arg PYTHON_VERSION=3.6 . +# DOCKER_BUILDKIT=1 docker build -f docker/Dockerfile --build-arg PYTHON_VERSION=3.9 . # ARG PYTHON_VERSION=3.8 @@ -19,7 +22,16 @@ ARG PYTHON_VERSION=3.8 FROM docker.io/python:${PYTHON_VERSION}-slim as builder # install the OS build deps -RUN apt-get update && apt-get install -y \ +# +# RUN --mount is specific to buildkit and is documented at +# https://github.com/moby/buildkit/blob/master/frontend/dockerfile/docs/syntax.md#build-mounts-run---mount. +# Here we use it to set up a cache for apt, to improve rebuild speeds on +# slow connections. +# +RUN \ + --mount=type=cache,target=/var/cache/apt,sharing=locked \ + --mount=type=cache,target=/var/lib/apt,sharing=locked \ + apt-get update && apt-get install -y \ build-essential \ libffi-dev \ libjpeg-dev \ @@ -44,7 +56,8 @@ COPY synapse/python_dependencies.py /synapse/synapse/python_dependencies.py # used while you develop on the source # # This is aiming at installing the `install_requires` and `extras_require` from `setup.py` -RUN pip install --prefix="/install" --no-warn-script-location \ +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install --prefix="/install" --no-warn-script-location \ /synapse[all] # Copy over the rest of the project @@ -66,7 +79,10 @@ LABEL org.opencontainers.image.documentation='https://github.com/matrix-org/syna LABEL org.opencontainers.image.source='https://github.com/matrix-org/synapse.git' LABEL org.opencontainers.image.licenses='Apache-2.0' -RUN apt-get update && apt-get install -y \ +RUN \ + --mount=type=cache,target=/var/cache/apt,sharing=locked \ + --mount=type=cache,target=/var/lib/apt,sharing=locked \ + apt-get update && apt-get install -y \ curl \ gosu \ libjpeg62-turbo \ diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 74933d2fcf..c514cadb9d 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -15,9 +15,10 @@ server admin: [Admin API](../usage/administration/admin_api) It returns a JSON body like the following: -```json +```jsonc { - "displayname": "User", + "name": "@user:example.com", + "displayname": "User", // can be null if not set "threepids": [ { "medium": "email", @@ -32,11 +33,11 @@ It returns a JSON body like the following: "validated_at": 1586458409743 } ], - "avatar_url": "<avatar_url>", + "avatar_url": "<avatar_url>", // can be null if not set + "is_guest": 0, "admin": 0, "deactivated": 0, "shadow_banned": 0, - "password_hash": "$2b$12$p9B4GkqYdRTPGD", "creation_ts": 1560432506, "appservice_id": null, "consent_server_notice_sent": null, diff --git a/docs/development/contributing_guide.md b/docs/development/contributing_guide.md index abdb808438..c142981693 100644 --- a/docs/development/contributing_guide.md +++ b/docs/development/contributing_guide.md @@ -20,7 +20,9 @@ recommended for development. More information about WSL can be found at <https://docs.microsoft.com/en-us/windows/wsl/install>. Running Synapse natively on Windows is not officially supported. -The code of Synapse is written in Python 3. To do pretty much anything, you'll need [a recent version of Python 3](https://wiki.python.org/moin/BeginnersGuide/Download). +The code of Synapse is written in Python 3. To do pretty much anything, you'll need [a recent version of Python 3](https://www.python.org/downloads/). Your Python also needs support for [virtual environments](https://docs.python.org/3/library/venv.html). This is usually built-in, but some Linux distributions like Debian and Ubuntu split it out into its own package. Running `sudo apt install python3-venv` should be enough. + +Synapse can connect to PostgreSQL via the [psycopg2](https://pypi.org/project/psycopg2/) Python library. Building this library from source requires access to PostgreSQL's C header files. On Debian or Ubuntu Linux, these can be installed with `sudo apt install libpq-dev`. The source code of Synapse is hosted on GitHub. You will also need [a recent version of git](https://github.com/git-guides/install-git). @@ -169,6 +171,27 @@ To increase the log level for the tests, set `SYNAPSE_TEST_LOG_LEVEL`: SYNAPSE_TEST_LOG_LEVEL=DEBUG trial tests ``` +By default, tests will use an in-memory SQLite database for test data. For additional +help with debugging, one can use an on-disk SQLite database file instead, in order to +review database state during and after running tests. This can be done by setting +the `SYNAPSE_TEST_PERSIST_SQLITE_DB` environment variable. Doing so will cause the +database state to be stored in a file named `test.db` under the trial process' +working directory. Typically, this ends up being `_trial_temp/test.db`. For example: + +```sh +SYNAPSE_TEST_PERSIST_SQLITE_DB=1 trial tests +``` + +The database file can then be inspected with: + +```sh +sqlite3 _trial_temp/test.db +``` + +Note that the database file is cleared at the beginning of each test run. Thus it +will always only contain the data generated by the *last run test*. Though generally +when debugging, one is only running a single test anyway. + ### Running tests under PostgreSQL Invoking `trial` as above will use an in-memory SQLite database. This is great for diff --git a/docs/development/url_previews.md b/docs/development/url_previews.md index aff3813609..154b9a5e12 100644 --- a/docs/development/url_previews.md +++ b/docs/development/url_previews.md @@ -35,7 +35,12 @@ When Synapse is asked to preview a URL it does the following: 5. If the media is HTML: 1. Decodes the HTML via the stored file. 2. Generates an Open Graph response from the HTML. - 3. If an image exists in the Open Graph response: + 3. If a JSON oEmbed URL was found in the HTML via autodiscovery: + 1. Downloads the URL and stores it into a file via the media storage provider + and saves the local media metadata. + 2. Convert the oEmbed response to an Open Graph response. + 3. Override any Open Graph data from the HTML with data from oEmbed. + 4. If an image exists in the Open Graph response: 1. Downloads the URL and stores it into a file via the media storage provider and saves the local media metadata. 2. Generates thumbnails. diff --git a/docs/openid.md b/docs/openid.md index ff9de9d5b8..171ea3b712 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -390,9 +390,6 @@ oidc_providers: ### Facebook -Like Github, Facebook provide a custom OAuth2 API rather than an OIDC-compliant -one so requires a little more configuration. - 0. You will need a Facebook developer account. You can register for one [here](https://developers.facebook.com/async/registration/). 1. On the [apps](https://developers.facebook.com/apps/) page of the developer @@ -412,24 +409,28 @@ Synapse config: idp_name: Facebook idp_brand: "facebook" # optional: styling hint for clients discover: false - issuer: "https://facebook.com" + issuer: "https://www.facebook.com" client_id: "your-client-id" # TO BE FILLED client_secret: "your-client-secret" # TO BE FILLED scopes: ["openid", "email"] - authorization_endpoint: https://facebook.com/dialog/oauth - token_endpoint: https://graph.facebook.com/v9.0/oauth/access_token - user_profile_method: "userinfo_endpoint" - userinfo_endpoint: "https://graph.facebook.com/v9.0/me?fields=id,name,email,picture" + authorization_endpoint: "https://facebook.com/dialog/oauth" + token_endpoint: "https://graph.facebook.com/v9.0/oauth/access_token" + jwks_uri: "https://www.facebook.com/.well-known/oauth/openid/jwks/" user_mapping_provider: config: - subject_claim: "id" display_name_template: "{{ user.name }}" + email_template: "{{ '{{ user.email }}' }}" ``` Relevant documents: - * https://developers.facebook.com/docs/facebook-login/manually-build-a-login-flow - * Using Facebook's Graph API: https://developers.facebook.com/docs/graph-api/using-graph-api/ - * Reference to the User endpoint: https://developers.facebook.com/docs/graph-api/reference/user + * [Manually Build a Login Flow](https://developers.facebook.com/docs/facebook-login/manually-build-a-login-flow) + * [Using Facebook's Graph API](https://developers.facebook.com/docs/graph-api/using-graph-api/) + * [Reference to the User endpoint](https://developers.facebook.com/docs/graph-api/reference/user) + +Facebook do have an [OIDC discovery endpoint](https://www.facebook.com/.well-known/openid-configuration), +but it has a `response_types_supported` which excludes "code" (which we rely on, and +is even mentioned in their [documentation](https://developers.facebook.com/docs/facebook-login/manually-build-a-login-flow#login)), +so we have to disable discovery and configure the URIs manually. ### Gitea diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 810a14b077..1b86d0295d 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -74,13 +74,7 @@ server_name: "SERVERNAME" # pid_file: DATADIR/homeserver.pid -# The absolute URL to the web client which /_matrix/client will redirect -# to if 'webclient' is configured under the 'listeners' configuration. -# -# This option can be also set to the filesystem path to the web client -# which will be served at /_matrix/client/ if 'webclient' is configured -# under the 'listeners' configuration, however this is a security risk: -# https://github.com/matrix-org/synapse#security-note +# The absolute URL to the web client which / will redirect to. # #web_client_location: https://riot.example.com/ @@ -164,7 +158,7 @@ presence: # The default room version for newly created rooms. # # Known room versions are listed here: -# https://matrix.org/docs/spec/#complete-list-of-room-versions +# https://spec.matrix.org/latest/rooms/#complete-list-of-room-versions # # For example, for room version 1, default_room_version should be set # to "1". @@ -310,8 +304,6 @@ presence: # static: static resources under synapse/static (/_matrix/static). (Mostly # useful for 'fallback authentication'.) # -# webclient: A web client. Requires web_client_location to be set. -# listeners: # TLS-enabled listener: for when matrix traffic is sent directly to synapse. # @@ -1503,6 +1495,21 @@ room_prejoin_state: #additional_event_types: # - org.example.custom.event.type +# We record the IP address of clients used to access the API for various +# reasons, including displaying it to the user in the "Where you're signed in" +# dialog. +# +# By default, when puppeting another user via the admin API, the client IP +# address is recorded against the user who created the access token (ie, the +# admin user), and *not* the puppeted user. +# +# Uncomment the following to also record the IP address against the puppeted +# user. (This also means that the puppeted user will count as an "active" user +# for the purpose of monthly active user tracking - see 'limit_usage_by_mau' etc +# above.) +# +#track_puppeted_user_ips: true + # A list of application service config files to use # @@ -1870,10 +1877,13 @@ saml2_config: # Defaults to false. Avoid this in production. # # user_profile_method: Whether to fetch the user profile from the userinfo -# endpoint. Valid values are: 'auto' or 'userinfo_endpoint'. +# endpoint, or to rely on the data returned in the id_token from the +# token_endpoint. +# +# Valid values are: 'auto' or 'userinfo_endpoint'. # -# Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is -# included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the +# Defaults to 'auto', which uses the userinfo endpoint if 'openid' is +# not included in 'scopes'. Set to 'userinfo_endpoint' to always use the # userinfo endpoint. # # allow_existing_users: set to 'true' to allow a user logging in via OIDC to diff --git a/docs/setup/installation.md b/docs/setup/installation.md index 210c80dace..fe657a15df 100644 --- a/docs/setup/installation.md +++ b/docs/setup/installation.md @@ -194,7 +194,7 @@ When following this route please make sure that the [Platform-specific prerequis System requirements: - POSIX-compliant system (tested on Linux & OS X) -- Python 3.6 or later, up to Python 3.9. +- Python 3.7 or later, up to Python 3.9. - At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org To install the Synapse homeserver run: diff --git a/docs/turn-howto.md b/docs/turn-howto.md index e32aaa1850..eba7ca6124 100644 --- a/docs/turn-howto.md +++ b/docs/turn-howto.md @@ -137,6 +137,10 @@ This will install and start a systemd service called `coturn`. # TLS private key file pkey=/path/to/privkey.pem + + # Ensure the configuration lines that disable TLS/DTLS are commented-out or removed + #no-tls + #no-dtls ``` In this case, replace the `turn:` schemes in the `turn_uris` settings below @@ -145,6 +149,14 @@ This will install and start a systemd service called `coturn`. We recommend that you only try to set up TLS/DTLS once you have set up a basic installation and got it working. + NB: If your TLS certificate was provided by Let's Encrypt, TLS/DTLS will + not work with any Matrix client that uses Chromium's WebRTC library. This + currently includes Element Android & iOS; for more details, see their + [respective](https://github.com/vector-im/element-android/issues/1533) + [issues](https://github.com/vector-im/element-ios/issues/2712) as well as the underlying + [WebRTC issue](https://bugs.chromium.org/p/webrtc/issues/detail?id=11710). + Consider using a ZeroSSL certificate for your TURN server as a working alternative. + 1. Ensure your firewall allows traffic into the TURN server on the ports you've configured it to listen on (By default: 3478 and 5349 for TURN traffic (remember to allow both TCP and UDP traffic), and ports 49152-65535 @@ -250,6 +262,10 @@ Here are a few things to try: * Check that you have opened your firewall to allow UDP traffic to the UDP relay ports (49152-65535 by default). + * Try disabling `coturn`'s TLS/DTLS listeners and enable only its (unencrypted) + TCP/UDP listeners. (This will only leave signaling traffic unencrypted; + voice & video WebRTC traffic is always encrypted.) + * Some WebRTC implementations (notably, that of Google Chrome) appear to get confused by TURN servers which are reachable over IPv6 (this appears to be an unexpected side-effect of its handling of multiple IP addresses as diff --git a/docs/upgrade.md b/docs/upgrade.md index 30bb0dcd9c..f455d257ba 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -85,6 +85,17 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.51.0 + +## Deprecation of `webclient` listeners and non-HTTP(S) `web_client_location` + +Listeners of type `webclient` are deprecated and scheduled to be removed in +Synapse v1.53.0. + +Similarly, a non-HTTP(S) `web_client_location` configuration is deprecated and +will become a configuration error in Synapse v1.53.0. + + # Upgrading to v1.50.0 ## Dropping support for old Python and Postgres versions diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 53295b58fc..e08ffedaf3 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -8,7 +8,8 @@ # By default the script will fetch the latest Complement master branch and # run tests with that. This can be overridden to use a custom Complement # checkout by setting the COMPLEMENT_DIR environment variable to the -# filepath of a local Complement checkout. +# filepath of a local Complement checkout or by setting the COMPLEMENT_REF +# environment variable to pull a different branch or commit. # # By default Synapse is run in monolith mode. This can be overridden by # setting the WORKERS environment variable. @@ -23,16 +24,20 @@ # Exit if a line returns a non-zero exit code set -e +# enable buildkit for the docker builds +export DOCKER_BUILDKIT=1 + # Change to the repository root cd "$(dirname $0)/.." # Check for a user-specified Complement checkout if [[ -z "$COMPLEMENT_DIR" ]]; then - echo "COMPLEMENT_DIR not set. Fetching the latest Complement checkout..." - wget -Nq https://github.com/matrix-org/complement/archive/master.tar.gz - tar -xzf master.tar.gz - COMPLEMENT_DIR=complement-master - echo "Checkout available at 'complement-master'" + COMPLEMENT_REF=${COMPLEMENT_REF:-master} + echo "COMPLEMENT_DIR not set. Fetching Complement checkout from ${COMPLEMENT_REF}..." + wget -Nq https://github.com/matrix-org/complement/archive/${COMPLEMENT_REF}.tar.gz + tar -xzf ${COMPLEMENT_REF}.tar.gz + COMPLEMENT_DIR=complement-${COMPLEMENT_REF} + echo "Checkout available at 'complement-${COMPLEMENT_REF}'" fi # Build the base Synapse image from the local checkout @@ -47,7 +52,7 @@ if [[ -n "$WORKERS" ]]; then COMPLEMENT_DOCKERFILE=SynapseWorkers.Dockerfile # And provide some more configuration to complement. export COMPLEMENT_CA=true - export COMPLEMENT_VERSION_CHECK_ITERATIONS=500 + export COMPLEMENT_SPAWN_HS_TIMEOUT_SECS=25 else export COMPLEMENT_BASE_IMAGE=complement-synapse COMPLEMENT_DOCKERFILE=Synapse.Dockerfile @@ -65,4 +70,5 @@ if [[ -n "$1" ]]; then fi # Run the tests! +echo "Images built; running complement" go test -v -tags synapse_blacklist,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... diff --git a/synapse/__init__.py b/synapse/__init__.py index 5ef294cd42..26bdfec33a 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.50.2" +__version__ = "1.51.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/_scripts/review_recent_signups.py b/synapse/_scripts/review_recent_signups.py index 093af4327a..e207f154f3 100644 --- a/synapse/_scripts/review_recent_signups.py +++ b/synapse/_scripts/review_recent_signups.py @@ -46,7 +46,9 @@ class UserInfo: ips: List[str] = attr.Factory(list) -def get_recent_users(txn: LoggingTransaction, since_ms: int) -> List[UserInfo]: +def get_recent_users( + txn: LoggingTransaction, since_ms: int, exclude_app_service: bool +) -> List[UserInfo]: """Fetches recently registered users and some info on them.""" sql = """ @@ -56,6 +58,9 @@ def get_recent_users(txn: LoggingTransaction, since_ms: int) -> List[UserInfo]: AND deactivated = 0 """ + if exclude_app_service: + sql += " AND appservice_id IS NULL" + txn.execute(sql, (since_ms / 1000,)) user_infos = [UserInfo(user_id, creation_ts) for user_id, creation_ts in txn] @@ -113,7 +118,7 @@ def main() -> None: "-e", "--exclude-emails", action="store_true", - help="Exclude users that have validated email addresses", + help="Exclude users that have validated email addresses.", ) parser.add_argument( "-u", @@ -121,6 +126,12 @@ def main() -> None: action="store_true", help="Only print user IDs that match.", ) + parser.add_argument( + "-a", + "--exclude-app-service", + help="Exclude appservice users.", + action="store_true", + ) config = ReviewConfig() @@ -133,6 +144,7 @@ def main() -> None: since_ms = time.time() * 1000 - Config.parse_duration(config_args.since) exclude_users_with_email = config_args.exclude_emails + exclude_users_with_appservice = config_args.exclude_app_service include_context = not config_args.only_users for database_config in config.database.databases: @@ -143,7 +155,7 @@ def main() -> None: with make_conn(database_config, engine, "review_recent_signups") as db_conn: # This generates a type of Cursor, not LoggingTransaction. - user_infos = get_recent_users(db_conn.cursor(), since_ms) # type: ignore[arg-type] + user_infos = get_recent_users(db_conn.cursor(), since_ms, exclude_users_with_appservice) # type: ignore[arg-type] for user_info in user_infos: if exclude_users_with_email and user_info.emails: diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 4a32d430bd..683241201c 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -71,6 +71,7 @@ class Auth: self._auth_blocking = AuthBlocking(self.hs) self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips + self._track_puppeted_user_ips = hs.config.api.track_puppeted_user_ips self._macaroon_secret_key = hs.config.key.macaroon_secret_key self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users @@ -246,6 +247,18 @@ class Auth: user_agent=user_agent, device_id=device_id, ) + # Track also the puppeted user client IP if enabled and the user is puppeting + if ( + user_info.user_id != user_info.token_owner + and self._track_puppeted_user_ips + ): + await self.store.insert_client_ip( + user_id=user_info.user_id, + access_token=access_token, + ip=ip_addr, + user_agent=user_agent, + device_id=device_id, + ) if is_guest and not allow_guest: raise AuthError( diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index 0a895bba48..a747a40814 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -46,41 +46,41 @@ class RoomDisposition: UNSTABLE = "unstable" -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class RoomVersion: """An object which describes the unique attributes of a room version.""" - identifier = attr.ib(type=str) # the identifier for this version - disposition = attr.ib(type=str) # one of the RoomDispositions - event_format = attr.ib(type=int) # one of the EventFormatVersions - state_res = attr.ib(type=int) # one of the StateResolutionVersions - enforce_key_validity = attr.ib(type=bool) + identifier: str # the identifier for this version + disposition: str # one of the RoomDispositions + event_format: int # one of the EventFormatVersions + state_res: int # one of the StateResolutionVersions + enforce_key_validity: bool # Before MSC2432, m.room.aliases had special auth rules and redaction rules - special_case_aliases_auth = attr.ib(type=bool) + special_case_aliases_auth: bool # Strictly enforce canonicaljson, do not allow: # * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1] # * Floats # * NaN, Infinity, -Infinity - strict_canonicaljson = attr.ib(type=bool) + strict_canonicaljson: bool # MSC2209: Check 'notifications' key while verifying # m.room.power_levels auth rules. - limit_notifications_power_levels = attr.ib(type=bool) + limit_notifications_power_levels: bool # MSC2174/MSC2176: Apply updated redaction rules algorithm. - msc2176_redaction_rules = attr.ib(type=bool) + msc2176_redaction_rules: bool # MSC3083: Support the 'restricted' join_rule. - msc3083_join_rules = attr.ib(type=bool) + msc3083_join_rules: bool # MSC3375: Support for the proper redaction rules for MSC3083. This mustn't # be enabled if MSC3083 is not. - msc3375_redaction_rules = attr.ib(type=bool) + msc3375_redaction_rules: bool # MSC2403: Allows join_rules to be set to 'knock', changes auth rules to allow sending # m.room.membership event with membership 'knock'. - msc2403_knocking = attr.ib(type=bool) + msc2403_knocking: bool # MSC2716: Adds m.room.power_levels -> content.historical field to control # whether "insertion", "chunk", "marker" events can be sent - msc2716_historical = attr.ib(type=bool) + msc2716_historical: bool # MSC2716: Adds support for redacting "insertion", "chunk", and "marker" events - msc2716_redactions = attr.ib(type=bool) + msc2716_redactions: bool class RoomVersions: diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 5fc59c1be1..579adbbca0 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -60,7 +60,7 @@ from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.handlers.auth import load_legacy_password_auth_providers from synapse.logging.context import PreserveLoggingContext -from synapse.metrics import register_threadpool +from synapse.metrics import install_gc_manager, register_threadpool from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.jemalloc import setup_jemalloc_stats from synapse.types import ISynapseReactor @@ -159,6 +159,7 @@ def start_reactor( change_resource_limit(soft_file_limit) if gc_thresholds: gc.set_threshold(*gc_thresholds) + install_gc_manager() run_command() # make sure that we run the reactor with the sentinel log context, diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index dd76e07321..efedcc8889 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -131,9 +131,18 @@ class SynapseHomeServer(HomeServer): resources.update(self._module_web_resources) self._module_web_resources_consumed = True - # try to find something useful to redirect '/' to - if WEB_CLIENT_PREFIX in resources: - root_resource: Resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX) + # Try to find something useful to serve at '/': + # + # 1. Redirect to the web client if it is an HTTP(S) URL. + # 2. Redirect to the web client served via Synapse. + # 3. Redirect to the static "Synapse is running" page. + # 4. Do not redirect and use a blank resource. + if self.config.server.web_client_location_is_redirect: + root_resource: Resource = RootOptionsRedirectResource( + self.config.server.web_client_location + ) + elif WEB_CLIENT_PREFIX in resources: + root_resource = RootOptionsRedirectResource(WEB_CLIENT_PREFIX) elif STATIC_PREFIX in resources: root_resource = RootOptionsRedirectResource(STATIC_PREFIX) else: @@ -262,15 +271,15 @@ class SynapseHomeServer(HomeServer): resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) if name == "webclient": + # webclient listeners are deprecated as of Synapse v1.51.0, remove it + # in > v1.53.0. webclient_loc = self.config.server.web_client_location if webclient_loc is None: logger.warning( "Not enabling webclient resource, as web_client_location is unset." ) - elif webclient_loc.startswith("http://") or webclient_loc.startswith( - "https://" - ): + elif self.config.server.web_client_location_is_redirect: resources[WEB_CLIENT_PREFIX] = RootRedirect(webclient_loc) else: logger.warning( diff --git a/synapse/config/api.py b/synapse/config/api.py index 25538b82d5..8133b6b624 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -29,6 +29,7 @@ class ApiConfig(Config): def read_config(self, config: JsonDict, **kwargs): validate_config(_MAIN_SCHEMA, config, ()) self.room_prejoin_state = list(self._get_prejoin_state_types(config)) + self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False) def generate_config_section(cls, **kwargs) -> str: formatted_default_state_types = "\n".join( @@ -59,6 +60,21 @@ class ApiConfig(Config): # #additional_event_types: # - org.example.custom.event.type + + # We record the IP address of clients used to access the API for various + # reasons, including displaying it to the user in the "Where you're signed in" + # dialog. + # + # By default, when puppeting another user via the admin API, the client IP + # address is recorded against the user who created the access token (ie, the + # admin user), and *not* the puppeted user. + # + # Uncomment the following to also record the IP address against the puppeted + # user. (This also means that the puppeted user will count as an "active" user + # for the purpose of monthly active user tracking - see 'limit_usage_by_mau' etc + # above.) + # + #track_puppeted_user_ips: true """ % { "formatted_default_state_types": formatted_default_state_types } @@ -138,5 +154,8 @@ _MAIN_SCHEMA = { "properties": { "room_prejoin_state": _ROOM_PREJOIN_STATE_CONFIG_SCHEMA, "room_invite_state_types": _ROOM_INVITE_STATE_TYPES_SCHEMA, + "track_puppeted_user_ips": { + "type": "boolean", + }, }, } diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 510b647c63..949d7dd5ac 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -55,19 +55,19 @@ https://matrix-org.github.io/synapse/latest/templates.html ---------------------------------------------------------------------------------------""" -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class EmailSubjectConfig: - message_from_person_in_room = attr.ib(type=str) - message_from_person = attr.ib(type=str) - messages_from_person = attr.ib(type=str) - messages_in_room = attr.ib(type=str) - messages_in_room_and_others = attr.ib(type=str) - messages_from_person_and_others = attr.ib(type=str) - invite_from_person = attr.ib(type=str) - invite_from_person_to_room = attr.ib(type=str) - invite_from_person_to_space = attr.ib(type=str) - password_reset = attr.ib(type=str) - email_validation = attr.ib(type=str) + message_from_person_in_room: str + message_from_person: str + messages_from_person: str + messages_in_room: str + messages_in_room_and_others: str + messages_from_person_and_others: str + invite_from_person: str + invite_from_person_to_room: str + invite_from_person_to_space: str + password_reset: str + email_validation: str class EmailConfig(Config): diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index 79c400fe30..e783b11315 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -148,10 +148,13 @@ class OIDCConfig(Config): # Defaults to false. Avoid this in production. # # user_profile_method: Whether to fetch the user profile from the userinfo - # endpoint. Valid values are: 'auto' or 'userinfo_endpoint'. + # endpoint, or to rely on the data returned in the id_token from the + # token_endpoint. # - # Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is - # included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the + # Valid values are: 'auto' or 'userinfo_endpoint'. + # + # Defaults to 'auto', which uses the userinfo endpoint if 'openid' is + # not included in 'scopes'. Set to 'userinfo_endpoint' to always use the # userinfo endpoint. # # allow_existing_users: set to 'true' to allow a user logging in via OIDC to diff --git a/synapse/config/server.py b/synapse/config/server.py index 1de2dea9b0..f200d0c1f1 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -200,8 +200,8 @@ class HttpListenerConfig: """Object describing the http-specific parts of the config of a listener""" x_forwarded: bool = False - resources: List[HttpResourceConfig] = attr.ib(factory=list) - additional_resources: Dict[str, dict] = attr.ib(factory=dict) + resources: List[HttpResourceConfig] = attr.Factory(list) + additional_resources: Dict[str, dict] = attr.Factory(dict) tag: Optional[str] = None @@ -259,7 +259,6 @@ class ServerConfig(Config): raise ConfigError(str(e)) self.pid_file = self.abspath(config.get("pid_file")) - self.web_client_location = config.get("web_client_location", None) self.soft_file_limit = config.get("soft_file_limit", 0) self.daemonize = config.get("daemonize") self.print_pidfile = config.get("print_pidfile") @@ -506,8 +505,17 @@ class ServerConfig(Config): l2.append(listener) self.listeners = l2 - if not self.web_client_location: - _warn_if_webclient_configured(self.listeners) + self.web_client_location = config.get("web_client_location", None) + self.web_client_location_is_redirect = self.web_client_location and ( + self.web_client_location.startswith("http://") + or self.web_client_location.startswith("https://") + ) + # A non-HTTP(S) web client location is deprecated. + if self.web_client_location and not self.web_client_location_is_redirect: + logger.warning(NO_MORE_NONE_HTTP_WEB_CLIENT_LOCATION_WARNING) + + # Warn if webclient is configured for a worker. + _warn_if_webclient_configured(self.listeners) self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None)) self.gc_seconds = self.read_gc_intervals(config.get("gc_min_interval", None)) @@ -793,13 +801,7 @@ class ServerConfig(Config): # pid_file: %(pid_file)s - # The absolute URL to the web client which /_matrix/client will redirect - # to if 'webclient' is configured under the 'listeners' configuration. - # - # This option can be also set to the filesystem path to the web client - # which will be served at /_matrix/client/ if 'webclient' is configured - # under the 'listeners' configuration, however this is a security risk: - # https://github.com/matrix-org/synapse#security-note + # The absolute URL to the web client which / will redirect to. # #web_client_location: https://riot.example.com/ @@ -883,7 +885,7 @@ class ServerConfig(Config): # The default room version for newly created rooms. # # Known room versions are listed here: - # https://matrix.org/docs/spec/#complete-list-of-room-versions + # https://spec.matrix.org/latest/rooms/#complete-list-of-room-versions # # For example, for room version 1, default_room_version should be set # to "1". @@ -1011,8 +1013,6 @@ class ServerConfig(Config): # static: static resources under synapse/static (/_matrix/static). (Mostly # useful for 'fallback authentication'.) # - # webclient: A web client. Requires web_client_location to be set. - # listeners: # TLS-enabled listener: for when matrix traffic is sent directly to synapse. # @@ -1349,9 +1349,15 @@ def parse_listener_def(listener: Any) -> ListenerConfig: return ListenerConfig(port, bind_addresses, listener_type, tls, http_config) +NO_MORE_NONE_HTTP_WEB_CLIENT_LOCATION_WARNING = """ +Synapse no longer supports serving a web client. To remove this warning, +configure 'web_client_location' with an HTTP(S) URL. +""" + + NO_MORE_WEB_CLIENT_WARNING = """ -Synapse no longer includes a web client. To enable a web client, configure -web_client_location. To remove this warning, remove 'webclient' from the 'listeners' +Synapse no longer includes a web client. To redirect the root resource to a web client, configure +'web_client_location'. To remove this warning, remove 'webclient' from the 'listeners' configuration. """ diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 576f519188..bdaba6db37 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -51,12 +51,12 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]: return obj -@attr.s +@attr.s(auto_attribs=True) class InstanceLocationConfig: """The host and port to talk to an instance via HTTP replication.""" - host = attr.ib(type=str) - port = attr.ib(type=int) + host: str + port: int @attr.s @@ -77,34 +77,28 @@ class WriterLocations: can only be a single instance. """ - events = attr.ib( + events: List[str] = attr.ib( default=["master"], - type=List[str], converter=_instance_to_list_converter, ) - typing = attr.ib( + typing: List[str] = attr.ib( default=["master"], - type=List[str], converter=_instance_to_list_converter, ) - to_device = attr.ib( + to_device: List[str] = attr.ib( default=["master"], - type=List[str], converter=_instance_to_list_converter, ) - account_data = attr.ib( + account_data: List[str] = attr.ib( default=["master"], - type=List[str], converter=_instance_to_list_converter, ) - receipts = attr.ib( + receipts: List[str] = attr.ib( default=["master"], - type=List[str], converter=_instance_to_list_converter, ) - presence = attr.ib( + presence: List[str] = attr.ib( default=["master"], - type=List[str], converter=_instance_to_list_converter, ) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 993b04099e..72d4a69aac 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -58,7 +58,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -@attr.s(slots=True, cmp=False) +@attr.s(slots=True, frozen=True, cmp=False, auto_attribs=True) class VerifyJsonRequest: """ A request to verify a JSON object. @@ -78,10 +78,10 @@ class VerifyJsonRequest: key_ids: The set of key_ids to that could be used to verify the JSON object """ - server_name = attr.ib(type=str) - get_json_object = attr.ib(type=Callable[[], JsonDict]) - minimum_valid_until_ts = attr.ib(type=int) - key_ids = attr.ib(type=List[str]) + server_name: str + get_json_object: Callable[[], JsonDict] + minimum_valid_until_ts: int + key_ids: List[str] @staticmethod def from_json_object( @@ -124,7 +124,7 @@ class KeyLookupError(ValueError): pass -@attr.s(slots=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class _FetchKeyRequest: """A request for keys for a given server. @@ -138,9 +138,9 @@ class _FetchKeyRequest: key_ids: The IDs of the keys to attempt to fetch """ - server_name = attr.ib(type=str) - minimum_valid_until_ts = attr.ib(type=int) - key_ids = attr.ib(type=List[str]) + server_name: str + minimum_valid_until_ts: int + key_ids: List[str] class Keyring: diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index f251402ed8..0eab1aefd6 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: from synapse.storage.databases.main import DataStore -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class EventContext: """ Holds information relevant to persisting an event @@ -103,15 +103,15 @@ class EventContext: accessed via get_prev_state_ids. """ - rejected = attr.ib(default=False, type=Union[bool, str]) - _state_group = attr.ib(default=None, type=Optional[int]) - state_group_before_event = attr.ib(default=None, type=Optional[int]) - prev_group = attr.ib(default=None, type=Optional[int]) - delta_ids = attr.ib(default=None, type=Optional[StateMap[str]]) - app_service = attr.ib(default=None, type=Optional[ApplicationService]) + rejected: Union[bool, str] = False + _state_group: Optional[int] = None + state_group_before_event: Optional[int] = None + prev_group: Optional[int] = None + delta_ids: Optional[StateMap[str]] = None + app_service: Optional[ApplicationService] = None - _current_state_ids = attr.ib(default=None, type=Optional[StateMap[str]]) - _prev_state_ids = attr.ib(default=None, type=Optional[StateMap[str]]) + _current_state_ids: Optional[StateMap[str]] = None + _prev_state_ids: Optional[StateMap[str]] = None @staticmethod def with_state( diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 2038e72924..918adeecf8 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -14,17 +14,7 @@ # limitations under the License. import collections.abc import re -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - List, - Mapping, - Optional, - Union, -) +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union from frozendict import frozendict @@ -32,14 +22,10 @@ from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion from synapse.types import JsonDict -from synapse.util.async_helpers import yieldable_gather_results from synapse.util.frozenutils import unfreeze from . import EventBase -if TYPE_CHECKING: - from synapse.server import HomeServer - # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' # (?<!stuff) matches if the current position in the string is not preceded # by a match for 'stuff'. @@ -385,17 +371,12 @@ class EventClientSerializer: clients. """ - def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() - self._msc1849_enabled = hs.config.experimental.msc1849_enabled - self._msc3440_enabled = hs.config.experimental.msc3440_enabled - - async def serialize_event( + def serialize_event( self, event: Union[JsonDict, EventBase], time_now: int, *, - bundle_aggregations: bool = False, + bundle_aggregations: Optional[Dict[str, JsonDict]] = None, **kwargs: Any, ) -> JsonDict: """Serializes a single event. @@ -418,66 +399,41 @@ class EventClientSerializer: serialized_event = serialize_event(event, time_now, **kwargs) # Check if there are any bundled aggregations to include with the event. - # - # Do not bundle aggregations if any of the following at true: - # - # * Support is disabled via the configuration or the caller. - # * The event is a state event. - # * The event has been redacted. - if ( - self._msc1849_enabled - and bundle_aggregations - and not event.is_state() - and not event.internal_metadata.is_redacted() - ): - await self._injected_bundled_aggregations(event, time_now, serialized_event) + if bundle_aggregations: + event_aggregations = bundle_aggregations.get(event.event_id) + if event_aggregations: + self._inject_bundled_aggregations( + event, + time_now, + bundle_aggregations[event.event_id], + serialized_event, + ) return serialized_event - async def _injected_bundled_aggregations( - self, event: EventBase, time_now: int, serialized_event: JsonDict + def _inject_bundled_aggregations( + self, + event: EventBase, + time_now: int, + aggregations: JsonDict, + serialized_event: JsonDict, ) -> None: """Potentially injects bundled aggregations into the unsigned portion of the serialized event. Args: event: The event being serialized. time_now: The current time in milliseconds + aggregations: The bundled aggregation to serialize. serialized_event: The serialized event which may be modified. """ - # Do not bundle aggregations for an event which represents an edit or an - # annotation. It does not make sense for them to have related events. - relates_to = event.content.get("m.relates_to") - if isinstance(relates_to, (dict, frozendict)): - relation_type = relates_to.get("rel_type") - if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): - return - - event_id = event.event_id - room_id = event.room_id - - # The bundled aggregations to include. - aggregations = {} - - annotations = await self.store.get_aggregation_groups_for_event( - event_id, room_id - ) - if annotations.chunk: - aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() + # Make a copy in-case the object is cached. + aggregations = aggregations.copy() - references = await self.store.get_relations_for_event( - event_id, room_id, RelationTypes.REFERENCE, direction="f" - ) - if references.chunk: - aggregations[RelationTypes.REFERENCE] = references.to_dict() - - edit = None - if event.type == EventTypes.Message: - edit = await self.store.get_applicable_edit(event_id, room_id) - - if edit: + if RelationTypes.REPLACE in aggregations: # If there is an edit replace the content, preserving existing # relations. + edit = aggregations[RelationTypes.REPLACE] # Ensure we take copies of the edit content, otherwise we risk modifying # the original event. @@ -502,27 +458,19 @@ class EventClientSerializer: } # If this event is the start of a thread, include a summary of the replies. - if self._msc3440_enabled: - ( - thread_count, - latest_thread_event, - ) = await self.store.get_thread_summary(event_id, room_id) - if latest_thread_event: - aggregations[RelationTypes.THREAD] = { - # Don't bundle aggregations as this could recurse forever. - "latest_event": await self.serialize_event( - latest_thread_event, time_now, bundle_aggregations=False - ), - "count": thread_count, - } - - # If any bundled aggregations were found, include them. - if aggregations: - serialized_event["unsigned"].setdefault("m.relations", {}).update( - aggregations + if RelationTypes.THREAD in aggregations: + # Serialize the latest thread event. + latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"] + + # Don't bundle aggregations as this could recurse forever. + aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event( + latest_thread_event, time_now, bundle_aggregations=None ) - async def serialize_events( + # Include the bundled aggregations in the event. + serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations) + + def serialize_events( self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any ) -> List[JsonDict]: """Serializes multiple events. @@ -535,9 +483,9 @@ class EventClientSerializer: Returns: The list of serialized events """ - return await yieldable_gather_results( - self.serialize_event, events, time_now=time_now, **kwargs - ) + return [ + self.serialize_event(event, time_now=time_now, **kwargs) for event in events + ] def copy_power_levels_contents( diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index addc0bf000..896168c05c 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -230,6 +230,10 @@ def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventB # origin, etc etc) assert_params_in_dict(pdu_json, ("type", "depth")) + # Strip any unauthorized values from "unsigned" if they exist + if "unsigned" in pdu_json: + _strip_unsigned_values(pdu_json) + depth = pdu_json["depth"] if not isinstance(depth, int): raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON) @@ -245,3 +249,24 @@ def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventB event = make_event_from_dict(pdu_json, room_version) return event + + +def _strip_unsigned_values(pdu_dict: JsonDict) -> None: + """ + Strip any unsigned values unless specifically allowed, as defined by the whitelist. + + pdu: the json dict to strip values from. Note that the dict is mutated by this + function + """ + unsigned = pdu_dict["unsigned"] + + if not isinstance(unsigned, dict): + pdu_dict["unsigned"] = {} + + if pdu_dict["type"] == "m.room.member": + whitelist = ["knock_room_state", "invite_room_state", "age"] + else: + whitelist = ["age"] + + filtered_unsigned = {k: v for k, v in unsigned.items() if k in whitelist} + pdu_dict["unsigned"] = filtered_unsigned diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 6ea4edfc71..74f17aa4da 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -56,7 +56,6 @@ from synapse.api.room_versions import ( from synapse.events import EventBase, builder from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.transport.client import SendJoinResponse -from synapse.logging.utils import log_function from synapse.types import JsonDict, get_domain_from_id from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache @@ -119,7 +118,8 @@ class FederationClient(FederationBase): # It is a map of (room ID, suggested-only) -> the response of # get_room_hierarchy. self._get_room_hierarchy_cache: ExpiringCache[ - Tuple[str, bool], Tuple[JsonDict, Sequence[JsonDict], Sequence[str]] + Tuple[str, bool], + Tuple[JsonDict, Sequence[JsonDict], Sequence[JsonDict], Sequence[str]], ] = ExpiringCache( cache_name="get_room_hierarchy_cache", clock=self._clock, @@ -144,7 +144,6 @@ class FederationClient(FederationBase): if destination_dict: self.pdu_destination_tried[event_id] = destination_dict - @log_function async def make_query( self, destination: str, @@ -178,7 +177,6 @@ class FederationClient(FederationBase): ignore_backoff=ignore_backoff, ) - @log_function async def query_client_keys( self, destination: str, content: JsonDict, timeout: int ) -> JsonDict: @@ -196,7 +194,6 @@ class FederationClient(FederationBase): destination, content, timeout ) - @log_function async def query_user_devices( self, destination: str, user_id: str, timeout: int = 30000 ) -> JsonDict: @@ -208,7 +205,6 @@ class FederationClient(FederationBase): destination, user_id, timeout ) - @log_function async def claim_client_keys( self, destination: str, content: JsonDict, timeout: int ) -> JsonDict: @@ -1338,7 +1334,7 @@ class FederationClient(FederationBase): destinations: Iterable[str], room_id: str, suggested_only: bool, - ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]: + ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[JsonDict], Sequence[str]]: """ Call other servers to get a hierarchy of the given room. @@ -1353,7 +1349,8 @@ class FederationClient(FederationBase): Returns: A tuple of: - The room as a JSON dictionary. + The room as a JSON dictionary, without a "children_state" key. + A list of `m.space.child` state events. A list of children rooms, as JSON dictionaries. A list of inaccessible children room IDs. @@ -1368,7 +1365,7 @@ class FederationClient(FederationBase): async def send_request( destination: str, - ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[str]]: + ) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[JsonDict], Sequence[str]]: try: res = await self.transport_layer.get_room_hierarchy( destination=destination, @@ -1397,7 +1394,7 @@ class FederationClient(FederationBase): raise InvalidResponseError("'room' must be a dict") # Validate children_state of the room. - children_state = room.get("children_state", []) + children_state = room.pop("children_state", []) if not isinstance(children_state, Sequence): raise InvalidResponseError("'room.children_state' must be a list") if any(not isinstance(e, dict) for e in children_state): @@ -1426,7 +1423,7 @@ class FederationClient(FederationBase): "Invalid room ID in 'inaccessible_children' list" ) - return room, children, inaccessible_children + return room, children_state, children, inaccessible_children try: result = await self._try_destination_list( @@ -1474,8 +1471,6 @@ class FederationClient(FederationBase): if event.room_id == room_id: children_events.append(event.data) children_room_ids.add(event.state_key) - # And add them under the requested room. - requested_room["children_state"] = children_events # Find the children rooms. children = [] @@ -1485,7 +1480,7 @@ class FederationClient(FederationBase): # It isn't clear from the response whether some of the rooms are # not accessible. - result = (requested_room, children, ()) + result = (requested_room, children_events, children, ()) # Cache the result to avoid fetching data over federation every time. self._get_room_hierarchy_cache[(room_id, suggested_only)] = result diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index ee71f289c8..af9cb98f67 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -58,7 +58,6 @@ from synapse.logging.context import ( run_in_background, ) from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace -from synapse.logging.utils import log_function from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.replication.http.federation import ( ReplicationFederationSendEduRestServlet, @@ -859,7 +858,6 @@ class FederationServer(FederationBase): res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]} return 200, res - @log_function async def on_query_client_keys( self, origin: str, content: Dict[str, str] ) -> Tuple[int, Dict[str, Any]]: @@ -940,7 +938,6 @@ class FederationServer(FederationBase): return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]} - @log_function async def on_openid_userinfo(self, token: str) -> Optional[str]: ts_now_ms = self._clock.time_msec() return await self.store.get_user_id_for_open_id_token(token, ts_now_ms) diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 523ab1c51e..60e2e6cf01 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -23,7 +23,6 @@ import logging from typing import Optional, Tuple from synapse.federation.units import Transaction -from synapse.logging.utils import log_function from synapse.storage.databases.main import DataStore from synapse.types import JsonDict @@ -36,7 +35,6 @@ class TransactionActions: def __init__(self, datastore: DataStore): self.store = datastore - @log_function async def have_responded( self, origin: str, transaction: Transaction ) -> Optional[Tuple[int, JsonDict]]: @@ -53,7 +51,6 @@ class TransactionActions: return await self.store.get_received_txn_response(transaction_id, origin) - @log_function async def set_response( self, origin: str, transaction: Transaction, code: int, response: JsonDict ) -> None: diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 391b30fbb5..8152e80b88 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -607,18 +607,18 @@ class PerDestinationQueue: self._pending_pdus = [] -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class _TransactionQueueManager: """A helper async context manager for pulling stuff off the queues and tracking what was last successfully sent, etc. """ - queue = attr.ib(type=PerDestinationQueue) + queue: PerDestinationQueue - _device_stream_id = attr.ib(type=Optional[int], default=None) - _device_list_id = attr.ib(type=Optional[int], default=None) - _last_stream_ordering = attr.ib(type=Optional[int], default=None) - _pdus = attr.ib(type=List[EventBase], factory=list) + _device_stream_id: Optional[int] = None + _device_list_id: Optional[int] = None + _last_stream_ordering: Optional[int] = None + _pdus: List[EventBase] = attr.Factory(list) async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]: # First we calculate the EDUs we want to send, if any. diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index ab935e5a7e..742ee57255 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -35,6 +35,7 @@ if TYPE_CHECKING: import synapse.server logger = logging.getLogger(__name__) +issue_8631_logger = logging.getLogger("synapse.8631_debug") last_pdu_ts_metric = Gauge( "synapse_federation_last_sent_pdu_time", @@ -124,6 +125,17 @@ class TransactionManager: len(pdus), len(edus), ) + if issue_8631_logger.isEnabledFor(logging.DEBUG): + DEVICE_UPDATE_EDUS = {"m.device_list_update", "m.signing_key_update"} + device_list_updates = [ + edu.content for edu in edus if edu.edu_type in DEVICE_UPDATE_EDUS + ] + if device_list_updates: + issue_8631_logger.debug( + "about to send txn [%s] including device list updates: %s", + transaction.transaction_id, + device_list_updates, + ) # Actually send the transaction diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 9fc4c31c93..8782586cd6 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -44,7 +44,6 @@ from synapse.api.urls import ( from synapse.events import EventBase, make_event_from_dict from synapse.federation.units import Transaction from synapse.http.matrixfederationclient import ByteParser -from synapse.logging.utils import log_function from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -62,7 +61,6 @@ class TransportLayerClient: self.server_name = hs.hostname self.client = hs.get_federation_http_client() - @log_function async def get_room_state_ids( self, destination: str, room_id: str, event_id: str ) -> JsonDict: @@ -88,7 +86,6 @@ class TransportLayerClient: try_trailing_slash_on_400=True, ) - @log_function async def get_event( self, destination: str, event_id: str, timeout: Optional[int] = None ) -> JsonDict: @@ -111,7 +108,6 @@ class TransportLayerClient: destination, path=path, timeout=timeout, try_trailing_slash_on_400=True ) - @log_function async def backfill( self, destination: str, room_id: str, event_tuples: Collection[str], limit: int ) -> Optional[JsonDict]: @@ -149,7 +145,6 @@ class TransportLayerClient: destination, path=path, args=args, try_trailing_slash_on_400=True ) - @log_function async def timestamp_to_event( self, destination: str, room_id: str, timestamp: int, direction: str ) -> Union[JsonDict, List]: @@ -185,7 +180,6 @@ class TransportLayerClient: return remote_response - @log_function async def send_transaction( self, transaction: Transaction, @@ -234,7 +228,6 @@ class TransportLayerClient: try_trailing_slash_on_400=True, ) - @log_function async def make_query( self, destination: str, @@ -254,7 +247,6 @@ class TransportLayerClient: ignore_backoff=ignore_backoff, ) - @log_function async def make_membership_event( self, destination: str, @@ -317,7 +309,6 @@ class TransportLayerClient: ignore_backoff=ignore_backoff, ) - @log_function async def send_join_v1( self, room_version: RoomVersion, @@ -336,7 +327,6 @@ class TransportLayerClient: max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN, ) - @log_function async def send_join_v2( self, room_version: RoomVersion, @@ -355,7 +345,6 @@ class TransportLayerClient: max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN, ) - @log_function async def send_leave_v1( self, destination: str, room_id: str, event_id: str, content: JsonDict ) -> Tuple[int, JsonDict]: @@ -372,7 +361,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def send_leave_v2( self, destination: str, room_id: str, event_id: str, content: JsonDict ) -> JsonDict: @@ -389,7 +377,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def send_knock_v1( self, destination: str, @@ -423,7 +410,6 @@ class TransportLayerClient: destination=destination, path=path, data=content ) - @log_function async def send_invite_v1( self, destination: str, room_id: str, event_id: str, content: JsonDict ) -> Tuple[int, JsonDict]: @@ -433,7 +419,6 @@ class TransportLayerClient: destination=destination, path=path, data=content, ignore_backoff=True ) - @log_function async def send_invite_v2( self, destination: str, room_id: str, event_id: str, content: JsonDict ) -> JsonDict: @@ -443,7 +428,6 @@ class TransportLayerClient: destination=destination, path=path, data=content, ignore_backoff=True ) - @log_function async def get_public_rooms( self, remote_server: str, @@ -516,7 +500,6 @@ class TransportLayerClient: return response - @log_function async def exchange_third_party_invite( self, destination: str, room_id: str, event_dict: JsonDict ) -> JsonDict: @@ -526,7 +509,6 @@ class TransportLayerClient: destination=destination, path=path, data=event_dict ) - @log_function async def get_event_auth( self, destination: str, room_id: str, event_id: str ) -> JsonDict: @@ -534,7 +516,6 @@ class TransportLayerClient: return await self.client.get_json(destination=destination, path=path) - @log_function async def query_client_keys( self, destination: str, query_content: JsonDict, timeout: int ) -> JsonDict: @@ -576,7 +557,6 @@ class TransportLayerClient: destination=destination, path=path, data=query_content, timeout=timeout ) - @log_function async def query_user_devices( self, destination: str, user_id: str, timeout: int ) -> JsonDict: @@ -616,7 +596,6 @@ class TransportLayerClient: destination=destination, path=path, timeout=timeout ) - @log_function async def claim_client_keys( self, destination: str, query_content: JsonDict, timeout: int ) -> JsonDict: @@ -655,7 +634,6 @@ class TransportLayerClient: destination=destination, path=path, data=query_content, timeout=timeout ) - @log_function async def get_missing_events( self, destination: str, @@ -680,7 +658,6 @@ class TransportLayerClient: timeout=timeout, ) - @log_function async def get_group_profile( self, destination: str, group_id: str, requester_user_id: str ) -> JsonDict: @@ -694,7 +671,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def update_group_profile( self, destination: str, group_id: str, requester_user_id: str, content: JsonDict ) -> JsonDict: @@ -716,7 +692,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def get_group_summary( self, destination: str, group_id: str, requester_user_id: str ) -> JsonDict: @@ -730,7 +705,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def get_rooms_in_group( self, destination: str, group_id: str, requester_user_id: str ) -> JsonDict: @@ -798,7 +772,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def get_users_in_group( self, destination: str, group_id: str, requester_user_id: str ) -> JsonDict: @@ -812,7 +785,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def get_invited_users_in_group( self, destination: str, group_id: str, requester_user_id: str ) -> JsonDict: @@ -826,7 +798,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def accept_group_invite( self, destination: str, group_id: str, user_id: str, content: JsonDict ) -> JsonDict: @@ -837,7 +808,6 @@ class TransportLayerClient: destination=destination, path=path, data=content, ignore_backoff=True ) - @log_function def join_group( self, destination: str, group_id: str, user_id: str, content: JsonDict ) -> Awaitable[JsonDict]: @@ -848,7 +818,6 @@ class TransportLayerClient: destination=destination, path=path, data=content, ignore_backoff=True ) - @log_function async def invite_to_group( self, destination: str, @@ -868,7 +837,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def invite_to_group_notification( self, destination: str, group_id: str, user_id: str, content: JsonDict ) -> JsonDict: @@ -882,7 +850,6 @@ class TransportLayerClient: destination=destination, path=path, data=content, ignore_backoff=True ) - @log_function async def remove_user_from_group( self, destination: str, @@ -902,7 +869,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def remove_user_from_group_notification( self, destination: str, group_id: str, user_id: str, content: JsonDict ) -> JsonDict: @@ -916,7 +882,6 @@ class TransportLayerClient: destination=destination, path=path, data=content, ignore_backoff=True ) - @log_function async def renew_group_attestation( self, destination: str, group_id: str, user_id: str, content: JsonDict ) -> JsonDict: @@ -930,7 +895,6 @@ class TransportLayerClient: destination=destination, path=path, data=content, ignore_backoff=True ) - @log_function async def update_group_summary_room( self, destination: str, @@ -959,7 +923,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def delete_group_summary_room( self, destination: str, @@ -986,7 +949,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def get_group_categories( self, destination: str, group_id: str, requester_user_id: str ) -> JsonDict: @@ -1000,7 +962,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def get_group_category( self, destination: str, group_id: str, requester_user_id: str, category_id: str ) -> JsonDict: @@ -1014,7 +975,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def update_group_category( self, destination: str, @@ -1034,7 +994,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def delete_group_category( self, destination: str, group_id: str, requester_user_id: str, category_id: str ) -> JsonDict: @@ -1048,7 +1007,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def get_group_roles( self, destination: str, group_id: str, requester_user_id: str ) -> JsonDict: @@ -1062,7 +1020,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def get_group_role( self, destination: str, group_id: str, requester_user_id: str, role_id: str ) -> JsonDict: @@ -1076,7 +1033,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def update_group_role( self, destination: str, @@ -1096,7 +1052,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def delete_group_role( self, destination: str, group_id: str, requester_user_id: str, role_id: str ) -> JsonDict: @@ -1110,7 +1065,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def update_group_summary_user( self, destination: str, @@ -1136,7 +1090,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def set_group_join_policy( self, destination: str, group_id: str, requester_user_id: str, content: JsonDict ) -> JsonDict: @@ -1151,7 +1104,6 @@ class TransportLayerClient: ignore_backoff=True, ) - @log_function async def delete_group_summary_user( self, destination: str, diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 77bfd88ad0..beadfa422b 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -36,6 +36,7 @@ from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) +issue_8631_logger = logging.getLogger("synapse.8631_debug") class BaseFederationServerServlet(BaseFederationServlet): @@ -95,6 +96,20 @@ class FederationSendServlet(BaseFederationServerServlet): len(transaction_data.get("edus", [])), ) + if issue_8631_logger.isEnabledFor(logging.DEBUG): + DEVICE_UPDATE_EDUS = {"m.device_list_update", "m.signing_key_update"} + device_list_updates = [ + edu.content + for edu in transaction_data.get("edus", []) + if edu.edu_type in DEVICE_UPDATE_EDUS + ] + if device_list_updates: + issue_8631_logger.debug( + "received transaction [%s] including device list updates: %s", + transaction_id, + device_list_updates, + ) + except Exception as e: logger.exception(e) return 400, {"error": "Invalid transaction"} diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 96273e2f81..bad48713bc 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -77,7 +77,7 @@ class AccountDataHandler: async def add_account_data_for_user( self, user_id: str, account_data_type: str, content: JsonDict ) -> int: - """Add some account_data to a room for a user. + """Add some global account_data for a user. Args: user_id: The user to add a tag for. diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 85157a138b..00ab5e79bf 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -55,21 +55,47 @@ class AdminHandler: async def get_user(self, user: UserID) -> Optional[JsonDict]: """Function to get user details""" - ret = await self.store.get_user_by_id(user.to_string()) - if ret: - profile = await self.store.get_profileinfo(user.localpart) - threepids = await self.store.user_get_threepids(user.to_string()) - external_ids = [ - ({"auth_provider": auth_provider, "external_id": external_id}) - for auth_provider, external_id in await self.store.get_external_ids_by_user( - user.to_string() - ) - ] - ret["displayname"] = profile.display_name - ret["avatar_url"] = profile.avatar_url - ret["threepids"] = threepids - ret["external_ids"] = external_ids - return ret + user_info_dict = await self.store.get_user_by_id(user.to_string()) + if user_info_dict is None: + return None + + # Restrict returned information to a known set of fields. This prevents additional + # fields added to get_user_by_id from modifying Synapse's external API surface. + user_info_to_return = { + "name", + "admin", + "deactivated", + "shadow_banned", + "creation_ts", + "appservice_id", + "consent_server_notice_sent", + "consent_version", + "user_type", + "is_guest", + } + + # Restrict returned keys to a known set. + user_info_dict = { + key: value + for key, value in user_info_dict.items() + if key in user_info_to_return + } + + # Add additional user metadata + profile = await self.store.get_profileinfo(user.localpart) + threepids = await self.store.user_get_threepids(user.to_string()) + external_ids = [ + ({"auth_provider": auth_provider, "external_id": external_id}) + for auth_provider, external_id in await self.store.get_external_ids_by_user( + user.to_string() + ) + ] + user_info_dict["displayname"] = profile.display_name + user_info_dict["avatar_url"] = profile.avatar_url + user_info_dict["threepids"] = threepids + user_info_dict["external_ids"] = external_ids + + return user_info_dict async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") -> Any: """Write all data we have on the user to the given writer. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 84724b207c..bd1a322563 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -168,25 +168,25 @@ def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]: } -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class SsoLoginExtraAttributes: """Data we track about SAML2 sessions""" # time the session was created, in milliseconds - creation_time = attr.ib(type=int) - extra_attributes = attr.ib(type=JsonDict) + creation_time: int + extra_attributes: JsonDict -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class LoginTokenAttributes: """Data we store in a short-term login token""" - user_id = attr.ib(type=str) + user_id: str - auth_provider_id = attr.ib(type=str) + auth_provider_id: str """The SSO Identity Provider that the user authenticated with, to get this token.""" - auth_provider_session_id = attr.ib(type=Optional[str]) + auth_provider_session_id: Optional[str] """The session ID advertised by the SSO Identity Provider.""" @@ -2281,7 +2281,7 @@ class PasswordAuthProvider: # call all of the on_logged_out callbacks for callback in self.on_logged_out_callbacks: try: - callback(user_id, device_id, access_token) + await callback(user_id, device_id, access_token) except Exception as e: logger.warning("Failed to run module API callback %s: %s", callback, e) continue diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 7665425232..b184a48cb1 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -948,8 +948,16 @@ class DeviceListUpdater: devices = [] ignore_devices = True else: + prev_stream_id = await self.store.get_device_list_last_stream_id_for_remote( + user_id + ) cached_devices = await self.store.get_cached_devices_for_user(user_id) - if cached_devices == {d["device_id"]: d for d in devices}: + + # To ensure that a user with no devices is cached, we skip the resync only + # if we have a stream_id from previously writing a cache entry. + if prev_stream_id is not None and cached_devices == { + d["device_id"]: d for d in devices + }: logging.info( "Skipping device list resync for %s, as our cache matches already", user_id, diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 14360b4e40..d4dfddf63f 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -1321,14 +1321,14 @@ def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool: return old_key == new_key_copy -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class SignatureListItem: """An item in the signature list as used by upload_signatures_for_device_keys.""" - signing_key_id = attr.ib(type=str) - target_user_id = attr.ib(type=str) - target_device_id = attr.ib(type=str) - signature = attr.ib(type=JsonDict) + signing_key_id: str + target_user_id: str + target_device_id: str + signature: JsonDict class SigningKeyEduUpdater: diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 1b996c420d..bac5de0526 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -20,7 +20,6 @@ from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase from synapse.handlers.presence import format_user_presence_state -from synapse.logging.utils import log_function from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, UserID from synapse.visibility import filter_events_for_client @@ -43,7 +42,6 @@ class EventStreamHandler: self._server_notices_sender = hs.get_server_notices_sender() self._event_serializer = hs.get_event_client_serializer() - @log_function async def get_stream( self, auth_user_id: str, @@ -119,7 +117,7 @@ class EventStreamHandler: events.extend(to_add) - chunks = await self._event_serializer.serialize_events( + chunks = self._event_serializer.serialize_events( events, time_now, as_client_event=as_client_event, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 26b8e3f43c..a37ae0ca09 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -51,7 +51,6 @@ from synapse.logging.context import ( preserve_fn, run_in_background, ) -from synapse.logging.utils import log_function from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, ReplicationStoreRoomOnOutlierMembershipRestServlet, @@ -556,7 +555,6 @@ class FederationHandler: run_in_background(self._handle_queued_pdus, room_queue) - @log_function async def do_knock( self, target_hosts: List[str], @@ -928,7 +926,6 @@ class FederationHandler: return event - @log_function async def on_make_knock_request( self, origin: str, room_id: str, user_id: str ) -> EventBase: @@ -1039,7 +1036,6 @@ class FederationHandler: else: return [] - @log_function async def on_backfill_request( self, origin: str, room_id: str, pdu_list: List[str], limit: int ) -> List[EventBase]: @@ -1056,7 +1052,6 @@ class FederationHandler: return events - @log_function async def get_persisted_pdu( self, origin: str, event_id: str ) -> Optional[EventBase]: @@ -1118,7 +1113,6 @@ class FederationHandler: return missing_events - @log_function async def exchange_third_party_invite( self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict ) -> None: diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 11771f3c9c..3905f60b3a 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -56,7 +56,6 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.federation.federation_client import InvalidResponseError from synapse.logging.context import nested_logging_context, run_in_background -from synapse.logging.utils import log_function from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.replication.http.federation import ( @@ -275,7 +274,6 @@ class FederationEventHandler: await self._process_received_pdu(origin, pdu, state=None) - @log_function async def on_send_membership_event( self, origin: str, event: EventBase ) -> Tuple[EventBase, EventContext]: @@ -472,7 +470,6 @@ class FederationEventHandler: return await self.persist_events_and_notify(room_id, [(event, context)]) - @log_function async def backfill( self, dest: str, room_id: str, limit: int, extremities: Collection[str] ) -> None: diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 601bab67f9..346a06ff49 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -170,7 +170,7 @@ class InitialSyncHandler: d["inviter"] = event.sender invite_event = await self.store.get_event(event.event_id) - d["invite"] = await self._event_serializer.serialize_event( + d["invite"] = self._event_serializer.serialize_event( invite_event, time_now, as_client_event=as_client_event, @@ -222,7 +222,7 @@ class InitialSyncHandler: d["messages"] = { "chunk": ( - await self._event_serializer.serialize_events( + self._event_serializer.serialize_events( messages, time_now=time_now, as_client_event=as_client_event, @@ -232,7 +232,7 @@ class InitialSyncHandler: "end": await end_token.to_string(self.store), } - d["state"] = await self._event_serializer.serialize_events( + d["state"] = self._event_serializer.serialize_events( current_state.values(), time_now=time_now, as_client_event=as_client_event, @@ -376,16 +376,14 @@ class InitialSyncHandler: "messages": { "chunk": ( # Don't bundle aggregations as this is a deprecated API. - await self._event_serializer.serialize_events(messages, time_now) + self._event_serializer.serialize_events(messages, time_now) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), }, "state": ( # Don't bundle aggregations as this is a deprecated API. - await self._event_serializer.serialize_events( - room_state.values(), time_now - ) + self._event_serializer.serialize_events(room_state.values(), time_now) ), "presence": [], "receipts": [], @@ -404,7 +402,7 @@ class InitialSyncHandler: # TODO: These concurrently time_now = self.clock.time_msec() # Don't bundle aggregations as this is a deprecated API. - state = await self._event_serializer.serialize_events( + state = self._event_serializer.serialize_events( current_state.values(), time_now ) @@ -480,7 +478,7 @@ class InitialSyncHandler: "messages": { "chunk": ( # Don't bundle aggregations as this is a deprecated API. - await self._event_serializer.serialize_events(messages, time_now) + self._event_serializer.serialize_events(messages, time_now) ), "start": await start_token.to_string(self.store), "end": await end_token.to_string(self.store), diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5e3d3886eb..b37250aa38 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -246,7 +246,7 @@ class MessageHandler: room_state = room_state_events[membership_event_id] now = self.clock.time_msec() - events = await self._event_serializer.serialize_events(room_state.values(), now) + events = self._event_serializer.serialize_events(room_state.values(), now) return events async def get_joined_members(self, requester: Requester, room_id: str) -> dict: diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 7469cc55a2..973f262964 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -537,14 +537,16 @@ class PaginationHandler: state_dict = await self.store.get_events(list(state_ids.values())) state = state_dict.values() + aggregations = await self.store.get_bundled_aggregations(events, user_id) + time_now = self.clock.time_msec() chunk = { "chunk": ( - await self._event_serializer.serialize_events( + self._event_serializer.serialize_events( events, time_now, - bundle_aggregations=True, + bundle_aggregations=aggregations, as_client_event=as_client_event, ) ), @@ -553,7 +555,7 @@ class PaginationHandler: } if state: - chunk["state"] = await self._event_serializer.serialize_events( + chunk["state"] = self._event_serializer.serialize_events( state, time_now, as_client_event=as_client_event ) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index c781fefb1b..067c43ae47 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -55,7 +55,6 @@ from synapse.api.presence import UserPresenceState from synapse.appservice import ApplicationService from synapse.events.presence_router import PresenceRouter from synapse.logging.context import run_in_background -from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.presence import ( @@ -1542,7 +1541,6 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): self.clock = hs.get_clock() self.store = hs.get_datastore() - @log_function async def get_new_events( self, user: UserID, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b9c1cbffa5..f963078e59 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -393,7 +393,9 @@ class RoomCreationHandler: user_id = requester.user.to_string() if not await self.spam_checker.user_may_create_room(user_id): - raise SynapseError(403, "You are not permitted to create rooms") + raise SynapseError( + 403, "You are not permitted to create rooms", Codes.FORBIDDEN + ) creation_content: JsonDict = { "room_version": new_room_version.identifier, @@ -685,7 +687,9 @@ class RoomCreationHandler: invite_3pid_list, ) ): - raise SynapseError(403, "You are not permitted to create rooms") + raise SynapseError( + 403, "You are not permitted to create rooms", Codes.FORBIDDEN + ) if ratelimit: await self.request_ratelimiter.ratelimit(requester) @@ -1177,6 +1181,22 @@ class RoomContextHandler: # `filtered` rather than the event we retrieved from the datastore. results["event"] = filtered[0] + # Fetch the aggregations. + aggregations = await self.store.get_bundled_aggregations( + [results["event"]], user.to_string() + ) + aggregations.update( + await self.store.get_bundled_aggregations( + results["events_before"], user.to_string() + ) + ) + aggregations.update( + await self.store.get_bundled_aggregations( + results["events_after"], user.to_string() + ) + ) + results["aggregations"] = aggregations + if results["events_after"]: last_event_id = results["events_after"][-1].event_id else: diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index b2cfe537df..4844b69a03 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -153,6 +153,9 @@ class RoomSummaryHandler: rooms_result: List[JsonDict] = [] events_result: List[JsonDict] = [] + if max_rooms_per_space is None or max_rooms_per_space > MAX_ROOMS_PER_SPACE: + max_rooms_per_space = MAX_ROOMS_PER_SPACE + while room_queue and len(rooms_result) < MAX_ROOMS: queue_entry = room_queue.popleft() room_id = queue_entry.room_id @@ -167,7 +170,7 @@ class RoomSummaryHandler: # The client-specified max_rooms_per_space limit doesn't apply to the # room_id specified in the request, so we ignore it if this is the # first room we are processing. - max_children = max_rooms_per_space if processed_rooms else None + max_children = max_rooms_per_space if processed_rooms else MAX_ROOMS if is_in_room: room_entry = await self._summarize_local_room( @@ -209,7 +212,7 @@ class RoomSummaryHandler: # Before returning to the client, remove the allowed_room_ids # and allowed_spaces keys. room.pop("allowed_room_ids", None) - room.pop("allowed_spaces", None) + room.pop("allowed_spaces", None) # historical rooms_result.append(room) events.extend(room_entry.children_state_events) @@ -395,7 +398,7 @@ class RoomSummaryHandler: None, room_id, suggested_only, - # TODO Handle max children. + # Do not limit the maximum children. max_children=None, ) @@ -525,6 +528,10 @@ class RoomSummaryHandler: rooms_result: List[JsonDict] = [] events_result: List[JsonDict] = [] + # Set a limit on the number of rooms to return. + if max_rooms_per_space is None or max_rooms_per_space > MAX_ROOMS_PER_SPACE: + max_rooms_per_space = MAX_ROOMS_PER_SPACE + while room_queue and len(rooms_result) < MAX_ROOMS: room_id = room_queue.popleft() if room_id in processed_rooms: @@ -583,7 +590,9 @@ class RoomSummaryHandler: # Iterate through each child and potentially add it, but not its children, # to the response. - for child_room in root_room_entry.children_state_events: + for child_room in itertools.islice( + root_room_entry.children_state_events, MAX_ROOMS_PER_SPACE + ): room_id = child_room.get("state_key") assert isinstance(room_id, str) # If the room is unknown, skip it. @@ -633,8 +642,8 @@ class RoomSummaryHandler: suggested_only: True if only suggested children should be returned. Otherwise, all children are returned. max_children: - The maximum number of children rooms to include. This is capped - to a server-set limit. + The maximum number of children rooms to include. A value of None + means no limit. Returns: A room entry if the room should be returned. None, otherwise. @@ -656,8 +665,13 @@ class RoomSummaryHandler: # we only care about suggested children child_events = filter(_is_suggested_child_event, child_events) - if max_children is None or max_children > MAX_ROOMS_PER_SPACE: - max_children = MAX_ROOMS_PER_SPACE + # TODO max_children is legacy code for the /spaces endpoint. + if max_children is not None: + child_iter: Iterable[EventBase] = itertools.islice( + child_events, max_children + ) + else: + child_iter = child_events stripped_events: List[JsonDict] = [ { @@ -668,7 +682,7 @@ class RoomSummaryHandler: "sender": e.sender, "origin_server_ts": e.origin_server_ts, } - for e in itertools.islice(child_events, max_children) + for e in child_iter ] return _RoomEntry(room_id, room_entry, stripped_events) @@ -766,6 +780,7 @@ class RoomSummaryHandler: try: ( room_response, + children_state_events, children, inaccessible_children, ) = await self._federation_client.get_room_hierarchy( @@ -790,7 +805,7 @@ class RoomSummaryHandler: } return ( - _RoomEntry(room_id, room_response, room_response.pop("children_state", ())), + _RoomEntry(room_id, room_response, children_state_events), children_by_room_id, set(inaccessible_children), ) @@ -988,12 +1003,14 @@ class RoomSummaryHandler: "canonical_alias": stats["canonical_alias"], "num_joined_members": stats["joined_members"], "avatar_url": stats["avatar"], + # plural join_rules is a documentation error but kept for historical + # purposes. Should match /publicRooms. "join_rules": stats["join_rules"], + "join_rule": stats["join_rules"], "world_readable": ( stats["history_visibility"] == HistoryVisibility.WORLD_READABLE ), "guest_can_join": stats["guest_access"] == "can_join", - "creation_ts": create_event.origin_server_ts, "room_type": create_event.content.get(EventContentFields.ROOM_TYPE), } diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index ab7eaab2fb..0b153a6822 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -420,10 +420,10 @@ class SearchHandler: time_now = self.clock.time_msec() for context in contexts.values(): - context["events_before"] = await self._event_serializer.serialize_events( + context["events_before"] = self._event_serializer.serialize_events( context["events_before"], time_now ) - context["events_after"] = await self._event_serializer.serialize_events( + context["events_after"] = self._event_serializer.serialize_events( context["events_after"], time_now ) @@ -441,9 +441,7 @@ class SearchHandler: results.append( { "rank": rank_map[e.event_id], - "result": ( - await self._event_serializer.serialize_event(e, time_now) - ), + "result": self._event_serializer.serialize_event(e, time_now), "context": contexts.get(e.event_id, {}), } ) @@ -457,7 +455,7 @@ class SearchHandler: if state_results: s = {} for room_id, state_events in state_results.items(): - s[room_id] = await self._event_serializer.serialize_events( + s[room_id] = self._event_serializer.serialize_events( state_events, time_now ) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 65c27bc64a..0bb8b0929e 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -126,45 +126,45 @@ class SsoIdentityProvider(Protocol): raise NotImplementedError() -@attr.s +@attr.s(auto_attribs=True) class UserAttributes: # the localpart of the mxid that the mapper has assigned to the user. # if `None`, the mapper has not picked a userid, and the user should be prompted to # enter one. - localpart = attr.ib(type=Optional[str]) - display_name = attr.ib(type=Optional[str], default=None) - emails = attr.ib(type=Collection[str], default=attr.Factory(list)) + localpart: Optional[str] + display_name: Optional[str] = None + emails: Collection[str] = attr.Factory(list) -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class UsernameMappingSession: """Data we track about SSO sessions""" # A unique identifier for this SSO provider, e.g. "oidc" or "saml". - auth_provider_id = attr.ib(type=str) + auth_provider_id: str # user ID on the IdP server - remote_user_id = attr.ib(type=str) + remote_user_id: str # attributes returned by the ID mapper - display_name = attr.ib(type=Optional[str]) - emails = attr.ib(type=Collection[str]) + display_name: Optional[str] + emails: Collection[str] # An optional dictionary of extra attributes to be provided to the client in the # login response. - extra_login_attributes = attr.ib(type=Optional[JsonDict]) + extra_login_attributes: Optional[JsonDict] # where to redirect the client back to - client_redirect_url = attr.ib(type=str) + client_redirect_url: str # expiry time for the session, in milliseconds - expiry_time_ms = attr.ib(type=int) + expiry_time_ms: int # choices made by the user - chosen_localpart = attr.ib(type=Optional[str], default=None) - use_display_name = attr.ib(type=bool, default=True) - emails_to_use = attr.ib(type=Collection[str], default=()) - terms_accepted_version = attr.ib(type=Optional[str], default=None) + chosen_localpart: Optional[str] = None + use_display_name: bool = True + emails_to_use: Collection[str] = () + terms_accepted_version: Optional[str] = None # the HTTP cookie used to track the mapping session id diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 7baf3f199c..ffc6b748e8 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -60,10 +60,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -# Debug logger for https://github.com/matrix-org/synapse/issues/4422 -issue4422_logger = logging.getLogger("synapse.handler.sync.4422_debug") - - # Counts the number of times we returned a non-empty sync. `type` is one of # "initial_sync", "full_state_sync" or "incremental_sync", `lazy_loaded` is # "true" or "false" depending on if the request asked for lazy loaded members or @@ -102,6 +98,9 @@ class TimelineBatch: prev_batch: StreamToken events: List[EventBase] limited: bool + # A mapping of event ID to the bundled aggregations for the above events. + # This is only calculated if limited is true. + bundled_aggregations: Optional[Dict[str, Dict[str, Any]]] = None def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -634,10 +633,19 @@ class SyncHandler: prev_batch_token = now_token.copy_and_replace("room_key", room_key) + # Don't bother to bundle aggregations if the timeline is unlimited, + # as clients will have all the necessary information. + bundled_aggregations = None + if limited or newly_joined_room: + bundled_aggregations = await self.store.get_bundled_aggregations( + recents, sync_config.user.to_string() + ) + return TimelineBatch( events=recents, prev_batch=prev_batch_token, limited=limited or newly_joined_room, + bundled_aggregations=bundled_aggregations, ) async def get_state_after_event( @@ -1161,13 +1169,8 @@ class SyncHandler: num_events = 0 - # debug for https://github.com/matrix-org/synapse/issues/4422 + # debug for https://github.com/matrix-org/synapse/issues/9424 for joined_room in sync_result_builder.joined: - room_id = joined_room.room_id - if room_id in newly_joined_rooms: - issue4422_logger.debug( - "Sync result for newly joined room %s: %r", room_id, joined_room - ) num_events += len(joined_room.timeline.events) log_kv( @@ -1740,18 +1743,6 @@ class SyncHandler: old_mem_ev_id, allow_none=True ) - # debug for #4422 - if has_join: - prev_membership = None - if old_mem_ev: - prev_membership = old_mem_ev.membership - issue4422_logger.debug( - "Previous membership for room %s with join: %s (event %s)", - room_id, - prev_membership, - old_mem_ev_id, - ) - if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: newly_joined_rooms.append(room_id) @@ -1893,13 +1884,6 @@ class SyncHandler: upto_token=since_token, ) - if newly_joined: - # debugging for https://github.com/matrix-org/synapse/issues/4422 - issue4422_logger.debug( - "RoomSyncResultBuilder events for newly joined room %s: %r", - room_id, - entry.events, - ) room_entries.append(entry) return _RoomChanges( @@ -2077,14 +2061,6 @@ class SyncHandler: # `_load_filtered_recents` can't find any events the user should see # (e.g. due to having ignored the sender of the last 50 events). - if newly_joined: - # debug for https://github.com/matrix-org/synapse/issues/4422 - issue4422_logger.debug( - "Timeline events after filtering in newly-joined room %s: %r", - room_id, - batch, - ) - # When we join the room (or the client requests full_state), we should # send down any existing tags. Usually the user won't have tags in a # newly joined room, unless either a) they've joined before or b) the diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py index fbafffd69b..203e995bb7 100644 --- a/synapse/http/connectproxyclient.py +++ b/synapse/http/connectproxyclient.py @@ -32,9 +32,9 @@ class ProxyConnectError(ConnectError): pass -@attr.s +@attr.s(auto_attribs=True) class ProxyCredentials: - username_password = attr.ib(type=bytes) + username_password: bytes def as_proxy_authorization_value(self) -> bytes: """ diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index deedde0b5b..2e668363b2 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -123,37 +123,37 @@ class ByteParser(ByteWriteable, Generic[T], abc.ABC): pass -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class MatrixFederationRequest: - method = attr.ib(type=str) + method: str """HTTP method """ - path = attr.ib(type=str) + path: str """HTTP path """ - destination = attr.ib(type=str) + destination: str """The remote server to send the HTTP request to. """ - json = attr.ib(default=None, type=Optional[JsonDict]) + json: Optional[JsonDict] = None """JSON to send in the body. """ - json_callback = attr.ib(default=None, type=Optional[Callable[[], JsonDict]]) + json_callback: Optional[Callable[[], JsonDict]] = None """A callback to generate the JSON. """ - query = attr.ib(default=None, type=Optional[dict]) + query: Optional[dict] = None """Query arguments. """ - txn_id = attr.ib(default=None, type=Optional[str]) + txn_id: Optional[str] = None """Unique ID for this request (for logging) """ - uri = attr.ib(init=False, type=bytes) + uri: bytes = attr.ib(init=False) """The URI of this request """ diff --git a/synapse/http/site.py b/synapse/http/site.py index 80f7a2ff58..c180a1d323 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -534,9 +534,9 @@ class XForwardedForRequest(SynapseRequest): @implementer(IAddress) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class _XForwardedForAddress: - host = attr.ib(type=str) + host: str class SynapseSite(Site): diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py index 8202d0494d..475756f1db 100644 --- a/synapse/logging/_remote.py +++ b/synapse/logging/_remote.py @@ -39,7 +39,7 @@ from twisted.python.failure import Failure logger = logging.getLogger(__name__) -@attr.s +@attr.s(slots=True, auto_attribs=True) @implementer(IPushProducer) class LogProducer: """ @@ -54,10 +54,10 @@ class LogProducer: # This is essentially ITCPTransport, but that is missing certain fields # (connected and registerProducer) which are part of the implementation. - transport = attr.ib(type=Connection) - _format = attr.ib(type=Callable[[logging.LogRecord], str]) - _buffer = attr.ib(type=deque) - _paused = attr.ib(default=False, type=bool, init=False) + transport: Connection + _format: Callable[[logging.LogRecord], str] + _buffer: Deque[logging.LogRecord] + _paused: bool = attr.ib(default=False, init=False) def pauseProducing(self): self._paused = True diff --git a/synapse/logging/context.py b/synapse/logging/context.py index d4ee893376..c31c2960ad 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -193,7 +193,7 @@ class ContextResourceUsage: return res -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class ContextRequest: """ A bundle of attributes from the SynapseRequest object. @@ -205,15 +205,15 @@ class ContextRequest: their children. """ - request_id = attr.ib(type=str) - ip_address = attr.ib(type=str) - site_tag = attr.ib(type=str) - requester = attr.ib(type=Optional[str]) - authenticated_entity = attr.ib(type=Optional[str]) - method = attr.ib(type=str) - url = attr.ib(type=str) - protocol = attr.ib(type=str) - user_agent = attr.ib(type=str) + request_id: str + ip_address: str + site_tag: str + requester: Optional[str] + authenticated_entity: Optional[str] + method: str + url: str + protocol: str + user_agent: str LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"] diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 622445e9f4..b240d2d21d 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -247,11 +247,11 @@ try: class BaseReporter: # type: ignore[no-redef] pass - @attr.s(slots=True, frozen=True) + @attr.s(slots=True, frozen=True, auto_attribs=True) class _WrappedRustReporter(BaseReporter): """Wrap the reporter to ensure `report_span` never throws.""" - _reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter)) + _reporter: Reporter = attr.Factory(Reporter) def set_process(self, *args, **kwargs): return self._reporter.set_process(*args, **kwargs) diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py deleted file mode 100644 index 4a01b902c2..0000000000 --- a/synapse/logging/utils.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging -from functools import wraps -from inspect import getcallargs -from typing import Callable, TypeVar, cast - -_TIME_FUNC_ID = 0 - - -def _log_debug_as_f(f, msg, msg_args): - name = f.__module__ - logger = logging.getLogger(name) - - if logger.isEnabledFor(logging.DEBUG): - lineno = f.__code__.co_firstlineno - pathname = f.__code__.co_filename - - record = logger.makeRecord( - name=name, - level=logging.DEBUG, - fn=pathname, - lno=lineno, - msg=msg, - args=msg_args, - exc_info=None, - ) - - logger.handle(record) - - -F = TypeVar("F", bound=Callable) - - -def log_function(f: F) -> F: - """Function decorator that logs every call to that function.""" - func_name = f.__name__ - - @wraps(f) - def wrapped(*args, **kwargs): - name = f.__module__ - logger = logging.getLogger(name) - level = logging.DEBUG - - if logger.isEnabledFor(level): - bound_args = getcallargs(f, *args, **kwargs) - - def format(value): - r = str(value) - if len(r) > 50: - r = r[:50] + "..." - return r - - func_args = ["%s=%s" % (k, format(v)) for k, v in bound_args.items()] - - msg_args = {"func_name": func_name, "args": ", ".join(func_args)} - - _log_debug_as_f(f, "Invoked '%(func_name)s' with args: %(args)s", msg_args) - - return f(*args, **kwargs) - - wrapped.__name__ = func_name - return cast(F, wrapped) diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index ceef57ad88..9e6c1b2f3b 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -12,16 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools -import gc import itertools import logging import os import platform import threading -import time from typing import ( - Any, Callable, Dict, Generic, @@ -34,35 +30,31 @@ from typing import ( Type, TypeVar, Union, - cast, ) import attr from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram, Metric from prometheus_client.core import ( REGISTRY, - CounterMetricFamily, GaugeHistogramMetricFamily, GaugeMetricFamily, ) -from twisted.internet import reactor -from twisted.internet.base import ReactorBase from twisted.python.threadpool import ThreadPool -import synapse +import synapse.metrics._reactor_metrics from synapse.metrics._exposition import ( MetricsResource, generate_latest, start_http_server, ) +from synapse.metrics._gc import MIN_TIME_BETWEEN_GCS, install_gc_manager from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) METRICS_PREFIX = "/_synapse/metrics" -running_on_pypy = platform.python_implementation() == "PyPy" all_gauges: "Dict[str, Union[LaterGauge, InFlightGauge]]" = {} HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") @@ -76,19 +68,17 @@ class RegistryProxy: yield metric -@attr.s(slots=True, hash=True) +@attr.s(slots=True, hash=True, auto_attribs=True) class LaterGauge: - name = attr.ib(type=str) - desc = attr.ib(type=str) - labels = attr.ib(hash=False, type=Optional[Iterable[str]]) + name: str + desc: str + labels: Optional[Iterable[str]] = attr.ib(hash=False) # callback: should either return a value (if there are no labels for this metric), # or dict mapping from a label tuple to a value - caller = attr.ib( - type=Callable[ - [], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]] - ] - ) + caller: Callable[ + [], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]] + ] def collect(self) -> Iterable[Metric]: @@ -157,7 +147,9 @@ class InFlightGauge(Generic[MetricsEntry]): # Create a class which have the sub_metrics values as attributes, which # default to 0 on initialization. Used to pass to registered callbacks. self._metrics_class: Type[MetricsEntry] = attr.make_class( - "_MetricsEntry", attrs={x: attr.ib(0) for x in sub_metrics}, slots=True + "_MetricsEntry", + attrs={x: attr.ib(default=0) for x in sub_metrics}, + slots=True, ) # Counts number of in flight blocks for a given set of label values @@ -369,136 +361,6 @@ class CPUMetrics: REGISTRY.register(CPUMetrics()) -# -# Python GC metrics -# - -gc_unreachable = Gauge("python_gc_unreachable_total", "Unreachable GC objects", ["gen"]) -gc_time = Histogram( - "python_gc_time", - "Time taken to GC (sec)", - ["gen"], - buckets=[ - 0.0025, - 0.005, - 0.01, - 0.025, - 0.05, - 0.10, - 0.25, - 0.50, - 1.00, - 2.50, - 5.00, - 7.50, - 15.00, - 30.00, - 45.00, - 60.00, - ], -) - - -class GCCounts: - def collect(self) -> Iterable[Metric]: - cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"]) - for n, m in enumerate(gc.get_count()): - cm.add_metric([str(n)], m) - - yield cm - - -if not running_on_pypy: - REGISTRY.register(GCCounts()) - - -# -# PyPy GC / memory metrics -# - - -class PyPyGCStats: - def collect(self) -> Iterable[Metric]: - - # @stats is a pretty-printer object with __str__() returning a nice table, - # plus some fields that contain data from that table. - # unfortunately, fields are pretty-printed themselves (i. e. '4.5MB'). - stats = gc.get_stats(memory_pressure=False) # type: ignore - # @s contains same fields as @stats, but as actual integers. - s = stats._s # type: ignore - - # also note that field naming is completely braindead - # and only vaguely correlates with the pretty-printed table. - # >>>> gc.get_stats(False) - # Total memory consumed: - # GC used: 8.7MB (peak: 39.0MB) # s.total_gc_memory, s.peak_memory - # in arenas: 3.0MB # s.total_arena_memory - # rawmalloced: 1.7MB # s.total_rawmalloced_memory - # nursery: 4.0MB # s.nursery_size - # raw assembler used: 31.0kB # s.jit_backend_used - # ----------------------------- - # Total: 8.8MB # stats.memory_used_sum - # - # Total memory allocated: - # GC allocated: 38.7MB (peak: 41.1MB) # s.total_allocated_memory, s.peak_allocated_memory - # in arenas: 30.9MB # s.peak_arena_memory - # rawmalloced: 4.1MB # s.peak_rawmalloced_memory - # nursery: 4.0MB # s.nursery_size - # raw assembler allocated: 1.0MB # s.jit_backend_allocated - # ----------------------------- - # Total: 39.7MB # stats.memory_allocated_sum - # - # Total time spent in GC: 0.073 # s.total_gc_time - - pypy_gc_time = CounterMetricFamily( - "pypy_gc_time_seconds_total", - "Total time spent in PyPy GC", - labels=[], - ) - pypy_gc_time.add_metric([], s.total_gc_time / 1000) - yield pypy_gc_time - - pypy_mem = GaugeMetricFamily( - "pypy_memory_bytes", - "Memory tracked by PyPy allocator", - labels=["state", "class", "kind"], - ) - # memory used by JIT assembler - pypy_mem.add_metric(["used", "", "jit"], s.jit_backend_used) - pypy_mem.add_metric(["allocated", "", "jit"], s.jit_backend_allocated) - # memory used by GCed objects - pypy_mem.add_metric(["used", "", "arenas"], s.total_arena_memory) - pypy_mem.add_metric(["allocated", "", "arenas"], s.peak_arena_memory) - pypy_mem.add_metric(["used", "", "rawmalloced"], s.total_rawmalloced_memory) - pypy_mem.add_metric(["allocated", "", "rawmalloced"], s.peak_rawmalloced_memory) - pypy_mem.add_metric(["used", "", "nursery"], s.nursery_size) - pypy_mem.add_metric(["allocated", "", "nursery"], s.nursery_size) - # totals - pypy_mem.add_metric(["used", "totals", "gc"], s.total_gc_memory) - pypy_mem.add_metric(["allocated", "totals", "gc"], s.total_allocated_memory) - pypy_mem.add_metric(["used", "totals", "gc_peak"], s.peak_memory) - pypy_mem.add_metric(["allocated", "totals", "gc_peak"], s.peak_allocated_memory) - yield pypy_mem - - -if running_on_pypy: - REGISTRY.register(PyPyGCStats()) - - -# -# Twisted reactor metrics -# - -tick_time = Histogram( - "python_twisted_reactor_tick_time", - "Tick time of the Twisted reactor (sec)", - buckets=[0.001, 0.002, 0.005, 0.01, 0.025, 0.05, 0.1, 0.2, 0.5, 1, 2, 5], -) -pending_calls_metric = Histogram( - "python_twisted_reactor_pending_calls", - "Pending calls", - buckets=[1, 2, 5, 10, 25, 50, 100, 250, 500, 1000], -) # # Federation Metrics @@ -551,8 +413,6 @@ build_info.labels( " ".join([platform.system(), platform.release()]), ).set(1) -last_ticked = time.time() - # 3PID send info threepid_send_requests = Histogram( "synapse_threepid_send_requests_with_tries", @@ -600,116 +460,6 @@ def register_threadpool(name: str, threadpool: ThreadPool) -> None: ) -class ReactorLastSeenMetric: - def collect(self) -> Iterable[Metric]: - cm = GaugeMetricFamily( - "python_twisted_reactor_last_seen", - "Seconds since the Twisted reactor was last seen", - ) - cm.add_metric([], time.time() - last_ticked) - yield cm - - -REGISTRY.register(ReactorLastSeenMetric()) - -# The minimum time in seconds between GCs for each generation, regardless of the current GC -# thresholds and counts. -MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0) - -# The time (in seconds since the epoch) of the last time we did a GC for each generation. -_last_gc = [0.0, 0.0, 0.0] - - -F = TypeVar("F", bound=Callable[..., Any]) - - -def runUntilCurrentTimer(reactor: ReactorBase, func: F) -> F: - @functools.wraps(func) - def f(*args: Any, **kwargs: Any) -> Any: - now = reactor.seconds() - num_pending = 0 - - # _newTimedCalls is one long list of *all* pending calls. Below loop - # is based off of impl of reactor.runUntilCurrent - for delayed_call in reactor._newTimedCalls: - if delayed_call.time > now: - break - - if delayed_call.delayed_time > 0: - continue - - num_pending += 1 - - num_pending += len(reactor.threadCallQueue) - start = time.time() - ret = func(*args, **kwargs) - end = time.time() - - # record the amount of wallclock time spent running pending calls. - # This is a proxy for the actual amount of time between reactor polls, - # since about 25% of time is actually spent running things triggered by - # I/O events, but that is harder to capture without rewriting half the - # reactor. - tick_time.observe(end - start) - pending_calls_metric.observe(num_pending) - - # Update the time we last ticked, for the metric to test whether - # Synapse's reactor has frozen - global last_ticked - last_ticked = end - - if running_on_pypy: - return ret - - # Check if we need to do a manual GC (since its been disabled), and do - # one if necessary. Note we go in reverse order as e.g. a gen 1 GC may - # promote an object into gen 2, and we don't want to handle the same - # object multiple times. - threshold = gc.get_threshold() - counts = gc.get_count() - for i in (2, 1, 0): - # We check if we need to do one based on a straightforward - # comparison between the threshold and count. We also do an extra - # check to make sure that we don't a GC too often. - if threshold[i] < counts[i] and MIN_TIME_BETWEEN_GCS[i] < end - _last_gc[i]: - if i == 0: - logger.debug("Collecting gc %d", i) - else: - logger.info("Collecting gc %d", i) - - start = time.time() - unreachable = gc.collect(i) - end = time.time() - - _last_gc[i] = end - - gc_time.labels(i).observe(end - start) - gc_unreachable.labels(i).set(unreachable) - - return ret - - return cast(F, f) - - -try: - # Ensure the reactor has all the attributes we expect - reactor.seconds # type: ignore - reactor.runUntilCurrent # type: ignore - reactor._newTimedCalls # type: ignore - reactor.threadCallQueue # type: ignore - - # runUntilCurrent is called when we have pending calls. It is called once - # per iteratation after fd polling. - reactor.runUntilCurrent = runUntilCurrentTimer(reactor, reactor.runUntilCurrent) # type: ignore - - # We manually run the GC each reactor tick so that we can get some metrics - # about time spent doing GC, - if not running_on_pypy: - gc.disable() -except AttributeError: - pass - - __all__ = [ "MetricsResource", "generate_latest", @@ -717,4 +467,6 @@ __all__ = [ "LaterGauge", "InFlightGauge", "GaugeBucketCollector", + "MIN_TIME_BETWEEN_GCS", + "install_gc_manager", ] diff --git a/synapse/metrics/_gc.py b/synapse/metrics/_gc.py new file mode 100644 index 0000000000..2bc909efa0 --- /dev/null +++ b/synapse/metrics/_gc.py @@ -0,0 +1,203 @@ +# Copyright 2015-2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import gc +import logging +import platform +import time +from typing import Iterable + +from prometheus_client.core import ( + REGISTRY, + CounterMetricFamily, + Gauge, + GaugeMetricFamily, + Histogram, + Metric, +) + +from twisted.internet import task + +"""Prometheus metrics for garbage collection""" + + +logger = logging.getLogger(__name__) + +# The minimum time in seconds between GCs for each generation, regardless of the current GC +# thresholds and counts. +MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0) + +running_on_pypy = platform.python_implementation() == "PyPy" + +# +# Python GC metrics +# + +gc_unreachable = Gauge("python_gc_unreachable_total", "Unreachable GC objects", ["gen"]) +gc_time = Histogram( + "python_gc_time", + "Time taken to GC (sec)", + ["gen"], + buckets=[ + 0.0025, + 0.005, + 0.01, + 0.025, + 0.05, + 0.10, + 0.25, + 0.50, + 1.00, + 2.50, + 5.00, + 7.50, + 15.00, + 30.00, + 45.00, + 60.00, + ], +) + + +class GCCounts: + def collect(self) -> Iterable[Metric]: + cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"]) + for n, m in enumerate(gc.get_count()): + cm.add_metric([str(n)], m) + + yield cm + + +def install_gc_manager() -> None: + """Disable automatic GC, and replace it with a task that runs every 100ms + + This means that (a) we can limit how often GC runs; (b) we can get some metrics + about GC activity. + + It does nothing on PyPy. + """ + + if running_on_pypy: + return + + REGISTRY.register(GCCounts()) + + gc.disable() + + # The time (in seconds since the epoch) of the last time we did a GC for each generation. + _last_gc = [0.0, 0.0, 0.0] + + def _maybe_gc() -> None: + # Check if we need to do a manual GC (since its been disabled), and do + # one if necessary. Note we go in reverse order as e.g. a gen 1 GC may + # promote an object into gen 2, and we don't want to handle the same + # object multiple times. + threshold = gc.get_threshold() + counts = gc.get_count() + end = time.time() + for i in (2, 1, 0): + # We check if we need to do one based on a straightforward + # comparison between the threshold and count. We also do an extra + # check to make sure that we don't a GC too often. + if threshold[i] < counts[i] and MIN_TIME_BETWEEN_GCS[i] < end - _last_gc[i]: + if i == 0: + logger.debug("Collecting gc %d", i) + else: + logger.info("Collecting gc %d", i) + + start = time.time() + unreachable = gc.collect(i) + end = time.time() + + _last_gc[i] = end + + gc_time.labels(i).observe(end - start) + gc_unreachable.labels(i).set(unreachable) + + gc_task = task.LoopingCall(_maybe_gc) + gc_task.start(0.1) + + +# +# PyPy GC / memory metrics +# + + +class PyPyGCStats: + def collect(self) -> Iterable[Metric]: + + # @stats is a pretty-printer object with __str__() returning a nice table, + # plus some fields that contain data from that table. + # unfortunately, fields are pretty-printed themselves (i. e. '4.5MB'). + stats = gc.get_stats(memory_pressure=False) # type: ignore + # @s contains same fields as @stats, but as actual integers. + s = stats._s # type: ignore + + # also note that field naming is completely braindead + # and only vaguely correlates with the pretty-printed table. + # >>>> gc.get_stats(False) + # Total memory consumed: + # GC used: 8.7MB (peak: 39.0MB) # s.total_gc_memory, s.peak_memory + # in arenas: 3.0MB # s.total_arena_memory + # rawmalloced: 1.7MB # s.total_rawmalloced_memory + # nursery: 4.0MB # s.nursery_size + # raw assembler used: 31.0kB # s.jit_backend_used + # ----------------------------- + # Total: 8.8MB # stats.memory_used_sum + # + # Total memory allocated: + # GC allocated: 38.7MB (peak: 41.1MB) # s.total_allocated_memory, s.peak_allocated_memory + # in arenas: 30.9MB # s.peak_arena_memory + # rawmalloced: 4.1MB # s.peak_rawmalloced_memory + # nursery: 4.0MB # s.nursery_size + # raw assembler allocated: 1.0MB # s.jit_backend_allocated + # ----------------------------- + # Total: 39.7MB # stats.memory_allocated_sum + # + # Total time spent in GC: 0.073 # s.total_gc_time + + pypy_gc_time = CounterMetricFamily( + "pypy_gc_time_seconds_total", + "Total time spent in PyPy GC", + labels=[], + ) + pypy_gc_time.add_metric([], s.total_gc_time / 1000) + yield pypy_gc_time + + pypy_mem = GaugeMetricFamily( + "pypy_memory_bytes", + "Memory tracked by PyPy allocator", + labels=["state", "class", "kind"], + ) + # memory used by JIT assembler + pypy_mem.add_metric(["used", "", "jit"], s.jit_backend_used) + pypy_mem.add_metric(["allocated", "", "jit"], s.jit_backend_allocated) + # memory used by GCed objects + pypy_mem.add_metric(["used", "", "arenas"], s.total_arena_memory) + pypy_mem.add_metric(["allocated", "", "arenas"], s.peak_arena_memory) + pypy_mem.add_metric(["used", "", "rawmalloced"], s.total_rawmalloced_memory) + pypy_mem.add_metric(["allocated", "", "rawmalloced"], s.peak_rawmalloced_memory) + pypy_mem.add_metric(["used", "", "nursery"], s.nursery_size) + pypy_mem.add_metric(["allocated", "", "nursery"], s.nursery_size) + # totals + pypy_mem.add_metric(["used", "totals", "gc"], s.total_gc_memory) + pypy_mem.add_metric(["allocated", "totals", "gc"], s.total_allocated_memory) + pypy_mem.add_metric(["used", "totals", "gc_peak"], s.peak_memory) + pypy_mem.add_metric(["allocated", "totals", "gc_peak"], s.peak_allocated_memory) + yield pypy_mem + + +if running_on_pypy: + REGISTRY.register(PyPyGCStats()) diff --git a/synapse/metrics/_reactor_metrics.py b/synapse/metrics/_reactor_metrics.py new file mode 100644 index 0000000000..f38f798313 --- /dev/null +++ b/synapse/metrics/_reactor_metrics.py @@ -0,0 +1,83 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import select +import time +from typing import Any, Iterable, List, Tuple + +from prometheus_client import Histogram, Metric +from prometheus_client.core import REGISTRY, GaugeMetricFamily + +from twisted.internet import reactor + +# +# Twisted reactor metrics +# + +tick_time = Histogram( + "python_twisted_reactor_tick_time", + "Tick time of the Twisted reactor (sec)", + buckets=[0.001, 0.002, 0.005, 0.01, 0.025, 0.05, 0.1, 0.2, 0.5, 1, 2, 5], +) + + +class EpollWrapper: + """a wrapper for an epoll object which records the time between polls""" + + def __init__(self, poller: "select.epoll"): # type: ignore[name-defined] + self.last_polled = time.time() + self._poller = poller + + def poll(self, *args, **kwargs) -> List[Tuple[int, int]]: # type: ignore[no-untyped-def] + # record the time since poll() was last called. This gives a good proxy for + # how long it takes to run everything in the reactor - ie, how long anything + # waiting for the next tick will have to wait. + tick_time.observe(time.time() - self.last_polled) + + ret = self._poller.poll(*args, **kwargs) + + self.last_polled = time.time() + return ret + + def __getattr__(self, item: str) -> Any: + return getattr(self._poller, item) + + +class ReactorLastSeenMetric: + def __init__(self, epoll_wrapper: EpollWrapper): + self._epoll_wrapper = epoll_wrapper + + def collect(self) -> Iterable[Metric]: + cm = GaugeMetricFamily( + "python_twisted_reactor_last_seen", + "Seconds since the Twisted reactor was last seen", + ) + cm.add_metric([], time.time() - self._epoll_wrapper.last_polled) + yield cm + + +try: + # if the reactor has a `_poller` attribute, which is an `epoll` object + # (ie, it's an EPollReactor), we wrap the `epoll` with a thing that will + # measure the time between ticks + from select import epoll # type: ignore[attr-defined] + + poller = reactor._poller # type: ignore[attr-defined] +except (AttributeError, ImportError): + pass +else: + if isinstance(poller, epoll): + poller = EpollWrapper(poller) + reactor._poller = poller # type: ignore[attr-defined] + REGISTRY.register(ReactorLastSeenMetric(poller)) diff --git a/synapse/notifier.py b/synapse/notifier.py index bbabdb0587..632b2245ef 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -40,7 +40,6 @@ from synapse.handlers.presence import format_user_presence_state from synapse.logging import issue9533_logger from synapse.logging.context import PreserveLoggingContext from synapse.logging.opentracing import log_kv, start_active_span -from synapse.logging.utils import log_function from synapse.metrics import LaterGauge from synapse.streams.config import PaginationConfig from synapse.types import ( @@ -193,15 +192,15 @@ class EventStreamResult: return bool(self.events) -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class _PendingRoomEventEntry: - event_pos = attr.ib(type=PersistedEventPosition) - extra_users = attr.ib(type=Collection[UserID]) + event_pos: PersistedEventPosition + extra_users: Collection[UserID] - room_id = attr.ib(type=str) - type = attr.ib(type=str) - state_key = attr.ib(type=Optional[str]) - membership = attr.ib(type=Optional[str]) + room_id: str + type: str + state_key: Optional[str] + membership: Optional[str] class Notifier: @@ -686,7 +685,6 @@ class Notifier: else: return False - @log_function def remove_expired_streams(self) -> None: time_now_ms = self.clock.time_msec() expired_streams = [] @@ -700,7 +698,6 @@ class Notifier: for expired_stream in expired_streams: expired_stream.remove(self) - @log_function def _register_with_keys(self, user_stream: _NotifierUserStream): self.user_to_user_stream[user_stream.user_id] = user_stream diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 820f6f3f7e..5176a1c186 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -23,25 +23,25 @@ if TYPE_CHECKING: from synapse.server import HomeServer -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class PusherConfig: """Parameters necessary to configure a pusher.""" - id = attr.ib(type=Optional[str]) - user_name = attr.ib(type=str) - access_token = attr.ib(type=Optional[int]) - profile_tag = attr.ib(type=str) - kind = attr.ib(type=str) - app_id = attr.ib(type=str) - app_display_name = attr.ib(type=str) - device_display_name = attr.ib(type=str) - pushkey = attr.ib(type=str) - ts = attr.ib(type=int) - lang = attr.ib(type=Optional[str]) - data = attr.ib(type=Optional[JsonDict]) - last_stream_ordering = attr.ib(type=int) - last_success = attr.ib(type=Optional[int]) - failing_since = attr.ib(type=Optional[int]) + id: Optional[str] + user_name: str + access_token: Optional[int] + profile_tag: str + kind: str + app_id: str + app_display_name: str + device_display_name: str + pushkey: str + ts: int + lang: Optional[str] + data: Optional[JsonDict] + last_stream_ordering: int + last_success: Optional[int] + failing_since: Optional[int] def as_dict(self) -> Dict[str, Any]: """Information that can be retrieved about a pusher after creation.""" @@ -57,12 +57,12 @@ class PusherConfig: } -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class ThrottleParams: """Parameters for controlling the rate of sending pushes via email.""" - last_sent_ts = attr.ib(type=int) - throttle_ms = attr.ib(type=int) + last_sent_ts: int + throttle_ms: int class Pusher(metaclass=abc.ABCMeta): diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 009d8e77b0..bee660893b 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -298,7 +298,7 @@ RulesByUser = Dict[str, List[Rule]] StateGroup = Union[object, int] -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class RulesForRoomData: """The data stored in the cache by `RulesForRoom`. @@ -307,29 +307,29 @@ class RulesForRoomData: """ # event_id -> (user_id, state) - member_map = attr.ib(type=MemberMap, factory=dict) + member_map: MemberMap = attr.Factory(dict) # user_id -> rules - rules_by_user = attr.ib(type=RulesByUser, factory=dict) + rules_by_user: RulesByUser = attr.Factory(dict) # The last state group we updated the caches for. If the state_group of # a new event comes along, we know that we can just return the cached # result. # On invalidation of the rules themselves (if the user changes them), # we invalidate everything and set state_group to `object()` - state_group = attr.ib(type=StateGroup, factory=object) + state_group: StateGroup = attr.Factory(object) # A sequence number to keep track of when we're allowed to update the # cache. We bump the sequence number when we invalidate the cache. If # the sequence number changes while we're calculating stuff we should # not update the cache with it. - sequence = attr.ib(type=int, default=0) + sequence: int = 0 # A cache of user_ids that we *know* aren't interesting, e.g. user_ids # owned by AS's, or remote users, etc. (I.e. users we will never need to # calculate push for) # These never need to be invalidated as we will never set up push for # them. - uninteresting_user_set = attr.ib(type=Set[str], factory=set) + uninteresting_user_set: Set[str] = attr.Factory(set) class RulesForRoom: @@ -553,7 +553,7 @@ class RulesForRoom: self.data.state_group = state_group -@attr.attrs(slots=True, frozen=True) +@attr.attrs(slots=True, frozen=True, auto_attribs=True) class _Invalidation: # _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules, # which means that it it is stored on the bulk_get_push_rules cache entry. In order @@ -564,8 +564,8 @@ class _Invalidation: # attrs provides suitable __hash__ and __eq__ methods, provided we remember to # set `frozen=True`. - cache = attr.ib(type=LruCache) - room_id = attr.ib(type=str) + cache: LruCache + room_id: str def __call__(self) -> None: rules_data = self.cache.get(self.room_id, None, update_metrics=False) diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index ff904c2b4a..dadfc57413 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -178,7 +178,7 @@ class Mailer: await self.send_email( email_address, self.email_subjects.email_validation - % {"server_name": self.hs.config.server.server_name}, + % {"server_name": self.hs.config.server.server_name, "app": self.app_name}, template_vars, ) @@ -209,7 +209,7 @@ class Mailer: await self.send_email( email_address, self.email_subjects.email_validation - % {"server_name": self.hs.config.server.server_name}, + % {"server_name": self.hs.config.server.server_name, "app": self.app_name}, template_vars, ) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index a390cfcb74..4f4f1ad453 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -50,12 +50,12 @@ data part are: """ -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class EventsStreamRow: """A parsed row from the events replication stream""" - type = attr.ib() # str: the TypeId of one of the *EventsStreamRows - data = attr.ib() # BaseEventsStreamRow + type: str # the TypeId of one of the *EventsStreamRows + data: "BaseEventsStreamRow" class BaseEventsStreamRow: @@ -79,28 +79,28 @@ class BaseEventsStreamRow: return cls(*data) -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class EventsStreamEventRow(BaseEventsStreamRow): TypeId = "ev" - event_id = attr.ib(type=str) - room_id = attr.ib(type=str) - type = attr.ib(type=str) - state_key = attr.ib(type=Optional[str]) - redacts = attr.ib(type=Optional[str]) - relates_to = attr.ib(type=Optional[str]) - membership = attr.ib(type=Optional[str]) - rejected = attr.ib(type=bool) + event_id: str + room_id: str + type: str + state_key: Optional[str] + redacts: Optional[str] + relates_to: Optional[str] + membership: Optional[str] + rejected: bool -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class EventsStreamCurrentStateRow(BaseEventsStreamRow): TypeId = "state" - room_id = attr.ib() # str - type = attr.ib() # str - state_key = attr.ib() # str - event_id = attr.ib() # str, optional + room_id: str + type: str + state_key: str + event_id: Optional[str] _EventRows: Tuple[Type[BaseEventsStreamRow], ...] = ( diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py index 6ec00ce0b9..e9bce22a34 100644 --- a/synapse/rest/admin/background_updates.py +++ b/synapse/rest/admin/background_updates.py @@ -123,34 +123,25 @@ class BackgroundUpdateStartJobRestServlet(RestServlet): job_name = body["job_name"] if job_name == "populate_stats_process_rooms": - jobs = [ - { - "update_name": "populate_stats_process_rooms", - "progress_json": "{}", - }, - ] + jobs = [("populate_stats_process_rooms", "{}", "")] elif job_name == "regenerate_directory": jobs = [ - { - "update_name": "populate_user_directory_createtables", - "progress_json": "{}", - "depends_on": "", - }, - { - "update_name": "populate_user_directory_process_rooms", - "progress_json": "{}", - "depends_on": "populate_user_directory_createtables", - }, - { - "update_name": "populate_user_directory_process_users", - "progress_json": "{}", - "depends_on": "populate_user_directory_process_rooms", - }, - { - "update_name": "populate_user_directory_cleanup", - "progress_json": "{}", - "depends_on": "populate_user_directory_process_users", - }, + ("populate_user_directory_createtables", "{}", ""), + ( + "populate_user_directory_process_rooms", + "{}", + "populate_user_directory_createtables", + ), + ( + "populate_user_directory_process_users", + "{}", + "populate_user_directory_process_rooms", + ), + ( + "populate_user_directory_cleanup", + "{}", + "populate_user_directory_process_users", + ), ] else: raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid job_name") @@ -158,6 +149,7 @@ class BackgroundUpdateStartJobRestServlet(RestServlet): try: await self._store.db_pool.simple_insert_many( table="background_updates", + keys=("update_name", "progress_json", "depends_on"), values=jobs, desc=f"admin_api_run_{job_name}", ) diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py index 50d88c9109..8cd3fa189e 100644 --- a/synapse/rest/admin/federation.py +++ b/synapse/rest/admin/federation.py @@ -111,25 +111,37 @@ class DestinationsRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self._auth, request) + if not await self._store.is_destination_known(destination): + raise NotFoundError("Unknown destination") + destination_retry_timings = await self._store.get_destination_retry_timings( destination ) - if not destination_retry_timings: - raise NotFoundError("Unknown destination") - last_successful_stream_ordering = ( await self._store.get_destination_last_successful_stream_ordering( destination ) ) - response = { + response: JsonDict = { "destination": destination, - "failure_ts": destination_retry_timings.failure_ts, - "retry_last_ts": destination_retry_timings.retry_last_ts, - "retry_interval": destination_retry_timings.retry_interval, "last_successful_stream_ordering": last_successful_stream_ordering, } + if destination_retry_timings: + response = { + **response, + "failure_ts": destination_retry_timings.failure_ts, + "retry_last_ts": destination_retry_timings.retry_last_ts, + "retry_interval": destination_retry_timings.retry_interval, + } + else: + response = { + **response, + "failure_ts": None, + "retry_last_ts": 0, + "retry_interval": 0, + } + return HTTPStatus.OK, response diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 7236e4027f..299f5c9eb0 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -466,7 +466,7 @@ class UserMediaRestServlet(RestServlet): ) deleted_media, total = await self.media_repository.delete_local_media_ids( - ([row["media_id"] for row in media]) + [row["media_id"] for row in media] ) return HTTPStatus.OK, {"deleted_media": deleted_media, "total": total} diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 6030373ebc..efe25fe7eb 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -424,7 +424,7 @@ class RoomStateRestServlet(RestServlet): event_ids = await self.store.get_current_state_ids(room_id) events = await self.store.get_events(event_ids.values()) now = self.clock.time_msec() - room_state = await self._event_serializer.serialize_events(events.values(), now) + room_state = self._event_serializer.serialize_events(events.values(), now) ret = {"state": room_state} return HTTPStatus.OK, ret @@ -744,22 +744,17 @@ class RoomEventContextServlet(RestServlet): ) time_now = self.clock.time_msec() - results["events_before"] = await self._event_serializer.serialize_events( - results["events_before"], - time_now, - bundle_aggregations=True, + aggregations = results.pop("aggregations", None) + results["events_before"] = self._event_serializer.serialize_events( + results["events_before"], time_now, bundle_aggregations=aggregations ) - results["event"] = await self._event_serializer.serialize_event( - results["event"], - time_now, - bundle_aggregations=True, + results["event"] = self._event_serializer.serialize_event( + results["event"], time_now, bundle_aggregations=aggregations ) - results["events_after"] = await self._event_serializer.serialize_events( - results["events_after"], - time_now, - bundle_aggregations=True, + results["events_after"] = self._event_serializer.serialize_events( + results["events_after"], time_now, bundle_aggregations=aggregations ) - results["state"] = await self._event_serializer.serialize_events( + results["state"] = self._event_serializer.serialize_events( results["state"], time_now ) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 78e795c347..c2617ee30c 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -173,12 +173,11 @@ class UserRestServletV2(RestServlet): if not self.hs.is_mine(target_user): raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users") - ret = await self.admin_handler.get_user(target_user) - - if not ret: + user_info_dict = await self.admin_handler.get_user(target_user) + if not user_info_dict: raise NotFoundError("User not found") - return HTTPStatus.OK, ret + return HTTPStatus.OK, user_info_dict async def on_PUT( self, request: SynapseRequest, user_id: str @@ -399,10 +398,10 @@ class UserRestServletV2(RestServlet): target_user, requester, body["avatar_url"], True ) - user = await self.admin_handler.get_user(target_user) - assert user is not None + user_info_dict = await self.admin_handler.get_user(target_user) + assert user_info_dict is not None - return 201, user + return HTTPStatus.CREATED, user_info_dict class UserRegisterServlet(RestServlet): diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 13b72a045a..672c821061 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -91,7 +91,7 @@ class EventRestServlet(RestServlet): time_now = self.clock.time_msec() if event: - result = await self._event_serializer.serialize_event(event, time_now) + result = self._event_serializer.serialize_event(event, time_now) return 200, result else: return 404, "Event not found." diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index acd0c9e135..8e427a96a3 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -72,7 +72,7 @@ class NotificationsServlet(RestServlet): "actions": pa.actions, "ts": pa.received_ts, "event": ( - await self._event_serializer.serialize_event( + self._event_serializer.serialize_event( notif_events[pa.event_id], self.clock.time_msec(), event_format=format_event_for_client_v2_without_room_id, diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 5815650ee6..8cf5ebaa07 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -19,28 +19,20 @@ any time to reflect changes in the MSC. """ import logging -from typing import TYPE_CHECKING, Awaitable, Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple -from synapse.api.constants import EventTypes, RelationTypes -from synapse.api.errors import ShadowBanError, SynapseError +from synapse.api.constants import RelationTypes +from synapse.api.errors import SynapseError from synapse.http.server import HttpServer -from synapse.http.servlet import ( - RestServlet, - parse_integer, - parse_json_object_from_request, - parse_string, -) +from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest -from synapse.rest.client.transactions import HttpTransactionCache +from synapse.rest.client._base import client_patterns from synapse.storage.relations import ( AggregationPaginationToken, PaginationChunk, RelationPaginationToken, ) from synapse.types import JsonDict -from synapse.util.stringutils import random_string - -from ._base import client_patterns if TYPE_CHECKING: from synapse.server import HomeServer @@ -48,112 +40,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class RelationSendServlet(RestServlet): - """Helper API for sending events that have relation data. - - Example API shape to send a 👍 reaction to a room: - - POST /rooms/!foo/send_relation/$bar/m.annotation/m.reaction?key=%F0%9F%91%8D - {} - - { - "event_id": "$foobar" - } - """ - - PATTERN = ( - "/rooms/(?P<room_id>[^/]*)/send_relation" - "/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.event_creation_handler = hs.get_event_creation_handler() - self.txns = HttpTransactionCache(hs) - - def register(self, http_server: HttpServer) -> None: - http_server.register_paths( - "POST", - client_patterns(self.PATTERN + "$", releases=()), - self.on_PUT_or_POST, - self.__class__.__name__, - ) - http_server.register_paths( - "PUT", - client_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()), - self.on_PUT, - self.__class__.__name__, - ) - - def on_PUT( - self, - request: SynapseRequest, - room_id: str, - parent_id: str, - relation_type: str, - event_type: str, - txn_id: Optional[str] = None, - ) -> Awaitable[Tuple[int, JsonDict]]: - return self.txns.fetch_or_execute_request( - request, - self.on_PUT_or_POST, - request, - room_id, - parent_id, - relation_type, - event_type, - txn_id, - ) - - async def on_PUT_or_POST( - self, - request: SynapseRequest, - room_id: str, - parent_id: str, - relation_type: str, - event_type: str, - txn_id: Optional[str] = None, - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - - if event_type == EventTypes.Member: - # Add relations to a membership is meaningless, so we just deny it - # at the CS API rather than trying to handle it correctly. - raise SynapseError(400, "Cannot send member events with relations") - - content = parse_json_object_from_request(request) - - aggregation_key = parse_string(request, "key", encoding="utf-8") - - content["m.relates_to"] = { - "event_id": parent_id, - "rel_type": relation_type, - } - if aggregation_key is not None: - content["m.relates_to"]["key"] = aggregation_key - - event_dict = { - "type": event_type, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - } - - try: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict=event_dict, txn_id=txn_id - ) - event_id = event.event_id - except ShadowBanError: - event_id = "$" + random_string(43) - - return 200, {"event_id": event_id} - - class RelationPaginationServlet(RestServlet): """API to paginate relations on an event by topological ordering, optionally filtered by relation type and event type. @@ -227,13 +113,16 @@ class RelationPaginationServlet(RestServlet): now = self.clock.time_msec() # Do not bundle aggregations when retrieving the original event because # we want the content before relations are applied to it. - original_event = await self._event_serializer.serialize_event( - event, now, bundle_aggregations=False + original_event = self._event_serializer.serialize_event( + event, now, bundle_aggregations=None ) # The relations returned for the requested event do include their # bundled aggregations. - serialized_events = await self._event_serializer.serialize_events( - events, now, bundle_aggregations=True + aggregations = await self.store.get_bundled_aggregations( + events, requester.user.to_string() + ) + serialized_events = self._event_serializer.serialize_events( + events, now, bundle_aggregations=aggregations ) return_value = pagination_chunk.to_dict() @@ -422,7 +311,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): ) now = self.clock.time_msec() - serialized_events = await self._event_serializer.serialize_events(events, now) + serialized_events = self._event_serializer.serialize_events(events, now) return_value = result.to_dict() return_value["chunk"] = serialized_events @@ -431,7 +320,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - RelationSendServlet(hs).register(http_server) RelationPaginationServlet(hs).register(http_server) RelationAggregationPaginationServlet(hs).register(http_server) RelationAggregationGroupPaginationServlet(hs).register(http_server) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 40330749e5..90bb9142a0 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -642,6 +642,7 @@ class RoomEventServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() + self._store = hs.get_datastore() self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() @@ -660,10 +661,15 @@ class RoomEventServlet(RestServlet): # https://matrix.org/docs/spec/client_server/r0.5.0#get-matrix-client-r0-rooms-roomid-event-eventid raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) - time_now = self.clock.time_msec() if event: - event_dict = await self._event_serializer.serialize_event( - event, time_now, bundle_aggregations=True + # Ensure there are bundled aggregations available. + aggregations = await self._store.get_bundled_aggregations( + [event], requester.user.to_string() + ) + + time_now = self.clock.time_msec() + event_dict = self._event_serializer.serialize_event( + event, time_now, bundle_aggregations=aggregations ) return 200, event_dict @@ -708,16 +714,17 @@ class RoomEventContextServlet(RestServlet): raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() - results["events_before"] = await self._event_serializer.serialize_events( - results["events_before"], time_now, bundle_aggregations=True + aggregations = results.pop("aggregations", None) + results["events_before"] = self._event_serializer.serialize_events( + results["events_before"], time_now, bundle_aggregations=aggregations ) - results["event"] = await self._event_serializer.serialize_event( - results["event"], time_now, bundle_aggregations=True + results["event"] = self._event_serializer.serialize_event( + results["event"], time_now, bundle_aggregations=aggregations ) - results["events_after"] = await self._event_serializer.serialize_events( - results["events_after"], time_now, bundle_aggregations=True + results["events_after"] = self._event_serializer.serialize_events( + results["events_after"], time_now, bundle_aggregations=aggregations ) - results["state"] = await self._event_serializer.serialize_events( + results["state"] = self._event_serializer.serialize_events( results["state"], time_now ) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index e99a943d0d..d20ae1421e 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -17,7 +17,6 @@ from collections import defaultdict from typing import ( TYPE_CHECKING, Any, - Awaitable, Callable, Dict, Iterable, @@ -395,7 +394,7 @@ class SyncRestServlet(RestServlet): """ invited = {} for room in rooms: - invite = await self._event_serializer.serialize_event( + invite = self._event_serializer.serialize_event( room.invite, time_now, token_id=token_id, @@ -432,7 +431,7 @@ class SyncRestServlet(RestServlet): """ knocked = {} for room in rooms: - knock = await self._event_serializer.serialize_event( + knock = self._event_serializer.serialize_event( room.knock, time_now, token_id=token_id, @@ -525,21 +524,14 @@ class SyncRestServlet(RestServlet): The room, encoded in our response format """ - def serialize(events: Iterable[EventBase]) -> Awaitable[List[JsonDict]]: + def serialize( + events: Iterable[EventBase], + aggregations: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> List[JsonDict]: return self._event_serializer.serialize_events( events, time_now=time_now, - # Don't bother to bundle aggregations if the timeline is unlimited, - # as clients will have all the necessary information. - # bundle_aggregations=room.timeline.limited, - # - # richvdh 2021-12-15: disable this temporarily as it has too high an - # overhead for initialsyncs. We need to figure out a way that the - # bundling can be done *before* the events are stored in the - # SyncResponseCache so that this part can be synchronous. - # - # Ensure to re-enable the test at tests/rest/client/test_relations.py::RelationsTestCase.test_bundled_aggregations. - bundle_aggregations=False, + bundle_aggregations=aggregations, token_id=token_id, event_format=event_formatter, only_event_fields=only_fields, @@ -561,8 +553,10 @@ class SyncRestServlet(RestServlet): event.room_id, ) - serialized_state = await serialize(state_events) - serialized_timeline = await serialize(timeline_events) + serialized_state = serialize(state_events) + serialized_timeline = serialize( + timeline_events, room.timeline.bundled_aggregations + ) account_data = room.account_data diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index fca239d8c7..9f6c251caf 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -343,7 +343,7 @@ class SpamMediaException(NotFoundError): """ -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class ReadableFileWrapper: """Wrapper that allows reading a file in chunks, yielding to the reactor, and writing to a callback. @@ -354,8 +354,8 @@ class ReadableFileWrapper: CHUNK_SIZE = 2 ** 14 - clock = attr.ib(type=Clock) - path = attr.ib(type=str) + clock: Clock + path: str async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None: """Reads the file in chunks and calls the callback with each chunk.""" diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py index cce1527ed9..2177b46c9e 100644 --- a/synapse/rest/media/v1/oembed.py +++ b/synapse/rest/media/v1/oembed.py @@ -33,6 +33,8 @@ logger = logging.getLogger(__name__) class OEmbedResult: # The Open Graph result (converted from the oEmbed result). open_graph_result: JsonDict + # The author_name of the oEmbed result + author_name: Optional[str] # Number of milliseconds to cache the content, according to the oEmbed response. # # This will be None if no cache-age is provided in the oEmbed response (or @@ -154,11 +156,12 @@ class OEmbedProvider: "og:url": url, } - # Use either title or author's name as the title. - title = oembed.get("title") or oembed.get("author_name") + title = oembed.get("title") if title: open_graph_response["og:title"] = title + author_name = oembed.get("author_name") + # Use the provider name and as the site. provider_name = oembed.get("provider_name") if provider_name: @@ -193,9 +196,10 @@ class OEmbedProvider: # Trap any exception and let the code follow as usual. logger.warning("Error parsing oEmbed metadata from %s: %r", url, e) open_graph_response = {} + author_name = None cache_age = None - return OEmbedResult(open_graph_response, cache_age) + return OEmbedResult(open_graph_response, author_name, cache_age) def _fetch_urls(tree: "etree.Element", tag_name: str) -> List[str]: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index a3829d943b..e8881bc870 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -262,6 +262,7 @@ class PreviewUrlResource(DirectServeJsonResource): # The number of milliseconds that the response should be considered valid. expiration_ms = media_info.expires + author_name: Optional[str] = None if _is_media(media_info.media_type): file_id = media_info.filesystem_id @@ -294,17 +295,25 @@ class PreviewUrlResource(DirectServeJsonResource): # Check if this HTML document points to oEmbed information and # defer to that. oembed_url = self._oembed.autodiscover_from_html(tree) - og = {} + og_from_oembed: JsonDict = {} if oembed_url: oembed_info = await self._download_url(oembed_url, user) - og, expiration_ms = await self._handle_oembed_response( + ( + og_from_oembed, + author_name, + expiration_ms, + ) = await self._handle_oembed_response( url, oembed_info, expiration_ms ) - # If there was no oEmbed URL (or oEmbed parsing failed), attempt - # to generate the Open Graph information from the HTML. - if not oembed_url or not og: - og = parse_html_to_open_graph(tree, media_info.uri) + # Parse Open Graph information from the HTML in case the oEmbed + # response failed or is incomplete. + og_from_html = parse_html_to_open_graph(tree, media_info.uri) + + # Compile the Open Graph response by using the scraped + # information from the HTML and overlaying any information + # from the oEmbed response. + og = {**og_from_html, **og_from_oembed} await self._precache_image_url(user, media_info, og) else: @@ -312,7 +321,7 @@ class PreviewUrlResource(DirectServeJsonResource): elif oembed_url: # Handle the oEmbed information. - og, expiration_ms = await self._handle_oembed_response( + og, author_name, expiration_ms = await self._handle_oembed_response( url, media_info, expiration_ms ) await self._precache_image_url(user, media_info, og) @@ -321,6 +330,11 @@ class PreviewUrlResource(DirectServeJsonResource): logger.warning("Failed to find any OG data in %s", url) og = {} + # If we don't have a title but we have author_name, copy it as + # title + if not og.get("og:title") and author_name: + og["og:title"] = author_name + # filter out any stupidly long values keys_to_remove = [] for k, v in og.items(): @@ -484,7 +498,7 @@ class PreviewUrlResource(DirectServeJsonResource): async def _handle_oembed_response( self, url: str, media_info: MediaInfo, expiration_ms: int - ) -> Tuple[JsonDict, int]: + ) -> Tuple[JsonDict, Optional[str], int]: """ Parse the downloaded oEmbed info. @@ -497,11 +511,12 @@ class PreviewUrlResource(DirectServeJsonResource): Returns: A tuple of: The Open Graph dictionary, if the oEmbed info can be parsed. + The author name if it could be retrieved from oEmbed. The (possibly updated) length of time, in milliseconds, the media is valid for. """ # If JSON was not returned, there's nothing to do. if not _is_json(media_info.media_type): - return {}, expiration_ms + return {}, None, expiration_ms with open(media_info.filename, "rb") as file: body = file.read() @@ -513,7 +528,7 @@ class PreviewUrlResource(DirectServeJsonResource): if open_graph_result and oembed_response.cache_age is not None: expiration_ms = oembed_response.cache_age - return open_graph_result, expiration_ms + return open_graph_result, oembed_response.author_name, expiration_ms def _start_expire_url_cache_data(self) -> Deferred: return run_as_background_process( diff --git a/synapse/server.py b/synapse/server.py index 185e40e4da..3032f0b738 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -759,7 +759,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_event_client_serializer(self) -> EventClientSerializer: - return EventClientSerializer(self) + return EventClientSerializer() @cache_in_self def get_password_policy_handler(self) -> PasswordPolicyHandler: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 69ac8c3423..67e8bc6ec2 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -45,7 +45,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersio from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.context import ContextResourceUsage -from synapse.logging.utils import log_function from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.roommember import ProfileInfo @@ -450,19 +449,19 @@ class StateHandler: return {key: state_map[ev_id] for key, ev_id in new_state.items()} -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class _StateResMetrics: """Keeps track of some usage metrics about state res.""" # System and User CPU time, in seconds - cpu_time = attr.ib(type=float, default=0.0) + cpu_time: float = 0.0 # time spent on database transactions (excluding scheduling time). This roughly # corresponds to the amount of work done on the db server, excluding event fetches. - db_time = attr.ib(type=float, default=0.0) + db_time: float = 0.0 # number of events fetched from the db. - db_events = attr.ib(type=int, default=0) + db_events: int = 0 _biggest_room_by_cpu_counter = Counter( @@ -512,7 +511,6 @@ class StateResolutionHandler: self.clock.looping_call(self._report_metrics, 120 * 1000) - @log_function async def resolve_state_groups( self, room_id: str, diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 2cacc7dd6c..57cc1d76e0 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -143,7 +143,7 @@ def make_conn( return db_conn -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class LoggingDatabaseConnection: """A wrapper around a database connection that returns `LoggingTransaction` as its cursor class. @@ -151,9 +151,9 @@ class LoggingDatabaseConnection: This is mainly used on startup to ensure that queries get logged correctly """ - conn = attr.ib(type=Connection) - engine = attr.ib(type=BaseDatabaseEngine) - default_txn_name = attr.ib(type=str) + conn: Connection + engine: BaseDatabaseEngine + default_txn_name: str def cursor( self, *, txn_name=None, after_callbacks=None, exception_callbacks=None @@ -934,56 +934,6 @@ class DatabasePool: txn.execute(sql, vals) async def simple_insert_many( - self, table: str, values: List[Dict[str, Any]], desc: str - ) -> None: - """Executes an INSERT query on the named table. - - The input is given as a list of dicts, with one dict per row. - Generally simple_insert_many_values should be preferred for new code. - - Args: - table: string giving the table name - values: dict of new column names and values for them - desc: description of the transaction, for logging and metrics - """ - await self.runInteraction(desc, self.simple_insert_many_txn, table, values) - - @staticmethod - def simple_insert_many_txn( - txn: LoggingTransaction, table: str, values: List[Dict[str, Any]] - ) -> None: - """Executes an INSERT query on the named table. - - The input is given as a list of dicts, with one dict per row. - Generally simple_insert_many_values_txn should be preferred for new code. - - Args: - txn: The transaction to use. - table: string giving the table name - values: dict of new column names and values for them - """ - if not values: - return - - # This is a *slight* abomination to get a list of tuples of key names - # and a list of tuples of value names. - # - # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}] - # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)] - # - # The sort is to ensure that we don't rely on dictionary iteration - # order. - keys, vals = zip( - *(zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i) - ) - - for k in keys: - if k != keys[0]: - raise RuntimeError("All items must have the same keys") - - return DatabasePool.simple_insert_many_values_txn(txn, table, keys[0], vals) - - async def simple_insert_many_values( self, table: str, keys: Collection[str], @@ -1002,11 +952,11 @@ class DatabasePool: desc: description of the transaction, for logging and metrics """ await self.runInteraction( - desc, self.simple_insert_many_values_txn, table, keys, values + desc, self.simple_insert_many_txn, table, keys, values ) @staticmethod - def simple_insert_many_values_txn( + def simple_insert_many_txn( txn: LoggingTransaction, table: str, keys: Collection[str], diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 32a553fdd7..ef475e18c7 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -450,7 +450,7 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore): async def add_account_data_for_user( self, user_id: str, account_data_type: str, content: JsonDict ) -> int: - """Add some account_data to a room for a user. + """Add some global account_data for a user. Args: user_id: The user to add a tag for. @@ -536,9 +536,9 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore): self.db_pool.simple_insert_many_txn( txn, table="ignored_users", + keys=("ignorer_user_id", "ignored_user_id"), values=[ - {"ignorer_user_id": user_id, "ignored_user_id": u} - for u in currently_ignored_users - previously_ignored_users + (user_id, u) for u in currently_ignored_users - previously_ignored_users ], ) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 3682cb6a81..4eca97189b 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -432,14 +432,21 @@ class DeviceInboxWorkerStore(SQLBaseStore): self.db_pool.simple_insert_many_txn( txn, table="device_federation_outbox", + keys=( + "destination", + "stream_id", + "queued_ts", + "messages_json", + "instance_name", + ), values=[ - { - "destination": destination, - "stream_id": stream_id, - "queued_ts": now_ms, - "messages_json": json_encoder.encode(edu), - "instance_name": self._instance_name, - } + ( + destination, + stream_id, + now_ms, + json_encoder.encode(edu), + self._instance_name, + ) for destination, edu in remote_messages_by_destination.items() ], ) @@ -571,14 +578,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): self.db_pool.simple_insert_many_txn( txn, table="device_inbox", + keys=("user_id", "device_id", "stream_id", "message_json", "instance_name"), values=[ - { - "user_id": user_id, - "device_id": device_id, - "stream_id": stream_id, - "message_json": message_json, - "instance_name": self._instance_name, - } + (user_id, device_id, stream_id, message_json, self._instance_name) for user_id, messages_by_device in local_by_user_then_device.items() for device_id, message_json in messages_by_device.items() ], diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index bc7e876047..b2a5cd9a65 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -53,6 +53,7 @@ if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) +issue_8631_logger = logging.getLogger("synapse.8631_debug") DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( "drop_device_list_streams_non_unique_indexes" @@ -229,6 +230,12 @@ class DeviceWorkerStore(SQLBaseStore): if not updates: return now_stream_id, [] + if issue_8631_logger.isEnabledFor(logging.DEBUG): + data = {(user, device): stream_id for user, device, stream_id, _ in updates} + issue_8631_logger.debug( + "device updates need to be sent to %s: %s", destination, data + ) + # get the cross-signing keys of the users in the list, so that we can # determine which of the device changes were cross-signing keys users = {r[0] for r in updates} @@ -365,6 +372,17 @@ class DeviceWorkerStore(SQLBaseStore): # and remove the length budgeting above. results.append(("org.matrix.signing_key_update", result)) + if issue_8631_logger.isEnabledFor(logging.DEBUG): + for (user_id, edu) in results: + issue_8631_logger.debug( + "device update to %s for %s from %s to %s: %s", + destination, + user_id, + from_stream_id, + last_processed_stream_id, + edu, + ) + return last_processed_stream_id, results def _get_device_updates_by_remote_txn( @@ -781,7 +799,7 @@ class DeviceWorkerStore(SQLBaseStore): @cached(max_entries=10000) async def get_device_list_last_stream_id_for_remote( self, user_id: str - ) -> Optional[Any]: + ) -> Optional[str]: """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ @@ -797,7 +815,9 @@ class DeviceWorkerStore(SQLBaseStore): cached_method_name="get_device_list_last_stream_id_for_remote", list_name="user_ids", ) - async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]): + async def get_device_list_last_stream_id_for_remotes( + self, user_ids: Iterable[str] + ) -> Dict[str, Optional[str]]: rows = await self.db_pool.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", @@ -1384,6 +1404,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): content: JsonDict, stream_id: str, ) -> None: + """Delete, update or insert a cache entry for this (user, device) pair.""" if content.get("deleted"): self.db_pool.simple_delete_txn( txn, @@ -1443,6 +1464,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def _update_remote_device_list_cache_txn( self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int ) -> None: + """Replace the list of cached devices for this user with the given list.""" self.db_pool.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) @@ -1450,12 +1472,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.db_pool.simple_insert_many_txn( txn, table="device_lists_remote_cache", + keys=("user_id", "device_id", "content"), values=[ - { - "user_id": user_id, - "device_id": content["device_id"], - "content": json_encoder.encode(content), - } + (user_id, content["device_id"], json_encoder.encode(content)) for content in devices ], ) @@ -1543,8 +1562,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.db_pool.simple_insert_many_txn( txn, table="device_lists_stream", + keys=("stream_id", "user_id", "device_id"), values=[ - {"stream_id": stream_id, "user_id": user_id, "device_id": device_id} + (stream_id, user_id, device_id) for stream_id, device_id in zip(stream_ids, device_ids) ], ) @@ -1571,18 +1591,27 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.db_pool.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", + keys=( + "destination", + "stream_id", + "user_id", + "device_id", + "sent", + "ts", + "opentracing_context", + ), values=[ - { - "destination": destination, - "stream_id": next(next_stream_id), - "user_id": user_id, - "device_id": device_id, - "sent": False, - "ts": now, - "opentracing_context": json_encoder.encode(context) + ( + destination, + next(next_stream_id), + user_id, + device_id, + False, + now, + json_encoder.encode(context) if whitelisted_homeserver(destination) else "{}", - } + ) for destination in hosts for device_id in device_ids ], diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index f76c6121e8..5903fdaf00 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -112,10 +112,8 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore): self.db_pool.simple_insert_many_txn( txn, table="room_alias_servers", - values=[ - {"room_alias": room_alias.to_string(), "server": server} - for server in servers - ], + keys=("room_alias", "server"), + values=[(room_alias.to_string(), server) for server in servers], ) self._invalidate_cache_and_stream( diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index 0cb48b9dd7..b789a588a5 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -110,16 +110,16 @@ class EndToEndRoomKeyStore(SQLBaseStore): values = [] for (room_id, session_id, room_key) in room_keys: values.append( - { - "user_id": user_id, - "version": version_int, - "room_id": room_id, - "session_id": session_id, - "first_message_index": room_key["first_message_index"], - "forwarded_count": room_key["forwarded_count"], - "is_verified": room_key["is_verified"], - "session_data": json_encoder.encode(room_key["session_data"]), - } + ( + user_id, + version_int, + room_id, + session_id, + room_key["first_message_index"], + room_key["forwarded_count"], + room_key["is_verified"], + json_encoder.encode(room_key["session_data"]), + ) ) log_kv( { @@ -131,7 +131,19 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) await self.db_pool.simple_insert_many( - table="e2e_room_keys", values=values, desc="add_e2e_room_keys" + table="e2e_room_keys", + keys=( + "user_id", + "version", + "room_id", + "session_id", + "first_message_index", + "forwarded_count", + "is_verified", + "session_data", + ), + values=values, + desc="add_e2e_room_keys", ) @trace diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 57b5ffbad3..1f8447b507 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -50,16 +50,16 @@ if TYPE_CHECKING: from synapse.server import HomeServer -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class DeviceKeyLookupResult: """The type returned by get_e2e_device_keys_and_signatures""" - display_name = attr.ib(type=Optional[str]) + display_name: Optional[str] # the key data from e2e_device_keys_json. Typically includes fields like # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing # key) and "signatures" (a map from (user id) to (key id/device_id) to signature.) - keys = attr.ib(type=Optional[JsonDict]) + keys: Optional[JsonDict] class EndToEndKeyBackgroundStore(SQLBaseStore): @@ -387,15 +387,16 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker self.db_pool.simple_insert_many_txn( txn, table="e2e_one_time_keys_json", + keys=( + "user_id", + "device_id", + "algorithm", + "key_id", + "ts_added_ms", + "key_json", + ), values=[ - { - "user_id": user_id, - "device_id": device_id, - "algorithm": algorithm, - "key_id": key_id, - "ts_added_ms": time_now, - "key_json": json_bytes, - } + (user_id, device_id, algorithm, key_id, time_now, json_bytes) for algorithm, key_id, json_bytes in new_keys ], ) @@ -1186,15 +1187,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): """ await self.db_pool.simple_insert_many( "e2e_cross_signing_signatures", - [ - { - "user_id": user_id, - "key_id": item.signing_key_id, - "target_user_id": item.target_user_id, - "target_device_id": item.target_device_id, - "signature": item.signature, - } + keys=( + "user_id", + "key_id", + "target_user_id", + "target_device_id", + "signature", + ), + values=[ + ( + user_id, + item.signing_key_id, + item.target_user_id, + item.target_device_id, + item.signature, + ) for item in signatures ], - "add_e2e_signing_key", + desc="add_e2e_signing_key", ) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index a98e6b2593..b7c4c62222 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -875,14 +875,21 @@ class EventPushActionsWorkerStore(SQLBaseStore): self.db_pool.simple_insert_many_txn( txn, table="event_push_summary", + keys=( + "user_id", + "room_id", + "notif_count", + "unread_count", + "stream_ordering", + ), values=[ - { - "user_id": user_id, - "room_id": room_id, - "notif_count": summary.notif_count, - "unread_count": summary.unread_count, - "stream_ordering": summary.stream_ordering, - } + ( + user_id, + room_id, + summary.notif_count, + summary.unread_count, + summary.stream_ordering, + ) for ((user_id, room_id), summary) in summaries.items() if summary.old_user_id is None ], diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index dd255aefb9..1ae1ebe108 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -39,7 +39,6 @@ from synapse.api.room_versions import RoomVersions from synapse.crypto.event_signing import compute_event_reference_hash from synapse.events import EventBase # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401 -from synapse.logging.utils import log_function from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -69,7 +68,7 @@ event_counter = Counter( ) -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class DeltaState: """Deltas to use to update the `current_state_events` table. @@ -80,9 +79,9 @@ class DeltaState: should e.g. be removed from `current_state_events` table. """ - to_delete = attr.ib(type=List[Tuple[str, str]]) - to_insert = attr.ib(type=StateMap[str]) - no_longer_in_room = attr.ib(type=bool, default=False) + to_delete: List[Tuple[str, str]] + to_insert: StateMap[str] + no_longer_in_room: bool = False class PersistEventsStore: @@ -328,7 +327,6 @@ class PersistEventsStore: return existing_prevs - @log_function def _persist_events_txn( self, txn: LoggingTransaction, @@ -442,12 +440,9 @@ class PersistEventsStore: self.db_pool.simple_insert_many_txn( txn, table="event_auth", + keys=("event_id", "room_id", "auth_id"), values=[ - { - "event_id": event.event_id, - "room_id": event.room_id, - "auth_id": auth_id, - } + (event.event_id, event.room_id, auth_id) for event in events for auth_id in event.auth_event_ids() if event.is_state() @@ -675,8 +670,9 @@ class PersistEventsStore: db_pool.simple_insert_many_txn( txn, table="event_auth_chains", + keys=("event_id", "chain_id", "sequence_number"), values=[ - {"event_id": event_id, "chain_id": c_id, "sequence_number": seq} + (event_id, c_id, seq) for event_id, (c_id, seq) in new_chain_tuples.items() ], ) @@ -782,13 +778,14 @@ class PersistEventsStore: db_pool.simple_insert_many_txn( txn, table="event_auth_chain_links", + keys=( + "origin_chain_id", + "origin_sequence_number", + "target_chain_id", + "target_sequence_number", + ), values=[ - { - "origin_chain_id": source_id, - "origin_sequence_number": source_seq, - "target_chain_id": target_id, - "target_sequence_number": target_seq, - } + (source_id, source_seq, target_id, target_seq) for ( source_id, source_seq, @@ -943,20 +940,28 @@ class PersistEventsStore: txn_id = getattr(event.internal_metadata, "txn_id", None) if token_id and txn_id: to_insert.append( - { - "event_id": event.event_id, - "room_id": event.room_id, - "user_id": event.sender, - "token_id": token_id, - "txn_id": txn_id, - "inserted_ts": self._clock.time_msec(), - } + ( + event.event_id, + event.room_id, + event.sender, + token_id, + txn_id, + self._clock.time_msec(), + ) ) if to_insert: self.db_pool.simple_insert_many_txn( txn, table="event_txn_id", + keys=( + "event_id", + "room_id", + "user_id", + "token_id", + "txn_id", + "inserted_ts", + ), values=to_insert, ) @@ -1161,8 +1166,9 @@ class PersistEventsStore: self.db_pool.simple_insert_many_txn( txn, table="event_forward_extremities", + keys=("event_id", "room_id"), values=[ - {"event_id": ev_id, "room_id": room_id} + (ev_id, room_id) for room_id, new_extrem in new_forward_extremities.items() for ev_id in new_extrem ], @@ -1174,12 +1180,9 @@ class PersistEventsStore: self.db_pool.simple_insert_many_txn( txn, table="stream_ordering_to_exterm", + keys=("room_id", "event_id", "stream_ordering"), values=[ - { - "room_id": room_id, - "event_id": event_id, - "stream_ordering": max_stream_order, - } + (room_id, event_id, max_stream_order) for room_id, new_extrem in new_forward_extremities.items() for event_id in new_extrem ], @@ -1251,20 +1254,22 @@ class PersistEventsStore: for room_id, depth in depth_updates.items(): self._update_min_depth_for_room_txn(txn, room_id, depth) - def _update_outliers_txn(self, txn, events_and_contexts): + def _update_outliers_txn( + self, + txn: LoggingTransaction, + events_and_contexts: List[Tuple[EventBase, EventContext]], + ) -> List[Tuple[EventBase, EventContext]]: """Update any outliers with new event info. - This turns outliers into ex-outliers (unless the new event was - rejected). + This turns outliers into ex-outliers (unless the new event was rejected), and + also removes any other events we have already seen from the list. Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting + txn: db connection + events_and_contexts: events we are persisting Returns: - list[(EventBase, EventContext)] new list, without events which - are already in the events table. + new list, without events which are already in the events table. """ txn.execute( "SELECT event_id, outlier FROM events WHERE event_id in (%s)" @@ -1272,7 +1277,9 @@ class PersistEventsStore: [event.event_id for event, _ in events_and_contexts], ) - have_persisted = {event_id: outlier for event_id, outlier in txn} + have_persisted: Dict[str, bool] = { + event_id: outlier for event_id, outlier in txn + } to_remove = set() for event, context in events_and_contexts: @@ -1282,15 +1289,22 @@ class PersistEventsStore: to_remove.add(event) if context.rejected: - # If the event is rejected then we don't care if the event - # was an outlier or not. + # If the incoming event is rejected then we don't care if the event + # was an outlier or not - what we have is at least as good. continue outlier_persisted = have_persisted[event.event_id] if not event.internal_metadata.is_outlier() and outlier_persisted: # We received a copy of an event that we had already stored as - # an outlier in the database. We now have some state at that + # an outlier in the database. We now have some state at that event # so we need to update the state_groups table with that state. + # + # Note that we do not update the stream_ordering of the event in this + # scenario. XXX: does this cause bugs? It will mean we won't send such + # events down /sync. In general they will be historical events, so that + # doesn't matter too much, but that is not always the case. + + logger.info("Updating state for ex-outlier event %s", event.event_id) # insert into event_to_state_groups. try: @@ -1342,7 +1356,7 @@ class PersistEventsStore: d.pop("redacted_because", None) return d - self.db_pool.simple_insert_many_values_txn( + self.db_pool.simple_insert_many_txn( txn, table="event_json", keys=("event_id", "room_id", "internal_metadata", "json", "format_version"), @@ -1358,7 +1372,7 @@ class PersistEventsStore: ), ) - self.db_pool.simple_insert_many_values_txn( + self.db_pool.simple_insert_many_txn( txn, table="events", keys=( @@ -1412,7 +1426,7 @@ class PersistEventsStore: ) txn.execute(sql + clause, [False] + args) - self.db_pool.simple_insert_many_values_txn( + self.db_pool.simple_insert_many_txn( txn, table="state_events", keys=("event_id", "room_id", "type", "state_key"), @@ -1622,14 +1636,9 @@ class PersistEventsStore: return self.db_pool.simple_insert_many_txn( txn=txn, table="event_labels", + keys=("event_id", "label", "room_id", "topological_ordering"), values=[ - { - "event_id": event_id, - "label": label, - "room_id": room_id, - "topological_ordering": topological_ordering, - } - for label in labels + (event_id, label, room_id, topological_ordering) for label in labels ], ) @@ -1657,16 +1666,13 @@ class PersistEventsStore: vals = [] for event in events: ref_alg, ref_hash_bytes = compute_event_reference_hash(event) - vals.append( - { - "event_id": event.event_id, - "algorithm": ref_alg, - "hash": memoryview(ref_hash_bytes), - } - ) + vals.append((event.event_id, ref_alg, memoryview(ref_hash_bytes))) self.db_pool.simple_insert_many_txn( - txn, table="event_reference_hashes", values=vals + txn, + table="event_reference_hashes", + keys=("event_id", "algorithm", "hash"), + values=vals, ) def _store_room_members_txn( @@ -1689,18 +1695,25 @@ class PersistEventsStore: self.db_pool.simple_insert_many_txn( txn, table="room_memberships", + keys=( + "event_id", + "user_id", + "sender", + "room_id", + "membership", + "display_name", + "avatar_url", + ), values=[ - { - "event_id": event.event_id, - "user_id": event.state_key, - "sender": event.user_id, - "room_id": event.room_id, - "membership": event.membership, - "display_name": non_null_str_or_none( - event.content.get("displayname") - ), - "avatar_url": non_null_str_or_none(event.content.get("avatar_url")), - } + ( + event.event_id, + event.state_key, + event.user_id, + event.room_id, + event.membership, + non_null_str_or_none(event.content.get("displayname")), + non_null_str_or_none(event.content.get("avatar_url")), + ) for event in events ], ) @@ -1791,6 +1804,13 @@ class PersistEventsStore: txn.call_after( self.store.get_thread_summary.invalidate, (parent_id, event.room_id) ) + # It should be safe to only invalidate the cache if the user has not + # previously participated in the thread, but that's difficult (and + # potentially error-prone) so it is always invalidated. + txn.call_after( + self.store.get_thread_participated.invalidate, + (parent_id, event.room_id, event.sender), + ) def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): """Handles keeping track of insertion events and edges/connections. @@ -2163,13 +2183,9 @@ class PersistEventsStore: self.db_pool.simple_insert_many_txn( txn, table="event_edges", + keys=("event_id", "prev_event_id", "room_id", "is_state"), values=[ - { - "event_id": ev.event_id, - "prev_event_id": e_id, - "room_id": ev.room_id, - "is_state": False, - } + (ev.event_id, e_id, ev.room_id, False) for ev in events for e_id in ev.prev_event_ids() ], @@ -2226,17 +2242,17 @@ class PersistEventsStore: ) -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class _LinkMap: """A helper type for tracking links between chains.""" # Stores the set of links as nested maps: source chain ID -> target chain ID # -> source sequence number -> target sequence number. - maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict) + maps: Dict[int, Dict[int, Dict[int, int]]] = attr.Factory(dict) # Stores the links that have been added (with new set to true), as tuples of # `(source chain ID, source sequence no, target chain ID, target sequence no.)` - additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set) + additions: Set[Tuple[int, int, int, int]] = attr.Factory(set) def add_link( self, diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index a68f14ba48..d5f0059665 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -65,22 +65,22 @@ class _BackgroundUpdates: REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column" -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class _CalculateChainCover: """Return value for _calculate_chain_cover_txn.""" # The last room_id/depth/stream processed. - room_id = attr.ib(type=str) - depth = attr.ib(type=int) - stream = attr.ib(type=int) + room_id: str + depth: int + stream: int # Number of rows processed - processed_count = attr.ib(type=int) + processed_count: int # Map from room_id to last depth/stream processed for each room that we have # processed all events for (i.e. the rooms we can flip the # `has_auth_chain_index` for) - finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]]) + finished_room_map: Dict[str, Tuple[int, int]] class EventsBackgroundUpdatesStore(SQLBaseStore): @@ -684,13 +684,14 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): self.db_pool.simple_insert_many_txn( txn=txn, table="event_labels", + keys=("event_id", "label", "room_id", "topological_ordering"), values=[ - { - "event_id": event_id, - "label": label, - "room_id": event_json["room_id"], - "topological_ordering": event_json["depth"], - } + ( + event_id, + label, + event_json["room_id"], + event_json["depth"], + ) for label in event_json["content"].get( EventContentFields.LABELS, [] ) @@ -803,29 +804,19 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): if not has_state: state_events.append( - { - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - } + (event.event_id, event.room_id, event.type, event.state_key) ) if not has_event_auth: # Old, dodgy, events may have duplicate auth events, which we # need to deduplicate as we have a unique constraint. for auth_id in set(event.auth_event_ids()): - auth_events.append( - { - "room_id": event.room_id, - "event_id": event.event_id, - "auth_id": auth_id, - } - ) + auth_events.append((event.event_id, event.room_id, auth_id)) if state_events: await self.db_pool.simple_insert_many( table="state_events", + keys=("event_id", "room_id", "type", "state_key"), values=state_events, desc="_rejected_events_metadata_state_events", ) @@ -833,6 +824,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): if auth_events: await self.db_pool.simple_insert_many( table="event_auth", + keys=("event_id", "room_id", "auth_id"), values=auth_events, desc="_rejected_events_metadata_event_auth", ) diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index cbf9ec38f7..4f05811a77 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -129,18 +129,29 @@ class PresenceStore(PresenceBackgroundUpdateStore): self.db_pool.simple_insert_many_txn( txn, table="presence_stream", + keys=( + "stream_id", + "user_id", + "state", + "last_active_ts", + "last_federation_update_ts", + "last_user_sync_ts", + "status_msg", + "currently_active", + "instance_name", + ), values=[ - { - "stream_id": stream_id, - "user_id": state.user_id, - "state": state.state, - "last_active_ts": state.last_active_ts, - "last_federation_update_ts": state.last_federation_update_ts, - "last_user_sync_ts": state.last_user_sync_ts, - "status_msg": state.status_msg, - "currently_active": state.currently_active, - "instance_name": self._instance_name, - } + ( + stream_id, + state.user_id, + state.state, + state.last_active_ts, + state.last_federation_update_ts, + state.last_user_sync_ts, + state.status_msg, + state.currently_active, + self._instance_name, + ) for stream_id, state in zip(stream_orderings, presence_states) ], ) diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 747b4f31df..cf64cd63a4 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -561,13 +561,9 @@ class PusherStore(PusherWorkerStore): self.db_pool.simple_insert_many_txn( txn, table="deleted_pushers", + keys=("stream_id", "app_id", "pushkey", "user_id"), values=[ - { - "stream_id": stream_id, - "app_id": pusher.app_id, - "pushkey": pusher.pushkey, - "user_id": user_id, - } + (stream_id, pusher.app_id, pusher.pushkey, user_id) for stream_id, pusher in zip(stream_ids, pushers) ], ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 4175c82a25..aac94fa464 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -51,7 +51,7 @@ class ExternalIDReuseException(Exception): pass -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class TokenLookupResult: """Result of looking up an access token. @@ -69,14 +69,14 @@ class TokenLookupResult: cached. """ - user_id = attr.ib(type=str) - is_guest = attr.ib(type=bool, default=False) - shadow_banned = attr.ib(type=bool, default=False) - token_id = attr.ib(type=Optional[int], default=None) - device_id = attr.ib(type=Optional[str], default=None) - valid_until_ms = attr.ib(type=Optional[int], default=None) - token_owner = attr.ib(type=str) - token_used = attr.ib(type=bool, default=False) + user_id: str + is_guest: bool = False + shadow_banned: bool = False + token_id: Optional[int] = None + device_id: Optional[str] = None + valid_until_ms: Optional[int] = None + token_owner: str = attr.ib() + token_used: bool = False # Make the token owner default to the user ID, which is the common case. @token_owner.default diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 4ff6aed253..2cb5d06c13 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -13,14 +13,30 @@ # limitations under the License. import logging -from typing import List, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, + cast, +) import attr +from frozendict import frozendict -from synapse.api.constants import RelationTypes +from synapse.api.constants import EventTypes, RelationTypes from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.relations import ( AggregationPaginationToken, @@ -29,10 +45,24 @@ from synapse.storage.relations import ( ) from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class RelationsWorkerStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self._msc1849_enabled = hs.config.experimental.msc1849_enabled + self._msc3440_enabled = hs.config.experimental.msc3440_enabled + @cached(tree=True) async def get_relations_for_event( self, @@ -354,8 +384,7 @@ class RelationsWorkerStore(SQLBaseStore): async def get_thread_summary( self, event_id: str, room_id: str ) -> Tuple[int, Optional[EventBase]]: - """Get the number of threaded replies, the senders of those replies, and - the latest reply (if any) for the given event. + """Get the number of threaded replies and the latest reply (if any) for the given event. Args: event_id: Summarize the thread related to this event ID. @@ -368,7 +397,7 @@ class RelationsWorkerStore(SQLBaseStore): def _get_thread_summary_txn( txn: LoggingTransaction, ) -> Tuple[int, Optional[str]]: - # Fetch the count of threaded events and the latest event ID. + # Fetch the latest event ID in the thread. # TODO Should this only allow m.room.message events. sql = """ SELECT event_id @@ -389,6 +418,7 @@ class RelationsWorkerStore(SQLBaseStore): latest_event_id = row[0] + # Fetch the number of threaded replies. sql = """ SELECT COUNT(event_id) FROM event_relations @@ -413,6 +443,44 @@ class RelationsWorkerStore(SQLBaseStore): return count, latest_event + @cached() + async def get_thread_participated( + self, event_id: str, room_id: str, user_id: str + ) -> bool: + """Get whether the requesting user participated in a thread. + + This is separate from get_thread_summary since that can be cached across + all users while this value is specific to the requeser. + + Args: + event_id: The thread related to this event ID. + room_id: The room the event belongs to. + user_id: The user requesting the summary. + + Returns: + True if the requesting user participated in the thread, otherwise false. + """ + + def _get_thread_summary_txn(txn: LoggingTransaction) -> bool: + # Fetch whether the requester has participated or not. + sql = """ + SELECT 1 + FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id = ? + AND room_id = ? + AND relation_type = ? + AND sender = ? + """ + + txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id)) + return bool(txn.fetchone()) + + return await self.db_pool.runInteraction( + "get_thread_summary", _get_thread_summary_txn + ) + async def events_have_relations( self, parent_ids: List[str], @@ -515,6 +583,104 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) + async def _get_bundled_aggregation_for_event( + self, event: EventBase, user_id: str + ) -> Optional[Dict[str, Any]]: + """Generate bundled aggregations for an event. + + Note that this does not use a cache, but depends on cached methods. + + Args: + event: The event to calculate bundled aggregations for. + user_id: The user requesting the bundled aggregations. + + Returns: + The bundled aggregations for an event, if bundled aggregations are + enabled and the event can have bundled aggregations. + """ + # State events and redacted events do not get bundled aggregations. + if event.is_state() or event.internal_metadata.is_redacted(): + return None + + # Do not bundle aggregations for an event which represents an edit or an + # annotation. It does not make sense for them to have related events. + relates_to = event.content.get("m.relates_to") + if isinstance(relates_to, (dict, frozendict)): + relation_type = relates_to.get("rel_type") + if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): + return None + + event_id = event.event_id + room_id = event.room_id + + # The bundled aggregations to include, a mapping of relation type to a + # type-specific value. Some types include the direct return type here + # while others need more processing during serialization. + aggregations: Dict[str, Any] = {} + + annotations = await self.get_aggregation_groups_for_event(event_id, room_id) + if annotations.chunk: + aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() + + references = await self.get_relations_for_event( + event_id, room_id, RelationTypes.REFERENCE, direction="f" + ) + if references.chunk: + aggregations[RelationTypes.REFERENCE] = references.to_dict() + + edit = None + if event.type == EventTypes.Message: + edit = await self.get_applicable_edit(event_id, room_id) + + if edit: + aggregations[RelationTypes.REPLACE] = edit + + # If this event is the start of a thread, include a summary of the replies. + if self._msc3440_enabled: + thread_count, latest_thread_event = await self.get_thread_summary( + event_id, room_id + ) + participated = await self.get_thread_participated( + event_id, room_id, user_id + ) + if latest_thread_event: + aggregations[RelationTypes.THREAD] = { + "latest_event": latest_thread_event, + "count": thread_count, + "current_user_participated": participated, + } + + # Store the bundled aggregations in the event metadata for later use. + return aggregations + + async def get_bundled_aggregations( + self, + events: Iterable[EventBase], + user_id: str, + ) -> Dict[str, Dict[str, Any]]: + """Generate bundled aggregations for events. + + Args: + events: The iterable of events to calculate bundled aggregations for. + user_id: The user requesting the bundled aggregations. + + Returns: + A map of event ID to the bundled aggregation for the event. Not all + events may have bundled aggregations in the results. + """ + # If bundled aggregations are disabled, nothing to do. + if not self._msc1849_enabled: + return {} + + # TODO Parallelize. + results = {} + for event in events: + event_result = await self._get_bundled_aggregation_for_event(event, user_id) + if event_result is not None: + results[event.event_id] = event_result + + return results + class RelationsStore(RelationsWorkerStore): pass diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index c0e837854a..95167116c9 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -551,24 +551,24 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): FROM room_stats_state state INNER JOIN room_stats_current curr USING (room_id) INNER JOIN rooms USING (room_id) - %s - ORDER BY %s %s + {where} + ORDER BY {order_by} {direction}, state.room_id {direction} LIMIT ? OFFSET ? - """ % ( - where_statement, - order_by_column, - "ASC" if order_by_asc else "DESC", + """.format( + where=where_statement, + order_by=order_by_column, + direction="ASC" if order_by_asc else "DESC", ) # Use a nested SELECT statement as SQL can't count(*) with an OFFSET count_sql = """ SELECT count(*) FROM ( SELECT room_id FROM room_stats_state state - %s + {where} ) AS get_room_ids - """ % ( - where_statement, + """.format( + where=where_statement, ) def _get_rooms_paginate_txn( diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index cda80d6511..4489732fda 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -1177,18 +1177,18 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): await self.db_pool.runInteraction("forget_membership", f) -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class _JoinedHostsCache: """The cached data used by the `_get_joined_hosts_cache`.""" # Dict of host to the set of their users in the room at the state group. - hosts_to_joined_users = attr.ib(type=Dict[str, Set[str]], factory=dict) + hosts_to_joined_users: Dict[str, Set[str]] = attr.Factory(dict) # The state group `hosts_to_joined_users` is derived from. Will be an object # if the instance is newly created or if the state is not based on a state # group. (An object is used as a sentinel value to ensure that it never is # equal to anything else). - state_group = attr.ib(type=Union[object, int], factory=object) + state_group: Union[object, int] = attr.Factory(object) def __len__(self): return sum(len(v) for v in self.hosts_to_joined_users.values()) diff --git a/synapse/storage/databases/main/session.py b/synapse/storage/databases/main/session.py index 5a97120437..e8c776b97a 100644 --- a/synapse/storage/databases/main/session.py +++ b/synapse/storage/databases/main/session.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 6c299cafa5..4b78b4d098 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -560,3 +560,14 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): return await self.db_pool.runInteraction( "get_destinations_paginate_txn", get_destinations_paginate_txn ) + + async def is_destination_known(self, destination: str) -> bool: + """Check if a destination is known to the server.""" + result = await self.db_pool.simple_select_one_onecol( + table="destinations", + keyvalues={"destination": destination}, + retcol="1", + allow_none=True, + desc="is_destination_known", + ) + return bool(result) diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index a1a1a6a14a..2d339b6008 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -23,19 +23,19 @@ from synapse.types import JsonDict from synapse.util import json_encoder, stringutils -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class UIAuthSessionData: - session_id = attr.ib(type=str) + session_id: str # The dictionary from the client root level, not the 'auth' key. - clientdict = attr.ib(type=JsonDict) + clientdict: JsonDict # The URI and method the session was intiatied with. These are checked at # each stage of the authentication to ensure that the asked for operation # has not changed. - uri = attr.ib(type=str) - method = attr.ib(type=str) + uri: str + method: str # A string description of the operation that the current authentication is # authorising. - description = attr.ib(type=str) + description: str class UIAuthWorkerStore(SQLBaseStore): diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 0f9b8575d3..f7c778bdf2 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -105,8 +105,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): GROUP BY room_id """ txn.execute(sql) - rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()] - self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) + rooms = list(txn.fetchall()) + self.db_pool.simple_insert_many_txn( + txn, TEMP_TABLE + "_rooms", keys=("room_id", "events"), values=rooms + ) del rooms sql = ( @@ -117,9 +119,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): txn.execute(sql) txn.execute("SELECT name FROM users") - users = [{"user_id": x[0]} for x in txn.fetchall()] + users = list(txn.fetchall()) - self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) + self.db_pool.simple_insert_many_txn( + txn, TEMP_TABLE + "_users", keys=("user_id",), values=users + ) new_pos = await self.get_max_stream_id_in_current_state_deltas() await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index eb1118d2cb..5de70f31d2 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -327,14 +327,15 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): self.db_pool.simple_insert_many_txn( txn, table="state_groups_state", + keys=( + "state_group", + "room_id", + "type", + "state_key", + "event_id", + ), values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } + (state_group, room_id, key[0], key[1], state_id) for key, state_id in delta_state.items() ], ) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index c4c8c0021b..7614d76ac6 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -460,14 +460,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self.db_pool.simple_insert_many_txn( txn, table="state_groups_state", + keys=("state_group", "room_id", "type", "state_key", "event_id"), values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } + (state_group, room_id, key[0], key[1], state_id) for key, state_id in delta_ids.items() ], ) @@ -475,14 +470,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self.db_pool.simple_insert_many_txn( txn, table="state_groups_state", + keys=("state_group", "room_id", "type", "state_key", "event_id"), values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } + (state_group, room_id, key[0], key[1], state_id) for key, state_id in current_state_ids.items() ], ) @@ -589,14 +579,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self.db_pool.simple_insert_many_txn( txn, table="state_groups_state", + keys=("state_group", "room_id", "type", "state_key", "event_id"), values=[ - { - "state_group": sg, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } + (sg, room_id, key[0], key[1], state_id) for key, state_id in curr_state.items() ], ) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 540adb8781..71584f3f74 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -21,7 +21,7 @@ from signedjson.types import VerifyKey logger = logging.getLogger(__name__) -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class FetchKeyResult: - verify_key = attr.ib(type=VerifyKey) # the key itself - valid_until_ts = attr.ib(type=int) # how long we can use this key for + verify_key: VerifyKey # the key itself + valid_until_ts: int # how long we can use this key for diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index e45adfcb55..1823e18720 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -696,7 +696,7 @@ def _get_or_create_schema_state( ) -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class _DirectoryListing: """Helper class to store schema file name and the absolute path to it. @@ -705,5 +705,5 @@ class _DirectoryListing: `file_name` attr is kept first. """ - file_name = attr.ib(type=str) - absolute_path = attr.ib(type=str) + file_name: str + absolute_path: str diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 10a46b5e82..b1536c1ca4 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -23,7 +23,7 @@ from synapse.types import JsonDict logger = logging.getLogger(__name__) -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class PaginationChunk: """Returned by relation pagination APIs. @@ -35,9 +35,9 @@ class PaginationChunk: None then there are no previous results. """ - chunk = attr.ib(type=List[JsonDict]) - next_batch = attr.ib(type=Optional[Any], default=None) - prev_batch = attr.ib(type=Optional[Any], default=None) + chunk: List[JsonDict] + next_batch: Optional[Any] = None + prev_batch: Optional[Any] = None def to_dict(self) -> Dict[str, Any]: d = {"chunk": self.chunk} @@ -51,7 +51,7 @@ class PaginationChunk: return d -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class RelationPaginationToken: """Pagination token for relation pagination API. @@ -64,8 +64,8 @@ class RelationPaginationToken: stream: The stream ordering of the boundary event. """ - topological = attr.ib(type=int) - stream = attr.ib(type=int) + topological: int + stream: int @staticmethod def from_string(string: str) -> "RelationPaginationToken": @@ -82,7 +82,7 @@ class RelationPaginationToken: return attr.astuple(self) -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class AggregationPaginationToken: """Pagination token for relation aggregation pagination API. @@ -94,8 +94,8 @@ class AggregationPaginationToken: stream: The MAX stream ordering in the boundary group. """ - count = attr.ib(type=int) - stream = attr.ib(type=int) + count: int + stream: int @staticmethod def from_string(string: str) -> "AggregationPaginationToken": diff --git a/synapse/storage/state.py b/synapse/storage/state.py index b5ba1560d1..df8b2f1088 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -45,7 +45,7 @@ logger = logging.getLogger(__name__) T = TypeVar("T") -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class StateFilter: """A filter used when querying for state. @@ -58,8 +58,8 @@ class StateFilter: appear in `types`. """ - types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]") - include_others = attr.ib(default=False, type=bool) + types: "frozendict[str, Optional[FrozenSet[str]]]" + include_others: bool = False def __attrs_post_init__(self): # If `include_others` is set we canonicalise the filter by removing diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index b8112e1c05..3c13859faa 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -762,13 +762,13 @@ class _AsyncCtxManagerWrapper(Generic[T]): return self.inner.__exit__(exc_type, exc, tb) -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class _MultiWriterCtxManager: """Async context manager returned by MultiWriterIdGenerator""" - id_gen = attr.ib(type=MultiWriterIdGenerator) - multiple_ids = attr.ib(type=Optional[int], default=None) - stream_ids = attr.ib(type=List[int], factory=list) + id_gen: MultiWriterIdGenerator + multiple_ids: Optional[int] = None + stream_ids: List[int] = attr.Factory(list) async def __aenter__(self) -> Union[int, List[int]]: # It's safe to run this in autocommit mode as fetching values from a diff --git a/synapse/streams/config.py b/synapse/streams/config.py index c08d591f29..b52723e2b8 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -28,14 +28,14 @@ logger = logging.getLogger(__name__) MAX_LIMIT = 1000 -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class PaginationConfig: """A configuration object which stores pagination parameters.""" - from_token = attr.ib(type=Optional[StreamToken]) - to_token = attr.ib(type=Optional[StreamToken]) - direction = attr.ib(type=str) - limit = attr.ib(type=Optional[int]) + from_token: Optional[StreamToken] + to_token: Optional[StreamToken] + direction: str + limit: Optional[int] @classmethod async def from_request( diff --git a/synapse/types.py b/synapse/types.py index 42aeaf6270..f89fb216a6 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -20,7 +20,9 @@ from typing import ( Any, ClassVar, Dict, + List, Mapping, + Match, MutableMapping, Optional, Tuple, @@ -79,7 +81,7 @@ class ISynapseReactor( """The interfaces necessary for Synapse to function.""" -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class Requester: """ Represents the user making a request @@ -97,13 +99,13 @@ class Requester: "puppeting" the user. """ - user = attr.ib(type="UserID") - access_token_id = attr.ib(type=Optional[int]) - is_guest = attr.ib(type=bool) - shadow_banned = attr.ib(type=bool) - device_id = attr.ib(type=Optional[str]) - app_service = attr.ib(type=Optional["ApplicationService"]) - authenticated_entity = attr.ib(type=str) + user: "UserID" + access_token_id: Optional[int] + is_guest: bool + shadow_banned: bool + device_id: Optional[str] + app_service: Optional["ApplicationService"] + authenticated_entity: str def serialize(self): """Converts self to a type that can be serialized as JSON, and then @@ -210,7 +212,7 @@ def get_localpart_from_id(string: str) -> str: DS = TypeVar("DS", bound="DomainSpecificString") -@attr.s(slots=True, frozen=True, repr=False) +@attr.s(slots=True, frozen=True, repr=False, auto_attribs=True) class DomainSpecificString(metaclass=abc.ABCMeta): """Common base class among ID/name strings that have a local part and a domain name, prefixed with a sigil. @@ -223,8 +225,8 @@ class DomainSpecificString(metaclass=abc.ABCMeta): SIGIL: ClassVar[str] = abc.abstractproperty() # type: ignore - localpart = attr.ib(type=str) - domain = attr.ib(type=str) + localpart: str + domain: str # Because this is a frozen class, it is deeply immutable. def __copy__(self): @@ -380,7 +382,7 @@ def map_username_to_mxid_localpart( onto different mxids Returns: - unicode: string suitable for a mxid localpart + string suitable for a mxid localpart """ if not isinstance(username, bytes): username = username.encode("utf-8") @@ -388,29 +390,23 @@ def map_username_to_mxid_localpart( # first we sort out upper-case characters if case_sensitive: - def f1(m): + def f1(m: Match[bytes]) -> bytes: return b"_" + m.group().lower() username = UPPER_CASE_PATTERN.sub(f1, username) else: username = username.lower() - # then we sort out non-ascii characters - def f2(m): - g = m.group()[0] - if isinstance(g, str): - # on python 2, we need to do a ord(). On python 3, the - # byte itself will do. - g = ord(g) - return b"=%02x" % (g,) + # then we sort out non-ascii characters by converting to the hex equivalent. + def f2(m: Match[bytes]) -> bytes: + return b"=%02x" % (m.group()[0],) username = NON_MXID_CHARACTER_PATTERN.sub(f2, username) # we also do the =-escaping to mxids starting with an underscore. username = re.sub(b"^_", b"=5f", username) - # we should now only have ascii bytes left, so can decode back to a - # unicode. + # we should now only have ascii bytes left, so can decode back to a string. return username.decode("ascii") @@ -466,14 +462,12 @@ class RoomStreamToken: attributes, must be hashable. """ - topological = attr.ib( - type=Optional[int], + topological: Optional[int] = attr.ib( validator=attr.validators.optional(attr.validators.instance_of(int)), ) - stream = attr.ib(type=int, validator=attr.validators.instance_of(int)) + stream: int = attr.ib(validator=attr.validators.instance_of(int)) - instance_map = attr.ib( - type="frozendict[str, int]", + instance_map: "frozendict[str, int]" = attr.ib( factory=frozendict, validator=attr.validators.deep_mapping( key_validator=attr.validators.instance_of(str), @@ -482,7 +476,7 @@ class RoomStreamToken: ), ) - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: """Validates that both `topological` and `instance_map` aren't set.""" if self.instance_map and self.topological: @@ -598,7 +592,7 @@ class RoomStreamToken: return "s%d" % (self.stream,) -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class StreamToken: """A collection of positions within multiple streams. @@ -606,20 +600,20 @@ class StreamToken: must be hashable. """ - room_key = attr.ib( - type=RoomStreamToken, validator=attr.validators.instance_of(RoomStreamToken) + room_key: RoomStreamToken = attr.ib( + validator=attr.validators.instance_of(RoomStreamToken) ) - presence_key = attr.ib(type=int) - typing_key = attr.ib(type=int) - receipt_key = attr.ib(type=int) - account_data_key = attr.ib(type=int) - push_rules_key = attr.ib(type=int) - to_device_key = attr.ib(type=int) - device_list_key = attr.ib(type=int) - groups_key = attr.ib(type=int) + presence_key: int + typing_key: int + receipt_key: int + account_data_key: int + push_rules_key: int + to_device_key: int + device_list_key: int + groups_key: int _SEPARATOR = "_" - START: "StreamToken" + START: ClassVar["StreamToken"] @classmethod async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": @@ -679,7 +673,7 @@ class StreamToken: StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0) -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class PersistedEventPosition: """Position of a newly persisted event with instance that persisted it. @@ -687,8 +681,8 @@ class PersistedEventPosition: RoomStreamToken. """ - instance_name = attr.ib(type=str) - stream = attr.ib(type=int) + instance_name: str + stream: int def persisted_after(self, token: RoomStreamToken) -> bool: return token.get_stream_pos_for_instance(self.instance_name) < self.stream @@ -738,15 +732,15 @@ class ThirdPartyInstanceID: __str__ = to_string -@attr.s(slots=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class ReadReceipt: """Information about a read-receipt""" - room_id = attr.ib() - receipt_type = attr.ib() - user_id = attr.ib() - event_ids = attr.ib() - data = attr.ib() + room_id: str + receipt_type: str + user_id: str + event_ids: List[str] + data: JsonDict def get_verify_key_from_cross_signing_key(key_info): diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 150a04b53e..3f7299aff7 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -309,12 +309,12 @@ def gather_results( # type: ignore[misc] return deferred.addCallback(tuple) -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class _LinearizerEntry: # The number of things executing. - count = attr.ib(type=int) + count: int # Deferreds for the things blocked from executing. - deferreds = attr.ib(type=collections.OrderedDict) + deferreds: collections.OrderedDict class Linearizer: diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 485ddb1893..d267703df0 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -33,7 +33,7 @@ DV = TypeVar("DV") # This class can't be generic because it uses slots with attrs. # See: https://github.com/python-attrs/attrs/issues/313 -@attr.s(slots=True) +@attr.s(slots=True, auto_attribs=True) class DictionaryEntry: # should be: Generic[DKT, DV]. """Returned when getting an entry from the cache @@ -41,14 +41,13 @@ class DictionaryEntry: # should be: Generic[DKT, DV]. full: Whether the cache has the full or dict or just some keys. If not full then not all requested keys will necessarily be present in `value` - known_absent: Keys that were looked up in the dict and were not - there. + known_absent: Keys that were looked up in the dict and were not there. value: The full or partial dict value """ - full = attr.ib(type=bool) - known_absent = attr.ib(type=Set[Any]) # should be: Set[DKT] - value = attr.ib(type=Dict[Any, Any]) # should be: Dict[DKT, DV] + full: bool + known_absent: Set[Any] # should be: Set[DKT] + value: Dict[Any, Any] # should be: Dict[DKT, DV] def __len__(self) -> int: return len(self.value) diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index a2dfa1ed05..4b53b6d40b 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -274,6 +274,39 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertEquals(failure.value.code, 400) self.assertEquals(failure.value.errcode, Codes.EXCLUSIVE) + def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self): + self.store.get_user_by_access_token = simple_async_mock( + TokenLookupResult( + user_id="@baldrick:matrix.org", + device_id="device", + token_owner="@admin:matrix.org", + ) + ) + self.store.insert_client_ip = simple_async_mock(None) + request = Mock(args={}) + request.getClientIP.return_value = "127.0.0.1" + request.args[b"access_token"] = [self.test_token] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + self.get_success(self.auth.get_user_by_req(request)) + self.store.insert_client_ip.assert_called_once() + + def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self): + self.auth._track_puppeted_user_ips = True + self.store.get_user_by_access_token = simple_async_mock( + TokenLookupResult( + user_id="@baldrick:matrix.org", + device_id="device", + token_owner="@admin:matrix.org", + ) + ) + self.store.insert_client_ip = simple_async_mock(None) + request = Mock(args={}) + request.getClientIP.return_value = "127.0.0.1" + request.args[b"access_token"] = [self.test_token] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + self.get_success(self.auth.get_user_by_req(request)) + self.assertEquals(self.store.insert_client_ip.call_count, 2) + def test_get_user_from_macaroon(self): self.store.get_user_by_access_token = simple_async_mock( TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device") diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index ddcf3ee348..734ed84d78 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -13,8 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Iterable from unittest import mock +from parameterized import parameterized from signedjson import key as key, sign as sign from twisted.internet import defer @@ -23,6 +25,7 @@ from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.errors import Codes, SynapseError from tests import unittest +from tests.test_utils import make_awaitable class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): @@ -765,6 +768,8 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): remote_user_id = "@test:other" local_user_id = "@test:test" + # Pretend we're sharing a room with the user we're querying. If not, + # `_query_devices_for_destination` will return early. self.store.get_rooms_for_user = mock.Mock( return_value=defer.succeed({"some_room_id"}) ) @@ -831,3 +836,94 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): } }, ) + + @parameterized.expand( + [ + # The remote homeserver's response indicates that this user has 0/1/2 devices. + ([],), + (["device_1"],), + (["device_1", "device_2"],), + ] + ) + def test_query_all_devices_caches_result(self, device_ids: Iterable[str]): + """Test that requests for all of a remote user's devices are cached. + + We do this by asserting that only one call over federation was made, and that + the two queries to the local homeserver produce the same response. + """ + local_user_id = "@test:test" + remote_user_id = "@test:other" + request_body = {"device_keys": {remote_user_id: []}} + + response_devices = [ + { + "device_id": device_id, + "keys": { + "algorithms": ["dummy"], + "device_id": device_id, + "keys": {f"dummy:{device_id}": "dummy"}, + "signatures": {device_id: {f"dummy:{device_id}": "dummy"}}, + "unsigned": {}, + "user_id": "@test:other", + }, + } + for device_id in device_ids + ] + + response_body = { + "devices": response_devices, + "user_id": remote_user_id, + "stream_id": 12345, # an integer, according to the spec + } + + e2e_handler = self.hs.get_e2e_keys_handler() + + # Pretend we're sharing a room with the user we're querying. If not, + # `_query_devices_for_destination` will return early. + mock_get_rooms = mock.patch.object( + self.store, + "get_rooms_for_user", + new_callable=mock.MagicMock, + return_value=make_awaitable(["some_room_id"]), + ) + mock_request = mock.patch.object( + self.hs.get_federation_client(), + "query_user_devices", + new_callable=mock.MagicMock, + return_value=make_awaitable(response_body), + ) + + with mock_get_rooms, mock_request as mocked_federation_request: + # Make the first query and sanity check it succeeds. + response_1 = self.get_success( + e2e_handler.query_devices( + request_body, + timeout=10, + from_user_id=local_user_id, + from_device_id="some_device_id", + ) + ) + self.assertEqual(response_1["failures"], {}) + + # We should have made a federation request to do so. + mocked_federation_request.assert_called_once() + + # Reset the mock so we can prove we don't make a second federation request. + mocked_federation_request.reset_mock() + + # Repeat the query. + response_2 = self.get_success( + e2e_handler.query_devices( + request_body, + timeout=10, + from_user_id=local_user_id, + from_device_id="some_device_id", + ) + ) + self.assertEqual(response_2["failures"], {}) + + # We should not have made a second federation request. + mocked_federation_request.assert_not_called() + + # The two requests to the local homeserver should be identical. + self.assertEqual(response_1, response_2) diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 08e9730d4d..2add72b28a 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -22,7 +22,7 @@ from twisted.internet import defer import synapse from synapse.handlers.auth import load_legacy_password_auth_providers from synapse.module_api import ModuleApi -from synapse.rest.client import devices, login +from synapse.rest.client import devices, login, logout from synapse.types import JsonDict from tests import unittest @@ -155,6 +155,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): synapse.rest.admin.register_servlets, login.register_servlets, devices.register_servlets, + logout.register_servlets, ] def setUp(self): @@ -719,6 +720,31 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): channel = self._send_password_login("localuser", "localpass") self.assertEqual(channel.code, 400, channel.result) + def test_on_logged_out(self): + """Tests that the on_logged_out callback is called when the user logs out.""" + self.register_user("rin", "password") + tok = self.login("rin", "password") + + self.called = False + + async def on_logged_out(user_id, device_id, access_token): + self.called = True + + on_logged_out = Mock(side_effect=on_logged_out) + self.hs.get_password_auth_provider().on_logged_out_callbacks.append( + on_logged_out + ) + + channel = self.make_request( + "POST", + "/_matrix/client/v3/logout", + {}, + access_token=tok, + ) + self.assertEqual(channel.code, 200) + on_logged_out.assert_called_once() + self.assertTrue(self.called) + def _get_login_flows(self) -> JsonDict: channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index e5a6a6c747..51b22d2998 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -28,6 +28,7 @@ from synapse.api.constants import ( from synapse.api.errors import AuthError, NotFoundError, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict +from synapse.federation.transport.client import TransportLayerClient from synapse.handlers.room_summary import _child_events_comparison_key, _RoomEntry from synapse.rest import admin from synapse.rest.client import login, room @@ -134,10 +135,18 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._add_child(self.space, self.room, self.token) def _add_child( - self, space_id: str, room_id: str, token: str, order: Optional[str] = None + self, + space_id: str, + room_id: str, + token: str, + order: Optional[str] = None, + via: Optional[List[str]] = None, ) -> None: """Add a child room to a space.""" - content: JsonDict = {"via": [self.hs.hostname]} + if via is None: + via = [self.hs.hostname] + + content: JsonDict = {"via": via} if order is not None: content["order"] = order self.helper.send_state( @@ -253,6 +262,38 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self._assert_hierarchy(result, expected) + def test_large_space(self): + """Test a space with a large number of rooms.""" + rooms = [self.room] + # Make at least 51 rooms that are part of the space. + for _ in range(55): + room = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, room, self.token) + rooms.append(room) + + result = self.get_success(self.handler.get_space_summary(self.user, self.space)) + # The spaces result should have the space and the first 50 rooms in it, + # along with the links from space -> room for those 50 rooms. + expected = [(self.space, rooms[:50])] + [(room, []) for room in rooms[:49]] + self._assert_rooms(result, expected) + + # The result should have the space and the rooms in it, along with the links + # from space -> room. + expected = [(self.space, rooms)] + [(room, []) for room in rooms] + + # Make two requests to fully paginate the results. + result = self.get_success( + self.handler.get_room_hierarchy(create_requester(self.user), self.space) + ) + result2 = self.get_success( + self.handler.get_room_hierarchy( + create_requester(self.user), self.space, from_token=result["next_batch"] + ) + ) + # Combine the results. + result["rooms"] += result2["rooms"] + self._assert_hierarchy(result, expected) + def test_visibility(self): """A user not in a space cannot inspect it.""" user2 = self.register_user("user2", "pass") @@ -1004,6 +1045,85 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self._assert_hierarchy(result, expected) + def test_fed_caching(self): + """ + Federation `/hierarchy` responses should be cached. + """ + fed_hostname = self.hs.hostname + "2" + fed_subspace = "#space:" + fed_hostname + fed_room = "#room:" + fed_hostname + + # Add a room to the space which is on another server. + self._add_child(self.space, fed_subspace, self.token, via=[fed_hostname]) + + federation_requests = 0 + + async def get_room_hierarchy( + _self: TransportLayerClient, + destination: str, + room_id: str, + suggested_only: bool, + ) -> JsonDict: + nonlocal federation_requests + federation_requests += 1 + + return { + "room": { + "room_id": fed_subspace, + "world_readable": True, + "room_type": RoomTypes.SPACE, + "children_state": [ + { + "type": EventTypes.SpaceChild, + "room_id": fed_subspace, + "state_key": fed_room, + "content": {"via": [fed_hostname]}, + }, + ], + }, + "children": [ + { + "room_id": fed_room, + "world_readable": True, + }, + ], + "inaccessible_children": [], + } + + expected = [ + (self.space, [self.room, fed_subspace]), + (self.room, ()), + (fed_subspace, [fed_room]), + (fed_room, ()), + ] + + with mock.patch( + "synapse.federation.transport.client.TransportLayerClient.get_room_hierarchy", + new=get_room_hierarchy, + ): + result = self.get_success( + self.handler.get_room_hierarchy(create_requester(self.user), self.space) + ) + self.assertEqual(federation_requests, 1) + self._assert_hierarchy(result, expected) + + # The previous federation response should be reused. + result = self.get_success( + self.handler.get_room_hierarchy(create_requester(self.user), self.space) + ) + self.assertEqual(federation_requests, 1) + self._assert_hierarchy(result, expected) + + # Expire the response cache + self.reactor.advance(5 * 60 + 1) + + # A new federation request should be made. + result = self.get_success( + self.handler.get_room_hierarchy(create_requester(self.user), self.space) + ) + self.assertEqual(federation_requests, 2) + self._assert_hierarchy(result, expected) + class RoomSummaryTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 638186f173..07a760e91a 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -11,15 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Optional -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock, patch from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import Codes, ResourceLimitError from synapse.api.filtering import Filtering from synapse.api.room_versions import RoomVersions -from synapse.handlers.sync import SyncConfig +from synapse.handlers.sync import SyncConfig, SyncResult from synapse.rest import admin from synapse.rest.client import knock, login, room from synapse.server import HomeServer @@ -27,6 +26,7 @@ from synapse.types import UserID, create_requester import tests.unittest import tests.utils +from tests.test_utils import make_awaitable class SyncTestCase(tests.unittest.HomeserverTestCase): @@ -186,6 +186,97 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.assertNotIn(invite_room, [r.room_id for r in result.invited]) self.assertNotIn(knock_room, [r.room_id for r in result.knocked]) + def test_ban_wins_race_with_join(self): + """Rooms shouldn't appear under "joined" if a join loses a race to a ban. + + A complicated edge case. Imagine the following scenario: + + * you attempt to join a room + * racing with that is a ban which comes in over federation, which ends up with + an earlier stream_ordering than the join. + * you get a sync response with a sync token which is _after_ the ban, but before + the join + * now your join lands; it is a valid event because its `prev_event`s predate the + ban, but will not make it into current_state_events (because bans win over + joins in state res, essentially). + * When we do a sync from the incremental sync, the only event in the timeline + is your join ... and yet you aren't joined. + + The ban coming in over federation isn't crucial for this behaviour; the key + requirements are: + 1. the homeserver generates a join event with prev_events that precede the ban + (so that it passes the "are you banned" test) + 2. the join event has a stream_ordering after that of the ban. + + We use monkeypatching to artificially trigger condition (1). + """ + # A local user Alice creates a room. + owner = self.register_user("alice", "password") + owner_tok = self.login(owner, "password") + room_id = self.helper.create_room_as(owner, is_public=True, tok=owner_tok) + + # Do a sync as Alice to get the latest event in the room. + alice_sync_result: SyncResult = self.get_success( + self.sync_handler.wait_for_sync_for_user( + create_requester(owner), generate_sync_config(owner) + ) + ) + self.assertEqual(len(alice_sync_result.joined), 1) + self.assertEqual(alice_sync_result.joined[0].room_id, room_id) + last_room_creation_event_id = ( + alice_sync_result.joined[0].timeline.events[-1].event_id + ) + + # Eve, a ne'er-do-well, registers. + eve = self.register_user("eve", "password") + eve_token = self.login(eve, "password") + + # Alice preemptively bans Eve. + self.helper.ban(room_id, owner, eve, tok=owner_tok) + + # Eve syncs. + eve_requester = create_requester(eve) + eve_sync_config = generate_sync_config(eve) + eve_sync_after_ban: SyncResult = self.get_success( + self.sync_handler.wait_for_sync_for_user(eve_requester, eve_sync_config) + ) + + # Sanity check this sync result. We shouldn't be joined to the room. + self.assertEqual(eve_sync_after_ban.joined, []) + + # Eve tries to join the room. We monkey patch the internal logic which selects + # the prev_events used when creating the join event, such that the ban does not + # precede the join. + mocked_get_prev_events = patch.object( + self.hs.get_datastore(), + "get_prev_events_for_room", + new_callable=MagicMock, + return_value=make_awaitable([last_room_creation_event_id]), + ) + with mocked_get_prev_events: + self.helper.join(room_id, eve, tok=eve_token) + + # Eve makes a second, incremental sync. + eve_incremental_sync_after_join: SyncResult = self.get_success( + self.sync_handler.wait_for_sync_for_user( + eve_requester, + eve_sync_config, + since_token=eve_sync_after_ban.next_batch, + ) + ) + # Eve should not see herself as joined to the room. + self.assertEqual(eve_incremental_sync_after_join.joined, []) + + # If we did a third initial sync, we should _still_ see eve is not joined to the room. + eve_initial_sync_after_join: SyncResult = self.get_success( + self.sync_handler.wait_for_sync_for_user( + eve_requester, + eve_sync_config, + since_token=None, + ) + ) + self.assertEqual(eve_initial_sync_after_join.joined, []) + _request_key = 0 diff --git a/tests/http/test_webclient.py b/tests/http/test_webclient.py new file mode 100644 index 0000000000..ee5cf299f6 --- /dev/null +++ b/tests/http/test_webclient.py @@ -0,0 +1,108 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from http import HTTPStatus +from typing import Dict + +from twisted.web.resource import Resource + +from synapse.app.homeserver import SynapseHomeServer +from synapse.config.server import HttpListenerConfig, HttpResourceConfig, ListenerConfig +from synapse.http.site import SynapseSite + +from tests.server import make_request +from tests.unittest import HomeserverTestCase, create_resource_tree, override_config + + +class WebClientTests(HomeserverTestCase): + @override_config( + { + "web_client_location": "https://example.org", + } + ) + def test_webclient_resolves_with_client_resource(self): + """ + Tests that both client and webclient resources can be accessed simultaneously. + + This is a regression test created in response to https://github.com/matrix-org/synapse/issues/11763. + """ + for resource_name_order_list in [ + ["webclient", "client"], + ["client", "webclient"], + ]: + # Create a dictionary from path regex -> resource + resource_dict: Dict[str, Resource] = {} + + for resource_name in resource_name_order_list: + resource_dict.update( + SynapseHomeServer._configure_named_resource(self.hs, resource_name) + ) + + # Create a root resource which ties the above resources together into one + root_resource = Resource() + create_resource_tree(resource_dict, root_resource) + + # Create a site configured with this resource to make HTTP requests against + listener_config = ListenerConfig( + port=8008, + bind_addresses=["127.0.0.1"], + type="http", + http_options=HttpListenerConfig( + resources=[HttpResourceConfig(names=resource_name_order_list)] + ), + ) + test_site = SynapseSite( + logger_name="synapse.access.http.fake", + site_tag=self.hs.config.server.server_name, + config=listener_config, + resource=root_resource, + server_version_string="1", + max_request_body_size=1234, + reactor=self.reactor, + ) + + # Attempt to make requests to endpoints on both the webclient and client resources + # on test_site. + self._request_client_and_webclient_resources(test_site) + + def _request_client_and_webclient_resources(self, test_site: SynapseSite) -> None: + """Make a request to an endpoint on both the webclient and client-server resources + of the given SynapseSite. + + Args: + test_site: The SynapseSite object to make requests against. + """ + + # Ensure that the *webclient* resource is behaving as expected (we get redirected to + # the configured web_client_location) + channel = make_request( + self.reactor, + site=test_site, + method="GET", + path="/_matrix/client", + ) + # Check that we are being redirected to the webclient location URI. + self.assertEqual(channel.code, HTTPStatus.FOUND) + self.assertEqual( + channel.headers.getRawHeaders("Location"), ["https://example.org"] + ) + + # Ensure that a request to the *client* resource works. + channel = make_request( + self.reactor, + site=test_site, + method="GET", + path="/_matrix/client/v3/login", + ) + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertIn("flows", channel.json_body) diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index 742f194257..b70350b6f1 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -314,15 +314,12 @@ class FederationTestCase(unittest.HomeserverTestCase): retry_interval, last_successful_stream_ordering, ) in dest: - self.get_success( - self.store.set_destination_retry_timings( - destination, failure_ts, retry_last_ts, retry_interval - ) - ) - self.get_success( - self.store.set_destination_last_successful_stream_ordering( - destination, last_successful_stream_ordering - ) + self._create_destination( + destination, + failure_ts, + retry_last_ts, + retry_interval, + last_successful_stream_ordering, ) # order by default (destination) @@ -413,11 +410,9 @@ class FederationTestCase(unittest.HomeserverTestCase): _search_test(None, "foo") _search_test(None, "bar") - def test_get_single_destination(self) -> None: - """ - Get one specific destinations. - """ - self._create_destinations(5) + def test_get_single_destination_with_retry_timings(self) -> None: + """Get one specific destination which has retry timings.""" + self._create_destinations(1) channel = self.make_request( "GET", @@ -432,6 +427,53 @@ class FederationTestCase(unittest.HomeserverTestCase): # convert channel.json_body into a List self._check_fields([channel.json_body]) + def test_get_single_destination_no_retry_timings(self) -> None: + """Get one specific destination which has no retry timings.""" + self._create_destination("sub0.example.com") + + channel = self.make_request( + "GET", + self.url + "/sub0.example.com", + access_token=self.admin_user_tok, + ) + + self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) + self.assertEqual("sub0.example.com", channel.json_body["destination"]) + self.assertEqual(0, channel.json_body["retry_last_ts"]) + self.assertEqual(0, channel.json_body["retry_interval"]) + self.assertIsNone(channel.json_body["failure_ts"]) + self.assertIsNone(channel.json_body["last_successful_stream_ordering"]) + + def _create_destination( + self, + destination: str, + failure_ts: Optional[int] = None, + retry_last_ts: int = 0, + retry_interval: int = 0, + last_successful_stream_ordering: Optional[int] = None, + ) -> None: + """Create one specific destination + + Args: + destination: the destination we have successfully sent to + failure_ts: when the server started failing (ms since epoch) + retry_last_ts: time of last retry attempt in unix epoch ms + retry_interval: how long until next retry in ms + last_successful_stream_ordering: the stream_ordering of the most + recent successfully-sent PDU + """ + self.get_success( + self.store.set_destination_retry_timings( + destination, failure_ts, retry_last_ts, retry_interval + ) + ) + if last_successful_stream_ordering is not None: + self.get_success( + self.store.set_destination_last_successful_stream_ordering( + destination, last_successful_stream_ordering + ) + ) + def _create_destinations(self, number_destinations: int) -> None: """Create a number of destinations @@ -440,10 +482,7 @@ class FederationTestCase(unittest.HomeserverTestCase): """ for i in range(0, number_destinations): dest = f"sub{i}.example.com" - self.get_success(self.store.set_destination_retry_timings(dest, 50, 50, 50)) - self.get_success( - self.store.set_destination_last_successful_stream_ordering(dest, 100) - ) + self._create_destination(dest, 50, 50, 50, 100) def _check_fields(self, content: List[JsonDict]) -> None: """Checks that the expected destination attributes are present in content diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index 81f3ac7f04..8513b1d2df 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -223,20 +223,13 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): # Create all possible single character tokens tokens = [] for c in string.ascii_letters + string.digits + "._~-": - tokens.append( - { - "token": c, - "uses_allowed": None, - "pending": 0, - "completed": 0, - "expiry_time": None, - } - ) + tokens.append((c, None, 0, 0, None)) self.get_success( self.store.db_pool.simple_insert_many( "registration_tokens", - tokens, - "create_all_registration_tokens", + keys=("token", "uses_allowed", "pending", "completed", "expiry_time"), + values=tokens, + desc="create_all_registration_tokens", ) ) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index d2c8781cd4..3495a0366a 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1089,6 +1089,8 @@ class RoomTestCase(unittest.HomeserverTestCase): ) room_ids.append(room_id) + room_ids.sort() + # Request the list of rooms url = "/_synapse/admin/v1/rooms" channel = self.make_request( @@ -1360,6 +1362,12 @@ class RoomTestCase(unittest.HomeserverTestCase): room_id_2 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) room_id_3 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + # Also create a list sorted by IDs for properties that are equal (and thus sorted by room_id) + sorted_by_room_id_asc = [room_id_1, room_id_2, room_id_3] + sorted_by_room_id_asc.sort() + sorted_by_room_id_desc = sorted_by_room_id_asc.copy() + sorted_by_room_id_desc.reverse() + # Set room names in alphabetical order. room 1 -> A, 2 -> B, 3 -> C self.helper.send_state( room_id_1, @@ -1405,41 +1413,42 @@ class RoomTestCase(unittest.HomeserverTestCase): _order_test("canonical_alias", [room_id_1, room_id_2, room_id_3]) _order_test("canonical_alias", [room_id_3, room_id_2, room_id_1], reverse=True) + # Note: joined_member counts are sorted in descending order when dir=f _order_test("joined_members", [room_id_3, room_id_2, room_id_1]) _order_test("joined_members", [room_id_1, room_id_2, room_id_3], reverse=True) + # Note: joined_local_member counts are sorted in descending order when dir=f _order_test("joined_local_members", [room_id_3, room_id_2, room_id_1]) _order_test( "joined_local_members", [room_id_1, room_id_2, room_id_3], reverse=True ) - _order_test("version", [room_id_1, room_id_2, room_id_3]) - _order_test("version", [room_id_1, room_id_2, room_id_3], reverse=True) + # Note: versions are sorted in descending order when dir=f + _order_test("version", sorted_by_room_id_asc, reverse=True) + _order_test("version", sorted_by_room_id_desc) - _order_test("creator", [room_id_1, room_id_2, room_id_3]) - _order_test("creator", [room_id_1, room_id_2, room_id_3], reverse=True) + _order_test("creator", sorted_by_room_id_asc) + _order_test("creator", sorted_by_room_id_desc, reverse=True) - _order_test("encryption", [room_id_1, room_id_2, room_id_3]) - _order_test("encryption", [room_id_1, room_id_2, room_id_3], reverse=True) + _order_test("encryption", sorted_by_room_id_asc) + _order_test("encryption", sorted_by_room_id_desc, reverse=True) - _order_test("federatable", [room_id_1, room_id_2, room_id_3]) - _order_test("federatable", [room_id_1, room_id_2, room_id_3], reverse=True) + _order_test("federatable", sorted_by_room_id_asc) + _order_test("federatable", sorted_by_room_id_desc, reverse=True) - _order_test("public", [room_id_1, room_id_2, room_id_3]) - # Different sort order of SQlite and PostreSQL - # _order_test("public", [room_id_3, room_id_2, room_id_1], reverse=True) + _order_test("public", sorted_by_room_id_asc) + _order_test("public", sorted_by_room_id_desc, reverse=True) - _order_test("join_rules", [room_id_1, room_id_2, room_id_3]) - _order_test("join_rules", [room_id_1, room_id_2, room_id_3], reverse=True) + _order_test("join_rules", sorted_by_room_id_asc) + _order_test("join_rules", sorted_by_room_id_desc, reverse=True) - _order_test("guest_access", [room_id_1, room_id_2, room_id_3]) - _order_test("guest_access", [room_id_1, room_id_2, room_id_3], reverse=True) + _order_test("guest_access", sorted_by_room_id_asc) + _order_test("guest_access", sorted_by_room_id_desc, reverse=True) - _order_test("history_visibility", [room_id_1, room_id_2, room_id_3]) - _order_test( - "history_visibility", [room_id_1, room_id_2, room_id_3], reverse=True - ) + _order_test("history_visibility", sorted_by_room_id_asc) + _order_test("history_visibility", sorted_by_room_id_desc, reverse=True) + # Note: state_event counts are sorted in descending order when dir=f _order_test("state_events", [room_id_3, room_id_2, room_id_1]) _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index e0b9fe8e91..9711405735 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1181,6 +1181,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.other_user, device_id=None, valid_until_ms=None ) ) + self.url_prefix = "/_synapse/admin/v2/users/%s" self.url_other_user = self.url_prefix % self.other_user @@ -1188,7 +1189,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ If the user is not a server admin, an error is returned. """ - url = "/_synapse/admin/v2/users/@bob:test" + url = self.url_prefix % "@bob:test" channel = self.make_request( "GET", @@ -1216,7 +1217,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "/_synapse/admin/v2/users/@unknown_person:test", + self.url_prefix % "@unknown_person:test", access_token=self.admin_user_tok, ) @@ -1337,7 +1338,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ Check that a new admin user is created successfully. """ - url = "/_synapse/admin/v2/users/@bob:test" + url = self.url_prefix % "@bob:test" # Create user (server admin) body = { @@ -1386,7 +1387,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ Check that a new regular user is created successfully. """ - url = "/_synapse/admin/v2/users/@bob:test" + url = self.url_prefix % "@bob:test" # Create user body = { @@ -1478,7 +1479,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) # Register new user with admin API - url = "/_synapse/admin/v2/users/@bob:test" + url = self.url_prefix % "@bob:test" # Create user channel = self.make_request( @@ -1515,7 +1516,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) # Register new user with admin API - url = "/_synapse/admin/v2/users/@bob:test" + url = self.url_prefix % "@bob:test" # Create user channel = self.make_request( @@ -1545,7 +1546,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): Check that a new regular user is created successfully and got an email pusher. """ - url = "/_synapse/admin/v2/users/@bob:test" + url = self.url_prefix % "@bob:test" # Create user body = { @@ -1588,7 +1589,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): Check that a new regular user is created successfully and got not an email pusher. """ - url = "/_synapse/admin/v2/users/@bob:test" + url = self.url_prefix % "@bob:test" # Create user body = { @@ -2085,10 +2086,13 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) - self.assertIsNone(channel.json_body["password_hash"]) self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User", channel.json_body["displayname"]) + + # This key was removed intentionally. Ensure it is not accidentally re-included. + self.assertNotIn("password_hash", channel.json_body) + # the user is deactivated, the threepid will be deleted # Get user @@ -2101,11 +2105,13 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["deactivated"]) - self.assertIsNone(channel.json_body["password_hash"]) self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User", channel.json_body["displayname"]) + # This key was removed intentionally. Ensure it is not accidentally re-included. + self.assertNotIn("password_hash", channel.json_body) + @override_config({"user_directory": {"enabled": True, "search_all_users": True}}) def test_change_name_deactivate_user_user_directory(self): """ @@ -2177,9 +2183,11 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) - self.assertIsNotNone(channel.json_body["password_hash"]) self._is_erased("@user:test", False) + # This key was removed intentionally. Ensure it is not accidentally re-included. + self.assertNotIn("password_hash", channel.json_body) + @override_config({"password_config": {"localdb_enabled": False}}) def test_reactivate_user_localdb_disabled(self): """ @@ -2209,9 +2217,11 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) - self.assertIsNone(channel.json_body["password_hash"]) self._is_erased("@user:test", False) + # This key was removed intentionally. Ensure it is not accidentally re-included. + self.assertNotIn("password_hash", channel.json_body) + @override_config({"password_config": {"enabled": False}}) def test_reactivate_user_password_disabled(self): """ @@ -2241,9 +2251,11 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual("@user:test", channel.json_body["name"]) self.assertFalse(channel.json_body["deactivated"]) - self.assertIsNone(channel.json_body["password_hash"]) self._is_erased("@user:test", False) + # This key was removed intentionally. Ensure it is not accidentally re-included. + self.assertNotIn("password_hash", channel.json_body) + def test_set_user_as_admin(self): """ Test setting the admin flag on a user. @@ -2328,7 +2340,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): Ensure an account can't accidentally be deactivated by using a str value for the deactivated body parameter """ - url = "/_synapse/admin/v2/users/@bob:test" + url = self.url_prefix % "@bob:test" # Create user channel = self.make_request( @@ -2392,18 +2404,20 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Deactivate the user. channel = self.make_request( "PUT", - "/_synapse/admin/v2/users/%s" % urllib.parse.quote(user_id), + self.url_prefix % urllib.parse.quote(user_id), access_token=self.admin_user_tok, content={"deactivated": True}, ) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertTrue(channel.json_body["deactivated"]) - self.assertIsNone(channel.json_body["password_hash"]) self._is_erased(user_id, False) d = self.store.mark_user_erased(user_id) self.assertIsNone(self.get_success(d)) self._is_erased(user_id, True) + # This key was removed intentionally. Ensure it is not accidentally re-included. + self.assertNotIn("password_hash", channel.json_body) + def _check_fields(self, content: JsonDict): """Checks that the expected user attributes are present in content @@ -2416,13 +2430,15 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertIn("admin", content) self.assertIn("deactivated", content) self.assertIn("shadow_banned", content) - self.assertIn("password_hash", content) self.assertIn("creation_ts", content) self.assertIn("appservice_id", content) self.assertIn("consent_server_notice_sent", content) self.assertIn("consent_version", content) self.assertIn("external_ids", content) + # This key was removed intentionally. Ensure it is not accidentally re-included. + self.assertNotIn("password_hash", content) + class UserMembershipRestTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index c026d526ef..c9b220e73d 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -21,6 +21,7 @@ from unittest.mock import patch from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, register, relations, room, sync +from synapse.types import JsonDict from tests import unittest from tests.server import FakeChannel @@ -93,11 +94,6 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel.json_body, ) - def test_deny_membership(self): - """Test that we deny relations on membership events""" - channel = self._send_relation(RelationTypes.ANNOTATION, EventTypes.Member) - self.assertEquals(400, channel.code, channel.json_body) - def test_deny_invalid_event(self): """Test that we deny relations on non-existant events""" channel = self._send_relation( @@ -459,7 +455,14 @@ class RelationsTestCase(unittest.HomeserverTestCase): @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_bundled_aggregations(self): - """Test that annotations, references, and threads get correctly bundled.""" + """ + Test that annotations, references, and threads get correctly bundled. + + Note that this doesn't test against /relations since only thread relations + get bundled via that API. See test_aggregation_get_event_for_thread. + + See test_edit for a similar test for edits. + """ # Setup by sending a variety of relations. channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(200, channel.code, channel.json_body) @@ -487,12 +490,13 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) thread_2 = channel.json_body["event_id"] - def assert_bundle(actual): + def assert_bundle(event_json: JsonDict) -> None: """Assert the expected values of the bundled aggregations.""" + relations_dict = event_json["unsigned"].get("m.relations") # Ensure the fields are as expected. self.assertCountEqual( - actual.keys(), + relations_dict.keys(), ( RelationTypes.ANNOTATION, RelationTypes.REFERENCE, @@ -508,17 +512,20 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"type": "m.reaction", "key": "b", "count": 1}, ] }, - actual[RelationTypes.ANNOTATION], + relations_dict[RelationTypes.ANNOTATION], ) self.assertEquals( {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, - actual[RelationTypes.REFERENCE], + relations_dict[RelationTypes.REFERENCE], ) self.assertEquals( 2, - actual[RelationTypes.THREAD].get("count"), + relations_dict[RelationTypes.THREAD].get("count"), + ) + self.assertTrue( + relations_dict[RelationTypes.THREAD].get("current_user_participated") ) # The latest thread event has some fields that don't matter. self.assert_dict( @@ -535,20 +542,9 @@ class RelationsTestCase(unittest.HomeserverTestCase): "type": "m.room.test", "user_id": self.user_id, }, - actual[RelationTypes.THREAD].get("latest_event"), + relations_dict[RelationTypes.THREAD].get("latest_event"), ) - def _find_and_assert_event(events): - """ - Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. - """ - for event in events: - if event["event_id"] == self.parent_id: - break - else: - raise AssertionError(f"Event {self.parent_id} not found in chunk") - assert_bundle(event["unsigned"].get("m.relations")) - # Request the event directly. channel = self.make_request( "GET", @@ -556,7 +552,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["unsigned"].get("m.relations")) + assert_bundle(channel.json_body) # Request the room messages. channel = self.make_request( @@ -565,7 +561,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - _find_and_assert_event(channel.json_body["chunk"]) + assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) # Request the room context. channel = self.make_request( @@ -574,17 +570,14 @@ class RelationsTestCase(unittest.HomeserverTestCase): access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations")) + assert_bundle(channel.json_body["event"]) # Request sync. - # channel = self.make_request("GET", "/sync", access_token=self.user_token) - # self.assertEquals(200, channel.code, channel.json_body) - # room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] - # self.assertTrue(room_timeline["limited"]) - # _find_and_assert_event(room_timeline["events"]) - - # Note that /relations is tested separately in test_aggregation_get_event_for_thread - # since it needs different data configured. + channel = self.make_request("GET", "/sync", access_token=self.user_token) + self.assertEquals(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + self._find_event_in_chunk(room_timeline["events"]) def test_aggregation_get_event_for_annotation(self): """Test that annotations do not get bundled aggregations included @@ -779,25 +772,58 @@ class RelationsTestCase(unittest.HomeserverTestCase): edit_event_id = channel.json_body["event_id"] + def assert_bundle(event_json: JsonDict) -> None: + """Assert the expected values of the bundled aggregations.""" + relations_dict = event_json["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + channel = self.make_request( "GET", - "/rooms/%s/event/%s" % (self.room, self.parent_id), + f"/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) self.assertEquals(200, channel.code, channel.json_body) - self.assertEquals(channel.json_body["content"], new_body) + assert_bundle(channel.json_body) - relations_dict = channel.json_body["unsigned"].get("m.relations") - self.assertIn(RelationTypes.REPLACE, relations_dict) + # Request the room messages. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) - m_replace_dict = relations_dict[RelationTypes.REPLACE] - for key in ["event_id", "sender", "origin_server_ts"]: - self.assertIn(key, m_replace_dict) + # Request the room context. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/context/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["event"]) - self.assert_dict( - {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + # Request sync, but limit the timeline so it becomes limited (and includes + # bundled aggregations). + filter = urllib.parse.quote_plus( + '{"room": {"timeline": {"limit": 2}}}'.encode() ) + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token + ) + self.assertEquals(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + assert_bundle(self._find_event_in_chunk(room_timeline["events"])) def test_multi_edit(self): """Test that multiple edits, including attempts by people who @@ -1104,6 +1130,16 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(channel.json_body["chunk"], []) + def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: + """ + Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. + """ + for event in events: + if event["event_id"] == self.parent_id: + return event + + raise AssertionError(f"Event {self.parent_id} not found in chunk") + def _send_relation( self, relation_type: str, @@ -1119,7 +1155,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): relation_type: One of `RelationTypes` event_type: The type of the event to create key: The aggregation key used for m.annotation relation type. - content: The content of the created event. + content: The content of the created event. Will be modified to configure + the m.relates_to key based on the other provided parameters. access_token: The access token used to send the relation, defaults to `self.user_token` parent_id: The event_id this relation relates to. If None, then self.parent_id @@ -1130,17 +1167,21 @@ class RelationsTestCase(unittest.HomeserverTestCase): if not access_token: access_token = self.user_token - query = "" - if key: - query = "?key=" + urllib.parse.quote_plus(key.encode("utf-8")) - original_id = parent_id if parent_id else self.parent_id + if content is None: + content = {} + content["m.relates_to"] = { + "event_id": original_id, + "rel_type": relation_type, + } + if key is not None: + content["m.relates_to"]["key"] = key + channel = self.make_request( "POST", - "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" - % (self.room, original_id, relation_type, event_type, query), - content or {}, + f"/_matrix/client/v3/rooms/{self.room}/send/{event_type}", + content, access_token=access_token, ) return channel diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index b58452195a..fe5b536d97 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -228,7 +228,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(event) time_now = self.clock.time_msec() - serialized = self.get_success(self.serializer.serialize_event(event, time_now)) + serialized = self.serializer.serialize_event(event, time_now) return serialized diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 1af5e5cee5..8424383580 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -196,6 +196,16 @@ class RestHelper: expect_code=expect_code, ) + def ban(self, room: str, src: str, targ: str, **kwargs: object): + """A convenience helper: `change_membership` with `membership` preset to "ban".""" + self.change_membership( + room=room, + src=src, + targ=targ, + membership=Membership.BAN, + **kwargs, + ) + def change_membership( self, room: str, diff --git a/tests/server.py b/tests/server.py index ca2b7a5b97..a0cd14ea45 100644 --- a/tests/server.py +++ b/tests/server.py @@ -14,6 +14,8 @@ import hashlib import json import logging +import os +import os.path import time import uuid import warnings @@ -71,6 +73,7 @@ from tests.utils import ( POSTGRES_HOST, POSTGRES_PASSWORD, POSTGRES_USER, + SQLITE_PERSIST_DB, USE_POSTGRES_FOR_TESTS, MockClock, default_config, @@ -739,9 +742,23 @@ def setup_test_homeserver( }, } else: + if SQLITE_PERSIST_DB: + # The current working directory is in _trial_temp, so this gets created within that directory. + test_db_location = os.path.abspath("test.db") + logger.debug("Will persist db to %s", test_db_location) + # Ensure each test gets a clean database. + try: + os.remove(test_db_location) + except FileNotFoundError: + pass + else: + logger.debug("Removed existing DB at %s", test_db_location) + else: + test_db_location = ":memory:" + database_config = { "name": "sqlite3", - "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, + "args": {"database": test_db_location, "cp_min": 1, "cp_max": 1}, } if "db_txn_limit" in kwargs: diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index bf78084869..2bc89512f8 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -531,17 +531,25 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.get_success( self.store.db_pool.simple_insert_many( table="federation_inbound_events_staging", + keys=( + "origin", + "room_id", + "received_ts", + "event_id", + "event_json", + "internal_metadata", + ), values=[ - { - "origin": "some_origin", - "room_id": room_id, - "received_ts": 0, - "event_id": f"$fake_event_id_{i + 1}", - "event_json": json_encoder.encode( + ( + "some_origin", + room_id, + 0, + f"$fake_event_id_{i + 1}", + json_encoder.encode( {"prev_events": [prev_event_format(f"$fake_event_id_{i}")]} ), - "internal_metadata": "{}", - } + "{}", + ) for i in range(500) ], desc="test_prune_inbound_federation_queue", diff --git a/tests/test_federation.py b/tests/test_federation.py index 3eef1c4c05..2b9804aba0 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -17,7 +17,9 @@ from unittest.mock import Mock from twisted.internet.defer import succeed from synapse.api.errors import FederationError +from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict +from synapse.federation.federation_base import event_from_pdu_json from synapse.logging.context import LoggingContext from synapse.types import UserID, create_requester from synapse.util import Clock @@ -276,3 +278,73 @@ class MessageAcceptTests(unittest.HomeserverTestCase): "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(), ) self.assertTrue(remote_self_signing_key in self_signing_key["keys"].values()) + + +class StripUnsignedFromEventsTestCase(unittest.TestCase): + def test_strip_unauthorized_unsigned_values(self): + event1 = { + "sender": "@baduser:test.serv", + "state_key": "@baduser:test.serv", + "event_id": "$event1:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.member", + "origin": "test.servx", + "content": {"membership": "join"}, + "auth_events": [], + "unsigned": {"malicious garbage": "hackz", "more warez": "more hackz"}, + } + filtered_event = event_from_pdu_json(event1, RoomVersions.V1) + # Make sure unauthorized fields are stripped from unsigned + self.assertNotIn("more warez", filtered_event.unsigned) + + def test_strip_event_maintains_allowed_fields(self): + event2 = { + "sender": "@baduser:test.serv", + "state_key": "@baduser:test.serv", + "event_id": "$event2:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.member", + "origin": "test.servx", + "auth_events": [], + "content": {"membership": "join"}, + "unsigned": { + "malicious garbage": "hackz", + "more warez": "more hackz", + "age": 14, + "invite_room_state": [], + }, + } + + filtered_event2 = event_from_pdu_json(event2, RoomVersions.V1) + self.assertIn("age", filtered_event2.unsigned) + self.assertEqual(14, filtered_event2.unsigned["age"]) + self.assertNotIn("more warez", filtered_event2.unsigned) + # Invite_room_state is allowed in events of type m.room.member + self.assertIn("invite_room_state", filtered_event2.unsigned) + self.assertEqual([], filtered_event2.unsigned["invite_room_state"]) + + def test_strip_event_removes_fields_based_on_event_type(self): + event3 = { + "sender": "@baduser:test.serv", + "state_key": "@baduser:test.serv", + "event_id": "$event3:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.power_levels", + "origin": "test.servx", + "content": {}, + "auth_events": [], + "unsigned": { + "malicious garbage": "hackz", + "more warez": "more hackz", + "age": 14, + "invite_room_state": [], + }, + } + filtered_event3 = event_from_pdu_json(event3, RoomVersions.V1) + self.assertIn("age", filtered_event3.unsigned) + # Invite_room_state field is only permitted in event type m.room.member + self.assertNotIn("invite_room_state", filtered_event3.unsigned) + self.assertNotIn("more warez", filtered_event3.unsigned) diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 15ac2bfeba..f05a373aa0 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -19,7 +19,7 @@ import sys import warnings from asyncio import Future from binascii import unhexlify -from typing import Any, Awaitable, Callable, TypeVar +from typing import Awaitable, Callable, TypeVar from unittest.mock import Mock import attr @@ -46,7 +46,7 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: raise Exception("awaitable has not yet completed") -def make_awaitable(result: Any) -> Awaitable[Any]: +def make_awaitable(result: TV) -> Awaitable[TV]: """ Makes an awaitable, suitable for mocking an `async` function. This uses Futures as they can be awaited multiple times so can be returned diff --git a/tests/utils.py b/tests/utils.py index 6d013e8518..c06fc320f3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -42,6 +42,10 @@ POSTGRES_HOST = os.environ.get("SYNAPSE_POSTGRES_HOST", None) POSTGRES_PASSWORD = os.environ.get("SYNAPSE_POSTGRES_PASSWORD", None) POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),) +# When debugging a specific test, it's occasionally useful to write the +# DB to disk and query it with the sqlite CLI. +SQLITE_PERSIST_DB = os.environ.get("SYNAPSE_TEST_PERSIST_SQLITE_DB") is not None + # the dbname we will connect to in order to create the base database. POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres" |