diff options
151 files changed, 3048 insertions, 1188 deletions
diff --git a/changelog.d/10533.misc b/changelog.d/10533.misc new file mode 100644 index 0000000000..f70dc6496f --- /dev/null +++ b/changelog.d/10533.misc @@ -0,0 +1 @@ +Improve event caching mechanism to avoid having multiple copies of an event in memory at a time. diff --git a/changelog.d/12477.misc b/changelog.d/12477.misc new file mode 100644 index 0000000000..e793d08e5e --- /dev/null +++ b/changelog.d/12477.misc @@ -0,0 +1 @@ +Add some type hints to datastore. \ No newline at end of file diff --git a/changelog.d/12513.feature b/changelog.d/12513.feature new file mode 100644 index 0000000000..01bf1d9d2c --- /dev/null +++ b/changelog.d/12513.feature @@ -0,0 +1 @@ +Measure the time taken in spam-checking callbacks and expose those measurements as metrics. diff --git a/changelog.d/12567.misc b/changelog.d/12567.misc new file mode 100644 index 0000000000..35f08569ba --- /dev/null +++ b/changelog.d/12567.misc @@ -0,0 +1 @@ +Replace string literal instances of stream key types with typed constants. \ No newline at end of file diff --git a/changelog.d/12586.misc b/changelog.d/12586.misc new file mode 100644 index 0000000000..d26e332305 --- /dev/null +++ b/changelog.d/12586.misc @@ -0,0 +1 @@ +Add `@cancellable` decorator, for use on endpoint methods that can be cancelled when clients disconnect. diff --git a/changelog.d/12588.misc b/changelog.d/12588.misc new file mode 100644 index 0000000000..f62d5c8e21 --- /dev/null +++ b/changelog.d/12588.misc @@ -0,0 +1 @@ +Add ability to cancel disconnected requests to `SynapseRequest`. diff --git a/changelog.d/12618.feature b/changelog.d/12618.feature new file mode 100644 index 0000000000..37fa03b3cb --- /dev/null +++ b/changelog.d/12618.feature @@ -0,0 +1 @@ +Add a `default_power_level_content_override` config option to set default room power levels per room preset. diff --git a/changelog.d/12630.misc b/changelog.d/12630.misc new file mode 100644 index 0000000000..43e12603e2 --- /dev/null +++ b/changelog.d/12630.misc @@ -0,0 +1 @@ +Add a helper class for testing request cancellation. diff --git a/changelog.d/12673.feature b/changelog.d/12673.feature new file mode 100644 index 0000000000..f2bddd6e1c --- /dev/null +++ b/changelog.d/12673.feature @@ -0,0 +1 @@ +Synapse will now reload [cache config](https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#caching) when it receives a [SIGHUP](https://en.wikipedia.org/wiki/SIGHUP) signal. diff --git a/changelog.d/12676.misc b/changelog.d/12676.misc new file mode 100644 index 0000000000..26490af00d --- /dev/null +++ b/changelog.d/12676.misc @@ -0,0 +1 @@ +Improve documentation of the `synapse.push` module. diff --git a/changelog.d/12677.misc b/changelog.d/12677.misc new file mode 100644 index 0000000000..eed12e69e9 --- /dev/null +++ b/changelog.d/12677.misc @@ -0,0 +1 @@ +Refactor functions to on `PushRuleEvaluatorForEvent`. diff --git a/changelog.d/12679.misc b/changelog.d/12679.misc new file mode 100644 index 0000000000..6df1116b49 --- /dev/null +++ b/changelog.d/12679.misc @@ -0,0 +1 @@ +Preparation for database schema simplifications: stop writing to `event_reference_hashes`. diff --git a/changelog.d/12683.bugfix b/changelog.d/12683.bugfix new file mode 100644 index 0000000000..2ce84a223a --- /dev/null +++ b/changelog.d/12683.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.57.0 where `/messages` would throw a 500 error when querying for a non-existent room. diff --git a/changelog.d/12689.misc b/changelog.d/12689.misc new file mode 100644 index 0000000000..daa484ea30 --- /dev/null +++ b/changelog.d/12689.misc @@ -0,0 +1 @@ +Refactor `EventContext` class. diff --git a/changelog.d/12691.misc b/changelog.d/12691.misc new file mode 100644 index 0000000000..c635434211 --- /dev/null +++ b/changelog.d/12691.misc @@ -0,0 +1 @@ +Remove an unneeded class in the push code. diff --git a/changelog.d/12693.misc b/changelog.d/12693.misc new file mode 100644 index 0000000000..8bd1e1cb0c --- /dev/null +++ b/changelog.d/12693.misc @@ -0,0 +1 @@ +Consolidate parsing of relation information from events. diff --git a/changelog.d/12694.misc b/changelog.d/12694.misc new file mode 100644 index 0000000000..e1e956a513 --- /dev/null +++ b/changelog.d/12694.misc @@ -0,0 +1 @@ +Capture the `Deferred` for request cancellation in `_AsyncResource`. diff --git a/changelog.d/12695.misc b/changelog.d/12695.misc new file mode 100644 index 0000000000..1b39d969a4 --- /dev/null +++ b/changelog.d/12695.misc @@ -0,0 +1 @@ +Fixes an incorrect type hint for `Filter._check_event_relations`. diff --git a/changelog.d/12696.bugfix b/changelog.d/12696.bugfix new file mode 100644 index 0000000000..e410184a22 --- /dev/null +++ b/changelog.d/12696.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where an empty room would be created when a user with an insufficient power level tried to upgrade a room. diff --git a/changelog.d/12698.misc b/changelog.d/12698.misc new file mode 100644 index 0000000000..5d626352f9 --- /dev/null +++ b/changelog.d/12698.misc @@ -0,0 +1 @@ +Respect the `@cancellable` flag for `DirectServe{Html,Json}Resource`s. diff --git a/changelog.d/12699.misc b/changelog.d/12699.misc new file mode 100644 index 0000000000..d278a956c7 --- /dev/null +++ b/changelog.d/12699.misc @@ -0,0 +1 @@ +Respect the `@cancellable` flag for `RestServlet`s and `BaseFederationServlet`s. diff --git a/changelog.d/12700.misc b/changelog.d/12700.misc new file mode 100644 index 0000000000..d93eb5dada --- /dev/null +++ b/changelog.d/12700.misc @@ -0,0 +1 @@ +Respect the `@cancellable` flag for `ReplicationEndpoint`s. diff --git a/changelog.d/12701.feature b/changelog.d/12701.feature new file mode 100644 index 0000000000..bb2264602c --- /dev/null +++ b/changelog.d/12701.feature @@ -0,0 +1 @@ +Add a config options to allow for auto-tuning of caches. diff --git a/changelog.d/12705.misc b/changelog.d/12705.misc new file mode 100644 index 0000000000..a913d8bb85 --- /dev/null +++ b/changelog.d/12705.misc @@ -0,0 +1 @@ +Complain if a federation endpoint has the `@cancellable` flag, since some of the wrapper code may not handle cancellation correctly yet. diff --git a/changelog.d/12708.misc b/changelog.d/12708.misc new file mode 100644 index 0000000000..aa99e7311b --- /dev/null +++ b/changelog.d/12708.misc @@ -0,0 +1 @@ +Enable cancellation of `GET /rooms/$room_id/members`, `GET /rooms/$room_id/state` and `GET /rooms/$room_id/state/$event_type/*` requests. diff --git a/changelog.d/12709.removal b/changelog.d/12709.removal new file mode 100644 index 0000000000..6bb03e2894 --- /dev/null +++ b/changelog.d/12709.removal @@ -0,0 +1 @@ +Require a body in POST requests to `/rooms/{roomId}/receipt/{receiptType}/{eventId}`, as required by the [Matrix specification](https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidreceiptreceipttypeeventid). This breaks compatibility with Element Android 1.2.0 and earlier: users of those clients will be unable to send read receipts. diff --git a/changelog.d/12711.misc b/changelog.d/12711.misc new file mode 100644 index 0000000000..0831ce0452 --- /dev/null +++ b/changelog.d/12711.misc @@ -0,0 +1 @@ +Optimize private read receipt filtering. diff --git a/changelog.d/12713.bugfix b/changelog.d/12713.bugfix new file mode 100644 index 0000000000..91e70f102c --- /dev/null +++ b/changelog.d/12713.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.30.0 where empty rooms could be automatically created if a monthly active users limit is set. diff --git a/changelog.d/12715.doc b/changelog.d/12715.doc new file mode 100644 index 0000000000..150d78c3f6 --- /dev/null +++ b/changelog.d/12715.doc @@ -0,0 +1 @@ +Fix a typo in the Media Admin API documentation. diff --git a/changelog.d/12716.misc b/changelog.d/12716.misc new file mode 100644 index 0000000000..b07e1b52ee --- /dev/null +++ b/changelog.d/12716.misc @@ -0,0 +1 @@ +Add type annotations to increase the number of modules passing `disallow-untyped-defs`. \ No newline at end of file diff --git a/changelog.d/12720.misc b/changelog.d/12720.misc new file mode 100644 index 0000000000..01b427f200 --- /dev/null +++ b/changelog.d/12720.misc @@ -0,0 +1 @@ +Drop the logging level of status messages for the URL preview cache expiry job from INFO to DEBUG. \ No newline at end of file diff --git a/changelog.d/12726.misc b/changelog.d/12726.misc new file mode 100644 index 0000000000..b07e1b52ee --- /dev/null +++ b/changelog.d/12726.misc @@ -0,0 +1 @@ +Add type annotations to increase the number of modules passing `disallow-untyped-defs`. \ No newline at end of file diff --git a/changelog.d/12727.doc b/changelog.d/12727.doc new file mode 100644 index 0000000000..c41e50c85b --- /dev/null +++ b/changelog.d/12727.doc @@ -0,0 +1 @@ +Update the OpenID Connect example for Keycloak to be compatible with newer versions of Keycloak. Contributed by @nhh. diff --git a/changelog.d/12731.misc b/changelog.d/12731.misc new file mode 100644 index 0000000000..962100d516 --- /dev/null +++ b/changelog.d/12731.misc @@ -0,0 +1 @@ +Update configs used by Complement to allow more invites/3PID validations during tests. \ No newline at end of file diff --git a/changelog.d/12734.misc b/changelog.d/12734.misc new file mode 100644 index 0000000000..ffbfb0d632 --- /dev/null +++ b/changelog.d/12734.misc @@ -0,0 +1 @@ +Tidy up and type-hint the database engine modules. diff --git a/changelog.d/12742.doc b/changelog.d/12742.doc new file mode 100644 index 0000000000..0084e27a7d --- /dev/null +++ b/changelog.d/12742.doc @@ -0,0 +1 @@ +Fix typo in server listener documentation. \ No newline at end of file diff --git a/changelog.d/12747.bugfix b/changelog.d/12747.bugfix new file mode 100644 index 0000000000..0fb0059237 --- /dev/null +++ b/changelog.d/12747.bugfix @@ -0,0 +1 @@ +Fix poor database performance when reading the cache invalidation stream for large servers with lots of workers. diff --git a/changelog.d/12749.doc b/changelog.d/12749.doc new file mode 100644 index 0000000000..4560319ee4 --- /dev/null +++ b/changelog.d/12749.doc @@ -0,0 +1 @@ +Fix typo in 'run_background_tasks_on' option name in configuration manual documentation. diff --git a/docker/complement/conf-workers/workers-shared.yaml b/docker/complement/conf-workers/workers-shared.yaml index 8b69870377..86ee11ecd0 100644 --- a/docker/complement/conf-workers/workers-shared.yaml +++ b/docker/complement/conf-workers/workers-shared.yaml @@ -53,6 +53,18 @@ rc_joins: per_second: 9999 burst_count: 9999 +rc_3pid_validation: + per_second: 1000 + burst_count: 1000 + +rc_invites: + per_room: + per_second: 1000 + burst_count: 1000 + per_user: + per_second: 1000 + burst_count: 1000 + federation_rr_transactions_per_room_per_second: 9999 ## Experimental Features ## diff --git a/docker/complement/conf/homeserver.yaml b/docker/complement/conf/homeserver.yaml index 174f87f52e..e2be540bbb 100644 --- a/docker/complement/conf/homeserver.yaml +++ b/docker/complement/conf/homeserver.yaml @@ -87,6 +87,18 @@ rc_joins: per_second: 9999 burst_count: 9999 +rc_3pid_validation: + per_second: 1000 + burst_count: 1000 + +rc_invites: + per_room: + per_second: 1000 + burst_count: 1000 + per_user: + per_second: 1000 + burst_count: 1000 + federation_rr_transactions_per_room_per_second: 9999 ## API Configuration ## diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index 96b3668f2a..d57c5aedae 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -289,7 +289,7 @@ POST /_synapse/admin/v1/purge_media_cache?before_ts=<unix_timestamp_in_ms> URL Parameters -* `unix_timestamp_in_ms`: string representing a positive integer - Unix timestamp in milliseconds. +* `before_ts`: string representing a positive integer - Unix timestamp in milliseconds. All cached media that was last accessed before this timestamp will be removed. Response: diff --git a/docs/openid.md b/docs/openid.md index 19cacaafef..e899db63d6 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -159,7 +159,7 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to oidc_providers: - idp_id: keycloak idp_name: "My KeyCloak server" - issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}" + issuer: "https://127.0.0.1:8443/realms/{realm_name}" client_id: "synapse" client_secret: "copy secret generated from above" scopes: ["openid", "profile"] diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index a803b8261d..ee98d193cb 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -289,7 +289,7 @@ presence: # federation: the server-server API (/_matrix/federation). Also implies # 'media', 'keys', 'openid' # -# keys: the key discovery API (/_matrix/keys). +# keys: the key discovery API (/_matrix/key). # # media: the media API (/_matrix/media). # @@ -730,6 +730,12 @@ retention: # A cache 'factor' is a multiplier that can be applied to each of # Synapse's caches in order to increase or decrease the maximum # number of entries that can be stored. +# +# The configuration for cache factors (caches.global_factor and +# caches.per_cache_factors) can be reloaded while the application is running, +# by sending a SIGHUP signal to the Synapse process. Changes to other parts of +# the caching config will NOT be applied after a SIGHUP is received; a restart +# is necessary. # The number of events to cache in memory. Not affected by # caches.global_factor. @@ -778,6 +784,24 @@ caches: # #cache_entry_ttl: 30m + # This flag enables cache autotuning, and is further specified by the sub-options `max_cache_memory_usage`, + # `target_cache_memory_usage`, `min_cache_ttl`. These flags work in conjunction with each other to maintain + # a balance between cache memory usage and cache entry availability. You must be using jemalloc to utilize + # this option, and all three of the options must be specified for this feature to work. + #cache_autotuning: + # This flag sets a ceiling on much memory the cache can use before caches begin to be continuously evicted. + # They will continue to be evicted until the memory usage drops below the `target_memory_usage`, set in + # the flag below, or until the `min_cache_ttl` is hit. + #max_cache_memory_usage: 1024M + + # This flag sets a rough target for the desired memory usage of the caches. + #target_cache_memory_usage: 758M + + # 'min_cache_ttl` sets a limit under which newer cache entries are not evicted and is only applied when + # caches are actively being evicted/`max_cache_memory_usage` has been exceeded. This is to protect hot caches + # from being emptied while Synapse is evicting due to memory. + #min_cache_ttl: 5m + # Controls how long the results of a /sync request are cached for after # a successful response is returned. A higher duration can help clients with # intermittent connections, at the cost of higher memory usage. @@ -2462,6 +2486,40 @@ push: # #encryption_enabled_by_default_for_room_type: invite +# Override the default power levels for rooms created on this server, per +# room creation preset. +# +# The appropriate dictionary for the room preset will be applied on top +# of the existing power levels content. +# +# Useful if you know that your users need special permissions in rooms +# that they create (e.g. to send particular types of state events without +# needing an elevated power level). This takes the same shape as the +# `power_level_content_override` parameter in the /createRoom API, but +# is applied before that parameter. +# +# Valid keys are some or all of `private_chat`, `trusted_private_chat` +# and `public_chat`. Inside each of those should be any of the +# properties allowed in `power_level_content_override` in the +# /createRoom API. If any property is missing, its default value will +# continue to be used. If any property is present, it will overwrite +# the existing default completely (so if the `events` property exists, +# the default event power levels will be ignored). +# +#default_power_level_content_override: +# private_chat: +# "events": +# "com.example.myeventtype" : 0 +# "m.room.avatar": 50 +# "m.room.canonical_alias": 50 +# "m.room.encryption": 100 +# "m.room.history_visibility": 100 +# "m.room.name": 50 +# "m.room.power_levels": 100 +# "m.room.server_acl": 100 +# "m.room.tombstone": 100 +# "events_default": 1 + # Uncomment to allow non-server-admin users to create groups on this server # diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 21dad0ac41..3e2031f08a 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -467,13 +467,13 @@ Sub-options for each listener include: Valid resource names are: -* `client`: the client-server API (/_matrix/client), and the synapse admin API (/_synapse/admin). Also implies 'media' and 'static'. +* `client`: the client-server API (/_matrix/client), and the synapse admin API (/_synapse/admin). Also implies `media` and `static`. * `consent`: user consent forms (/_matrix/consent). See [here](../../consent_tracking.md) for more. * `federation`: the server-server API (/_matrix/federation). Also implies `media`, `keys`, `openid` -* `keys`: the key discovery API (/_matrix/keys). +* `keys`: the key discovery API (/_matrix/key). * `media`: the media API (/_matrix/media). @@ -1119,7 +1119,17 @@ Caching can be configured through the following sub-options: with intermittent connections, at the cost of higher memory usage. By default, this is zero, which means that sync responses are not cached at all. - +* `cache_autotuning` and its sub-options `max_cache_memory_usage`, `target_cache_memory_usage`, and + `min_cache_ttl` work in conjunction with each other to maintain a balance between cache memory + usage and cache entry availability. You must be using [jemalloc](https://github.com/matrix-org/synapse#help-synapse-is-slow-and-eats-all-my-ramcpu) + to utilize this option, and all three of the options must be specified for this feature to work. + * `max_cache_memory_usage` sets a ceiling on how much memory the cache can use before caches begin to be continuously evicted. + They will continue to be evicted until the memory usage drops below the `target_memory_usage`, set in + the flag below, or until the `min_cache_ttl` is hit. + * `target_memory_usage` sets a rough target for the desired memory usage of the caches. + * `min_cache_ttl` sets a limit under which newer cache entries are not evicted and is only applied when + caches are actively being evicted/`max_cache_memory_usage` has been exceeded. This is to protect hot caches + from being emptied while Synapse is evicting due to memory. Example configuration: ```yaml @@ -1127,9 +1137,29 @@ caches: global_factor: 1.0 per_cache_factors: get_users_who_share_room_with_user: 2.0 - expire_caches: false sync_response_cache_duration: 2m + cache_autotuning: + max_cache_memory_usage: 1024M + target_cache_memory_usage: 758M + min_cache_ttl: 5m +``` + +### Reloading cache factors + +The cache factors (i.e. `caches.global_factor` and `caches.per_cache_factors`) may be reloaded at any time by sending a +[`SIGHUP`](https://en.wikipedia.org/wiki/SIGHUP) signal to Synapse using e.g. + +```commandline +kill -HUP [PID_OF_SYNAPSE_PROCESS] ``` + +If you are running multiple workers, you must individually update the worker +config file and send this signal to each worker process. + +If you're using the [example systemd service](https://github.com/matrix-org/synapse/blob/develop/contrib/systemd/matrix-synapse.service) +file in Synapse's `contrib` directory, you can send a `SIGHUP` signal by using +`systemctl reload matrix-synapse`. + --- ## Database ## Config options related to database settings. @@ -3298,6 +3328,32 @@ room_list_publication_rules: room_id: "*" action: allow ``` + +--- +Config option: `default_power_level_content_override` + +The `default_power_level_content_override` option controls the default power +levels for rooms. + +Useful if you know that your users need special permissions in rooms +that they create (e.g. to send particular types of state events without +needing an elevated power level). This takes the same shape as the +`power_level_content_override` parameter in the /createRoom API, but +is applied before that parameter. + +Note that each key provided inside a preset (for example `events` in the example +below) will overwrite all existing defaults inside that key. So in the example +below, newly-created private_chat rooms will have no rules for any event types +except `com.example.foo`. + +Example configuration: +```yaml +default_power_level_content_override: + private_chat: { "events": { "com.example.foo" : 0 } } + trusted_private_chat: null + public_chat: null +``` + --- ## Opentracing ## Configuration options related to Opentracing support. @@ -3398,7 +3454,7 @@ stream_writers: typing: worker1 ``` --- -Config option: `run_background_task_on` +Config option: `run_background_tasks_on` The worker that is used to run background tasks (e.g. cleaning up expired data). If not provided this defaults to the main process. diff --git a/mypy.ini b/mypy.ini index ba0de419f5..b5b907973f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -119,15 +119,39 @@ disallow_untyped_defs = True [mypy-synapse.federation.transport.client] disallow_untyped_defs = False +[mypy-synapse.groups.*] +disallow_untyped_defs = True + [mypy-synapse.handlers.*] disallow_untyped_defs = True +[mypy-synapse.http.federation.*] +disallow_untyped_defs = True + +[mypy-synapse.http.connectproxyclient] +disallow_untyped_defs = True + +[mypy-synapse.http.proxyagent] +disallow_untyped_defs = True + +[mypy-synapse.http.request_metrics] +disallow_untyped_defs = True + [mypy-synapse.http.server] disallow_untyped_defs = True +[mypy-synapse.logging._remote] +disallow_untyped_defs = True + [mypy-synapse.logging.context] disallow_untyped_defs = True +[mypy-synapse.logging.formatter] +disallow_untyped_defs = True + +[mypy-synapse.logging.handlers] +disallow_untyped_defs = True + [mypy-synapse.metrics.*] disallow_untyped_defs = True @@ -157,6 +181,9 @@ disallow_untyped_defs = True [mypy-synapse.state.*] disallow_untyped_defs = True +[mypy-synapse.storage.databases.background_updates] +disallow_untyped_defs = True + [mypy-synapse.storage.databases.main.account_data] disallow_untyped_defs = True @@ -196,18 +223,39 @@ disallow_untyped_defs = True [mypy-synapse.storage.databases.main.state_deltas] disallow_untyped_defs = True +[mypy-synapse.storage.databases.main.stream] +disallow_untyped_defs = True + [mypy-synapse.storage.databases.main.transactions] disallow_untyped_defs = True [mypy-synapse.storage.databases.main.user_erasure_store] disallow_untyped_defs = True +[mypy-synapse.storage.engines.*] +disallow_untyped_defs = True + +[mypy-synapse.storage.prepare_database] +disallow_untyped_defs = True + +[mypy-synapse.storage.persist_events] +disallow_untyped_defs = True + +[mypy-synapse.storage.state] +disallow_untyped_defs = True + +[mypy-synapse.storage.types] +disallow_untyped_defs = True + [mypy-synapse.storage.util.*] disallow_untyped_defs = True [mypy-synapse.streams.*] disallow_untyped_defs = True +[mypy-synapse.types] +disallow_untyped_defs = True + [mypy-synapse.util.*] disallow_untyped_defs = True diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 4a808e33fe..b91ce06de7 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -19,6 +19,7 @@ from typing import ( TYPE_CHECKING, Awaitable, Callable, + Collection, Dict, Iterable, List, @@ -444,9 +445,9 @@ class Filter: return room_ids async def _check_event_relations( - self, events: Iterable[FilterEvent] + self, events: Collection[FilterEvent] ) -> List[FilterEvent]: - # The event IDs to check, mypy doesn't understand the ifinstance check. + # The event IDs to check, mypy doesn't understand the isinstance check. event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined] event_ids_to_keep = set( await self._store.events_have_relations( diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 3623c1724d..a3446ac6e8 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -49,9 +49,12 @@ from twisted.logger import LoggingFile, LogLevel from twisted.protocols.tls import TLSMemoryBIOFactory from twisted.python.threadpool import ThreadPool +import synapse.util.caches from synapse.api.constants import MAX_PDU_SIZE from synapse.app import check_bind_error from synapse.app.phone_stats_home import start_phone_stats_home +from synapse.config import ConfigError +from synapse.config._base import format_config_error from synapse.config.homeserver import HomeServerConfig from synapse.config.server import ManholeConfig from synapse.crypto import context_factory @@ -432,6 +435,10 @@ async def start(hs: "HomeServer") -> None: signal.signal(signal.SIGHUP, run_sighup) register_sighup(refresh_certificate, hs) + register_sighup(reload_cache_config, hs.config) + + # Apply the cache config. + hs.config.caches.resize_all_caches() # Load the certificate from disk. refresh_certificate(hs) @@ -486,6 +493,43 @@ async def start(hs: "HomeServer") -> None: atexit.register(gc.freeze) +def reload_cache_config(config: HomeServerConfig) -> None: + """Reload cache config from disk and immediately apply it.resize caches accordingly. + + If the config is invalid, a `ConfigError` is logged and no changes are made. + + Otherwise, this: + - replaces the `caches` section on the given `config` object, + - resizes all caches according to the new cache factors, and + + Note that the following cache config keys are read, but not applied: + - event_cache_size: used to set a max_size and _original_max_size on + EventsWorkerStore._get_event_cache when it is created. We'd have to update + the _original_max_size (and maybe + - sync_response_cache_duration: would have to update the timeout_sec attribute on + HomeServer -> SyncHandler -> ResponseCache. + - track_memory_usage. This affects synapse.util.caches.TRACK_MEMORY_USAGE which + influences Synapse's self-reported metrics. + + Also, the HTTPConnectionPool in SimpleHTTPClient sets its maxPersistentPerHost + parameter based on the global_factor. This won't be applied on a config reload. + """ + try: + previous_cache_config = config.reload_config_section("caches") + except ConfigError as e: + logger.warning("Failed to reload cache config") + for f in format_config_error(e): + logger.warning(f) + else: + logger.debug( + "New cache config. Was:\n %s\nNow:\n", + previous_cache_config.__dict__, + config.caches.__dict__, + ) + synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage + config.caches.resize_all_caches() + + def setup_sentry(hs: "HomeServer") -> None: """Enable sentry integration, if enabled in configuration""" diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 0f75e7b9d4..4c6c0658ab 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -16,7 +16,7 @@ import logging import os import sys -from typing import Dict, Iterable, Iterator, List +from typing import Dict, Iterable, List from matrix_common.versionstring import get_distribution_version_string @@ -45,7 +45,7 @@ from synapse.app._base import ( redirect_stdio_to_logs, register_start, ) -from synapse.config._base import ConfigError +from synapse.config._base import ConfigError, format_config_error from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.homeserver import HomeServerConfig from synapse.config.server import ListenerConfig @@ -399,38 +399,6 @@ def setup(config_options: List[str]) -> SynapseHomeServer: return hs -def format_config_error(e: ConfigError) -> Iterator[str]: - """ - Formats a config error neatly - - The idea is to format the immediate error, plus the "causes" of those errors, - hopefully in a way that makes sense to the user. For example: - - Error in configuration at 'oidc_config.user_mapping_provider.config.display_name_template': - Failed to parse config for module 'JinjaOidcMappingProvider': - invalid jinja template: - unexpected end of template, expected 'end of print statement'. - - Args: - e: the error to be formatted - - Returns: An iterator which yields string fragments to be formatted - """ - yield "Error in configuration" - - if e.path: - yield " at '%s'" % (".".join(e.path),) - - yield ":\n %s" % (e.msg,) - - parent_e = e.__cause__ - indent = 1 - while parent_e: - indent += 1 - yield ":\n%s%s" % (" " * indent, str(parent_e)) - parent_e = parent_e.__cause__ - - def run(hs: HomeServer) -> None: _base.start_reactor( "synapse-homeserver", diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 179aa7ff88..42364fc133 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -16,14 +16,18 @@ import argparse import errno +import logging import os from collections import OrderedDict from hashlib import sha256 from textwrap import dedent from typing import ( Any, + ClassVar, + Collection, Dict, Iterable, + Iterator, List, MutableMapping, Optional, @@ -40,6 +44,8 @@ import yaml from synapse.util.templates import _create_mxc_to_http_filter, _format_ts_filter +logger = logging.getLogger(__name__) + class ConfigError(Exception): """Represents a problem parsing the configuration @@ -55,6 +61,38 @@ class ConfigError(Exception): self.path = path +def format_config_error(e: ConfigError) -> Iterator[str]: + """ + Formats a config error neatly + + The idea is to format the immediate error, plus the "causes" of those errors, + hopefully in a way that makes sense to the user. For example: + + Error in configuration at 'oidc_config.user_mapping_provider.config.display_name_template': + Failed to parse config for module 'JinjaOidcMappingProvider': + invalid jinja template: + unexpected end of template, expected 'end of print statement'. + + Args: + e: the error to be formatted + + Returns: An iterator which yields string fragments to be formatted + """ + yield "Error in configuration" + + if e.path: + yield " at '%s'" % (".".join(e.path),) + + yield ":\n %s" % (e.msg,) + + parent_e = e.__cause__ + indent = 1 + while parent_e: + indent += 1 + yield ":\n%s%s" % (" " * indent, str(parent_e)) + parent_e = parent_e.__cause__ + + # We split these messages out to allow packages to override with package # specific instructions. MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS = """\ @@ -119,7 +157,7 @@ class Config: defined in subclasses. """ - section: str + section: ClassVar[str] def __init__(self, root_config: "RootConfig" = None): self.root = root_config @@ -309,9 +347,12 @@ class RootConfig: class, lower-cased and with "Config" removed. """ - config_classes = [] + config_classes: List[Type[Config]] = [] + + def __init__(self, config_files: Collection[str] = ()): + # Capture absolute paths here, so we can reload config after we daemonize. + self.config_files = [os.path.abspath(path) for path in config_files] - def __init__(self): for config_class in self.config_classes: if config_class.section is None: raise ValueError("%r requires a section name" % (config_class,)) @@ -512,12 +553,10 @@ class RootConfig: object from parser.parse_args(..)` """ - obj = cls() - config_args = parser.parse_args(argv) config_files = find_config_files(search_paths=config_args.config_path) - + obj = cls(config_files) if not config_files: parser.error("Must supply a config file.") @@ -627,7 +666,7 @@ class RootConfig: generate_missing_configs = config_args.generate_missing_configs - obj = cls() + obj = cls(config_files) if config_args.generate_config: if config_args.report_stats is None: @@ -727,6 +766,34 @@ class RootConfig: ) -> None: self.invoke_all("generate_files", config_dict, config_dir_path) + def reload_config_section(self, section_name: str) -> Config: + """Reconstruct the given config section, leaving all others unchanged. + + This works in three steps: + + 1. Create a new instance of the relevant `Config` subclass. + 2. Call `read_config` on that instance to parse the new config. + 3. Replace the existing config instance with the new one. + + :raises ValueError: if the given `section` does not exist. + :raises ConfigError: for any other problems reloading config. + + :returns: the previous config object, which no longer has a reference to this + RootConfig. + """ + existing_config: Optional[Config] = getattr(self, section_name, None) + if existing_config is None: + raise ValueError(f"Unknown config section '{section_name}'") + logger.info("Reloading config section '%s'", section_name) + + new_config_data = read_config_files(self.config_files) + new_config = type(existing_config)(self) + new_config.read_config(new_config_data) + setattr(self, section_name, new_config) + + existing_config.root = None + return existing_config + def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]: """Read the config files into a dict diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index bd092f956d..71d6655fda 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -1,15 +1,19 @@ import argparse from typing import ( Any, + Collection, Dict, Iterable, + Iterator, List, + Literal, MutableMapping, Optional, Tuple, Type, TypeVar, Union, + overload, ) import jinja2 @@ -64,6 +68,8 @@ class ConfigError(Exception): self.msg = msg self.path = path +def format_config_error(e: ConfigError) -> Iterator[str]: ... + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str MISSING_REPORT_STATS_SPIEL: str MISSING_SERVER_NAME: str @@ -117,7 +123,8 @@ class RootConfig: background_updates: background_updates.BackgroundUpdateConfig config_classes: List[Type["Config"]] = ... - def __init__(self) -> None: ... + config_files: List[str] + def __init__(self, config_files: Collection[str] = ...) -> None: ... def invoke_all( self, func_name: str, *args: Any, **kwargs: Any ) -> MutableMapping[str, Any]: ... @@ -157,6 +164,12 @@ class RootConfig: def generate_missing_files( self, config_dict: dict, config_dir_path: str ) -> None: ... + @overload + def reload_config_section( + self, section_name: Literal["caches"] + ) -> cache.CacheConfig: ... + @overload + def reload_config_section(self, section_name: str) -> Config: ... class Config: root: RootConfig diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 94d852f413..d2f55534d7 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -69,11 +69,11 @@ def _canonicalise_cache_name(cache_name: str) -> str: def add_resizable_cache( cache_name: str, cache_resize_callback: Callable[[float], None] ) -> None: - """Register a cache that's size can dynamically change + """Register a cache whose size can dynamically change Args: cache_name: A reference to the cache - cache_resize_callback: A callback function that will be ran whenever + cache_resize_callback: A callback function that will run whenever the cache needs to be resized """ # Some caches have '*' in them which we strip out. @@ -96,6 +96,13 @@ class CacheConfig(Config): section = "caches" _environ = os.environ + event_cache_size: int + cache_factors: Dict[str, float] + global_factor: float + track_memory_usage: bool + expiry_time_msec: Optional[int] + sync_response_cache_duration: int + @staticmethod def reset() -> None: """Resets the caches to their defaults. Used for tests.""" @@ -115,6 +122,12 @@ class CacheConfig(Config): # A cache 'factor' is a multiplier that can be applied to each of # Synapse's caches in order to increase or decrease the maximum # number of entries that can be stored. + # + # The configuration for cache factors (caches.global_factor and + # caches.per_cache_factors) can be reloaded while the application is running, + # by sending a SIGHUP signal to the Synapse process. Changes to other parts of + # the caching config will NOT be applied after a SIGHUP is received; a restart + # is necessary. # The number of events to cache in memory. Not affected by # caches.global_factor. @@ -163,6 +176,24 @@ class CacheConfig(Config): # #cache_entry_ttl: 30m + # This flag enables cache autotuning, and is further specified by the sub-options `max_cache_memory_usage`, + # `target_cache_memory_usage`, `min_cache_ttl`. These flags work in conjunction with each other to maintain + # a balance between cache memory usage and cache entry availability. You must be using jemalloc to utilize + # this option, and all three of the options must be specified for this feature to work. + #cache_autotuning: + # This flag sets a ceiling on much memory the cache can use before caches begin to be continuously evicted. + # They will continue to be evicted until the memory usage drops below the `target_memory_usage`, set in + # the flag below, or until the `min_cache_ttl` is hit. + #max_cache_memory_usage: 1024M + + # This flag sets a rough target for the desired memory usage of the caches. + #target_cache_memory_usage: 758M + + # 'min_cache_ttl` sets a limit under which newer cache entries are not evicted and is only applied when + # caches are actively being evicted/`max_cache_memory_usage` has been exceeded. This is to protect hot caches + # from being emptied while Synapse is evicting due to memory. + #min_cache_ttl: 5m + # Controls how long the results of a /sync request are cached for after # a successful response is returned. A higher duration can help clients with # intermittent connections, at the cost of higher memory usage. @@ -174,21 +205,21 @@ class CacheConfig(Config): """ def read_config(self, config: JsonDict, **kwargs: Any) -> None: + """Populate this config object with values from `config`. + + This method does NOT resize existing or future caches: use `resize_all_caches`. + We use two separate methods so that we can reject bad config before applying it. + """ self.event_cache_size = self.parse_size( config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE) ) - self.cache_factors: Dict[str, float] = {} + self.cache_factors = {} cache_config = config.get("caches") or {} - self.global_factor = cache_config.get( - "global_factor", properties.default_factor_size - ) + self.global_factor = cache_config.get("global_factor", _DEFAULT_FACTOR_SIZE) if not isinstance(self.global_factor, (int, float)): raise ConfigError("caches.global_factor must be a number.") - # Set the global one so that it's reflected in new caches - properties.default_factor_size = self.global_factor - # Load cache factors from the config individual_factors = cache_config.get("per_cache_factors") or {} if not isinstance(individual_factors, dict): @@ -230,7 +261,7 @@ class CacheConfig(Config): cache_entry_ttl = cache_config.get("cache_entry_ttl", "30m") if expire_caches: - self.expiry_time_msec: Optional[int] = self.parse_duration(cache_entry_ttl) + self.expiry_time_msec = self.parse_duration(cache_entry_ttl) else: self.expiry_time_msec = None @@ -250,23 +281,38 @@ class CacheConfig(Config): ) self.expiry_time_msec = self.parse_duration(expiry_time) + self.cache_autotuning = cache_config.get("cache_autotuning") + if self.cache_autotuning: + max_memory_usage = self.cache_autotuning.get("max_cache_memory_usage") + self.cache_autotuning["max_cache_memory_usage"] = self.parse_size( + max_memory_usage + ) + + target_mem_size = self.cache_autotuning.get("target_cache_memory_usage") + self.cache_autotuning["target_cache_memory_usage"] = self.parse_size( + target_mem_size + ) + + min_cache_ttl = self.cache_autotuning.get("min_cache_ttl") + self.cache_autotuning["min_cache_ttl"] = self.parse_duration(min_cache_ttl) + self.sync_response_cache_duration = self.parse_duration( cache_config.get("sync_response_cache_duration", 0) ) - # Resize all caches (if necessary) with the new factors we've loaded - self.resize_all_caches() - - # Store this function so that it can be called from other classes without - # needing an instance of Config - properties.resize_all_caches_func = self.resize_all_caches - def resize_all_caches(self) -> None: - """Ensure all cache sizes are up to date + """Ensure all cache sizes are up-to-date. For each cache, run the mapped callback function with either a specific cache factor or the default, global one. """ + # Set the global factor size, so that new caches are appropriately sized. + properties.default_factor_size = self.global_factor + + # Store this function so that it can be called from other classes without + # needing an instance of CacheConfig + properties.resize_all_caches_func = self.resize_all_caches + # block other threads from modifying _CACHES while we iterate it. with _CACHES_LOCK: for cache_name, callback in _CACHES.items(): diff --git a/synapse/config/room.py b/synapse/config/room.py index e18a87ea37..462d85ac1d 100644 --- a/synapse/config/room.py +++ b/synapse/config/room.py @@ -63,6 +63,19 @@ class RoomConfig(Config): "Invalid value for encryption_enabled_by_default_for_room_type" ) + self.default_power_level_content_override = config.get( + "default_power_level_content_override", + None, + ) + if self.default_power_level_content_override is not None: + for preset in self.default_power_level_content_override: + if preset not in vars(RoomCreationPreset).values(): + raise ConfigError( + "Unrecognised room preset %s in default_power_level_content_override" + % preset + ) + # We validate the actual overrides when we try to apply them. + def generate_config_section(self, **kwargs: Any) -> str: return """\ ## Rooms ## @@ -83,4 +96,38 @@ class RoomConfig(Config): # will also not affect rooms created by other servers. # #encryption_enabled_by_default_for_room_type: invite + + # Override the default power levels for rooms created on this server, per + # room creation preset. + # + # The appropriate dictionary for the room preset will be applied on top + # of the existing power levels content. + # + # Useful if you know that your users need special permissions in rooms + # that they create (e.g. to send particular types of state events without + # needing an elevated power level). This takes the same shape as the + # `power_level_content_override` parameter in the /createRoom API, but + # is applied before that parameter. + # + # Valid keys are some or all of `private_chat`, `trusted_private_chat` + # and `public_chat`. Inside each of those should be any of the + # properties allowed in `power_level_content_override` in the + # /createRoom API. If any property is missing, its default value will + # continue to be used. If any property is present, it will overwrite + # the existing default completely (so if the `events` property exists, + # the default event power levels will be ignored). + # + #default_power_level_content_override: + # private_chat: + # "events": + # "com.example.myeventtype" : 0 + # "m.room.avatar": 50 + # "m.room.canonical_alias": 50 + # "m.room.encryption": 100 + # "m.room.history_visibility": 100 + # "m.room.name": 50 + # "m.room.power_levels": 100 + # "m.room.server_acl": 100 + # "m.room.tombstone": 100 + # "events_default": 1 """ diff --git a/synapse/config/server.py b/synapse/config/server.py index 005a3ee48c..f73d5e1f66 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -996,7 +996,7 @@ class ServerConfig(Config): # federation: the server-server API (/_matrix/federation). Also implies # 'media', 'keys', 'openid' # - # keys: the key discovery API (/_matrix/keys). + # keys: the key discovery API (/_matrix/key). # # media: the media API (/_matrix/media). # diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index c238376caf..39ad2793d9 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -15,6 +15,7 @@ # limitations under the License. import abc +import collections.abc import os from typing import ( TYPE_CHECKING, @@ -32,9 +33,11 @@ from typing import ( overload, ) +import attr from typing_extensions import Literal from unpaddedbase64 import encode_base64 +from synapse.api.constants import RelationTypes from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.types import JsonDict, RoomStreamToken from synapse.util.caches import intern_dict @@ -615,3 +618,45 @@ def make_event_from_dict( return event_type( event_dict, room_version, internal_metadata_dict or {}, rejected_reason ) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _EventRelation: + # The target event of the relation. + parent_id: str + # The relation type. + rel_type: str + # The aggregation key. Will be None if the rel_type is not m.annotation or is + # not a string. + aggregation_key: Optional[str] + + +def relation_from_event(event: EventBase) -> Optional[_EventRelation]: + """ + Attempt to parse relation information an event. + + Returns: + The event relation information, if it is valid. None, otherwise. + """ + relation = event.content.get("m.relates_to") + if not relation or not isinstance(relation, collections.abc.Mapping): + # No relation information. + return None + + # Relations must have a type and parent event ID. + rel_type = relation.get("rel_type") + if not isinstance(rel_type, str): + return None + + parent_id = relation.get("event_id") + if not isinstance(parent_id, str): + return None + + # Annotations have a key field. + aggregation_key = None + if rel_type == RelationTypes.ANNOTATION: + aggregation_key = relation.get("key") + if not isinstance(aggregation_key, str): + aggregation_key = None + + return _EventRelation(parent_id, rel_type, aggregation_key) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 46042b2bf7..9ccd24b298 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -15,12 +15,10 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union import attr from frozendict import frozendict - -from twisted.internet.defer import Deferred +from typing_extensions import Literal from synapse.appservice import ApplicationService from synapse.events import EventBase -from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.types import JsonDict, StateMap if TYPE_CHECKING: @@ -60,6 +58,9 @@ class EventContext: If ``state_group`` is None (ie, the event is an outlier), ``state_group_before_event`` will always also be ``None``. + state_delta_due_to_event: If `state_group` and `state_group_before_event` are not None + then this is the delta of the state between the two groups. + prev_group: If it is known, ``state_group``'s prev_group. Note that this being None does not necessarily mean that ``state_group`` does not have a prev_group! @@ -78,73 +79,47 @@ class EventContext: app_service: If this event is being sent by a (local) application service, that app service. - _current_state_ids: The room state map, including this event - ie, the state - in ``state_group``. - - (type, state_key) -> event_id - - For an outlier, this is {} - - Note that this is a private attribute: it should be accessed via - ``get_current_state_ids``. _AsyncEventContext impl calculates this - on-demand: it will be None until that happens. - - _prev_state_ids: The room state map, excluding this event - ie, the state - in ``state_group_before_event``. For a non-state - event, this will be the same as _current_state_events. - - Note that it is a completely different thing to prev_group! - - (type, state_key) -> event_id - - For an outlier, this is {} - - As with _current_state_ids, this is a private attribute. It should be - accessed via get_prev_state_ids. - partial_state: if True, we may be storing this event with a temporary, incomplete state. """ - rejected: Union[bool, str] = False + _storage: "Storage" + rejected: Union[Literal[False], str] = False _state_group: Optional[int] = None state_group_before_event: Optional[int] = None + _state_delta_due_to_event: Optional[StateMap[str]] = None prev_group: Optional[int] = None delta_ids: Optional[StateMap[str]] = None app_service: Optional[ApplicationService] = None - _current_state_ids: Optional[StateMap[str]] = None - _prev_state_ids: Optional[StateMap[str]] = None - partial_state: bool = False @staticmethod def with_state( + storage: "Storage", state_group: Optional[int], state_group_before_event: Optional[int], - current_state_ids: Optional[StateMap[str]], - prev_state_ids: Optional[StateMap[str]], + state_delta_due_to_event: Optional[StateMap[str]], partial_state: bool, prev_group: Optional[int] = None, delta_ids: Optional[StateMap[str]] = None, ) -> "EventContext": return EventContext( - current_state_ids=current_state_ids, - prev_state_ids=prev_state_ids, + storage=storage, state_group=state_group, state_group_before_event=state_group_before_event, + state_delta_due_to_event=state_delta_due_to_event, prev_group=prev_group, delta_ids=delta_ids, partial_state=partial_state, ) @staticmethod - def for_outlier() -> "EventContext": + def for_outlier( + storage: "Storage", + ) -> "EventContext": """Return an EventContext instance suitable for persisting an outlier event""" - return EventContext( - current_state_ids={}, - prev_state_ids={}, - ) + return EventContext(storage=storage) async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict: """Converts self to a type that can be serialized as JSON, and then @@ -157,24 +132,14 @@ class EventContext: The serialized event. """ - # We don't serialize the full state dicts, instead they get pulled out - # of the DB on the other side. However, the other side can't figure out - # the prev_state_ids, so if we're a state event we include the event - # id that we replaced in the state. - if event.is_state(): - prev_state_ids = await self.get_prev_state_ids() - prev_state_id = prev_state_ids.get((event.type, event.state_key)) - else: - prev_state_id = None - return { - "prev_state_id": prev_state_id, - "event_type": event.type, - "event_state_key": event.get_state_key(), "state_group": self._state_group, "state_group_before_event": self.state_group_before_event, "rejected": self.rejected, "prev_group": self.prev_group, + "state_delta_due_to_event": _encode_state_dict( + self._state_delta_due_to_event + ), "delta_ids": _encode_state_dict(self.delta_ids), "app_service_id": self.app_service.id if self.app_service else None, "partial_state": self.partial_state, @@ -192,16 +157,16 @@ class EventContext: Returns: The event context. """ - context = _AsyncEventContextImpl( + context = EventContext( # We use the state_group and prev_state_id stuff to pull the # current_state_ids out of the DB and construct prev_state_ids. storage=storage, - prev_state_id=input["prev_state_id"], - event_type=input["event_type"], - event_state_key=input["event_state_key"], state_group=input["state_group"], state_group_before_event=input["state_group_before_event"], prev_group=input["prev_group"], + state_delta_due_to_event=_decode_state_dict( + input["state_delta_due_to_event"] + ), delta_ids=_decode_state_dict(input["delta_ids"]), rejected=input["rejected"], partial_state=input.get("partial_state", False), @@ -249,8 +214,15 @@ class EventContext: if self.rejected: raise RuntimeError("Attempt to access state_ids of rejected event") - await self._ensure_fetched() - return self._current_state_ids + assert self._state_delta_due_to_event is not None + + prev_state_ids = await self.get_prev_state_ids() + + if self._state_delta_due_to_event: + prev_state_ids = dict(prev_state_ids) + prev_state_ids.update(self._state_delta_due_to_event) + + return prev_state_ids async def get_prev_state_ids(self) -> StateMap[str]: """ @@ -265,94 +237,10 @@ class EventContext: Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - await self._ensure_fetched() - # There *should* be previous state IDs now. - assert self._prev_state_ids is not None - return self._prev_state_ids - - def get_cached_current_state_ids(self) -> Optional[StateMap[str]]: - """Gets the current state IDs if we have them already cached. - - It is an error to access this for a rejected event, since rejected state should - not make it into the room state. This method will raise an exception if - ``rejected`` is set. - - Returns: - Returns None if we haven't cached the state or if state_group is None - (which happens when the associated event is an outlier). - - Otherwise, returns the the current state IDs. - """ - if self.rejected: - raise RuntimeError("Attempt to access state_ids of rejected event") - - return self._current_state_ids - - async def _ensure_fetched(self) -> None: - return None - - -@attr.s(slots=True) -class _AsyncEventContextImpl(EventContext): - """ - An implementation of EventContext which fetches _current_state_ids and - _prev_state_ids from the database on demand. - - Attributes: - - _storage - - _fetching_state_deferred: Resolves when *_state_ids have been calculated. - None if we haven't started calculating yet - - _event_type: The type of the event the context is associated with. - - _event_state_key: The state_key of the event the context is associated with. - - _prev_state_id: If the event associated with the context is a state event, - then `_prev_state_id` is the event_id of the state that was replaced. - """ - - # This needs to have a default as we're inheriting - _storage: "Storage" = attr.ib(default=None) - _prev_state_id: Optional[str] = attr.ib(default=None) - _event_type: str = attr.ib(default=None) - _event_state_key: Optional[str] = attr.ib(default=None) - _fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None) - - async def _ensure_fetched(self) -> None: - if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background(self._fill_out_state) - - await make_deferred_yieldable(self._fetching_state_deferred) - - async def _fill_out_state(self) -> None: - """Called to populate the _current_state_ids and _prev_state_ids - attributes by loading from the database. - """ - if self.state_group is None: - # No state group means the event is an outlier. Usually the state_ids dicts are also - # pre-set to empty dicts, but they get reset when the context is serialized, so set - # them to empty dicts again here. - self._current_state_ids = {} - self._prev_state_ids = {} - return - - current_state_ids = await self._storage.state.get_state_ids_for_group( - self.state_group + assert self.state_group_before_event is not None + return await self._storage.state.get_state_ids_for_group( + self.state_group_before_event ) - # Set this separately so mypy knows current_state_ids is not None. - self._current_state_ids = current_state_ids - if self._event_state_key is not None: - self._prev_state_ids = dict(current_state_ids) - - key = (self._event_type, self._event_state_key) - if self._prev_state_id: - self._prev_state_ids[key] = self._prev_state_id - else: - self._prev_state_ids.pop(key, None) - else: - self._prev_state_ids = current_state_ids def _encode_state_dict( diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 3b6795d40f..f30207376a 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -32,6 +32,7 @@ from synapse.rest.media.v1.media_storage import ReadableFileWrapper from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import RoomAlias, UserProfile from synapse.util.async_helpers import delay_cancellation, maybe_awaitable +from synapse.util.metrics import Measure if TYPE_CHECKING: import synapse.events @@ -162,7 +163,10 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None: class SpamChecker: - def __init__(self) -> None: + def __init__(self, hs: "synapse.server.HomeServer") -> None: + self.hs = hs + self.clock = hs.get_clock() + self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = [] self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = [] self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = [] @@ -255,7 +259,10 @@ class SpamChecker: will be used as the error message returned to the user. """ for callback in self._check_event_for_spam_callbacks: - res: Union[bool, str] = await delay_cancellation(callback(event)) + with Measure( + self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) + ): + res: Union[bool, str] = await delay_cancellation(callback(event)) if res: return res @@ -276,9 +283,12 @@ class SpamChecker: Whether the user may join the room """ for callback in self._user_may_join_room_callbacks: - may_join_room = await delay_cancellation( - callback(user_id, room_id, is_invited) - ) + with Measure( + self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) + ): + may_join_room = await delay_cancellation( + callback(user_id, room_id, is_invited) + ) if may_join_room is False: return False @@ -300,9 +310,12 @@ class SpamChecker: True if the user may send an invite, otherwise False """ for callback in self._user_may_invite_callbacks: - may_invite = await delay_cancellation( - callback(inviter_userid, invitee_userid, room_id) - ) + with Measure( + self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) + ): + may_invite = await delay_cancellation( + callback(inviter_userid, invitee_userid, room_id) + ) if may_invite is False: return False @@ -328,9 +341,12 @@ class SpamChecker: True if the user may send the invite, otherwise False """ for callback in self._user_may_send_3pid_invite_callbacks: - may_send_3pid_invite = await delay_cancellation( - callback(inviter_userid, medium, address, room_id) - ) + with Measure( + self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) + ): + may_send_3pid_invite = await delay_cancellation( + callback(inviter_userid, medium, address, room_id) + ) if may_send_3pid_invite is False: return False @@ -348,7 +364,10 @@ class SpamChecker: True if the user may create a room, otherwise False """ for callback in self._user_may_create_room_callbacks: - may_create_room = await delay_cancellation(callback(userid)) + with Measure( + self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) + ): + may_create_room = await delay_cancellation(callback(userid)) if may_create_room is False: return False @@ -369,9 +388,12 @@ class SpamChecker: True if the user may create a room alias, otherwise False """ for callback in self._user_may_create_room_alias_callbacks: - may_create_room_alias = await delay_cancellation( - callback(userid, room_alias) - ) + with Measure( + self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) + ): + may_create_room_alias = await delay_cancellation( + callback(userid, room_alias) + ) if may_create_room_alias is False: return False @@ -390,7 +412,10 @@ class SpamChecker: True if the user may publish the room, otherwise False """ for callback in self._user_may_publish_room_callbacks: - may_publish_room = await delay_cancellation(callback(userid, room_id)) + with Measure( + self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) + ): + may_publish_room = await delay_cancellation(callback(userid, room_id)) if may_publish_room is False: return False @@ -412,9 +437,13 @@ class SpamChecker: True if the user is spammy. """ for callback in self._check_username_for_spam_callbacks: - # Make a copy of the user profile object to ensure the spam checker cannot - # modify it. - if await delay_cancellation(callback(user_profile.copy())): + with Measure( + self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) + ): + # Make a copy of the user profile object to ensure the spam checker cannot + # modify it. + res = await delay_cancellation(callback(user_profile.copy())) + if res: return True return False @@ -442,9 +471,12 @@ class SpamChecker: """ for callback in self._check_registration_for_spam_callbacks: - behaviour = await delay_cancellation( - callback(email_threepid, username, request_info, auth_provider_id) - ) + with Measure( + self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) + ): + behaviour = await delay_cancellation( + callback(email_threepid, username, request_info, auth_provider_id) + ) assert isinstance(behaviour, RegistrationBehaviour) if behaviour != RegistrationBehaviour.ALLOW: return behaviour @@ -486,7 +518,10 @@ class SpamChecker: """ for callback in self._check_media_file_for_spam_callbacks: - spam = await delay_cancellation(callback(file_wrapper, file_info)) + with Measure( + self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) + ): + spam = await delay_cancellation(callback(file_wrapper, file_info)) if spam: return True diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index d629a3ecb5..103861644a 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tupl from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.urls import FEDERATION_V1_PREFIX -from synapse.http.server import HttpServer, ServletCallback +from synapse.http.server import HttpServer, ServletCallback, is_method_cancellable from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest from synapse.logging.context import run_in_background @@ -373,6 +373,17 @@ class BaseFederationServlet: if code is None: continue + if is_method_cancellable(code): + # The wrapper added by `self._wrap` will inherit the cancellable flag, + # but the wrapper itself does not support cancellation yet. + # Once resolved, the cancellation tests in + # `tests/federation/transport/server/test__base.py` can be re-enabled. + raise Exception( + f"{self.__class__.__name__}.on_{method} has been marked as " + "cancellable, but federation servlets do not support cancellation " + "yet." + ) + server.register_paths( method, (pattern,), diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 4c3a5a6e24..dfd24af695 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -934,7 +934,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): # Before deleting the group lets kick everyone out of it users = await self.store.get_users_in_group(group_id, include_private=True) - async def _kick_user_from_group(user_id): + async def _kick_user_from_group(user_id: str) -> None: if self.hs.is_mine_id(user_id): groups_local = self.hs.get_groups_local_handler() assert isinstance( diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 4af9fbc5d1..0478448b47 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -23,7 +23,7 @@ from synapse.replication.http.account_data import ( ReplicationUserAccountDataRestServlet, ) from synapse.streams import EventSource -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, StreamKeyType, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -105,7 +105,7 @@ class AccountDataHandler: ) self._notifier.on_new_event( - "account_data_key", max_stream_id, users=[user_id] + StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id] ) await self._notify_modules(user_id, room_id, account_data_type, content) @@ -141,7 +141,7 @@ class AccountDataHandler: ) self._notifier.on_new_event( - "account_data_key", max_stream_id, users=[user_id] + StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id] ) await self._notify_modules(user_id, None, account_data_type, content) @@ -176,7 +176,7 @@ class AccountDataHandler: ) self._notifier.on_new_event( - "account_data_key", max_stream_id, users=[user_id] + StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id] ) return max_stream_id else: @@ -201,7 +201,7 @@ class AccountDataHandler: ) self._notifier.on_new_event( - "account_data_key", max_stream_id, users=[user_id] + StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id] ) return max_stream_id else: diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 85bd5e4768..1da7bcc85b 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -38,6 +38,7 @@ from synapse.types import ( JsonDict, RoomAlias, RoomStreamToken, + StreamKeyType, UserID, ) from synapse.util.async_helpers import Linearizer @@ -213,8 +214,8 @@ class ApplicationServicesHandler: Args: stream_key: The stream the event came from. - `stream_key` can be "typing_key", "receipt_key", "presence_key", - "to_device_key" or "device_list_key". Any other value for `stream_key` + `stream_key` can be StreamKeyType.TYPING, StreamKeyType.RECEIPT, StreamKeyType.PRESENCE, + StreamKeyType.TO_DEVICE or StreamKeyType.DEVICE_LIST. Any other value for `stream_key` will cause this function to return early. Ephemeral events will only be pushed to appservices that have opted into @@ -235,11 +236,11 @@ class ApplicationServicesHandler: # Only the following streams are currently supported. # FIXME: We should use constants for these values. if stream_key not in ( - "typing_key", - "receipt_key", - "presence_key", - "to_device_key", - "device_list_key", + StreamKeyType.TYPING, + StreamKeyType.RECEIPT, + StreamKeyType.PRESENCE, + StreamKeyType.TO_DEVICE, + StreamKeyType.DEVICE_LIST, ): return @@ -258,14 +259,14 @@ class ApplicationServicesHandler: # Ignore to-device messages if the feature flag is not enabled if ( - stream_key == "to_device_key" + stream_key == StreamKeyType.TO_DEVICE and not self._msc2409_to_device_messages_enabled ): return # Ignore device lists if the feature flag is not enabled if ( - stream_key == "device_list_key" + stream_key == StreamKeyType.DEVICE_LIST and not self._msc3202_transaction_extensions_enabled ): return @@ -283,15 +284,15 @@ class ApplicationServicesHandler: if ( stream_key in ( - "typing_key", - "receipt_key", - "presence_key", - "to_device_key", + StreamKeyType.TYPING, + StreamKeyType.RECEIPT, + StreamKeyType.PRESENCE, + StreamKeyType.TO_DEVICE, ) and service.supports_ephemeral ) or ( - stream_key == "device_list_key" + stream_key == StreamKeyType.DEVICE_LIST and service.msc3202_transaction_extensions ) ] @@ -317,7 +318,7 @@ class ApplicationServicesHandler: logger.debug("Checking interested services for %s", stream_key) with Measure(self.clock, "notify_interested_services_ephemeral"): for service in services: - if stream_key == "typing_key": + if stream_key == StreamKeyType.TYPING: # Note that we don't persist the token (via set_appservice_stream_type_pos) # for typing_key due to performance reasons and due to their highly # ephemeral nature. @@ -333,7 +334,7 @@ class ApplicationServicesHandler: async with self._ephemeral_events_linearizer.queue( (service.id, stream_key) ): - if stream_key == "receipt_key": + if stream_key == StreamKeyType.RECEIPT: events = await self._handle_receipts(service, new_token) self.scheduler.enqueue_for_appservice(service, ephemeral=events) @@ -342,7 +343,7 @@ class ApplicationServicesHandler: service, "read_receipt", new_token ) - elif stream_key == "presence_key": + elif stream_key == StreamKeyType.PRESENCE: events = await self._handle_presence(service, users, new_token) self.scheduler.enqueue_for_appservice(service, ephemeral=events) @@ -351,7 +352,7 @@ class ApplicationServicesHandler: service, "presence", new_token ) - elif stream_key == "to_device_key": + elif stream_key == StreamKeyType.TO_DEVICE: # Retrieve a list of to-device message events, as well as the # maximum stream token of the messages we were able to retrieve. to_device_messages = await self._get_to_device_messages( @@ -366,7 +367,7 @@ class ApplicationServicesHandler: service, "to_device", new_token ) - elif stream_key == "device_list_key": + elif stream_key == StreamKeyType.DEVICE_LIST: device_list_summary = await self._get_device_list_summary( service, new_token ) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index a91b1ee4d5..1d6d1f8a92 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -43,6 +43,7 @@ from synapse.metrics.background_process_metrics import ( ) from synapse.types import ( JsonDict, + StreamKeyType, StreamToken, UserID, get_domain_from_id, @@ -502,7 +503,7 @@ class DeviceHandler(DeviceWorkerHandler): # specify the user ID too since the user should always get their own device list # updates, even if they aren't in any rooms. self.notifier.on_new_event( - "device_list_key", position, users={user_id}, rooms=room_ids + StreamKeyType.DEVICE_LIST, position, users={user_id}, rooms=room_ids ) # We may need to do some processing asynchronously for local user IDs. @@ -523,7 +524,9 @@ class DeviceHandler(DeviceWorkerHandler): from_user_id, user_ids ) - self.notifier.on_new_event("device_list_key", position, users=[from_user_id]) + self.notifier.on_new_event( + StreamKeyType.DEVICE_LIST, position, users=[from_user_id] + ) async def user_left_room(self, user: UserID, room_id: str) -> None: user_id = user.to_string() diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 4cb725d027..53668cce3b 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -26,7 +26,7 @@ from synapse.logging.opentracing import ( set_tag, ) from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet -from synapse.types import JsonDict, Requester, UserID, get_domain_from_id +from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id from synapse.util import json_encoder from synapse.util.stringutils import random_string @@ -151,7 +151,7 @@ class DeviceMessageHandler: # Notify listeners that there are new to-device messages to process, # handing them the latest stream id. self.notifier.on_new_event( - "to_device_key", last_stream_id, users=local_messages.keys() + StreamKeyType.TO_DEVICE, last_stream_id, users=local_messages.keys() ) async def _check_for_unknown_devices( @@ -285,7 +285,7 @@ class DeviceMessageHandler: # Notify listeners that there are new to-device messages to process, # handing them the latest stream id. self.notifier.on_new_event( - "to_device_key", last_stream_id, users=local_messages.keys() + StreamKeyType.TO_DEVICE, last_stream_id, users=local_messages.keys() ) if self.federation_sender: diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d6714228ef..e6c2cfb8c8 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple import attr from canonicaljson import encode_canonical_json @@ -1105,22 +1105,19 @@ class E2eKeysHandler: # can request over federation raise NotFoundError("No %s key found for %s" % (key_type, user_id)) - ( - key, - key_id, - verify_key, - ) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type) - - if key is None: + cross_signing_keys = await self._retrieve_cross_signing_keys_for_remote_user( + user, key_type + ) + if cross_signing_keys is None: raise NotFoundError("No %s key found for %s" % (key_type, user_id)) - return key, key_id, verify_key + return cross_signing_keys async def _retrieve_cross_signing_keys_for_remote_user( self, user: UserID, desired_key_type: str, - ) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]: + ) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]: """Queries cross-signing keys for a remote user and saves them to the database Only the key specified by `key_type` will be returned, while all retrieved keys @@ -1146,12 +1143,10 @@ class E2eKeysHandler: type(e), e, ) - return None, None, None + return None # Process each of the retrieved cross-signing keys - desired_key = None - desired_key_id = None - desired_verify_key = None + desired_key_data = None retrieved_device_ids = [] for key_type in ["master", "self_signing"]: key_content = remote_result.get(key_type + "_key") @@ -1196,9 +1191,7 @@ class E2eKeysHandler: # If this is the desired key type, save it and its ID/VerifyKey if key_type == desired_key_type: - desired_key = key_content - desired_verify_key = verify_key - desired_key_id = key_id + desired_key_data = key_content, key_id, verify_key # At the same time, store this key in the db for subsequent queries await self.store.set_e2e_cross_signing_key( @@ -1212,7 +1205,7 @@ class E2eKeysHandler: user.to_string(), retrieved_device_ids ) - return desired_key, desired_key_id, desired_verify_key + return desired_key_data def _check_cross_signing_key( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 38dc5b1f6e..be5099b507 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -659,7 +659,7 @@ class FederationHandler: # in the invitee's sync stream. It is stripped out for all other local users. event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"] - context = EventContext.for_outlier() + context = EventContext.for_outlier(self.storage) stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -848,7 +848,7 @@ class FederationHandler: ) ) - context = EventContext.for_outlier() + context = EventContext.for_outlier(self.storage) await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -877,7 +877,7 @@ class FederationHandler: await self.federation_client.send_leave(host_list, event) - context = EventContext.for_outlier() + context = EventContext.for_outlier(self.storage) stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 6cf927e4ff..761caa04b7 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -103,7 +103,7 @@ class FederationEventHandler: self._event_creation_handler = hs.get_event_creation_handler() self._event_auth_handler = hs.get_event_auth_handler() self._message_handler = hs.get_message_handler() - self._action_generator = hs.get_action_generator() + self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator() self._state_resolution_handler = hs.get_state_resolution_handler() # avoid a circular dependency by deferring execution here self._get_room_member_handler = hs.get_room_member_handler @@ -1423,7 +1423,7 @@ class FederationEventHandler: # we're not bothering about room state, so flag the event as an outlier. event.internal_metadata.outlier = True - context = EventContext.for_outlier() + context = EventContext.for_outlier(self._storage) try: validate_event_for_room_version(room_version_obj, event) check_auth_rules_for_event(room_version_obj, event, auth) @@ -1874,10 +1874,10 @@ class FederationEventHandler: ) return EventContext.with_state( + storage=self._storage, state_group=state_group, state_group_before_event=context.state_group_before_event, - current_state_ids=current_state_ids, - prev_state_ids=prev_state_ids, + state_delta_due_to_event=state_updates, prev_group=prev_group, delta_ids=state_updates, partial_state=context.partial_state, @@ -1913,7 +1913,7 @@ class FederationEventHandler: min_depth, ) else: - await self._action_generator.handle_push_actions_for_event( + await self._bulk_push_rule_evaluator.action_for_event_by_user( event, context ) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 7b94770f97..d79248ad90 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -30,6 +30,7 @@ from synapse.types import ( Requester, RoomStreamToken, StateMap, + StreamKeyType, StreamToken, UserID, ) @@ -143,7 +144,7 @@ class InitialSyncHandler: to_key=int(now_token.receipt_key), ) if self.hs.config.experimental.msc2285_enabled: - receipt = ReceiptEventSource.filter_out_private(receipt, user_id) + receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id) tags_by_room = await self.store.get_tags_for_user(user_id) @@ -220,8 +221,10 @@ class InitialSyncHandler: self.storage, user_id, messages ) - start_token = now_token.copy_and_replace("room_key", token) - end_token = now_token.copy_and_replace("room_key", room_end_token) + start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) + end_token = now_token.copy_and_replace( + StreamKeyType.ROOM, room_end_token + ) time_now = self.clock.time_msec() d["messages"] = { @@ -369,8 +372,8 @@ class InitialSyncHandler: self.storage, user_id, messages, is_peeking=is_peeking ) - start_token = StreamToken.START.copy_and_replace("room_key", token) - end_token = StreamToken.START.copy_and_replace("room_key", stream_token) + start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token) + end_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, stream_token) time_now = self.clock.time_msec() @@ -449,7 +452,9 @@ class InitialSyncHandler: if not receipts: return [] if self.hs.config.experimental.msc2285_enabled: - receipts = ReceiptEventSource.filter_out_private(receipts, user_id) + receipts = ReceiptEventSource.filter_out_private_receipts( + receipts, user_id + ) return receipts presence, receipts, (messages, token) = await make_deferred_yieldable( @@ -472,7 +477,7 @@ class InitialSyncHandler: self.storage, user_id, messages, is_peeking=is_peeking ) - start_token = now_token.copy_and_replace("room_key", token) + start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) end_token = now_token time_now = self.clock.time_msec() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c28b792e6f..0951b9c71f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -44,7 +44,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.urls import ConsentURIBuilder from synapse.event_auth import validate_event_for_room_version -from synapse.events import EventBase +from synapse.events import EventBase, relation_from_event from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator @@ -426,7 +426,7 @@ class EventCreationHandler: # This is to stop us from diverging history *too* much. self.limiter = Linearizer(max_count=5, name="room_event_creation_limit") - self.action_generator = hs.get_action_generator() + self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator() self.spam_checker = hs.get_spam_checker() self.third_party_event_rules: "ThirdPartyEventRules" = ( @@ -757,6 +757,10 @@ class EventCreationHandler: The previous version of the event is returned, if it is found in the event context. Otherwise, None is returned. """ + if event.internal_metadata.is_outlier(): + # This can happen due to out of band memberships + return None + prev_state_ids = await context.get_prev_state_ids() prev_event_id = prev_state_ids.get((event.type, event.state_key)) if not prev_event_id: @@ -1001,7 +1005,7 @@ class EventCreationHandler: # after it is created if builder.internal_metadata.outlier: event.internal_metadata.outlier = True - context = EventContext.for_outlier() + context = EventContext.for_outlier(self.storage) elif ( event.type == EventTypes.MSC2716_INSERTION and state_event_ids @@ -1056,20 +1060,11 @@ class EventCreationHandler: SynapseError if the event is invalid. """ - relation = event.content.get("m.relates_to") + relation = relation_from_event(event) if not relation: return - relation_type = relation.get("rel_type") - if not relation_type: - return - - # Ensure the parent is real. - relates_to = relation.get("event_id") - if not relates_to: - return - - parent_event = await self.store.get_event(relates_to, allow_none=True) + parent_event = await self.store.get_event(relation.parent_id, allow_none=True) if parent_event: # And in the same room. if parent_event.room_id != event.room_id: @@ -1078,28 +1073,31 @@ class EventCreationHandler: else: # There must be some reason that the client knows the event exists, # see if there are existing relations. If so, assume everything is fine. - if not await self.store.event_is_target_of_relation(relates_to): + if not await self.store.event_is_target_of_relation(relation.parent_id): # Otherwise, the client can't know about the parent event! raise SynapseError(400, "Can't send relation to unknown event") # If this event is an annotation then we check that that the sender # can't annotate the same way twice (e.g. stops users from liking an # event multiple times). - if relation_type == RelationTypes.ANNOTATION: - aggregation_key = relation["key"] + if relation.rel_type == RelationTypes.ANNOTATION: + aggregation_key = relation.aggregation_key + + if aggregation_key is None: + raise SynapseError(400, "Missing aggregation key") if len(aggregation_key) > 500: raise SynapseError(400, "Aggregation key is too long") already_exists = await self.store.has_user_annotated_event( - relates_to, event.type, aggregation_key, event.sender + relation.parent_id, event.type, aggregation_key, event.sender ) if already_exists: raise SynapseError(400, "Can't send same reaction twice") # Don't attempt to start a thread if the parent event is a relation. - elif relation_type == RelationTypes.THREAD: - if await self.store.event_includes_relation(relates_to): + elif relation.rel_type == RelationTypes.THREAD: + if await self.store.event_includes_relation(relation.parent_id): raise SynapseError( 400, "Cannot start threads from an event with a relation" ) @@ -1245,7 +1243,9 @@ class EventCreationHandler: # and `state_groups` because they have `prev_events` that aren't persisted yet # (historical messages persisted in reverse-chronological order). if not event.internal_metadata.is_historical(): - await self.action_generator.handle_push_actions_for_event(event, context) + await self._bulk_push_rule_evaluator.action_for_event_by_user( + event, context + ) try: # If we're a worker we need to hit out to the master. diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 7ee3340373..6ae88add95 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -27,7 +27,7 @@ from synapse.handlers.room import ShutdownRoomResponse from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester +from synapse.types import JsonDict, Requester, StreamKeyType from synapse.util.async_helpers import ReadWriteLock from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client @@ -448,7 +448,7 @@ class PaginationHandler: ) # We expect `/messages` to use historic pagination tokens by default but # `/messages` should still works with live tokens when manually provided. - assert from_token.room_key.topological + assert from_token.room_key.topological is not None if pagin_config.limit is None: # This shouldn't happen as we've set a default limit before this @@ -491,7 +491,7 @@ class PaginationHandler: if leave_token.topological < curr_topo: from_token = from_token.copy_and_replace( - "room_key", leave_token + StreamKeyType.ROOM, leave_token ) await self.hs.get_federation_handler().maybe_backfill( @@ -513,7 +513,7 @@ class PaginationHandler: event_filter=event_filter, ) - next_token = from_token.copy_and_replace("room_key", next_key) + next_token = from_token.copy_and_replace(StreamKeyType.ROOM, next_key) if events: if event_filter: diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 268481ec19..dd84e6c88b 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -66,7 +66,7 @@ from synapse.replication.tcp.commands import ClearUserSyncsCommand from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream from synapse.storage.databases.main import DataStore from synapse.streams import EventSource -from synapse.types import JsonDict, UserID, get_domain_from_id +from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.metrics import Measure @@ -522,7 +522,7 @@ class WorkerPresenceHandler(BasePresenceHandler): room_ids_to_states, users_to_states = parties self.notifier.on_new_event( - "presence_key", + StreamKeyType.PRESENCE, stream_id, rooms=room_ids_to_states.keys(), users=users_to_states.keys(), @@ -1145,7 +1145,7 @@ class PresenceHandler(BasePresenceHandler): room_ids_to_states, users_to_states = parties self.notifier.on_new_event( - "presence_key", + StreamKeyType.PRESENCE, stream_id, rooms=room_ids_to_states.keys(), users=[UserID.from_string(u) for u in users_to_states], diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 43d615357b..e6a35f1d09 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -17,7 +17,13 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from synapse.api.constants import ReceiptTypes from synapse.appservice import ApplicationService from synapse.streams import EventSource -from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id +from synapse.types import ( + JsonDict, + ReadReceipt, + StreamKeyType, + UserID, + get_domain_from_id, +) if TYPE_CHECKING: from synapse.server import HomeServer @@ -129,7 +135,9 @@ class ReceiptsHandler: affected_room_ids = list({r.room_id for r in receipts}) - self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids) + self.notifier.on_new_event( + StreamKeyType.RECEIPT, max_batch_id, rooms=affected_room_ids + ) # Note that the min here shouldn't be relied upon to be accurate. await self.hs.get_pusherpool().on_new_receipts( min_batch_id, max_batch_id, affected_room_ids @@ -165,43 +173,69 @@ class ReceiptEventSource(EventSource[int, JsonDict]): self.config = hs.config @staticmethod - def filter_out_private(events: List[JsonDict], user_id: str) -> List[JsonDict]: - """ - This method takes in what is returned by - get_linearized_receipts_for_rooms() and goes through read receipts - filtering out m.read.private receipts if they were not sent by the - current user. + def filter_out_private_receipts( + rooms: List[JsonDict], user_id: str + ) -> List[JsonDict]: """ + Filters a list of serialized receipts (as returned by /sync and /initialSync) + and removes private read receipts of other users. - visible_events = [] - - # filter out private receipts the user shouldn't see - for event in events: - content = event.get("content", {}) - new_event = event.copy() - new_event["content"] = {} - - for event_id, event_content in content.items(): - receipt_event = {} - for receipt_type, receipt_content in event_content.items(): - if receipt_type == ReceiptTypes.READ_PRIVATE: - user_rr = receipt_content.get(user_id, None) - if user_rr: - receipt_event[ReceiptTypes.READ_PRIVATE] = { - user_id: user_rr.copy() - } - else: - receipt_event[receipt_type] = receipt_content.copy() + This operates on the return value of get_linearized_receipts_for_rooms(), + which is wrapped in a cache. Care must be taken to ensure that the input + values are not modified. - # Only include the receipt event if it is non-empty. - if receipt_event: - new_event["content"][event_id] = receipt_event + Args: + rooms: A list of mappings, each mapping has a `content` field, which + is a map of event ID -> receipt type -> user ID -> receipt information. - # Append new_event to visible_events unless empty - if len(new_event["content"].keys()) > 0: - visible_events.append(new_event) + Returns: + The same as rooms, but filtered. + """ - return visible_events + result = [] + + # Iterate through each room's receipt content. + for room in rooms: + # The receipt content with other user's private read receipts removed. + content = {} + + # Iterate over each event ID / receipts for that event. + for event_id, orig_event_content in room.get("content", {}).items(): + event_content = orig_event_content + # If there are private read receipts, additional logic is necessary. + if ReceiptTypes.READ_PRIVATE in event_content: + # Make a copy without private read receipts to avoid leaking + # other user's private read receipts.. + event_content = { + receipt_type: receipt_value + for receipt_type, receipt_value in event_content.items() + if receipt_type != ReceiptTypes.READ_PRIVATE + } + + # Copy the current user's private read receipt from the + # original content, if it exists. + user_private_read_receipt = orig_event_content[ + ReceiptTypes.READ_PRIVATE + ].get(user_id, None) + if user_private_read_receipt: + event_content[ReceiptTypes.READ_PRIVATE] = { + user_id: user_private_read_receipt + } + + # Include the event if there is at least one non-private read + # receipt or the current user has a private read receipt. + if event_content: + content[event_id] = event_content + + # Include the event if there is at least one non-private read receipt + # or the current user has a private read receipt. + if content: + # Build a new event to avoid mutating the cache. + new_room = {k: v for k, v in room.items() if k != "content"} + new_room["content"] = content + result.append(new_room) + + return result async def get_new_events( self, @@ -223,7 +257,9 @@ class ReceiptEventSource(EventSource[int, JsonDict]): ) if self.config.experimental.msc2285_enabled: - events = ReceiptEventSource.filter_out_private(events, user.to_string()) + events = ReceiptEventSource.filter_out_private_receipts( + events, user.to_string() + ) return events, to_key diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index c2754ec918..ab7e54857d 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections.abc import logging from typing import ( TYPE_CHECKING, @@ -28,7 +27,7 @@ import attr from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError -from synapse.events import EventBase +from synapse.events import EventBase, relation_from_event from synapse.storage.databases.main.relations import _RelatedEvent from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.visibility import filter_events_for_client @@ -373,20 +372,21 @@ class RelationsHandler: if event.is_state(): continue - relates_to = event.content.get("m.relates_to") - relation_type = None - if isinstance(relates_to, collections.abc.Mapping): - relation_type = relates_to.get("rel_type") + relates_to = relation_from_event(event) + if relates_to: # An event which is a replacement (ie edit) or annotation (ie, # reaction) may not have any other event related to it. - if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): + if relates_to.rel_type in ( + RelationTypes.ANNOTATION, + RelationTypes.REPLACE, + ): continue + # Track the event's relation information for later. + relations_by_id[event.event_id] = relates_to.rel_type + # The event should get bundled aggregations. events_by_id[event.event_id] = event - # Track the event's relation information for later. - if isinstance(relation_type, str): - relations_by_id[event.event_id] = relation_type # event ID -> bundled aggregation in non-serialized form. results: Dict[str, BundledAggregations] = {} diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 604eb6ec15..a2973109ad 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -33,6 +33,7 @@ from typing import ( import attr from typing_extensions import TypedDict +import synapse.events.snapshot from synapse.api.constants import ( EventContentFields, EventTypes, @@ -72,12 +73,12 @@ from synapse.types import ( RoomID, RoomStreamToken, StateMap, + StreamKeyType, StreamToken, UserID, create_requester, ) from synapse.util import stringutils -from synapse.util.async_helpers import Linearizer from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_and_validate_server_name from synapse.visibility import filter_events_for_client @@ -149,10 +150,11 @@ class RoomCreationHandler: ) preset_config["encrypted"] = encrypted - self._replication = hs.get_replication_data_handler() + self._default_power_level_content_override = ( + self.config.room.default_power_level_content_override + ) - # linearizer to stop two upgrades happening at once - self._upgrade_linearizer = Linearizer("room_upgrade_linearizer") + self._replication = hs.get_replication_data_handler() # If a user tries to update the same room multiple times in quick # succession, only process the first attempt and return its result to @@ -196,6 +198,39 @@ class RoomCreationHandler: 400, "An upgrade for this room is currently in progress" ) + # Check whether the room exists and 404 if it doesn't. + # We could go straight for the auth check, but that will raise a 403 instead. + old_room = await self.store.get_room(old_room_id) + if old_room is None: + raise NotFoundError("Unknown room id %s" % (old_room_id,)) + + new_room_id = self._generate_room_id() + + # Check whether the user has the power level to carry out the upgrade. + # `check_auth_rules_from_context` will check that they are in the room and have + # the required power level to send the tombstone event. + ( + tombstone_event, + tombstone_context, + ) = await self.event_creation_handler.create_event( + requester, + { + "type": EventTypes.Tombstone, + "state_key": "", + "room_id": old_room_id, + "sender": user_id, + "content": { + "body": "This room has been replaced", + "replacement_room": new_room_id, + }, + }, + ) + old_room_version = await self.store.get_room_version(old_room_id) + validate_event_for_room_version(old_room_version, tombstone_event) + await self._event_auth_handler.check_auth_rules_from_context( + old_room_version, tombstone_event, tombstone_context + ) + # Upgrade the room # # If this user has sent multiple upgrade requests for the same room @@ -206,19 +241,35 @@ class RoomCreationHandler: self._upgrade_room, requester, old_room_id, - new_version, # args for _upgrade_room + old_room, # args for _upgrade_room + new_room_id, + new_version, + tombstone_event, + tombstone_context, ) return ret async def _upgrade_room( - self, requester: Requester, old_room_id: str, new_version: RoomVersion + self, + requester: Requester, + old_room_id: str, + old_room: Dict[str, Any], + new_room_id: str, + new_version: RoomVersion, + tombstone_event: EventBase, + tombstone_context: synapse.events.snapshot.EventContext, ) -> str: """ Args: requester: the user requesting the upgrade old_room_id: the id of the room to be replaced - new_versions: the version to upgrade the room to + old_room: a dict containing room information for the room to be replaced, + as returned by `RoomWorkerStore.get_room`. + new_room_id: the id of the replacement room + new_version: the version to upgrade the room to + tombstone_event: the tombstone event to send to the old room + tombstone_context: the context for the tombstone event Raises: ShadowBanError if the requester is shadow-banned. @@ -226,40 +277,15 @@ class RoomCreationHandler: user_id = requester.user.to_string() assert self.hs.is_mine_id(user_id), "User must be our own: %s" % (user_id,) - # start by allocating a new room id - r = await self.store.get_room(old_room_id) - if r is None: - raise NotFoundError("Unknown room id %s" % (old_room_id,)) - new_room_id = await self._generate_room_id( - creator_id=user_id, - is_public=r["is_public"], - room_version=new_version, - ) - logger.info("Creating new room %s to replace %s", new_room_id, old_room_id) - # we create and auth the tombstone event before properly creating the new - # room, to check our user has perms in the old room. - ( - tombstone_event, - tombstone_context, - ) = await self.event_creation_handler.create_event( - requester, - { - "type": EventTypes.Tombstone, - "state_key": "", - "room_id": old_room_id, - "sender": user_id, - "content": { - "body": "This room has been replaced", - "replacement_room": new_room_id, - }, - }, - ) - old_room_version = await self.store.get_room_version(old_room_id) - validate_event_for_room_version(old_room_version, tombstone_event) - await self._event_auth_handler.check_auth_rules_from_context( - old_room_version, tombstone_event, tombstone_context + # create the new room. may raise a `StoreError` in the exceedingly unlikely + # event of a room ID collision. + await self.store.store_room( + room_id=new_room_id, + room_creator_user_id=user_id, + is_public=old_room["is_public"], + room_version=new_version, ) await self.clone_existing_room( @@ -778,7 +804,7 @@ class RoomCreationHandler: visibility = config.get("visibility", "private") is_public = visibility == "public" - room_id = await self._generate_room_id( + room_id = await self._generate_and_create_room_id( creator_id=user_id, is_public=is_public, room_version=room_version, @@ -1042,9 +1068,19 @@ class RoomCreationHandler: for invitee in invite_list: power_level_content["users"][invitee] = 100 - # Power levels overrides are defined per chat preset + # If the user supplied a preset name e.g. "private_chat", + # we apply that preset power_level_content.update(config["power_level_content_override"]) + # If the server config contains default_power_level_content_override, + # and that contains information for this room preset, apply it. + if self._default_power_level_content_override: + override = self._default_power_level_content_override.get(preset_config) + if override is not None: + power_level_content.update(override) + + # Finally, if the user supplied specific permissions for this room, + # apply those. if power_level_content_override: power_level_content.update(power_level_content_override) @@ -1090,7 +1126,26 @@ class RoomCreationHandler: return last_sent_stream_id - async def _generate_room_id( + def _generate_room_id(self) -> str: + """Generates a random room ID. + + Room IDs look like "!opaque_id:domain" and are case-sensitive as per the spec + at https://spec.matrix.org/v1.2/appendices/#room-ids-and-event-ids. + + Does not check for collisions with existing rooms or prevent future calls from + returning the same room ID. To ensure the uniqueness of a new room ID, use + `_generate_and_create_room_id` instead. + + Synapse's room IDs are 18 [a-zA-Z] characters long, which comes out to around + 102 bits. + + Returns: + A random room ID of the form "!opaque_id:domain". + """ + random_string = stringutils.random_string(18) + return RoomID(random_string, self.hs.hostname).to_string() + + async def _generate_and_create_room_id( self, creator_id: str, is_public: bool, @@ -1101,8 +1156,7 @@ class RoomCreationHandler: attempts = 0 while attempts < 5: try: - random_string = stringutils.random_string(18) - gen_room_id = RoomID(random_string, self.hs.hostname).to_string() + gen_room_id = self._generate_room_id() await self.store.store_room( room_id=gen_room_id, room_creator_user_id=creator_id, @@ -1239,10 +1293,10 @@ class RoomContextHandler: events_after=events_after, state=await filter_evts(state_events), aggregations=aggregations, - start=await token.copy_and_replace("room_key", results.start).to_string( - self.store - ), - end=await token.copy_and_replace("room_key", results.end).to_string( + start=await token.copy_and_replace( + StreamKeyType.ROOM, results.start + ).to_string(self.store), + end=await token.copy_and_replace(StreamKeyType.ROOM, results.end).to_string( self.store ), ) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 5619f8f50e..cd1c47dae8 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -24,7 +24,7 @@ from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.events import EventBase from synapse.storage.state import StateFilter -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, StreamKeyType, UserID from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -655,11 +655,11 @@ class SearchHandler: "events_before": events_before, "events_after": events_after, "start": await now_token.copy_and_replace( - "room_key", res.start + StreamKeyType.ROOM, res.start + ).to_string(self.store), + "end": await now_token.copy_and_replace( + StreamKeyType.ROOM, res.end ).to_string(self.store), - "end": await now_token.copy_and_replace("room_key", res.end).to_string( - self.store - ), } if include_profile: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 2c555a66d0..4be08fe7cb 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -37,6 +37,7 @@ from synapse.types import ( Requester, RoomStreamToken, StateMap, + StreamKeyType, StreamToken, UserID, ) @@ -449,7 +450,7 @@ class SyncHandler: room_ids=room_ids, is_guest=sync_config.is_guest, ) - now_token = now_token.copy_and_replace("typing_key", typing_key) + now_token = now_token.copy_and_replace(StreamKeyType.TYPING, typing_key) ephemeral_by_room: JsonDict = {} @@ -471,7 +472,7 @@ class SyncHandler: room_ids=room_ids, is_guest=sync_config.is_guest, ) - now_token = now_token.copy_and_replace("receipt_key", receipt_key) + now_token = now_token.copy_and_replace(StreamKeyType.RECEIPT, receipt_key) for event in receipts: room_id = event["room_id"] @@ -537,7 +538,9 @@ class SyncHandler: prev_batch_token = now_token if recents: room_key = recents[0].internal_metadata.before - prev_batch_token = now_token.copy_and_replace("room_key", room_key) + prev_batch_token = now_token.copy_and_replace( + StreamKeyType.ROOM, room_key + ) return TimelineBatch( events=recents, prev_batch=prev_batch_token, limited=False @@ -611,7 +614,7 @@ class SyncHandler: recents = recents[-timeline_limit:] room_key = recents[0].internal_metadata.before - prev_batch_token = now_token.copy_and_replace("room_key", room_key) + prev_batch_token = now_token.copy_and_replace(StreamKeyType.ROOM, room_key) # Don't bother to bundle aggregations if the timeline is unlimited, # as clients will have all the necessary information. @@ -1398,7 +1401,7 @@ class SyncHandler: now_token.to_device_key, ) sync_result_builder.now_token = now_token.copy_and_replace( - "to_device_key", stream_id + StreamKeyType.TO_DEVICE, stream_id ) sync_result_builder.to_device = messages else: @@ -1503,7 +1506,7 @@ class SyncHandler: ) assert presence_key sync_result_builder.now_token = now_token.copy_and_replace( - "presence_key", presence_key + StreamKeyType.PRESENCE, presence_key ) extra_users_ids = set(newly_joined_or_invited_users) @@ -1826,7 +1829,7 @@ class SyncHandler: # stream token as it'll only be used in the context of this # room. (c.f. the docstring of `to_room_stream_token`). leave_token = since_token.copy_and_replace( - "room_key", leave_position.to_room_stream_token() + StreamKeyType.ROOM, leave_position.to_room_stream_token() ) # If this is an out of band message, like a remote invite @@ -1875,7 +1878,9 @@ class SyncHandler: if room_entry: events, start_key = room_entry - prev_batch_token = now_token.copy_and_replace("room_key", start_key) + prev_batch_token = now_token.copy_and_replace( + StreamKeyType.ROOM, start_key + ) entry = RoomSyncResultBuilder( room_id=room_id, @@ -1972,7 +1977,7 @@ class SyncHandler: continue leave_token = now_token.copy_and_replace( - "room_key", RoomStreamToken(None, event.stream_ordering) + StreamKeyType.ROOM, RoomStreamToken(None, event.stream_ordering) ) room_entries.append( RoomSyncResultBuilder( diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 6854428b7c..bb00750bfd 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -25,7 +25,7 @@ from synapse.metrics.background_process_metrics import ( ) from synapse.replication.tcp.streams import TypingStream from synapse.streams import EventSource -from synapse.types import JsonDict, Requester, UserID, get_domain_from_id +from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -382,7 +382,7 @@ class TypingWriterHandler(FollowerTypingHandler): ) self.notifier.on_new_event( - "typing_key", self._latest_room_serial, rooms=[member.room_id] + StreamKeyType.TYPING, self._latest_room_serial, rooms=[member.room_id] ) async def get_all_typing_updates( diff --git a/synapse/http/client.py b/synapse/http/client.py index 8310fb466a..084d0a5b84 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -43,8 +43,10 @@ from twisted.internet import defer, error as twisted_error, protocol, ssl from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.interfaces import ( IAddress, + IDelayedCall, IHostResolution, IReactorPluggableNameResolver, + IReactorTime, IResolutionReceiver, ITCPTransport, ) @@ -121,13 +123,15 @@ def check_against_blacklist( _EPSILON = 0.00000001 -def _make_scheduler(reactor): +def _make_scheduler( + reactor: IReactorTime, +) -> Callable[[Callable[[], object]], IDelayedCall]: """Makes a schedular suitable for a Cooperator using the given reactor. (This is effectively just a copy from `twisted.internet.task`) """ - def _scheduler(x): + def _scheduler(x: Callable[[], object]) -> IDelayedCall: return reactor.callLater(_EPSILON, x) return _scheduler @@ -348,7 +352,7 @@ class SimpleHttpClient: # XXX: The justification for using the cache factor here is that larger instances # will need both more cache and more connections. # Still, this should probably be a separate dial - pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5)) + pool.maxPersistentPerHost = max(int(100 * hs.config.caches.global_factor), 5) pool.cachedConnectionTimeout = 2 * 60 self.agent: IAgent = ProxyAgent( @@ -775,7 +779,7 @@ class SimpleHttpClient: ) -def _timeout_to_request_timed_out_error(f: Failure): +def _timeout_to_request_timed_out_error(f: Failure) -> Failure: if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError): # The TCP connection has its own timeout (set by the 'connectTimeout' param # on the Agent), which raises twisted_error.TimeoutError exception. @@ -809,7 +813,7 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): def __init__(self, deferred: defer.Deferred): self.deferred = deferred - def _maybe_fail(self): + def _maybe_fail(self) -> None: """ Report a max size exceed error and disconnect the first time this is called. """ @@ -933,12 +937,12 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory): Do not use this since it allows an attacker to intercept your communications. """ - def __init__(self): + def __init__(self) -> None: self._context = SSL.Context(SSL.SSLv23_METHOD) self._context.set_verify(VERIFY_NONE, lambda *_: False) def getContext(self, hostname=None, port=None): return self._context - def creatorForNetloc(self, hostname, port): + def creatorForNetloc(self, hostname: bytes, port: int): return self diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py index 203e995bb7..23a60af171 100644 --- a/synapse/http/connectproxyclient.py +++ b/synapse/http/connectproxyclient.py @@ -14,15 +14,22 @@ import base64 import logging -from typing import Optional +from typing import Optional, Union import attr from zope.interface import implementer from twisted.internet import defer, protocol from twisted.internet.error import ConnectError -from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint +from twisted.internet.interfaces import ( + IAddress, + IConnector, + IProtocol, + IReactorCore, + IStreamClientEndpoint, +) from twisted.internet.protocol import ClientFactory, Protocol, connectionDone +from twisted.python.failure import Failure from twisted.web import http logger = logging.getLogger(__name__) @@ -81,14 +88,14 @@ class HTTPConnectProxyEndpoint: self._port = port self._proxy_creds = proxy_creds - def __repr__(self): + def __repr__(self) -> str: return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,) # Mypy encounters a false positive here: it complains that ClientFactory # is incompatible with IProtocolFactory. But ClientFactory inherits from # Factory, which implements IProtocolFactory. So I think this is a bug # in mypy-zope. - def connect(self, protocolFactory: ClientFactory): # type: ignore[override] + def connect(self, protocolFactory: ClientFactory) -> "defer.Deferred[IProtocol]": # type: ignore[override] f = HTTPProxiedClientFactory( self._host, self._port, protocolFactory, self._proxy_creds ) @@ -125,10 +132,10 @@ class HTTPProxiedClientFactory(protocol.ClientFactory): self.proxy_creds = proxy_creds self.on_connection: "defer.Deferred[None]" = defer.Deferred() - def startedConnecting(self, connector): + def startedConnecting(self, connector: IConnector) -> None: return self.wrapped_factory.startedConnecting(connector) - def buildProtocol(self, addr): + def buildProtocol(self, addr: IAddress) -> "HTTPConnectProtocol": wrapped_protocol = self.wrapped_factory.buildProtocol(addr) if wrapped_protocol is None: raise TypeError("buildProtocol produced None instead of a Protocol") @@ -141,13 +148,13 @@ class HTTPProxiedClientFactory(protocol.ClientFactory): self.proxy_creds, ) - def clientConnectionFailed(self, connector, reason): + def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None: logger.debug("Connection to proxy failed: %s", reason) if not self.on_connection.called: self.on_connection.errback(reason) return self.wrapped_factory.clientConnectionFailed(connector, reason) - def clientConnectionLost(self, connector, reason): + def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None: logger.debug("Connection to proxy lost: %s", reason) if not self.on_connection.called: self.on_connection.errback(reason) @@ -191,10 +198,10 @@ class HTTPConnectProtocol(protocol.Protocol): ) self.http_setup_client.on_connected.addCallback(self.proxyConnected) - def connectionMade(self): + def connectionMade(self) -> None: self.http_setup_client.makeConnection(self.transport) - def connectionLost(self, reason=connectionDone): + def connectionLost(self, reason: Failure = connectionDone) -> None: if self.wrapped_protocol.connected: self.wrapped_protocol.connectionLost(reason) @@ -203,7 +210,7 @@ class HTTPConnectProtocol(protocol.Protocol): if not self.connected_deferred.called: self.connected_deferred.errback(reason) - def proxyConnected(self, _): + def proxyConnected(self, _: Union[None, "defer.Deferred[None]"]) -> None: self.wrapped_protocol.makeConnection(self.transport) self.connected_deferred.callback(self.wrapped_protocol) @@ -213,7 +220,7 @@ class HTTPConnectProtocol(protocol.Protocol): if buf: self.wrapped_protocol.dataReceived(buf) - def dataReceived(self, data: bytes): + def dataReceived(self, data: bytes) -> None: # if we've set up the HTTP protocol, we can send the data there if self.wrapped_protocol.connected: return self.wrapped_protocol.dataReceived(data) @@ -243,7 +250,7 @@ class HTTPConnectSetupClient(http.HTTPClient): self.proxy_creds = proxy_creds self.on_connected: "defer.Deferred[None]" = defer.Deferred() - def connectionMade(self): + def connectionMade(self) -> None: logger.debug("Connected to proxy, sending CONNECT") self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port)) @@ -257,14 +264,14 @@ class HTTPConnectSetupClient(http.HTTPClient): self.endHeaders() - def handleStatus(self, version: bytes, status: bytes, message: bytes): + def handleStatus(self, version: bytes, status: bytes, message: bytes) -> None: logger.debug("Got Status: %s %s %s", status, message, version) if status != b"200": raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}") - def handleEndHeaders(self): + def handleEndHeaders(self) -> None: logger.debug("End Headers") self.on_connected.callback(None) - def handleResponse(self, body): + def handleResponse(self, body: bytes) -> None: pass diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index a8a520f809..2f0177f1e2 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -239,7 +239,7 @@ class MatrixHostnameEndpointFactory: self._srv_resolver = srv_resolver - def endpointForURI(self, parsed_uri: URI): + def endpointForURI(self, parsed_uri: URI) -> "MatrixHostnameEndpoint": return MatrixHostnameEndpoint( self._reactor, self._proxy_reactor, diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index f68646fd0d..de0e882b33 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -16,7 +16,7 @@ import logging import random import time -from typing import Callable, Dict, List +from typing import Any, Callable, Dict, List import attr @@ -109,7 +109,7 @@ class SrvResolver: def __init__( self, - dns_client=client, + dns_client: Any = client, cache: Dict[bytes, List[Server]] = SERVER_CACHE, get_time: Callable[[], float] = time.time, ): diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index 43f2140429..71b685fade 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -74,9 +74,9 @@ _well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known") _had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known") -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class WellKnownLookupResult: - delegated_server = attr.ib() + delegated_server: Optional[bytes] class WellKnownResolver: @@ -336,4 +336,4 @@ def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]: class _FetchWellKnownFailure(Exception): # True if we didn't get a non-5xx HTTP response, i.e. this may or may not be # a temporary failure. - temporary = attr.ib() + temporary: bool = attr.ib() diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index c2ec3caa0e..725b5c33b8 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -23,6 +23,8 @@ from http import HTTPStatus from io import BytesIO, StringIO from typing import ( TYPE_CHECKING, + Any, + BinaryIO, Callable, Dict, Generic, @@ -44,7 +46,7 @@ from typing_extensions import Literal from twisted.internet import defer from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import IReactorTime -from twisted.internet.task import _EPSILON, Cooperator +from twisted.internet.task import Cooperator from twisted.web.client import ResponseFailed from twisted.web.http_headers import Headers from twisted.web.iweb import IBodyProducer, IResponse @@ -58,11 +60,13 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) +from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.http import QuieterFileBodyProducer from synapse.http.client import ( BlacklistingAgentWrapper, BodyExceededMaxSize, ByteWriteable, + _make_scheduler, encode_query_args, read_body_with_max_size, ) @@ -181,7 +185,7 @@ class JsonParser(ByteParser[Union[JsonDict, list]]): CONTENT_TYPE = "application/json" - def __init__(self): + def __init__(self) -> None: self._buffer = StringIO() self._binary_wrapper = BinaryIOWrapper(self._buffer) @@ -299,7 +303,9 @@ async def _handle_response( class BinaryIOWrapper: """A wrapper for a TextIO which converts from bytes on the fly.""" - def __init__(self, file: typing.TextIO, encoding="utf-8", errors="strict"): + def __init__( + self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict" + ): self.decoder = codecs.getincrementaldecoder(encoding)(errors) self.file = file @@ -317,7 +323,11 @@ class MatrixFederationHttpClient: requests. """ - def __init__(self, hs: "HomeServer", tls_client_options_factory): + def __init__( + self, + hs: "HomeServer", + tls_client_options_factory: Optional[FederationPolicyForHTTPS], + ): self.hs = hs self.signing_key = hs.signing_key self.server_name = hs.hostname @@ -348,10 +358,7 @@ class MatrixFederationHttpClient: self.version_string_bytes = hs.version_string.encode("ascii") self.default_timeout = 60 - def schedule(x): - self.reactor.callLater(_EPSILON, x) - - self._cooperator = Cooperator(scheduler=schedule) + self._cooperator = Cooperator(scheduler=_make_scheduler(self.reactor)) self._sleeper = AwakenableSleeper(self.reactor) @@ -364,7 +371,7 @@ class MatrixFederationHttpClient: self, request: MatrixFederationRequest, try_trailing_slash_on_400: bool = False, - **send_request_args, + **send_request_args: Any, ) -> IResponse: """Wrapper for _send_request which can optionally retry the request upon receiving a combination of a 400 HTTP response code and a @@ -1159,7 +1166,7 @@ class MatrixFederationHttpClient: self, destination: str, path: str, - output_stream, + output_stream: BinaryIO, args: Optional[QueryParams] = None, retry_on_dns_fail: bool = True, max_size: Optional[int] = None, @@ -1250,10 +1257,10 @@ class MatrixFederationHttpClient: return length, headers -def _flatten_response_never_received(e): +def _flatten_response_never_received(e: BaseException) -> str: if hasattr(e, "reasons"): reasons = ", ".join( - _flatten_response_never_received(f.value) for f in e.reasons + _flatten_response_never_received(f.value) for f in e.reasons # type: ignore[attr-defined] ) return "%s:[%s]" % (type(e).__name__, reasons) diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index a16dde2380..b2a50c9105 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -245,7 +245,7 @@ def http_proxy_endpoint( proxy: Optional[bytes], reactor: IReactorCore, tls_options_factory: Optional[IPolicyForHTTPS], - **kwargs, + **kwargs: object, ) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]: """Parses an http proxy setting and returns an endpoint for the proxy diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index 4886626d50..2b6d113544 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -162,7 +162,7 @@ class RequestMetrics: with _in_flight_requests_lock: _in_flight_requests.add(self) - def stop(self, time_sec, response_code, sent_bytes): + def stop(self, time_sec: float, response_code: int, sent_bytes: int) -> None: with _in_flight_requests_lock: _in_flight_requests.discard(self) @@ -186,13 +186,13 @@ class RequestMetrics: ) return - response_code = str(response_code) + response_code_str = str(response_code) - outgoing_responses_counter.labels(self.method, response_code).inc() + outgoing_responses_counter.labels(self.method, response_code_str).inc() response_count.labels(self.method, self.name, tag).inc() - response_timer.labels(self.method, self.name, tag, response_code).observe( + response_timer.labels(self.method, self.name, tag, response_code_str).observe( time_sec - self.start_ts ) @@ -221,7 +221,7 @@ class RequestMetrics: # flight. self.update_metrics() - def update_metrics(self): + def update_metrics(self) -> None: """Updates the in flight metrics with values from this request.""" if not self.start_context: logger.error( diff --git a/synapse/http/server.py b/synapse/http/server.py index 657bffcddd..e3dcc3f3dd 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -33,6 +33,7 @@ from typing import ( Optional, Pattern, Tuple, + TypeVar, Union, ) @@ -92,6 +93,68 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html> HTTP_STATUS_REQUEST_CANCELLED = 499 +F = TypeVar("F", bound=Callable[..., Any]) + + +_cancellable_method_names = frozenset( + { + # `RestServlet`, `BaseFederationServlet` and `BaseFederationServerServlet` + # methods + "on_GET", + "on_PUT", + "on_POST", + "on_DELETE", + # `_AsyncResource`, `DirectServeHtmlResource` and `DirectServeJsonResource` + # methods + "_async_render_GET", + "_async_render_PUT", + "_async_render_POST", + "_async_render_DELETE", + "_async_render_OPTIONS", + # `ReplicationEndpoint` methods + "_handle_request", + } +) + + +def cancellable(method: F) -> F: + """Marks a servlet method as cancellable. + + Methods with this decorator will be cancelled if the client disconnects before we + finish processing the request. + + During cancellation, `Deferred.cancel()` will be invoked on the `Deferred` wrapping + the method. The `cancel()` call will propagate down to the `Deferred` that is + currently being waited on. That `Deferred` will raise a `CancelledError`, which will + propagate up, as per normal exception handling. + + Before applying this decorator to a new endpoint, you MUST recursively check + that all `await`s in the function are on `async` functions or `Deferred`s that + handle cancellation cleanly, otherwise a variety of bugs may occur, ranging from + premature logging context closure, to stuck requests, to database corruption. + + Usage: + class SomeServlet(RestServlet): + @cancellable + async def on_GET(self, request: SynapseRequest) -> ...: + ... + """ + if method.__name__ not in _cancellable_method_names and not any( + method.__name__.startswith(prefix) for prefix in _cancellable_method_names + ): + raise ValueError( + "@cancellable decorator can only be applied to servlet methods." + ) + + method.cancellable = True # type: ignore[attr-defined] + return method + + +def is_method_cancellable(method: Callable[..., Any]) -> bool: + """Checks whether a servlet method has the `@cancellable` flag.""" + return getattr(method, "cancellable", False) + + def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: """Sends a JSON error response to clients.""" @@ -253,6 +316,9 @@ class HttpServer(Protocol): If the regex contains groups these gets passed to the callback via an unpacked tuple. + The callback may be marked with the `@cancellable` decorator, which will + cause request processing to be cancelled when clients disconnect early. + Args: method: The HTTP method to listen to. path_patterns: The regex used to match requests. @@ -283,7 +349,9 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): def render(self, request: SynapseRequest) -> int: """This gets called by twisted every time someone sends us a request.""" - defer.ensureDeferred(self._async_render_wrapper(request)) + request.render_deferred = defer.ensureDeferred( + self._async_render_wrapper(request) + ) return NOT_DONE_YET @wrap_async_request_handler @@ -319,6 +387,8 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): method_handler = getattr(self, "_async_render_%s" % (request_method,), None) if method_handler: + request.is_render_cancellable = is_method_cancellable(method_handler) + raw_callback_return = method_handler(request) # Is it synchronous? We'll allow this for now. @@ -479,6 +549,8 @@ class JsonResource(DirectServeJsonResource): async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]: callback, servlet_classname, group_dict = self._get_handler_for_request(request) + request.is_render_cancellable = is_method_cancellable(callback) + # Make sure we have an appropriate name for this handler in prometheus # (rather than the default of JsonResource). request.request_metrics.name = servlet_classname diff --git a/synapse/http/site.py b/synapse/http/site.py index 0b85a57d77..eeec74b78a 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union import attr from zope.interface import implementer +from twisted.internet.defer import Deferred from twisted.internet.interfaces import IAddress, IReactorTime from twisted.python.failure import Failure from twisted.web.http import HTTPChannel @@ -91,6 +92,14 @@ class SynapseRequest(Request): # we can't yet create the logcontext, as we don't know the method. self.logcontext: Optional[LoggingContext] = None + # The `Deferred` to cancel if the client disconnects early and + # `is_render_cancellable` is set. Expected to be set by `Resource.render`. + self.render_deferred: Optional["Deferred[None]"] = None + # A boolean indicating whether `render_deferred` should be cancelled if the + # client disconnects early. Expected to be set by the coroutine started by + # `Resource.render`, if rendering is asynchronous. + self.is_render_cancellable = False + global _next_request_seq self.request_seq = _next_request_seq _next_request_seq += 1 @@ -357,7 +366,21 @@ class SynapseRequest(Request): {"event": "client connection lost", "reason": str(reason.value)} ) - if not self._is_processing: + if self._is_processing: + if self.is_render_cancellable: + if self.render_deferred is not None: + # Throw a cancellation into the request processing, in the hope + # that it will finish up sooner than it normally would. + # The `self.processing()` context manager will call + # `_finished_processing()` when done. + with PreserveLoggingContext(): + self.render_deferred.cancel() + else: + logger.error( + "Connection from client lost, but have no Deferred to " + "cancel even though the request is marked as cancellable." + ) + else: self._finished_processing() def _started_processing(self, servlet_name: str) -> None: diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py index 475756f1db..5a61b21eaf 100644 --- a/synapse/logging/_remote.py +++ b/synapse/logging/_remote.py @@ -31,7 +31,11 @@ from twisted.internet.endpoints import ( TCP4ClientEndpoint, TCP6ClientEndpoint, ) -from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint +from twisted.internet.interfaces import ( + IPushProducer, + IReactorTCP, + IStreamClientEndpoint, +) from twisted.internet.protocol import Factory, Protocol from twisted.internet.tcp import Connection from twisted.python.failure import Failure @@ -59,14 +63,14 @@ class LogProducer: _buffer: Deque[logging.LogRecord] _paused: bool = attr.ib(default=False, init=False) - def pauseProducing(self): + def pauseProducing(self) -> None: self._paused = True - def stopProducing(self): + def stopProducing(self) -> None: self._paused = True self._buffer = deque() - def resumeProducing(self): + def resumeProducing(self) -> None: # If we're already producing, nothing to do. self._paused = False @@ -102,8 +106,8 @@ class RemoteHandler(logging.Handler): host: str, port: int, maximum_buffer: int = 1000, - level=logging.NOTSET, - _reactor=None, + level: int = logging.NOTSET, + _reactor: Optional[IReactorTCP] = None, ): super().__init__(level=level) self.host = host @@ -118,7 +122,7 @@ class RemoteHandler(logging.Handler): if _reactor is None: from twisted.internet import reactor - _reactor = reactor + _reactor = reactor # type: ignore[assignment] try: ip = ip_address(self.host) @@ -139,7 +143,7 @@ class RemoteHandler(logging.Handler): self._stopping = False self._connect() - def close(self): + def close(self) -> None: self._stopping = True self._service.stopService() diff --git a/synapse/logging/formatter.py b/synapse/logging/formatter.py index c0f12ecd15..c88b8ae545 100644 --- a/synapse/logging/formatter.py +++ b/synapse/logging/formatter.py @@ -16,6 +16,8 @@ import logging import traceback from io import StringIO +from types import TracebackType +from typing import Optional, Tuple, Type class LogFormatter(logging.Formatter): @@ -28,10 +30,14 @@ class LogFormatter(logging.Formatter): where it was caught are logged). """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def formatException(self, ei): + def formatException( + self, + ei: Tuple[ + Optional[Type[BaseException]], + Optional[BaseException], + Optional[TracebackType], + ], + ) -> str: sio = StringIO() (typ, val, tb) = ei diff --git a/synapse/logging/handlers.py b/synapse/logging/handlers.py index 478b527494..dec2a2c3dd 100644 --- a/synapse/logging/handlers.py +++ b/synapse/logging/handlers.py @@ -49,7 +49,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler): ) self._flushing_thread.start() - def on_reactor_running(): + def on_reactor_running() -> None: self._reactor_started = True reactor_to_use: IReactorCore @@ -74,7 +74,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler): else: return True - def _flush_periodically(self): + def _flush_periodically(self) -> None: """ Whilst this handler is active, flush the handler periodically. """ diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py index d57e7c5324..a26a1a58e7 100644 --- a/synapse/logging/scopecontextmanager.py +++ b/synapse/logging/scopecontextmanager.py @@ -13,6 +13,8 @@ # limitations under the License.import logging import logging +from types import TracebackType +from typing import Optional, Type from opentracing import Scope, ScopeManager @@ -107,19 +109,26 @@ class _LogContextScope(Scope): and - if enter_logcontext was set - the logcontext is finished too. """ - def __init__(self, manager, span, logcontext, enter_logcontext, finish_on_close): + def __init__( + self, + manager: LogContextScopeManager, + span, + logcontext, + enter_logcontext: bool, + finish_on_close: bool, + ): """ Args: - manager (LogContextScopeManager): + manager: the manager that is responsible for this scope. span (Span): the opentracing span which this scope represents the local lifetime for. logcontext (LogContext): the logcontext to which this scope is attached. - enter_logcontext (Boolean): + enter_logcontext: if True the logcontext will be exited when the scope is finished - finish_on_close (Boolean): + finish_on_close: if True finish the span when the scope is closed """ super().__init__(manager, span) @@ -127,16 +136,21 @@ class _LogContextScope(Scope): self._finish_on_close = finish_on_close self._enter_logcontext = enter_logcontext - def __exit__(self, exc_type, value, traceback): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: if exc_type == twisted.internet.defer._DefGen_Return: # filter out defer.returnValue() calls exc_type = value = traceback = None super().__exit__(exc_type, value, traceback) - def __str__(self): + def __str__(self) -> str: return f"Scope<{self.span}>" - def close(self): + def close(self) -> None: active_scope = self.manager.active if active_scope is not self: logger.error( diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py index 6bc329f04a..1fc8a0e888 100644 --- a/synapse/metrics/jemalloc.py +++ b/synapse/metrics/jemalloc.py @@ -18,6 +18,7 @@ import os import re from typing import Iterable, Optional, overload +import attr from prometheus_client import REGISTRY, Metric from typing_extensions import Literal @@ -27,52 +28,24 @@ from synapse.metrics._types import Collector logger = logging.getLogger(__name__) -def _setup_jemalloc_stats() -> None: - """Checks to see if jemalloc is loaded, and hooks up a collector to record - statistics exposed by jemalloc. - """ - - # Try to find the loaded jemalloc shared library, if any. We need to - # introspect into what is loaded, rather than loading whatever is on the - # path, as if we load a *different* jemalloc version things will seg fault. - - # We look in `/proc/self/maps`, which only exists on linux. - if not os.path.exists("/proc/self/maps"): - logger.debug("Not looking for jemalloc as no /proc/self/maps exist") - return - - # We're looking for a path at the end of the line that includes - # "libjemalloc". - regex = re.compile(r"/\S+/libjemalloc.*$") - - jemalloc_path = None - with open("/proc/self/maps") as f: - for line in f: - match = regex.search(line.strip()) - if match: - jemalloc_path = match.group() - - if not jemalloc_path: - # No loaded jemalloc was found. - logger.debug("jemalloc not found") - return - - logger.debug("Found jemalloc at %s", jemalloc_path) - - jemalloc = ctypes.CDLL(jemalloc_path) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class JemallocStats: + jemalloc: ctypes.CDLL @overload def _mallctl( - name: str, read: Literal[True] = True, write: Optional[int] = None + self, name: str, read: Literal[True] = True, write: Optional[int] = None ) -> int: ... @overload - def _mallctl(name: str, read: Literal[False], write: Optional[int] = None) -> None: + def _mallctl( + self, name: str, read: Literal[False], write: Optional[int] = None + ) -> None: ... def _mallctl( - name: str, read: bool = True, write: Optional[int] = None + self, name: str, read: bool = True, write: Optional[int] = None ) -> Optional[int]: """Wrapper around `mallctl` for reading and writing integers to jemalloc. @@ -120,7 +93,7 @@ def _setup_jemalloc_stats() -> None: # Where oldp/oldlenp is a buffer where the old value will be written to # (if not null), and newp/newlen is the buffer with the new value to set # (if not null). Note that they're all references *except* newlen. - result = jemalloc.mallctl( + result = self.jemalloc.mallctl( name.encode("ascii"), input_var_ref, input_len_ref, @@ -136,21 +109,80 @@ def _setup_jemalloc_stats() -> None: return input_var.value - def _jemalloc_refresh_stats() -> None: + def refresh_stats(self) -> None: """Request that jemalloc updates its internal statistics. This needs to be called before querying for stats, otherwise it will return stale values. """ try: - _mallctl("epoch", read=False, write=1) + self._mallctl("epoch", read=False, write=1) except Exception as e: logger.warning("Failed to reload jemalloc stats: %s", e) + def get_stat(self, name: str) -> int: + """Request the stat of the given name at the time of the last + `refresh_stats` call. This may throw if we fail to read + the stat. + """ + return self._mallctl(f"stats.{name}") + + +_JEMALLOC_STATS: Optional[JemallocStats] = None + + +def get_jemalloc_stats() -> Optional[JemallocStats]: + """Returns an interface to jemalloc, if it is being used. + + Note that this will always return None until `setup_jemalloc_stats` has been + called. + """ + return _JEMALLOC_STATS + + +def _setup_jemalloc_stats() -> None: + """Checks to see if jemalloc is loaded, and hooks up a collector to record + statistics exposed by jemalloc. + """ + + global _JEMALLOC_STATS + + # Try to find the loaded jemalloc shared library, if any. We need to + # introspect into what is loaded, rather than loading whatever is on the + # path, as if we load a *different* jemalloc version things will seg fault. + + # We look in `/proc/self/maps`, which only exists on linux. + if not os.path.exists("/proc/self/maps"): + logger.debug("Not looking for jemalloc as no /proc/self/maps exist") + return + + # We're looking for a path at the end of the line that includes + # "libjemalloc". + regex = re.compile(r"/\S+/libjemalloc.*$") + + jemalloc_path = None + with open("/proc/self/maps") as f: + for line in f: + match = regex.search(line.strip()) + if match: + jemalloc_path = match.group() + + if not jemalloc_path: + # No loaded jemalloc was found. + logger.debug("jemalloc not found") + return + + logger.debug("Found jemalloc at %s", jemalloc_path) + + jemalloc_dll = ctypes.CDLL(jemalloc_path) + + stats = JemallocStats(jemalloc_dll) + _JEMALLOC_STATS = stats + class JemallocCollector(Collector): """Metrics for internal jemalloc stats.""" def collect(self) -> Iterable[Metric]: - _jemalloc_refresh_stats() + stats.refresh_stats() g = GaugeMetricFamily( "jemalloc_stats_app_memory_bytes", @@ -184,7 +216,7 @@ def _setup_jemalloc_stats() -> None: "metadata", ): try: - value = _mallctl(f"stats.{t}") + value = stats.get_stat(t) except Exception as e: # There was an error fetching the value, skip. logger.warning("Failed to read jemalloc stats.%s: %s", t, e) diff --git a/synapse/notifier.py b/synapse/notifier.py index 01a50b9d62..ba23257f54 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -46,6 +46,7 @@ from synapse.types import ( JsonDict, PersistedEventPosition, RoomStreamToken, + StreamKeyType, StreamToken, UserID, ) @@ -370,7 +371,7 @@ class Notifier: if users or rooms: self.on_new_event( - "room_key", + StreamKeyType.ROOM, max_room_stream_token, users=users, rooms=rooms, @@ -440,7 +441,7 @@ class Notifier: for room in rooms: user_streams |= self.room_to_user_streams.get(room, set()) - if stream_key == "to_device_key": + if stream_key == StreamKeyType.TO_DEVICE: issue9533_logger.debug( "to-device messages stream id %s, awaking streams for %s", new_token, diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index a1b7711098..57c4d70466 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -12,6 +12,80 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +This module implements the push rules & notifications portion of the Matrix +specification. + +There's a few related features: + +* Push notifications (i.e. email or outgoing requests to a Push Gateway). +* Calculation of unread notifications (for /sync and /notifications). + +When Synapse receives a new event (locally, via the Client-Server API, or via +federation), the following occurs: + +1. The push rules get evaluated to generate a set of per-user actions. +2. The event is persisted into the database. +3. (In the background) The notifier is notified about the new event. + +The per-user actions are initially stored in the event_push_actions_staging table, +before getting moved into the event_push_actions table when the event is persisted. +The event_push_actions table is periodically summarised into the event_push_summary +and event_push_summary_stream_ordering tables. + +Since push actions block an event from being persisted the generation of push +actions is performance sensitive. + +The general interaction of the classes are: + + +---------------------------------------------+ + | FederationEventHandler/EventCreationHandler | + +---------------------------------------------+ + | + v + +-----------------------+ +---------------------------+ + | BulkPushRuleEvaluator |---->| PushRuleEvaluatorForEvent | + +-----------------------+ +---------------------------+ + | + v + +-----------------------------+ + | EventPushActionsWorkerStore | + +-----------------------------+ + +The notifier notifies the pusher pool of the new event, which checks for affected +users. Each user-configured pusher of the affected users then performs the +previously calculated action. + +The general interaction of the classes are: + + +----------+ + | Notifier | + +----------+ + | + v + +------------+ +--------------+ + | PusherPool |---->| PusherConfig | + +------------+ +--------------+ + | + | +---------------+ + +<--->| PusherFactory | + | +---------------+ + v + +------------------------+ +-----------------------------------------------+ + | EmailPusher/HttpPusher |---->| EventPushActionsWorkerStore/PusherWorkerStore | + +------------------------+ +-----------------------------------------------+ + | + v + +-------------------------+ + | Mailer/SimpleHttpClient | + +-------------------------+ + +The Pusher instance also calls out to various utilities for generating payloads +(or email templates), but those interactions are not detailed in this diagram +(and are specific to the type of pusher). + +""" + import abc from typing import TYPE_CHECKING, Any, Dict, Optional diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py deleted file mode 100644 index 60758df016..0000000000 --- a/synapse/push/action_generator.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2015 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import TYPE_CHECKING - -from synapse.events import EventBase -from synapse.events.snapshot import EventContext -from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator -from synapse.util.metrics import Measure - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class ActionGenerator: - def __init__(self, hs: "HomeServer"): - self.clock = hs.get_clock() - self.bulk_evaluator = BulkPushRuleEvaluator(hs) - # really we want to get all user ids and all profile tags too, - # since we want the actions for each profile tag for every user and - # also actions for a client with no profile tag for each user. - # Currently the event stream doesn't support profile tags on an - # event stream, so we just run the rules for a client with no profile - # tag (ie. we just need all the users). - - async def handle_push_actions_for_event( - self, event: EventBase, context: EventContext - ) -> None: - with Measure(self.clock, "action_for_event_by_user"): - await self.bulk_evaluator.action_for_event_by_user(event, context) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index b07cf2eee7..4ac2c546bf 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -21,7 +21,7 @@ from prometheus_client import Counter from synapse.api.constants import EventTypes, Membership, RelationTypes from synapse.event_auth import get_user_power_level -from synapse.events import EventBase +from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership @@ -29,6 +29,7 @@ from synapse.util.async_helpers import Linearizer from synapse.util.caches import CacheMetric, register_cache from synapse.util.caches.descriptors import lru_cache from synapse.util.caches.lrucache import LruCache +from synapse.util.metrics import measure_func from .push_rule_evaluator import PushRuleEvaluatorForEvent @@ -77,8 +78,8 @@ def _should_count_as_unread(event: EventBase, context: EventContext) -> bool: return False # Exclude edits. - relates_to = event.content.get("m.relates_to", {}) - if relates_to.get("rel_type") == RelationTypes.REPLACE: + relates_to = relation_from_event(event) + if relates_to and relates_to.rel_type == RelationTypes.REPLACE: return False # Mark events that have a non-empty string body as unread. @@ -105,6 +106,7 @@ class BulkPushRuleEvaluator: def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastores().main + self.clock = hs.get_clock() self._event_auth_handler = hs.get_event_auth_handler() # Used by `RulesForRoom` to ensure only one thing mutates the cache at a @@ -185,6 +187,7 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level + @measure_func("action_for_event_by_user") async def action_for_event_by_user( self, event: EventBase, context: EventContext ) -> None: @@ -192,6 +195,10 @@ class BulkPushRuleEvaluator: should increment the unread count, and insert the results into the event_push_actions_staging table. """ + if event.internal_metadata.is_outlier(): + # This can happen due to out of band memberships + return + count_as_unread = _should_count_as_unread(event, context) rules_by_user = await self._get_rules_for_event(event, context) @@ -208,8 +215,6 @@ class BulkPushRuleEvaluator: event, len(room_members), sender_power_level, power_levels ) - condition_cache: Dict[str, bool] = {} - # If the event is not a state event check if any users ignore the sender. if not event.is_state(): ignorers = await self.store.ignored_by(event.sender) @@ -247,8 +252,8 @@ class BulkPushRuleEvaluator: if "enabled" in rule and not rule["enabled"]: continue - matches = _condition_checker( - evaluator, rule["conditions"], uid, display_name, condition_cache + matches = evaluator.check_conditions( + rule["conditions"], uid, display_name ) if matches: actions = [x for x in rule["actions"] if x != "dont_notify"] @@ -267,32 +272,6 @@ class BulkPushRuleEvaluator: ) -def _condition_checker( - evaluator: PushRuleEvaluatorForEvent, - conditions: List[dict], - uid: str, - display_name: Optional[str], - cache: Dict[str, bool], -) -> bool: - for cond in conditions: - _cache_key = cond.get("_cache_key", None) - if _cache_key: - res = cache.get(_cache_key, None) - if res is False: - return False - elif res is True: - continue - - res = evaluator.matches(cond, uid, display_name) - if _cache_key: - cache[_cache_key] = bool(res) - - if not res: - return False - - return True - - MemberMap = Dict[str, Optional[EventIdMembership]] Rule = Dict[str, dict] RulesByUser = Dict[str, List[Rule]] diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index f617c759e6..54db6b5612 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -129,9 +129,55 @@ class PushRuleEvaluatorForEvent: # Maps strings of e.g. 'content.body' -> event["content"]["body"] self._value_cache = _flatten_dict(event) + # Maps cache keys to final values. + self._condition_cache: Dict[str, bool] = {} + + def check_conditions( + self, conditions: List[dict], uid: str, display_name: Optional[str] + ) -> bool: + """ + Returns true if a user's conditions/user ID/display name match the event. + + Args: + conditions: The user's conditions to match. + uid: The user's MXID. + display_name: The display name. + + Returns: + True if all conditions match the event, False otherwise. + """ + for cond in conditions: + _cache_key = cond.get("_cache_key", None) + if _cache_key: + res = self._condition_cache.get(_cache_key, None) + if res is False: + return False + elif res is True: + continue + + res = self.matches(cond, uid, display_name) + if _cache_key: + self._condition_cache[_cache_key] = bool(res) + + if not res: + return False + + return True + def matches( self, condition: Dict[str, Any], user_id: str, display_name: Optional[str] ) -> bool: + """ + Returns true if a user's condition/user ID/display name match the event. + + Args: + condition: The user's condition to match. + uid: The user's MXID. + display_name: The display name, or None if there is not one. + + Returns: + True if the condition matches the event, False otherwise. + """ if condition["kind"] == "event_match": return self._event_match(condition, user_id) elif condition["kind"] == "contains_display_name": @@ -146,6 +192,16 @@ class PushRuleEvaluatorForEvent: return True def _event_match(self, condition: dict, user_id: str) -> bool: + """ + Check an "event_match" push rule condition. + + Args: + condition: The "event_match" push rule condition to match. + user_id: The user's MXID. + + Returns: + True if the condition matches the event, False otherwise. + """ pattern = condition.get("pattern", None) if not pattern: @@ -167,13 +223,22 @@ class PushRuleEvaluatorForEvent: return _glob_matches(pattern, body, word_boundary=True) else: - haystack = self._get_value(condition["key"]) + haystack = self._value_cache.get(condition["key"], None) if haystack is None: return False return _glob_matches(pattern, haystack) def _contains_display_name(self, display_name: Optional[str]) -> bool: + """ + Check an "event_match" push rule condition. + + Args: + display_name: The display name, or None if there is not one. + + Returns: + True if the display name is found in the event body, False otherwise. + """ if not display_name: return False @@ -191,9 +256,6 @@ class PushRuleEvaluatorForEvent: return bool(r.search(body)) - def _get_value(self, dotted_key: str) -> Optional[str]: - return self._value_cache.get(dotted_key, None) - # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache( diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 2bd244ed79..a4ae4040c3 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -26,7 +26,8 @@ from twisted.web.server import Request from synapse.api.errors import HttpResponseException, SynapseError from synapse.http import RequestTimedOutError -from synapse.http.server import HttpServer +from synapse.http.server import HttpServer, is_method_cancellable +from synapse.http.site import SynapseRequest from synapse.logging import opentracing from synapse.logging.opentracing import trace from synapse.types import JsonDict @@ -310,6 +311,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): url_args = list(self.PATH_ARGS) method = self.METHOD + if self.CACHE and is_method_cancellable(self._handle_request): + raise Exception( + f"{self.__class__.__name__} has been marked as cancellable, but CACHE " + "is set. The cancellable flag would have no effect." + ) + if self.CACHE: url_args.append("txn_id") @@ -324,7 +331,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): ) async def _check_auth_and_handle( - self, request: Request, **kwargs: Any + self, request: SynapseRequest, **kwargs: Any ) -> Tuple[int, JsonDict]: """Called on new incoming requests when caching is enabled. Checks if there is a cached response for the request and returns that, @@ -340,8 +347,18 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): if self.CACHE: txn_id = kwargs.pop("txn_id") + # We ignore the `@cancellable` flag, since cancellation wouldn't interupt + # `_handle_request` and `ResponseCache` does not handle cancellation + # correctly yet. In particular, there may be issues to do with logging + # context lifetimes. + return await self.response_cache.wrap( txn_id, self._handle_request, request, **kwargs ) + # The `@cancellable` decorator may be applied to `_handle_request`. But we + # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`, + # so we have to set up the cancellable flag ourselves. + request.is_render_cancellable = is_method_cancellable(self._handle_request) + return await self._handle_request(request, **kwargs) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 350762f494..a52e25c1af 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -43,7 +43,7 @@ from synapse.replication.tcp.streams.events import ( EventsStreamEventRow, EventsStreamRow, ) -from synapse.types import PersistedEventPosition, ReadReceipt, UserID +from synapse.types import PersistedEventPosition, ReadReceipt, StreamKeyType, UserID from synapse.util.async_helpers import Linearizer, timeout_deferred from synapse.util.metrics import Measure @@ -153,19 +153,19 @@ class ReplicationDataHandler: if stream_name == TypingStream.NAME: self._typing_handler.process_replication_rows(token, rows) self.notifier.on_new_event( - "typing_key", token, rooms=[row.room_id for row in rows] + StreamKeyType.TYPING, token, rooms=[row.room_id for row in rows] ) elif stream_name == PushRulesStream.NAME: self.notifier.on_new_event( - "push_rules_key", token, users=[row.user_id for row in rows] + StreamKeyType.PUSH_RULES, token, users=[row.user_id for row in rows] ) elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME): self.notifier.on_new_event( - "account_data_key", token, users=[row.user_id for row in rows] + StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows] ) elif stream_name == ReceiptsStream.NAME: self.notifier.on_new_event( - "receipt_key", token, rooms=[row.room_id for row in rows] + StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows] ) await self._pusher_pool.on_new_receipts( token, token, {row.room_id for row in rows} @@ -173,14 +173,18 @@ class ReplicationDataHandler: elif stream_name == ToDeviceStream.NAME: entities = [row.entity for row in rows if row.entity.startswith("@")] if entities: - self.notifier.on_new_event("to_device_key", token, users=entities) + self.notifier.on_new_event( + StreamKeyType.TO_DEVICE, token, users=entities + ) elif stream_name == DeviceListsStream.NAME: all_room_ids: Set[str] = set() for row in rows: if row.entity.startswith("@"): room_ids = await self.store.get_rooms_for_user(row.entity) all_room_ids.update(room_ids) - self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids) + self.notifier.on_new_event( + StreamKeyType.DEVICE_LIST, token, rooms=all_room_ids + ) elif stream_name == GroupServerStream.NAME: self.notifier.on_new_event( "groups_key", token, users=[row.user_id for row in rows] diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index f9caab6635..4b03eb876b 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -13,12 +13,10 @@ # limitations under the License. import logging -import re from typing import TYPE_CHECKING, Tuple from synapse.api.constants import ReceiptTypes from synapse.api.errors import SynapseError -from synapse.http import get_request_user_agent from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest @@ -26,8 +24,6 @@ from synapse.types import JsonDict from ._base import client_patterns -pattern = re.compile(r"(?:Element|SchildiChat)/1\.[012]\.") - if TYPE_CHECKING: from synapse.server import HomeServer @@ -69,14 +65,7 @@ class ReceiptRestServlet(RestServlet): ): raise SynapseError(400, "Receipt type must be 'm.read'") - # Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body. - user_agent = get_request_user_agent(request) - allow_empty_body = False - if "Android" in user_agent: - if pattern.match(user_agent) or "Riot" in user_agent: - allow_empty_body = True - # This call makes sure possible empty body is handled correctly - parse_json_object_from_request(request, allow_empty_body) + parse_json_object_from_request(request, allow_empty_body=False) await self.presence_handler.bump_presence_active_time(requester.user) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 906fe09e97..4b8bfbffcb 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -34,7 +34,7 @@ from synapse.api.errors import ( ) from synapse.api.filtering import Filter from synapse.events.utils import format_event_for_client_v2 -from synapse.http.server import HttpServer +from synapse.http.server import HttpServer, cancellable from synapse.http.servlet import ( ResolveRoomIdMixin, RestServlet, @@ -143,6 +143,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): self.__class__.__name__, ) + @cancellable def on_GET_no_state_key( self, request: SynapseRequest, room_id: str, event_type: str ) -> Awaitable[Tuple[int, JsonDict]]: @@ -153,6 +154,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): ) -> Awaitable[Tuple[int, JsonDict]]: return self.on_PUT(request, room_id, event_type, "") + @cancellable async def on_GET( self, request: SynapseRequest, room_id: str, event_type: str, state_key: str ) -> Tuple[int, JsonDict]: @@ -481,6 +483,7 @@ class RoomMemberListRestServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastores().main + @cancellable async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: @@ -602,6 +605,7 @@ class RoomStateRestServlet(RestServlet): self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() + @cancellable async def on_GET( self, request: SynapseRequest, room_id: str ) -> Tuple[int, List[JsonDict]]: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 50383bdbd1..2b2db63bf7 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -668,7 +668,7 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("Running url preview cache expiry") if not (await self.store.db_pool.updates.has_completed_background_updates()): - logger.info("Still running DB updates; skipping expiry") + logger.debug("Still running DB updates; skipping url preview cache expiry") return def try_remove_parent_dirs(dirs: Iterable[str]) -> None: @@ -688,7 +688,9 @@ class PreviewUrlResource(DirectServeJsonResource): # Failed, skip deleting the rest of the parent dirs if e.errno != errno.ENOTEMPTY: logger.warning( - "Failed to remove media directory: %r: %s", dir, e + "Failed to remove media directory while clearing url preview cache: %r: %s", + dir, + e, ) break @@ -703,7 +705,11 @@ class PreviewUrlResource(DirectServeJsonResource): except FileNotFoundError: pass # If the path doesn't exist, meh except OSError as e: - logger.warning("Failed to remove media: %r: %s", media_id, e) + logger.warning( + "Failed to remove media while clearing url preview cache: %r: %s", + media_id, + e, + ) continue removed_media.append(media_id) @@ -714,9 +720,11 @@ class PreviewUrlResource(DirectServeJsonResource): await self.store.delete_url_cache(removed_media) if removed_media: - logger.info("Deleted %d entries from url cache", len(removed_media)) + logger.debug( + "Deleted %d entries from url preview cache", len(removed_media) + ) else: - logger.debug("No entries removed from url cache") + logger.debug("No entries removed from url preview cache") # Now we delete old images associated with the url cache. # These may be cached for a bit on the client (i.e., they @@ -733,7 +741,9 @@ class PreviewUrlResource(DirectServeJsonResource): except FileNotFoundError: pass # If the path doesn't exist, meh except OSError as e: - logger.warning("Failed to remove media: %r: %s", media_id, e) + logger.warning( + "Failed to remove media from url preview cache: %r: %s", media_id, e + ) continue dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id) @@ -745,7 +755,9 @@ class PreviewUrlResource(DirectServeJsonResource): except FileNotFoundError: pass # If the path doesn't exist, meh except OSError as e: - logger.warning("Failed to remove media: %r: %s", media_id, e) + logger.warning( + "Failed to remove media from url preview cache: %r: %s", media_id, e + ) continue removed_media.append(media_id) @@ -758,9 +770,9 @@ class PreviewUrlResource(DirectServeJsonResource): await self.store.delete_url_cache_media(removed_media) if removed_media: - logger.info("Deleted %d media from url cache", len(removed_media)) + logger.debug("Deleted %d media from url preview cache", len(removed_media)) else: - logger.debug("No media removed from url cache") + logger.debug("No media removed from url preview cache") def _is_media(content_type: str) -> bool: diff --git a/synapse/server.py b/synapse/server.py index d49c76518a..ee60cce8eb 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -119,7 +119,7 @@ from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpC from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.module_api import ModuleApi from synapse.notifier import Notifier -from synapse.push.action_generator import ActionGenerator +from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator from synapse.push.pusherpool import PusherPool from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.external_cache import ExternalCache @@ -644,8 +644,8 @@ class HomeServer(metaclass=abc.ABCMeta): return ReplicationCommandHandler(self) @cache_in_self - def get_action_generator(self) -> ActionGenerator: - return ActionGenerator(self) + def get_bulk_push_rule_evaluator(self) -> BulkPushRuleEvaluator: + return BulkPushRuleEvaluator(self) @cache_in_self def get_user_directory_handler(self) -> UserDirectoryHandler: @@ -681,7 +681,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_spam_checker(self) -> SpamChecker: - return SpamChecker() + return SpamChecker(self) @cache_in_self def get_third_party_event_rules(self) -> ThirdPartyEventRules: diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 015dd08f05..b5f3a0c74e 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -21,7 +21,6 @@ from synapse.api.constants import ( ServerNoticeMsgType, ) from synapse.api.errors import AuthError, ResourceLimitError, SynapseError -from synapse.server_notices.server_notices_manager import SERVER_NOTICE_ROOM_TAG if TYPE_CHECKING: from synapse.server import HomeServer @@ -71,18 +70,19 @@ class ResourceLimitsServerNotices: # In practice, not sure we can ever get here return - room_id = await self._server_notices_manager.get_or_create_notice_room_for_user( + # Check if there's a server notice room for this user. + room_id = await self._server_notices_manager.maybe_get_notice_room_for_user( user_id ) - if not room_id: - logger.warning("Failed to get server notices room") - return - - await self._check_and_set_tags(user_id, room_id) - - # Determine current state of room - currently_blocked, ref_events = await self._is_room_currently_blocked(room_id) + if room_id is not None: + # Determine current state of room + currently_blocked, ref_events = await self._is_room_currently_blocked( + room_id + ) + else: + currently_blocked = False + ref_events = [] limit_msg = None limit_type = None @@ -161,26 +161,6 @@ class ResourceLimitsServerNotices: user_id, content, EventTypes.Pinned, "" ) - async def _check_and_set_tags(self, user_id: str, room_id: str) -> None: - """ - Since server notices rooms were originally not with tags, - important to check that tags have been set correctly - Args: - user_id(str): the user in question - room_id(str): the server notices room for that user - """ - tags = await self._store.get_tags_for_room(user_id, room_id) - need_to_set_tag = True - if tags: - if SERVER_NOTICE_ROOM_TAG in tags: - # tag already present, nothing to do here - need_to_set_tag = False - if need_to_set_tag: - max_id = await self._account_data_handler.add_tag_to_room( - user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} - ) - self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) - async def _is_room_currently_blocked(self, room_id: str) -> Tuple[bool, List[str]]: """ Determines if the room is currently blocked diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 48eae5fa06..8ecab86ec7 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Optional from synapse.api.constants import EventTypes, Membership, RoomCreationPreset from synapse.events import EventBase -from synapse.types import Requester, UserID, create_requester +from synapse.types import Requester, StreamKeyType, UserID, create_requester from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -91,6 +91,35 @@ class ServerNoticesManager: return event @cached() + async def maybe_get_notice_room_for_user(self, user_id: str) -> Optional[str]: + """Try to look up the server notice room for this user if it exists. + + Does not create one if none can be found. + + Args: + user_id: the user we want a server notice room for. + + Returns: + The room's ID, or None if no room could be found. + """ + rooms = await self._store.get_rooms_for_local_user_where_membership_is( + user_id, [Membership.INVITE, Membership.JOIN] + ) + for room in rooms: + # it's worth noting that there is an asymmetry here in that we + # expect the user to be invited or joined, but the system user must + # be joined. This is kinda deliberate, in that if somebody somehow + # manages to invite the system user to a room, that doesn't make it + # the server notices room. + user_ids = await self._store.get_users_in_room(room.room_id) + if len(user_ids) <= 2 and self.server_notices_mxid in user_ids: + # we found a room which our user shares with the system notice + # user + return room.room_id + + return None + + @cached() async def get_or_create_notice_room_for_user(self, user_id: str) -> str: """Get the room for notices for a given user @@ -112,31 +141,20 @@ class ServerNoticesManager: self.server_notices_mxid, authenticated_entity=self._server_name ) - rooms = await self._store.get_rooms_for_local_user_where_membership_is( - user_id, [Membership.INVITE, Membership.JOIN] - ) - for room in rooms: - # it's worth noting that there is an asymmetry here in that we - # expect the user to be invited or joined, but the system user must - # be joined. This is kinda deliberate, in that if somebody somehow - # manages to invite the system user to a room, that doesn't make it - # the server notices room. - user_ids = await self._store.get_users_in_room(room.room_id) - if len(user_ids) <= 2 and self.server_notices_mxid in user_ids: - # we found a room which our user shares with the system notice - # user - logger.info( - "Using existing server notices room %s for user %s", - room.room_id, - user_id, - ) - await self._update_notice_user_profile_if_changed( - requester, - room.room_id, - self._config.servernotices.server_notices_mxid_display_name, - self._config.servernotices.server_notices_mxid_avatar_url, - ) - return room.room_id + room_id = await self.maybe_get_notice_room_for_user(user_id) + if room_id is not None: + logger.info( + "Using existing server notices room %s for user %s", + room_id, + user_id, + ) + await self._update_notice_user_profile_if_changed( + requester, + room_id, + self._config.servernotices.server_notices_mxid_display_name, + self._config.servernotices.server_notices_mxid_avatar_url, + ) + return room_id # apparently no existing notice room: create a new one logger.info("Creating server notices room for %s", user_id) @@ -166,10 +184,12 @@ class ServerNoticesManager: ) room_id = info["room_id"] + self.maybe_get_notice_room_for_user.invalidate((user_id,)) + max_id = await self._account_data_handler.add_tag_to_room( user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} ) - self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) + self._notifier.on_new_event(StreamKeyType.ACCOUNT_DATA, max_id, users=[user_id]) logger.info("Created server notices room %s for %s", room_id, user_id) return room_id diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index cad3b42640..54e41d5375 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -130,6 +130,7 @@ class StateHandler: self.state_store = hs.get_storage().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() + self._storage = hs.get_storage() @overload async def get_current_state( @@ -361,10 +362,10 @@ class StateHandler: if not event.is_state(): return EventContext.with_state( + storage=self._storage, state_group_before_event=state_group_before_event, state_group=state_group_before_event, - current_state_ids=state_ids_before_event, - prev_state_ids=state_ids_before_event, + state_delta_due_to_event={}, prev_group=state_group_before_event_prev_group, delta_ids=deltas_to_state_group_before_event, partial_state=partial_state, @@ -393,10 +394,10 @@ class StateHandler: ) return EventContext.with_state( + storage=self._storage, state_group=state_group_after_event, state_group_before_event=state_group_before_event, - current_state_ids=state_ids_after_event, - prev_state_ids=state_ids_before_event, + state_delta_due_to_event=delta_ids, prev_group=state_group_before_event, delta_ids=delta_ids, partial_state=partial_state, diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 08c6eabc6d..c2bbbb574e 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -12,20 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from types import TracebackType from typing import ( TYPE_CHECKING, + Any, AsyncContextManager, Awaitable, Callable, Dict, Iterable, + List, Optional, + Type, ) import attr from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.types import Connection +from synapse.storage.types import Connection, Cursor from synapse.types import JsonDict from synapse.util import Clock, json_encoder @@ -74,7 +78,12 @@ class _BackgroundUpdateContextManager: return self._update_duration_ms - async def __aexit__(self, *exc) -> None: + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: pass @@ -352,7 +361,7 @@ class BackgroundUpdater: True if we have finished running all the background updates, otherwise False """ - def get_background_updates_txn(txn): + def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]: txn.execute( """ SELECT update_name, depends_on FROM background_updates @@ -469,7 +478,7 @@ class BackgroundUpdater: self, update_name: str, update_handler: Callable[[JsonDict, int], Awaitable[int]], - ): + ) -> None: """Register a handler for doing a background update. The handler should take two arguments: @@ -603,7 +612,7 @@ class BackgroundUpdater: else: runner = create_index_sqlite - async def updater(progress, batch_size): + async def updater(progress: JsonDict, batch_size: int) -> int: if runner is not None: logger.info("Adding index %s to %s", index_name, table) await self.db_pool.runWithConnection(runner) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 41f566b648..5ddb58a8a2 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -31,6 +31,7 @@ from typing import ( List, Optional, Tuple, + Type, TypeVar, cast, overload, @@ -41,6 +42,7 @@ from prometheus_client import Histogram from typing_extensions import Concatenate, Literal, ParamSpec from twisted.enterprise import adbapi +from twisted.internet.interfaces import IReactorCore from synapse.api.errors import StoreError from synapse.config.database import DatabaseConnectionConfig @@ -92,7 +94,9 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { def make_pool( - reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine + reactor: IReactorCore, + db_config: DatabaseConnectionConfig, + engine: BaseDatabaseEngine, ) -> adbapi.ConnectionPool: """Get the connection pool for the database.""" @@ -101,7 +105,7 @@ def make_pool( db_args = dict(db_config.config.get("args", {})) db_args.setdefault("cp_reconnect", True) - def _on_new_connection(conn): + def _on_new_connection(conn: Connection) -> None: # Ensure we have a logging context so we can correctly track queries, # etc. with LoggingContext("db.on_new_connection"): @@ -157,7 +161,11 @@ class LoggingDatabaseConnection: default_txn_name: str def cursor( - self, *, txn_name=None, after_callbacks=None, exception_callbacks=None + self, + *, + txn_name: Optional[str] = None, + after_callbacks: Optional[List["_CallbackListEntry"]] = None, + exception_callbacks: Optional[List["_CallbackListEntry"]] = None, ) -> "LoggingTransaction": if not txn_name: txn_name = self.default_txn_name @@ -183,11 +191,16 @@ class LoggingDatabaseConnection: self.conn.__enter__() return self - def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[types.TracebackType], + ) -> Optional[bool]: return self.conn.__exit__(exc_type, exc_value, traceback) # Proxy through any unknown lookups to the DB conn class. - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: return getattr(self.conn, name) @@ -391,17 +404,22 @@ class LoggingTransaction: def __enter__(self) -> "LoggingTransaction": return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[types.TracebackType], + ) -> None: self.close() class PerformanceCounters: - def __init__(self): - self.current_counters = {} - self.previous_counters = {} + def __init__(self) -> None: + self.current_counters: Dict[str, Tuple[int, float]] = {} + self.previous_counters: Dict[str, Tuple[int, float]] = {} def update(self, key: str, duration_secs: float) -> None: - count, cum_time = self.current_counters.get(key, (0, 0)) + count, cum_time = self.current_counters.get(key, (0, 0.0)) count += 1 cum_time += duration_secs self.current_counters[key] = (count, cum_time) @@ -527,7 +545,7 @@ class DatabasePool: def start_profiling(self) -> None: self._previous_loop_ts = monotonic_time() - def loop(): + def loop() -> None: curr = self._current_txn_total_time prev = self._previous_txn_total_time self._previous_txn_total_time = curr @@ -1186,7 +1204,7 @@ class DatabasePool: if lock: self.engine.lock_table(txn, table) - def _getwhere(key): + def _getwhere(key: str) -> str: # If the value we're passing in is None (aka NULL), we need to use # IS, not =, as NULL = NULL equals NULL (False). if keyvalues[key] is None: @@ -2258,7 +2276,7 @@ class DatabasePool: term: Optional[str], col: str, retcols: Collection[str], - desc="simple_search_list", + desc: str = "simple_search_list", ) -> Optional[List[Dict[str, Any]]]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index dd4e83a2ad..1653a6a9b6 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -57,6 +57,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._instance_name = hs.get_instance_name() + self.db_pool.updates.register_background_index_update( + update_name="cache_invalidation_index_by_instance", + index_name="cache_invalidation_stream_by_instance_instance_index", + table="cache_invalidation_stream_by_instance", + columns=("instance_name", "stream_id"), + psql_only=True, # The table is only on postgres DBs. + ) + async def get_all_updated_caches( self, instance_name: str, last_id: int, current_id: int, limit: int ) -> Tuple[List[Tuple[int, tuple]], int, bool]: diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py index b789a588a5..af59be6b48 100644 --- a/synapse/storage/databases/main/e2e_room_keys.py +++ b/synapse/storage/databases/main/e2e_room_keys.py @@ -21,7 +21,7 @@ from synapse.api.errors import StoreError from synapse.logging.opentracing import log_kv, trace from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import LoggingTransaction -from synapse.types import JsonDict, JsonSerializable +from synapse.types import JsonDict, JsonSerializable, StreamKeyType from synapse.util import json_encoder @@ -126,7 +126,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): "message": "Set room key", "room_id": room_id, "session_id": session_id, - "room_key": room_key, + StreamKeyType.ROOM: room_key, } ) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index ed29a0a5e2..42d484dc98 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -36,9 +36,8 @@ from prometheus_client import Counter import synapse.metrics from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.room_versions import RoomVersions -from synapse.crypto.event_signing import compute_event_reference_hash -from synapse.events import EventBase # noqa: F401 -from synapse.events.snapshot import EventContext # noqa: F401 +from synapse.events import EventBase, relation_from_event +from synapse.events.snapshot import EventContext from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -50,7 +49,7 @@ from synapse.storage.databases.main.search import SearchEntry from synapse.storage.engines.postgres import PostgresEngine from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.sequence import SequenceGenerator -from synapse.types import StateMap, get_domain_from_id +from synapse.types import JsonDict, StateMap, get_domain_from_id from synapse.util import json_encoder from synapse.util.iterutils import batch_iter, sorted_topologically @@ -129,7 +128,6 @@ class PersistEventsStore: self, events_and_contexts: List[Tuple[EventBase, EventContext]], *, - current_state_for_room: Dict[str, StateMap[str]], state_delta_for_room: Dict[str, DeltaState], new_forward_extremities: Dict[str, Set[str]], use_negative_stream_ordering: bool = False, @@ -140,8 +138,6 @@ class PersistEventsStore: Args: events_and_contexts: - current_state_for_room: Map from room_id to the current state of - the room based on forward extremities state_delta_for_room: Map from room_id to the delta to apply to room state new_forward_extremities: Map from room_id to set of event IDs @@ -216,9 +212,6 @@ class PersistEventsStore: event_counter.labels(event.type, origin_type, origin_entity).inc() - for room_id, new_state in current_state_for_room.items(): - self.store.get_current_state_ids.prefill((room_id,), new_state) - for room_id, latest_event_ids in new_forward_extremities.items(): self.store.get_latest_event_ids_in_room.prefill( (room_id,), list(latest_event_ids) @@ -236,7 +229,9 @@ class PersistEventsStore: """ results: List[str] = [] - def _get_events_which_are_prevs_txn(txn, batch): + def _get_events_which_are_prevs_txn( + txn: LoggingTransaction, batch: Collection[str] + ) -> None: sql = """ SELECT prev_event_id, internal_metadata FROM event_edges @@ -286,7 +281,9 @@ class PersistEventsStore: # and their prev events. existing_prevs = set() - def _get_prevs_before_rejected_txn(txn, batch): + def _get_prevs_before_rejected_txn( + txn: LoggingTransaction, batch: Collection[str] + ) -> None: to_recursively_check = batch while to_recursively_check: @@ -516,7 +513,7 @@ class PersistEventsStore: @classmethod def _add_chain_cover_index( cls, - txn, + txn: LoggingTransaction, db_pool: DatabasePool, event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], @@ -810,7 +807,7 @@ class PersistEventsStore: @staticmethod def _allocate_chain_ids( - txn, + txn: LoggingTransaction, db_pool: DatabasePool, event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], @@ -944,7 +941,7 @@ class PersistEventsStore: self, txn: LoggingTransaction, events_and_contexts: List[Tuple[EventBase, EventContext]], - ): + ) -> None: """Persist the mapping from transaction IDs to event IDs (if defined).""" to_insert = [] @@ -998,7 +995,7 @@ class PersistEventsStore: txn: LoggingTransaction, state_delta_by_room: Dict[str, DeltaState], stream_id: int, - ): + ) -> None: for room_id, delta_state in state_delta_by_room.items(): to_delete = delta_state.to_delete to_insert = delta_state.to_insert @@ -1156,7 +1153,7 @@ class PersistEventsStore: txn, room_id, members_changed ) - def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str): + def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None: """Update the room version in the database based off current state events. @@ -1190,7 +1187,7 @@ class PersistEventsStore: txn: LoggingTransaction, new_forward_extremities: Dict[str, Set[str]], max_stream_order: int, - ): + ) -> None: for room_id in new_forward_extremities.keys(): self.db_pool.simple_delete_txn( txn, table="event_forward_extremities", keyvalues={"room_id": room_id} @@ -1255,9 +1252,9 @@ class PersistEventsStore: def _update_room_depths_txn( self, - txn, + txn: LoggingTransaction, events_and_contexts: List[Tuple[EventBase, EventContext]], - ): + ) -> None: """Update min_depth for each room Args: @@ -1386,7 +1383,7 @@ class PersistEventsStore: # nothing to do here return - def event_dict(event): + def event_dict(event: EventBase) -> JsonDict: d = event.get_dict() d.pop("redacted", None) d.pop("redacted_because", None) @@ -1477,18 +1474,20 @@ class PersistEventsStore: ), ) - def _store_rejected_events_txn(self, txn, events_and_contexts): + def _store_rejected_events_txn( + self, + txn: LoggingTransaction, + events_and_contexts: List[Tuple[EventBase, EventContext]], + ) -> List[Tuple[EventBase, EventContext]]: """Add rows to the 'rejections' table for received events which were rejected Args: - txn (twisted.enterprise.adbapi.Connection): db connection - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting + txn: db connection + events_and_contexts: events we are persisting Returns: - list[(EventBase, EventContext)] new list, without the rejected - events. + new list, without the rejected events. """ # Remove the rejected events from the list now that we've added them # to the events table and the events_json table. @@ -1509,7 +1508,7 @@ class PersistEventsStore: events_and_contexts: List[Tuple[EventBase, EventContext]], all_events_and_contexts: List[Tuple[EventBase, EventContext]], inhibit_local_membership_updates: bool = False, - ): + ) -> None: """Update all the miscellaneous tables for new events Args: @@ -1600,15 +1599,14 @@ class PersistEventsStore: inhibit_local_membership_updates=inhibit_local_membership_updates, ) - # Insert event_reference_hashes table. - self._store_event_reference_hashes_txn( - txn, [event for event, _ in events_and_contexts] - ) - # Prefill the event cache self._add_to_cache(txn, events_and_contexts) - def _add_to_cache(self, txn, events_and_contexts): + def _add_to_cache( + self, + txn: LoggingTransaction, + events_and_contexts: List[Tuple[EventBase, EventContext]], + ) -> None: to_prefill = [] rows = [] @@ -1639,7 +1637,7 @@ class PersistEventsStore: if not row["rejects"] and not row["redacts"]: to_prefill.append(EventCacheEntry(event=event, redacted_event=None)) - def prefill(): + def prefill() -> None: for cache_entry in to_prefill: self.store._get_event_cache.set( (cache_entry.event.event_id,), cache_entry @@ -1669,19 +1667,24 @@ class PersistEventsStore: ) def insert_labels_for_event_txn( - self, txn, event_id, labels, room_id, topological_ordering - ): + self, + txn: LoggingTransaction, + event_id: str, + labels: List[str], + room_id: str, + topological_ordering: int, + ) -> None: """Store the mapping between an event's ID and its labels, with one row per (event_id, label) tuple. Args: - txn (LoggingTransaction): The transaction to execute. - event_id (str): The event's ID. - labels (list[str]): A list of text labels. - room_id (str): The ID of the room the event was sent to. - topological_ordering (int): The position of the event in the room's topology. + txn: The transaction to execute. + event_id: The event's ID. + labels: A list of text labels. + room_id: The ID of the room the event was sent to. + topological_ordering: The position of the event in the room's topology. """ - return self.db_pool.simple_insert_many_txn( + self.db_pool.simple_insert_many_txn( txn=txn, table="event_labels", keys=("event_id", "label", "room_id", "topological_ordering"), @@ -1690,44 +1693,32 @@ class PersistEventsStore: ], ) - def _insert_event_expiry_txn(self, txn, event_id, expiry_ts): + def _insert_event_expiry_txn( + self, txn: LoggingTransaction, event_id: str, expiry_ts: int + ) -> None: """Save the expiry timestamp associated with a given event ID. Args: - txn (LoggingTransaction): The database transaction to use. - event_id (str): The event ID the expiry timestamp is associated with. - expiry_ts (int): The timestamp at which to expire (delete) the event. + txn: The database transaction to use. + event_id: The event ID the expiry timestamp is associated with. + expiry_ts: The timestamp at which to expire (delete) the event. """ - return self.db_pool.simple_insert_txn( + self.db_pool.simple_insert_txn( txn=txn, table="event_expiry", values={"event_id": event_id, "expiry_ts": expiry_ts}, ) - def _store_event_reference_hashes_txn(self, txn, events): - """Store a hash for a PDU - Args: - txn (cursor): - events (list): list of Events. - """ - - vals = [] - for event in events: - ref_alg, ref_hash_bytes = compute_event_reference_hash(event) - vals.append((event.event_id, ref_alg, memoryview(ref_hash_bytes))) - - self.db_pool.simple_insert_many_txn( - txn, - table="event_reference_hashes", - keys=("event_id", "algorithm", "hash"), - values=vals, - ) - def _store_room_members_txn( - self, txn, events, *, inhibit_local_membership_updates: bool = False - ): + self, + txn: LoggingTransaction, + events: List[EventBase], + *, + inhibit_local_membership_updates: bool = False, + ) -> None: """ Store a room member in the database. + Args: txn: The transaction to use. events: List of events to store. @@ -1767,6 +1758,7 @@ class PersistEventsStore: ) for event in events: + assert event.internal_metadata.stream_ordering is not None txn.call_after( self.store._membership_stream_cache.entity_has_changed, event.state_key, @@ -1815,55 +1807,50 @@ class PersistEventsStore: txn: The current database transaction. event: The event which might have relations. """ - relation = event.content.get("m.relates_to") + relation = relation_from_event(event) if not relation: - # No relations - return - - # Relations must have a type and parent event ID. - rel_type = relation.get("rel_type") - if not isinstance(rel_type, str): + # No relation, nothing to do. return - parent_id = relation.get("event_id") - if not isinstance(parent_id, str): - return - - # Annotations have a key field. - aggregation_key = None - if rel_type == RelationTypes.ANNOTATION: - aggregation_key = relation.get("key") - self.db_pool.simple_insert_txn( txn, table="event_relations", values={ "event_id": event.event_id, - "relates_to_id": parent_id, - "relation_type": rel_type, - "aggregation_key": aggregation_key, + "relates_to_id": relation.parent_id, + "relation_type": relation.rel_type, + "aggregation_key": relation.aggregation_key, }, ) - txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,)) txn.call_after( - self.store.get_aggregation_groups_for_event.invalidate, (parent_id,) + self.store.get_relations_for_event.invalidate, (relation.parent_id,) + ) + txn.call_after( + self.store.get_aggregation_groups_for_event.invalidate, + (relation.parent_id,), ) - if rel_type == RelationTypes.REPLACE: - txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) + if relation.rel_type == RelationTypes.REPLACE: + txn.call_after( + self.store.get_applicable_edit.invalidate, (relation.parent_id,) + ) - if rel_type == RelationTypes.THREAD: - txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) + if relation.rel_type == RelationTypes.THREAD: + txn.call_after( + self.store.get_thread_summary.invalidate, (relation.parent_id,) + ) # It should be safe to only invalidate the cache if the user has not # previously participated in the thread, but that's difficult (and # potentially error-prone) so it is always invalidated. txn.call_after( self.store.get_thread_participated.invalidate, - (parent_id, event.sender), + (relation.parent_id, event.sender), ) - def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): + def _handle_insertion_event( + self, txn: LoggingTransaction, event: EventBase + ) -> None: """Handles keeping track of insertion events and edges/connections. Part of MSC2716. @@ -1924,7 +1911,7 @@ class PersistEventsStore: }, ) - def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase): + def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None: """Handles inserting the batch edges/connections between the batch event and an insertion event. Part of MSC2716. @@ -2024,25 +2011,29 @@ class PersistEventsStore: txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) - def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase): + def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None: if isinstance(event.content.get("topic"), str): self.store_event_search_txn( txn, event, "content.topic", event.content["topic"] ) - def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase): + def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase) -> None: if isinstance(event.content.get("name"), str): self.store_event_search_txn( txn, event, "content.name", event.content["name"] ) - def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase): + def _store_room_message_txn( + self, txn: LoggingTransaction, event: EventBase + ) -> None: if isinstance(event.content.get("body"), str): self.store_event_search_txn( txn, event, "content.body", event.content["body"] ) - def _store_retention_policy_for_room_txn(self, txn, event): + def _store_retention_policy_for_room_txn( + self, txn: LoggingTransaction, event: EventBase + ) -> None: if not event.is_state(): logger.debug("Ignoring non-state m.room.retention event") return @@ -2102,8 +2093,11 @@ class PersistEventsStore: ) def _set_push_actions_for_event_and_users_txn( - self, txn, events_and_contexts, all_events_and_contexts - ): + self, + txn: LoggingTransaction, + events_and_contexts: List[Tuple[EventBase, EventContext]], + all_events_and_contexts: List[Tuple[EventBase, EventContext]], + ) -> None: """Handles moving push actions from staging table to main event_push_actions table for all events in `events_and_contexts`. @@ -2111,12 +2105,10 @@ class PersistEventsStore: from the push action staging area. Args: - events_and_contexts (list[(EventBase, EventContext)]): events - we are persisting - all_events_and_contexts (list[(EventBase, EventContext)]): all - events that we were going to persist. This includes events - we've already persisted, etc, that wouldn't appear in - events_and_context. + events_and_contexts: events we are persisting + all_events_and_contexts: all events that we were going to persist. + This includes events we've already persisted, etc, that wouldn't + appear in events_and_context. """ # Only non outlier events will have push actions associated with them, @@ -2185,7 +2177,9 @@ class PersistEventsStore: ), ) - def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): + def _remove_push_actions_for_event_id_txn( + self, txn: LoggingTransaction, room_id: str, event_id: str + ) -> None: # Sad that we have to blow away the cache for the whole room here txn.call_after( self.store.get_unread_event_push_actions_by_room_for_user.invalidate, @@ -2196,7 +2190,9 @@ class PersistEventsStore: (room_id, event_id), ) - def _store_rejections_txn(self, txn, event_id, reason): + def _store_rejections_txn( + self, txn: LoggingTransaction, event_id: str, reason: str + ) -> None: self.db_pool.simple_insert_txn( txn, table="rejections", @@ -2208,8 +2204,10 @@ class PersistEventsStore: ) def _store_event_state_mappings_txn( - self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]] - ): + self, + txn: LoggingTransaction, + events_and_contexts: Collection[Tuple[EventBase, EventContext]], + ) -> None: state_groups = {} for event, context in events_and_contexts: if event.internal_metadata.is_outlier(): @@ -2266,7 +2264,9 @@ class PersistEventsStore: state_group_id, ) - def _update_min_depth_for_room_txn(self, txn, room_id, depth): + def _update_min_depth_for_room_txn( + self, txn: LoggingTransaction, room_id: str, depth: int + ) -> None: min_depth = self.store._get_min_depth_interaction(txn, room_id) if min_depth is not None and depth >= min_depth: @@ -2279,7 +2279,9 @@ class PersistEventsStore: values={"min_depth": depth}, ) - def _handle_mult_prev_events(self, txn, events): + def _handle_mult_prev_events( + self, txn: LoggingTransaction, events: List[EventBase] + ) -> None: """ For the given event, update the event edges table and forward and backward extremities tables. @@ -2297,7 +2299,9 @@ class PersistEventsStore: self._update_backward_extremeties(txn, events) - def _update_backward_extremeties(self, txn, events): + def _update_backward_extremeties( + self, txn: LoggingTransaction, events: List[EventBase] + ) -> None: """Updates the event_backward_extremities tables based on the new/updated events being persisted. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index a4a604a499..5b22d6b452 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -14,6 +14,7 @@ import logging import threading +import weakref from enum import Enum, auto from typing import ( TYPE_CHECKING, @@ -23,6 +24,7 @@ from typing import ( Dict, Iterable, List, + MutableMapping, Optional, Set, Tuple, @@ -248,6 +250,12 @@ class EventsWorkerStore(SQLBaseStore): str, ObservableDeferred[Dict[str, EventCacheEntry]] ] = {} + # We keep track of the events we have currently loaded in memory so that + # we can reuse them even if they've been evicted from the cache. We only + # track events that don't need redacting in here (as then we don't need + # to track redaction status). + self._event_ref: MutableMapping[str, EventBase] = weakref.WeakValueDictionary() + self._event_fetch_lock = threading.Condition() self._event_fetch_list: List[ Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"] @@ -723,6 +731,8 @@ class EventsWorkerStore(SQLBaseStore): def _invalidate_get_event_cache(self, event_id: str) -> None: self._get_event_cache.invalidate((event_id,)) + self._event_ref.pop(event_id, None) + self._current_event_fetches.pop(event_id, None) def _get_events_from_cache( self, events: Iterable[str], update_metrics: bool = True @@ -738,13 +748,30 @@ class EventsWorkerStore(SQLBaseStore): event_map = {} for event_id in events: + # First check if it's in the event cache ret = self._get_event_cache.get( (event_id,), None, update_metrics=update_metrics ) - if not ret: + if ret: + event_map[event_id] = ret continue - event_map[event_id] = ret + # Otherwise check if we still have the event in memory. + event = self._event_ref.get(event_id) + if event: + # Reconstruct an event cache entry + + cache_entry = EventCacheEntry( + event=event, + # We don't cache weakrefs to redacted events, so we know + # this is None. + redacted_event=None, + ) + event_map[event_id] = cache_entry + + # We add the entry back into the cache as we want to keep + # recently queried events in the cache. + self._get_event_cache.set((event_id,), cache_entry) return event_map @@ -1124,6 +1151,10 @@ class EventsWorkerStore(SQLBaseStore): self._get_event_cache.set((event_id,), cache_entry) result_map[event_id] = cache_entry + if not redacted_event: + # We only cache references to unredacted events. + self._event_ref[event_id] = original_ev + return result_map async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]: diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index 1480a0f048..d03555a585 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -23,6 +23,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) +from synapse.storage.types import Cursor if TYPE_CHECKING: from synapse.server import HomeServer @@ -71,7 +72,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): self._last_user_visit_update = self._get_start_of_day() @wrap_as_background_process("read_forward_extremities") - async def _read_forward_extremities(self): + async def _read_forward_extremities(self) -> None: def fetch(txn): txn.execute( """ @@ -95,7 +96,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): (x[0] - 1) * x[1] for x in res if x[1] ) - async def count_daily_e2ee_messages(self): + async def count_daily_e2ee_messages(self) -> int: """ Returns an estimate of the number of messages sent in the last day. @@ -115,7 +116,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages) - async def count_daily_sent_e2ee_messages(self): + async def count_daily_sent_e2ee_messages(self) -> int: def _count_messages(txn): # This is good enough as if you have silly characters in your own # hostname then that's your own fault. @@ -136,7 +137,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): "count_daily_sent_e2ee_messages", _count_messages ) - async def count_daily_active_e2ee_rooms(self): + async def count_daily_active_e2ee_rooms(self) -> int: def _count(txn): sql = """ SELECT COUNT(DISTINCT room_id) FROM events @@ -151,7 +152,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): "count_daily_active_e2ee_rooms", _count ) - async def count_daily_messages(self): + async def count_daily_messages(self) -> int: """ Returns an estimate of the number of messages sent in the last day. @@ -171,7 +172,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): return await self.db_pool.runInteraction("count_messages", _count_messages) - async def count_daily_sent_messages(self): + async def count_daily_sent_messages(self) -> int: def _count_messages(txn): # This is good enough as if you have silly characters in your own # hostname then that's your own fault. @@ -192,7 +193,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): "count_daily_sent_messages", _count_messages ) - async def count_daily_active_rooms(self): + async def count_daily_active_rooms(self) -> int: def _count(txn): sql = """ SELECT COUNT(DISTINCT room_id) FROM events @@ -226,7 +227,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): "count_monthly_users", self._count_users, thirty_days_ago ) - def _count_users(self, txn, time_from): + def _count_users(self, txn: Cursor, time_from: int) -> int: """ Returns number of users seen in the past time_from period """ @@ -238,7 +239,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): ) u """ txn.execute(sql, (time_from,)) - (count,) = txn.fetchone() + # Mypy knows that fetchone() might return None if there are no rows. + # We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always + # returns exactly one row. + (count,) = txn.fetchone() # type: ignore[misc] return count async def count_r30_users(self) -> Dict[str, int]: @@ -453,7 +457,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): "count_r30v2_users", _count_r30v2_users ) - def _get_start_of_day(self): + def _get_start_of_day(self) -> int: """ Returns millisecond unixtime for start of UTC day. """ diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index bfc85b3add..38ba91af4c 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -69,7 +69,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): # event_forward_extremities # event_json # event_push_actions - # event_reference_hashes # event_relations # event_search # event_to_state_groups @@ -220,7 +219,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "event_auth", "event_edges", "event_forward_extremities", - "event_reference_hashes", "event_relations", "event_search", "rejections", @@ -369,7 +367,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "event_edges", "event_json", "event_push_actions_staging", - "event_reference_hashes", "event_relations", "event_to_state_groups", "event_auth_chains", diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 484976ca6b..fe8fded88b 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -34,7 +34,7 @@ from synapse.storage._base import SQLBaseStore from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine -from synapse.types import JsonDict, RoomStreamToken, StreamToken +from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList logger = logging.getLogger(__name__) @@ -161,7 +161,9 @@ class RelationsWorkerStore(SQLBaseStore): if len(events) > limit and last_topo_id and last_stream_id: next_key = RoomStreamToken(last_topo_id, last_stream_id) if from_token: - next_token = from_token.copy_and_replace("room_key", next_key) + next_token = from_token.copy_and_replace( + StreamKeyType.ROOM, next_key + ) else: next_token = StreamToken( room_key=next_key, diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 3c49e7ec98..78e0773b2a 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -14,7 +14,7 @@ import logging import re -from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set +from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple import attr @@ -27,7 +27,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.types import JsonDict if TYPE_CHECKING: @@ -149,7 +149,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings ) - async def _background_reindex_search(self, progress, batch_size): + async def _background_reindex_search( + self, progress: JsonDict, batch_size: int + ) -> int: # we work through the events table from highest stream id to lowest target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] @@ -157,7 +159,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): TYPES = ["m.room.name", "m.room.message", "m.room.topic"] - def reindex_search_txn(txn): + def reindex_search_txn(txn: LoggingTransaction) -> int: sql = ( "SELECT stream_ordering, event_id, room_id, type, json, " " origin_server_ts FROM events" @@ -255,12 +257,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): return result - async def _background_reindex_gin_search(self, progress, batch_size): + async def _background_reindex_gin_search( + self, progress: JsonDict, batch_size: int + ) -> int: """This handles old synapses which used GIST indexes, if any; converting them back to be GIN as per the actual schema. """ - def create_index(conn): + def create_index(conn: LoggingDatabaseConnection) -> None: conn.rollback() # we have to set autocommit, because postgres refuses to @@ -299,7 +303,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): ) return 1 - async def _background_reindex_search_order(self, progress, batch_size): + async def _background_reindex_search_order( + self, progress: JsonDict, batch_size: int + ) -> int: target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) @@ -307,7 +313,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): if not have_added_index: - def create_index(conn): + def create_index(conn: LoggingDatabaseConnection) -> None: conn.rollback() conn.set_session(autocommit=True) c = conn.cursor() @@ -336,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): pg, ) - def reindex_search_txn(txn): + def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]: sql = ( "UPDATE event_search AS es SET stream_ordering = e.stream_ordering," " origin_server_ts = e.origin_server_ts" @@ -644,7 +650,8 @@ class SearchStore(SearchBackgroundUpdateStore): else: raise Exception("Unrecognized database engine") - args.append(limit) + # mypy expects to append only a `str`, not an `int` + args.append(limit) # type: ignore[arg-type] results = await self.db_pool.execute( "search_rooms", self.db_pool.cursor_to_dict, sql, *args @@ -705,7 +712,7 @@ class SearchStore(SearchBackgroundUpdateStore): A set of strings. """ - def f(txn): + def f(txn: LoggingTransaction) -> Set[str]: highlight_words = set() for event in events: # As a hack we simply join values of all possible keys. This is @@ -759,11 +766,11 @@ class SearchStore(SearchBackgroundUpdateStore): return await self.db_pool.runInteraction("_find_highlights", f) -def _to_postgres_options(options_dict): +def _to_postgres_options(options_dict: JsonDict) -> str: return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),) -def _parse_query(database_engine, search_term): +def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str: """Takes a plain unicode string from the user and converts it into a form that can be passed to database. We use this so that we can add prefix matching, which isn't something diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 0373af86c8..0e3a23a140 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -788,30 +788,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return None async def get_current_room_stream_token_for_room_id( - self, room_id: Optional[str] = None + self, room_id: str ) -> RoomStreamToken: - """Returns the current position of the rooms stream. - - By default, it returns a live token with the current global stream - token. Specifying a `room_id` causes it to return a historic token with - the room specific topological token. - """ + """Returns the current position of the rooms stream (historic token).""" stream_ordering = self.get_room_max_stream_ordering() - if room_id is None: - return RoomStreamToken(None, stream_ordering) - else: - topo = await self.db_pool.runInteraction( - "_get_max_topological_txn", self._get_max_topological_txn, room_id - ) - return RoomStreamToken(topo, stream_ordering) + topo = await self.db_pool.runInteraction( + "_get_max_topological_txn", self._get_max_topological_txn, room_id + ) + return RoomStreamToken(topo, stream_ordering) def get_stream_id_for_event_txn( self, txn: LoggingTransaction, event_id: str, - allow_none=False, - ) -> int: - return self.db_pool.simple_select_one_onecol_txn( + allow_none: bool = False, + ) -> Optional[int]: + # Type ignore: we pass keyvalues a Dict[str, str]; the function wants + # Dict[str, Any]. I think mypy is unhappy because Dict is invariant? + return self.db_pool.simple_select_one_onecol_txn( # type: ignore[call-overload] txn=txn, table="events", keyvalues={"event_id": event_id}, @@ -873,7 +867,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) rows = txn.fetchall() - return rows[0][0] if rows else 0 + # An aggregate function like MAX() will always return one row per group + # so we can safely rely on the lookup here. For example, when a we + # lookup a `room_id` which does not exist, `rows` will look like + # `[(None,)]` + return rows[0][0] if rows[0][0] is not None else 0 @staticmethod def _set_before_and_after( diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index afb7d5054d..f51b3d228e 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -11,25 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Mapping from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup from .postgres import PostgresEngine from .sqlite import Sqlite3Engine -def create_engine(database_config) -> BaseDatabaseEngine: +def create_engine(database_config: Mapping[str, Any]) -> BaseDatabaseEngine: name = database_config["name"] if name == "sqlite3": - import sqlite3 - - return Sqlite3Engine(sqlite3, database_config) + return Sqlite3Engine(database_config) if name == "psycopg2": - # Note that psycopg2cffi-compat provides the psycopg2 module on pypy. - import psycopg2 - - return PostgresEngine(psycopg2, database_config) + return PostgresEngine(database_config) raise RuntimeError("Unsupported database engine '%s'" % (name,)) diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index 143cd98ca2..971ff82693 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -13,9 +13,12 @@ # limitations under the License. import abc from enum import IntEnum -from typing import Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Mapping, Optional, TypeVar -from synapse.storage.types import Connection +from synapse.storage.types import Connection, Cursor, DBAPI2Module + +if TYPE_CHECKING: + from synapse.storage.database import LoggingDatabaseConnection class IsolationLevel(IntEnum): @@ -32,7 +35,7 @@ ConnectionType = TypeVar("ConnectionType", bound=Connection) class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): - def __init__(self, module, database_config: dict): + def __init__(self, module: DBAPI2Module, config: Mapping[str, Any]): self.module = module @property @@ -69,7 +72,7 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): ... @abc.abstractmethod - def check_new_database(self, txn) -> None: + def check_new_database(self, txn: Cursor) -> None: """Gets called when setting up a brand new database. This allows us to apply stricter checks on new databases versus existing database. """ @@ -79,8 +82,11 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): def convert_param_style(self, sql: str) -> str: ... + # This method would ideally take a plain ConnectionType, but it seems that + # the Sqlite engine expects to use LoggingDatabaseConnection.cursor + # instead of sqlite3.Connection.cursor: only the former takes a txn_name. @abc.abstractmethod - def on_new_connection(self, db_conn: ConnectionType) -> None: + def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None: ... @abc.abstractmethod @@ -92,7 +98,7 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): ... @abc.abstractmethod - def lock_table(self, txn, table: str) -> None: + def lock_table(self, txn: Cursor, table: str) -> None: ... @property @@ -102,12 +108,12 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): ... @abc.abstractmethod - def in_transaction(self, conn: Connection) -> bool: + def in_transaction(self, conn: ConnectionType) -> bool: """Whether the connection is currently in a transaction.""" ... @abc.abstractmethod - def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): + def attempt_to_set_autocommit(self, conn: ConnectionType, autocommit: bool) -> None: """Attempt to set the connections autocommit mode. When True queries are run outside of transactions. @@ -119,8 +125,8 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): @abc.abstractmethod def attempt_to_set_isolation_level( - self, conn: Connection, isolation_level: Optional[int] - ): + self, conn: ConnectionType, isolation_level: Optional[int] + ) -> None: """Attempt to set the connections isolation level. Note: This has no effect on SQLite3, as transactions are SERIALIZABLE by default. diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index e8d29e2870..391f8ed24a 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -13,39 +13,47 @@ # limitations under the License. import logging -from typing import Mapping, Optional +from typing import TYPE_CHECKING, Any, Mapping, NoReturn, Optional, Tuple, cast from synapse.storage.engines._base import ( BaseDatabaseEngine, IncorrectDatabaseSetup, IsolationLevel, ) -from synapse.storage.types import Connection +from synapse.storage.types import Cursor + +if TYPE_CHECKING: + import psycopg2 # noqa: F401 + + from synapse.storage.database import LoggingDatabaseConnection + logger = logging.getLogger(__name__) -class PostgresEngine(BaseDatabaseEngine): - def __init__(self, database_module, database_config): - super().__init__(database_module, database_config) - self.module.extensions.register_type(self.module.extensions.UNICODE) +class PostgresEngine(BaseDatabaseEngine["psycopg2.connection"]): + def __init__(self, database_config: Mapping[str, Any]): + import psycopg2.extensions + + super().__init__(psycopg2, database_config) + psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) # Disables passing `bytes` to txn.execute, c.f. #6186. If you do # actually want to use bytes than wrap it in `bytearray`. - def _disable_bytes_adapter(_): + def _disable_bytes_adapter(_: bytes) -> NoReturn: raise Exception("Passing bytes to DB is disabled.") - self.module.extensions.register_adapter(bytes, _disable_bytes_adapter) - self.synchronous_commit = database_config.get("synchronous_commit", True) - self._version = None # unknown as yet + psycopg2.extensions.register_adapter(bytes, _disable_bytes_adapter) + self.synchronous_commit: bool = database_config.get("synchronous_commit", True) + self._version: Optional[int] = None # unknown as yet self.isolation_level_map: Mapping[int, int] = { - IsolationLevel.READ_COMMITTED: self.module.extensions.ISOLATION_LEVEL_READ_COMMITTED, - IsolationLevel.REPEATABLE_READ: self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ, - IsolationLevel.SERIALIZABLE: self.module.extensions.ISOLATION_LEVEL_SERIALIZABLE, + IsolationLevel.READ_COMMITTED: psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED, + IsolationLevel.REPEATABLE_READ: psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ, + IsolationLevel.SERIALIZABLE: psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE, } self.default_isolation_level = ( - self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ + psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) self.config = database_config @@ -53,19 +61,21 @@ class PostgresEngine(BaseDatabaseEngine): def single_threaded(self) -> bool: return False - def get_db_locale(self, txn): + def get_db_locale(self, txn: Cursor) -> Tuple[str, str]: txn.execute( "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()" ) - collation, ctype = txn.fetchone() + collation, ctype = cast(Tuple[str, str], txn.fetchone()) return collation, ctype - def check_database(self, db_conn, allow_outdated_version: bool = False): + def check_database( + self, db_conn: "psycopg2.connection", allow_outdated_version: bool = False + ) -> None: # Get the version of PostgreSQL that we're using. As per the psycopg2 # docs: The number is formed by converting the major, minor, and # revision numbers into two-decimal-digit numbers and appending them # together. For example, version 8.1.5 will be returned as 80105 - self._version = db_conn.server_version + self._version = cast(int, db_conn.server_version) allow_unsafe_locale = self.config.get("allow_unsafe_locale", False) # Are we on a supported PostgreSQL version? @@ -108,7 +118,7 @@ class PostgresEngine(BaseDatabaseEngine): ctype, ) - def check_new_database(self, txn): + def check_new_database(self, txn: Cursor) -> None: """Gets called when setting up a brand new database. This allows us to apply stricter checks on new databases versus existing database. """ @@ -129,10 +139,10 @@ class PostgresEngine(BaseDatabaseEngine): "See docs/postgres.md for more information." % ("\n".join(errors)) ) - def convert_param_style(self, sql): + def convert_param_style(self, sql: str) -> str: return sql.replace("?", "%s") - def on_new_connection(self, db_conn): + def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None: db_conn.set_isolation_level(self.default_isolation_level) # Set the bytea output to escape, vs the default of hex @@ -149,14 +159,14 @@ class PostgresEngine(BaseDatabaseEngine): db_conn.commit() @property - def can_native_upsert(self): + def can_native_upsert(self) -> bool: """ Can we use native UPSERTs? """ return True @property - def supports_using_any_list(self): + def supports_using_any_list(self) -> bool: """Do we support using `a = ANY(?)` and passing a list""" return True @@ -165,27 +175,25 @@ class PostgresEngine(BaseDatabaseEngine): """Do we support the `RETURNING` clause in insert/update/delete?""" return True - def is_deadlock(self, error): - if isinstance(error, self.module.DatabaseError): + def is_deadlock(self, error: Exception) -> bool: + import psycopg2.extensions + + if isinstance(error, psycopg2.DatabaseError): # https://www.postgresql.org/docs/current/static/errcodes-appendix.html # "40001" serialization_failure # "40P01" deadlock_detected return error.pgcode in ["40001", "40P01"] return False - def is_connection_closed(self, conn): + def is_connection_closed(self, conn: "psycopg2.connection") -> bool: return bool(conn.closed) - def lock_table(self, txn, table): + def lock_table(self, txn: Cursor, table: str) -> None: txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,)) @property - def server_version(self): - """Returns a string giving the server version. For example: '8.1.5' - - Returns: - string - """ + def server_version(self) -> str: + """Returns a string giving the server version. For example: '8.1.5'.""" # note that this is a bit of a hack because it relies on check_database # having been called. Still, that should be a safe bet here. numver = self._version @@ -197,17 +205,21 @@ class PostgresEngine(BaseDatabaseEngine): else: return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100) - def in_transaction(self, conn: Connection) -> bool: - return conn.status != self.module.extensions.STATUS_READY # type: ignore + def in_transaction(self, conn: "psycopg2.connection") -> bool: + import psycopg2.extensions + + return conn.status != psycopg2.extensions.STATUS_READY - def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): - return conn.set_session(autocommit=autocommit) # type: ignore + def attempt_to_set_autocommit( + self, conn: "psycopg2.connection", autocommit: bool + ) -> None: + return conn.set_session(autocommit=autocommit) def attempt_to_set_isolation_level( - self, conn: Connection, isolation_level: Optional[int] - ): + self, conn: "psycopg2.connection", isolation_level: Optional[int] + ) -> None: if isolation_level is None: isolation_level = self.default_isolation_level else: isolation_level = self.isolation_level_map[isolation_level] - return conn.set_isolation_level(isolation_level) # type: ignore + return conn.set_isolation_level(isolation_level) diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 6c19e55999..621f2c5efe 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -12,21 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. import platform +import sqlite3 import struct import threading -import typing -from typing import Optional +from typing import TYPE_CHECKING, Any, List, Mapping, Optional from synapse.storage.engines import BaseDatabaseEngine -from synapse.storage.types import Connection +from synapse.storage.types import Cursor -if typing.TYPE_CHECKING: - import sqlite3 # noqa: F401 +if TYPE_CHECKING: + from synapse.storage.database import LoggingDatabaseConnection -class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): - def __init__(self, database_module, database_config): - super().__init__(database_module, database_config) +class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]): + def __init__(self, database_config: Mapping[str, Any]): + super().__init__(sqlite3, database_config) database = database_config.get("args", {}).get("database") self._is_in_memory = database in ( @@ -37,7 +37,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): if platform.python_implementation() == "PyPy": # pypy's sqlite3 module doesn't handle bytearrays, convert them # back to bytes. - database_module.register_adapter(bytearray, lambda array: bytes(array)) + sqlite3.register_adapter(bytearray, lambda array: bytes(array)) # The current max state_group, or None if we haven't looked # in the DB yet. @@ -49,41 +49,43 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): return True @property - def can_native_upsert(self): + def can_native_upsert(self) -> bool: """ Do we support native UPSERTs? This requires SQLite3 3.24+, plus some more work we haven't done yet to tell what was inserted vs updated. """ - return self.module.sqlite_version_info >= (3, 24, 0) + return sqlite3.sqlite_version_info >= (3, 24, 0) @property - def supports_using_any_list(self): + def supports_using_any_list(self) -> bool: """Do we support using `a = ANY(?)` and passing a list""" return False @property def supports_returning(self) -> bool: """Do we support the `RETURNING` clause in insert/update/delete?""" - return self.module.sqlite_version_info >= (3, 35, 0) + return sqlite3.sqlite_version_info >= (3, 35, 0) - def check_database(self, db_conn, allow_outdated_version: bool = False): + def check_database( + self, db_conn: sqlite3.Connection, allow_outdated_version: bool = False + ) -> None: if not allow_outdated_version: - version = self.module.sqlite_version_info + version = sqlite3.sqlite_version_info # Synapse is untested against older SQLite versions, and we don't want # to let users upgrade to a version of Synapse with broken support for their # sqlite version, because it risks leaving them with a half-upgraded db. if version < (3, 22, 0): raise RuntimeError("Synapse requires sqlite 3.22 or above.") - def check_new_database(self, txn): + def check_new_database(self, txn: Cursor) -> None: """Gets called when setting up a brand new database. This allows us to apply stricter checks on new databases versus existing database. """ - def convert_param_style(self, sql): + def convert_param_style(self, sql: str) -> str: return sql - def on_new_connection(self, db_conn): + def on_new_connection(self, db_conn: "LoggingDatabaseConnection") -> None: # We need to import here to avoid an import loop. from synapse.storage.prepare_database import prepare_database @@ -97,48 +99,46 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): db_conn.execute("PRAGMA foreign_keys = ON;") db_conn.commit() - def is_deadlock(self, error): + def is_deadlock(self, error: Exception) -> bool: return False - def is_connection_closed(self, conn): + def is_connection_closed(self, conn: sqlite3.Connection) -> bool: return False - def lock_table(self, txn, table): + def lock_table(self, txn: Cursor, table: str) -> None: return @property - def server_version(self): - """Gets a string giving the server version. For example: '3.22.0' + def server_version(self) -> str: + """Gets a string giving the server version. For example: '3.22.0'.""" + return "%i.%i.%i" % sqlite3.sqlite_version_info - Returns: - string - """ - return "%i.%i.%i" % self.module.sqlite_version_info - - def in_transaction(self, conn: Connection) -> bool: - return conn.in_transaction # type: ignore + def in_transaction(self, conn: sqlite3.Connection) -> bool: + return conn.in_transaction - def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool): + def attempt_to_set_autocommit( + self, conn: sqlite3.Connection, autocommit: bool + ) -> None: # Twisted doesn't let us set attributes on the connections, so we can't # set the connection to autocommit mode. pass def attempt_to_set_isolation_level( - self, conn: Connection, isolation_level: Optional[int] - ): - # All transactions are SERIALIZABLE by default in sqllite + self, conn: sqlite3.Connection, isolation_level: Optional[int] + ) -> None: + # All transactions are SERIALIZABLE by default in sqlite pass # Following functions taken from: https://github.com/coleifer/peewee -def _parse_match_info(buf): +def _parse_match_info(buf: bytes) -> List[int]: bufsize = len(buf) return [struct.unpack("@I", buf[i : i + 4])[0] for i in range(0, bufsize, 4)] -def _rank(raw_match_info): +def _rank(raw_match_info: bytes) -> float: """Handle match_info called w/default args 'pcx' - based on the example rank function http://sqlite.org/fts3.html#appendix_a """ diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 97118045a1..0fc282866b 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -25,6 +25,7 @@ from typing import ( Collection, Deque, Dict, + Generator, Generic, Iterable, List, @@ -207,7 +208,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): return res - def _handle_queue(self, room_id): + def _handle_queue(self, room_id: str) -> None: """Attempts to handle the queue for a room if not already being handled. The queue's callback will be invoked with for each item in the queue, @@ -227,7 +228,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): self._currently_persisting_rooms.add(room_id) - async def handle_queue_loop(): + async def handle_queue_loop() -> None: try: queue = self._get_drainining_queue(room_id) for item in queue: @@ -250,15 +251,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]): with PreserveLoggingContext(): item.deferred.callback(ret) finally: - queue = self._event_persist_queues.pop(room_id, None) - if queue: - self._event_persist_queues[room_id] = queue + remaining_queue = self._event_persist_queues.pop(room_id, None) + if remaining_queue: + self._event_persist_queues[room_id] = remaining_queue self._currently_persisting_rooms.discard(room_id) # set handle_queue_loop off in the background run_as_background_process("persist_events", handle_queue_loop) - def _get_drainining_queue(self, room_id): + def _get_drainining_queue( + self, room_id: str + ) -> Generator[_EventPersistQueueItem, None, None]: queue = self._event_persist_queues.setdefault(room_id, deque()) try: @@ -317,7 +320,9 @@ class EventsPersistenceStorage: for event, ctx in events_and_contexts: partitioned.setdefault(event.room_id, []).append((event, ctx)) - async def enqueue(item): + async def enqueue( + item: Tuple[str, List[Tuple[EventBase, EventContext]]] + ) -> Dict[str, str]: room_id, evs_ctxs = item return await self._event_persist_queue.add_to_queue( room_id, evs_ctxs, backfilled=backfilled @@ -487,12 +492,6 @@ class EventsPersistenceStorage: # extremities in each room new_forward_extremities: Dict[str, Set[str]] = {} - # map room_id->(type,state_key)->event_id tracking the full - # state in each room after adding these events. - # This is simply used to prefill the get_current_state_ids - # cache - current_state_for_room: Dict[str, StateMap[str]] = {} - # map room_id->(to_delete, to_insert) where to_delete is a list # of type/state keys to remove from current state, and to_insert # is a map (type,key)->event_id giving the state delta in each @@ -628,14 +627,8 @@ class EventsPersistenceStorage: state_delta_for_room[room_id] = delta - # If we have the current_state then lets prefill - # the cache with it. - if current_state is not None: - current_state_for_room[room_id] = current_state - await self.persist_events_store._persist_events_and_state_updates( chunk, - current_state_for_room=current_state_for_room, state_delta_for_room=state_delta_for_room, new_forward_extremities=new_forward_extremities, use_negative_stream_ordering=backfilled, @@ -733,7 +726,8 @@ class EventsPersistenceStorage: The first state map is the full new current state and the second is the delta to the existing current state. If both are None then - there has been no change. + there has been no change. Either or neither can be None if there + has been a change. The function may prune some old entries from the set of new forward extremities if it's safe to do so. @@ -743,9 +737,6 @@ class EventsPersistenceStorage: the new current state is only returned if we've already calculated it. """ - # map from state_group to ((type, key) -> event_id) state map - state_groups_map = {} - # Map from (prev state group, new state group) -> delta state dict state_group_deltas = {} @@ -759,16 +750,6 @@ class EventsPersistenceStorage: ) continue - if ctx.state_group in state_groups_map: - continue - - # We're only interested in pulling out state that has already - # been cached in the context. We'll pull stuff out of the DB later - # if necessary. - current_state_ids = ctx.get_cached_current_state_ids() - if current_state_ids is not None: - state_groups_map[ctx.state_group] = current_state_ids - if ctx.prev_group: state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids @@ -826,18 +807,14 @@ class EventsPersistenceStorage: delta_ids = state_group_deltas.get((old_state_group, new_state_group), None) if delta_ids is not None: # We have a delta from the existing to new current state, - # so lets just return that. If we happen to already have - # the current state in memory then lets also return that, - # but it doesn't matter if we don't. - new_state = state_groups_map.get(new_state_group) - return new_state, delta_ids, new_latest_event_ids + # so lets just return that. + return None, delta_ids, new_latest_event_ids # Now that we have calculated new_state_groups we need to get # their state IDs so we can resolve to a single state set. - missing_state = new_state_groups - set(state_groups_map) - if missing_state: - group_to_state = await self.state_store._get_state_for_groups(missing_state) - state_groups_map.update(group_to_state) + state_groups_map = await self.state_store._get_state_for_groups( + new_state_groups + ) if len(new_state_groups) == 1: # If there is only one state group, then we know what the current @@ -1130,7 +1107,7 @@ class EventsPersistenceStorage: return False - async def _handle_potentially_left_users(self, user_ids: Set[str]): + async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None: """Given a set of remote users check if the server still shares a room with them. If not then mark those users' device cache as stale. """ diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 546d6bae6e..c33df42084 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -85,7 +85,7 @@ def prepare_database( database_engine: BaseDatabaseEngine, config: Optional[HomeServerConfig], databases: Collection[str] = ("main", "state"), -): +) -> None: """Prepares a physical database for usage. Will either create all necessary tables or upgrade from an older schema version. diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 871d4ace12..20c344faea 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 69 # remember to update the list below when updating +SCHEMA_VERSION = 70 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -62,6 +62,9 @@ Changes in SCHEMA_VERSION = 68: Changes in SCHEMA_VERSION = 69: - We now write to `device_lists_changes_in_room` table. - Use sequence to generate future `application_services_txns.txn_id`s + +Changes in SCHEMA_VERSION = 70: + - event_reference_hashes is no longer written to. """ diff --git a/synapse/storage/schema/main/delta/69/02cache_invalidation_index.sql b/synapse/storage/schema/main/delta/69/02cache_invalidation_index.sql new file mode 100644 index 0000000000..22ae3b8c00 --- /dev/null +++ b/synapse/storage/schema/main/delta/69/02cache_invalidation_index.sql @@ -0,0 +1,18 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Background update to clear the inboxes of hidden and deleted devices. +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6902, 'cache_invalidation_index_by_instance', '{}'); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index d1d5859214..d4a1bd4f9d 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -62,7 +62,7 @@ class StateFilter: types: "frozendict[str, Optional[FrozenSet[str]]]" include_others: bool = False - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: # If `include_others` is set we canonicalise the filter by removing # wildcards from the types dictionary if self.include_others: @@ -138,7 +138,9 @@ class StateFilter: ) @staticmethod - def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool): + def freeze( + types: Mapping[str, Optional[Collection[str]]], include_others: bool + ) -> "StateFilter": """ Returns a (frozen) StateFilter with the same contents as the parameters specified here, which can be made of mutable types. diff --git a/synapse/storage/types.py b/synapse/storage/types.py index d7d6f1d90e..0031df1e06 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union +from types import TracebackType +from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union from typing_extensions import Protocol @@ -86,5 +87,80 @@ class Connection(Protocol): def __enter__(self) -> "Connection": ... - def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> Optional[bool]: + ... + + +class DBAPI2Module(Protocol): + """The module-level attributes that we use from PEP 249. + + This is NOT a comprehensive stub for the entire DBAPI2.""" + + __name__: str + + # Exceptions. See https://peps.python.org/pep-0249/#exceptions + + # For our specific drivers: + # - Python's sqlite3 module doesn't contains the same descriptions as the + # DBAPI2 spec, see https://docs.python.org/3/library/sqlite3.html#exceptions + # - Psycopg2 maps every Postgres error code onto a unique exception class which + # extends from this hierarchy. See + # https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#exceptions + # https://www.postgresql.org/docs/current/errcodes-appendix.html#ERRCODES-TABLE + Warning: Type[Exception] + Error: Type[Exception] + + # Errors are divided into `InterfaceError`s (something went wrong in the database + # driver) and `DatabaseError`s (something went wrong in the database). These are + # both subclasses of `Error`, but we can't currently express this in type + # annotations due to https://github.com/python/mypy/issues/8397 + InterfaceError: Type[Exception] + DatabaseError: Type[Exception] + + # Everything below is a subclass of `DatabaseError`. + + # Roughly: the database rejected a nonsensical value. Examples: + # - An integer was too big for its data type. + # - An invalid date time was provided. + # - A string contained a null code point. + DataError: Type[Exception] + + # Roughly: something went wrong in the database, but it's not within the application + # programmer's control. Examples: + # - We failed to establish a connection to the database. + # - The connection to the database was lost. + # - A deadlock was detected. + # - A serialisation failure occurred. + # - The database ran out of resources, such as storage, memory, connections, etc. + # - The database encountered an error from the operating system. + OperationalError: Type[Exception] + + # Roughly: we've given the database data which breaks a rule we asked it to enforce. + # Examples: + # - Stop, criminal scum! You violated the foreign key constraint + # - Also check constraints, non-null constraints, etc. + IntegrityError: Type[Exception] + + # Roughly: something went wrong within the database server itself. + InternalError: Type[Exception] + + # Roughly: the application did something silly that needs to be fixed. Examples: + # - We don't have permissions to do something. + # - We tried to create a table with duplicate column names. + # - We tried to use a reserved name. + # - We referred to a column that doesn't exist. + ProgrammingError: Type[Exception] + + # Roughly: we've tried to do something that this database doesn't support. + NotSupportedError: Type[Exception] + + def connect(self, **parameters: object) -> Connection: ... + + +__all__ = ["Cursor", "Connection", "DBAPI2Module"] diff --git a/synapse/types.py b/synapse/types.py index 9ac688b23b..bd8071d51d 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -24,6 +24,7 @@ from typing import ( Mapping, Match, MutableMapping, + NoReturn, Optional, Set, Tuple, @@ -35,7 +36,8 @@ from typing import ( import attr from frozendict import frozendict from signedjson.key import decode_verify_key_bytes -from typing_extensions import TypedDict +from signedjson.types import VerifyKey +from typing_extensions import Final, TypedDict from unpaddedbase64 import decode_base64 from zope.interface import Interface @@ -55,6 +57,7 @@ from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: from synapse.appservice.api import ApplicationService from synapse.storage.databases.main import DataStore, PurgeEventsStore + from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore # Define a state map type from type/state_key to T (usually an event ID or # event) @@ -114,7 +117,7 @@ class Requester: app_service: Optional["ApplicationService"] authenticated_entity: str - def serialize(self): + def serialize(self) -> Dict[str, Any]: """Converts self to a type that can be serialized as JSON, and then deserialized by `deserialize` @@ -132,7 +135,9 @@ class Requester: } @staticmethod - def deserialize(store, input): + def deserialize( + store: "ApplicationServiceWorkerStore", input: Dict[str, Any] + ) -> "Requester": """Converts a dict that was produced by `serialize` back into a Requester. @@ -236,10 +241,10 @@ class DomainSpecificString(metaclass=abc.ABCMeta): domain: str # Because this is a frozen class, it is deeply immutable. - def __copy__(self): + def __copy__(self: DS) -> DS: return self - def __deepcopy__(self, memo): + def __deepcopy__(self: DS, memo: Dict[str, object]) -> DS: return self @classmethod @@ -625,6 +630,22 @@ class RoomStreamToken: return "s%d" % (self.stream,) +class StreamKeyType: + """Known stream types. + + A stream is a list of entities ordered by an incrementing "stream token". + """ + + ROOM: Final = "room_key" + PRESENCE: Final = "presence_key" + TYPING: Final = "typing_key" + RECEIPT: Final = "receipt_key" + ACCOUNT_DATA: Final = "account_data_key" + PUSH_RULES: Final = "push_rules_key" + TO_DEVICE: Final = "to_device_key" + DEVICE_LIST: Final = "device_list_key" + + @attr.s(slots=True, frozen=True, auto_attribs=True) class StreamToken: """A collection of keys joined together by underscores in the following @@ -729,16 +750,18 @@ class StreamToken: ) @property - def room_stream_id(self): + def room_stream_id(self) -> int: return self.room_key.stream - def copy_and_advance(self, key, new_value) -> "StreamToken": + def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken": """Advance the given key in the token to a new value if and only if the new value is after the old value. + + :raises TypeError: if `key` is not the one of the keys tracked by a StreamToken. """ - if key == "room_key": + if key == StreamKeyType.ROOM: new_token = self.copy_and_replace( - "room_key", self.room_key.copy_and_advance(new_value) + StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value) ) return new_token @@ -751,7 +774,7 @@ class StreamToken: else: return self - def copy_and_replace(self, key, new_value) -> "StreamToken": + def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken": return attr.evolve(self, **{key: new_value}) @@ -793,14 +816,14 @@ class ThirdPartyInstanceID: # Deny iteration because it will bite you if you try to create a singleton # set by: # users = set(user) - def __iter__(self): + def __iter__(self) -> NoReturn: raise ValueError("Attempted to iterate a %s" % (type(self).__name__,)) # Because this class is a frozen class, it is deeply immutable. - def __copy__(self): + def __copy__(self) -> "ThirdPartyInstanceID": return self - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: Dict[str, object]) -> "ThirdPartyInstanceID": return self @classmethod @@ -852,25 +875,28 @@ class DeviceListUpdates: return bool(self.changed or self.left) -def get_verify_key_from_cross_signing_key(key_info): +def get_verify_key_from_cross_signing_key( + key_info: Mapping[str, Any] +) -> Tuple[str, VerifyKey]: """Get the key ID and signedjson verify key from a cross-signing key dict Args: - key_info (dict): a cross-signing key dict, which must have a "keys" + key_info: a cross-signing key dict, which must have a "keys" property that has exactly one item in it Returns: - (str, VerifyKey): the key ID and verify key for the cross-signing key + the key ID and verify key for the cross-signing key """ - # make sure that exactly one key is provided + # make sure that a `keys` field is provided if "keys" not in key_info: raise ValueError("Invalid key") keys = key_info["keys"] - if len(keys) != 1: - raise ValueError("Invalid key") - # and return that one key - for key_id, key_data in keys.items(): + # and that it contains exactly one key + if len(keys) == 1: + key_id, key_data = next(iter(keys.items())) return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)) + else: + raise ValueError("Invalid key") @attr.s(auto_attribs=True, frozen=True, slots=True) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 45ff0de638..a3b60578e3 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +import math import threading import weakref from enum import Enum @@ -40,6 +41,7 @@ from twisted.internet.interfaces import IReactorTime from synapse.config import cache as cache_config from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.metrics.jemalloc import get_jemalloc_stats from synapse.util import Clock, caches from synapse.util.caches import CacheMetric, EvictionReason, register_cache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry @@ -106,10 +108,16 @@ GLOBAL_ROOT = ListNode["_Node"].create_root_node() @wrap_as_background_process("LruCache._expire_old_entries") -async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None: +async def _expire_old_entries( + clock: Clock, expiry_seconds: int, autotune_config: Optional[dict] +) -> None: """Walks the global cache list to find cache entries that haven't been - accessed in the given number of seconds. + accessed in the given number of seconds, or if a given memory threshold has been breached. """ + if autotune_config: + max_cache_memory_usage = autotune_config["max_cache_memory_usage"] + target_cache_memory_usage = autotune_config["target_cache_memory_usage"] + min_cache_ttl = autotune_config["min_cache_ttl"] / 1000 now = int(clock.time()) node = GLOBAL_ROOT.prev_node @@ -119,11 +127,36 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None: logger.debug("Searching for stale caches") + evicting_due_to_memory = False + + # determine if we're evicting due to memory + jemalloc_interface = get_jemalloc_stats() + if jemalloc_interface and autotune_config: + try: + jemalloc_interface.refresh_stats() + mem_usage = jemalloc_interface.get_stat("allocated") + if mem_usage > max_cache_memory_usage: + logger.info("Begin memory-based cache eviction.") + evicting_due_to_memory = True + except Exception: + logger.warning( + "Unable to read allocated memory, skipping memory-based cache eviction." + ) + while node is not GLOBAL_ROOT: # Only the root node isn't a `_TimedListNode`. assert isinstance(node, _TimedListNode) - if node.last_access_ts_secs > now - expiry_seconds: + # if node has not aged past expiry_seconds and we are not evicting due to memory usage, there's + # nothing to do here + if ( + node.last_access_ts_secs > now - expiry_seconds + and not evicting_due_to_memory + ): + break + + # if entry is newer than min_cache_entry_ttl then do not evict and don't evict anything newer + if evicting_due_to_memory and now - node.last_access_ts_secs < min_cache_ttl: break cache_entry = node.get_cache_entry() @@ -136,10 +169,29 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None: assert cache_entry is not None cache_entry.drop_from_cache() + # Check mem allocation periodically if we are evicting a bunch of caches + if jemalloc_interface and evicting_due_to_memory and (i + 1) % 100 == 0: + try: + jemalloc_interface.refresh_stats() + mem_usage = jemalloc_interface.get_stat("allocated") + if mem_usage < target_cache_memory_usage: + evicting_due_to_memory = False + logger.info("Stop memory-based cache eviction.") + except Exception: + logger.warning( + "Unable to read allocated memory, this may affect memory-based cache eviction." + ) + # If we've failed to read the current memory usage then we + # should stop trying to evict based on memory usage + evicting_due_to_memory = False + # If we do lots of work at once we yield to allow other stuff to happen. if (i + 1) % 10000 == 0: logger.debug("Waiting during drop") - await clock.sleep(0) + if node.last_access_ts_secs > now - expiry_seconds: + await clock.sleep(0.5) + else: + await clock.sleep(0) logger.debug("Waking during drop") node = next_node @@ -156,21 +208,28 @@ async def _expire_old_entries(clock: Clock, expiry_seconds: int) -> None: def setup_expire_lru_cache_entries(hs: "HomeServer") -> None: """Start a background job that expires all cache entries if they have not - been accessed for the given number of seconds. + been accessed for the given number of seconds, or if a given memory usage threshold has been + breached. """ - if not hs.config.caches.expiry_time_msec: + if not hs.config.caches.expiry_time_msec and not hs.config.caches.cache_autotuning: return - logger.info( - "Expiring LRU caches after %d seconds", hs.config.caches.expiry_time_msec / 1000 - ) + if hs.config.caches.expiry_time_msec: + expiry_time = hs.config.caches.expiry_time_msec / 1000 + logger.info("Expiring LRU caches after %d seconds", expiry_time) + else: + expiry_time = math.inf global USE_GLOBAL_LIST USE_GLOBAL_LIST = True clock = hs.get_clock() clock.looping_call( - _expire_old_entries, 30 * 1000, clock, hs.config.caches.expiry_time_msec / 1000 + _expire_old_entries, + 30 * 1000, + clock, + expiry_time, + hs.config.caches.cache_autotuning, ) diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py index 4bb82e810e..d2b3c299e3 100644 --- a/tests/config/test_cache.py +++ b/tests/config/test_cache.py @@ -38,6 +38,7 @@ class CacheConfigTests(TestCase): "SYNAPSE_NOT_CACHE": "BLAH", } self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0}) @@ -52,6 +53,7 @@ class CacheConfigTests(TestCase): "SYNAPSE_CACHE_FACTOR_FOO": 1, } self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() self.assertEqual( dict(self.config.cache_factors), @@ -71,6 +73,7 @@ class CacheConfigTests(TestCase): config = {"caches": {"per_cache_factors": {"foo": 3}}} self.config.read_config(config) + self.config.resize_all_caches() self.assertEqual(cache.max_size, 300) @@ -82,6 +85,7 @@ class CacheConfigTests(TestCase): """ config = {"caches": {"per_cache_factors": {"foo": 2}}} self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() cache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) @@ -99,6 +103,7 @@ class CacheConfigTests(TestCase): config = {"caches": {"global_factor": 4}} self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() self.assertEqual(cache.max_size, 400) @@ -110,6 +115,7 @@ class CacheConfigTests(TestCase): """ config = {"caches": {"global_factor": 1.5}} self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() cache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) @@ -128,6 +134,7 @@ class CacheConfigTests(TestCase): "SYNAPSE_CACHE_FACTOR_CACHE_B": 3, } self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() cache_a = LruCache(100) add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor) @@ -148,6 +155,7 @@ class CacheConfigTests(TestCase): config = {"caches": {"event_cache_size": "10k"}} self.config.read_config(config, config_dir_path="", data_dir_path="") + self.config.resize_all_caches() cache = LruCache( max_size=self.config.event_cache_size, diff --git a/tests/federation/transport/server/__init__.py b/tests/federation/transport/server/__init__.py new file mode 100644 index 0000000000..3a5f22c022 --- /dev/null +++ b/tests/federation/transport/server/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/federation/transport/server/test__base.py b/tests/federation/transport/server/test__base.py new file mode 100644 index 0000000000..ac3695a8cc --- /dev/null +++ b/tests/federation/transport/server/test__base.py @@ -0,0 +1,114 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http import HTTPStatus +from typing import Dict, List, Tuple + +from synapse.api.errors import Codes +from synapse.federation.transport.server import BaseFederationServlet +from synapse.federation.transport.server._base import Authenticator +from synapse.http.server import JsonResource, cancellable +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util.ratelimitutils import FederationRateLimiter + +from tests import unittest +from tests.http.server._base import EndpointCancellationTestHelperMixin + + +class CancellableFederationServlet(BaseFederationServlet): + PATH = "/sleep" + + def __init__( + self, + hs: HomeServer, + authenticator: Authenticator, + ratelimiter: FederationRateLimiter, + server_name: str, + ): + super().__init__(hs, authenticator, ratelimiter, server_name) + self.clock = hs.get_clock() + + @cancellable + async def on_GET( + self, origin: str, content: None, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + +class BaseFederationServletCancellationTests( + unittest.FederatingHomeserverTestCase, EndpointCancellationTestHelperMixin +): + """Tests for `BaseFederationServlet` cancellation.""" + + skip = "`BaseFederationServlet` does not support cancellation yet." + + path = f"{CancellableFederationServlet.PREFIX}{CancellableFederationServlet.PATH}" + + def create_test_resource(self): + """Overrides `HomeserverTestCase.create_test_resource`.""" + resource = JsonResource(self.hs) + + CancellableFederationServlet( + hs=self.hs, + authenticator=Authenticator(self.hs), + ratelimiter=self.hs.get_federation_ratelimiter(), + server_name=self.hs.hostname, + ).register(resource) + + return resource + + def test_cancellable_disconnect(self) -> None: + """Test that handlers with the `@cancellable` flag can be cancelled.""" + channel = self.make_signed_federation_request( + "GET", self.path, await_result=False + ) + + # Advance past all the rate limiting logic. If we disconnect too early, the + # request won't be processed. + self.pump() + + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=True, + expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN}, + ) + + def test_uncancellable_disconnect(self) -> None: + """Test that handlers without the `@cancellable` flag cannot be cancelled.""" + channel = self.make_signed_federation_request( + "POST", + self.path, + content={}, + await_result=False, + ) + + # Advance past all the rate limiting logic. If we disconnect too early, the + # request won't be processed. + self.pump() + + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=False, + expected_body={"result": True}, + ) diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 489ba57736..e64b28f28b 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -148,7 +148,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): prev_event.internal_metadata.outlier = True persistence = self.hs.get_storage().persistence self.get_success( - persistence.persist_event(prev_event, EventContext.for_outlier()) + persistence.persist_event( + prev_event, EventContext.for_outlier(self.hs.get_storage()) + ) ) else: diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index 0482a1ea34..78807cdcfc 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from copy import deepcopy from typing import List from synapse.api.constants import ReceiptTypes @@ -125,42 +125,6 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - def test_handles_missing_content_of_m_read(self): - self._test_filters_private( - [ - { - "content": { - "$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}}, - "$1435641916114394fHBLK:matrix.org": { - ReceiptTypes.READ: { - "@user:jki.re": { - "ts": 1436451550453, - } - } - }, - }, - "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", - } - ], - [ - { - "content": { - "$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}}, - "$1435641916114394fHBLK:matrix.org": { - ReceiptTypes.READ: { - "@user:jki.re": { - "ts": 1436451550453, - } - } - }, - }, - "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", - } - ], - ) - def test_handles_empty_event(self): self._test_filters_private( [ @@ -332,9 +296,33 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) + def test_we_do_not_mutate(self): + """Ensure the input values are not modified.""" + events = [ + { + "content": { + "$1435641916114394fHBLK:matrix.org": { + ReceiptTypes.READ_PRIVATE: { + "@rikj:jki.re": { + "ts": 1436451550453, + } + } + } + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + } + ] + original_events = deepcopy(events) + self._test_filters_private(events, []) + # Since the events are fed in from a cache they should not be modified. + self.assertEqual(events, original_events) + def _test_filters_private( self, events: List[JsonDict], expected_output: List[JsonDict] ): """Tests that the _filter_out_private returns the expected output""" - filtered_events = self.event_source.filter_out_private(events, "@me:server.org") + filtered_events = self.event_source.filter_out_private_receipts( + events, "@me:server.org" + ) self.assertEqual(filtered_events, expected_output) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 865b8b7e47..db3302a4c7 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -160,6 +160,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # Blow away caches (supported room versions can only change due to a restart). self.store.get_rooms_for_user_with_stream_ordering.invalidate_all() self.store._get_event_cache.clear() + self.store._event_ref.clear() # The rooms should be excluded from the sync response. # Get a new request key. diff --git a/tests/http/server/__init__.py b/tests/http/server/__init__.py new file mode 100644 index 0000000000..3a5f22c022 --- /dev/null +++ b/tests/http/server/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py new file mode 100644 index 0000000000..b9f1a381aa --- /dev/null +++ b/tests/http/server/_base.py @@ -0,0 +1,100 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unles4s required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http import HTTPStatus +from typing import Any, Callable, Optional, Union +from unittest import mock + +from twisted.internet.error import ConnectionDone + +from synapse.http.server import ( + HTTP_STATUS_REQUEST_CANCELLED, + respond_with_html_bytes, + respond_with_json, +) +from synapse.types import JsonDict + +from tests import unittest +from tests.server import FakeChannel, ThreadedMemoryReactorClock + + +class EndpointCancellationTestHelperMixin(unittest.TestCase): + """Provides helper methods for testing cancellation of endpoints.""" + + def _test_disconnect( + self, + reactor: ThreadedMemoryReactorClock, + channel: FakeChannel, + expect_cancellation: bool, + expected_body: Union[bytes, JsonDict], + expected_code: Optional[int] = None, + ) -> None: + """Disconnects an in-flight request and checks the response. + + Args: + reactor: The twisted reactor running the request handler. + channel: The `FakeChannel` for the request. + expect_cancellation: `True` if request processing is expected to be + cancelled, `False` if the request should run to completion. + expected_body: The expected response for the request. + expected_code: The expected status code for the request. Defaults to `200` + or `499` depending on `expect_cancellation`. + """ + # Determine the expected status code. + if expected_code is None: + if expect_cancellation: + expected_code = HTTP_STATUS_REQUEST_CANCELLED + else: + expected_code = HTTPStatus.OK + + request = channel.request + self.assertFalse( + channel.is_finished(), + "Request finished before we could disconnect - " + "was `await_result=False` passed to `make_request`?", + ) + + # We're about to disconnect the request. This also disconnects the channel, so + # we have to rely on mocks to extract the response. + respond_method: Callable[..., Any] + if isinstance(expected_body, bytes): + respond_method = respond_with_html_bytes + else: + respond_method = respond_with_json + + with mock.patch( + f"synapse.http.server.{respond_method.__name__}", wraps=respond_method + ) as respond_mock: + # Disconnect the request. + request.connectionLost(reason=ConnectionDone()) + + if expect_cancellation: + # An immediate cancellation is expected. + respond_mock.assert_called_once() + args, _kwargs = respond_mock.call_args + code, body = args[1], args[2] + self.assertEqual(code, expected_code) + self.assertEqual(request.code, expected_code) + self.assertEqual(body, expected_body) + else: + respond_mock.assert_not_called() + + # The handler is expected to run to completion. + reactor.pump([1.0]) + respond_mock.assert_called_once() + args, _kwargs = respond_mock.call_args + code, body = args[1], args[2] + self.assertEqual(code, expected_code) + self.assertEqual(request.code, expected_code) + self.assertEqual(body, expected_body) diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py index a80bfb9f4e..ad521525cf 100644 --- a/tests/http/test_servlet.py +++ b/tests/http/test_servlet.py @@ -12,16 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +from http import HTTPStatus from io import BytesIO +from typing import Tuple from unittest.mock import Mock -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, SynapseError +from synapse.http.server import cancellable from synapse.http.servlet import ( + RestServlet, parse_json_object_from_request, parse_json_value_from_request, ) +from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns +from synapse.server import HomeServer +from synapse.types import JsonDict from tests import unittest +from tests.http.server._base import EndpointCancellationTestHelperMixin def make_request(content): @@ -76,3 +85,52 @@ class TestServletUtils(unittest.TestCase): # Test not an object with self.assertRaises(SynapseError): parse_json_object_from_request(make_request(b'["foo"]')) + + +class CancellableRestServlet(RestServlet): + """A `RestServlet` with a mix of cancellable and uncancellable handlers.""" + + PATTERNS = client_patterns("/sleep$") + + def __init__(self, hs: HomeServer): + super().__init__() + self.clock = hs.get_clock() + + @cancellable + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + +class TestRestServletCancellation( + unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin +): + """Tests for `RestServlet` cancellation.""" + + servlets = [ + lambda hs, http_server: CancellableRestServlet(hs).register(http_server) + ] + + def test_cancellable_disconnect(self) -> None: + """Test that handlers with the `@cancellable` flag can be cancelled.""" + channel = self.make_request("GET", "/sleep", await_result=False) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=True, + expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN}, + ) + + def test_uncancellable_disconnect(self) -> None: + """Test that handlers without the `@cancellable` flag cannot be cancelled.""" + channel = self.make_request("POST", "/sleep", await_result=False) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=False, + expected_body={"result": True}, + ) diff --git a/tests/replication/http/__init__.py b/tests/replication/http/__init__.py new file mode 100644 index 0000000000..3a5f22c022 --- /dev/null +++ b/tests/replication/http/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py new file mode 100644 index 0000000000..a5ab093a27 --- /dev/null +++ b/tests/replication/http/test__base.py @@ -0,0 +1,106 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http import HTTPStatus +from typing import Tuple + +from twisted.web.server import Request + +from synapse.api.errors import Codes +from synapse.http.server import JsonResource, cancellable +from synapse.replication.http import REPLICATION_PREFIX +from synapse.replication.http._base import ReplicationEndpoint +from synapse.server import HomeServer +from synapse.types import JsonDict + +from tests import unittest +from tests.http.server._base import EndpointCancellationTestHelperMixin + + +class CancellableReplicationEndpoint(ReplicationEndpoint): + NAME = "cancellable_sleep" + PATH_ARGS = () + CACHE = False + + def __init__(self, hs: HomeServer): + super().__init__(hs) + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload() -> JsonDict: + return {} + + @cancellable + async def _handle_request( # type: ignore[override] + self, request: Request + ) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + +class UncancellableReplicationEndpoint(ReplicationEndpoint): + NAME = "uncancellable_sleep" + PATH_ARGS = () + CACHE = False + + def __init__(self, hs: HomeServer): + super().__init__(hs) + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload() -> JsonDict: + return {} + + async def _handle_request( # type: ignore[override] + self, request: Request + ) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + +class ReplicationEndpointCancellationTestCase( + unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin +): + """Tests for `ReplicationEndpoint` cancellation.""" + + def create_test_resource(self): + """Overrides `HomeserverTestCase.create_test_resource`.""" + resource = JsonResource(self.hs) + + CancellableReplicationEndpoint(self.hs).register(resource) + UncancellableReplicationEndpoint(self.hs).register(resource) + + return resource + + def test_cancellable_disconnect(self) -> None: + """Test that handlers with the `@cancellable` flag can be cancelled.""" + path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/" + channel = self.make_request("POST", path, await_result=False) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=True, + expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN}, + ) + + def test_uncancellable_disconnect(self) -> None: + """Test that handlers without the `@cancellable` flag cannot be cancelled.""" + path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/" + channel = self.make_request("POST", path, await_result=False) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=False, + expected_body={"result": True}, + ) diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 5f142e84c3..a7ca68069e 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -14,7 +14,6 @@ import logging from unittest.mock import patch -from synapse.api.room_versions import RoomVersion from synapse.rest import admin from synapse.rest.client import login, room, sync from synapse.storage.util.id_generators import MultiWriterIdGenerator @@ -64,21 +63,10 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): # We control the room ID generation by patching out the # `_generate_room_id` method - async def generate_room( - creator_id: str, is_public: bool, room_version: RoomVersion - ): - await self.store.store_room( - room_id=room_id, - room_creator_user_id=creator_id, - is_public=is_public, - room_version=room_version, - ) - return room_id - with patch( "synapse.handlers.room.RoomCreationHandler._generate_room_id" ) as mock: - mock.side_effect = generate_room + mock.side_effect = lambda: room_id self.helper.create_room_as(user_id, tok=tok) def test_basic(self): diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 9443daa056..d0197aca94 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -925,7 +925,7 @@ class RoomJoinTestCase(RoomBase): ) -> bool: return return_value - callback_mock = Mock(side_effect=user_may_join_room) + callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None) self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock) # Join a first room, without being invited to it. @@ -1116,6 +1116,264 @@ class RoomMessagesTestCase(RoomBase): self.assertEqual(200, channel.code, msg=channel.result["body"]) +class RoomPowerLevelOverridesTestCase(RoomBase): + """Tests that the power levels can be overridden with server config.""" + + user_id = "@sid1:red" + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.admin_user_id = self.register_user("admin", "pass") + self.admin_access_token = self.login("admin", "pass") + + def power_levels(self, room_id: str) -> Dict[str, Any]: + return self.helper.get_state( + room_id, "m.room.power_levels", self.admin_access_token + ) + + def test_default_power_levels_with_room_override(self) -> None: + """ + Create a room, providing power level overrides. + Confirm that the room's power levels reflect the overrides. + + See https://github.com/matrix-org/matrix-spec/issues/492 + - currently we overwrite each key of power_level_content_override + completely. + """ + + room_id = self.helper.create_room_as( + self.user_id, + extra_content={ + "power_level_content_override": {"events": {"custom.event": 0}} + }, + ) + self.assertEqual( + { + "custom.event": 0, + }, + self.power_levels(room_id)["events"], + ) + + @unittest.override_config( + { + "default_power_level_content_override": { + "public_chat": {"events": {"custom.event": 0}}, + } + }, + ) + def test_power_levels_with_server_override(self) -> None: + """ + With a server configured to modify the room-level defaults, + Create a room, without providing any extra power level overrides. + Confirm that the room's power levels reflect the server-level overrides. + + Similar to https://github.com/matrix-org/matrix-spec/issues/492, + we overwrite each key of power_level_content_override completely. + """ + + room_id = self.helper.create_room_as(self.user_id) + self.assertEqual( + { + "custom.event": 0, + }, + self.power_levels(room_id)["events"], + ) + + @unittest.override_config( + { + "default_power_level_content_override": { + "public_chat": { + "events": {"server.event": 0}, + "ban": 13, + }, + } + }, + ) + def test_power_levels_with_server_and_room_overrides(self) -> None: + """ + With a server configured to modify the room-level defaults, + create a room, providing different overrides. + Confirm that the room's power levels reflect both overrides, and + choose the room overrides where they clash. + """ + + room_id = self.helper.create_room_as( + self.user_id, + extra_content={ + "power_level_content_override": {"events": {"room.event": 0}} + }, + ) + + # Room override wins over server config + self.assertEqual( + {"room.event": 0}, + self.power_levels(room_id)["events"], + ) + + # But where there is no room override, server config wins + self.assertEqual(13, self.power_levels(room_id)["ban"]) + + +class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): + """ + Tests that we can really do various otherwise-prohibited actions + based on overriding the power levels in config. + """ + + user_id = "@sid1:red" + + def test_creator_can_post_state_event(self) -> None: + # Given I am the creator of a room + room_id = self.helper.create_room_as(self.user_id) + + # When I send a state event + path = "/rooms/{room_id}/state/custom.event/my_state_key".format( + room_id=urlparse.quote(room_id), + ) + channel = self.make_request("PUT", path, "{}") + + # Then I am allowed + self.assertEqual(200, channel.code, msg=channel.result["body"]) + + def test_normal_user_can_not_post_state_event(self) -> None: + # Given I am a normal member of a room + room_id = self.helper.create_room_as("@some_other_guy:red") + self.helper.join(room=room_id, user=self.user_id) + + # When I send a state event + path = "/rooms/{room_id}/state/custom.event/my_state_key".format( + room_id=urlparse.quote(room_id), + ) + channel = self.make_request("PUT", path, "{}") + + # Then I am not allowed because state events require PL>=50 + self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual( + "You don't have permission to post that to the room. " + "user_level (0) < send_level (50)", + channel.json_body["error"], + ) + + @unittest.override_config( + { + "default_power_level_content_override": { + "public_chat": {"events": {"custom.event": 0}}, + } + }, + ) + def test_with_config_override_normal_user_can_post_state_event(self) -> None: + # Given the server has config allowing normal users to post my event type, + # and I am a normal member of a room + room_id = self.helper.create_room_as("@some_other_guy:red") + self.helper.join(room=room_id, user=self.user_id) + + # When I send a state event + path = "/rooms/{room_id}/state/custom.event/my_state_key".format( + room_id=urlparse.quote(room_id), + ) + channel = self.make_request("PUT", path, "{}") + + # Then I am allowed + self.assertEqual(200, channel.code, msg=channel.result["body"]) + + @unittest.override_config( + { + "default_power_level_content_override": { + "public_chat": {"events": {"custom.event": 0}}, + } + }, + ) + def test_any_room_override_defeats_config_override(self) -> None: + # Given the server has config allowing normal users to post my event type + # And I am a normal member of a room + # But the room was created with special permissions + extra_content: Dict[str, Any] = { + "power_level_content_override": {"events": {}}, + } + room_id = self.helper.create_room_as( + "@some_other_guy:red", extra_content=extra_content + ) + self.helper.join(room=room_id, user=self.user_id) + + # When I send a state event + path = "/rooms/{room_id}/state/custom.event/my_state_key".format( + room_id=urlparse.quote(room_id), + ) + channel = self.make_request("PUT", path, "{}") + + # Then I am not allowed + self.assertEqual(403, channel.code, msg=channel.result["body"]) + + @unittest.override_config( + { + "default_power_level_content_override": { + "public_chat": {"events": {"custom.event": 0}}, + } + }, + ) + def test_specific_room_override_defeats_config_override(self) -> None: + # Given the server has config allowing normal users to post my event type, + # and I am a normal member of a room, + # but the room was created with special permissions for this event type + extra_content = { + "power_level_content_override": {"events": {"custom.event": 1}}, + } + room_id = self.helper.create_room_as( + "@some_other_guy:red", extra_content=extra_content + ) + self.helper.join(room=room_id, user=self.user_id) + + # When I send a state event + path = "/rooms/{room_id}/state/custom.event/my_state_key".format( + room_id=urlparse.quote(room_id), + ) + channel = self.make_request("PUT", path, "{}") + + # Then I am not allowed + self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual( + "You don't have permission to post that to the room. " + + "user_level (0) < send_level (1)", + channel.json_body["error"], + ) + + @unittest.override_config( + { + "default_power_level_content_override": { + "public_chat": {"events": {"custom.event": 0}}, + "private_chat": None, + "trusted_private_chat": None, + } + }, + ) + def test_config_override_applies_only_to_specific_preset(self) -> None: + # Given the server has config for public_chats, + # and I am a normal member of a private_chat room + room_id = self.helper.create_room_as("@some_other_guy:red", is_public=False) + self.helper.invite(room=room_id, src="@some_other_guy:red", targ=self.user_id) + self.helper.join(room=room_id, user=self.user_id) + + # When I send a state event + path = "/rooms/{room_id}/state/custom.event/my_state_key".format( + room_id=urlparse.quote(room_id), + ) + channel = self.make_request("PUT", path, "{}") + + # Then I am not allowed because the public_chat config does not + # affect this room, because this room is a private_chat + self.assertEqual(403, channel.code, msg=channel.result["body"]) + self.assertEqual( + "You don't have permission to post that to the room. " + + "user_level (0) < send_level (50)", + channel.json_body["error"], + ) + + class RoomInitialSyncTestCase(RoomBase): """Tests /rooms/$room_id/initialSync.""" @@ -2598,7 +2856,9 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it # allow everything for now. - mock = Mock(return_value=make_awaitable(True)) + # `spec` argument is needed for this function mock to have `__qualname__`, which + # is needed for `Measure` metrics buried in SpamChecker. + mock = Mock(return_value=make_awaitable(True), spec=lambda *x: None) self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock) # Send a 3PID invite into the room and check that it succeeded. diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 0108337649..74b6560cbc 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +from http import HTTPStatus from typing import List, Optional from parameterized import parameterized @@ -485,30 +486,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Test that we didn't override the public read receipt self.assertIsNone(self._get_read_receipt()) - @parameterized.expand( - [ - # Old Element version, expected to send an empty body - ( - "agent1", - "Element/1.2.2 (Linux; U; Android 9; MatrixAndroidSDK_X 0.0.1)", - 200, - ), - # Old SchildiChat version, expected to send an empty body - ("agent2", "SchildiChat/1.2.1 (Android 10)", 200), - # Expected 400: Denies empty body starting at version 1.3+ - ("agent3", "Element/1.3.6 (Android 10)", 400), - ("agent4", "SchildiChat/1.3.6 (Android 11)", 400), - # Contains "Riot": Receipts with empty bodies expected - ("agent5", "Element (Riot.im) (Android 9)", 200), - # Expected 400: Does not contain "Android" - ("agent6", "Element/1.2.1", 400), - # Expected 400: Different format, missing "/" after Element; existing build that should allow empty bodies, but minimal ongoing usage - ("agent7", "Element dbg/1.1.8-dev (Android)", 400), - ] - ) - def test_read_receipt_with_empty_body( - self, name: str, user_agent: str, expected_status_code: int - ) -> None: + def test_read_receipt_with_empty_body_is_rejected(self) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) @@ -517,9 +495,9 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): "POST", f"/rooms/{self.room_id}/receipt/m.read/{res['event_id']}", access_token=self.tok2, - custom_headers=[("User-Agent", user_agent)], ) - self.assertEqual(channel.code, expected_status_code) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST) + self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON", channel.json_body) def _get_read_receipt(self) -> Optional[JsonDict]: """Syncs and returns the read receipt.""" @@ -678,12 +656,13 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self._check_unread_count(3) # Check that custom events with a body increase the unread counter. - self.helper.send_event( + result = self.helper.send_event( self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2, ) + event_id = result["event_id"] self._check_unread_count(4) # Check that edits don't increase the unread counter. @@ -693,7 +672,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): content={ "body": "hello", "msgtype": "m.text", - "m.relates_to": {"rel_type": RelationTypes.REPLACE}, + "m.relates_to": { + "rel_type": RelationTypes.REPLACE, + "event_id": event_id, + }, }, tok=self.tok2, ) diff --git a/tests/server.py b/tests/server.py index 8f30e250c8..b9f465971f 100644 --- a/tests/server.py +++ b/tests/server.py @@ -109,6 +109,17 @@ class FakeChannel: _ip: str = "127.0.0.1" _producer: Optional[Union[IPullProducer, IPushProducer]] = None resource_usage: Optional[ContextResourceUsage] = None + _request: Optional[Request] = None + + @property + def request(self) -> Request: + assert self._request is not None + return self._request + + @request.setter + def request(self, request: Request) -> None: + assert self._request is None + self._request = request @property def json_body(self): @@ -322,6 +333,8 @@ def make_request( channel = FakeChannel(site, reactor, ip=client_ip) req = request(channel, site) + channel.request = req + req.content = BytesIO(content) # Twisted expects to be at the end of the content when parsing the request. req.content.seek(0, SEEK_END) @@ -736,6 +749,7 @@ def setup_test_homeserver( if config is None: config = default_config(name, parse=True) + config.caches.resize_all_caches() config.ldap_enabled = False if "clock" not in kwargs: diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 9ee9509d3a..07e29788e5 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -75,6 +75,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( return_value=make_awaitable("!something:localhost") ) + self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = Mock( + return_value=make_awaitable("!something:localhost") + ) self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None)) self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) @@ -102,6 +105,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event + self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once() self._send_notice.assert_called_once() def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self): @@ -300,7 +304,10 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): hasn't been reached (since it's the only user and the limit is 5), so users shouldn't receive a server notice. """ - self.register_user("user", "password") + m = Mock(return_value=make_awaitable(None)) + self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = m + + user_id = self.register_user("user", "password") tok = self.login("user", "password") channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) @@ -309,6 +316,8 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): "rooms", channel.json_body, "Got invites without server notice" ) + m.assert_called_once_with(user_id) + def test_invite_with_notice(self): """Tests that, if the MAU limit is hit, the server notices user invites each user to a room in which it has sent a notice. diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index c237a8c7e2..38963ce4a7 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -154,6 +154,31 @@ class EventCacheTestCase(unittest.HomeserverTestCase): # We should have fetched the event from the DB self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) + def test_event_ref(self): + """Test that we reuse events that are still in memory but have fallen + out of the cache, rather than requesting them from the DB. + """ + + # Reset the event cache + self.store._get_event_cache.clear() + + with LoggingContext("test") as ctx: + # We keep hold of the event event though we never use it. + event = self.get_success(self.store.get_event(self.event_id)) # noqa: F841 + + # We should have fetched the event from the DB + self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) + + # Reset the event cache + self.store._get_event_cache.clear() + + with LoggingContext("test") as ctx: + self.get_success(self.store.get_event(self.event_id)) + + # Since the event is still in memory we shouldn't have fetched it + # from the DB + self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 0) + def test_dedupe(self): """Test that if we request the same event multiple times we only pull it out once. diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index 401020fd63..c7661e7186 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -393,7 +393,7 @@ class EventChainStoreTestCase(HomeserverTestCase): # We need to persist the events to the events and state_events # tables. persist_events_store._store_event_txn( - txn, [(e, EventContext()) for e in events] + txn, [(e, EventContext(self.hs.get_storage())) for e in events] ) # Actually call the function that calculates the auth chain stuff. diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 645d564d1c..d92a9ac5b7 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -58,15 +58,6 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): (room_id, event_id), ) - txn.execute( - ( - "INSERT INTO event_reference_hashes " - "(event_id, algorithm, hash) " - "VALUES (?, 'sha256', ?)" - ), - (event_id, bytearray(b"ffff")), - ) - for i in range(0, 20): self.get_success( self.store.db_pool.runInteraction("insert", insert_event, i) diff --git a/tests/test_server.py b/tests/test_server.py index f2ffbc895b..0f1eb43cbc 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -13,18 +13,28 @@ # limitations under the License. import re +from http import HTTPStatus +from typing import Tuple from twisted.internet.defer import Deferred from twisted.web.resource import Resource from synapse.api.errors import Codes, RedirectException, SynapseError from synapse.config.server import parse_listener_def -from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource -from synapse.http.site import SynapseSite +from synapse.http.server import ( + DirectServeHtmlResource, + DirectServeJsonResource, + JsonResource, + OptionsResource, + cancellable, +) +from synapse.http.site import SynapseRequest, SynapseSite from synapse.logging.context import make_deferred_yieldable +from synapse.types import JsonDict from synapse.util import Clock from tests import unittest +from tests.http.server._base import EndpointCancellationTestHelperMixin from tests.server import ( FakeSite, ThreadedMemoryReactorClock, @@ -363,3 +373,100 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) + + +class CancellableDirectServeJsonResource(DirectServeJsonResource): + def __init__(self, clock: Clock): + super().__init__() + self.clock = clock + + @cancellable + async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + +class CancellableDirectServeHtmlResource(DirectServeHtmlResource): + ERROR_TEMPLATE = "{code} {msg}" + + def __init__(self, clock: Clock): + super().__init__() + self.clock = clock + + @cancellable + async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, bytes]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, b"ok" + + async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, bytes]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, b"ok" + + +class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin): + """Tests for `DirectServeJsonResource` cancellation.""" + + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + self.clock = Clock(self.reactor) + self.resource = CancellableDirectServeJsonResource(self.clock) + self.site = FakeSite(self.resource, self.reactor) + + def test_cancellable_disconnect(self) -> None: + """Test that handlers with the `@cancellable` flag can be cancelled.""" + channel = make_request( + self.reactor, self.site, "GET", "/sleep", await_result=False + ) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=True, + expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN}, + ) + + def test_uncancellable_disconnect(self) -> None: + """Test that handlers without the `@cancellable` flag cannot be cancelled.""" + channel = make_request( + self.reactor, self.site, "POST", "/sleep", await_result=False + ) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=False, + expected_body={"result": True}, + ) + + +class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin): + """Tests for `DirectServeHtmlResource` cancellation.""" + + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + self.clock = Clock(self.reactor) + self.resource = CancellableDirectServeHtmlResource(self.clock) + self.site = FakeSite(self.resource, self.reactor) + + def test_cancellable_disconnect(self) -> None: + """Test that handlers with the `@cancellable` flag can be cancelled.""" + channel = make_request( + self.reactor, self.site, "GET", "/sleep", await_result=False + ) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=True, + expected_body=b"499 Request cancelled", + ) + + def test_uncancellable_disconnect(self) -> None: + """Test that handlers without the `@cancellable` flag cannot be cancelled.""" + channel = make_request( + self.reactor, self.site, "POST", "/sleep", await_result=False + ) + self._test_disconnect( + self.reactor, channel, expect_cancellation=False, expected_body=b"ok" + ) diff --git a/tests/test_state.py b/tests/test_state.py index e4baa69137..651ec1c7d4 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -88,6 +88,9 @@ class _DummyStore: return groups + async def get_state_ids_for_group(self, state_group): + return self._group_to_state[state_group] + async def store_state_group( self, event_id, room_id, prev_group, delta_ids, current_state_ids ): diff --git a/tests/test_visibility.py b/tests/test_visibility.py index d0230f9ebb..7a9b01ef9d 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -234,7 +234,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[])) event.internal_metadata.outlier = True self.get_success( - self.storage.persistence.persist_event(event, EventContext.for_outlier()) + self.storage.persistence.persist_event( + event, EventContext.for_outlier(self.storage) + ) ) return event diff --git a/tests/unittest.py b/tests/unittest.py index 9afa68c164..e7f255b4fa 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -831,7 +831,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase): self.site, method=method, path=path, - content=content or "", + content=content if content is not None else "", shorthand=False, await_result=await_result, custom_headers=custom_headers, diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 321fc1776f..67173a4f5b 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -14,8 +14,9 @@ from typing import List -from unittest.mock import Mock +from unittest.mock import Mock, patch +from synapse.metrics.jemalloc import JemallocStats from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries from synapse.util.caches.treecache import TreeCache @@ -316,3 +317,58 @@ class TimeEvictionTestCase(unittest.HomeserverTestCase): self.assertEqual(cache.get("key1"), None) self.assertEqual(cache.get("key2"), 3) + + +class MemoryEvictionTestCase(unittest.HomeserverTestCase): + @override_config( + { + "caches": { + "cache_autotuning": { + "max_cache_memory_usage": "700M", + "target_cache_memory_usage": "500M", + "min_cache_ttl": "5m", + } + } + } + ) + @patch("synapse.util.caches.lrucache.get_jemalloc_stats") + def test_evict_memory(self, jemalloc_interface) -> None: + mock_jemalloc_class = Mock(spec=JemallocStats) + jemalloc_interface.return_value = mock_jemalloc_class + + # set the return value of get_stat() to be greater than max_cache_memory_usage + mock_jemalloc_class.get_stat.return_value = 924288000 + + setup_expire_lru_cache_entries(self.hs) + cache = LruCache(4, clock=self.hs.get_clock()) + + cache["key1"] = 1 + cache["key2"] = 2 + + # advance the reactor less than the min_cache_ttl + self.reactor.advance(60 * 2) + + # our items should still be in the cache + self.assertEqual(cache.get("key1"), 1) + self.assertEqual(cache.get("key2"), 2) + + # advance the reactor past the min_cache_ttl + self.reactor.advance(60 * 6) + + # the items should be cleared from cache + self.assertEqual(cache.get("key1"), None) + self.assertEqual(cache.get("key2"), None) + + # add more stuff to caches + cache["key1"] = 1 + cache["key2"] = 2 + + # set the return value of get_stat() to be lower than target_cache_memory_usage + mock_jemalloc_class.get_stat.return_value = 10000 + + # advance the reactor past the min_cache_ttl + self.reactor.advance(60 * 6) + + # the items should still be in the cache + self.assertEqual(cache.get("key1"), 1) + self.assertEqual(cache.get("key2"), 2) |