diff options
author | Patrick Cloke <patrickc@matrix.org> | 2022-03-10 10:38:28 -0500 |
---|---|---|
committer | Patrick Cloke <patrickc@matrix.org> | 2022-03-10 10:38:28 -0500 |
commit | fdc106378212da907449fa9c747a3d7c332f59c2 (patch) | |
tree | b695d38b0baa293af702fdf1e9e645b43ed2a7a5 | |
parent | Merge branch 'release-v1.54', remote-tracking branch 'origin' into matrix-org... (diff) | |
parent | Support stable identifiers for MSC3440: Threading (#12151) (diff) | |
download | synapse-fdc106378212da907449fa9c747a3d7c332f59c2.tar.xz |
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
78 files changed, 938 insertions, 429 deletions
diff --git a/.github/workflows/release-artifacts.yml b/.github/workflows/release-artifacts.yml index 65ea761ad7..ed4fc6179d 100644 --- a/.github/workflows/release-artifacts.yml +++ b/.github/workflows/release-artifacts.yml @@ -112,7 +112,8 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: files: | - python-dist/* + Sdist/* + Wheel/* debs.tar.xz # if it's not already published, keep the release as a draft. draft: true diff --git a/CHANGES.md b/CHANGES.md index 9d27cd35aa..ef671e73f1 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,27 +1,27 @@ Synapse 1.54.0 (2022-03-08) =========================== +Please note that this will be the last release of Synapse that is compatible with Mjolnir 1.3.1 and earlier. +Administrators of servers which have the Mjolnir module installed are advised to upgrade Mjolnir to version 1.3.2 or later. + + Bugfixes -------- - Fix a bug introduced in Synapse 1.54.0rc1 preventing the new module callbacks introduced in this release from being registered by modules. ([\#12141](https://github.com/matrix-org/synapse/issues/12141)) -- Fix a bug introduced in Synapse 1.54.0rc1 which meant that Synapse would refuse to start if pre-release versions of dependencies were installed. ([\#12177](https://github.com/matrix-org/synapse/issues/12177)) +- Fix a bug introduced in Synapse 1.54.0rc1 where runtime dependency version checks would mistakenly check development dependencies if they were present and would not accept pre-release versions of dependencies. ([\#12129](https://github.com/matrix-org/synapse/issues/12129), [\#12177](https://github.com/matrix-org/synapse/issues/12177)) Internal Changes ---------------- - Update release script to insert the previous version when writing "No significant changes" line in the changelog. ([\#12127](https://github.com/matrix-org/synapse/issues/12127)) -- Inspect application dependencies using `importlib.metadata` or its backport. ([\#12129](https://github.com/matrix-org/synapse/issues/12129)) -- Relax the version guard for "packaging" added in #12088. ([\#12166](https://github.com/matrix-org/synapse/issues/12166)) +- Relax the version guard for "packaging" added in [\#12088](https://github.com/matrix-org/synapse/issues/12088). ([\#12166](https://github.com/matrix-org/synapse/issues/12166)) Synapse 1.54.0rc1 (2022-03-02) ============================== -Please note that this will be the last release of Synapse that is compatible with Mjolnir 1.3.1 and earlier. -Administrators of servers which have the Mjolnir module installed are advised to upgrade Mjolnir to version 1.3.2 or later. - Features -------- diff --git a/README.rst b/README.rst index 4281c87d1f..595fb5ff62 100644 --- a/README.rst +++ b/README.rst @@ -312,6 +312,9 @@ We recommend using the demo which starts 3 federated instances running on ports (to stop, you can use `./demo/stop.sh`) +See the [demo documentation](https://matrix-org.github.io/synapse/develop/development/demo.html) +for more information. + If you just want to start a single instance of the app and run it directly:: # Create the homeserver.yaml config once diff --git a/changelog.d/12028.feature b/changelog.d/12028.feature new file mode 100644 index 0000000000..5549c8f6fc --- /dev/null +++ b/changelog.d/12028.feature @@ -0,0 +1 @@ +Add third-party rules rules callbacks `check_can_shutdown_room` and `check_can_deactivate_user`. diff --git a/changelog.d/12042.misc b/changelog.d/12042.misc new file mode 100644 index 0000000000..6ecdc96021 --- /dev/null +++ b/changelog.d/12042.misc @@ -0,0 +1 @@ +Correct type hints for txredis. diff --git a/changelog.d/12130.bugfix b/changelog.d/12130.bugfix new file mode 100644 index 0000000000..df9b0dc413 --- /dev/null +++ b/changelog.d/12130.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when redacting events with relations. diff --git a/changelog.d/12131.misc b/changelog.d/12131.misc new file mode 100644 index 0000000000..8ef23c22d5 --- /dev/null +++ b/changelog.d/12131.misc @@ -0,0 +1 @@ +Fix CI not attaching source distributions and wheels to the GitHub releases. \ No newline at end of file diff --git a/changelog.d/12135.feature b/changelog.d/12135.feature new file mode 100644 index 0000000000..b337f51730 --- /dev/null +++ b/changelog.d/12135.feature @@ -0,0 +1 @@ +Add experimental env var `SYNAPSE_ASYNC_IO_REACTOR` that causes Synapse to use the asyncio reactor for Twisted. diff --git a/changelog.d/12143.doc b/changelog.d/12143.doc new file mode 100644 index 0000000000..4b9db74b1f --- /dev/null +++ b/changelog.d/12143.doc @@ -0,0 +1 @@ +Improve documentation for demo scripts. diff --git a/changelog.d/12150.misc b/changelog.d/12150.misc new file mode 100644 index 0000000000..2d2706dac7 --- /dev/null +++ b/changelog.d/12150.misc @@ -0,0 +1 @@ +Use `ParamSpec` in type hints for `synapse.logging.context`. diff --git a/changelog.d/12151.feature b/changelog.d/12151.feature new file mode 100644 index 0000000000..18432b2da9 --- /dev/null +++ b/changelog.d/12151.feature @@ -0,0 +1 @@ +Support the stable identifiers from [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440): threads. diff --git a/changelog.d/12173.misc b/changelog.d/12173.misc new file mode 100644 index 0000000000..9f333e718a --- /dev/null +++ b/changelog.d/12173.misc @@ -0,0 +1 @@ +Avoid trying to calculate the state at outlier events. diff --git a/changelog.d/12175.bugfix b/changelog.d/12175.bugfix new file mode 100644 index 0000000000..881cb9b76c --- /dev/null +++ b/changelog.d/12175.bugfix @@ -0,0 +1 @@ +Fix a bug where non-standard information was returned from the `/hierarchy` API. Introduced in Synapse v1.41.0. diff --git a/changelog.d/12179.doc b/changelog.d/12179.doc new file mode 100644 index 0000000000..55d8caa45a --- /dev/null +++ b/changelog.d/12179.doc @@ -0,0 +1 @@ +Updates to the Room DAG concepts development document. diff --git a/changelog.d/12182.misc b/changelog.d/12182.misc new file mode 100644 index 0000000000..7e9ad2c752 --- /dev/null +++ b/changelog.d/12182.misc @@ -0,0 +1 @@ +Retry HTTP replication failures, this should prevent 502's when restarting stateful workers (main, event persisters, stream writers). Contributed by Nick @ Beeper. diff --git a/changelog.d/12187.misc b/changelog.d/12187.misc new file mode 100644 index 0000000000..c53e68faa5 --- /dev/null +++ b/changelog.d/12187.misc @@ -0,0 +1 @@ +Remove unused variables. diff --git a/changelog.d/12189.bugfix b/changelog.d/12189.bugfix new file mode 100644 index 0000000000..df9b0dc413 --- /dev/null +++ b/changelog.d/12189.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when redacting events with relations. diff --git a/changelog.d/12192.misc b/changelog.d/12192.misc new file mode 100644 index 0000000000..bdfe8dad98 --- /dev/null +++ b/changelog.d/12192.misc @@ -0,0 +1 @@ +Rename `HomeServer.get_tcp_replication` to `get_replication_command_handler`. diff --git a/changelog.d/12197.misc b/changelog.d/12197.misc new file mode 100644 index 0000000000..7d0e9b6bbf --- /dev/null +++ b/changelog.d/12197.misc @@ -0,0 +1 @@ +Remove some dead code. diff --git a/demo/.gitignore b/demo/.gitignore index 4d12712343..5663aba8e7 100644 --- a/demo/.gitignore +++ b/demo/.gitignore @@ -1,7 +1,4 @@ -*.db -*.log -*.log.* -*.pid - -/media_store.* -/etc +# Ignore all the temporary files from the demo servers. +8080/ +8081/ +8082/ diff --git a/demo/README b/demo/README deleted file mode 100644 index a5a95bd196..0000000000 --- a/demo/README +++ /dev/null @@ -1,26 +0,0 @@ -DO NOT USE THESE DEMO SERVERS IN PRODUCTION - -Requires you to have done: - python setup.py develop - - -The demo start.sh will start three synapse servers on ports 8080, 8081 and 8082, with host names localhost:$port. This can be easily changed to `hostname`:$port in start.sh if required. - -To enable the servers to communicate untrusted ssl certs are used. In order to do this the servers do not check the certs -and are configured in a highly insecure way. Do not use these configuration files in production. - -stop.sh will stop the synapse servers and the webclient. - -clean.sh will delete the databases and log files. - -To start a completely new set of servers, run: - - ./demo/stop.sh; ./demo/clean.sh && ./demo/start.sh - - -Logs and sqlitedb will be stored in demo/808{0,1,2}.{log,db} - - - -Also note that when joining a public room on a different HS via "#foo:bar.net", then you are (in the current impl) joining a room with room_id "foo". This means that it won't work if your HS already has a room with that name. - diff --git a/demo/clean.sh b/demo/clean.sh index e9b440d90d..7f1e192021 100755 --- a/demo/clean.sh +++ b/demo/clean.sh @@ -4,6 +4,9 @@ set -e DIR="$( cd "$( dirname "$0" )" && pwd )" +# Ensure that the servers are stopped. +$DIR/stop.sh + PID_FILE="$DIR/servers.pid" if [ -f "$PID_FILE" ]; then diff --git a/demo/start.sh b/demo/start.sh index 8ffb14e30a..55e69685e3 100755 --- a/demo/start.sh +++ b/demo/start.sh @@ -6,8 +6,6 @@ CWD=$(pwd) cd "$DIR/.." || exit -mkdir -p demo/etc - PYTHONPATH=$(readlink -f "$(pwd)") export PYTHONPATH @@ -21,22 +19,26 @@ for port in 8080 8081 8082; do mkdir -p demo/$port pushd demo/$port || exit - #rm $DIR/etc/$port.config + # Generate the configuration for the homeserver at localhost:848x. python3 -m synapse.app.homeserver \ --generate-config \ - -H "localhost:$https_port" \ - --config-path "$DIR/etc/$port.config" \ + --server-name "localhost:$port" \ + --config-path "$port.config" \ --report-stats no - if ! grep -F "Customisation made by demo/start.sh" -q "$DIR/etc/$port.config"; then - # Generate tls keys - openssl req -x509 -newkey rsa:4096 -keyout "$DIR/etc/localhost:$https_port.tls.key" -out "$DIR/etc/localhost:$https_port.tls.crt" -days 365 -nodes -subj "/O=matrix" + if ! grep -F "Customisation made by demo/start.sh" -q "$port.config"; then + # Generate TLS keys. + openssl req -x509 -newkey rsa:4096 \ + -keyout "localhost:$port.tls.key" \ + -out "localhost:$port.tls.crt" \ + -days 365 -nodes -subj "/O=matrix" - # Regenerate configuration + # Add customisations to the configuration. { - printf '\n\n# Customisation made by demo/start.sh\n' + printf '\n\n# Customisation made by demo/start.sh\n\n' echo "public_baseurl: http://localhost:$port/" echo 'enable_registration: true' + echo '' # Warning, this heredoc depends on the interaction of tabs and spaces. # Please don't accidentaly bork me with your fancy settings. @@ -63,38 +65,34 @@ for port in 8080 8081 8082; do echo "${listeners}" - # Disable tls for the servers - printf '\n\n# Disable tls on the servers.' + # Disable TLS for the servers + printf '\n\n# Disable TLS for the servers.' echo '# DO NOT USE IN PRODUCTION' echo 'use_insecure_ssl_client_just_for_testing_do_not_use: true' echo 'federation_verify_certificates: false' - # Set tls paths - echo "tls_certificate_path: \"$DIR/etc/localhost:$https_port.tls.crt\"" - echo "tls_private_key_path: \"$DIR/etc/localhost:$https_port.tls.key\"" + # Set paths for the TLS certificates. + echo "tls_certificate_path: \"$DIR/$port/localhost:$port.tls.crt\"" + echo "tls_private_key_path: \"$DIR/$port/localhost:$port.tls.key\"" # Ignore keys from the trusted keys server echo '# Ignore keys from the trusted keys server' echo 'trusted_key_servers:' echo ' - server_name: "matrix.org"' echo ' accept_keys_insecurely: true' - - # Reduce the blacklist - blacklist=$(cat <<-BLACK - # Set the blacklist so that it doesn't include 127.0.0.1, ::1 - federation_ip_range_blacklist: - - '10.0.0.0/8' - - '172.16.0.0/12' - - '192.168.0.0/16' - - '100.64.0.0/10' - - '169.254.0.0/16' - - 'fe80::/64' - - 'fc00::/7' - BLACK + echo '' + + # Allow the servers to communicate over localhost. + allow_list=$(cat <<-ALLOW_LIST + # Allow the servers to communicate over localhost. + ip_range_whitelist: + - '127.0.0.1/8' + - '::1/128' + ALLOW_LIST ) - echo "${blacklist}" - } >> "$DIR/etc/$port.config" + echo "${allow_list}" + } >> "$port.config" fi # Check script parameters @@ -141,19 +139,18 @@ for port in 8080 8081 8082; do burst_count: 1000 RC ) - echo "${ratelimiting}" >> "$DIR/etc/$port.config" + echo "${ratelimiting}" >> "$port.config" fi fi - if ! grep -F "full_twisted_stacktraces" -q "$DIR/etc/$port.config"; then - echo "full_twisted_stacktraces: true" >> "$DIR/etc/$port.config" - fi - if ! grep -F "report_stats" -q "$DIR/etc/$port.config" ; then - echo "report_stats: false" >> "$DIR/etc/$port.config" + # Always disable reporting of stats if the option is not there. + if ! grep -F "report_stats" -q "$port.config" ; then + echo "report_stats: false" >> "$port.config" fi + # Run the homeserver in the background. python3 -m synapse.app.homeserver \ - --config-path "$DIR/etc/$port.config" \ + --config-path "$port.config" \ -D \ popd || exit diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index ef9cabf555..21f80efc99 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -82,6 +82,7 @@ - [Release Cycle](development/releases.md) - [Git Usage](development/git.md) - [Testing]() + - [Demo scripts](development/demo.md) - [OpenTracing](opentracing.md) - [Database Schemas](development/database_schema.md) - [Experimental features](development/experimental_features.md) diff --git a/docs/development/demo.md b/docs/development/demo.md new file mode 100644 index 0000000000..4277252ceb --- /dev/null +++ b/docs/development/demo.md @@ -0,0 +1,41 @@ +# Synapse demo setup + +**DO NOT USE THESE DEMO SERVERS IN PRODUCTION** + +Requires you to have a [Synapse development environment setup](https://matrix-org.github.io/synapse/develop/development/contributing_guide.html#4-install-the-dependencies). + +The demo setup allows running three federation Synapse servers, with server +names `localhost:8080`, `localhost:8081`, and `localhost:8082`. + +You can access them via any Matrix client over HTTP at `localhost:8080`, +`localhost:8081`, and `localhost:8082` or over HTTPS at `localhost:8480`, +`localhost:8481`, and `localhost:8482`. + +To enable the servers to communicate, self-signed SSL certificates are generated +and the servers are configured in a highly insecure way, including: + +* Not checking certificates over federation. +* Not verifying keys. + +The servers are configured to store their data under `demo/8080`, `demo/8081`, and +`demo/8082`. This includes configuration, logs, SQLite databases, and media. + +Note that when joining a public room on a different HS via "#foo:bar.net", then +you are (in the current impl) joining a room with room_id "foo". This means that +it won't work if your HS already has a room with that name. + +## Using the demo scripts + +There's three main scripts with straightforward purposes: + +* `start.sh` will start the Synapse servers, generating any missing configuration. + * This accepts a single parameter `--no-rate-limit` to "disable" rate limits + (they actually still exist, but are very high). +* `stop.sh` will stop the Synapse servers. +* `clean.sh` will delete the configuration, databases, log files, etc. + +To start a completely new set of servers, run: + +```sh +./demo/stop.sh; ./demo/clean.sh && ./demo/start.sh +``` diff --git a/docs/development/room-dag-concepts.md b/docs/development/room-dag-concepts.md index cbc7cf2949..3eb4d5acc4 100644 --- a/docs/development/room-dag-concepts.md +++ b/docs/development/room-dag-concepts.md @@ -30,37 +30,72 @@ rather than skipping any that arrived late; whereas if you're looking at a historical section of timeline (i.e. `/messages`), you want to see the best representation of the state of the room as others were seeing it at the time. +## Outliers -## Forward extremity +We mark an event as an `outlier` when we haven't figured out the state for the +room at that point in the DAG yet. They are "floating" events that we haven't +yet correlated to the DAG. -Most-recent-in-time events in the DAG which are not referenced by any other events' `prev_events` yet. +Outliers typically arise when we fetch the auth chain or state for a given +event. When that happens, we just grab the events in the state/auth chain, +without calculating the state at those events, or backfilling their +`prev_events`. -The forward extremities of a room are used as the `prev_events` when the next event is sent. +So, typically, we won't have the `prev_events` of an `outlier` in the database, +(though it's entirely possible that we *might* have them for some other +reason). Other things that make outliers different from regular events: + * We don't have state for them, so there should be no entry in + `event_to_state_groups` for an outlier. (In practice this isn't always + the case, though I'm not sure why: see https://github.com/matrix-org/synapse/issues/12201). -## Backward extremity + * We don't record entries for them in the `event_edges`, + `event_forward_extremeties` or `event_backward_extremities` tables. -The current marker of where we have backfilled up to and will generally be the -`prev_events` of the oldest-in-time events we have in the DAG. This gives a starting point when -backfilling history. +Since outliers are not tied into the DAG, they do not normally form part of the +timeline sent down to clients via `/sync` or `/messages`; however there is an +exception: -When we persist a non-outlier event, we clear it as a backward extremity and set -all of its `prev_events` as the new backward extremities if they aren't already -persisted in the `events` table. +### Out-of-band membership events +A special case of outlier events are some membership events for federated rooms +that we aren't full members of. For example: -## Outliers + * invites received over federation, before we join the room + * *rejections* for said invites + * knock events for rooms that we would like to join but have not yet joined. -We mark an event as an `outlier` when we haven't figured out the state for the -room at that point in the DAG yet. +In all the above cases, we don't have the state for the room, which is why they +are treated as outliers. They are a bit special though, in that they are +proactively sent to clients via `/sync`. -We won't *necessarily* have the `prev_events` of an `outlier` in the database, -but it's entirely possible that we *might*. +## Forward extremity + +Most-recent-in-time events in the DAG which are not referenced by any other +events' `prev_events` yet. (In this definition, outliers, rejected events, and +soft-failed events don't count.) + +The forward extremities of a room (or at least, a subset of them, if there are +more than ten) are used as the `prev_events` when the next event is sent. + +The "current state" of a room (ie: the state which would be used if we +generated a new event) is, therefore, the resolution of the room states +at each of the forward extremities. + +## Backward extremity + +The current marker of where we have backfilled up to and will generally be the +`prev_events` of the oldest-in-time events we have in the DAG. This gives a starting point when +backfilling history. -For example, when we fetch the event auth chain or state for a given event, we -mark all of those claimed auth events as outliers because we haven't done the -state calculation ourself. +Note that, unlike forward extremities, we typically don't have any backward +extremity events themselves in the database - or, if we do, they will be "outliers" (see +above). Either way, we don't expect to have the room state at a backward extremity. +When we persist a non-outlier event, if it was previously a backward extremity, +we clear it as a backward extremity and set all of its `prev_events` as the new +backward extremities if they aren't already persisted as non-outliers. This +therefore keeps the backward extremities up-to-date. ## State groups diff --git a/docs/federate.md b/docs/federate.md index 5107f995be..df4c87da51 100644 --- a/docs/federate.md +++ b/docs/federate.md @@ -63,4 +63,5 @@ release of Synapse. If you want to get up and running quickly with a trio of homeservers in a private federation, there is a script in the `demo` directory. This is mainly -useful just for development purposes. See [demo/README](https://github.com/matrix-org/synapse/tree/develop/demo/). +useful just for development purposes. See +[demo scripts](https://matrix-org.github.io/synapse/develop/development/demo.html). diff --git a/docs/modules/third_party_rules_callbacks.md b/docs/modules/third_party_rules_callbacks.md index 09ac838107..1d3c39967f 100644 --- a/docs/modules/third_party_rules_callbacks.md +++ b/docs/modules/third_party_rules_callbacks.md @@ -148,6 +148,49 @@ deny an incoming event, see [`check_event_for_spam`](spam_checker_callbacks.md#c If multiple modules implement this callback, Synapse runs them all in order. +### `check_can_shutdown_room` + +_First introduced in Synapse v1.55.0_ + +```python +async def check_can_shutdown_room( + user_id: str, room_id: str, +) -> bool: +``` + +Called when an admin user requests the shutdown of a room. The module must return a +boolean indicating whether the shutdown can go through. If the callback returns `False`, +the shutdown will not proceed and the caller will see a `M_FORBIDDEN` error. + +If multiple modules implement this callback, they will be considered in order. If a +callback returns `True`, Synapse falls through to the next one. The value of the first +callback that does not return `True` will be used. If this happens, Synapse will not call +any of the subsequent implementations of this callback. + +### `check_can_deactivate_user` + +_First introduced in Synapse v1.55.0_ + +```python +async def check_can_deactivate_user( + user_id: str, by_admin: bool, +) -> bool: +``` + +Called when the deactivation of a user is requested. User deactivation can be +performed by an admin or the user themselves, so developers are encouraged to check the +requester when implementing this callback. The module must return a +boolean indicating whether the deactivation can go through. If the callback returns `False`, +the deactivation will not proceed and the caller will see a `M_FORBIDDEN` error. + +The module is passed two parameters, `user_id` which is the ID of the user being deactivated, and `by_admin` which is `True` if the request is made by a serve admin, and `False` otherwise. + +If multiple modules implement this callback, they will be considered in order. If a +callback returns `True`, Synapse falls through to the next one. The value of the first +callback that does not return `True` will be used. If this happens, Synapse will not call +any of the subsequent implementations of this callback. + + ### `on_profile_update` _First introduced in Synapse v1.54.0_ diff --git a/mypy.ini b/mypy.ini index 481e8a5366..c8390ddba9 100644 --- a/mypy.ini +++ b/mypy.ini @@ -353,3 +353,6 @@ ignore_missing_imports = True [mypy-zope] ignore_missing_imports = True + +[mypy-incremental.*] +ignore_missing_imports = True diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index 429234d7ae..2d8ca018fb 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -20,7 +20,7 @@ from twisted.internet import protocol from twisted.internet.defer import Deferred class RedisProtocol(protocol.Protocol): - def publish(self, channel: str, message: bytes): ... + def publish(self, channel: str, message: bytes) -> "Deferred[None]": ... def ping(self) -> "Deferred[None]": ... def set( self, @@ -52,11 +52,14 @@ def lazyConnection( convertNumbers: bool = ..., ) -> RedisProtocol: ... -class ConnectionHandler: ... +# ConnectionHandler doesn't actually inherit from RedisProtocol, but it proxies +# most methods to it via ConnectionHandler.__getattr__. +class ConnectionHandler(RedisProtocol): + def disconnect(self) -> "Deferred[None]": ... class RedisFactory(protocol.ReconnectingClientFactory): continueTrying: bool - handler: RedisProtocol + handler: ConnectionHandler pool: List[RedisProtocol] replyTimeout: Optional[int] def __init__( diff --git a/synapse/__init__.py b/synapse/__init__.py index c6727024f0..4b00565976 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -25,6 +25,27 @@ if sys.version_info < (3, 7): print("Synapse requires Python 3.7 or above.") sys.exit(1) +# Allow using the asyncio reactor via env var. +if bool(os.environ.get("SYNAPSE_ASYNC_IO_REACTOR", False)): + try: + from incremental import Version + + import twisted + + # We need a bugfix that is included in Twisted 21.2.0: + # https://twistedmatrix.com/trac/ticket/9787 + if twisted.version < Version("Twisted", 21, 2, 0): + print("Using asyncio reactor requires Twisted>=21.2.0") + sys.exit(1) + + import asyncio + + from twisted.internet import asyncioreactor + + asyncioreactor.install(asyncio.get_event_loop()) + except ImportError: + pass + # Twisted and canonicaljson will fail to import when this file is executed to # get the __version__ during a fresh install. That's OK and subsequent calls to # actually start Synapse will import these libraries fine. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 36ace7c613..b0c08a074d 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -178,7 +178,9 @@ class RelationTypes: ANNOTATION: Final = "m.annotation" REPLACE: Final = "m.replace" REFERENCE: Final = "m.reference" - THREAD: Final = "io.element.thread" + THREAD: Final = "m.thread" + # TODO Remove this in Synapse >= v1.57.0. + UNSTABLE_THREAD: Final = "io.element.thread" class LimitBlockingTypes: diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index cb532d7238..27e97d6f37 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -88,7 +88,9 @@ ROOM_EVENT_FILTER_SCHEMA = { "org.matrix.labels": {"type": "array", "items": {"type": "string"}}, "org.matrix.not_labels": {"type": "array", "items": {"type": "string"}}, # MSC3440, filtering by event relations. + "related_by_senders": {"type": "array", "items": {"type": "string"}}, "io.element.relation_senders": {"type": "array", "items": {"type": "string"}}, + "related_by_rel_types": {"type": "array", "items": {"type": "string"}}, "io.element.relation_types": {"type": "array", "items": {"type": "string"}}, }, } @@ -318,19 +320,18 @@ class Filter: self.labels = filter_json.get("org.matrix.labels", None) self.not_labels = filter_json.get("org.matrix.not_labels", []) - # Ideally these would be rejected at the endpoint if they were provided - # and not supported, but that would involve modifying the JSON schema - # based on the homeserver configuration. + self.related_by_senders = self.filter_json.get("related_by_senders", None) + self.related_by_rel_types = self.filter_json.get("related_by_rel_types", None) + + # Fallback to the unstable prefix if the stable version is not given. if hs.config.experimental.msc3440_enabled: - self.relation_senders = self.filter_json.get( + self.related_by_senders = self.related_by_senders or self.filter_json.get( "io.element.relation_senders", None ) - self.relation_types = self.filter_json.get( - "io.element.relation_types", None + self.related_by_rel_types = ( + self.related_by_rel_types + or self.filter_json.get("io.element.relation_types", None) ) - else: - self.relation_senders = None - self.relation_types = None def filters_all_types(self) -> bool: return "*" in self.not_types @@ -461,7 +462,7 @@ class Filter: event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined] event_ids_to_keep = set( await self._store.events_have_relations( - event_ids, self.relation_senders, self.relation_types + event_ids, self.related_by_senders, self.related_by_rel_types ) ) @@ -474,7 +475,7 @@ class Filter: async def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: result = [event for event in events if self._check(event)] - if self.relation_senders or self.relation_types: + if self.related_by_senders or self.related_by_rel_types: return await self._check_event_relations(result) return result diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 1536a42723..a10a63b06c 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -417,7 +417,7 @@ class GenericWorkerServer(HomeServer): else: logger.warning("Unsupported listener type: %s", listener.type) - self.get_tcp_replication().start_replication(self) + self.get_replication_command_handler().start_replication(self) def start(config_options: List[str]) -> None: diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index a6789a840e..e4dc04c0b4 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -273,7 +273,7 @@ class SynapseHomeServer(HomeServer): # If redis is enabled we connect via the replication command handler # in the same way as the workers (since we're effectively a client # rather than a server). - self.get_tcp_replication().start_replication(self) + self.get_replication_command_handler().start_replication(self) for listener in self.config.server.listeners: if listener.type == "http": diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index ede72ee876..bfca454f51 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -38,6 +38,8 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[ [str, StateMap[EventBase], str], Awaitable[bool] ] ON_NEW_EVENT_CALLBACK = Callable[[EventBase, StateMap[EventBase]], Awaitable] +CHECK_CAN_SHUTDOWN_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]] +CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]] ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable] ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable] @@ -157,6 +159,12 @@ class ThirdPartyEventRules: CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = [] self._on_new_event_callbacks: List[ON_NEW_EVENT_CALLBACK] = [] + self._check_can_shutdown_room_callbacks: List[ + CHECK_CAN_SHUTDOWN_ROOM_CALLBACK + ] = [] + self._check_can_deactivate_user_callbacks: List[ + CHECK_CAN_DEACTIVATE_USER_CALLBACK + ] = [] self._on_profile_update_callbacks: List[ON_PROFILE_UPDATE_CALLBACK] = [] self._on_user_deactivation_status_changed_callbacks: List[ ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK @@ -173,6 +181,8 @@ class ThirdPartyEventRules: CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = None, on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None, + check_can_shutdown_room: Optional[CHECK_CAN_SHUTDOWN_ROOM_CALLBACK] = None, + check_can_deactivate_user: Optional[CHECK_CAN_DEACTIVATE_USER_CALLBACK] = None, on_profile_update: Optional[ON_PROFILE_UPDATE_CALLBACK] = None, on_user_deactivation_status_changed: Optional[ ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK @@ -198,6 +208,11 @@ class ThirdPartyEventRules: if on_new_event is not None: self._on_new_event_callbacks.append(on_new_event) + if check_can_shutdown_room is not None: + self._check_can_shutdown_room_callbacks.append(check_can_shutdown_room) + + if check_can_deactivate_user is not None: + self._check_can_deactivate_user_callbacks.append(check_can_deactivate_user) if on_profile_update is not None: self._on_profile_update_callbacks.append(on_profile_update) @@ -369,6 +384,46 @@ class ThirdPartyEventRules: "Failed to run module API callback %s: %s", callback, e ) + async def check_can_shutdown_room(self, user_id: str, room_id: str) -> bool: + """Intercept requests to shutdown a room. If `False` is returned, the + room must not be shut down. + + Args: + requester: The ID of the user requesting the shutdown. + room_id: The ID of the room. + """ + for callback in self._check_can_shutdown_room_callbacks: + try: + if await callback(user_id, room_id) is False: + return False + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) + return True + + async def check_can_deactivate_user( + self, + user_id: str, + by_admin: bool, + ) -> bool: + """Intercept requests to deactivate a user. If `False` is returned, the + user should not be deactivated. + + Args: + requester + user_id: The ID of the room. + """ + for callback in self._check_can_deactivate_user_callbacks: + try: + if await callback(user_id, by_admin) is False: + return False + except Exception as e: + logger.exception( + "Failed to run module API callback %s: %s", callback, e + ) + return True + async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]: """Given a room ID, return the state events of that room. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index ee34cb46e4..b2a237c1e0 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -38,6 +38,7 @@ from synapse.util.frozenutils import unfreeze from . import EventBase if TYPE_CHECKING: + from synapse.server import HomeServer from synapse.storage.databases.main.relations import BundledAggregations @@ -395,6 +396,9 @@ class EventClientSerializer: clients. """ + def __init__(self, hs: "HomeServer"): + self._msc3440_enabled = hs.config.experimental.msc3440_enabled + def serialize_event( self, event: Union[JsonDict, EventBase], @@ -515,11 +519,14 @@ class EventClientSerializer: thread.latest_event, serialized_latest_event, thread.latest_edit ) - serialized_aggregations[RelationTypes.THREAD] = { + thread_summary = { "latest_event": serialized_latest_event, "count": thread.count, "current_user_participated": thread.current_user_participated, } + serialized_aggregations[RelationTypes.THREAD] = thread_summary + if self._msc3440_enabled: + serialized_aggregations[RelationTypes.UNSTABLE_THREAD] = thread_summary # Include the bundled aggregations in the event. if serialized_aggregations: diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index 87e99c7ddf..2529dee613 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -63,7 +63,7 @@ class Authenticator: self.replication_client = None if hs.config.worker.worker_app: - self.replication_client = hs.get_tcp_replication() + self.replication_client = hs.get_replication_command_handler() # A method just so we can pass 'self' as the authenticator to the Servlets async def authenticate_request( diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 76ae768e6e..816e1a6d79 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Optional from synapse.api.errors import SynapseError from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import Requester, UserID, create_requester +from synapse.types import Codes, Requester, UserID, create_requester if TYPE_CHECKING: from synapse.server import HomeServer @@ -42,6 +42,7 @@ class DeactivateAccountHandler: # Flag that indicates whether the process to part users from rooms is running self._user_parter_running = False + self._third_party_rules = hs.get_third_party_event_rules() # Start the user parter loop so it can resume parting users from rooms where # it left off (if it has work left to do). @@ -74,6 +75,15 @@ class DeactivateAccountHandler: Returns: True if identity server supports removing threepids, otherwise False. """ + + # Check if this user can be deactivated + if not await self._third_party_rules.check_can_deactivate_user( + user_id, by_admin + ): + raise SynapseError( + 403, "Deactivation of this user is forbidden", Codes.FORBIDDEN + ) + # FIXME: Theoretically there is a race here wherein user resets # password using threepid. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index eb03a5accb..db39aeabde 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -23,8 +23,6 @@ from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 -from twisted.internet import defer - from synapse import event_auth from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import ( @@ -45,11 +43,7 @@ from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.federation.federation_client import InvalidResponseError from synapse.http.servlet import assert_params_in_dict -from synapse.logging.context import ( - make_deferred_yieldable, - nested_logging_context, - preserve_fn, -) +from synapse.logging.context import nested_logging_context from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, @@ -355,56 +349,8 @@ class FederationHandler: if success: return True - # Huh, well *those* domains didn't work out. Lets try some domains - # from the time. - - tried_domains = set(likely_domains) - tried_domains.add(self.server_name) - - event_ids = list(extremities.keys()) - - logger.debug("calling resolve_state_groups in _maybe_backfill") - resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events) - states_list = await make_deferred_yieldable( - defer.gatherResults( - [resolve(room_id, [e]) for e in event_ids], consumeErrors=True - ) - ) - - # A map from event_id to state map of event_ids. - state_ids: Dict[str, StateMap[str]] = dict( - zip(event_ids, [s.state for s in states_list]) - ) - - state_map = await self.store.get_events( - [e_id for ids in state_ids.values() for e_id in ids.values()], - get_prev_content=False, - ) - - # A map from event_id to state map of events. - state_events: Dict[str, StateMap[EventBase]] = { - key: { - k: state_map[e_id] - for k, e_id in state_dict.items() - if e_id in state_map - } - for key, state_dict in state_ids.items() - } - - for e_id in event_ids: - likely_extremeties_domains = get_domains_from_state(state_events[e_id]) - - success = await try_backfill( - [ - dom - for dom, _ in likely_extremeties_domains - if dom not in tried_domains - ] - ) - if success: - return True - - tried_domains.update(dom for dom, _ in likely_extremeties_domains) + # TODO: we could also try servers which were previously in the room, but + # are no longer. return False diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 316cfae24f..a7db8feb57 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -153,8 +153,9 @@ class InitialSyncHandler: public_room_ids = await self.store.get_public_room_ids() - limit = pagin_config.limit - if limit is None: + if pagin_config.limit is not None: + limit = pagin_config.limit + else: limit = 10 serializer_options = SerializeEventConfig(as_client_event=as_client_event) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 2123e97245..848acbee36 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1079,7 +1079,10 @@ class EventCreationHandler: raise SynapseError(400, "Can't send same reaction twice") # Don't attempt to start a thread if the parent event is a relation. - elif relation_type == RelationTypes.THREAD: + elif ( + relation_type == RelationTypes.THREAD + or relation_type == RelationTypes.UNSTABLE_THREAD + ): if await self.store.event_includes_relation(relates_to): raise SynapseError( 400, "Cannot start threads from an event with a relation" diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index c155098bee..9927a30e6e 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -424,13 +424,13 @@ class WorkerPresenceHandler(BasePresenceHandler): async def _on_shutdown(self) -> None: if self._presence_enabled: - self.hs.get_tcp_replication().send_command( + self.hs.get_replication_command_handler().send_command( ClearUserSyncsCommand(self.instance_id) ) def send_user_sync(self, user_id: str, is_syncing: bool, last_sync_ms: int) -> None: if self._presence_enabled: - self.hs.get_tcp_replication().send_user_sync( + self.hs.get_replication_command_handler().send_user_sync( self.instance_id, user_id, is_syncing, last_sync_ms ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 7b965b4b96..b9735631fc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1475,6 +1475,7 @@ class RoomShutdownHandler: self.room_member_handler = hs.get_room_member_handler() self._room_creation_handler = hs.get_room_creation_handler() self._replication = hs.get_replication_data_handler() + self._third_party_rules = hs.get_third_party_event_rules() self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastores().main @@ -1548,6 +1549,13 @@ class RoomShutdownHandler: if not RoomID.is_valid(room_id): raise SynapseError(400, "%s is not a legal room ID" % (room_id,)) + if not await self._third_party_rules.check_can_shutdown_room( + requester_user_id, room_id + ): + raise SynapseError( + 403, "Shutdown of this room is forbidden", Codes.FORBIDDEN + ) + # Action the block first (even if the room doesn't exist yet) if block: # This will work even if the room is already blocked, but that is diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 3979cbba71..486145f48a 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -295,7 +295,7 @@ class RoomSummaryHandler: # inaccessible to the requesting user. if room_entry: # Add the room (including the stripped m.space.child events). - rooms_result.append(room_entry.as_json()) + rooms_result.append(room_entry.as_json(for_client=True)) # If this room is not at the max-depth, check if there are any # children to process. @@ -843,14 +843,25 @@ class _RoomEntry: # This may not include all children. children_state_events: Sequence[JsonDict] = () - def as_json(self) -> JsonDict: + def as_json(self, for_client: bool = False) -> JsonDict: """ Returns a JSON dictionary suitable for the room hierarchy endpoint. It returns the room summary including the stripped m.space.child events as a sub-key. + + Args: + for_client: If true, any server-server only fields are stripped from + the result. + """ result = dict(self.room) + + # Before returning to the client, remove the allowed_room_ids key, if it + # exists. + if for_client: + result.pop("allowed_room_ids", False) + result["children_state"] = self.children_state_events return result diff --git a/synapse/logging/context.py b/synapse/logging/context.py index c31c2960ad..88cd8a9e1c 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -29,7 +29,6 @@ import warnings from types import TracebackType from typing import ( TYPE_CHECKING, - Any, Awaitable, Callable, Optional, @@ -41,7 +40,7 @@ from typing import ( ) import attr -from typing_extensions import Literal +from typing_extensions import Literal, ParamSpec from twisted.internet import defer, threads from twisted.python.threadpool import ThreadPool @@ -719,32 +718,33 @@ def nested_logging_context(suffix: str) -> LoggingContext: ) +P = ParamSpec("P") R = TypeVar("R") @overload def preserve_fn( # type: ignore[misc] - f: Callable[..., Awaitable[R]], -) -> Callable[..., "defer.Deferred[R]"]: + f: Callable[P, Awaitable[R]], +) -> Callable[P, "defer.Deferred[R]"]: # The `type: ignore[misc]` above suppresses # "Overloaded function signatures 1 and 2 overlap with incompatible return types" ... @overload -def preserve_fn(f: Callable[..., R]) -> Callable[..., "defer.Deferred[R]"]: +def preserve_fn(f: Callable[P, R]) -> Callable[P, "defer.Deferred[R]"]: ... def preserve_fn( f: Union[ - Callable[..., R], - Callable[..., Awaitable[R]], + Callable[P, R], + Callable[P, Awaitable[R]], ] -) -> Callable[..., "defer.Deferred[R]"]: +) -> Callable[P, "defer.Deferred[R]"]: """Function decorator which wraps the function with run_in_background""" - def g(*args: Any, **kwargs: Any) -> "defer.Deferred[R]": + def g(*args: P.args, **kwargs: P.kwargs) -> "defer.Deferred[R]": return run_in_background(f, *args, **kwargs) return g @@ -752,7 +752,7 @@ def preserve_fn( @overload def run_in_background( # type: ignore[misc] - f: Callable[..., Awaitable[R]], *args: Any, **kwargs: Any + f: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs ) -> "defer.Deferred[R]": # The `type: ignore[misc]` above suppresses # "Overloaded function signatures 1 and 2 overlap with incompatible return types" @@ -761,18 +761,22 @@ def run_in_background( # type: ignore[misc] @overload def run_in_background( - f: Callable[..., R], *args: Any, **kwargs: Any + f: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> "defer.Deferred[R]": ... -def run_in_background( +def run_in_background( # type: ignore[misc] + # The `type: ignore[misc]` above suppresses + # "Overloaded function implementation does not accept all possible arguments of signature 1" + # "Overloaded function implementation does not accept all possible arguments of signature 2" + # which seems like a bug in mypy. f: Union[ - Callable[..., R], - Callable[..., Awaitable[R]], + Callable[P, R], + Callable[P, Awaitable[R]], ], - *args: Any, - **kwargs: Any, + *args: P.args, + **kwargs: P.kwargs, ) -> "defer.Deferred[R]": """Calls a function, ensuring that the current context is restored after return from the function, and that the sentinel context is set once the @@ -872,7 +876,7 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT: def defer_to_thread( - reactor: "ISynapseReactor", f: Callable[..., R], *args: Any, **kwargs: Any + reactor: "ISynapseReactor", f: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> "defer.Deferred[R]": """ Calls the function `f` using a thread from the reactor's default threadpool and @@ -908,9 +912,9 @@ def defer_to_thread( def defer_to_threadpool( reactor: "ISynapseReactor", threadpool: ThreadPool, - f: Callable[..., R], - *args: Any, - **kwargs: Any, + f: Callable[P, R], + *args: P.args, + **kwargs: P.kwargs, ) -> "defer.Deferred[R]": """ A wrapper for twisted.internet.threads.deferToThreadpool, which handles diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index c42eeedd87..d735c1d461 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -54,6 +54,8 @@ from synapse.events.spamcheck import ( USER_MAY_SEND_3PID_INVITE_CALLBACK, ) from synapse.events.third_party_rules import ( + CHECK_CAN_DEACTIVATE_USER_CALLBACK, + CHECK_CAN_SHUTDOWN_ROOM_CALLBACK, CHECK_EVENT_ALLOWED_CALLBACK, CHECK_THREEPID_CAN_BE_INVITED_CALLBACK, CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK, @@ -283,6 +285,8 @@ class ModuleApi: CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK ] = None, on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None, + check_can_shutdown_room: Optional[CHECK_CAN_SHUTDOWN_ROOM_CALLBACK] = None, + check_can_deactivate_user: Optional[CHECK_CAN_DEACTIVATE_USER_CALLBACK] = None, on_profile_update: Optional[ON_PROFILE_UPDATE_CALLBACK] = None, on_user_deactivation_status_changed: Optional[ ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK @@ -298,6 +302,8 @@ class ModuleApi: check_threepid_can_be_invited=check_threepid_can_be_invited, check_visibility_can_be_modified=check_visibility_can_be_modified, on_new_event=on_new_event, + check_can_shutdown_room=check_can_shutdown_room, + check_can_deactivate_user=check_can_deactivate_user, on_profile_update=on_profile_update, on_user_deactivation_status_changed=on_user_deactivation_status_changed, ) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index b40a7bbb76..1dd39f06cf 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -76,7 +76,8 @@ REQUIREMENTS = [ "netaddr>=0.7.18", "Jinja2>=2.9", "bleach>=1.4.3", - "typing-extensions>=3.7.4", + # We use `ParamSpec`, which was added in `typing-extensions` 3.10.0.0. + "typing-extensions>=3.10.0", # We enforce that we have a `cryptography` version that bundles an `openssl` # with the latest security patches. "cryptography>=3.4.7", diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 2e697c74a6..f1abb98653 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple from prometheus_client import Counter, Gauge +from twisted.internet.error import ConnectError, DNSLookupError from twisted.web.server import Request from synapse.api.errors import HttpResponseException, SynapseError @@ -87,6 +88,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): `_handle_request` must return a Deferred. RETRY_ON_TIMEOUT(bool): Whether or not to retry the request when a 504 is received. + RETRY_ON_CONNECT_ERROR (bool): Whether or not to retry the request when + a connection error is received. + RETRY_ON_CONNECT_ERROR_ATTEMPTS (int): Number of attempts to retry when + receiving connection errors, each will backoff exponentially longer. """ NAME: str = abc.abstractproperty() # type: ignore @@ -94,6 +99,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): METHOD = "POST" CACHE = True RETRY_ON_TIMEOUT = True + RETRY_ON_CONNECT_ERROR = True + RETRY_ON_CONNECT_ERROR_ATTEMPTS = 5 # =63s (2^6-1) def __init__(self, hs: "HomeServer"): if self.CACHE: @@ -236,18 +243,20 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): "/".join(url_args), ) + headers: Dict[bytes, List[bytes]] = {} + # Add an authorization header, if configured. + if replication_secret: + headers[b"Authorization"] = [b"Bearer " + replication_secret] + opentracing.inject_header_dict(headers, check_destination=False) + try: + # Keep track of attempts made so we can bail if we don't manage to + # connect to the target after N tries. + attempts = 0 # We keep retrying the same request for timeouts. This is so that we # have a good idea that the request has either succeeded or failed # on the master, and so whether we should clean up or not. while True: - headers: Dict[bytes, List[bytes]] = {} - # Add an authorization header, if configured. - if replication_secret: - headers[b"Authorization"] = [ - b"Bearer " + replication_secret - ] - opentracing.inject_header_dict(headers, check_destination=False) try: result = await request_func(uri, data, headers=headers) break @@ -255,11 +264,27 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): if not cls.RETRY_ON_TIMEOUT: raise - logger.warning("%s request timed out; retrying", cls.NAME) + logger.warning("%s request timed out; retrying", cls.NAME) + + # If we timed out we probably don't need to worry about backing + # off too much, but lets just wait a little anyway. + await clock.sleep(1) + except (ConnectError, DNSLookupError) as e: + if not cls.RETRY_ON_CONNECT_ERROR: + raise + if attempts > cls.RETRY_ON_CONNECT_ERROR_ATTEMPTS: + raise + + delay = 2 ** attempts + logger.warning( + "%s request connection failed; retrying in %ds: %r", + cls.NAME, + delay, + e, + ) - # If we timed out we probably don't need to worry about backing - # off too much, but lets just wait a little anyway. - await clock.sleep(1) + await clock.sleep(delay) + attempts += 1 except HttpResponseException as e: # We convert to SynapseError as we know that it was a SynapseError # on the main process that we should send to the client. (And diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index b5b84c09ae..14706a0817 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -54,6 +54,6 @@ class SlavedClientIpStore(BaseSlavedStore): self.client_ip_last_seen.set(key, now) - self.hs.get_tcp_replication().send_user_ip( + self.hs.get_replication_command_handler().send_user_ip( user_id, access_token, ip, user_agent, device_id, now ) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index b8fc1d4db9..deeaaec4e6 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -462,6 +462,8 @@ class FederationSenderHandler: # We ACK this token over replication so that the master can drop # its in memory queues - self._hs.get_tcp_replication().send_federation_ack(current_position) + self._hs.get_replication_command_handler().send_federation_ack( + current_position + ) except Exception: logger.exception("Error updating federation stream position") diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py index aaf91e5e02..bf7d017968 100644 --- a/synapse/replication/tcp/external_cache.py +++ b/synapse/replication/tcp/external_cache.py @@ -21,7 +21,7 @@ from synapse.logging.context import make_deferred_yieldable from synapse.util import json_decoder, json_encoder if TYPE_CHECKING: - from txredisapi import RedisProtocol + from txredisapi import ConnectionHandler from synapse.server import HomeServer @@ -63,7 +63,7 @@ class ExternalCache: def __init__(self, hs: "HomeServer"): if hs.config.redis.redis_enabled: self._redis_connection: Optional[ - "RedisProtocol" + "ConnectionHandler" ] = hs.get_outbound_redis_connection() else: self._redis_connection = None diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 0d2013a3cf..d51f045f22 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -295,9 +295,7 @@ class ReplicationCommandHandler: raise Exception("Unrecognised command %s in stream queue", cmd.NAME) def start_replication(self, hs: "HomeServer") -> None: - """Helper method to start a replication connection to the remote server - using TCP. - """ + """Helper method to start replication.""" if hs.config.redis.redis_enabled: from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 3170f7c59b..989c5be032 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -93,7 +93,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): synapse_handler: "ReplicationCommandHandler" synapse_stream_name: str - synapse_outbound_redis_connection: txredisapi.RedisProtocol + synapse_outbound_redis_connection: txredisapi.ConnectionHandler def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -313,7 +313,7 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): protocol = RedisSubscriber def __init__( - self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol + self, hs: "HomeServer", outbound_redis_connection: txredisapi.ConnectionHandler ): super().__init__( @@ -325,7 +325,7 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): password=hs.config.redis.redis_password, ) - self.synapse_handler = hs.get_tcp_replication() + self.synapse_handler = hs.get_replication_command_handler() self.synapse_stream_name = hs.hostname self.synapse_outbound_redis_connection = outbound_redis_connection @@ -353,7 +353,7 @@ def lazyConnection( reconnect: bool = True, password: Optional[str] = None, replyTimeout: int = 30, -) -> txredisapi.RedisProtocol: +) -> txredisapi.ConnectionHandler: """Creates a connection to Redis that is lazily set up and reconnects if the connections is lost. """ diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 494e42a2be..ab829040cd 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -44,7 +44,7 @@ class ReplicationStreamProtocolFactory(ServerFactory): """Factory for new replication connections.""" def __init__(self, hs: "HomeServer"): - self.command_handler = hs.get_tcp_replication() + self.command_handler = hs.get_replication_command_handler() self.clock = hs.get_clock() self.server_name = hs.config.server.server_name @@ -85,7 +85,7 @@ class ReplicationStreamer: self.is_looping = False self.pending_updates = False - self.command_handler = hs.get_tcp_replication() + self.command_handler = hs.get_replication_command_handler() # Set of streams to replicate. self.streams = self.command_handler.get_streams_to_replicate() diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f4736a3dad..356d6f74d7 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -67,6 +67,7 @@ class RoomRestV2Servlet(RestServlet): self._auth = hs.get_auth() self._store = hs.get_datastores().main self._pagination_handler = hs.get_pagination_handler() + self._third_party_rules = hs.get_third_party_event_rules() async def on_DELETE( self, request: SynapseRequest, room_id: str @@ -106,6 +107,14 @@ class RoomRestV2Servlet(RestServlet): HTTPStatus.BAD_REQUEST, "%s is not a legal room ID" % (room_id,) ) + # Check this here, as otherwise we'll only fail after the background job has been started. + if not await self._third_party_rules.check_can_shutdown_room( + requester.user.to_string(), room_id + ): + raise SynapseError( + 403, "Shutdown of this room is forbidden", Codes.FORBIDDEN + ) + delete_id = self._pagination_handler.start_shutdown_and_purge_room( room_id=room_id, new_room_user_id=content.get("new_room_user_id"), diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 07fa1cdd4c..d9a6be43f7 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -27,7 +27,7 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.storage.relations import AggregationPaginationToken, PaginationChunk +from synapse.storage.relations import AggregationPaginationToken from synapse.types import JsonDict, StreamToken if TYPE_CHECKING: @@ -82,28 +82,25 @@ class RelationPaginationServlet(RestServlet): from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") - if event.internal_metadata.is_redacted(): - # If the event is redacted, return an empty list of relations - pagination_chunk = PaginationChunk(chunk=[]) - else: - # Return the relations - from_token = None - if from_token_str: - from_token = await StreamToken.from_string(self.store, from_token_str) - to_token = None - if to_token_str: - to_token = await StreamToken.from_string(self.store, to_token_str) - - pagination_chunk = await self.store.get_relations_for_event( - event_id=parent_id, - room_id=room_id, - relation_type=relation_type, - event_type=event_type, - limit=limit, - direction=direction, - from_token=from_token, - to_token=to_token, - ) + # Return the relations + from_token = None + if from_token_str: + from_token = await StreamToken.from_string(self.store, from_token_str) + to_token = None + if to_token_str: + to_token = await StreamToken.from_string(self.store, to_token_str) + + pagination_chunk = await self.store.get_relations_for_event( + event_id=parent_id, + event=event, + room_id=room_id, + relation_type=relation_type, + event_type=event_type, + limit=limit, + direction=direction, + from_token=from_token, + to_token=to_token, + ) events = await self.store.get_events_as_list( [c["event_id"] for c in pagination_chunk.chunk] @@ -193,27 +190,23 @@ class RelationAggregationPaginationServlet(RestServlet): from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") - if event.internal_metadata.is_redacted(): - # If the event is redacted, return an empty list of relations - pagination_chunk = PaginationChunk(chunk=[]) - else: - # Return the relations - from_token = None - if from_token_str: - from_token = AggregationPaginationToken.from_string(from_token_str) - - to_token = None - if to_token_str: - to_token = AggregationPaginationToken.from_string(to_token_str) - - pagination_chunk = await self.store.get_aggregation_groups_for_event( - event_id=parent_id, - room_id=room_id, - event_type=event_type, - limit=limit, - from_token=from_token, - to_token=to_token, - ) + # Return the relations + from_token = None + if from_token_str: + from_token = AggregationPaginationToken.from_string(from_token_str) + + to_token = None + if to_token_str: + to_token = AggregationPaginationToken.from_string(to_token_str) + + pagination_chunk = await self.store.get_aggregation_groups_for_event( + event_id=parent_id, + room_id=room_id, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) return 200, await pagination_chunk.to_dict(self.store) @@ -295,6 +288,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): result = await self.store.get_relations_for_event( event_id=parent_id, + event=event, room_id=room_id, relation_type=relation_type, event_type=event_type, diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 2e5d0e4e22..9a65aa4843 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -101,6 +101,7 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3030": self.config.experimental.msc3030_enabled, # Adds support for thread relations, per MSC3440. "org.matrix.msc3440": self.config.experimental.msc3440_enabled, + "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above }, }, ) diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 18bf977d3d..1c9b71d69c 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -16,7 +16,7 @@ import abc import logging import os import shutil -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Callable, Optional from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background @@ -150,8 +150,13 @@ class FileStorageProviderBackend(StorageProvider): dirname = os.path.dirname(backup_fname) os.makedirs(dirname, exist_ok=True) + # mypy needs help inferring the type of the second parameter, which is generic + shutil_copyfile: Callable[[str, str], str] = shutil.copyfile await defer_to_thread( - self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname + self.hs.get_reactor(), + shutil_copyfile, + primary_fname, + backup_fname, ) async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: diff --git a/synapse/server.py b/synapse/server.py index b5e2a319bc..7741ff29dc 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -145,7 +145,7 @@ from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) if TYPE_CHECKING: - from txredisapi import RedisProtocol + from txredisapi import ConnectionHandler from synapse.handlers.oidc import OidcHandler from synapse.handlers.saml import SamlHandler @@ -639,7 +639,7 @@ class HomeServer(metaclass=abc.ABCMeta): return ReadMarkerHandler(self) @cache_in_self - def get_tcp_replication(self) -> ReplicationCommandHandler: + def get_replication_command_handler(self) -> ReplicationCommandHandler: return ReplicationCommandHandler(self) @cache_in_self @@ -754,7 +754,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_event_client_serializer(self) -> EventClientSerializer: - return EventClientSerializer() + return EventClientSerializer(self) @cache_in_self def get_password_policy_handler(self) -> PasswordPolicyHandler: @@ -807,7 +807,7 @@ class HomeServer(metaclass=abc.ABCMeta): return AccountHandler(self) @cache_in_self - def get_outbound_redis_connection(self) -> "RedisProtocol": + def get_outbound_redis_connection(self) -> "ConnectionHandler": """ The Redis connection used for replication. diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index abd54c7dc7..d6a2df1afe 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -191,6 +191,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if redacts: self._invalidate_get_event_cache(redacts) + # Caches which might leak edits must be invalidated for the event being + # redacted. + self.get_relations_for_event.invalidate((redacts,)) + self.get_applicable_edit.invalidate((redacts,)) if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1dc83aa5e3..1f60aef180 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1619,9 +1619,12 @@ class PersistEventsStore: txn.call_after(prefill) - def _store_redaction(self, txn, event): - # invalidate the cache for the redacted event + def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None: + # Invalidate the caches for the redacted event, note that these caches + # are also cleared as part of event replication in _invalidate_caches_for_event. txn.call_after(self.store._invalidate_get_event_cache, event.redacts) + txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,)) + txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,)) self.db_pool.simple_upsert_txn( txn, @@ -1811,10 +1814,11 @@ class PersistEventsStore: if rel_type == RelationTypes.REPLACE: txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) - if rel_type == RelationTypes.THREAD: - txn.call_after( - self.store.get_thread_summary.invalidate, (parent_id, event.room_id) - ) + if ( + rel_type == RelationTypes.THREAD + or rel_type == RelationTypes.UNSTABLE_THREAD + ): + txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) # It should be safe to only invalidate the cache if the user has not # previously participated in the thread, but that's difficult (and # potentially error-prone) so it is always invalidated. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 26784f755e..59454a47df 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1286,7 +1286,7 @@ class EventsWorkerStore(SQLBaseStore): ) return {eid for ((_rid, eid), have_event) in res.items() if have_event} - @cachedList("have_seen_event", "keys") + @cachedList(cached_method_name="have_seen_event", list_name="keys") async def _have_seen_events_dict( self, keys: Iterable[Tuple[str, str]] ) -> Dict[Tuple[str, str], bool]: @@ -1954,7 +1954,7 @@ class EventsWorkerStore(SQLBaseStore): get_event_id_for_timestamp_txn, ) - @cachedList("is_partial_state_event", list_name="event_ids") + @cachedList(cached_method_name="is_partial_state_event", list_name="event_ids") async def get_partial_state_events( self, event_ids: Collection[str] ) -> Dict[str, bool]: diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 36aa1092f6..c4869d64e6 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -91,10 +91,11 @@ class RelationsWorkerStore(SQLBaseStore): self._msc3440_enabled = hs.config.experimental.msc3440_enabled - @cached(tree=True) + @cached(uncached_args=("event",), tree=True) async def get_relations_for_event( self, event_id: str, + event: EventBase, room_id: str, relation_type: Optional[str] = None, event_type: Optional[str] = None, @@ -108,6 +109,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. + event: The matching EventBase to event_id. room_id: The room the event belongs to. relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. @@ -122,9 +124,13 @@ class RelationsWorkerStore(SQLBaseStore): List of event IDs that match relations requested. The rows are of the form `{"event_id": "..."}`. """ + # We don't use `event_id`, it's there so that we can cache based on + # it. The `event_id` must match the `event.event_id`. + assert event.event_id == event_id where_clause = ["relates_to_id = ?", "room_id = ?"] - where_args: List[Union[str, int]] = [event_id, room_id] + where_args: List[Union[str, int]] = [event.event_id, room_id] + is_redacted = event.internal_metadata.is_redacted() if relation_type is not None: where_clause.append("relation_type = ?") @@ -157,7 +163,7 @@ class RelationsWorkerStore(SQLBaseStore): order = "ASC" sql = """ - SELECT event_id, topological_ordering, stream_ordering + SELECT event_id, relation_type, topological_ordering, stream_ordering FROM event_relations INNER JOIN events USING (event_id) WHERE %s @@ -178,9 +184,12 @@ class RelationsWorkerStore(SQLBaseStore): last_stream_id = None events = [] for row in txn: - events.append({"event_id": row[0]}) - last_topo_id = row[1] - last_stream_id = row[2] + # Do not include edits for redacted events as they leak event + # content. + if not is_redacted or row[1] != RelationTypes.REPLACE: + events.append({"event_id": row[0]}) + last_topo_id = row[2] + last_stream_id = row[3] # If there are more events, generate the next pagination key. next_token = None @@ -499,7 +508,7 @@ class RelationsWorkerStore(SQLBaseStore): AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC """ else: @@ -514,16 +523,22 @@ class RelationsWorkerStore(SQLBaseStore): AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s ORDER BY child.topological_ordering DESC, child.stream_ordering DESC """ clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", event_ids ) - args.append(RelationTypes.THREAD) - txn.execute(sql % (clause,), args) + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD)) + else: + relations_clause = "relation_type = ?" + args.append(RelationTypes.THREAD) + + txn.execute(sql % (clause, relations_clause), args) latest_event_ids = {} for parent_event_id, child_event_id in txn: # Only consider the latest threaded reply (by topological ordering). @@ -543,7 +558,7 @@ class RelationsWorkerStore(SQLBaseStore): AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s GROUP BY parent.event_id """ @@ -552,9 +567,15 @@ class RelationsWorkerStore(SQLBaseStore): clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", latest_event_ids.keys() ) - args.append(RelationTypes.THREAD) - txn.execute(sql % (clause,), args) + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD)) + else: + relations_clause = "relation_type = ?" + args.append(RelationTypes.THREAD) + + txn.execute(sql % (clause, relations_clause), args) counts = dict(cast(List[Tuple[str, int]], txn.fetchall())) return counts, latest_event_ids @@ -617,16 +638,24 @@ class RelationsWorkerStore(SQLBaseStore): AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s AND child.sender = ? """ clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", event_ids ) - args.extend((RelationTypes.THREAD, user_id)) - txn.execute(sql % (clause,), args) + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD)) + else: + relations_clause = "relation_type = ?" + args.append(RelationTypes.THREAD) + + args.append(user_id) + + txn.execute(sql % (clause, relations_clause), args) return {row[0] for row in txn.fetchall()} participated_threads = await self.db_pool.runInteraction( @@ -776,7 +805,7 @@ class RelationsWorkerStore(SQLBaseStore): ) references = await self.get_relations_for_event( - event_id, room_id, RelationTypes.REFERENCE, direction="f" + event_id, event, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: aggregations.references = await references.to_dict(cast("DataStore", self)) @@ -797,59 +826,51 @@ class RelationsWorkerStore(SQLBaseStore): A map of event ID to the bundled aggregation for the event. Not all events may have bundled aggregations in the results. """ - # The already processed event IDs. Tracked separately from the result - # since the result omits events which do not have bundled aggregations. - seen_event_ids = set() - - # State events and redacted events do not get bundled aggregations. - events = [ - event - for event in events - if not event.is_state() and not event.internal_metadata.is_redacted() - ] + # De-duplicate events by ID to handle the same event requested multiple times. + # + # State events do not get bundled aggregations. + events_by_id = { + event.event_id: event for event in events if not event.is_state() + } # event ID -> bundled aggregation in non-serialized form. results: Dict[str, BundledAggregations] = {} # Fetch other relations per event. - for event in events: - # De-duplicate events by ID to handle the same event requested multiple - # times. The caches that _get_bundled_aggregation_for_event use should - # capture this, but best to reduce work. - if event.event_id in seen_event_ids: - continue - seen_event_ids.add(event.event_id) - + for event in events_by_id.values(): event_result = await self._get_bundled_aggregation_for_event(event, user_id) if event_result: results[event.event_id] = event_result - # Fetch any edits. - edits = await self._get_applicable_edits(seen_event_ids) + # Fetch any edits (but not for redacted events). + edits = await self._get_applicable_edits( + [ + event_id + for event_id, event in events_by_id.items() + if not event.internal_metadata.is_redacted() + ] + ) for event_id, edit in edits.items(): results.setdefault(event_id, BundledAggregations()).replace = edit # Fetch thread summaries. - if self._msc3440_enabled: - summaries = await self._get_thread_summaries(seen_event_ids) - # Only fetch participated for a limited selection based on what had - # summaries. - participated = await self._get_threads_participated( - summaries.keys(), user_id - ) - for event_id, summary in summaries.items(): - if summary: - thread_count, latest_thread_event, edit = summary - results.setdefault( - event_id, BundledAggregations() - ).thread = _ThreadAggregation( - latest_event=latest_thread_event, - latest_edit=edit, - count=thread_count, - # If there's a thread summary it must also exist in the - # participated dictionary. - current_user_participated=participated[event_id], - ) + summaries = await self._get_thread_summaries(events_by_id.keys()) + # Only fetch participated for a limited selection based on what had + # summaries. + participated = await self._get_threads_participated(summaries.keys(), user_id) + for event_id, summary in summaries.items(): + if summary: + thread_count, latest_thread_event, edit = summary + results.setdefault( + event_id, BundledAggregations() + ).thread = _ThreadAggregation( + latest_event=latest_thread_event, + latest_edit=edit, + count=thread_count, + # If there's a thread summary it must also exist in the + # participated dictionary. + current_user_participated=participated[event_id], + ) return results diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index e48ec5f495..bef675b845 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -46,7 +46,7 @@ from synapse.storage.roommember import ( ProfileInfo, RoomsForUser, ) -from synapse.types import PersistedEventPosition, StateMap, get_domain_from_id +from synapse.types import PersistedEventPosition, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList @@ -273,7 +273,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql, (room_id,)) res = {} for count, membership in txn: - summary = res.setdefault(membership, MemberSummary([], count)) + res.setdefault(membership, MemberSummary([], count)) # we order by membership and then fairly arbitrarily by event_id so # heroes are consistent @@ -839,18 +839,14 @@ class RoomMemberWorkerStore(EventsWorkerStore): with Measure(self._clock, "get_joined_hosts"): return await self._get_joined_hosts( - room_id, state_group, state_entry.state, state_entry=state_entry + room_id, state_group, state_entry=state_entry ) @cached(num_args=2, max_entries=10000, iterable=True) async def _get_joined_hosts( - self, - room_id: str, - state_group: int, - current_state_ids: StateMap[str], - state_entry: "_StateCacheEntry", + self, room_id: str, state_group: int, state_entry: "_StateCacheEntry" ) -> FrozenSet[str]: - # We don't use `state_group`, its there so that we can cache based on + # We don't use `state_group`, it's there so that we can cache based on # it. However, its important that its never None, since two # current_state's with a state_group of None are likely to be different. # diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index a898f847e7..39e1efe373 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -325,21 +325,23 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: args.extend(event_filter.labels) # Filter on relation_senders / relation types from the joined tables. - if event_filter.relation_senders: + if event_filter.related_by_senders: clauses.append( "(%s)" % " OR ".join( - "related_event.sender = ?" for _ in event_filter.relation_senders + "related_event.sender = ?" for _ in event_filter.related_by_senders ) ) - args.extend(event_filter.relation_senders) + args.extend(event_filter.related_by_senders) - if event_filter.relation_types: + if event_filter.related_by_rel_types: clauses.append( "(%s)" - % " OR ".join("relation_type = ?" for _ in event_filter.relation_types) + % " OR ".join( + "relation_type = ?" for _ in event_filter.related_by_rel_types + ) ) - args.extend(event_filter.relation_types) + args.extend(event_filter.related_by_rel_types) return " AND ".join(clauses), args @@ -1203,7 +1205,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # If there is a filter on relation_senders and relation_types join to the # relations table. if event_filter and ( - event_filter.relation_senders or event_filter.relation_types + event_filter.related_by_senders or event_filter.related_by_rel_types ): # Filtering by relations could cause the same event to appear multiple # times (since there's no limit on the number of relations to an event). @@ -1211,7 +1213,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): join_clause += """ LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id) """ - if event_filter.relation_senders: + if event_filter.related_by_senders: join_clause += """ LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id) """ diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 1cdead02f1..c3c5c16db9 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -20,6 +20,7 @@ from typing import ( Any, Awaitable, Callable, + Collection, Dict, Generic, Hashable, @@ -69,6 +70,7 @@ class _CacheDescriptorBase: self, orig: Callable[..., Any], num_args: Optional[int], + uncached_args: Optional[Collection[str]] = None, cache_context: bool = False, ): self.orig = orig @@ -76,6 +78,13 @@ class _CacheDescriptorBase: arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args + # There's no reason that keyword-only arguments couldn't be supported, + # but right now they're buggy so do not allow them. + if arg_spec.kwonlyargs: + raise ValueError( + "_CacheDescriptorBase does not support keyword-only arguments." + ) + if "cache_context" in all_args: if not cache_context: raise ValueError( @@ -88,6 +97,9 @@ class _CacheDescriptorBase: " named `cache_context`" ) + if num_args is not None and uncached_args is not None: + raise ValueError("Cannot provide both num_args and uncached_args") + if num_args is None: num_args = len(all_args) - 1 if cache_context: @@ -105,6 +117,12 @@ class _CacheDescriptorBase: # list of the names of the args used as the cache key self.arg_names = all_args[1 : num_args + 1] + # If there are args to not cache on, filter them out (and fix the size of num_args). + if uncached_args is not None: + include_arg_in_cache_key = [n not in uncached_args for n in self.arg_names] + else: + include_arg_in_cache_key = [True] * len(self.arg_names) + # self.arg_defaults is a map of arg name to its default value for each # argument that has a default value if arg_spec.defaults: @@ -119,8 +137,8 @@ class _CacheDescriptorBase: self.add_cache_context = cache_context - self.cache_key_builder = get_cache_key_builder( - self.arg_names, self.arg_defaults + self.cache_key_builder = _get_cache_key_builder( + self.arg_names, include_arg_in_cache_key, self.arg_defaults ) @@ -130,8 +148,7 @@ class _LruCachedFunction(Generic[F]): def lru_cache( - max_entries: int = 1000, - cache_context: bool = False, + *, max_entries: int = 1000, cache_context: bool = False ) -> Callable[[F], _LruCachedFunction[F]]: """A method decorator that applies a memoizing cache around the function. @@ -186,7 +203,9 @@ class LruCacheDescriptor(_CacheDescriptorBase): max_entries: int = 1000, cache_context: bool = False, ): - super().__init__(orig, num_args=None, cache_context=cache_context) + super().__init__( + orig, num_args=None, uncached_args=None, cache_context=cache_context + ) self.max_entries = max_entries def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: @@ -260,6 +279,9 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): num_args: number of positional arguments (excluding ``self`` and ``cache_context``) to use as cache keys. Defaults to all named args of the function. + uncached_args: a list of argument names to not use as the cache key. + (``self`` and ``cache_context`` are always ignored.) Cannot be used + with num_args. tree: cache_context: iterable: @@ -273,12 +295,18 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): orig: Callable[..., Any], max_entries: int = 1000, num_args: Optional[int] = None, + uncached_args: Optional[Collection[str]] = None, tree: bool = False, cache_context: bool = False, iterable: bool = False, prune_unread_entries: bool = True, ): - super().__init__(orig, num_args=num_args, cache_context=cache_context) + super().__init__( + orig, + num_args=num_args, + uncached_args=uncached_args, + cache_context=cache_context, + ) if tree and self.num_args < 2: raise RuntimeError( @@ -369,7 +397,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): but including list_name) to use as cache keys. Defaults to all named args of the function. """ - super().__init__(orig, num_args=num_args) + super().__init__(orig, num_args=num_args, uncached_args=None) self.list_name = list_name @@ -530,8 +558,10 @@ class _CacheContext: def cached( + *, max_entries: int = 1000, num_args: Optional[int] = None, + uncached_args: Optional[Collection[str]] = None, tree: bool = False, cache_context: bool = False, iterable: bool = False, @@ -541,6 +571,7 @@ def cached( orig, max_entries=max_entries, num_args=num_args, + uncached_args=uncached_args, tree=tree, cache_context=cache_context, iterable=iterable, @@ -551,7 +582,7 @@ def cached( def cachedList( - cached_method_name: str, list_name: str, num_args: Optional[int] = None + *, cached_method_name: str, list_name: str, num_args: Optional[int] = None ) -> Callable[[F], _CachedFunction[F]]: """Creates a descriptor that wraps a function in a `CacheListDescriptor`. @@ -590,13 +621,16 @@ def cachedList( return cast(Callable[[F], _CachedFunction[F]], func) -def get_cache_key_builder( - param_names: Sequence[str], param_defaults: Mapping[str, Any] +def _get_cache_key_builder( + param_names: Sequence[str], + include_params: Sequence[bool], + param_defaults: Mapping[str, Any], ) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]: """Construct a function which will build cache keys suitable for a cached function Args: param_names: list of formal parameter names for the cached function + include_params: list of bools of whether to include the parameter name in the cache key param_defaults: a mapping from parameter name to default value for that param Returns: @@ -608,6 +642,7 @@ def get_cache_key_builder( if len(param_names) == 1: nm = param_names[0] + assert include_params[0] is True def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: if nm in kwargs: @@ -620,13 +655,18 @@ def get_cache_key_builder( else: def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: - return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs)) + return tuple( + _get_cache_key_gen( + param_names, include_params, param_defaults, args, kwargs + ) + ) return get_cache_key def _get_cache_key_gen( param_names: Iterable[str], + include_params: Iterable[bool], param_defaults: Mapping[str, Any], args: Sequence[Any], kwargs: Mapping[str, Any], @@ -637,16 +677,18 @@ def _get_cache_key_gen( This is essentially the same operation as `inspect.getcallargs`, but optimised so that we don't need to inspect the target function for each call. """ - # We loop through each arg name, looking up if its in the `kwargs`, # otherwise using the next argument in `args`. If there are no more # args then we try looking the arg name up in the defaults. pos = 0 - for nm in param_names: + for nm, inc in zip(param_names, include_params): if nm in kwargs: - yield kwargs[nm] + if inc: + yield kwargs[nm] elif pos < len(args): - yield args[pos] + if inc: + yield args[pos] pos += 1 else: - yield param_defaults[nm] + if inc: + yield param_defaults[nm] diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index cff07a8973..d37292ce13 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -172,6 +172,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): result_room_ids = [] result_children_ids = [] for result_room in result["rooms"]: + # Ensure federation results are not leaking over the client-server API. + self.assertNotIn("allowed_room_ids", result_room) + result_room_ids.append(result_room["room_id"]) result_children_ids.append( [ diff --git a/tests/replication/_base.py b/tests/replication/_base.py index a7a05a564f..9c5df266bd 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -251,7 +251,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.connect_any_redis_attempts, ) - self.hs.get_tcp_replication().start_replication(self.hs) + self.hs.get_replication_command_handler().start_replication(self.hs) # When we see a connection attempt to the master replication listener we # automatically set up the connection. This is so that tests don't @@ -375,7 +375,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): ) if worker_hs.config.redis.redis_enabled: - worker_hs.get_tcp_replication().start_replication(worker_hs) + worker_hs.get_replication_command_handler().start_replication(worker_hs) return worker_hs diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index f9d5da723c..641a94133b 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -420,7 +420,7 @@ class EventsStreamTestCase(BaseStreamTestCase): # Manually send an old RDATA command, which should get dropped. This # re-uses the row from above, but with an earlier stream token. - self.hs.get_tcp_replication().send_command( + self.hs.get_replication_command_handler().send_command( RdataCommand("events", "master", 1, row) ) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index 3ff5afc6e5..9a229dd23f 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -118,7 +118,7 @@ class TypingStreamTestCase(BaseStreamTestCase): # Reset the typing handler self.hs.get_replication_streams()["typing"].last_token = 0 - self.hs.get_tcp_replication()._streams["typing"].last_token = 0 + self.hs.get_replication_command_handler()._streams["typing"].last_token = 0 typing._latest_room_serial = 0 typing._typing_stream_change_cache = StreamChangeCache( "TypingStreamChangeCache", typing._latest_room_serial diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 1b6a4bf4b0..26b8bd512a 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -48,7 +48,7 @@ class FederationAckTestCase(HomeserverTestCase): transport, rather than assuming that the implementation has a ReplicationCommandHandler. """ - rch = self.hs.get_tcp_replication() + rch = self.hs.get_replication_command_handler() # wire up the ReplicationCommandHandler to a mock connection, which needs # to implement IReplicationConnection. (Note that Mock doesn't understand diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index a40a5de399..0cbe6c0cf7 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -547,9 +547,7 @@ class RelationsTestCase(BaseRelationsTestCase): ) self.assertEqual(400, channel.code, channel.json_body) - @unittest.override_config( - {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}} - ) + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) def test_bundled_aggregations(self) -> None: """ Test that annotations, references, and threads get correctly bundled. @@ -758,7 +756,6 @@ class RelationsTestCase(BaseRelationsTestCase): }, ) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_ignore_invalid_room(self) -> None: """Test that we ignore invalid relations over federation.""" # Create another room and send a message in it. @@ -1065,7 +1062,6 @@ class RelationsTestCase(BaseRelationsTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_edit_thread(self) -> None: """Test that editing a thread works.""" @@ -1383,7 +1379,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): chunk = self._get_aggregations() self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}]) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_redact_relation_thread(self) -> None: """ Test that thread replies are properly handled after the thread reply redacted. @@ -1475,12 +1470,13 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self.assertEqual(relations, {}) def test_redact_parent_annotation(self) -> None: - """Test that annotations of an event are redacted when the original event + """Test that annotations of an event are viewable when the original event is redacted. """ # Add a relation channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") self.assertEqual(200, channel.code, channel.json_body) + related_event_id = channel.json_body["event_id"] # The relations should exist. event_ids, relations = self._make_relation_requests() @@ -1494,11 +1490,45 @@ class RelationRedactionTestCase(BaseRelationsTestCase): # Redact the original event. self._redact(self.parent_id) - # The relations are not returned. + # The relations are returned. event_ids, relations = self._make_relation_requests() - self.assertEqual(event_ids, []) - self.assertEqual(relations, {}) + self.assertEquals(event_ids, [related_event_id]) + self.assertEquals( + relations["m.annotation"], + {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]}, + ) # There's nothing to aggregate. chunk = self._get_aggregations() - self.assertEqual(chunk, []) + self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}]) + + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_redact_parent_thread(self) -> None: + """ + Test that thread replies are still available when the root event is redacted. + """ + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": "reply 1", "msgtype": "m.text"}, + ) + self.assertEqual(200, channel.code, channel.json_body) + related_event_id = channel.json_body["event_id"] + + # Redact one of the reactions. + self._redact(self.parent_id) + + # The unredacted relation should still exist. + event_ids, relations = self._make_relation_requests() + self.assertEquals(len(event_ids), 1) + self.assertDictContainsSubset( + { + "count": 1, + "current_user_participated": True, + }, + relations[RelationTypes.THREAD], + ) + self.assertEqual( + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + related_event_id, + ) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 37866ee330..3a9617d6da 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2141,21 +2141,19 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_filter_relation_senders(self) -> None: # Messages which second user reacted to. - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_1) # Messages which third user reacted to. - filter = {"io.element.relation_senders": [self.third_user_id]} + filter = {"related_by_senders": [self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_2) # Messages which either user reacted to. - filter = { - "io.element.relation_senders": [self.second_user_id, self.third_user_id] - } + filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 2, chunk) self.assertCountEqual( @@ -2164,20 +2162,20 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_filter_relation_type(self) -> None: # Messages which have annotations. - filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} + filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_1) # Messages which have references. - filter = {"io.element.relation_types": [RelationTypes.REFERENCE]} + filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_2) # Messages which have either annotations or references. filter = { - "io.element.relation_types": [ + "related_by_rel_types": [ RelationTypes.ANNOTATION, RelationTypes.REFERENCE, ] @@ -2191,8 +2189,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_filter_relation_senders_and_type(self) -> None: # Messages which second user reacted to. filter = { - "io.element.relation_senders": [self.second_user_id], - "io.element.relation_types": [RelationTypes.ANNOTATION], + "related_by_senders": [self.second_user_id], + "related_by_rel_types": [RelationTypes.ANNOTATION], } chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 58f1ea11b7..e7de67e3a3 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -775,3 +775,124 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): self.assertEqual(args[0], user_id) self.assertFalse(args[1]) self.assertTrue(args[2]) + + def test_check_can_deactivate_user(self) -> None: + """Tests that the on_user_deactivation_status_changed module callback is called + correctly when processing a user's deactivation. + """ + # Register a mocked callback. + deactivation_mock = Mock(return_value=make_awaitable(False)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._check_can_deactivate_user_callbacks.append( + deactivation_mock, + ) + + # Register a user that we'll deactivate. + user_id = self.register_user("altan", "password") + tok = self.login("altan", "password") + + # Deactivate that user. + channel = self.make_request( + "POST", + "/_matrix/client/v3/account/deactivate", + { + "auth": { + "type": LoginType.PASSWORD, + "password": "password", + "identifier": { + "type": "m.id.user", + "user": user_id, + }, + }, + "erase": True, + }, + access_token=tok, + ) + + # Check that the deactivation was blocked + self.assertEqual(channel.code, 403, channel.json_body) + + # Check that the mock was called once. + deactivation_mock.assert_called_once() + args = deactivation_mock.call_args[0] + + # Check that the mock was called with the right user ID + self.assertEqual(args[0], user_id) + + # Check that the request was not made by an admin + self.assertEqual(args[1], False) + + def test_check_can_deactivate_user_admin(self) -> None: + """Tests that the on_user_deactivation_status_changed module callback is called + correctly when processing a user's deactivation triggered by a server admin. + """ + # Register a mocked callback. + deactivation_mock = Mock(return_value=make_awaitable(False)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._check_can_deactivate_user_callbacks.append( + deactivation_mock, + ) + + # Register an admin user. + self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Register a user that we'll deactivate. + user_id = self.register_user("altan", "password") + + # Deactivate the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % user_id, + {"deactivated": True}, + access_token=admin_tok, + ) + + # Check that the deactivation was blocked + self.assertEqual(channel.code, 403, channel.json_body) + + # Check that the mock was called once. + deactivation_mock.assert_called_once() + args = deactivation_mock.call_args[0] + + # Check that the mock was called with the right user ID + self.assertEqual(args[0], user_id) + + # Check that the mock was made by an admin + self.assertEqual(args[1], True) + + def test_check_can_shutdown_room(self) -> None: + """Tests that the check_can_shutdown_room module callback is called + correctly when processing an admin's shutdown room request. + """ + # Register a mocked callback. + shutdown_mock = Mock(return_value=make_awaitable(False)) + third_party_rules = self.hs.get_third_party_event_rules() + third_party_rules._check_can_shutdown_room_callbacks.append( + shutdown_mock, + ) + + # Register an admin user. + admin_user_id = self.register_user("admin", "password", admin=True) + admin_tok = self.login("admin", "password") + + # Shutdown the room. + channel = self.make_request( + "DELETE", + "/_synapse/admin/v2/rooms/%s" % self.room_id, + {}, + access_token=admin_tok, + ) + + # Check that the shutdown was blocked + self.assertEqual(channel.code, 403, channel.json_body) + + # Check that the mock was called once. + shutdown_mock.assert_called_once() + args = shutdown_mock.call_args[0] + + # Check that the mock was called with the right user ID + self.assertEqual(args[0], admin_user_id) + + # Check that the mock was called with the right room ID + self.assertEqual(args[1], self.room_id) diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index 6fbac0ab14..8597867563 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -13,26 +13,10 @@ # limitations under the License. from synapse.storage.database import make_tuple_comparison_clause -from synapse.storage.engines import BaseDatabaseEngine from tests import unittest -def _stub_db_engine(**kwargs) -> BaseDatabaseEngine: - # returns a DatabaseEngine, circumventing the abc mechanism - # any kwargs are set as attributes on the class before instantiating it - t = type( - "TestBaseDatabaseEngine", - (BaseDatabaseEngine,), - dict(BaseDatabaseEngine.__dict__), - ) - # defeat the abc mechanism - t.__abstractmethods__ = set() - for k, v in kwargs.items(): - setattr(t, k, v) - return t(None, None) - - class TupleComparisonClauseTestCase(unittest.TestCase): def test_native_tuple_comparison(self): clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)]) diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 6a1cf33054..eaa0d7d749 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -129,21 +129,19 @@ class PaginationTestCase(HomeserverTestCase): def test_filter_relation_senders(self): # Messages which second user reacted to. - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) # Messages which third user reacted to. - filter = {"io.element.relation_senders": [self.third_user_id]} + filter = {"related_by_senders": [self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_2) # Messages which either user reacted to. - filter = { - "io.element.relation_senders": [self.second_user_id, self.third_user_id] - } + filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 2, chunk) self.assertCountEqual( @@ -152,20 +150,20 @@ class PaginationTestCase(HomeserverTestCase): def test_filter_relation_type(self): # Messages which have annotations. - filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} + filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) # Messages which have references. - filter = {"io.element.relation_types": [RelationTypes.REFERENCE]} + filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_2) # Messages which have either annotations or references. filter = { - "io.element.relation_types": [ + "related_by_rel_types": [ RelationTypes.ANNOTATION, RelationTypes.REFERENCE, ] @@ -179,8 +177,8 @@ class PaginationTestCase(HomeserverTestCase): def test_filter_relation_senders_and_type(self): # Messages which second user reacted to. filter = { - "io.element.relation_senders": [self.second_user_id], - "io.element.relation_types": [RelationTypes.ANNOTATION], + "related_by_senders": [self.second_user_id], + "related_by_rel_types": [RelationTypes.ANNOTATION], } chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) @@ -201,7 +199,7 @@ class PaginationTestCase(HomeserverTestCase): tok=self.second_tok, ) - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 19741ffcda..6a4b17527a 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -141,6 +141,84 @@ class DescriptorTestCase(unittest.TestCase): self.assertEqual(r, "chips") obj.mock.assert_not_called() + @defer.inlineCallbacks + def test_cache_uncached_args(self): + """ + Only the arguments not named in uncached_args should matter to the cache + + Note that this is identical to test_cache_num_args, but provides the + arguments differently. + """ + + class Cls: + # Note that it is important that this is not the last argument to + # test behaviour of skipping arguments properly. + @descriptors.cached(uncached_args=("arg2",)) + def fn(self, arg1, arg2, arg3): + return self.mock(arg1, arg2, arg3) + + def __init__(self): + self.mock = mock.Mock() + + obj = Cls() + obj.mock.return_value = "fish" + r = yield obj.fn(1, 2, 3) + self.assertEqual(r, "fish") + obj.mock.assert_called_once_with(1, 2, 3) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = "chips" + r = yield obj.fn(2, 3, 4) + self.assertEqual(r, "chips") + obj.mock.assert_called_once_with(2, 3, 4) + obj.mock.reset_mock() + + # the two values should now be cached; we should be able to vary + # the second argument and still get the cached result. + r = yield obj.fn(1, 4, 3) + self.assertEqual(r, "fish") + r = yield obj.fn(2, 5, 4) + self.assertEqual(r, "chips") + obj.mock.assert_not_called() + + @defer.inlineCallbacks + def test_cache_kwargs(self): + """Test that keyword arguments are treated properly""" + + class Cls: + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached() + def fn(self, arg1, kwarg1=2): + return self.mock(arg1, kwarg1=kwarg1) + + obj = Cls() + obj.mock.return_value = "fish" + r = yield obj.fn(1, kwarg1=2) + self.assertEqual(r, "fish") + obj.mock.assert_called_once_with(1, kwarg1=2) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = "chips" + r = yield obj.fn(1, kwarg1=3) + self.assertEqual(r, "chips") + obj.mock.assert_called_once_with(1, kwarg1=3) + obj.mock.reset_mock() + + # the values should now be cached. + r = yield obj.fn(1, kwarg1=2) + self.assertEqual(r, "fish") + # We should be able to not provide kwarg1 and get the cached value back. + r = yield obj.fn(1) + self.assertEqual(r, "fish") + # Keyword arguments can be in any order. + r = yield obj.fn(kwarg1=2, arg1=1) + self.assertEqual(r, "fish") + obj.mock.assert_not_called() + def test_cache_with_sync_exception(self): """If the wrapped function throws synchronously, things should continue to work""" @@ -656,7 +734,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): def fn(self, arg1, arg2): pass - @descriptors.cachedList("fn", "args1") + @descriptors.cachedList(cached_method_name="fn", list_name="args1") async def list_fn(self, args1, arg2): assert current_context().name == "c1" # we want this to behave like an asynchronous function @@ -715,7 +793,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): def fn(self, arg1): pass - @descriptors.cachedList("fn", "args1") + @descriptors.cachedList(cached_method_name="fn", list_name="args1") def list_fn(self, args1) -> "Deferred[dict]": return self.mock(args1) @@ -758,7 +836,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): def fn(self, arg1, arg2): pass - @descriptors.cachedList("fn", "args1") + @descriptors.cachedList(cached_method_name="fn", list_name="args1") async def list_fn(self, args1, arg2): # we want this to behave like an asynchronous function await run_on_reactor() |