diff options
163 files changed, 3382 insertions, 2381 deletions
diff --git a/.buildkite/worker-blacklist b/.buildkite/worker-blacklist index 7950d19db3..158ab79154 100644 --- a/.buildkite/worker-blacklist +++ b/.buildkite/worker-blacklist @@ -34,33 +34,8 @@ Device list doesn't change if remote server is down Remote servers cannot set power levels in rooms without existing powerlevels Remote servers should reject attempts by non-creators to set the power levels -# new failures as of https://github.com/matrix-org/sytest/pull/753 -GET /rooms/:room_id/messages returns a message -GET /rooms/:room_id/messages lazy loads members correctly -Read receipts are sent as events -Only original members of the room can see messages from erased users -Device deletion propagates over federation -If user leaves room, remote user changes device and rejoins we see update in /sync and /keys/changes -Changing user-signing key notifies local users -Newly updated tags appear in an incremental v2 /sync +# https://buildkite.com/matrix-dot-org/synapse/builds/6134#6f67bf47-e234-474d-80e8-c6e1868b15c5 Server correctly handles incoming m.device_list_update -Local device key changes get to remote servers with correct prev_id -AS-ghosted users can use rooms via AS -Ghost user must register before joining room -Test that a message is pushed -Invites are pushed -Rooms with aliases are correctly named in pushed -Rooms with names are correctly named in pushed -Rooms with canonical alias are correctly named in pushed -Rooms with many users are correctly pushed -Don't get pushed for rooms you've muted -Rejected events are not pushed -Test that rejected pushers are removed. -Events come down the correct room - -# https://buildkite.com/matrix-dot-org/sytest/builds/326#cca62404-a88a-4fcb-ad41-175fd3377603 -Presence changes to UNAVAILABLE are reported to remote room members -If remote user leaves room, changes device and rejoins we see update in sync -uploading self-signing key notifies over federation -Inbound federation can receive redacted events -Outbound federation can request missing events + +# this fails reliably with a torture level of 100 due to https://github.com/matrix-org/synapse/issues/6536 +Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state diff --git a/AUTHORS.rst b/AUTHORS.rst index b8b31a5b47..014f16d4a2 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -46,3 +46,6 @@ Joseph Weston <joseph at weston.cloud> Benjamin Saunders <ben.e.saunders at gmail dot com> * Documentation improvements + +Werner Sembach <werner.sembach at fau dot de> + * Automatically remove a group/community when it is empty diff --git a/changelog.d/6245.misc b/changelog.d/6245.misc new file mode 100644 index 0000000000..a3e6b8296e --- /dev/null +++ b/changelog.d/6245.misc @@ -0,0 +1 @@ +Split out state storage into separate data store. diff --git a/changelog.d/6349.feature b/changelog.d/6349.feature new file mode 100644 index 0000000000..56c4fbf78e --- /dev/null +++ b/changelog.d/6349.feature @@ -0,0 +1 @@ +Implement v2 APIs for the `send_join` and `send_leave` federation endpoints (as described in [MSC1802](https://github.com/matrix-org/matrix-doc/pull/1802)). diff --git a/changelog.d/6377.bugfix b/changelog.d/6377.bugfix new file mode 100644 index 0000000000..ccda96962f --- /dev/null +++ b/changelog.d/6377.bugfix @@ -0,0 +1 @@ +Prevent redacted events from being returned during message search. \ No newline at end of file diff --git a/changelog.d/6385.bugfix b/changelog.d/6385.bugfix new file mode 100644 index 0000000000..7a2bc02170 --- /dev/null +++ b/changelog.d/6385.bugfix @@ -0,0 +1 @@ +Prevent error on trying to search a upgraded room when the server is not in the predecessor room. \ No newline at end of file diff --git a/changelog.d/6394.feature b/changelog.d/6394.feature new file mode 100644 index 0000000000..1a0e8845ad --- /dev/null +++ b/changelog.d/6394.feature @@ -0,0 +1 @@ +Add a develop script to generate full SQL schemas. \ No newline at end of file diff --git a/changelog.d/6411.feature b/changelog.d/6411.feature new file mode 100644 index 0000000000..ebea4a208d --- /dev/null +++ b/changelog.d/6411.feature @@ -0,0 +1 @@ +Allow custom SAML username mapping functinality through an external provider plugin. \ No newline at end of file diff --git a/changelog.d/6453.feature b/changelog.d/6453.feature new file mode 100644 index 0000000000..e7bb801c6a --- /dev/null +++ b/changelog.d/6453.feature @@ -0,0 +1 @@ +Automatically delete empty groups/communities. diff --git a/changelog.d/6486.bugfix b/changelog.d/6486.bugfix new file mode 100644 index 0000000000..b98c5a9ae5 --- /dev/null +++ b/changelog.d/6486.bugfix @@ -0,0 +1 @@ +Improve performance of looking up cross-signing keys. diff --git a/changelog.d/6496.misc b/changelog.d/6496.misc new file mode 100644 index 0000000000..19c6e926b8 --- /dev/null +++ b/changelog.d/6496.misc @@ -0,0 +1 @@ +Port synapse.handlers.initial_sync to async/await. diff --git a/changelog.d/6502.removal b/changelog.d/6502.removal new file mode 100644 index 0000000000..0b72261d58 --- /dev/null +++ b/changelog.d/6502.removal @@ -0,0 +1 @@ +Remove redundant code from event authorisation implementation. diff --git a/changelog.d/6504.misc b/changelog.d/6504.misc new file mode 100644 index 0000000000..7c873459af --- /dev/null +++ b/changelog.d/6504.misc @@ -0,0 +1 @@ +Port handlers.account_data and handlers.account_validity to async/await. diff --git a/changelog.d/6505.misc b/changelog.d/6505.misc new file mode 100644 index 0000000000..3a75b2d9dd --- /dev/null +++ b/changelog.d/6505.misc @@ -0,0 +1 @@ +Make `make_deferred_yieldable` to work with async/await. diff --git a/changelog.d/6506.misc b/changelog.d/6506.misc new file mode 100644 index 0000000000..99d7a70bcf --- /dev/null +++ b/changelog.d/6506.misc @@ -0,0 +1 @@ +Remove `SnapshotCache` in favour of `ResponseCache`. diff --git a/changelog.d/6510.misc b/changelog.d/6510.misc new file mode 100644 index 0000000000..214f06539b --- /dev/null +++ b/changelog.d/6510.misc @@ -0,0 +1 @@ +Change phone home stats to not assume there is a single database and report information about the database used by the main data store. diff --git a/changelog.d/6511.misc b/changelog.d/6511.misc new file mode 100644 index 0000000000..19ce435e68 --- /dev/null +++ b/changelog.d/6511.misc @@ -0,0 +1 @@ +Move database config from apps into HomeServer object. diff --git a/changelog.d/6512.misc b/changelog.d/6512.misc new file mode 100644 index 0000000000..37a8099eec --- /dev/null +++ b/changelog.d/6512.misc @@ -0,0 +1 @@ +Silence mypy errors for files outside those specified. diff --git a/changelog.d/6513.misc b/changelog.d/6513.misc new file mode 100644 index 0000000000..36700f5657 --- /dev/null +++ b/changelog.d/6513.misc @@ -0,0 +1 @@ +Remove all assumptions of there being a single phyiscal DB apart from the `synapse.config`. diff --git a/changelog.d/6514.bugfix b/changelog.d/6514.bugfix new file mode 100644 index 0000000000..6dc1985c24 --- /dev/null +++ b/changelog.d/6514.bugfix @@ -0,0 +1 @@ +Fix race which occasionally caused deleted devices to reappear. diff --git a/changelog.d/6515.misc b/changelog.d/6515.misc new file mode 100644 index 0000000000..a9c303ed1c --- /dev/null +++ b/changelog.d/6515.misc @@ -0,0 +1 @@ +Clean up some logging when handling incoming events over federation. diff --git a/changelog.d/6517.misc b/changelog.d/6517.misc new file mode 100644 index 0000000000..c6ffed9952 --- /dev/null +++ b/changelog.d/6517.misc @@ -0,0 +1 @@ +Port some of FederationHandler to async/await. \ No newline at end of file diff --git a/changelog.d/6522.bugfix b/changelog.d/6522.bugfix new file mode 100644 index 0000000000..ccda96962f --- /dev/null +++ b/changelog.d/6522.bugfix @@ -0,0 +1 @@ +Prevent redacted events from being returned during message search. \ No newline at end of file diff --git a/changelog.d/6523.feature b/changelog.d/6523.feature new file mode 100644 index 0000000000..798fa143df --- /dev/null +++ b/changelog.d/6523.feature @@ -0,0 +1 @@ +Add option `limit_profile_requests_to_users_who_share_rooms` to prevent requirement of a local user sharing a room with another user to query their profile information. diff --git a/changelog.d/6534.misc b/changelog.d/6534.misc new file mode 100644 index 0000000000..7df6bb442a --- /dev/null +++ b/changelog.d/6534.misc @@ -0,0 +1 @@ +Test more folders against mypy. diff --git a/changelog.d/6537.misc b/changelog.d/6537.misc new file mode 100644 index 0000000000..3543153584 --- /dev/null +++ b/changelog.d/6537.misc @@ -0,0 +1 @@ +Update `mypy` to new version. diff --git a/changelog.d/6538.misc b/changelog.d/6538.misc new file mode 100644 index 0000000000..cb4fd56948 --- /dev/null +++ b/changelog.d/6538.misc @@ -0,0 +1 @@ +Adjust the sytest blacklist for worker mode. diff --git a/changelog.d/6541.doc b/changelog.d/6541.doc new file mode 100644 index 0000000000..c20029edc0 --- /dev/null +++ b/changelog.d/6541.doc @@ -0,0 +1 @@ +Document the Room Shutdown Admin API. \ No newline at end of file diff --git a/changelog.d/6546.feature b/changelog.d/6546.feature new file mode 100644 index 0000000000..954aacb0d0 --- /dev/null +++ b/changelog.d/6546.feature @@ -0,0 +1 @@ +Add an export_signing_key script to extract the public part of signing keys when rotating them. diff --git a/changelog.d/6555.bugfix b/changelog.d/6555.bugfix new file mode 100644 index 0000000000..86a5a56cf6 --- /dev/null +++ b/changelog.d/6555.bugfix @@ -0,0 +1 @@ +Fix missing row in device_max_stream_id that could cause unable to decrypt errors after server restart. \ No newline at end of file diff --git a/changelog.d/6557.misc b/changelog.d/6557.misc new file mode 100644 index 0000000000..80e7eaedb8 --- /dev/null +++ b/changelog.d/6557.misc @@ -0,0 +1 @@ +Remove unused `get_pagination_rows` methods from `EventSource` classes. diff --git a/changelog.d/6558.misc b/changelog.d/6558.misc new file mode 100644 index 0000000000..a7572f1a85 --- /dev/null +++ b/changelog.d/6558.misc @@ -0,0 +1 @@ +Clean up logs from the push notifier at startup. \ No newline at end of file diff --git a/changelog.d/6559.misc b/changelog.d/6559.misc new file mode 100644 index 0000000000..8bca37457d --- /dev/null +++ b/changelog.d/6559.misc @@ -0,0 +1 @@ +Port `synapse.handlers.admin` and `synapse.handlers.deactivate_account` to async/await. diff --git a/changelog.d/6564.misc b/changelog.d/6564.misc new file mode 100644 index 0000000000..f644f5868b --- /dev/null +++ b/changelog.d/6564.misc @@ -0,0 +1 @@ +Change `EventContext` to use the `Storage` class, in preparation for moving state database queries to a separate data store. diff --git a/changelog.d/6565.misc b/changelog.d/6565.misc new file mode 100644 index 0000000000..e83f245bf0 --- /dev/null +++ b/changelog.d/6565.misc @@ -0,0 +1 @@ +Add assertion that schema delta file names are unique. diff --git a/changelog.d/6570.misc b/changelog.d/6570.misc new file mode 100644 index 0000000000..e89955a51e --- /dev/null +++ b/changelog.d/6570.misc @@ -0,0 +1 @@ +Improve diagnostics on database upgrade failure. diff --git a/changelog.d/6571.bugfix b/changelog.d/6571.bugfix new file mode 100644 index 0000000000..e38ea7b4f7 --- /dev/null +++ b/changelog.d/6571.bugfix @@ -0,0 +1 @@ +Fix a bug which meant that we did not send systemd notifications on startup if acme was enabled. diff --git a/docs/admin_api/shutdown_room.md b/docs/admin_api/shutdown_room.md new file mode 100644 index 0000000000..54ce1cd234 --- /dev/null +++ b/docs/admin_api/shutdown_room.md @@ -0,0 +1,72 @@ +# Shutdown room API + +Shuts down a room, preventing new joins and moves local users and room aliases automatically +to a new room. The new room will be created with the user specified by the +`new_room_user_id` parameter as room administrator and will contain a message +explaining what happened. Users invited to the new room will have power level +-10 by default, and thus be unable to speak. The old room's power levels will be changed to +disallow any further invites or joins. + +The local server will only have the power to move local user and room aliases to +the new room. Users on other servers will be unaffected. + +## API + +You will need to authenticate with an access token for an admin user. + +### URL + +`POST /_synapse/admin/v1/shutdown_room/{room_id}` + +### URL Parameters + +* `room_id` - The ID of the room (e.g `!someroom:example.com`) + +### JSON Body Parameters + +* `new_room_user_id` - Required. A string representing the user ID of the user that will admin + the new room that all users in the old room will be moved to. +* `room_name` - Optional. A string representing the name of the room that new users will be + invited to. +* `message` - Optional. A string containing the first message that will be sent as + `new_room_user_id` in the new room. Ideally this will clearly convey why the + original room was shut down. + +If not specified, the default value of `room_name` is "Content Violation +Notification". The default value of `message` is "Sharing illegal content on +othis server is not permitted and rooms in violation will be blocked." + +### Response Parameters + +* `kicked_users` - An integer number representing the number of users that + were kicked. +* `failed_to_kick_users` - An integer number representing the number of users + that were not kicked. +* `local_aliases` - An array of strings representing the local aliases that were migrated from + the old room to the new. +* `new_room_id` - A string representing the room ID of the new room. + +## Example + +Request: + +``` +POST /_synapse/admin/v1/shutdown_room/!somebadroom%3Aexample.com + +{ + "new_room_user_id": "@someuser:example.com", + "room_name": "Content Violation Notification", + "message": "Bad Room has been shutdown due to content violations on this server. Please review our Terms of Service." +} +``` + +Response: + +``` +{ + "kicked_users": 5, + "failed_to_kick_users": 0, + "local_aliases": ["#badroom:example.com", "#evilsaloon:example.com], + "new_room_id": "!newroomid:example.com", +}, +``` diff --git a/docs/code_style.md b/docs/code_style.md index f983f72d6c..71aecd41f7 100644 --- a/docs/code_style.md +++ b/docs/code_style.md @@ -137,6 +137,7 @@ Some guidelines follow: correctly handles the top-level option being set to `None` (as it will be if no sub-options are enabled). - Lines should be wrapped at 80 characters. +- Use two-space indents. Example: @@ -155,13 +156,13 @@ Example: # Settings for the frobber # frobber: - # frobbing speed. Defaults to 1. - # - #speed: 10 + # frobbing speed. Defaults to 1. + # + #speed: 10 - # frobbing distance. Defaults to 1000. - # - #distance: 100 + # frobbing distance. Defaults to 1000. + # + #distance: 100 Note that the sample configuration is generated from the synapse code and is maintained by a script, `scripts-dev/generate_sample_config`. diff --git a/docs/saml_mapping_providers.md b/docs/saml_mapping_providers.md new file mode 100644 index 0000000000..92f2380488 --- /dev/null +++ b/docs/saml_mapping_providers.md @@ -0,0 +1,77 @@ +# SAML Mapping Providers + +A SAML mapping provider is a Python class (loaded via a Python module) that +works out how to map attributes of a SAML response object to Matrix-specific +user attributes. Details such as user ID localpart, displayname, and even avatar +URLs are all things that can be mapped from talking to a SSO service. + +As an example, a SSO service may return the email address +"john.smith@example.com" for a user, whereas Synapse will need to figure out how +to turn that into a displayname when creating a Matrix user for this individual. +It may choose `John Smith`, or `Smith, John [Example.com]` or any number of +variations. As each Synapse configuration may want something different, this is +where SAML mapping providers come into play. + +## Enabling Providers + +External mapping providers are provided to Synapse in the form of an external +Python module. Retrieve this module from [PyPi](https://pypi.org) or elsewhere, +then tell Synapse where to look for the handler class by editing the +`saml2_config.user_mapping_provider.module` config option. + +`saml2_config.user_mapping_provider.config` allows you to provide custom +configuration options to the module. Check with the module's documentation for +what options it provides (if any). The options listed by default are for the +user mapping provider built in to Synapse. If using a custom module, you should +comment these options out and use those specified by the module instead. + +## Building a Custom Mapping Provider + +A custom mapping provider must specify the following methods: + +* `__init__(self, parsed_config)` + - Arguments: + - `parsed_config` - A configuration object that is the return value of the + `parse_config` method. You should set any configuration options needed by + the module here. +* `saml_response_to_user_attributes(self, saml_response, failures)` + - Arguments: + - `saml_response` - A `saml2.response.AuthnResponse` object to extract user + information from. + - `failures` - An `int` that represents the amount of times the returned + mxid localpart mapping has failed. This should be used + to create a deduplicated mxid localpart which should be + returned instead. For example, if this method returns + `john.doe` as the value of `mxid_localpart` in the returned + dict, and that is already taken on the homeserver, this + method will be called again with the same parameters but + with failures=1. The method should then return a different + `mxid_localpart` value, such as `john.doe1`. + - This method must return a dictionary, which will then be used by Synapse + to build a new user. The following keys are allowed: + * `mxid_localpart` - Required. The mxid localpart of the new user. + * `displayname` - The displayname of the new user. If not provided, will default to + the value of `mxid_localpart`. +* `parse_config(config)` + - This method should have the `@staticmethod` decoration. + - Arguments: + - `config` - A `dict` representing the parsed content of the + `saml2_config.user_mapping_provider.config` homeserver config option. + Runs on homeserver startup. Providers should extract any option values + they need here. + - Whatever is returned will be passed back to the user mapping provider module's + `__init__` method during construction. +* `get_saml_attributes(config)` + - This method should have the `@staticmethod` decoration. + - Arguments: + - `config` - A object resulting from a call to `parse_config`. + - Returns a tuple of two sets. The first set equates to the saml auth + response attributes that are required for the module to function, whereas + the second set consists of those attributes which can be used if available, + but are not necessary. + +## Synapse's Default Provider + +Synapse has a built-in SAML mapping provider if a custom provider isn't +specified in the config. It is located at +[`synapse.handlers.saml_handler.DefaultSamlMappingProvider`](../synapse/handlers/saml_handler.py). diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 10664ae8f7..e3b05423b8 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -54,6 +54,13 @@ pid_file: DATADIR/homeserver.pid # #require_auth_for_profile_requests: true +# Uncomment to require a user to share a room with another user in order +# to retrieve their profile information. Only checked on Client-Server +# requests. Profile requests from other servers should be checked by the +# requesting server. Defaults to 'false'. +# +#limit_profile_requests_to_users_who_share_rooms: true + # If set to 'true', removes the need for authentication to access the server's # public rooms directory through the client API, meaning that anyone can # query the room directory. Defaults to 'false'. @@ -1115,14 +1122,19 @@ metrics_flags: signing_key_path: "CONFDIR/SERVERNAME.signing.key" # The keys that the server used to sign messages with but won't use -# to sign new messages. E.g. it has lost its private key +# to sign new messages. # -#old_signing_keys: -# "ed25519:auto": -# # Base64 encoded public key -# key: "The public part of your old signing key." -# # Millisecond POSIX timestamp when the key expired. -# expired_ts: 123456789123 +old_signing_keys: + # For each key, `key` should be the base64-encoded public key, and + # `expired_ts`should be the time (in milliseconds since the unix epoch) that + # it was last used. + # + # It is possible to build an entry from an old signing.key file using the + # `export_signing_key` script which is provided with synapse. + # + # For example: + # + #"ed25519:id": { key: "base64string", expired_ts: 123456789123 } # How long key response published by this server is valid for. # Used to set the valid_until_ts in /key/v2 APIs. @@ -1250,33 +1262,58 @@ saml2_config: # #config_path: "CONFDIR/sp_conf.py" - # the lifetime of a SAML session. This defines how long a user has to + # The lifetime of a SAML session. This defines how long a user has to # complete the authentication process, if allow_unsolicited is unset. # The default is 5 minutes. # #saml_session_lifetime: 5m - # The SAML attribute (after mapping via the attribute maps) to use to derive - # the Matrix ID from. 'uid' by default. + # An external module can be provided here as a custom solution to + # mapping attributes returned from a saml provider onto a matrix user. # - #mxid_source_attribute: displayName - - # The mapping system to use for mapping the saml attribute onto a matrix ID. - # Options include: - # * 'hexencode' (which maps unpermitted characters to '=xx') - # * 'dotreplace' (which replaces unpermitted characters with '.'). - # The default is 'hexencode'. - # - #mxid_mapping: dotreplace + user_mapping_provider: + # The custom module's class. Uncomment to use a custom module. + # + #module: mapping_provider.SamlMappingProvider - # In previous versions of synapse, the mapping from SAML attribute to MXID was - # always calculated dynamically rather than stored in a table. For backwards- - # compatibility, we will look for user_ids matching such a pattern before - # creating a new account. + # Custom configuration values for the module. Below options are + # intended for the built-in provider, they should be changed if + # using a custom module. This section will be passed as a Python + # dictionary to the module's `parse_config` method. + # + config: + # The SAML attribute (after mapping via the attribute maps) to use + # to derive the Matrix ID from. 'uid' by default. + # + # Note: This used to be configured by the + # saml2_config.mxid_source_attribute option. If that is still + # defined, its value will be used instead. + # + #mxid_source_attribute: displayName + + # The mapping system to use for mapping the saml attribute onto a + # matrix ID. + # + # Options include: + # * 'hexencode' (which maps unpermitted characters to '=xx') + # * 'dotreplace' (which replaces unpermitted characters with + # '.'). + # The default is 'hexencode'. + # + # Note: This used to be configured by the + # saml2_config.mxid_mapping option. If that is still defined, its + # value will be used instead. + # + #mxid_mapping: dotreplace + + # In previous versions of synapse, the mapping from SAML attribute to + # MXID was always calculated dynamically rather than stored in a + # table. For backwards- compatibility, we will look for user_ids + # matching such a pattern before creating a new account. # # This setting controls the SAML attribute which will be used for this - # backwards-compatibility lookup. Typically it should be 'uid', but if the - # attribute maps are changed, it may be necessary to change it. + # backwards-compatibility lookup. Typically it should be 'uid', but if + # the attribute maps are changed, it may be necessary to change it. # # The default is 'uid'. # diff --git a/mypy.ini b/mypy.ini index 1d77c0ecc8..a66434b76b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,7 +1,7 @@ [mypy] namespace_packages = True plugins = mypy_zope:plugin -follow_imports = normal +follow_imports = silent check_untyped_defs = True show_error_codes = True show_traceback = True diff --git a/scripts-dev/make_full_schema.sh b/scripts-dev/make_full_schema.sh new file mode 100755 index 0000000000..60e8970a35 --- /dev/null +++ b/scripts-dev/make_full_schema.sh @@ -0,0 +1,184 @@ +#!/bin/bash +# +# This script generates SQL files for creating a brand new Synapse DB with the latest +# schema, on both SQLite3 and Postgres. +# +# It does so by having Synapse generate an up-to-date SQLite DB, then running +# synapse_port_db to convert it to Postgres. It then dumps the contents of both. + +POSTGRES_HOST="localhost" +POSTGRES_DB_NAME="synapse_full_schema.$$" + +SQLITE_FULL_SCHEMA_OUTPUT_FILE="full.sql.sqlite" +POSTGRES_FULL_SCHEMA_OUTPUT_FILE="full.sql.postgres" + +REQUIRED_DEPS=("matrix-synapse" "psycopg2") + +usage() { + echo + echo "Usage: $0 -p <postgres_username> -o <path> [-c] [-n] [-h]" + echo + echo "-p <postgres_username>" + echo " Username to connect to local postgres instance. The password will be requested" + echo " during script execution." + echo "-c" + echo " CI mode. Enables coverage tracking and prints every command that the script runs." + echo "-o <path>" + echo " Directory to output full schema files to." + echo "-h" + echo " Display this help text." +} + +while getopts "p:co:h" opt; do + case $opt in + p) + POSTGRES_USERNAME=$OPTARG + ;; + c) + # Print all commands that are being executed + set -x + + # Modify required dependencies for coverage + REQUIRED_DEPS+=("coverage" "coverage-enable-subprocess") + + COVERAGE=1 + ;; + o) + command -v realpath > /dev/null || (echo "The -o flag requires the 'realpath' binary to be installed" && exit 1) + OUTPUT_DIR="$(realpath "$OPTARG")" + ;; + h) + usage + exit + ;; + \?) + echo "ERROR: Invalid option: -$OPTARG" >&2 + usage + exit + ;; + esac +done + +# Check that required dependencies are installed +unsatisfied_requirements=() +for dep in "${REQUIRED_DEPS[@]}"; do + pip show "$dep" --quiet || unsatisfied_requirements+=("$dep") +done +if [ ${#unsatisfied_requirements} -ne 0 ]; then + echo "Please install the following python packages: ${unsatisfied_requirements[*]}" + exit 1 +fi + +if [ -z "$POSTGRES_USERNAME" ]; then + echo "No postgres username supplied" + usage + exit 1 +fi + +if [ -z "$OUTPUT_DIR" ]; then + echo "No output directory supplied" + usage + exit 1 +fi + +# Create the output directory if it doesn't exist +mkdir -p "$OUTPUT_DIR" + +read -rsp "Postgres password for '$POSTGRES_USERNAME': " POSTGRES_PASSWORD +echo "" + +# Exit immediately if a command fails +set -e + +# cd to root of the synapse directory +cd "$(dirname "$0")/.." + +# Create temporary SQLite and Postgres homeserver db configs and key file +TMPDIR=$(mktemp -d) +KEY_FILE=$TMPDIR/test.signing.key # default Synapse signing key path +SQLITE_CONFIG=$TMPDIR/sqlite.conf +SQLITE_DB=$TMPDIR/homeserver.db +POSTGRES_CONFIG=$TMPDIR/postgres.conf + +# Ensure these files are delete on script exit +trap 'rm -rf $TMPDIR' EXIT + +cat > "$SQLITE_CONFIG" <<EOF +server_name: "test" + +signing_key_path: "$KEY_FILE" +macaroon_secret_key: "abcde" + +report_stats: false + +database: + name: "sqlite3" + args: + database: "$SQLITE_DB" + +# Suppress the key server warning. +trusted_key_servers: [] +EOF + +cat > "$POSTGRES_CONFIG" <<EOF +server_name: "test" + +signing_key_path: "$KEY_FILE" +macaroon_secret_key: "abcde" + +report_stats: false + +database: + name: "psycopg2" + args: + user: "$POSTGRES_USERNAME" + host: "$POSTGRES_HOST" + password: "$POSTGRES_PASSWORD" + database: "$POSTGRES_DB_NAME" + +# Suppress the key server warning. +trusted_key_servers: [] +EOF + +# Generate the server's signing key. +echo "Generating SQLite3 db schema..." +python -m synapse.app.homeserver --generate-keys -c "$SQLITE_CONFIG" + +# Make sure the SQLite3 database is using the latest schema and has no pending background update. +echo "Running db background jobs..." +scripts-dev/update_database --database-config "$SQLITE_CONFIG" + +# Create the PostgreSQL database. +echo "Creating postgres database..." +createdb $POSTGRES_DB_NAME + +echo "Copying data from SQLite3 to Postgres with synapse_port_db..." +if [ -z "$COVERAGE" ]; then + # No coverage needed + scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG" +else + # Coverage desired + coverage run scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG" +fi + +# Delete schema_version, applied_schema_deltas and applied_module_schemas tables +# This needs to be done after synapse_port_db is run +echo "Dropping unwanted db tables..." +SQL=" +DROP TABLE schema_version; +DROP TABLE applied_schema_deltas; +DROP TABLE applied_module_schemas; +" +sqlite3 "$SQLITE_DB" <<< "$SQL" +psql $POSTGRES_DB_NAME -U "$POSTGRES_USERNAME" -w <<< "$SQL" + +echo "Dumping SQLite3 schema to '$OUTPUT_DIR/$SQLITE_FULL_SCHEMA_OUTPUT_FILE'..." +sqlite3 "$SQLITE_DB" ".dump" > "$OUTPUT_DIR/$SQLITE_FULL_SCHEMA_OUTPUT_FILE" + +echo "Dumping Postgres schema to '$OUTPUT_DIR/$POSTGRES_FULL_SCHEMA_OUTPUT_FILE'..." +pg_dump --format=plain --no-tablespaces --no-acl --no-owner $POSTGRES_DB_NAME | sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > "$OUTPUT_DIR/$POSTGRES_FULL_SCHEMA_OUTPUT_FILE" + +echo "Cleaning up temporary Postgres database..." +dropdb $POSTGRES_DB_NAME + +echo "Done! Files dumped to: $OUTPUT_DIR" diff --git a/scripts-dev/update_database b/scripts-dev/update_database index 1776d202c5..1d62f0403a 100755 --- a/scripts-dev/update_database +++ b/scripts-dev/update_database @@ -26,8 +26,6 @@ from synapse.config.homeserver import HomeServerConfig from synapse.metrics.background_process_metrics import run_as_background_process from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.engines import create_engine -from synapse.storage.prepare_database import prepare_database logger = logging.getLogger("update_database") @@ -35,21 +33,11 @@ logger = logging.getLogger("update_database") class MockHomeserver(HomeServer): DATASTORE_CLASS = DataStore - def __init__(self, config, database_engine, db_conn, **kwargs): + def __init__(self, config, **kwargs): super(MockHomeserver, self).__init__( - config.server_name, - reactor=reactor, - config=config, - database_engine=database_engine, - **kwargs + config.server_name, reactor=reactor, config=config, **kwargs ) - self.database_engine = database_engine - self.db_conn = db_conn - - def get_db_conn(self): - return self.db_conn - if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -85,25 +73,11 @@ if __name__ == "__main__": config = HomeServerConfig() config.parse_config_dict(hs_config, "", "") - # Create the database engine and a connection to it. - database_engine = create_engine(config.database_config) - db_conn = database_engine.module.connect( - **{ - k: v - for k, v in config.database_config.get("args", {}).items() - if not k.startswith("cp_") - } - ) - - # Update the database to the latest schema. - prepare_database(db_conn, database_engine, config=config) - db_conn.commit() - # Instantiate and initialise the homeserver object. - hs = MockHomeserver( - config, database_engine, db_conn, db_config=config.database_config, - ) - # setup instantiates the store within the homeserver object. + hs = MockHomeserver(config) + + # Setup instantiates the store within the homeserver object and updates the + # DB. hs.setup() store = hs.get_datastore() diff --git a/scripts/export_signing_key b/scripts/export_signing_key new file mode 100755 index 0000000000..8aec9d802b --- /dev/null +++ b/scripts/export_signing_key @@ -0,0 +1,94 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import sys +import time +from typing import Optional + +import nacl.signing +from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys + + +def exit(status: int = 0, message: Optional[str] = None): + if message: + print(message, file=sys.stderr) + sys.exit(status) + + +def format_plain(public_key: nacl.signing.VerifyKey): + print( + "%s:%s %s" + % (public_key.alg, public_key.version, encode_verify_key_base64(public_key),) + ) + + +def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int): + print( + ' "%s:%s": { key: "%s", expired_ts: %i }' + % ( + public_key.alg, + public_key.version, + encode_verify_key_base64(public_key), + expiry_ts, + ) + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "key_file", nargs="+", type=argparse.FileType("r"), help="The key file to read", + ) + + parser.add_argument( + "-x", + action="store_true", + dest="for_config", + help="format the output for inclusion in the old_signing_keys config setting", + ) + + parser.add_argument( + "--expiry-ts", + type=int, + default=int(time.time() * 1000) + 6*3600000, + help=( + "The expiry time to use for -x, in milliseconds since 1970. The default " + "is (now+6h)." + ), + ) + + args = parser.parse_args() + + formatter = ( + (lambda k: format_for_config(k, args.expiry_ts)) + if args.for_config + else format_plain + ) + + keys = [] + for file in args.key_file: + try: + res = read_signing_keys(file) + except Exception as e: + exit( + status=1, + message="Error reading key from file %s: %s %s" + % (file.name, type(e), e), + ) + res = [] + for key in res: + formatter(get_verify_key(key)) diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index e393a9b2f7..eb927f2094 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -30,6 +30,7 @@ import yaml from twisted.enterprise import adbapi from twisted.internet import defer, reactor +from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.logging.context import PreserveLoggingContext from synapse.storage._base import LoggingTransaction @@ -50,12 +51,13 @@ from synapse.storage.data_stores.main.registration import ( from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore from synapse.storage.data_stores.main.search import SearchBackgroundUpdateStore -from synapse.storage.data_stores.main.state import StateBackgroundUpdateStore +from synapse.storage.data_stores.main.state import MainStateBackgroundUpdateStore from synapse.storage.data_stores.main.stats import StatsStore from synapse.storage.data_stores.main.user_directory import ( UserDirectoryBackgroundUpdateStore, ) -from synapse.storage.database import Database +from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore +from synapse.storage.database import Database, make_conn from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database from synapse.util import Clock @@ -137,6 +139,7 @@ class Store( RoomMemberBackgroundUpdateStore, SearchBackgroundUpdateStore, StateBackgroundUpdateStore, + MainStateBackgroundUpdateStore, UserDirectoryBackgroundUpdateStore, StatsStore, ): @@ -165,23 +168,17 @@ class Store( class MockHomeserver: - def __init__(self, config, database_engine, db_conn, db_pool): - self.database_engine = database_engine - self.db_conn = db_conn - self.db_pool = db_pool + def __init__(self, config): self.clock = Clock(reactor) self.config = config self.hostname = config.server_name - def get_db_conn(self): - return self.db_conn - - def get_db_pool(self): - return self.db_pool - def get_clock(self): return self.clock + def get_reactor(self): + return reactor + class Porter(object): def __init__(self, **kwargs): @@ -445,45 +442,36 @@ class Porter(object): else: return - def setup_db(self, db_config, database_engine): - db_conn = database_engine.module.connect( - **{ - k: v - for k, v in db_config.get("args", {}).items() - if not k.startswith("cp_") - } - ) - - prepare_database(db_conn, database_engine, config=None) + def setup_db(self, db_config: DatabaseConnectionConfig, engine): + db_conn = make_conn(db_config, engine) + prepare_database(db_conn, engine, config=None) db_conn.commit() return db_conn @defer.inlineCallbacks - def build_db_store(self, config): + def build_db_store(self, db_config: DatabaseConnectionConfig): """Builds and returns a database store using the provided configuration. Args: - config: The database configuration, i.e. a dict following the structure of - the "database" section of Synapse's configuration file. + config: The database configuration Returns: The built Store object. """ - engine = create_engine(config) - - self.progress.set_state("Preparing %s" % config["name"]) - conn = self.setup_db(config, engine) + self.progress.set_state("Preparing %s" % db_config.config["name"]) - db_pool = adbapi.ConnectionPool(config["name"], **config["args"]) + engine = create_engine(db_config.config) + conn = self.setup_db(db_config, engine) - hs = MockHomeserver(self.hs_config, engine, conn, db_pool) + hs = MockHomeserver(self.hs_config) - store = Store(Database(hs), conn, hs) + store = Store(Database(hs, db_config, engine), conn, hs) yield store.db.runInteraction( - "%s_engine.check_database" % config["name"], engine.check_database, + "%s_engine.check_database" % db_config.config["name"], + engine.check_database, ) return store @@ -509,7 +497,9 @@ class Porter(object): @defer.inlineCallbacks def run(self): try: - self.sqlite_store = yield self.build_db_store(self.sqlite_config) + self.sqlite_store = yield self.build_db_store( + DatabaseConnectionConfig("master-sqlite", self.sqlite_config) + ) # Check if all background updates are done, abort if not. updates_complete = ( @@ -524,7 +514,7 @@ class Porter(object): defer.returnValue(None) self.postgres_store = yield self.build_db_store( - self.hs_config.database_config + self.hs_config.get_single_database() ) yield self.run_background_updates_on_postgres() diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 9fd52a8c77..abbc7079a3 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -79,7 +79,7 @@ class Auth(object): @defer.inlineCallbacks def check_from_context(self, room_version, event, context, do_sig_check=True): - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() auth_events_ids = yield self.compute_auth_events( event, prev_state_ids, for_verification=True ) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 9c96816096..0e8b467a3e 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -237,6 +237,12 @@ def start(hs, listeners=None): """ Start a Synapse server or worker. + Should be called once the reactor is running and (if we're using ACME) the + TLS certificates are in place. + + Will start the main HTTP listeners and do some other startup tasks, and then + notify systemd. + Args: hs (synapse.server.HomeServer) listeners (list[dict]): Listener configuration ('listeners' in homeserver.yaml) @@ -311,9 +317,7 @@ def setup_sdnotify(hs): # Tell systemd our state, if we're using it. This will silently fail if # we're not using systemd. - hs.get_reactor().addSystemEventTrigger( - "after", "startup", sdnotify, b"READY=1\nMAINPID=%i" % (os.getpid(),) - ) + sdnotify(b"READY=1\nMAINPID=%i" % (os.getpid(),)) hs.get_reactor().addSystemEventTrigger( "before", "shutdown", sdnotify, b"STOPPING=1" diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 04751a6a5e..8e36bc57d3 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -45,7 +45,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.room import RoomStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.logcontext import LoggingContext from synapse.util.versionstring import get_version_string @@ -105,8 +104,10 @@ def export_data_command(hs, args): user_id = args.user_id directory = args.output_directory - res = yield hs.get_handlers().admin_handler.export_user_data( - user_id, FileExfiltrationWriter(user_id, directory=directory) + res = yield defer.ensureDeferred( + hs.get_handlers().admin_handler.export_user_data( + user_id, FileExfiltrationWriter(user_id, directory=directory) + ) ) print(res) @@ -229,14 +230,10 @@ def start(config_options): synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = AdminCmdServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 02b900f382..e82e0f11e3 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -34,7 +34,6 @@ from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -143,8 +142,6 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - if config.notify_appservices: sys.stderr.write( "\nThe appservices must be disabled in the main synapse process" @@ -159,10 +156,8 @@ def start(config_options): ps = AppserviceServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ps, config, use_worker_options=True) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index dadb487d5f..3edfe19567 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -62,7 +62,6 @@ from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -181,14 +180,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = ClientReaderServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index d110599a35..d0ddbe38fc 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -57,7 +57,6 @@ from synapse.rest.client.v1.room import ( ) from synapse.server import HomeServer from synapse.storage.data_stores.main.user_directory import UserDirectoryStore -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -180,14 +179,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = EventCreatorServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 418c086254..311523e0ed 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -46,7 +46,6 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.key.v2 import KeyApiV2Resource from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -162,14 +161,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = FederationReaderServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index f24920a7d6..83c436229c 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -41,7 +41,6 @@ from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.streams._base import ReceiptsStream from synapse.server import HomeServer from synapse.storage.database import Database -from synapse.storage.engines import create_engine from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer from synapse.util.httpresourcetree import create_resource_tree @@ -174,8 +173,6 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - if config.send_federation: sys.stderr.write( "\nThe send_federation must be disabled in the main synapse process" @@ -190,10 +187,8 @@ def start(config_options): ss = FederationSenderServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index e647459d0e..30e435eead 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -39,7 +39,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.client.v2_alpha._base import client_patterns from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -234,14 +233,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = FrontendProxyServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index df65d0a989..0e9bf7f53a 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -69,7 +69,7 @@ from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.engines import IncorrectDatabaseSetup, create_engine +from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.prepare_database import UpgradeDatabaseException from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.httpresourcetree import create_resource_tree @@ -328,15 +328,10 @@ def setup(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection - hs = SynapseHomeServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) synapse.config.logger.setup_logging(hs, config, use_worker_options=False) @@ -347,13 +342,8 @@ def setup(config_options): hs.setup() except IncorrectDatabaseSetup as e: quit_with_error(str(e)) - except UpgradeDatabaseException: - sys.stderr.write( - "\nFailed to upgrade database.\n" - "Have you checked for version specific instructions in" - " UPGRADES.rst?\n" - ) - sys.exit(1) + except UpgradeDatabaseException as e: + quit_with_error("Failed to upgrade database: %s" % (e,)) hs.setup_master() @@ -519,8 +509,10 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): # Database version # - stats["database_engine"] = hs.database_engine.module.__name__ - stats["database_server_version"] = hs.database_engine.server_version + # This only reports info about the *main* database. + stats["database_engine"] = hs.get_datastore().db.engine.module.__name__ + stats["database_server_version"] = hs.get_datastore().db.engine.server_version + logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) try: yield hs.get_proxied_http_client().put_json( diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 2c6dd3ef02..4c80f257e2 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -40,7 +40,6 @@ from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.server import HomeServer from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -157,14 +156,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = MediaRepositoryServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index dd52a9fc2d..09e639040a 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -37,7 +37,6 @@ from synapse.replication.slave.storage.room import RoomStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -203,14 +202,10 @@ def start(config_options): # Force the pushers to start since they will be disabled in the main config config.start_pushers = True - database_engine = create_engine(config.database_config) - ps = PusherServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ps, config, use_worker_options=True) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 288ee64b42..dd2132e608 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -55,7 +55,6 @@ from synapse.rest.client.v1.room import RoomInitialSyncRestServlet from synapse.rest.client.v2_alpha import sync from synapse.server import HomeServer from synapse.storage.data_stores.main.presence import UserPresenceState -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.stringutils import random_string @@ -437,14 +436,10 @@ def start(config_options): synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = SynchrotronServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, application_service_handler=SynchrotronApplicationService(), ) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index c01fb34a9b..1257098f92 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -44,7 +44,6 @@ from synapse.rest.client.v2_alpha import user_directory from synapse.server import HomeServer from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.storage.database import Database -from synapse.storage.engines import create_engine from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole @@ -200,8 +199,6 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - if config.update_user_directory: sys.stderr.write( "\nThe update_user_directory must be disabled in the main synapse process" @@ -216,10 +213,8 @@ def start(config_options): ss = UserDirectoryServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/config/database.py b/synapse/config/database.py index 0e2509f0b1..134824789c 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -12,12 +12,45 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from textwrap import indent +from typing import List import yaml -from ._base import Config +from synapse.config._base import Config, ConfigError + +logger = logging.getLogger(__name__) + + +class DatabaseConnectionConfig: + """Contains the connection config for a particular database. + + Args: + name: A label for the database, used for logging. + db_config: The config for a particular database, as per `database` + section of main config. Has two fields: `name` for database + module name, and `args` for the args to give to the database + connector. + data_stores: The list of data stores that should be provisioned on the + database. Defaults to all data stores. + """ + + def __init__( + self, name: str, db_config: dict, data_stores: List[str] = ["main", "state"] + ): + if db_config["name"] not in ("sqlite3", "psycopg2"): + raise ConfigError("Unsupported database type %r" % (db_config["name"],)) + + if db_config["name"] == "sqlite3": + db_config.setdefault("args", {}).update( + {"cp_min": 1, "cp_max": 1, "check_same_thread": False} + ) + + self.name = name + self.config = db_config + self.data_stores = data_stores class DatabaseConfig(Config): @@ -26,20 +59,12 @@ class DatabaseConfig(Config): def read_config(self, config, **kwargs): self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) - self.database_config = config.get("database") + database_config = config.get("database") - if self.database_config is None: - self.database_config = {"name": "sqlite3", "args": {}} + if database_config is None: + database_config = {"name": "sqlite3", "args": {}} - name = self.database_config.get("name", None) - if name == "psycopg2": - pass - elif name == "sqlite3": - self.database_config.setdefault("args", {}).update( - {"cp_min": 1, "cp_max": 1, "check_same_thread": False} - ) - else: - raise RuntimeError("Unsupported database type '%s'" % (name,)) + self.databases = [DatabaseConnectionConfig("master", database_config)] self.set_databasepath(config.get("database_path")) @@ -76,11 +101,24 @@ class DatabaseConfig(Config): self.set_databasepath(args.database_path) def set_databasepath(self, database_path): + if database_path is None: + return + if database_path != ":memory:": database_path = self.abspath(database_path) - if self.database_config.get("name", None) == "sqlite3": - if database_path is not None: - self.database_config["args"]["database"] = database_path + + # We only support setting a database path if we have a single sqlite3 + # database. + if len(self.databases) != 1: + raise ConfigError("Cannot specify 'database_path' with multiple databases") + + database = self.get_single_database() + if database.config["name"] != "sqlite3": + # We don't raise here as we haven't done so before for this case. + logger.warn("Ignoring 'database_path' for non-sqlite3 database") + return + + database.config["args"]["database"] = database_path @staticmethod def add_arguments(parser): @@ -91,3 +129,11 @@ class DatabaseConfig(Config): metavar="SQLITE_DATABASE_PATH", help="The path to a sqlite database to use.", ) + + def get_single_database(self) -> DatabaseConnectionConfig: + """Returns the database if there is only one, useful for e.g. tests + """ + if len(self.databases) != 1: + raise Exception("More than one database exists") + + return self.databases[0] diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 18f42a87f9..35756bed87 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -21,6 +21,7 @@ from __future__ import print_function import email.utils import os from enum import Enum +from typing import Optional import pkg_resources @@ -101,7 +102,7 @@ class EmailConfig(Config): # both in RegistrationConfig and here. We should factor this bit out self.account_threepid_delegate_email = self.trusted_third_party_id_servers[ 0 - ] + ] # type: Optional[str] self.using_identity_server_from_trusted_list = True else: raise ConfigError( diff --git a/synapse/config/key.py b/synapse/config/key.py index 52ff1b2621..066e7838c3 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -108,7 +108,7 @@ class KeyConfig(Config): self.signing_key = self.read_signing_keys(signing_key_path, "signing_key") self.old_signing_keys = self.read_old_signing_keys( - config.get("old_signing_keys", {}) + config.get("old_signing_keys") ) self.key_refresh_interval = self.parse_duration( config.get("key_refresh_interval", "1d") @@ -199,14 +199,19 @@ class KeyConfig(Config): signing_key_path: "%(base_key_name)s.signing.key" # The keys that the server used to sign messages with but won't use - # to sign new messages. E.g. it has lost its private key + # to sign new messages. # - #old_signing_keys: - # "ed25519:auto": - # # Base64 encoded public key - # key: "The public part of your old signing key." - # # Millisecond POSIX timestamp when the key expired. - # expired_ts: 123456789123 + old_signing_keys: + # For each key, `key` should be the base64-encoded public key, and + # `expired_ts`should be the time (in milliseconds since the unix epoch) that + # it was last used. + # + # It is possible to build an entry from an old signing.key file using the + # `export_signing_key` script which is provided with synapse. + # + # For example: + # + #"ed25519:id": { key: "base64string", expired_ts: 123456789123 } # How long key response published by this server is valid for. # Used to set the valid_until_ts in /key/v2 APIs. @@ -290,6 +295,8 @@ class KeyConfig(Config): raise ConfigError("Error reading %s: %s" % (name, str(e))) def read_old_signing_keys(self, old_signing_keys): + if old_signing_keys is None: + return {} keys = {} for key_id, key_data in old_signing_keys.items(): if is_signing_algorithm_supported(key_id): diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 947f653e03..4a3bfc4354 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -83,10 +83,9 @@ class RatelimitConfig(Config): ) rc_admin_redaction = config.get("rc_admin_redaction") + self.rc_admin_redaction = None if rc_admin_redaction: self.rc_admin_redaction = RateLimitConfig(rc_admin_redaction) - else: - self.rc_admin_redaction = None def generate_config_section(self, **kwargs): return """\ diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index c5ea2d43a1..b91414aa35 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -14,17 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re +import logging from synapse.python_dependencies import DependencyException, check_requirements -from synapse.types import ( - map_username_to_mxid_localpart, - mxid_localpart_allowed_characters, -) -from synapse.util.module_loader import load_python_module +from synapse.util.module_loader import load_module, load_python_module from ._base import Config, ConfigError +logger = logging.getLogger(__name__) + +DEFAULT_USER_MAPPING_PROVIDER = ( + "synapse.handlers.saml_handler.DefaultSamlMappingProvider" +) + def _dict_merge(merge_dict, into_dict): """Do a deep merge of two dicts @@ -75,15 +77,69 @@ class SAML2Config(Config): self.saml2_enabled = True - self.saml2_mxid_source_attribute = saml2_config.get( - "mxid_source_attribute", "uid" - ) - self.saml2_grandfathered_mxid_source_attribute = saml2_config.get( "grandfathered_mxid_source_attribute", "uid" ) - saml2_config_dict = self._default_saml_config_dict() + # user_mapping_provider may be None if the key is present but has no value + ump_dict = saml2_config.get("user_mapping_provider") or {} + + # Use the default user mapping provider if not set + ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) + + # Ensure a config is present + ump_dict["config"] = ump_dict.get("config") or {} + + if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER: + # Load deprecated options for use by the default module + old_mxid_source_attribute = saml2_config.get("mxid_source_attribute") + if old_mxid_source_attribute: + logger.warning( + "The config option saml2_config.mxid_source_attribute is deprecated. " + "Please use saml2_config.user_mapping_provider.config" + ".mxid_source_attribute instead." + ) + ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute + + old_mxid_mapping = saml2_config.get("mxid_mapping") + if old_mxid_mapping: + logger.warning( + "The config option saml2_config.mxid_mapping is deprecated. Please " + "use saml2_config.user_mapping_provider.config.mxid_mapping instead." + ) + ump_dict["config"]["mxid_mapping"] = old_mxid_mapping + + # Retrieve an instance of the module's class + # Pass the config dictionary to the module for processing + ( + self.saml2_user_mapping_provider_class, + self.saml2_user_mapping_provider_config, + ) = load_module(ump_dict) + + # Ensure loaded user mapping module has defined all necessary methods + # Note parse_config() is already checked during the call to load_module + required_methods = [ + "get_saml_attributes", + "saml_response_to_user_attributes", + ] + missing_methods = [ + method + for method in required_methods + if not hasattr(self.saml2_user_mapping_provider_class, method) + ] + if missing_methods: + raise ConfigError( + "Class specified by saml2_config." + "user_mapping_provider.module is missing required " + "methods: %s" % (", ".join(missing_methods),) + ) + + # Get the desired saml auth response attributes from the module + saml2_config_dict = self._default_saml_config_dict( + *self.saml2_user_mapping_provider_class.get_saml_attributes( + self.saml2_user_mapping_provider_config + ) + ) _dict_merge( merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict ) @@ -103,22 +159,27 @@ class SAML2Config(Config): saml2_config.get("saml_session_lifetime", "5m") ) - mapping = saml2_config.get("mxid_mapping", "hexencode") - try: - self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping] - except KeyError: - raise ConfigError("%s is not a known mxid_mapping" % (mapping,)) - - def _default_saml_config_dict(self): + def _default_saml_config_dict( + self, required_attributes: set, optional_attributes: set + ): + """Generate a configuration dictionary with required and optional attributes that + will be needed to process new user registration + + Args: + required_attributes: SAML auth response attributes that are + necessary to function + optional_attributes: SAML auth response attributes that can be used to add + additional information to Synapse user accounts, but are not required + + Returns: + dict: A SAML configuration dictionary + """ import saml2 public_baseurl = self.public_baseurl if public_baseurl is None: raise ConfigError("saml2_config requires a public_baseurl to be set") - required_attributes = {"uid", self.saml2_mxid_source_attribute} - - optional_attributes = {"displayName"} if self.saml2_grandfathered_mxid_source_attribute: optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) optional_attributes -= required_attributes @@ -207,33 +268,58 @@ class SAML2Config(Config): # #config_path: "%(config_dir_path)s/sp_conf.py" - # the lifetime of a SAML session. This defines how long a user has to + # The lifetime of a SAML session. This defines how long a user has to # complete the authentication process, if allow_unsolicited is unset. # The default is 5 minutes. # #saml_session_lifetime: 5m - # The SAML attribute (after mapping via the attribute maps) to use to derive - # the Matrix ID from. 'uid' by default. + # An external module can be provided here as a custom solution to + # mapping attributes returned from a saml provider onto a matrix user. # - #mxid_source_attribute: displayName - - # The mapping system to use for mapping the saml attribute onto a matrix ID. - # Options include: - # * 'hexencode' (which maps unpermitted characters to '=xx') - # * 'dotreplace' (which replaces unpermitted characters with '.'). - # The default is 'hexencode'. - # - #mxid_mapping: dotreplace - - # In previous versions of synapse, the mapping from SAML attribute to MXID was - # always calculated dynamically rather than stored in a table. For backwards- - # compatibility, we will look for user_ids matching such a pattern before - # creating a new account. + user_mapping_provider: + # The custom module's class. Uncomment to use a custom module. + # + #module: mapping_provider.SamlMappingProvider + + # Custom configuration values for the module. Below options are + # intended for the built-in provider, they should be changed if + # using a custom module. This section will be passed as a Python + # dictionary to the module's `parse_config` method. + # + config: + # The SAML attribute (after mapping via the attribute maps) to use + # to derive the Matrix ID from. 'uid' by default. + # + # Note: This used to be configured by the + # saml2_config.mxid_source_attribute option. If that is still + # defined, its value will be used instead. + # + #mxid_source_attribute: displayName + + # The mapping system to use for mapping the saml attribute onto a + # matrix ID. + # + # Options include: + # * 'hexencode' (which maps unpermitted characters to '=xx') + # * 'dotreplace' (which replaces unpermitted characters with + # '.'). + # The default is 'hexencode'. + # + # Note: This used to be configured by the + # saml2_config.mxid_mapping option. If that is still defined, its + # value will be used instead. + # + #mxid_mapping: dotreplace + + # In previous versions of synapse, the mapping from SAML attribute to + # MXID was always calculated dynamically rather than stored in a + # table. For backwards- compatibility, we will look for user_ids + # matching such a pattern before creating a new account. # # This setting controls the SAML attribute which will be used for this - # backwards-compatibility lookup. Typically it should be 'uid', but if the - # attribute maps are changed, it may be necessary to change it. + # backwards-compatibility lookup. Typically it should be 'uid', but if + # the attribute maps are changed, it may be necessary to change it. # # The default is 'uid'. # @@ -241,23 +327,3 @@ class SAML2Config(Config): """ % { "config_dir_path": config_dir_path } - - -DOT_REPLACE_PATTERN = re.compile( - ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) -) - - -def dot_replace_for_mxid(username: str) -> str: - username = username.lower() - username = DOT_REPLACE_PATTERN.sub(".", username) - - # regular mxids aren't allowed to start with an underscore either - username = re.sub("^_", "", username) - return username - - -MXID_MAPPER_MAP = { - "hexencode": map_username_to_mxid_localpart, - "dotreplace": dot_replace_for_mxid, -} diff --git a/synapse/config/server.py b/synapse/config/server.py index a4bef00936..38f6ff9edc 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -102,6 +102,12 @@ class ServerConfig(Config): "require_auth_for_profile_requests", False ) + # Whether to require sharing a room with a user to retrieve their + # profile data + self.limit_profile_requests_to_users_who_share_rooms = config.get( + "limit_profile_requests_to_users_who_share_rooms", False, + ) + if "restrict_public_rooms_to_local_users" in config and ( "allow_public_rooms_without_auth" in config or "allow_public_rooms_over_federation" in config @@ -200,7 +206,7 @@ class ServerConfig(Config): self.admin_contact = config.get("admin_contact", None) # FIXME: federation_domain_whitelist needs sytests - self.federation_domain_whitelist = None + self.federation_domain_whitelist = None # type: Optional[dict] federation_domain_whitelist = config.get("federation_domain_whitelist", None) if federation_domain_whitelist is not None: @@ -621,6 +627,13 @@ class ServerConfig(Config): # #require_auth_for_profile_requests: true + # Uncomment to require a user to share a room with another user in order + # to retrieve their profile information. Only checked on Client-Server + # requests. Profile requests from other servers should be checked by the + # requesting server. Defaults to 'false'. + # + #limit_profile_requests_to_users_who_share_rooms: true + # If set to 'true', removes the need for authentication to access the server's # public rooms directory through the client API, meaning that anyone can # query the room directory. Defaults to 'false'. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 350ed9351f..1033e5e121 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -43,6 +43,8 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru Returns: if the auth checks pass. """ + assert isinstance(auth_events, dict) + if do_size_check: _check_size_limits(event) @@ -87,12 +89,6 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru if not event.signatures.get(event_id_domain): raise AuthError(403, "Event not signed by sending server") - if auth_events is None: - # Oh, we don't know what the state of the room was, so we - # are trusting that this is allowed (at least for now) - logger.warning("Trusting event: %s", event.event_id) - return - if event.type == EventTypes.Create: sender_domain = get_domain_from_id(event.sender) room_id_domain = get_domain_from_id(event.room_id) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 64e898f40c..a44baea365 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -149,7 +149,7 @@ class EventContext: # the prev_state_ids, so if we're a state event we include the event # id that we replaced in the state. if event.is_state(): - prev_state_ids = yield self.get_prev_state_ids(store) + prev_state_ids = yield self.get_prev_state_ids() prev_state_id = prev_state_ids.get((event.type, event.state_key)) else: prev_state_id = None @@ -167,12 +167,13 @@ class EventContext: } @staticmethod - def deserialize(store, input): + def deserialize(storage, input): """Converts a dict that was produced by `serialize` back into a EventContext. Args: - store (DataStore): Used to convert AS ID to AS object + storage (Storage): Used to convert AS ID to AS object and fetch + state. input (dict): A dict produced by `serialize` Returns: @@ -181,6 +182,7 @@ class EventContext: context = _AsyncEventContextImpl( # We use the state_group and prev_state_id stuff to pull the # current_state_ids out of the DB and construct prev_state_ids. + storage=storage, prev_state_id=input["prev_state_id"], event_type=input["event_type"], event_state_key=input["event_state_key"], @@ -193,7 +195,7 @@ class EventContext: app_service_id = input["app_service_id"] if app_service_id: - context.app_service = store.get_app_service_by_id(app_service_id) + context.app_service = storage.main.get_app_service_by_id(app_service_id) return context @@ -216,7 +218,7 @@ class EventContext: return self._state_group @defer.inlineCallbacks - def get_current_state_ids(self, store): + def get_current_state_ids(self): """ Gets the room state map, including this event - ie, the state in ``state_group`` @@ -234,11 +236,11 @@ class EventContext: if self.rejected: raise RuntimeError("Attempt to access state_ids of rejected event") - yield self._ensure_fetched(store) + yield self._ensure_fetched() return self._current_state_ids @defer.inlineCallbacks - def get_prev_state_ids(self, store): + def get_prev_state_ids(self): """ Gets the room state map, excluding this event. @@ -250,7 +252,7 @@ class EventContext: Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - yield self._ensure_fetched(store) + yield self._ensure_fetched() return self._prev_state_ids def get_cached_current_state_ids(self): @@ -270,7 +272,7 @@ class EventContext: return self._current_state_ids - def _ensure_fetched(self, store): + def _ensure_fetched(self): return defer.succeed(None) @@ -282,6 +284,8 @@ class _AsyncEventContextImpl(EventContext): Attributes: + _storage (Storage) + _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have been calculated. None if we haven't started calculating yet @@ -295,28 +299,30 @@ class _AsyncEventContextImpl(EventContext): that was replaced. """ + # This needs to have a default as we're inheriting + _storage = attr.ib(default=None) _prev_state_id = attr.ib(default=None) _event_type = attr.ib(default=None) _event_state_key = attr.ib(default=None) _fetching_state_deferred = attr.ib(default=None) - def _ensure_fetched(self, store): + def _ensure_fetched(self): if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background( - self._fill_out_state, store - ) + self._fetching_state_deferred = run_in_background(self._fill_out_state) return make_deferred_yieldable(self._fetching_state_deferred) @defer.inlineCallbacks - def _fill_out_state(self, store): + def _fill_out_state(self): """Called to populate the _current_state_ids and _prev_state_ids attributes by loading from the database. """ if self.state_group is None: return - self._current_state_ids = yield store.get_state_ids_for_group(self.state_group) + self._current_state_ids = yield self._storage.state.get_state_ids_for_group( + self.state_group + ) if self._prev_state_id and self._event_state_key is not None: self._prev_state_ids = dict(self._current_state_ids) diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 714a9b1579..86f7e5f8aa 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -53,7 +53,7 @@ class ThirdPartyEventRules(object): if self.third_party_rules is None: return True - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() # Retrieve the state events from the database. state_events = {} diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index d396e6564f..af652a7659 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -526,13 +526,7 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_request(destination): - time_now = self._clock.time_msec() - _, content = yield self.transport_layer.send_join( - destination=destination, - room_id=pdu.room_id, - event_id=pdu.event_id, - content=pdu.get_pdu_json(time_now), - ) + content = yield self._do_send_join(destination, pdu) logger.debug("Got content: %s", content) @@ -600,6 +594,44 @@ class FederationClient(FederationBase): return self._try_destination_list("send_join", destinations, send_request) @defer.inlineCallbacks + def _do_send_join(self, destination, pdu): + time_now = self._clock.time_msec() + + try: + content = yield self.transport_layer.send_join_v2( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + + return content + except HttpResponseException as e: + if e.code in [400, 404]: + err = e.to_synapse_error() + + # If we receive an error response that isn't a generic error, or an + # unrecognised endpoint error, we assume that the remote understands + # the v2 invite API and this is a legitimate error. + if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]: + raise err + else: + raise e.to_synapse_error() + + logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API") + + resp = yield self.transport_layer.send_join_v1( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + + # We expect the v1 API to respond with [200, content], so we only return the + # content. + return resp[1] + + @defer.inlineCallbacks def send_invite(self, destination, room_id, event_id, pdu): room_version = yield self.store.get_room_version(room_id) @@ -708,18 +740,50 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_request(destination): - time_now = self._clock.time_msec() - _, content = yield self.transport_layer.send_leave( + content = yield self._do_send_leave(destination, pdu) + + logger.debug("Got content: %s", content) + return None + + return self._try_destination_list("send_leave", destinations, send_request) + + @defer.inlineCallbacks + def _do_send_leave(self, destination, pdu): + time_now = self._clock.time_msec() + + try: + content = yield self.transport_layer.send_leave_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, content=pdu.get_pdu_json(time_now), ) - logger.debug("Got content: %s", content) - return None + return content + except HttpResponseException as e: + if e.code in [400, 404]: + err = e.to_synapse_error() - return self._try_destination_list("send_leave", destinations, send_request) + # If we receive an error response that isn't a generic error, or an + # unrecognised endpoint error, we assume that the remote understands + # the v2 invite API and this is a legitimate error. + if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]: + raise err + else: + raise e.to_synapse_error() + + logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API") + + resp = yield self.transport_layer.send_leave_v1( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + + # We expect the v1 API to respond with [200, content], so we only return the + # content. + return resp[1] def get_public_rooms( self, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 84d4eca041..d7ce333822 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -384,15 +384,10 @@ class FederationServer(FederationBase): res_pdus = await self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() - return ( - 200, - { - "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], - "auth_chain": [ - p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] - ], - }, - ) + return { + "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], + "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]], + } async def on_make_leave_request(self, origin, room_id, user_id): origin_host, _ = parse_server_name(origin) @@ -419,7 +414,7 @@ class FederationServer(FederationBase): pdu = await self._check_sigs_and_hash(room_version, pdu) await self.handler.on_send_leave_request(origin, pdu) - return 200, {} + return {} async def on_event_auth(self, origin, room_id, event_id): with (await self._server_linearizer.queue((origin, room_id))): diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 46dba84cac..198257414b 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -243,7 +243,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def send_join(self, destination, room_id, event_id, content): + def send_join_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/send_join/%s/%s", room_id, event_id) response = yield self.client.put_json( @@ -254,7 +254,18 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def send_leave(self, destination, room_id, event_id, content): + def send_join_v2(self, destination, room_id, event_id, content): + path = _create_v2_path("/send_join/%s/%s", room_id, event_id) + + response = yield self.client.put_json( + destination=destination, path=path, data=content + ) + + return response + + @defer.inlineCallbacks + @log_function + def send_leave_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/send_leave/%s/%s", room_id, event_id) response = yield self.client.put_json( @@ -272,6 +283,24 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function + def send_leave_v2(self, destination, room_id, event_id, content): + path = _create_v2_path("/send_leave/%s/%s", room_id, event_id) + + response = yield self.client.put_json( + destination=destination, + path=path, + data=content, + # we want to do our best to send this through. The problem is + # that if it fails, we won't retry it later, so if the remote + # server was just having a momentary blip, the room will be out of + # sync. + ignore_backoff=True, + ) + + return response + + @defer.inlineCallbacks + @log_function def send_invite_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/invite/%s/%s", room_id, event_id) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index fefc789c85..b4cbf23394 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -506,11 +506,21 @@ class FederationMakeLeaveServlet(BaseFederationServlet): return 200, content -class FederationSendLeaveServlet(BaseFederationServlet): +class FederationV1SendLeaveServlet(BaseFederationServlet): PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): content = await self.handler.on_send_leave_request(origin, content, room_id) + return 200, (200, content) + + +class FederationV2SendLeaveServlet(BaseFederationServlet): + PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + + async def on_PUT(self, origin, content, query, room_id, event_id): + content = await self.handler.on_send_leave_request(origin, content, room_id) return 200, content @@ -521,9 +531,21 @@ class FederationEventAuthServlet(BaseFederationServlet): return await self.handler.on_event_auth(origin, context, event_id) -class FederationSendJoinServlet(BaseFederationServlet): +class FederationV1SendJoinServlet(BaseFederationServlet): + PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)" + + async def on_PUT(self, origin, content, query, context, event_id): + # TODO(paul): assert that context/event_id parsed from path actually + # match those given in content + content = await self.handler.on_send_join_request(origin, content, context) + return 200, (200, content) + + +class FederationV2SendJoinServlet(BaseFederationServlet): PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)" + PREFIX = FEDERATION_V2_PREFIX + async def on_PUT(self, origin, content, query, context, event_id): # TODO(paul): assert that context/event_id parsed from path actually # match those given in content @@ -1367,8 +1389,10 @@ FEDERATION_SERVLET_CLASSES = ( FederationMakeJoinServlet, FederationMakeLeaveServlet, FederationEventServlet, - FederationSendJoinServlet, - FederationSendLeaveServlet, + FederationV1SendJoinServlet, + FederationV2SendJoinServlet, + FederationV1SendLeaveServlet, + FederationV2SendLeaveServlet, FederationV1InviteServlet, FederationV2InviteServlet, FederationQueryAuthServlet, diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 29e8ffc295..0ec9be3cb5 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -773,6 +773,11 @@ class GroupsServerHandler(object): if not self.hs.is_mine_id(user_id): yield self.store.maybe_delete_remote_profile_cache(user_id) + # Delete group if the last user has left + users = yield self.store.get_users_in_group(group_id, include_private=True) + if not users: + yield self.store.delete_group(group_id) + return {} @defer.inlineCallbacks diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index d15c6282fb..51413d910e 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -134,7 +134,7 @@ class BaseHandler(object): guest_access = event.content.get("guest_access", "forbidden") if guest_access != "can_join": if context: - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() current_state = yield self.store.get_events( list(current_state_ids.values()) ) diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 2d7e6df6e4..a8d3fbc6de 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - class AccountDataEventSource(object): def __init__(self, hs): @@ -23,15 +21,14 @@ class AccountDataEventSource(object): def get_current_key(self, direction="f"): return self.store.get_max_account_data_stream_id() - @defer.inlineCallbacks - def get_new_events(self, user, from_key, **kwargs): + async def get_new_events(self, user, from_key, **kwargs): user_id = user.to_string() last_stream_id = from_key - current_stream_id = yield self.store.get_max_account_data_stream_id() + current_stream_id = self.store.get_max_account_data_stream_id() results = [] - tags = yield self.store.get_updated_tags(user_id, last_stream_id) + tags = await self.store.get_updated_tags(user_id, last_stream_id) for room_id, room_tags in tags.items(): results.append( @@ -41,7 +38,7 @@ class AccountDataEventSource(object): ( account_data, room_account_data, - ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) + ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id) for account_data_type, content in account_data.items(): results.append({"type": account_data_type, "content": content}) @@ -53,7 +50,3 @@ class AccountDataEventSource(object): ) return results, current_stream_id - - @defer.inlineCallbacks - def get_pagination_rows(self, user, config, key): - return [], config.to_id diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index d04e0fe576..829f52eca1 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -18,8 +18,7 @@ import email.utils import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText - -from twisted.internet import defer +from typing import List from synapse.api.errors import StoreError from synapse.logging.context import make_deferred_yieldable @@ -78,42 +77,39 @@ class AccountValidityHandler(object): # run as a background process to make sure that the database transactions # have a logcontext to report to return run_as_background_process( - "send_renewals", self.send_renewal_emails + "send_renewals", self._send_renewal_emails ) self.clock.looping_call(send_emails, 30 * 60 * 1000) - @defer.inlineCallbacks - def send_renewal_emails(self): + async def _send_renewal_emails(self): """Gets the list of users whose account is expiring in the amount of time configured in the ``renew_at`` parameter from the ``account_validity`` configuration, and sends renewal emails to all of these users as long as they have an email 3PID attached to their account. """ - expiring_users = yield self.store.get_users_expiring_soon() + expiring_users = await self.store.get_users_expiring_soon() if expiring_users: for user in expiring_users: - yield self._send_renewal_email( + await self._send_renewal_email( user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] ) - @defer.inlineCallbacks - def send_renewal_email_to_user(self, user_id): - expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) - yield self._send_renewal_email(user_id, expiration_ts) + async def send_renewal_email_to_user(self, user_id: str): + expiration_ts = await self.store.get_expiration_ts_for_user(user_id) + await self._send_renewal_email(user_id, expiration_ts) - @defer.inlineCallbacks - def _send_renewal_email(self, user_id, expiration_ts): + async def _send_renewal_email(self, user_id: str, expiration_ts: int): """Sends out a renewal email to every email address attached to the given user with a unique link allowing them to renew their account. Args: - user_id (str): ID of the user to send email(s) to. - expiration_ts (int): Timestamp in milliseconds for the expiration date of + user_id: ID of the user to send email(s) to. + expiration_ts: Timestamp in milliseconds for the expiration date of this user's account (used in the email templates). """ - addresses = yield self._get_email_addresses_for_user(user_id) + addresses = await self._get_email_addresses_for_user(user_id) # Stop right here if the user doesn't have at least one email address. # In this case, they will have to ask their server admin to renew their @@ -125,7 +121,7 @@ class AccountValidityHandler(object): return try: - user_display_name = yield self.store.get_profile_displayname( + user_display_name = await self.store.get_profile_displayname( UserID.from_string(user_id).localpart ) if user_display_name is None: @@ -133,7 +129,7 @@ class AccountValidityHandler(object): except StoreError: user_display_name = user_id - renewal_token = yield self._get_renewal_token(user_id) + renewal_token = await self._get_renewal_token(user_id) url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % ( self.hs.config.public_baseurl, renewal_token, @@ -165,7 +161,7 @@ class AccountValidityHandler(object): logger.info("Sending renewal email to %s", address) - yield make_deferred_yieldable( + await make_deferred_yieldable( self.sendmail( self.hs.config.email_smtp_host, self._raw_from, @@ -180,19 +176,18 @@ class AccountValidityHandler(object): ) ) - yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) + await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) - @defer.inlineCallbacks - def _get_email_addresses_for_user(self, user_id): + async def _get_email_addresses_for_user(self, user_id: str) -> List[str]: """Retrieve the list of email addresses attached to a user's account. Args: - user_id (str): ID of the user to lookup email addresses for. + user_id: ID of the user to lookup email addresses for. Returns: - defer.Deferred[list[str]]: Email addresses for this account. + Email addresses for this account. """ - threepids = yield self.store.user_get_threepids(user_id) + threepids = await self.store.user_get_threepids(user_id) addresses = [] for threepid in threepids: @@ -201,16 +196,15 @@ class AccountValidityHandler(object): return addresses - @defer.inlineCallbacks - def _get_renewal_token(self, user_id): + async def _get_renewal_token(self, user_id: str) -> str: """Generates a 32-byte long random string that will be inserted into the user's renewal email's unique link, then saves it into the database. Args: - user_id (str): ID of the user to generate a string for. + user_id: ID of the user to generate a string for. Returns: - defer.Deferred[str]: The generated string. + The generated string. Raises: StoreError(500): Couldn't generate a unique string after 5 attempts. @@ -219,52 +213,52 @@ class AccountValidityHandler(object): while attempts < 5: try: renewal_token = stringutils.random_string(32) - yield self.store.set_renewal_token_for_user(user_id, renewal_token) + await self.store.set_renewal_token_for_user(user_id, renewal_token) return renewal_token except StoreError: attempts += 1 raise StoreError(500, "Couldn't generate a unique string as refresh string.") - @defer.inlineCallbacks - def renew_account(self, renewal_token): + async def renew_account(self, renewal_token: str) -> bool: """Renews the account attached to a given renewal token by pushing back the expiration date by the current validity period in the server's configuration. Args: - renewal_token (str): Token sent with the renewal request. + renewal_token: Token sent with the renewal request. Returns: - bool: Whether the provided token is valid. + Whether the provided token is valid. """ try: - user_id = yield self.store.get_user_from_renewal_token(renewal_token) + user_id = await self.store.get_user_from_renewal_token(renewal_token) except StoreError: - defer.returnValue(False) + return False logger.debug("Renewing an account for user %s", user_id) - yield self.renew_account_for_user(user_id) + await self.renew_account_for_user(user_id) - defer.returnValue(True) + return True - @defer.inlineCallbacks - def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False): + async def renew_account_for_user( + self, user_id: str, expiration_ts: int = None, email_sent: bool = False + ) -> int: """Renews the account attached to a given user by pushing back the expiration date by the current validity period in the server's configuration. Args: - renewal_token (str): Token sent with the renewal request. - expiration_ts (int): New expiration date. Defaults to now + validity period. - email_sent (bool): Whether an email has been sent for this validity period. + renewal_token: Token sent with the renewal request. + expiration_ts: New expiration date. Defaults to now + validity period. + email_sen: Whether an email has been sent for this validity period. Defaults to False. Returns: - defer.Deferred[int]: New expiration date for this account, as a timestamp - in milliseconds since epoch. + New expiration date for this account, as a timestamp in + milliseconds since epoch. """ if expiration_ts is None: expiration_ts = self.clock.time_msec() + self._account_validity.period - yield self.store.set_account_validity_for_user( + await self.store.set_account_validity_for_user( user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent ) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 14449b9a1e..1a4ba12385 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import Membership from synapse.types import RoomStreamToken from synapse.visibility import filter_events_for_client @@ -33,11 +31,10 @@ class AdminHandler(BaseHandler): self.storage = hs.get_storage() self.state_store = self.storage.state - @defer.inlineCallbacks - def get_whois(self, user): + async def get_whois(self, user): connections = [] - sessions = yield self.store.get_user_ip_and_agents(user) + sessions = await self.store.get_user_ip_and_agents(user) for session in sessions: connections.append( { @@ -54,20 +51,18 @@ class AdminHandler(BaseHandler): return ret - @defer.inlineCallbacks - def get_users(self): + async def get_users(self): """Function to retrieve a list of users in users table. Args: Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - ret = yield self.store.get_users() + ret = await self.store.get_users() return ret - @defer.inlineCallbacks - def get_users_paginate(self, start, limit, name, guests, deactivated): + async def get_users_paginate(self, start, limit, name, guests, deactivated): """Function to retrieve a paginated list of users from users list. This will return a json list of users. @@ -80,14 +75,13 @@ class AdminHandler(BaseHandler): Returns: defer.Deferred: resolves to json list[dict[str, Any]] """ - ret = yield self.store.get_users_paginate( + ret = await self.store.get_users_paginate( start, limit, name, guests, deactivated ) return ret - @defer.inlineCallbacks - def search_users(self, term): + async def search_users(self, term): """Function to search users list for one or more users with the matched term. @@ -96,7 +90,7 @@ class AdminHandler(BaseHandler): Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - ret = yield self.store.search_users(term) + ret = await self.store.search_users(term) return ret @@ -119,8 +113,7 @@ class AdminHandler(BaseHandler): """ return self.store.set_server_admin(user, admin) - @defer.inlineCallbacks - def export_user_data(self, user_id, writer): + async def export_user_data(self, user_id, writer): """Write all data we have on the user to the given writer. Args: @@ -132,7 +125,7 @@ class AdminHandler(BaseHandler): The returned value is that returned by `writer.finished()`. """ # Get all rooms the user is in or has been in - rooms = yield self.store.get_rooms_for_user_where_membership_is( + rooms = await self.store.get_rooms_for_user_where_membership_is( user_id, membership_list=( Membership.JOIN, @@ -145,7 +138,7 @@ class AdminHandler(BaseHandler): # We only try and fetch events for rooms the user has been in. If # they've been e.g. invited to a room without joining then we handle # those seperately. - rooms_user_has_been_in = yield self.store.get_rooms_user_has_been_in(user_id) + rooms_user_has_been_in = await self.store.get_rooms_user_has_been_in(user_id) for index, room in enumerate(rooms): room_id = room.room_id @@ -154,7 +147,7 @@ class AdminHandler(BaseHandler): "[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms) ) - forgotten = yield self.store.did_forget(user_id, room_id) + forgotten = await self.store.did_forget(user_id, room_id) if forgotten: logger.info("[%s] User forgot room %d, ignoring", user_id, room_id) continue @@ -166,7 +159,7 @@ class AdminHandler(BaseHandler): if room.membership == Membership.INVITE: event_id = room.event_id - invite = yield self.store.get_event(event_id, allow_none=True) + invite = await self.store.get_event(event_id, allow_none=True) if invite: invited_state = invite.unsigned["invite_room_state"] writer.write_invite(room_id, invite, invited_state) @@ -177,7 +170,7 @@ class AdminHandler(BaseHandler): # were joined. We estimate that point by looking at the # stream_ordering of the last membership if it wasn't a join. if room.membership == Membership.JOIN: - stream_ordering = yield self.store.get_room_max_stream_ordering() + stream_ordering = self.store.get_room_max_stream_ordering() else: stream_ordering = room.stream_ordering @@ -203,7 +196,7 @@ class AdminHandler(BaseHandler): # events that we have and then filtering, this isn't the most # efficient method perhaps but it does guarantee we get everything. while True: - events, _ = yield self.store.paginate_room_events( + events, _ = await self.store.paginate_room_events( room_id, from_key, to_key, limit=100, direction="f" ) if not events: @@ -211,7 +204,7 @@ class AdminHandler(BaseHandler): from_key = events[-1].internal_metadata.after - events = yield filter_events_for_client(self.storage, user_id, events) + events = await filter_events_for_client(self.storage, user_id, events) writer.write_events(room_id, events) @@ -247,7 +240,7 @@ class AdminHandler(BaseHandler): for event_id in extremities: if not event_to_unseen_prevs[event_id]: continue - state = yield self.state_store.get_state_for_event(event_id) + state = await self.state_store.get_state_for_event(event_id) writer.write_state(room_id, event_id, state) return writer.finished() diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 6dedaaff8d..4426967f88 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -15,8 +15,6 @@ # limitations under the License. import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID, create_requester @@ -46,8 +44,7 @@ class DeactivateAccountHandler(BaseHandler): self._account_validity_enabled = hs.config.account_validity.enabled - @defer.inlineCallbacks - def deactivate_account(self, user_id, erase_data, id_server=None): + async def deactivate_account(self, user_id, erase_data, id_server=None): """Deactivate a user's account Args: @@ -74,11 +71,11 @@ class DeactivateAccountHandler(BaseHandler): identity_server_supports_unbinding = True # Retrieve the 3PIDs this user has bound to an identity server - threepids = yield self.store.user_get_bound_threepids(user_id) + threepids = await self.store.user_get_bound_threepids(user_id) for threepid in threepids: try: - result = yield self._identity_handler.try_unbind_threepid( + result = await self._identity_handler.try_unbind_threepid( user_id, { "medium": threepid["medium"], @@ -91,33 +88,33 @@ class DeactivateAccountHandler(BaseHandler): # Do we want this to be a fatal error or should we carry on? logger.exception("Failed to remove threepid from ID server") raise SynapseError(400, "Failed to remove threepid from ID server") - yield self.store.user_delete_threepid( + await self.store.user_delete_threepid( user_id, threepid["medium"], threepid["address"] ) # Remove all 3PIDs this user has bound to the homeserver - yield self.store.user_delete_threepids(user_id) + await self.store.user_delete_threepids(user_id) # delete any devices belonging to the user, which will also # delete corresponding access tokens. - yield self._device_handler.delete_all_devices_for_user(user_id) + await self._device_handler.delete_all_devices_for_user(user_id) # then delete any remaining access tokens which weren't associated with # a device. - yield self._auth_handler.delete_access_tokens_for_user(user_id) + await self._auth_handler.delete_access_tokens_for_user(user_id) - yield self.store.user_set_password_hash(user_id, None) + await self.store.user_set_password_hash(user_id, None) # Add the user to a table of users pending deactivation (ie. # removal from all the rooms they're a member of) - yield self.store.add_user_pending_deactivation(user_id) + await self.store.add_user_pending_deactivation(user_id) # delete from user directory - yield self.user_directory_handler.handle_user_deactivated(user_id) + await self.user_directory_handler.handle_user_deactivated(user_id) # Mark the user as erased, if they asked for that if erase_data: logger.info("Marking %s as erased", user_id) - yield self.store.mark_user_erased(user_id) + await self.store.mark_user_erased(user_id) # Now start the process that goes through that list and # parts users from rooms (if it isn't already running) @@ -125,30 +122,29 @@ class DeactivateAccountHandler(BaseHandler): # Reject all pending invites for the user, so that the user doesn't show up in the # "invited" section of rooms' members list. - yield self._reject_pending_invites_for_user(user_id) + await self._reject_pending_invites_for_user(user_id) # Remove all information on the user from the account_validity table. if self._account_validity_enabled: - yield self.store.delete_account_validity_for_user(user_id) + await self.store.delete_account_validity_for_user(user_id) # Mark the user as deactivated. - yield self.store.set_user_deactivated_status(user_id, True) + await self.store.set_user_deactivated_status(user_id, True) return identity_server_supports_unbinding - @defer.inlineCallbacks - def _reject_pending_invites_for_user(self, user_id): + async def _reject_pending_invites_for_user(self, user_id): """Reject pending invites addressed to a given user ID. Args: user_id (str): The user ID to reject pending invites for. """ user = UserID.from_string(user_id) - pending_invites = yield self.store.get_invited_rooms_for_user(user_id) + pending_invites = await self.store.get_invited_rooms_for_user(user_id) for room in pending_invites: try: - yield self._room_member_handler.update_membership( + await self._room_member_handler.update_membership( create_requester(user), user, room.room_id, @@ -180,8 +176,7 @@ class DeactivateAccountHandler(BaseHandler): if not self._user_parter_running: run_as_background_process("user_parter_loop", self._user_parter_loop) - @defer.inlineCallbacks - def _user_parter_loop(self): + async def _user_parter_loop(self): """Loop that parts deactivated users from rooms Returns: @@ -191,19 +186,18 @@ class DeactivateAccountHandler(BaseHandler): logger.info("Starting user parter") try: while True: - user_id = yield self.store.get_user_pending_deactivation() + user_id = await self.store.get_user_pending_deactivation() if user_id is None: break logger.info("User parter parting %r", user_id) - yield self._part_user(user_id) - yield self.store.del_user_pending_deactivation(user_id) + await self._part_user(user_id) + await self.store.del_user_pending_deactivation(user_id) logger.info("User parter finished parting %r", user_id) logger.info("User parter finished: stopping") finally: self._user_parter_running = False - @defer.inlineCallbacks - def _part_user(self, user_id): + async def _part_user(self, user_id): """Causes the given user_id to leave all the rooms they're joined to Returns: @@ -211,11 +205,11 @@ class DeactivateAccountHandler(BaseHandler): """ user = UserID.from_string(user_id) - rooms_for_user = yield self.store.get_rooms_for_user(user_id) + rooms_for_user = await self.store.get_rooms_for_user(user_id) for room_id in rooms_for_user: logger.info("User parter parting %r from %r", user_id, room_id) try: - yield self._room_member_handler.update_membership( + await self._room_member_handler.update_membership( create_requester(user), user, room_id, diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 57a10daefd..2d889364d4 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -264,6 +264,7 @@ class E2eKeysHandler(object): return ret + @defer.inlineCallbacks def get_cross_signing_keys_from_cache(self, query, from_user_id): """Get cross-signing keys for users from the database @@ -283,14 +284,32 @@ class E2eKeysHandler(object): self_signing_keys = {} user_signing_keys = {} - # Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486 - return defer.succeed( - { - "master_keys": master_keys, - "self_signing_keys": self_signing_keys, - "user_signing_keys": user_signing_keys, - } - ) + user_ids = list(query) + + keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id) + + for user_id, user_info in keys.items(): + if user_info is None: + continue + if "master" in user_info: + master_keys[user_id] = user_info["master"] + if "self_signing" in user_info: + self_signing_keys[user_id] = user_info["self_signing"] + + if ( + from_user_id in keys + and keys[from_user_id] is not None + and "user_signing" in keys[from_user_id] + ): + # users can see other users' master and self-signing keys, but can + # only see their own user-signing keys + user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"] + + return { + "master_keys": master_keys, + "self_signing_keys": self_signing_keys, + "user_signing_keys": user_signing_keys, + } @trace @defer.inlineCallbacks diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 6fb453ce60..72a0febc2b 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -19,7 +19,7 @@ import itertools import logging -from typing import Dict, Iterable, Optional, Sequence, Tuple +from typing import Dict, Iterable, List, Optional, Sequence, Tuple import six from six import iteritems, itervalues @@ -63,6 +63,7 @@ from synapse.replication.http.federation import ( ) from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.types import UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room @@ -163,8 +164,7 @@ class FederationHandler(BaseHandler): self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages - @defer.inlineCallbacks - def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): + async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: """ Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events @@ -174,17 +174,15 @@ class FederationHandler(BaseHandler): pdu (FrozenEvent): received PDU sent_to_us_directly (bool): True if this event was pushed to us; False if we pulled it as the result of a missing prev_event. - - Returns (Deferred): completes with None """ room_id = pdu.room_id event_id = pdu.event_id - logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu) + logger.info("handling received PDU: %s", pdu) # We reprocess pdus when we have seen them only as outliers - existing = yield self.store.get_event( + existing = await self.store.get_event( event_id, allow_none=True, allow_rejected=True ) @@ -228,7 +226,7 @@ class FederationHandler(BaseHandler): # # Note that if we were never in the room then we would have already # dropped the event, since we wouldn't know the room version. - is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name) + is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) if not is_in_room: logger.info( "[%s %s] Ignoring PDU from %s as we're not in the room", @@ -243,12 +241,12 @@ class FederationHandler(BaseHandler): # Get missing pdus if necessary. if not pdu.internal_metadata.is_outlier(): # We only backfill backwards to the min depth. - min_depth = yield self.get_min_depth_for_context(pdu.room_id) + min_depth = await self.get_min_depth_for_context(pdu.room_id) logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth) prevs = set(pdu.prev_event_ids()) - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) if min_depth and pdu.depth < min_depth: # This is so that we don't notify the user about this @@ -268,7 +266,7 @@ class FederationHandler(BaseHandler): len(missing_prevs), shortstr(missing_prevs), ) - with (yield self._room_pdu_linearizer.queue(pdu.room_id)): + with (await self._room_pdu_linearizer.queue(pdu.room_id)): logger.info( "[%s %s] Acquired room lock to fetch %d missing prev_events", room_id, @@ -276,13 +274,19 @@ class FederationHandler(BaseHandler): len(missing_prevs), ) - yield self._get_missing_events_for_pdu( - origin, pdu, prevs, min_depth - ) + try: + await self._get_missing_events_for_pdu( + origin, pdu, prevs, min_depth + ) + except Exception as e: + raise Exception( + "Error fetching missing prev_events for %s: %s" + % (event_id, e) + ) # Update the set of things we've seen after trying to # fetch the missing stuff - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) if not prevs - seen: logger.info( @@ -290,14 +294,6 @@ class FederationHandler(BaseHandler): room_id, event_id, ) - elif missing_prevs: - logger.info( - "[%s %s] Not recursively fetching %d missing prev_events: %s", - room_id, - event_id, - len(missing_prevs), - shortstr(missing_prevs), - ) if prevs - seen: # We've still not been able to get all of the prev_events for this event. @@ -342,12 +338,18 @@ class FederationHandler(BaseHandler): affected=pdu.event_id, ) + logger.info( + "Event %s is missing prev_events: calculating state for a " + "backwards extremity", + event_id, + ) + # Calculate the state after each of the previous events, and # resolve them to find the correct state at the current event. event_map = {event_id: pdu} try: # Get the state of the events we know about - ours = yield self.state_store.get_state_groups_ids(room_id, seen) + ours = await self.state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id state_maps = list( @@ -361,17 +363,14 @@ class FederationHandler(BaseHandler): # know about for p in prevs - seen: logger.info( - "[%s %s] Requesting state at missing prev_event %s", - room_id, - event_id, - p, + "Requesting state at missing prev_event %s", event_id, ) with nested_logging_context(p): # note that if any of the missing prevs share missing state or # auth events, the requests to fetch those events are deduped # by the get_pdu_cache in federation_client. - (remote_state, _,) = yield self._get_state_for_room( + (remote_state, _,) = await self._get_state_for_room( origin, room_id, p, include_event_in_state=True ) @@ -383,8 +382,8 @@ class FederationHandler(BaseHandler): for x in remote_state: event_map[x.event_id] = x - room_version = yield self.store.get_room_version(room_id) - state_map = yield resolve_events_with_store( + room_version = await self.store.get_room_version(room_id) + state_map = await resolve_events_with_store( room_id, room_version, state_maps, @@ -397,10 +396,10 @@ class FederationHandler(BaseHandler): # First though we need to fetch all the events that are in # state_map, so we can build up the state below. - evs = yield self.store.get_events( + evs = await self.store.get_events( list(state_map.values()), get_prev_content=False, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, ) event_map.update(evs) @@ -420,10 +419,9 @@ class FederationHandler(BaseHandler): affected=event_id, ) - yield self._process_received_pdu(origin, pdu, state=state) + await self._process_received_pdu(origin, pdu, state=state) - @defer.inlineCallbacks - def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): + async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): """ Args: origin (str): Origin of the pdu. Will be called to get the missing events @@ -435,12 +433,12 @@ class FederationHandler(BaseHandler): room_id = pdu.room_id event_id = pdu.event_id - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) if not prevs - seen: return - latest = yield self.store.get_latest_event_ids_in_room(room_id) + latest = await self.store.get_latest_event_ids_in_room(room_id) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us @@ -504,7 +502,7 @@ class FederationHandler(BaseHandler): # All that said: Let's try increasing the timout to 60s and see what happens. try: - missing_events = yield self.federation_client.get_missing_events( + missing_events = await self.federation_client.get_missing_events( origin, room_id, earliest_events_ids=list(latest), @@ -543,7 +541,7 @@ class FederationHandler(BaseHandler): ) with nested_logging_context(ev.event_id): try: - yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False) + await self.on_receive_pdu(origin, ev, sent_to_us_directly=False) except FederationError as e: if e.code == 403: logger.warning( @@ -555,29 +553,30 @@ class FederationHandler(BaseHandler): else: raise - @defer.inlineCallbacks - @log_function - def _get_state_for_room( - self, destination, room_id, event_id, include_event_in_state - ): + async def _get_state_for_room( + self, + destination: str, + room_id: str, + event_id: str, + include_event_in_state: bool = False, + ) -> Tuple[List[EventBase], List[EventBase]]: """Requests all of the room state at a given event from a remote homeserver. Args: - destination (str): The remote homeserver to query for the state. - room_id (str): The id of the room we're interested in. - event_id (str): The id of the event we want the state at. + destination: The remote homeserver to query for the state. + room_id: The id of the room we're interested in. + event_id: The id of the event we want the state at. include_event_in_state: if true, the event itself will be included in the returned state event list. Returns: - Deferred[Tuple[List[EventBase], List[EventBase]]]: - A list of events in the state, and a list of events in the auth chain - for the given event. + A list of events in the state, possibly including the event itself, and + a list of events in the auth chain for the given event. """ ( state_event_ids, auth_event_ids, - ) = yield self.federation_client.get_room_state_ids( + ) = await self.federation_client.get_room_state_ids( destination, room_id, event_id=event_id ) @@ -586,15 +585,15 @@ class FederationHandler(BaseHandler): if include_event_in_state: desired_events.add(event_id) - event_map = yield self._get_events_from_store_or_dest( + event_map = await self._get_events_from_store_or_dest( destination, room_id, desired_events ) failed_to_fetch = desired_events - event_map.keys() if failed_to_fetch: logger.warning( - "Failed to fetch missing state/auth events for %s: %s", - room_id, + "Failed to fetch missing state/auth events for %s %s", + event_id, failed_to_fetch, ) @@ -614,15 +613,11 @@ class FederationHandler(BaseHandler): return remote_state, auth_chain - @defer.inlineCallbacks - def _get_events_from_store_or_dest(self, destination, room_id, event_ids): + async def _get_events_from_store_or_dest( + self, destination: str, room_id: str, event_ids: Iterable[str] + ) -> Dict[str, EventBase]: """Fetch events from a remote destination, checking if we already have them. - Args: - destination (str) - room_id (str) - event_ids (Iterable[str]) - Persists any events we don't already have as outliers. If we fail to fetch any of the events, a warning will be logged, and the event @@ -630,10 +625,9 @@ class FederationHandler(BaseHandler): be in the given room. Returns: - Deferred[dict[str, EventBase]]: A deferred resolving to a map - from event_id to event + map from event_id to event """ - fetched_events = yield self.store.get_events(event_ids, allow_rejected=True) + fetched_events = await self.store.get_events(event_ids, allow_rejected=True) missing_events = set(event_ids) - fetched_events.keys() @@ -644,14 +638,14 @@ class FederationHandler(BaseHandler): room_id, ) - yield self._get_events_and_persist( + await self._get_events_and_persist( destination=destination, room_id=room_id, events=missing_events ) # we need to make sure we re-load from the database to get the rejected # state correct. fetched_events.update( - (yield self.store.get_events(missing_events, allow_rejected=True)) + (await self.store.get_events(missing_events, allow_rejected=True)) ) # check for events which were in the wrong room. @@ -677,12 +671,14 @@ class FederationHandler(BaseHandler): bad_room_id, room_id, ) + del fetched_events[bad_event_id] return fetched_events - @defer.inlineCallbacks - def _process_received_pdu(self, origin, event, state): + async def _process_received_pdu( + self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]], + ): """ Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. @@ -701,15 +697,15 @@ class FederationHandler(BaseHandler): logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) try: - context = yield self._handle_new_event(origin, event, state=state) + context = await self._handle_new_event(origin, event, state=state) except AuthError as e: raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) - room = yield self.store.get_room(room_id) + room = await self.store.get_room(room_id) if not room: try: - yield self.store.store_room( + await self.store.store_room( room_id=room_id, room_creator_user_id="", is_public=False ) except StoreError: @@ -722,11 +718,11 @@ class FederationHandler(BaseHandler): # changing their profile info. newly_joined = True - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = await context.get_prev_state_ids() prev_state_id = prev_state_ids.get((event.type, event.state_key)) if prev_state_id: - prev_state = yield self.store.get_event( + prev_state = await self.store.get_event( prev_state_id, allow_none=True ) if prev_state and prev_state.membership == Membership.JOIN: @@ -734,11 +730,10 @@ class FederationHandler(BaseHandler): if newly_joined: user = UserID.from_string(event.state_key) - yield self.user_joined_room(user, room_id) + await self.user_joined_room(user, room_id) @log_function - @defer.inlineCallbacks - def backfill(self, dest, room_id, limit, extremities): + async def backfill(self, dest, room_id, limit, extremities): """ Trigger a backfill request to `dest` for the given `room_id` This will attempt to get more events from the remote. If the other side @@ -755,7 +750,7 @@ class FederationHandler(BaseHandler): if dest == self.server_name: raise SynapseError(400, "Can't backfill from self.") - events = yield self.federation_client.backfill( + events = await self.federation_client.backfill( dest, room_id, limit=limit, extremities=extremities ) @@ -770,7 +765,7 @@ class FederationHandler(BaseHandler): # self._sanity_check_event(ev) # Don't bother processing events we already have. - seen_events = yield self.store.have_events_in_timeline( + seen_events = await self.store.have_events_in_timeline( set(e.event_id for e in events) ) @@ -796,7 +791,7 @@ class FederationHandler(BaseHandler): state_events = {} events_to_state = {} for e_id in edges: - state, auth = yield self._get_state_for_room( + state, auth = await self._get_state_for_room( destination=dest, room_id=room_id, event_id=e_id, @@ -843,7 +838,7 @@ class FederationHandler(BaseHandler): ) ) - yield self._handle_new_events(dest, ev_infos, backfilled=True) + await self._handle_new_events(dest, ev_infos, backfilled=True) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -859,16 +854,15 @@ class FederationHandler(BaseHandler): # We store these one at a time since each event depends on the # previous to work out the state. # TODO: We can probably do something more clever here. - yield self._handle_new_event(dest, event, backfilled=True) + await self._handle_new_event(dest, event, backfilled=True) return events - @defer.inlineCallbacks - def maybe_backfill(self, room_id, current_depth): + async def maybe_backfill(self, room_id, current_depth): """Checks the database to see if we should backfill before paginating, and if so do. """ - extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id) + extremities = await self.store.get_oldest_events_with_depth_in_room(room_id) if not extremities: logger.debug("Not backfilling as no extremeties found.") @@ -900,15 +894,17 @@ class FederationHandler(BaseHandler): # state *before* the event, ignoring the special casing certain event # types have. - forward_events = yield self.store.get_successor_events(list(extremities)) + forward_events = await self.store.get_successor_events(list(extremities)) - extremities_events = yield self.store.get_events( - forward_events, check_redacted=False, get_prev_content=False + extremities_events = await self.store.get_events( + forward_events, + redact_behaviour=EventRedactBehaviour.AS_IS, + get_prev_content=False, ) # We set `check_history_visibility_only` as we might otherwise get false # positives from users having been erased. - filtered_extremities = yield filter_events_for_server( + filtered_extremities = await filter_events_for_server( self.storage, self.server_name, list(extremities_events.values()), @@ -938,7 +934,7 @@ class FederationHandler(BaseHandler): # First we try hosts that are already in the room # TODO: HEURISTIC ALERT. - curr_state = yield self.state_handler.get_current_state(room_id) + curr_state = await self.state_handler.get_current_state(room_id) def get_domains_from_state(state): """Get joined domains from state @@ -977,12 +973,11 @@ class FederationHandler(BaseHandler): domain for domain, depth in curr_domains if domain != self.server_name ] - @defer.inlineCallbacks - def try_backfill(domains): + async def try_backfill(domains): # TODO: Should we try multiple of these at a time? for dom in domains: try: - yield self.backfill( + await self.backfill( dom, room_id, limit=100, extremities=extremities ) # If this succeeded then we probably already have the @@ -1013,7 +1008,7 @@ class FederationHandler(BaseHandler): return False - success = yield try_backfill(likely_domains) + success = await try_backfill(likely_domains) if success: return True @@ -1027,7 +1022,7 @@ class FederationHandler(BaseHandler): logger.debug("calling resolve_state_groups in _maybe_backfill") resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events) - states = yield make_deferred_yieldable( + states = await make_deferred_yieldable( defer.gatherResults( [resolve(room_id, [e]) for e in event_ids], consumeErrors=True ) @@ -1037,7 +1032,7 @@ class FederationHandler(BaseHandler): # event_ids. states = dict(zip(event_ids, [s.state for s in states])) - state_map = yield self.store.get_events( + state_map = await self.store.get_events( [e_id for ids in itervalues(states) for e_id in itervalues(ids)], get_prev_content=False, ) @@ -1053,7 +1048,7 @@ class FederationHandler(BaseHandler): for e_id, _ in sorted_extremeties_tuple: likely_domains = get_domains_from_state(states[e_id]) - success = yield try_backfill( + success = await try_backfill( [dom for dom, _ in likely_domains if dom not in tried_domains] ) if success: @@ -1063,8 +1058,7 @@ class FederationHandler(BaseHandler): return False - @defer.inlineCallbacks - def _get_events_and_persist( + async def _get_events_and_persist( self, destination: str, room_id: str, events: Iterable[str] ): """Fetch the given events from a server, and persist them as outliers. @@ -1072,7 +1066,7 @@ class FederationHandler(BaseHandler): Logs a warning if we can't find the given event. """ - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) event_infos = [] @@ -1108,9 +1102,9 @@ class FederationHandler(BaseHandler): e, ) - yield concurrently_execute(get_event, events, 5) + await concurrently_execute(get_event, events, 5) - yield self._handle_new_events( + await self._handle_new_events( destination, event_infos, ) @@ -1253,7 +1247,7 @@ class FederationHandler(BaseHandler): # Check whether this room is the result of an upgrade of a room we already know # about. If so, migrate over user information predecessor = yield self.store.get_room_predecessor(room_id) - if not predecessor: + if not predecessor or not isinstance(predecessor.get("room_id"), str): return old_room_id = predecessor["room_id"] logger.debug( @@ -1281,8 +1275,7 @@ class FederationHandler(BaseHandler): return True - @defer.inlineCallbacks - def _handle_queued_pdus(self, room_queue): + async def _handle_queued_pdus(self, room_queue): """Process PDUs which got queued up while we were busy send_joining. Args: @@ -1298,7 +1291,7 @@ class FederationHandler(BaseHandler): p.room_id, ) with nested_logging_context(p.event_id): - yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) + await self.on_receive_pdu(origin, p, sent_to_us_directly=True) except Exception as e: logger.warning( "Error handling queued PDU %s from %s: %s", p.event_id, origin, e @@ -1428,7 +1421,7 @@ class FederationHandler(BaseHandler): user = UserID.from_string(event.state_key) yield self.user_joined_room(user, event.room_id) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() state_ids = list(prev_state_ids.values()) auth_chain = yield self.store.get_auth_chain(state_ids) @@ -1496,7 +1489,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content): origin, event, event_format_version = yield self._make_and_verify_event( - target_hosts, room_id, user_id, "leave", content=content, + target_hosts, room_id, user_id, "leave", content=content ) # Mark as outlier as we don't have any state for this event; we're not # even in the room. @@ -1937,7 +1930,7 @@ class FederationHandler(BaseHandler): context = yield self.state_handler.compute_event_context(event, old_state=state) if not auth_events: - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() auth_events_ids = yield self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) @@ -2346,12 +2339,12 @@ class FederationHandler(BaseHandler): k: a.event_id for k, a in iteritems(auth_events) if k != event_key } - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() current_state_ids = dict(current_state_ids) current_state_ids.update(state_updates) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = dict(prev_state_ids) prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)}) @@ -2635,7 +2628,7 @@ class FederationHandler(BaseHandler): event.content["third_party_invite"]["signed"]["token"], ) original_invite = None - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() original_invite_id = prev_state_ids.get(key) if original_invite_id: original_invite = yield self.store.get_event( @@ -2683,7 +2676,7 @@ class FederationHandler(BaseHandler): signed = event.content["third_party_invite"]["signed"] token = signed["token"] - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token)) invite_event = None @@ -2857,7 +2850,7 @@ class FederationHandler(BaseHandler): room_id=room_id, user_id=user.to_string(), change="joined" ) else: - return user_joined_room(self.distributor, user, room_id) + return defer.succeed(user_joined_room(self.distributor, user, room_id)) @defer.inlineCallbacks def get_room_complexity(self, remote_room_hosts, room_id): diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 81dce96f4b..44ec3e66ae 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -26,7 +26,7 @@ from synapse.streams.config import PaginationConfig from synapse.types import StreamToken, UserID from synapse.util import unwrapFirstError from synapse.util.async_helpers import concurrently_execute -from synapse.util.caches.snapshot_cache import SnapshotCache +from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -41,7 +41,7 @@ class InitialSyncHandler(BaseHandler): self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() - self.snapshot_cache = SnapshotCache() + self.snapshot_cache = ResponseCache(hs, "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state @@ -79,21 +79,17 @@ class InitialSyncHandler(BaseHandler): as_client_event, include_archived, ) - now_ms = self.clock.time_msec() - result = self.snapshot_cache.get(now_ms, key) - if result is not None: - return result - return self.snapshot_cache.set( - now_ms, + return self.snapshot_cache.wrap( key, - self._snapshot_all_rooms( - user_id, pagin_config, as_client_event, include_archived - ), + self._snapshot_all_rooms, + user_id, + pagin_config, + as_client_event, + include_archived, ) - @defer.inlineCallbacks - def _snapshot_all_rooms( + async def _snapshot_all_rooms( self, user_id=None, pagin_config=None, @@ -105,7 +101,7 @@ class InitialSyncHandler(BaseHandler): if include_archived: memberships.append(Membership.LEAVE) - room_list = yield self.store.get_rooms_for_user_where_membership_is( + room_list = await self.store.get_rooms_for_user_where_membership_is( user_id=user_id, membership_list=memberships ) @@ -113,33 +109,32 @@ class InitialSyncHandler(BaseHandler): rooms_ret = [] - now_token = yield self.hs.get_event_sources().get_current_token() + now_token = await self.hs.get_event_sources().get_current_token() presence_stream = self.hs.get_event_sources().sources["presence"] pagination_config = PaginationConfig(from_token=now_token) - presence, _ = yield presence_stream.get_pagination_rows( + presence, _ = await presence_stream.get_pagination_rows( user, pagination_config.get_source_config("presence"), None ) receipt_stream = self.hs.get_event_sources().sources["receipt"] - receipt, _ = yield receipt_stream.get_pagination_rows( + receipt, _ = await receipt_stream.get_pagination_rows( user, pagination_config.get_source_config("receipt"), None ) - tags_by_room = yield self.store.get_tags_for_user(user_id) + tags_by_room = await self.store.get_tags_for_user(user_id) - account_data, account_data_by_room = yield self.store.get_account_data_for_user( + account_data, account_data_by_room = await self.store.get_account_data_for_user( user_id ) - public_room_ids = yield self.store.get_public_room_ids() + public_room_ids = await self.store.get_public_room_ids() limit = pagin_config.limit if limit is None: limit = 10 - @defer.inlineCallbacks - def handle_room(event): + async def handle_room(event): d = { "room_id": event.room_id, "membership": event.membership, @@ -152,8 +147,8 @@ class InitialSyncHandler(BaseHandler): time_now = self.clock.time_msec() d["inviter"] = event.sender - invite_event = yield self.store.get_event(event.event_id) - d["invite"] = yield self._event_serializer.serialize_event( + invite_event = await self.store.get_event(event.event_id) + d["invite"] = await self._event_serializer.serialize_event( invite_event, time_now, as_client_event ) @@ -177,7 +172,7 @@ class InitialSyncHandler(BaseHandler): lambda states: states[event.event_id] ) - (messages, token), current_state = yield make_deferred_yieldable( + (messages, token), current_state = await make_deferred_yieldable( defer.gatherResults( [ run_in_background( @@ -191,7 +186,7 @@ class InitialSyncHandler(BaseHandler): ) ).addErrback(unwrapFirstError) - messages = yield filter_events_for_client( + messages = await filter_events_for_client( self.storage, user_id, messages ) @@ -201,7 +196,7 @@ class InitialSyncHandler(BaseHandler): d["messages"] = { "chunk": ( - yield self._event_serializer.serialize_events( + await self._event_serializer.serialize_events( messages, time_now=time_now, as_client_event=as_client_event ) ), @@ -209,7 +204,7 @@ class InitialSyncHandler(BaseHandler): "end": end_token.to_string(), } - d["state"] = yield self._event_serializer.serialize_events( + d["state"] = await self._event_serializer.serialize_events( current_state.values(), time_now=time_now, as_client_event=as_client_event, @@ -232,7 +227,7 @@ class InitialSyncHandler(BaseHandler): except Exception: logger.exception("Failed to get snapshot") - yield concurrently_execute(handle_room, room_list, 10) + await concurrently_execute(handle_room, room_list, 10) account_data_events = [] for account_data_type, content in account_data.items(): @@ -256,8 +251,7 @@ class InitialSyncHandler(BaseHandler): return ret - @defer.inlineCallbacks - def room_initial_sync(self, requester, room_id, pagin_config=None): + async def room_initial_sync(self, requester, room_id, pagin_config=None): """Capture the a snapshot of a room. If user is currently a member of the room this will be what is currently in the room. If the user left the room this will be what was in the room when they left. @@ -274,32 +268,32 @@ class InitialSyncHandler(BaseHandler): A JSON serialisable dict with the snapshot of the room. """ - blocked = yield self.store.is_room_blocked(room_id) + blocked = await self.store.is_room_blocked(room_id) if blocked: raise SynapseError(403, "This room has been blocked on this server") user_id = requester.user.to_string() - membership, member_event_id = yield self._check_in_room_or_world_readable( + membership, member_event_id = await self._check_in_room_or_world_readable( room_id, user_id ) is_peeking = member_event_id is None if membership == Membership.JOIN: - result = yield self._room_initial_sync_joined( + result = await self._room_initial_sync_joined( user_id, room_id, pagin_config, membership, is_peeking ) elif membership == Membership.LEAVE: - result = yield self._room_initial_sync_parted( + result = await self._room_initial_sync_parted( user_id, room_id, pagin_config, membership, member_event_id, is_peeking ) account_data_events = [] - tags = yield self.store.get_tags_for_room(user_id, room_id) + tags = await self.store.get_tags_for_room(user_id, room_id) if tags: account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) - account_data = yield self.store.get_account_data_for_room(user_id, room_id) + account_data = await self.store.get_account_data_for_room(user_id, room_id) for account_data_type, content in account_data.items(): account_data_events.append({"type": account_data_type, "content": content}) @@ -307,11 +301,10 @@ class InitialSyncHandler(BaseHandler): return result - @defer.inlineCallbacks - def _room_initial_sync_parted( + async def _room_initial_sync_parted( self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking ): - room_state = yield self.state_store.get_state_for_events([member_event_id]) + room_state = await self.state_store.get_state_for_events([member_event_id]) room_state = room_state[member_event_id] @@ -319,13 +312,13 @@ class InitialSyncHandler(BaseHandler): if limit is None: limit = 10 - stream_token = yield self.store.get_stream_token_for_event(member_event_id) + stream_token = await self.store.get_stream_token_for_event(member_event_id) - messages, token = yield self.store.get_recent_events_for_room( + messages, token = await self.store.get_recent_events_for_room( room_id, limit=limit, end_token=stream_token ) - messages = yield filter_events_for_client( + messages = await filter_events_for_client( self.storage, user_id, messages, is_peeking=is_peeking ) @@ -339,13 +332,13 @@ class InitialSyncHandler(BaseHandler): "room_id": room_id, "messages": { "chunk": ( - yield self._event_serializer.serialize_events(messages, time_now) + await self._event_serializer.serialize_events(messages, time_now) ), "start": start_token.to_string(), "end": end_token.to_string(), }, "state": ( - yield self._event_serializer.serialize_events( + await self._event_serializer.serialize_events( room_state.values(), time_now ) ), @@ -353,19 +346,18 @@ class InitialSyncHandler(BaseHandler): "receipts": [], } - @defer.inlineCallbacks - def _room_initial_sync_joined( + async def _room_initial_sync_joined( self, user_id, room_id, pagin_config, membership, is_peeking ): - current_state = yield self.state.get_current_state(room_id=room_id) + current_state = await self.state.get_current_state(room_id=room_id) # TODO: These concurrently time_now = self.clock.time_msec() - state = yield self._event_serializer.serialize_events( + state = await self._event_serializer.serialize_events( current_state.values(), time_now ) - now_token = yield self.hs.get_event_sources().get_current_token() + now_token = await self.hs.get_event_sources().get_current_token() limit = pagin_config.limit if pagin_config else None if limit is None: @@ -380,28 +372,26 @@ class InitialSyncHandler(BaseHandler): presence_handler = self.hs.get_presence_handler() - @defer.inlineCallbacks - def get_presence(): + async def get_presence(): # If presence is disabled, return an empty list if not self.hs.config.use_presence: return [] - states = yield presence_handler.get_states( + states = await presence_handler.get_states( [m.user_id for m in room_members], as_event=True ) return states - @defer.inlineCallbacks - def get_receipts(): - receipts = yield self.store.get_linearized_receipts_for_room( + async def get_receipts(): + receipts = await self.store.get_linearized_receipts_for_room( room_id, to_key=now_token.receipt_key ) if not receipts: receipts = [] return receipts - presence, receipts, (messages, token) = yield make_deferred_yieldable( + presence, receipts, (messages, token) = await make_deferred_yieldable( defer.gatherResults( [ run_in_background(get_presence), @@ -417,7 +407,7 @@ class InitialSyncHandler(BaseHandler): ).addErrback(unwrapFirstError) ) - messages = yield filter_events_for_client( + messages = await filter_events_for_client( self.storage, user_id, messages, is_peeking=is_peeking ) @@ -430,7 +420,7 @@ class InitialSyncHandler(BaseHandler): "room_id": room_id, "messages": { "chunk": ( - yield self._event_serializer.serialize_events(messages, time_now) + await self._event_serializer.serialize_events(messages, time_now) ), "start": start_token.to_string(), "end": end_token.to_string(), @@ -444,18 +434,17 @@ class InitialSyncHandler(BaseHandler): return ret - @defer.inlineCallbacks - def _check_in_room_or_world_readable(self, room_id, user_id): + async def _check_in_room_or_world_readable(self, room_id, user_id): try: # check_user_was_in_room will return the most recent membership # event for the user if: # * The user is a non-guest user, and was ever in the room # * The user is a guest user, and has joined the room # else it will throw. - member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + member_event = await self.auth.check_user_was_in_room(room_id, user_id) return member_event.membership, member_event.event_id except AuthError: - visibility = yield self.state_handler.get_current_state( + visibility = await self.state_handler.get_current_state( room_id, EventTypes.RoomHistoryVisibility, "" ) if ( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 54fa216d83..4ad752205f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -46,6 +46,7 @@ from synapse.events.validator import EventValidator from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester from synapse.util.async_helpers import Linearizer @@ -514,7 +515,7 @@ class EventCreationHandler(object): # federation as well as those created locally. As of room v3, aliases events # can be created by users that are not in the room, therefore we have to # tolerate them in event_auth.check(). - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) prev_event = ( yield self.store.get_event(prev_event_id, allow_none=True) @@ -664,7 +665,7 @@ class EventCreationHandler(object): If so, returns the version of the event in context. Otherwise, returns None. """ - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_event_id = prev_state_ids.get((event.type, event.state_key)) if not prev_event_id: return @@ -875,7 +876,7 @@ class EventCreationHandler(object): if event.type == EventTypes.Redaction: original_event = yield self.store.get_event( event.redacts, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, allow_rejected=False, allow_none=True, @@ -913,7 +914,7 @@ class EventCreationHandler(object): def is_inviter_member_event(e): return e.type == EventTypes.Member and e.sender == event.sender - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() state_to_include_ids = [ e_id @@ -952,7 +953,7 @@ class EventCreationHandler(object): if event.type == EventTypes.Redaction: original_event = yield self.store.get_event( event.redacts, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, allow_rejected=False, allow_none=True, @@ -966,7 +967,7 @@ class EventCreationHandler(object): if original_event.room_id != event.room_id: raise SynapseError(400, "Cannot redact event from a different room") - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() auth_events_ids = yield self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) @@ -988,7 +989,7 @@ class EventCreationHandler(object): event.internal_metadata.recheck_redaction = False if event.type == EventTypes.Create: - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 8514ddc600..00a6afc963 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -280,8 +280,7 @@ class PaginationHandler(object): await self.storage.purge_events.purge_room(room_id) - @defer.inlineCallbacks - def get_messages( + async def get_messages( self, requester, room_id=None, @@ -307,7 +306,7 @@ class PaginationHandler(object): room_token = pagin_config.from_token.room_key else: pagin_config.from_token = ( - yield self.hs.get_event_sources().get_current_token_for_pagination() + await self.hs.get_event_sources().get_current_token_for_pagination() ) room_token = pagin_config.from_token.room_key @@ -319,11 +318,11 @@ class PaginationHandler(object): source_config = pagin_config.get_source_config("room") - with (yield self.pagination_lock.read(room_id)): + with (await self.pagination_lock.read(room_id)): ( membership, member_event_id, - ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) + ) = await self.auth.check_in_room_or_world_readable(room_id, user_id) if source_config.direction == "b": # if we're going backwards, we might need to backfill. This @@ -331,7 +330,7 @@ class PaginationHandler(object): if room_token.topological: max_topo = room_token.topological else: - max_topo = yield self.store.get_max_topological_token( + max_topo = await self.store.get_max_topological_token( room_id, room_token.stream ) @@ -339,18 +338,18 @@ class PaginationHandler(object): # If they have left the room then clamp the token to be before # they left the room, to save the effort of loading from the # database. - leave_token = yield self.store.get_topological_token_for_event( + leave_token = await self.store.get_topological_token_for_event( member_event_id ) leave_token = RoomStreamToken.parse(leave_token) if leave_token.topological < max_topo: source_config.from_key = str(leave_token) - yield self.hs.get_handlers().federation_handler.maybe_backfill( + await self.hs.get_handlers().federation_handler.maybe_backfill( room_id, max_topo ) - events, next_key = yield self.store.paginate_room_events( + events, next_key = await self.store.paginate_room_events( room_id=room_id, from_key=source_config.from_key, to_key=source_config.to_key, @@ -365,7 +364,7 @@ class PaginationHandler(object): if event_filter: events = event_filter.filter(events) - events = yield filter_events_for_client( + events = await filter_events_for_client( self.storage, user_id, events, is_peeking=(member_event_id is None) ) @@ -385,19 +384,19 @@ class PaginationHandler(object): (EventTypes.Member, event.sender) for event in events ) - state_ids = yield self.state_store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( events[0].event_id, state_filter=state_filter ) if state_ids: - state = yield self.store.get_events(list(state_ids.values())) + state = await self.store.get_events(list(state_ids.values())) state = state.values() time_now = self.clock.time_msec() chunk = { "chunk": ( - yield self._event_serializer.serialize_events( + await self._event_serializer.serialize_events( events, time_now, as_client_event=as_client_event ) ), @@ -406,7 +405,7 @@ class PaginationHandler(object): } if state: - chunk["state"] = yield self._event_serializer.serialize_events( + chunk["state"] = await self._event_serializer.serialize_events( state, time_now, as_client_event=as_client_event ) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index eda15bc623..240c4add12 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -230,7 +230,7 @@ class PresenceHandler(object): is some spurious presence changes that will self-correct. """ # If the DB pool has already terminated, don't try updating - if not self.hs.get_db_pool().running: + if not self.store.database.is_running(): return logger.info( diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 1e5a4613c9..f9579d69ee 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -295,12 +295,16 @@ class BaseProfileHandler(BaseHandler): be found to be in any room the server is in, and therefore the query is denied. """ + # Implementation of MSC1301: don't allow looking up profiles if the # requester isn't in the same room as the target. We expect requester to # be None when this function is called outside of a profile query, e.g. # when building a membership event. In this case, we must allow the # lookup. - if not self.hs.config.require_auth_for_profile_requests or not requester: + if ( + not self.hs.config.limit_profile_requests_to_users_who_share_rooms + or not requester + ): return # Always allow the user to query their own profile. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 60b8bbc7a5..89c9118b26 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -184,7 +184,7 @@ class RoomCreationHandler(BaseHandler): requester, tombstone_event, tombstone_context ) - old_room_state = yield tombstone_context.get_current_state_ids(self.store) + old_room_state = yield tombstone_context.get_current_state_ids() # update any aliases yield self._move_aliases_to_new_room( @@ -1011,15 +1011,3 @@ class RoomEventSource(object): def get_current_key_for_room(self, room_id): return self.store.get_room_events_max_id(room_id) - - @defer.inlineCallbacks - def get_pagination_rows(self, user, config, key): - events, next_key = yield self.store.paginate_room_events( - room_id=key, - from_key=config.from_key, - to_key=config.to_key, - direction=config.direction, - limit=config.limit, - ) - - return (events, next_key) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 7b7270fc61..44c5e3239c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -193,7 +193,7 @@ class RoomMemberHandler(object): requester, event, context, extra_users=[target], ratelimit=ratelimit ) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) @@ -601,7 +601,7 @@ class RoomMemberHandler(object): if prev_event is not None: return - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() if event.membership == Membership.JOIN: if requester.is_guest: guest_can_join = yield self._can_guest_join(prev_state_ids) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index cc9e6b9bd0..0082f85c26 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -13,20 +13,36 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import re +from typing import Tuple import attr import saml2 +import saml2.response from saml2.client import Saml2Client from synapse.api.errors import SynapseError +from synapse.config import ConfigError from synapse.http.servlet import parse_string from synapse.rest.client.v1.login import SSOAuthHandler -from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.types import ( + UserID, + map_username_to_mxid_localpart, + mxid_localpart_allowed_characters, +) from synapse.util.async_helpers import Linearizer logger = logging.getLogger(__name__) +@attr.s +class Saml2SessionData: + """Data we track about SAML2 sessions""" + + # time the session was created, in milliseconds + creation_time = attr.ib() + + class SamlHandler: def __init__(self, hs): self._saml_client = Saml2Client(hs.config.saml2_sp_config) @@ -37,11 +53,14 @@ class SamlHandler: self._datastore = hs.get_datastore() self._hostname = hs.hostname self._saml2_session_lifetime = hs.config.saml2_session_lifetime - self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute self._grandfathered_mxid_source_attribute = ( hs.config.saml2_grandfathered_mxid_source_attribute ) - self._mxid_mapper = hs.config.saml2_mxid_mapper + + # plugin to do custom mapping from saml response to mxid + self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( + hs.config.saml2_user_mapping_provider_config + ) # identifier for the external_ids table self._auth_provider_id = "saml" @@ -118,22 +137,10 @@ class SamlHandler: remote_user_id = saml2_auth.ava["uid"][0] except KeyError: logger.warning("SAML2 response lacks a 'uid' attestation") - raise SynapseError(400, "uid not in SAML2 response") - - try: - mxid_source = saml2_auth.ava[self._mxid_source_attribute][0] - except KeyError: - logger.warning( - "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute - ) - raise SynapseError( - 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) - ) + raise SynapseError(400, "'uid' not in SAML2 response") self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None) - displayName = saml2_auth.ava.get("displayName", [None])[0] - with (await self._mapping_lock.queue(self._auth_provider_id)): # first of all, check if we already have a mapping for this user logger.info( @@ -173,22 +180,46 @@ class SamlHandler: ) return registered_user_id - # figure out a new mxid for this user - base_mxid_localpart = self._mxid_mapper(mxid_source) + # Map saml response to user attributes using the configured mapping provider + for i in range(1000): + attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes( + saml2_auth, i + ) + + logger.debug( + "Retrieved SAML attributes from user mapping provider: %s " + "(attempt %d)", + attribute_dict, + i, + ) + + localpart = attribute_dict.get("mxid_localpart") + if not localpart: + logger.error( + "SAML mapping provider plugin did not return a " + "mxid_localpart object" + ) + raise SynapseError(500, "Error parsing SAML2 response") - suffix = 0 - while True: - localpart = base_mxid_localpart + (str(suffix) if suffix else "") + displayname = attribute_dict.get("displayname") + + # Check if this mxid already exists if not await self._datastore.get_users_by_id_case_insensitive( UserID(localpart, self._hostname).to_string() ): + # This mxid is free break - suffix += 1 - logger.info("Allocating mxid for new user with localpart %s", localpart) + else: + # Unable to generate a username in 1000 iterations + # Break and return error to the user + raise SynapseError( + 500, "Unable to generate a Matrix ID from the SAML response" + ) registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=displayName + localpart=localpart, default_display_name=displayname ) + await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) @@ -205,9 +236,120 @@ class SamlHandler: del self._outstanding_requests_dict[reqid] +DOT_REPLACE_PATTERN = re.compile( + ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) +) + + +def dot_replace_for_mxid(username: str) -> str: + username = username.lower() + username = DOT_REPLACE_PATTERN.sub(".", username) + + # regular mxids aren't allowed to start with an underscore either + username = re.sub("^_", "", username) + return username + + +MXID_MAPPER_MAP = { + "hexencode": map_username_to_mxid_localpart, + "dotreplace": dot_replace_for_mxid, +} + + @attr.s -class Saml2SessionData: - """Data we track about SAML2 sessions""" +class SamlConfig(object): + mxid_source_attribute = attr.ib() + mxid_mapper = attr.ib() - # time the session was created, in milliseconds - creation_time = attr.ib() + +class DefaultSamlMappingProvider(object): + __version__ = "0.0.1" + + def __init__(self, parsed_config: SamlConfig): + """The default SAML user mapping provider + + Args: + parsed_config: Module configuration + """ + self._mxid_source_attribute = parsed_config.mxid_source_attribute + self._mxid_mapper = parsed_config.mxid_mapper + + def saml_response_to_user_attributes( + self, saml_response: saml2.response.AuthnResponse, failures: int = 0, + ) -> dict: + """Maps some text from a SAML response to attributes of a new user + + Args: + saml_response: A SAML auth response object + + failures: How many times a call to this function with this + saml_response has resulted in a failure + + Returns: + dict: A dict containing new user attributes. Possible keys: + * mxid_localpart (str): Required. The localpart of the user's mxid + * displayname (str): The displayname of the user + """ + try: + mxid_source = saml_response.ava[self._mxid_source_attribute][0] + except KeyError: + logger.warning( + "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute, + ) + raise SynapseError( + 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) + ) + + # Use the configured mapper for this mxid_source + base_mxid_localpart = self._mxid_mapper(mxid_source) + + # Append suffix integer if last call to this function failed to produce + # a usable mxid + localpart = base_mxid_localpart + (str(failures) if failures else "") + + # Retrieve the display name from the saml response + # If displayname is None, the mxid_localpart will be used instead + displayname = saml_response.ava.get("displayName", [None])[0] + + return { + "mxid_localpart": localpart, + "displayname": displayname, + } + + @staticmethod + def parse_config(config: dict) -> SamlConfig: + """Parse the dict provided by the homeserver's config + Args: + config: A dictionary containing configuration options for this provider + Returns: + SamlConfig: A custom config object for this module + """ + # Parse config options and use defaults where necessary + mxid_source_attribute = config.get("mxid_source_attribute", "uid") + mapping_type = config.get("mxid_mapping", "hexencode") + + # Retrieve the associating mapping function + try: + mxid_mapper = MXID_MAPPER_MAP[mapping_type] + except KeyError: + raise ConfigError( + "saml2_config.user_mapping_provider.config: '%s' is not a valid " + "mxid_mapping value" % (mapping_type,) + ) + + return SamlConfig(mxid_source_attribute, mxid_mapper) + + @staticmethod + def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]: + """Returns the required attributes of a SAML + + Args: + config: A SamlConfig object containing configuration params for this provider + + Returns: + tuple[set,set]: The first set equates to the saml auth response + attributes that are required for the module to function, whereas the + second set consists of those attributes which can be used if + available, but are not necessary + """ + return {"uid", config.mxid_source_attribute}, {"displayName"} diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 56ed262a1f..ef750d1497 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -21,7 +21,7 @@ from unpaddedbase64 import decode_base64, encode_base64 from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import SynapseError +from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.storage.state import StateFilter from synapse.visibility import filter_events_for_client @@ -37,6 +37,7 @@ class SearchHandler(BaseHandler): self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state + self.auth = hs.get_auth() @defer.inlineCallbacks def get_old_rooms_from_upgraded_room(self, room_id): @@ -53,23 +54,38 @@ class SearchHandler(BaseHandler): room_id (str): id of the room to search through. Returns: - Deferred[iterable[unicode]]: predecessor room ids + Deferred[iterable[str]]: predecessor room ids """ historical_room_ids = [] - while True: - predecessor = yield self.store.get_room_predecessor(room_id) + # The initial room must have been known for us to get this far + predecessor = yield self.store.get_room_predecessor(room_id) - # If no predecessor, assume we've hit a dead end + while True: if not predecessor: + # We have reached the end of the chain of predecessors + break + + if not isinstance(predecessor.get("room_id"), str): + # This predecessor object is malformed. Exit here + break + + predecessor_room_id = predecessor["room_id"] + + # Don't add it to the list until we have checked that we are in the room + try: + next_predecessor_room = yield self.store.get_room_predecessor( + predecessor_room_id + ) + except NotFoundError: + # The predecessor is not a known room, so we are done here break - # Add predecessor's room ID - historical_room_ids.append(predecessor["room_id"]) + historical_room_ids.append(predecessor_room_id) - # Scan through the old room for further predecessors - room_id = predecessor["room_id"] + # And repeat + predecessor = next_predecessor_room return historical_room_ids diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 6f78454322..b635c339ed 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -317,6 +317,3 @@ class TypingNotificationEventSource(object): def get_current_key(self): return self.get_typing_handler()._latest_room_serial - - def get_pagination_rows(self, user, pagination_config, key): - return [], pagination_config.from_key diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py index 03934956f4..c0b9384189 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py @@ -171,7 +171,7 @@ class LogProducer(object): def stopProducing(self): self._paused = True - self._buffer = None + self._buffer = deque() def resumeProducing(self): self._paused = False diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 2c1fb9ddac..33b322209d 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -23,6 +23,7 @@ them. See doc/log_contexts.rst for details on how this works. """ +import inspect import logging import threading import types @@ -404,6 +405,9 @@ class LoggingContext(object): """ current = get_thread_resource_usage() + # Indicate to mypy that we know that self.usage_start is None. + assert self.usage_start is not None + utime_delta = current.ru_utime - self.usage_start.ru_utime stime_delta = current.ru_stime - self.usage_start.ru_stime @@ -612,7 +616,8 @@ def run_in_background(f, *args, **kwargs): def make_deferred_yieldable(deferred): - """Given a deferred, make it follow the Synapse logcontext rules: + """Given a deferred (or coroutine), make it follow the Synapse logcontext + rules: If the deferred has completed (or is not actually a Deferred), essentially does nothing (just returns another completed deferred with the @@ -624,6 +629,13 @@ def make_deferred_yieldable(deferred): (This is more-or-less the opposite operation to run_in_background.) """ + if inspect.isawaitable(deferred): + # If we're given a coroutine we convert it to a deferred so that we + # run it and find out if it immediately finishes, it it does then we + # don't need to fiddle with log contexts at all and can return + # immediately. + deferred = defer.ensureDeferred(deferred) + if not isinstance(deferred, defer.Deferred): return deferred diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 7881780760..7d9f5a38d9 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -116,7 +116,7 @@ class BulkPushRuleEvaluator(object): @defer.inlineCallbacks def _get_power_levels_and_sender_level(self, event, context): - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() pl_event_id = prev_state_ids.get(POWER_KEY) if pl_event_id: # fastpath: if there's a power level event, that's all we need, and @@ -304,7 +304,7 @@ class RulesForRoom(object): push_rules_delta_state_cache_metric.inc_hits() else: - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() push_rules_delta_state_cache_metric.inc_misses() push_rules_state_size_counter.inc(len(current_state_ids)) diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index f277aeb131..8ad0bf5936 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -80,9 +80,11 @@ class PusherFactory(object): return EmailPusher(self.hs, pusherdict, mailer) def _app_name_from_pusherdict(self, pusherdict): - if "data" in pusherdict and "brand" in pusherdict["data"]: - app_name = pusherdict["data"]["brand"] - else: - app_name = self.config.email_app_name + data = pusherdict["data"] - return app_name + if isinstance(data, dict): + brand = data.get("brand") + if isinstance(brand, str): + return brand + + return self.config.email_app_name diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 0f6992202d..b9dca5bc63 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -232,7 +232,6 @@ class PusherPool: Deferred """ pushers = yield self.store.get_all_pushers() - logger.info("Starting %d pushers", len(pushers)) # Stagger starting up the pushers so we don't completely drown the # process on start up. @@ -245,7 +244,7 @@ class PusherPool: """Start the given pusher Args: - pusherdict (dict): + pusherdict (dict): dict with the values pulled from the db table Returns: Deferred[EmailPusher|HttpPusher] @@ -254,7 +253,8 @@ class PusherPool: p = self.pusher_factory.create_pusher(pusherdict) except PusherConfigException as e: logger.warning( - "Pusher incorrectly configured user=%s, appid=%s, pushkey=%s: %s", + "Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s", + pusherdict["id"], pusherdict.get("user_name"), pusherdict.get("app_id"), pusherdict.get("pushkey"), @@ -262,7 +262,9 @@ class PusherPool: ) return except Exception: - logger.exception("Couldn't start a pusher: caught Exception") + logger.exception( + "Couldn't start pusher id %i: caught Exception", pusherdict["id"], + ) return if not p: diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 9af4e7e173..49a3251372 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -51,6 +51,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): super(ReplicationFederationSendEventsRestServlet, self).__init__(hs) self.store = hs.get_datastore() + self.storage = hs.get_storage() self.clock = hs.get_clock() self.federation_handler = hs.get_handlers().federation_handler @@ -100,7 +101,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): EventType = event_type_from_format_version(format_ver) event = EventType(event_dict, internal_metadata, rejected_reason) - context = EventContext.deserialize(self.store, event_payload["context"]) + context = EventContext.deserialize( + self.storage, event_payload["context"] + ) event_and_contexts.append((event, context)) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 9bafd60b14..84b92f16ad 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -54,6 +54,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastore() + self.storage = hs.get_storage() self.clock = hs.get_clock() @staticmethod @@ -100,7 +101,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): event = EventType(event_dict, internal_metadata, rejected_reason) requester = Requester.deserialize(self.store, content["requester"]) - context = EventContext.deserialize(self.store, content["context"]) + context = EventContext.deserialize(self.storage, content["context"]) ratelimit = content["ratelimit"] extra_users = [UserID.from_string(u) for u in content["extra_users"]] diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 0791866f55..6f6b7aed6e 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -28,6 +28,17 @@ from synapse.rest.client.v2_alpha._base import client_patterns logger = logging.getLogger(__name__) +ALLOWED_KEYS = { + "app_display_name", + "app_id", + "data", + "device_display_name", + "kind", + "lang", + "profile_tag", + "pushkey", +} + class PushersRestServlet(RestServlet): PATTERNS = client_patterns("/pushers$", v1=True) @@ -43,23 +54,11 @@ class PushersRestServlet(RestServlet): pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) - allowed_keys = [ - "app_display_name", - "app_id", - "data", - "device_display_name", - "kind", - "lang", - "profile_tag", - "pushkey", - ] - - for p in pushers: - for k, v in list(p.items()): - if k not in allowed_keys: - del p[k] - - return 200, {"pushers": pushers} + filtered_pushers = list( + {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers + ) + + return 200, {"pushers": filtered_pushers} def on_OPTIONS(self, _): return 200, {} diff --git a/synapse/server.py b/synapse/server.py index 2db3dab221..7926867b77 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -25,7 +25,6 @@ import abc import logging import os -from twisted.enterprise import adbapi from twisted.mail.smtp import sendmail from twisted.web.client import BrowserLikePolicyForHTTPS @@ -34,6 +33,7 @@ from synapse.api.filtering import Filtering from synapse.api.ratelimiting import Ratelimiter from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.scheduler import ApplicationServiceScheduler +from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory @@ -132,7 +132,6 @@ class HomeServer(object): DEPENDENCIES = [ "http_client", - "db_pool", "federation_client", "federation_server", "handlers", @@ -209,16 +208,18 @@ class HomeServer(object): # instantiated during setup() for future return by get_datastore() DATASTORE_CLASS = abc.abstractproperty() - def __init__(self, hostname, reactor=None, **kwargs): + def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs): """ Args: hostname : The hostname for the server. + config: The full config for the homeserver. """ if not reactor: from twisted.internet import reactor self._reactor = reactor self.hostname = hostname + self.config = config self._building = {} self._listening_services = [] self.start_time = None @@ -237,10 +238,8 @@ class HomeServer(object): def setup(self): logger.info("Setting up.") - with self.get_db_conn() as conn: - self.datastores = DataStores(self.DATASTORE_CLASS, conn, self) - conn.commit() self.start_time = int(self.get_clock().time()) + self.datastores = DataStores(self.DATASTORE_CLASS, self) logger.info("Finished setting up.") def setup_master(self): @@ -274,6 +273,9 @@ class HomeServer(object): def get_datastore(self): return self.datastores.main + def get_datastores(self): + return self.datastores + def get_config(self): return self.config @@ -423,31 +425,6 @@ class HomeServer(object): ) return MatrixFederationHttpClient(self, tls_client_options_factory) - def build_db_pool(self): - name = self.db_config["name"] - - return adbapi.ConnectionPool( - name, cp_reactor=self.get_reactor(), **self.db_config.get("args", {}) - ) - - def get_db_conn(self, run_new_connection=True): - """Makes a new connection to the database, skipping the db pool - - Returns: - Connection: a connection object implementing the PEP-249 spec - """ - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v - for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def build_media_repository_resource(self): # build the media repo resource. This indirects through the HomeServer # to ensure that we only have a single instance of diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 0e75e94c6f..5accc071ab 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -32,6 +32,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.utils import log_function from synapse.state import v1, v2 +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.util.async_helpers import Linearizer from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache @@ -655,7 +656,7 @@ class StateResolutionStore(object): return self.store.get_events( event_ids, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, allow_rejected=allow_rejected, ) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b7637b5dc0..88546ad614 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -40,7 +40,7 @@ class SQLBaseStore(object): def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() - self.database_engine = hs.database_engine + self.database_engine = database.engine self.db = database self.rand = random.SystemRandom() diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index cafedd5c0d..d20df5f076 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -13,24 +13,55 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.database import Database +import logging + +from synapse.storage.data_stores.state import StateGroupDataStore +from synapse.storage.database import Database, make_conn +from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database +logger = logging.getLogger(__name__) + class DataStores(object): """The various data stores. These are low level interfaces to physical databases. + + Attributes: + main (DataStore) """ - def __init__(self, main_store_class, db_conn, hs): + def __init__(self, main_store_class, hs): # Note we pass in the main store class here as workers use a different main # store. - database = Database(hs) - # Check that db is correctly configured. - database.engine.check_database(db_conn.cursor()) + self.databases = [] + + for database_config in hs.config.database.databases: + db_name = database_config.name + engine = create_engine(database_config.config) + + with make_conn(database_config, engine) as db_conn: + logger.info("Preparing database %r...", db_name) + + engine.check_database(db_conn.cursor()) + prepare_database( + db_conn, engine, hs.config, data_stores=database_config.data_stores, + ) + + database = Database(hs, database_config, engine) + + if "main" in database_config.data_stores: + logger.info("Starting 'main' data store") + self.main = main_store_class(database, db_conn, hs) + + if "state" in database_config.data_stores: + logger.info("Starting 'state' data store") + self.state = StateGroupDataStore(database, db_conn, hs) + + db_conn.commit() - prepare_database(db_conn, database.engine, config=hs.config) + self.databases.append(database) - self.main = main_store_class(database, db_conn, hs) + logger.info("Database %r prepared", db_name) diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 320c5b0f07..13f4c9c72e 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -412,7 +412,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): def _update_client_ips_batch(self): # If the DB pool has already terminated, don't try updating - if not self.hs.get_db_pool().running: + if not self.db.is_running(): return to_update = self._batch_row_update @@ -451,16 +451,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Technically an access token might not be associated with # a device so we need to check. if device_id: - self.db.simple_upsert_txn( + # this is always an update rather than an upsert: the row should + # already exist, and if it doesn't, that may be because it has been + # deleted, and we don't want to re-create it. + self.db.simple_update_txn( txn, table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, - values={ + updatevalues={ "user_agent": user_agent, "last_seen": last_seen, "ip": ip, }, - lock=False, ) except Exception as e: # Failed to upsert, log and continue diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 85cfa16850..0613b49f4a 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -358,21 +358,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) def _add_messages_to_local_device_inbox_txn( self, txn, stream_id, messages_by_user_then_device ): - # Compatible method of performing an upsert - sql = "SELECT stream_id FROM device_max_stream_id" - - txn.execute(sql) - rows = txn.fetchone() - if rows: - db_stream_id = rows[0] - if db_stream_id < stream_id: - # Insert the new stream_id - sql = "UPDATE device_max_stream_id SET stream_id = ?" - else: - # No rows, perform an insert - sql = "INSERT INTO device_max_stream_id (stream_id) VALUES (?)" - - txn.execute(sql, (stream_id,)) + sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?" + txn.execute(sql, (stream_id, stream_id)) local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items(): diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index 38cd0ca9b8..e551606f9d 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -14,15 +14,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, List + from six import iteritems from canonicaljson import encode_canonical_json, json +from twisted.enterprise.adbapi import Connection from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, cachedList class EndToEndKeyWorkerStore(SQLBaseStore): @@ -271,7 +274,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): Args: txn (twisted.enterprise.adbapi.Connection): db connection user_id (str): the user whose key is being requested - key_type (str): the type of key that is being set: either 'master' + key_type (str): the type of key that is being requested: either 'master' for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key from_user_id (str): if specified, signatures made by this user on @@ -316,8 +319,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore): """Returns a user's cross-signing key. Args: - user_id (str): the user whose self-signing key is being requested - key_type (str): the type of cross-signing key to get + user_id (str): the user whose key is being requested + key_type (str): the type of key that is being requested: either 'master' + for a master key, 'self_signing' for a self-signing key, or + 'user_signing' for a user-signing key from_user_id (str): if specified, signatures made by this user on the self-signing key will be included in the result @@ -332,6 +337,206 @@ class EndToEndKeyWorkerStore(SQLBaseStore): from_user_id, ) + @cached(num_args=1) + def _get_bare_e2e_cross_signing_keys(self, user_id): + """Dummy function. Only used to make a cache for + _get_bare_e2e_cross_signing_keys_bulk. + """ + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_bare_e2e_cross_signing_keys", + list_name="user_ids", + num_args=1, + ) + def _get_bare_e2e_cross_signing_keys_bulk( + self, user_ids: List[str] + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing keys for a set of users. The output of this + function should be passed to _get_e2e_cross_signing_signatures_txn if + the signatures for the calling user need to be fetched. + + Args: + user_ids (list[str]): the users whose keys are being requested + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. If a user's cross-signing keys were not found, either + their user ID will not be in the dict, or their user ID will map + to None. + + """ + return self.db.runInteraction( + "get_bare_e2e_cross_signing_keys_bulk", + self._get_bare_e2e_cross_signing_keys_bulk_txn, + user_ids, + ) + + def _get_bare_e2e_cross_signing_keys_bulk_txn( + self, txn: Connection, user_ids: List[str], + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing keys for a set of users. The output of this + function should be passed to _get_e2e_cross_signing_signatures_txn if + the signatures for the calling user need to be fetched. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + user_ids (list[str]): the users whose keys are being requested + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. If a user's cross-signing keys were not found, their user + ID will not be in the dict. + + """ + result = {} + + batch_size = 100 + chunks = [ + user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size) + ] + for user_chunk in chunks: + sql = """ + SELECT k.user_id, k.keytype, k.keydata, k.stream_id + FROM e2e_cross_signing_keys k + INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id + FROM e2e_cross_signing_keys + GROUP BY user_id, keytype) s + USING (user_id, stream_id, keytype) + WHERE k.user_id IN (%s) + """ % ( + ",".join("?" for u in user_chunk), + ) + query_params = [] + query_params.extend(user_chunk) + + txn.execute(sql, query_params) + rows = self.db.cursor_to_dict(txn) + + for row in rows: + user_id = row["user_id"] + key_type = row["keytype"] + key = json.loads(row["keydata"]) + user_info = result.setdefault(user_id, {}) + user_info[key_type] = key + + return result + + def _get_e2e_cross_signing_signatures_txn( + self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str, + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing signatures made by a user on a set of keys. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + keys (dict[str, dict[str, dict]]): a map of user ID to key type to + key data. This dict will be modified to add signatures. + from_user_id (str): fetch the signatures made by this user + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. The return value will be the same as the keys argument, + with the modifications included. + """ + + # find out what cross-signing keys (a.k.a. devices) we need to get + # signatures for. This is a map of (user_id, device_id) to key type + # (device_id is the key's public part). + devices = {} + + for user_id, user_info in keys.items(): + if user_info is None: + continue + for key_type, key in user_info.items(): + device_id = None + for k in key["keys"].values(): + device_id = k + devices[(user_id, device_id)] = key_type + + device_list = list(devices) + + # split into batches + batch_size = 100 + chunks = [ + device_list[i : i + batch_size] + for i in range(0, len(device_list), batch_size) + ] + for user_chunk in chunks: + sql = """ + SELECT target_user_id, target_device_id, key_id, signature + FROM e2e_cross_signing_signatures + WHERE user_id = ? + AND (%s) + """ % ( + " OR ".join( + "(target_user_id = ? AND target_device_id = ?)" for d in devices + ) + ) + query_params = [from_user_id] + for item in devices: + # item is a (user_id, device_id) tuple + query_params.extend(item) + + txn.execute(sql, query_params) + rows = self.db.cursor_to_dict(txn) + + # and add the signatures to the appropriate keys + for row in rows: + key_id = row["key_id"] + target_user_id = row["target_user_id"] + target_device_id = row["target_device_id"] + key_type = devices[(target_user_id, target_device_id)] + # We need to copy everything, because the result may have come + # from the cache. dict.copy only does a shallow copy, so we + # need to recursively copy the dicts that will be modified. + user_info = keys[target_user_id] = keys[target_user_id].copy() + target_user_key = user_info[key_type] = user_info[key_type].copy() + if "signatures" in target_user_key: + signatures = target_user_key["signatures"] = target_user_key[ + "signatures" + ].copy() + if from_user_id in signatures: + user_sigs = signatures[from_user_id] = signatures[from_user_id] + user_sigs[key_id] = row["signature"] + else: + signatures[from_user_id] = {key_id: row["signature"]} + else: + target_user_key["signatures"] = { + from_user_id: {key_id: row["signature"]} + } + + return keys + + @defer.inlineCallbacks + def get_e2e_cross_signing_keys_bulk( + self, user_ids: List[str], from_user_id: str = None + ) -> defer.Deferred: + """Returns the cross-signing keys for a set of users. + + Args: + user_ids (list[str]): the users whose keys are being requested + from_user_id (str): if specified, signatures made by this user on + the self-signing keys will be included in the result + + Returns: + Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to + key data. If a user's cross-signing keys were not found, either + their user ID will not be in the dict, or their user ID will map + to None. + """ + + result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids) + + if from_user_id: + result = yield self.db.runInteraction( + "get_e2e_cross_signing_signatures", + self._get_e2e_cross_signing_signatures_txn, + result, + from_user_id, + ) + + return result + def get_all_user_signature_changes_for_remotes(self, from_key, to_key): """Return a list of changes from the user signature stream to notify remotes. Note that the user signature stream represents when a user signs their @@ -520,6 +725,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): }, ) + self._invalidate_cache_and_stream( + txn, self._get_bare_e2e_cross_signing_keys, (user_id,) + ) + def set_e2e_cross_signing_key(self, user_id, key_type, key): """Set a user's cross-signing key. diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 998bba1aad..58f35d7f56 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1757,163 +1757,6 @@ class EventsStore( return state_groups - def purge_unreferenced_state_groups( - self, room_id: str, state_groups_to_delete - ) -> defer.Deferred: - """Deletes no longer referenced state groups and de-deltas any state - groups that reference them. - - Args: - room_id: The room the state groups belong to (must all be in the - same room). - state_groups_to_delete (Collection[int]): Set of all state groups - to delete. - """ - - return self.db.runInteraction( - "purge_unreferenced_state_groups", - self._purge_unreferenced_state_groups, - room_id, - state_groups_to_delete, - ) - - def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete): - logger.info( - "[purge] found %i state groups to delete", len(state_groups_to_delete) - ) - - rows = self.db.simple_select_many_txn( - txn, - table="state_group_edges", - column="prev_state_group", - iterable=state_groups_to_delete, - keyvalues={}, - retcols=("state_group",), - ) - - remaining_state_groups = set( - row["state_group"] - for row in rows - if row["state_group"] not in state_groups_to_delete - ) - - logger.info( - "[purge] de-delta-ing %i remaining state groups", - len(remaining_state_groups), - ) - - # Now we turn the state groups that reference to-be-deleted state - # groups to non delta versions. - for sg in remaining_state_groups: - logger.info("[purge] de-delta-ing remaining state group %s", sg) - curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) - curr_state = curr_state[sg] - - self.db.simple_delete_txn( - txn, table="state_groups_state", keyvalues={"state_group": sg} - ) - - self.db.simple_delete_txn( - txn, table="state_group_edges", keyvalues={"state_group": sg} - ) - - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": sg, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(curr_state) - ], - ) - - logger.info("[purge] removing redundant state groups") - txn.executemany( - "DELETE FROM state_groups_state WHERE state_group = ?", - ((sg,) for sg in state_groups_to_delete), - ) - txn.executemany( - "DELETE FROM state_groups WHERE id = ?", - ((sg,) for sg in state_groups_to_delete), - ) - - @defer.inlineCallbacks - def get_previous_state_groups(self, state_groups): - """Fetch the previous groups of the given state groups. - - Args: - state_groups (Iterable[int]) - - Returns: - Deferred[dict[int, int]]: mapping from state group to previous - state group. - """ - - rows = yield self.db.simple_select_many_batch( - table="state_group_edges", - column="prev_state_group", - iterable=state_groups, - keyvalues={}, - retcols=("prev_state_group", "state_group"), - desc="get_previous_state_groups", - ) - - return {row["state_group"]: row["prev_state_group"] for row in rows} - - def purge_room_state(self, room_id, state_groups_to_delete): - """Deletes all record of a room from state tables - - Args: - room_id (str): - state_groups_to_delete (list[int]): State groups to delete - """ - - return self.db.runInteraction( - "purge_room_state", - self._purge_room_state_txn, - room_id, - state_groups_to_delete, - ) - - def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): - # first we have to delete the state groups states - logger.info("[purge] removing %s from state_groups_state", room_id) - - self.db.simple_delete_many_txn( - txn, - table="state_groups_state", - column="state_group", - iterable=state_groups_to_delete, - keyvalues={}, - ) - - # ... and the state group edges - logger.info("[purge] removing %s from state_group_edges", room_id) - - self.db.simple_delete_many_txn( - txn, - table="state_group_edges", - column="state_group", - iterable=state_groups_to_delete, - keyvalues={}, - ) - - # ... and the state groups - logger.info("[purge] removing %s from state_groups", room_id) - - self.db.simple_delete_many_txn( - txn, - table="state_groups", - column="id", - iterable=state_groups_to_delete, - keyvalues={}, - ) - async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream """ diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 9ee117ce0f..2c9142814c 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -19,8 +19,10 @@ import itertools import logging import threading from collections import namedtuple +from typing import List, Optional from canonicaljson import json +from constantly import NamedConstant, Names from twisted.internet import defer @@ -55,6 +57,16 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) +class EventRedactBehaviour(Names): + """ + What to do when retrieving a redacted event from the database. + """ + + AS_IS = NamedConstant() + REDACT = NamedConstant() + BLOCK = NamedConstant() + + class EventsWorkerStore(SQLBaseStore): def __init__(self, database: Database, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) @@ -125,25 +137,27 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_event( self, - event_id, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, - allow_none=False, - check_room_id=None, + event_id: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: bool = False, + check_room_id: Optional[str] = None, ): """Get an event from the database by event_id. Args: - event_id (str): The event_id of the event to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_id: The event_id of the event to fetch + redact_behaviour: Determine what to do with a redacted event. Possible values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. - allow_none (bool): If True, return None if no event found, if + allow_rejected: If True return rejected events. + allow_none: If True, return None if no event found, if False throw a NotFoundError - check_room_id (str|None): if not None, check the room of the found event. + check_room_id: if not None, check the room of the found event. If there is a mismatch, behave as per allow_none. Returns: @@ -154,7 +168,7 @@ class EventsWorkerStore(SQLBaseStore): events = yield self.get_events_as_list( [event_id], - check_redacted=check_redacted, + redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, allow_rejected=allow_rejected, ) @@ -173,27 +187,30 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_events( self, - event_ids, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, + event_ids: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, ): """Get events from the database Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_ids: The event_ids of the events to fetch + redact_behaviour: Determine what to do with a redacted event. Possible + values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. + allow_rejected: If True return rejected events. Returns: Deferred : Dict from event_id to event. """ events = yield self.get_events_as_list( event_ids, - check_redacted=check_redacted, + redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, allow_rejected=allow_rejected, ) @@ -203,21 +220,23 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_events_as_list( self, - event_ids, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, + event_ids: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, ): """Get events from the database and return in a list in the same order as given by `event_ids` arg. Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_ids: The event_ids of the events to fetch + redact_behaviour: Determine what to do with a redacted event. Possible values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. + allow_rejected: If True, return rejected events. Returns: Deferred[list[EventBase]]: List of events fetched from the database. The @@ -319,10 +338,14 @@ class EventsWorkerStore(SQLBaseStore): # Update the cache to save doing the checks again. entry.event.internal_metadata.recheck_redaction = False - if check_redacted and entry.redacted_event: - event = entry.redacted_event - else: - event = entry.event + event = entry.event + + if entry.redacted_event: + if redact_behaviour == EventRedactBehaviour.BLOCK: + # Skip this event + continue + elif redact_behaviour == EventRedactBehaviour.REDACT: + event = entry.redacted_event events.append(event) diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index 5ba13aa973..e2673ae073 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -244,7 +244,7 @@ class PushRulesWorkerStore( # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids(self) + current_state_ids = yield context.get_current_state_ids() result = yield self._bulk_get_push_rules_for_room( event.room_id, state_group, current_state_ids, event=event ) diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py index f07309ef09..6b03233262 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -15,8 +15,7 @@ # limitations under the License. import logging - -import six +from typing import Iterable, Iterator from canonicaljson import encode_canonical_json, json @@ -27,21 +26,16 @@ from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList logger = logging.getLogger(__name__) -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview - class PusherWorkerStore(SQLBaseStore): - def _decode_pushers_rows(self, rows): + def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]: + """JSON-decode the data in the rows returned from the `pushers` table + + Drops any rows whose data cannot be decoded + """ for r in rows: dataJson = r["data"] - r["data"] = None try: - if isinstance(dataJson, db_binary_type): - dataJson = str(dataJson).decode("UTF8") - r["data"] = json.loads(dataJson) except Exception as e: logger.warning( @@ -50,12 +44,9 @@ class PusherWorkerStore(SQLBaseStore): dataJson, e.args[0], ) - pass - - if isinstance(r["pushkey"], db_binary_type): - r["pushkey"] = str(r["pushkey"]).decode("UTF8") + continue - return rows + yield r @defer.inlineCallbacks def user_has_pusher(self, user_id): diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 92e3b9c512..70ff5751b6 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -477,7 +477,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids(self) + current_state_ids = yield context.get_current_state_ids() result = yield self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) diff --git a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql index 4219cdd06a..2de50d408c 100644 --- a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql +++ b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql @@ -20,7 +20,6 @@ DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY -DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT diff --git a/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql new file mode 100644 index 0000000000..c2f557fde9 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql @@ -0,0 +1,20 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This line already existed in deltas/35/device_stream_id but was not included in the +-- 54 full schema SQL. Add some SQL here to insert the missing row if it does not exist +INSERT INTO device_max_stream_id (stream_id) SELECT 0 WHERE NOT EXISTS ( + SELECT * from device_max_stream_id +); \ No newline at end of file diff --git a/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql new file mode 100644 index 0000000000..4f24c1405d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql @@ -0,0 +1,29 @@ +/* Copyright 2019 Werner Sembach + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Groups/communities now get deleted when the last member leaves. This is a one time cleanup to remove old groups/communities that were already empty before that change was made. +DELETE FROM group_attestations_remote WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_attestations_renewals WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_invites WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_users WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM local_group_membership WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM local_group_updates WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM groups WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres index 4ad2929f32..889a9a0ce4 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres @@ -975,40 +975,6 @@ CREATE TABLE state_events ( -CREATE TABLE state_group_edges ( - state_group bigint NOT NULL, - prev_state_group bigint NOT NULL -); - - - -CREATE SEQUENCE state_group_id_seq - START WITH 1 - INCREMENT BY 1 - NO MINVALUE - NO MAXVALUE - CACHE 1; - - - -CREATE TABLE state_groups ( - id bigint NOT NULL, - room_id text NOT NULL, - event_id text NOT NULL -); - - - -CREATE TABLE state_groups_state ( - state_group bigint NOT NULL, - room_id text NOT NULL, - type text NOT NULL, - state_key text NOT NULL, - event_id text NOT NULL -); - - - CREATE TABLE stats_stream_pos ( lock character(1) DEFAULT 'X'::bpchar NOT NULL, stream_id bigint, @@ -1482,12 +1448,6 @@ ALTER TABLE ONLY state_events ADD CONSTRAINT state_events_event_id_key UNIQUE (event_id); - -ALTER TABLE ONLY state_groups - ADD CONSTRAINT state_groups_pkey PRIMARY KEY (id); - - - ALTER TABLE ONLY stats_stream_pos ADD CONSTRAINT stats_stream_pos_lock_key UNIQUE (lock); @@ -1928,18 +1888,6 @@ CREATE UNIQUE INDEX room_stats_room_ts ON room_stats USING btree (room_id, ts); -CREATE INDEX state_group_edges_idx ON state_group_edges USING btree (state_group); - - - -CREATE INDEX state_group_edges_prev_idx ON state_group_edges USING btree (prev_state_group); - - - -CREATE INDEX state_groups_state_type_idx ON state_groups_state USING btree (state_group, type, state_key); - - - CREATE INDEX stream_ordering_to_exterm_idx ON stream_ordering_to_exterm USING btree (stream_ordering); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite index bad33291e7..a0411ede7e 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite @@ -42,8 +42,6 @@ CREATE INDEX ev_edges_id ON event_edges(event_id); CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id); CREATE TABLE room_depth( room_id TEXT NOT NULL, min_depth INTEGER NOT NULL, UNIQUE (room_id) ); CREATE INDEX room_depth_room ON room_depth(room_id); -CREATE TABLE state_groups( id BIGINT PRIMARY KEY, room_id TEXT NOT NULL, event_id TEXT NOT NULL ); -CREATE TABLE state_groups_state( state_group BIGINT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, event_id TEXT NOT NULL ); CREATE TABLE event_to_state_groups( event_id TEXT NOT NULL, state_group BIGINT NOT NULL, UNIQUE (event_id) ); CREATE TABLE local_media_repository ( media_id TEXT, media_type TEXT, media_length INTEGER, created_ts BIGINT, upload_name TEXT, user_id TEXT, quarantined_by TEXT, url_cache TEXT, last_access_ts BIGINT, UNIQUE (media_id) ); CREATE TABLE local_media_repository_thumbnails ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type ) ); @@ -120,9 +118,6 @@ CREATE TABLE device_max_stream_id ( stream_id BIGINT NOT NULL ); CREATE TABLE public_room_list_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, visibility BOOLEAN NOT NULL , appservice_id TEXT, network_id TEXT); CREATE INDEX public_room_list_stream_idx on public_room_list_stream( stream_id ); CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( room_id, stream_id ); -CREATE TABLE state_group_edges( state_group BIGINT NOT NULL, prev_state_group BIGINT NOT NULL ); -CREATE INDEX state_group_edges_idx ON state_group_edges(state_group); -CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group); CREATE TABLE stream_ordering_to_exterm ( stream_ordering BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL ); CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( stream_ordering ); CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( room_id, stream_ordering ); @@ -254,6 +249,5 @@ CREATE INDEX user_ips_last_seen_only ON user_ips (last_seen); CREATE INDEX users_creation_ts ON users (creation_ts); CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups (state_group); CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache (user_id, device_id); -CREATE INDEX state_groups_state_type_idx ON state_groups_state(state_group, type, state_key); CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties (user_id); CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips (user_id, access_token, ip); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql index c265fd20e2..91d21b2921 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql @@ -5,3 +5,4 @@ INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coales INSERT INTO user_directory_stream_pos (stream_id) VALUES (0); INSERT INTO stats_stream_pos (stream_id) VALUES (0); INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0); +-- device_max_stream_id is handled separately in 56/device_stream_id_insert.sql \ No newline at end of file diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md new file mode 100644 index 0000000000..bbd3f18604 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/README.md @@ -0,0 +1,13 @@ +# Building full schema dumps + +These schemas need to be made from a database that has had all background updates run. + +To do so, use `scripts-dev/make_full_schema.sh`. This will produce +`full.sql.postgres ` and `full.sql.sqlite` files. + +Ensure postgres is installed and your user has the ability to run bash commands +such as `createdb`. + +``` +./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/ +``` diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.txt b/synapse/storage/data_stores/main/schema/full_schemas/README.txt deleted file mode 100644 index d3f6401344..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/README.txt +++ /dev/null @@ -1,19 +0,0 @@ -Building full schema dumps -========================== - -These schemas need to be made from a database that has had all background updates run. - -Postgres --------- - -$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres - -SQLite ------- - -$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite - -After ------ - -Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas". \ No newline at end of file diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index 4eec2fae5e..47ebb8a214 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -25,6 +25,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -384,7 +385,7 @@ class SearchStore(SearchBackgroundUpdateStore): """ clauses = [] - search_query = search_query = _parse_query(self.database_engine, search_term) + search_query = _parse_query(self.database_engine, search_term) args = [] @@ -453,7 +454,12 @@ class SearchStore(SearchBackgroundUpdateStore): results = list(filter(lambda row: row["room_id"] in room_ids, results)) - events = yield self.get_events_as_list([r["event_id"] for r in results]) + # We set redact_behaviour to BLOCK here to prevent redacted events being returned in + # search results (which is a data leak) + events = yield self.get_events_as_list( + [r["event_id"] for r in results], + redact_behaviour=EventRedactBehaviour.BLOCK, + ) event_map = {ev.event_id: ev for ev in events} @@ -495,7 +501,7 @@ class SearchStore(SearchBackgroundUpdateStore): """ clauses = [] - search_query = search_query = _parse_query(self.database_engine, search_term) + search_query = _parse_query(self.database_engine, search_term) args = [] @@ -600,7 +606,12 @@ class SearchStore(SearchBackgroundUpdateStore): results = list(filter(lambda row: row["room_id"] in room_ids, results)) - events = yield self.get_events_as_list([r["event_id"] for r in results]) + # We set redact_behaviour to BLOCK here to prevent redacted events being returned in + # search results (which is a data leak) + events = yield self.get_events_as_list( + [r["event_id"] for r in results], + redact_behaviour=EventRedactBehaviour.BLOCK, + ) event_map = {ev.event_id: ev for ev in events} diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 9ef7b48c74..0dc39f139c 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -17,8 +17,7 @@ import logging from collections import namedtuple from typing import Iterable, Tuple -from six import iteritems, itervalues -from six.moves import range +from six import iteritems from twisted.internet import defer @@ -29,11 +28,9 @@ from synapse.events.snapshot import EventContext from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.database import Database -from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter -from synapse.util.caches import get_cache_factor_for, intern_string +from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.stringutils import to_ascii logger = logging.getLogger(__name__) @@ -55,207 +52,14 @@ class _GetStateGroupDelta( return len(self.delta_ids) if self.delta_ids else 0 -class StateGroupBackgroundUpdateStore(SQLBaseStore): - """Defines functions related to state groups needed to run the state backgroud - updates. - """ - - def _count_state_group_hops_txn(self, txn, state_group): - """Given a state group, count how many hops there are in the tree. - - This is used to ensure the delta chains don't get too long. - """ - if isinstance(self.database_engine, PostgresEngine): - sql = """ - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT count(*) FROM state; - """ - - txn.execute(sql, (state_group,)) - row = txn.fetchone() - if row and row[0]: - return row[0] - else: - return 0 - else: - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - next_group = state_group - count = 0 - - while next_group: - next_group = self.db.simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - if next_group: - count += 1 - - return count - - def _get_state_groups_from_groups_txn( - self, txn, groups, state_filter=StateFilter.all() - ): - results = {group: {} for group in groups} - - where_clause, where_args = state_filter.make_sql_filter_clause() - - # Unless the filter clause is empty, we're going to append it after an - # existing where clause - if where_clause: - where_clause = " AND (%s)" % (where_clause,) - - if isinstance(self.database_engine, PostgresEngine): - # Temporarily disable sequential scans in this transaction. This is - # a temporary hack until we can add the right indices in - txn.execute("SET LOCAL enable_seqscan=off") - - # The below query walks the state_group tree so that the "state" - # table includes all state_groups in the tree. It then joins - # against `state_groups_state` to fetch the latest state. - # It assumes that previous state groups are always numerically - # lesser. - # The PARTITION is used to get the event_id in the greatest state - # group for the given type, state_key. - # This may return multiple rows per (type, state_key), but last_value - # should be the same. - sql = """ - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT DISTINCT type, state_key, last_value(event_id) OVER ( - PARTITION BY type, state_key ORDER BY state_group ASC - ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING - ) AS event_id FROM state_groups_state - WHERE state_group IN ( - SELECT state_group FROM state - ) - """ - - for group in groups: - args = [group] - args.extend(where_args) - - txn.execute(sql + where_clause, args) - for row in txn: - typ, state_key, event_id = row - key = (typ, state_key) - results[group][key] = event_id - else: - max_entries_returned = state_filter.max_entries_returned() - - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - for group in groups: - next_group = group - - while next_group: - # We did this before by getting the list of group ids, and - # then passing that list to sqlite to get latest event for - # each (type, state_key). However, that was terribly slow - # without the right indices (which we can't add until - # after we finish deduping state, which requires this func) - args = [next_group] - args.extend(where_args) - - txn.execute( - "SELECT type, state_key, event_id FROM state_groups_state" - " WHERE state_group = ? " + where_clause, - args, - ) - results[group].update( - ((typ, state_key), event_id) - for typ, state_key, event_id in txn - if (typ, state_key) not in results[group] - ) - - # If the number of entries in the (type,state_key)->event_id dict - # matches the number of (type,state_keys) types we were searching - # for, then we must have found them all, so no need to go walk - # further down the tree... UNLESS our types filter contained - # wildcards (i.e. Nones) in which case we have to do an exhaustive - # search - if ( - max_entries_returned is not None - and len(results[group]) == max_entries_returned - ): - break - - next_group = self.db.simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - - return results - - # this inherits from EventsWorkerStore because it calls self.get_events -class StateGroupWorkerStore( - EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore -): +class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers. """ - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" - STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" - CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - def __init__(self, database: Database, db_conn, hs): super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) - # Originally the state store used a single DictionaryCache to cache the - # event IDs for the state types in a given state group to avoid hammering - # on the state_group* tables. - # - # The point of using a DictionaryCache is that it can cache a subset - # of the state events for a given state group (i.e. a subset of the keys for a - # given dict which is an entry in the cache for a given state group ID). - # - # However, this poses problems when performing complicated queries - # on the store - for instance: "give me all the state for this group, but - # limit members to this subset of users", as DictionaryCache's API isn't - # rich enough to say "please cache any of these fields, apart from this subset". - # This is problematic when lazy loading members, which requires this behaviour, - # as without it the cache has no choice but to speculatively load all - # state events for the group, which negates the efficiency being sought. - # - # Rather than overcomplicating DictionaryCache's API, we instead split the - # state_group_cache into two halves - one for tracking non-member events, - # and the other for tracking member_events. This means that lazy loading - # queries can be made in a cache-friendly manner by querying both caches - # separately and then merging the result. So for the example above, you - # would query the members cache for a specific subset of state keys - # (which DictionaryCache will handle efficiently and fine) and the non-members - # cache for all state (which DictionaryCache will similarly handle fine) - # and then just merge the results together. - # - # We size the non-members cache to be smaller than the members cache as the - # vast majority of state in Matrix (today) is member events. - - self._state_group_cache = DictionaryCache( - "*stateGroupCache*", - # TODO: this hasn't been tuned yet - 50000 * get_cache_factor_for("stateGroupCache"), - ) - self._state_group_members_cache = DictionaryCache( - "*stateGroupMembersCache*", - 500000 * get_cache_factor_for("stateGroupMembersCache"), - ) - @defer.inlineCallbacks def get_room_version(self, room_id): """Get the room_version of a given room @@ -278,7 +82,7 @@ class StateGroupWorkerStore( @defer.inlineCallbacks def get_room_predecessor(self, room_id): - """Get the predecessor room of an upgraded room if one exists. + """Get the predecessor of an upgraded room if it exists. Otherwise return None. Args: @@ -291,14 +95,22 @@ class StateGroupWorkerStore( * room_id (str): The room ID of the predecessor room * event_id (str): The ID of the tombstone event in the predecessor room + None if a predecessor key is not found, or is not a dictionary. + Raises: - NotFoundError if the room is unknown + NotFoundError if the given room is unknown """ # Retrieve the room's create event create_event = yield self.get_create_event_for_room(room_id) - # Return predecessor if present - return create_event.content.get("predecessor", None) + # Retrieve the predecessor key of the create event + predecessor = create_event.content.get("predecessor", None) + + # Ensure the key is a dictionary + if not isinstance(predecessor, dict): + return None + + return predecessor @defer.inlineCallbacks def get_create_event_for_room(self, room_id): @@ -318,7 +130,7 @@ class StateGroupWorkerStore( # If we can't find the create event, assume we've hit a dead end if not create_id: - raise NotFoundError("Unknown room %s" % (room_id)) + raise NotFoundError("Unknown room %s" % (room_id,)) # Retrieve the room's create event and return create_event = yield self.get_event(create_id) @@ -423,229 +235,6 @@ class StateGroupWorkerStore( return event.content.get("canonical_alias") - @cached(max_entries=10000, iterable=True) - def get_state_group_delta(self, state_group): - """Given a state group try to return a previous group and a delta between - the old and the new. - - Returns: - (prev_group, delta_ids), where both may be None. - """ - - def _get_state_group_delta_txn(txn): - prev_group = self.db.simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": state_group}, - retcol="prev_state_group", - allow_none=True, - ) - - if not prev_group: - return _GetStateGroupDelta(None, None) - - delta_ids = self.db.simple_select_list_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": state_group}, - retcols=("type", "state_key", "event_id"), - ) - - return _GetStateGroupDelta( - prev_group, - {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, - ) - - return self.db.runInteraction( - "get_state_group_delta", _get_state_group_delta_txn - ) - - @defer.inlineCallbacks - def get_state_groups_ids(self, _room_id, event_ids): - """Get the event IDs of all the state for the state groups for the given events - - Args: - _room_id (str): id of the room for these events - event_ids (iterable[str]): ids of the events - - Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) - """ - if not event_ids: - return {} - - event_to_groups = yield self._get_state_group_for_events(event_ids) - - groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups) - - return group_to_state - - @defer.inlineCallbacks - def get_state_ids_for_group(self, state_group): - """Get the event IDs of all the state in the given state group - - Args: - state_group (int) - - Returns: - Deferred[dict]: Resolves to a map of (type, state_key) -> event_id - """ - group_to_state = yield self._get_state_for_groups((state_group,)) - - return group_to_state[state_group] - - @defer.inlineCallbacks - def get_state_groups(self, room_id, event_ids): - """ Get the state groups for the given list of event_ids - - Returns: - Deferred[dict[int, list[EventBase]]]: - dict of state_group_id -> list of state events. - """ - if not event_ids: - return {} - - group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) - - state_event_map = yield self.get_events( - [ - ev_id - for group_ids in itervalues(group_to_ids) - for ev_id in itervalues(group_ids) - ], - get_prev_content=False, - ) - - return { - group: [ - state_event_map[v] - for v in itervalues(event_id_map) - if v in state_event_map - ] - for group, event_id_map in iteritems(group_to_ids) - } - - @defer.inlineCallbacks - def _get_state_groups_from_groups(self, groups, state_filter): - """Returns the state groups for a given set of groups, filtering on - types of state events. - - Args: - groups(list[int]): list of state group IDs to query - state_filter (StateFilter): The state filter used to fetch state - from the database. - Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) - """ - results = {} - - chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] - for chunk in chunks: - res = yield self.db.runInteraction( - "_get_state_groups_from_groups", - self._get_state_groups_from_groups_txn, - chunk, - state_filter, - ) - results.update(res) - - return results - - @defer.inlineCallbacks - def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): - """Given a list of event_ids and type tuples, return a list of state - dicts for each event. - - Args: - event_ids (list[string]) - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - deferred: A dict of (event_id) -> (type, state_key) -> [state_events] - """ - event_to_groups = yield self._get_state_group_for_events(event_ids) - - groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, state_filter) - - state_event_map = yield self.get_events( - [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], - get_prev_content=False, - ) - - event_to_state = { - event_id: { - k: state_event_map[v] - for k, v in iteritems(group_to_state[group]) - if v in state_event_map - } - for event_id, group in iteritems(event_to_groups) - } - - return {event: event_to_state[event] for event in event_ids} - - @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): - """ - Get the state dicts corresponding to a list of events, containing the event_ids - of the state events (as opposed to the events themselves) - - Args: - event_ids(list(str)): events whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A deferred dict from event_id -> (type, state_key) -> event_id - """ - event_to_groups = yield self._get_state_group_for_events(event_ids) - - groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, state_filter) - - event_to_state = { - event_id: group_to_state[group] - for event_id, group in iteritems(event_to_groups) - } - - return {event: event_to_state[event] for event in event_ids} - - @defer.inlineCallbacks - def get_state_for_event(self, event_id, state_filter=StateFilter.all()): - """ - Get the state dict corresponding to a particular event - - Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A deferred dict from (type, state_key) -> state_event - """ - state_map = yield self.get_state_for_events([event_id], state_filter) - return state_map[event_id] - - @defer.inlineCallbacks - def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): - """ - Get the state dict corresponding to a particular event - - Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A deferred dict from (type, state_key) -> state_event - """ - state_map = yield self.get_state_ids_for_events([event_id], state_filter) - return state_map[event_id] - @cached(max_entries=50000) def _get_state_group_for_event(self, event_id): return self.db.simple_select_one_onecol( @@ -676,329 +265,6 @@ class StateGroupWorkerStore( return {row["event_id"]: row["state_group"] for row in rows} - def _get_state_for_group_using_cache(self, cache, group, state_filter): - """Checks if group is in cache. See `_get_state_for_groups` - - Args: - cache(DictionaryCache): the state group cache to use - group(int): The state group to lookup - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns 2-tuple (`state_dict`, `got_all`). - `got_all` is a bool indicating if we successfully retrieved all - requests state from the cache, if False we need to query the DB for the - missing state. - """ - is_all, known_absent, state_dict_ids = cache.get(group) - - if is_all or state_filter.is_full(): - # Either we have everything or want everything, either way - # `is_all` tells us whether we've gotten everything. - return state_filter.filter_state(state_dict_ids), is_all - - # tracks whether any of our requested types are missing from the cache - missing_types = False - - if state_filter.has_wildcards(): - # We don't know if we fetched all the state keys for the types in - # the filter that are wildcards, so we have to assume that we may - # have missed some. - missing_types = True - else: - # There aren't any wild cards, so `concrete_types()` returns the - # complete list of event types we're wanting. - for key in state_filter.concrete_types(): - if key not in state_dict_ids and key not in known_absent: - missing_types = True - break - - return state_filter.filter_state(state_dict_ids), not missing_types - - @defer.inlineCallbacks - def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): - """Gets the state at each of a list of state groups, optionally - filtering by type/state_key - - Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - state_filter (StateFilter): The state filter used to fetch state - from the database. - Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) - """ - - member_filter, non_member_filter = state_filter.get_member_split() - - # Now we look them up in the member and non-member caches - ( - non_member_state, - incomplete_groups_nm, - ) = yield self._get_state_for_groups_using_cache( - groups, self._state_group_cache, state_filter=non_member_filter - ) - - ( - member_state, - incomplete_groups_m, - ) = yield self._get_state_for_groups_using_cache( - groups, self._state_group_members_cache, state_filter=member_filter - ) - - state = dict(non_member_state) - for group in groups: - state[group].update(member_state[group]) - - # Now fetch any missing groups from the database - - incomplete_groups = incomplete_groups_m | incomplete_groups_nm - - if not incomplete_groups: - return state - - cache_sequence_nm = self._state_group_cache.sequence - cache_sequence_m = self._state_group_members_cache.sequence - - # Help the cache hit ratio by expanding the filter a bit - db_state_filter = state_filter.return_expanded() - - group_to_state_dict = yield self._get_state_groups_from_groups( - list(incomplete_groups), state_filter=db_state_filter - ) - - # Now lets update the caches - self._insert_into_cache( - group_to_state_dict, - db_state_filter, - cache_seq_num_members=cache_sequence_m, - cache_seq_num_non_members=cache_sequence_nm, - ) - - # And finally update the result dict, by filtering out any extra - # stuff we pulled out of the database. - for group, group_state_dict in iteritems(group_to_state_dict): - # We just replace any existing entries, as we will have loaded - # everything we need from the database anyway. - state[group] = state_filter.filter_state(group_state_dict) - - return state - - def _get_state_for_groups_using_cache(self, groups, cache, state_filter): - """Gets the state at each of a list of state groups, optionally - filtering by type/state_key, querying from a specific cache. - - Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - cache (DictionaryCache): the cache of group ids to state dicts which - we will pass through - either the normal state cache or the specific - members state cache. - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of - dict of state_group_id -> (dict of (type, state_key) -> event id) - of entries in the cache, and the state group ids either missing - from the cache or incomplete. - """ - results = {} - incomplete_groups = set() - for group in set(groups): - state_dict_ids, got_all = self._get_state_for_group_using_cache( - cache, group, state_filter - ) - results[group] = state_dict_ids - - if not got_all: - incomplete_groups.add(group) - - return results, incomplete_groups - - def _insert_into_cache( - self, - group_to_state_dict, - state_filter, - cache_seq_num_members, - cache_seq_num_non_members, - ): - """Inserts results from querying the database into the relevant cache. - - Args: - group_to_state_dict (dict): The new entries pulled from database. - Map from state group to state dict - state_filter (StateFilter): The state filter used to fetch state - from the database. - cache_seq_num_members (int): Sequence number of member cache since - last lookup in cache - cache_seq_num_non_members (int): Sequence number of member cache since - last lookup in cache - """ - - # We need to work out which types we've fetched from the DB for the - # member vs non-member caches. This should be as accurate as possible, - # but can be an underestimate (e.g. when we have wild cards) - - member_filter, non_member_filter = state_filter.get_member_split() - if member_filter.is_full(): - # We fetched all member events - member_types = None - else: - # `concrete_types()` will only return a subset when there are wild - # cards in the filter, but that's fine. - member_types = member_filter.concrete_types() - - if non_member_filter.is_full(): - # We fetched all non member events - non_member_types = None - else: - non_member_types = non_member_filter.concrete_types() - - for group, group_state_dict in iteritems(group_to_state_dict): - state_dict_members = {} - state_dict_non_members = {} - - for k, v in iteritems(group_state_dict): - if k[0] == EventTypes.Member: - state_dict_members[k] = v - else: - state_dict_non_members[k] = v - - self._state_group_members_cache.update( - cache_seq_num_members, - key=group, - value=state_dict_members, - fetched_keys=member_types, - ) - - self._state_group_cache.update( - cache_seq_num_non_members, - key=group, - value=state_dict_non_members, - fetched_keys=non_member_types, - ) - - def store_state_group( - self, event_id, room_id, prev_group, delta_ids, current_state_ids - ): - """Store a new set of state, returning a newly assigned state group. - - Args: - event_id (str): The event ID for which the state was calculated - room_id (str) - prev_group (int|None): A previous state group for the room, optional. - delta_ids (dict|None): The delta between state at `prev_group` and - `current_state_ids`, if `prev_group` was given. Same format as - `current_state_ids`. - current_state_ids (dict): The state to store. Map of (type, state_key) - to event_id. - - Returns: - Deferred[int]: The state group ID - """ - - def _store_state_group_txn(txn): - if current_state_ids is None: - # AFAIK, this can never happen - raise Exception("current_state_ids cannot be None") - - state_group = self.database_engine.get_next_state_group_id(txn) - - self.db.simple_insert_txn( - txn, - table="state_groups", - values={"id": state_group, "room_id": room_id, "event_id": event_id}, - ) - - # We persist as a delta if we can, while also ensuring the chain - # of deltas isn't tooo long, as otherwise read performance degrades. - if prev_group: - is_in_db = self.db.simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, - ) - if not is_in_db: - raise Exception( - "Trying to persist state with unpersisted prev_group: %r" - % (prev_group,) - ) - - potential_hops = self._count_state_group_hops_txn(txn, prev_group) - if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self.db.simple_insert_txn( - txn, - table="state_group_edges", - values={"state_group": state_group, "prev_state_group": prev_group}, - ) - - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(delta_ids) - ], - ) - else: - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(current_state_ids) - ], - ) - - # Prefill the state group caches with this group. - # It's fine to use the sequence like this as the state group map - # is immutable. (If the map wasn't immutable then this prefill could - # race with another update) - - current_member_state_ids = { - s: ev - for (s, ev) in iteritems(current_state_ids) - if s[0] == EventTypes.Member - } - txn.call_after( - self._state_group_members_cache.update, - self._state_group_members_cache.sequence, - key=state_group, - value=dict(current_member_state_ids), - ) - - current_non_member_state_ids = { - s: ev - for (s, ev) in iteritems(current_state_ids) - if s[0] != EventTypes.Member - } - txn.call_after( - self._state_group_cache.update, - self._state_group_cache.sequence, - key=state_group, - value=dict(current_non_member_state_ids), - ) - - return state_group - - return self.db.runInteraction("store_state_group", _store_state_group_txn) - @defer.inlineCallbacks def get_referenced_state_groups(self, state_groups): """Check if the state groups are referenced by events. @@ -1023,22 +289,14 @@ class StateGroupWorkerStore( return set(row["state_group"] for row in rows) -class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): +class MainStateBackgroundUpdateStore(SQLBaseStore): - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" - STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" def __init__(self, database: Database, db_conn, hs): - super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_update_handler( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, - self._background_deduplicate_state, - ) - self.db.updates.register_background_update_handler( - self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state - ) + super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + self.db.updates.register_background_index_update( self.CURRENT_STATE_INDEX_UPDATE_NAME, index_name="current_state_events_member_index", @@ -1053,181 +311,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): columns=["state_group"], ) - @defer.inlineCallbacks - def _background_deduplicate_state(self, progress, batch_size): - """This background update will slowly deduplicate state by reencoding - them as deltas. - """ - last_state_group = progress.get("last_state_group", 0) - rows_inserted = progress.get("rows_inserted", 0) - max_group = progress.get("max_group", None) - - BATCH_SIZE_SCALE_FACTOR = 100 - batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) - - if max_group is None: - rows = yield self.db.execute( - "_background_deduplicate_state", - None, - "SELECT coalesce(max(id), 0) FROM state_groups", - ) - max_group = rows[0][0] - - def reindex_txn(txn): - new_last_state_group = last_state_group - for count in range(batch_size): - txn.execute( - "SELECT id, room_id FROM state_groups" - " WHERE ? < id AND id <= ?" - " ORDER BY id ASC" - " LIMIT 1", - (new_last_state_group, max_group), - ) - row = txn.fetchone() - if row: - state_group, room_id = row - - if not row or not state_group: - return True, count - - txn.execute( - "SELECT state_group FROM state_group_edges" - " WHERE state_group = ?", - (state_group,), - ) - - # If we reach a point where we've already started inserting - # edges we should stop. - if txn.fetchall(): - return True, count - - txn.execute( - "SELECT coalesce(max(id), 0) FROM state_groups" - " WHERE id < ? AND room_id = ?", - (state_group, room_id), - ) - (prev_group,) = txn.fetchone() - new_last_state_group = state_group - - if prev_group: - potential_hops = self._count_state_group_hops_txn(txn, prev_group) - if potential_hops >= MAX_STATE_DELTA_HOPS: - # We want to ensure chains are at most this long,# - # otherwise read performance degrades. - continue - - prev_state = self._get_state_groups_from_groups_txn( - txn, [prev_group] - ) - prev_state = prev_state[prev_group] - - curr_state = self._get_state_groups_from_groups_txn( - txn, [state_group] - ) - curr_state = curr_state[state_group] - - if not set(prev_state.keys()) - set(curr_state.keys()): - # We can only do a delta if the current has a strict super set - # of keys - - delta_state = { - key: value - for key, value in iteritems(curr_state) - if prev_state.get(key, None) != value - } - - self.db.simple_delete_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": state_group}, - ) - - self.db.simple_insert_txn( - txn, - table="state_group_edges", - values={ - "state_group": state_group, - "prev_state_group": prev_group, - }, - ) - - self.db.simple_delete_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": state_group}, - ) - - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(delta_state) - ], - ) - - progress = { - "last_state_group": state_group, - "rows_inserted": rows_inserted + batch_size, - "max_group": max_group, - } - - self.db.updates._background_update_progress_txn( - txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress - ) - - return False, batch_size - - finished, result = yield self.db.runInteraction( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn - ) - - if finished: - yield self.db.updates._end_background_update( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME - ) - - return result * BATCH_SIZE_SCALE_FACTOR - - @defer.inlineCallbacks - def _background_index_state(self, progress, batch_size): - def reindex_txn(conn): - conn.rollback() - if isinstance(self.database_engine, PostgresEngine): - # postgres insists on autocommit for the index - conn.set_session(autocommit=True) - try: - txn = conn.cursor() - txn.execute( - "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" - " ON state_groups_state(state_group, type, state_key)" - ) - txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - finally: - conn.set_session(autocommit=False) - else: - txn = conn.cursor() - txn.execute( - "CREATE INDEX state_groups_state_type_idx" - " ON state_groups_state(state_group, type, state_key)" - ) - txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - - yield self.db.runWithConnection(reindex_txn) - - yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) - - return 1 - - -class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): +class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): """ Keeps track of the state at a given event. This is done by the concept of `state groups`. Every event is a assigned diff --git a/synapse/storage/data_stores/state/__init__.py b/synapse/storage/data_stores/state/__init__.py new file mode 100644 index 0000000000..86e09f6229 --- /dev/null +++ b/synapse/storage/data_stores/state/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.data_stores.state.store import StateGroupDataStore # noqa: F401 diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/data_stores/state/bg_updates.py new file mode 100644 index 0000000000..e8edaf9f7b --- /dev/null +++ b/synapse/storage/data_stores/state/bg_updates.py @@ -0,0 +1,374 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from six import iteritems + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database +from synapse.storage.engines import PostgresEngine +from synapse.storage.state import StateFilter + +logger = logging.getLogger(__name__) + + +MAX_STATE_DELTA_HOPS = 100 + + +class StateGroupBackgroundUpdateStore(SQLBaseStore): + """Defines functions related to state groups needed to run the state backgroud + updates. + """ + + def _count_state_group_hops_txn(self, txn, state_group): + """Given a state group, count how many hops there are in the tree. + + This is used to ensure the delta chains don't get too long. + """ + if isinstance(self.database_engine, PostgresEngine): + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT count(*) FROM state; + """ + + txn.execute(sql, (state_group,)) + row = txn.fetchone() + if row and row[0]: + return row[0] + else: + return 0 + else: + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + next_group = state_group + count = 0 + + while next_group: + next_group = self.db.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + count += 1 + + return count + + def _get_state_groups_from_groups_txn( + self, txn, groups, state_filter=StateFilter.all() + ): + results = {group: {} for group in groups} + + where_clause, where_args = state_filter.make_sql_filter_clause() + + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) + + if isinstance(self.database_engine, PostgresEngine): + # Temporarily disable sequential scans in this transaction. This is + # a temporary hack until we can add the right indices in + txn.execute("SET LOCAL enable_seqscan=off") + + # The below query walks the state_group tree so that the "state" + # table includes all state_groups in the tree. It then joins + # against `state_groups_state` to fetch the latest state. + # It assumes that previous state groups are always numerically + # lesser. + # The PARTITION is used to get the event_id in the greatest state + # group for the given type, state_key. + # This may return multiple rows per (type, state_key), but last_value + # should be the same. + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT DISTINCT type, state_key, last_value(event_id) OVER ( + PARTITION BY type, state_key ORDER BY state_group ASC + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) AS event_id FROM state_groups_state + WHERE state_group IN ( + SELECT state_group FROM state + ) + """ + + for group in groups: + args = [group] + args.extend(where_args) + + txn.execute(sql + where_clause, args) + for row in txn: + typ, state_key, event_id = row + key = (typ, state_key) + results[group][key] = event_id + else: + max_entries_returned = state_filter.max_entries_returned() + + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + for group in groups: + next_group = group + + while next_group: + # We did this before by getting the list of group ids, and + # then passing that list to sqlite to get latest event for + # each (type, state_key). However, that was terribly slow + # without the right indices (which we can't add until + # after we finish deduping state, which requires this func) + args = [next_group] + args.extend(where_args) + + txn.execute( + "SELECT type, state_key, event_id FROM state_groups_state" + " WHERE state_group = ? " + where_clause, + args, + ) + results[group].update( + ((typ, state_key), event_id) + for typ, state_key, event_id in txn + if (typ, state_key) not in results[group] + ) + + # If the number of entries in the (type,state_key)->event_id dict + # matches the number of (type,state_keys) types we were searching + # for, then we must have found them all, so no need to go walk + # further down the tree... UNLESS our types filter contained + # wildcards (i.e. Nones) in which case we have to do an exhaustive + # search + if ( + max_entries_returned is not None + and len(results[group]) == max_entries_returned + ): + break + + next_group = self.db.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + + return results + + +class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): + + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" + + def __init__(self, database: Database, db_conn, hs): + super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + self.db.updates.register_background_update_handler( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, + self._background_deduplicate_state, + ) + self.db.updates.register_background_update_handler( + self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state + ) + self.db.updates.register_background_index_update( + self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME, + index_name="state_groups_room_id_idx", + table="state_groups", + columns=["room_id"], + ) + + @defer.inlineCallbacks + def _background_deduplicate_state(self, progress, batch_size): + """This background update will slowly deduplicate state by reencoding + them as deltas. + """ + last_state_group = progress.get("last_state_group", 0) + rows_inserted = progress.get("rows_inserted", 0) + max_group = progress.get("max_group", None) + + BATCH_SIZE_SCALE_FACTOR = 100 + + batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) + + if max_group is None: + rows = yield self.db.execute( + "_background_deduplicate_state", + None, + "SELECT coalesce(max(id), 0) FROM state_groups", + ) + max_group = rows[0][0] + + def reindex_txn(txn): + new_last_state_group = last_state_group + for count in range(batch_size): + txn.execute( + "SELECT id, room_id FROM state_groups" + " WHERE ? < id AND id <= ?" + " ORDER BY id ASC" + " LIMIT 1", + (new_last_state_group, max_group), + ) + row = txn.fetchone() + if row: + state_group, room_id = row + + if not row or not state_group: + return True, count + + txn.execute( + "SELECT state_group FROM state_group_edges" + " WHERE state_group = ?", + (state_group,), + ) + + # If we reach a point where we've already started inserting + # edges we should stop. + if txn.fetchall(): + return True, count + + txn.execute( + "SELECT coalesce(max(id), 0) FROM state_groups" + " WHERE id < ? AND room_id = ?", + (state_group, room_id), + ) + (prev_group,) = txn.fetchone() + new_last_state_group = state_group + + if prev_group: + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if potential_hops >= MAX_STATE_DELTA_HOPS: + # We want to ensure chains are at most this long,# + # otherwise read performance degrades. + continue + + prev_state = self._get_state_groups_from_groups_txn( + txn, [prev_group] + ) + prev_state = prev_state[prev_group] + + curr_state = self._get_state_groups_from_groups_txn( + txn, [state_group] + ) + curr_state = curr_state[state_group] + + if not set(prev_state.keys()) - set(curr_state.keys()): + # We can only do a delta if the current has a strict super set + # of keys + + delta_state = { + key: value + for key, value in iteritems(curr_state) + if prev_state.get(key, None) != value + } + + self.db.simple_delete_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": state_group}, + ) + + self.db.simple_insert_txn( + txn, + table="state_group_edges", + values={ + "state_group": state_group, + "prev_state_group": prev_group, + }, + ) + + self.db.simple_delete_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + ) + + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(delta_state) + ], + ) + + progress = { + "last_state_group": state_group, + "rows_inserted": rows_inserted + batch_size, + "max_group": max_group, + } + + self.db.updates._background_update_progress_txn( + txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress + ) + + return False, batch_size + + finished, result = yield self.db.runInteraction( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn + ) + + if finished: + yield self.db.updates._end_background_update( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME + ) + + return result * BATCH_SIZE_SCALE_FACTOR + + @defer.inlineCallbacks + def _background_index_state(self, progress, batch_size): + def reindex_txn(conn): + conn.rollback() + if isinstance(self.database_engine, PostgresEngine): + # postgres insists on autocommit for the index + conn.set_session(autocommit=True) + try: + txn = conn.cursor() + txn.execute( + "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") + finally: + conn.set_session(autocommit=False) + else: + txn = conn.cursor() + txn.execute( + "CREATE INDEX state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") + + yield self.db.runWithConnection(reindex_txn) + + yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) + + return 1 diff --git a/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql index ae09fa0065..ae09fa0065 100644 --- a/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql +++ b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql index e85699e82e..e85699e82e 100644 --- a/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql +++ b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql diff --git a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql new file mode 100644 index 0000000000..1450313bfa --- /dev/null +++ b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql @@ -0,0 +1,19 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- The following indices are redundant, other indices are equivalent or +-- supersets +DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY diff --git a/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql index 33980d02f0..33980d02f0 100644 --- a/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql +++ b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/35/state.sql b/synapse/storage/data_stores/state/schema/delta/35/state.sql index 0f1fa68a89..0f1fa68a89 100644 --- a/synapse/storage/data_stores/main/schema/delta/35/state.sql +++ b/synapse/storage/data_stores/state/schema/delta/35/state.sql diff --git a/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql index 97e5067ef4..97e5067ef4 100644 --- a/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql +++ b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql diff --git a/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py index 9fd1ccf6f7..9fd1ccf6f7 100644 --- a/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py +++ b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py diff --git a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql new file mode 100644 index 0000000000..7916ef18b2 --- /dev/null +++ b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('state_groups_room_id_idx', '{}'); diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql new file mode 100644 index 0000000000..35f97d6b3d --- /dev/null +++ b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql @@ -0,0 +1,37 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE state_groups ( + id BIGINT PRIMARY KEY, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE state_groups_state ( + state_group BIGINT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE state_group_edges ( + state_group BIGINT NOT NULL, + prev_state_group BIGINT NOT NULL +); + +CREATE INDEX state_group_edges_idx ON state_group_edges (state_group); +CREATE INDEX state_group_edges_prev_idx ON state_group_edges (prev_state_group); +CREATE INDEX state_groups_state_type_idx ON state_groups_state (state_group, type, state_key); diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres new file mode 100644 index 0000000000..fcd926c9fb --- /dev/null +++ b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres @@ -0,0 +1,21 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE state_group_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py new file mode 100644 index 0000000000..d53695f238 --- /dev/null +++ b/synapse/storage/data_stores/state/store.py @@ -0,0 +1,640 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import namedtuple + +from six import iteritems +from six.moves import range + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore +from synapse.storage.database import Database +from synapse.storage.state import StateFilter +from synapse.util.caches import get_cache_factor_for +from synapse.util.caches.descriptors import cached +from synapse.util.caches.dictionary_cache import DictionaryCache + +logger = logging.getLogger(__name__) + + +MAX_STATE_DELTA_HOPS = 100 + + +class _GetStateGroupDelta( + namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) +): + """Return type of get_state_group_delta that implements __len__, which lets + us use the itrable flag when caching + """ + + __slots__ = [] + + def __len__(self): + return len(self.delta_ids) if self.delta_ids else 0 + + +class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): + """A data store for fetching/storing state groups. + """ + + def __init__(self, database: Database, db_conn, hs): + super(StateGroupDataStore, self).__init__(database, db_conn, hs) + + # Originally the state store used a single DictionaryCache to cache the + # event IDs for the state types in a given state group to avoid hammering + # on the state_group* tables. + # + # The point of using a DictionaryCache is that it can cache a subset + # of the state events for a given state group (i.e. a subset of the keys for a + # given dict which is an entry in the cache for a given state group ID). + # + # However, this poses problems when performing complicated queries + # on the store - for instance: "give me all the state for this group, but + # limit members to this subset of users", as DictionaryCache's API isn't + # rich enough to say "please cache any of these fields, apart from this subset". + # This is problematic when lazy loading members, which requires this behaviour, + # as without it the cache has no choice but to speculatively load all + # state events for the group, which negates the efficiency being sought. + # + # Rather than overcomplicating DictionaryCache's API, we instead split the + # state_group_cache into two halves - one for tracking non-member events, + # and the other for tracking member_events. This means that lazy loading + # queries can be made in a cache-friendly manner by querying both caches + # separately and then merging the result. So for the example above, you + # would query the members cache for a specific subset of state keys + # (which DictionaryCache will handle efficiently and fine) and the non-members + # cache for all state (which DictionaryCache will similarly handle fine) + # and then just merge the results together. + # + # We size the non-members cache to be smaller than the members cache as the + # vast majority of state in Matrix (today) is member events. + + self._state_group_cache = DictionaryCache( + "*stateGroupCache*", + # TODO: this hasn't been tuned yet + 50000 * get_cache_factor_for("stateGroupCache"), + ) + self._state_group_members_cache = DictionaryCache( + "*stateGroupMembersCache*", + 500000 * get_cache_factor_for("stateGroupMembersCache"), + ) + + @cached(max_entries=10000, iterable=True) + def get_state_group_delta(self, state_group): + """Given a state group try to return a previous group and a delta between + the old and the new. + + Returns: + (prev_group, delta_ids), where both may be None. + """ + + def _get_state_group_delta_txn(txn): + prev_group = self.db.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": state_group}, + retcol="prev_state_group", + allow_none=True, + ) + + if not prev_group: + return _GetStateGroupDelta(None, None) + + delta_ids = self.db.simple_select_list_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + retcols=("type", "state_key", "event_id"), + ) + + return _GetStateGroupDelta( + prev_group, + {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, + ) + + return self.db.runInteraction( + "get_state_group_delta", _get_state_group_delta_txn + ) + + @defer.inlineCallbacks + def _get_state_groups_from_groups(self, groups, state_filter): + """Returns the state groups for a given set of groups, filtering on + types of state events. + + Args: + groups(list[int]): list of state group IDs to query + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + results = {} + + chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] + for chunk in chunks: + res = yield self.db.runInteraction( + "_get_state_groups_from_groups", + self._get_state_groups_from_groups_txn, + chunk, + state_filter, + ) + results.update(res) + + return results + + def _get_state_for_group_using_cache(self, cache, group, state_filter): + """Checks if group is in cache. See `_get_state_for_groups` + + Args: + cache(DictionaryCache): the state group cache to use + group(int): The state group to lookup + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns 2-tuple (`state_dict`, `got_all`). + `got_all` is a bool indicating if we successfully retrieved all + requests state from the cache, if False we need to query the DB for the + missing state. + """ + is_all, known_absent, state_dict_ids = cache.get(group) + + if is_all or state_filter.is_full(): + # Either we have everything or want everything, either way + # `is_all` tells us whether we've gotten everything. + return state_filter.filter_state(state_dict_ids), is_all + + # tracks whether any of our requested types are missing from the cache + missing_types = False + + if state_filter.has_wildcards(): + # We don't know if we fetched all the state keys for the types in + # the filter that are wildcards, so we have to assume that we may + # have missed some. + missing_types = True + else: + # There aren't any wild cards, so `concrete_types()` returns the + # complete list of event types we're wanting. + for key in state_filter.concrete_types(): + if key not in state_dict_ids and key not in known_absent: + missing_types = True + break + + return state_filter.filter_state(state_dict_ids), not missing_types + + @defer.inlineCallbacks + def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + + member_filter, non_member_filter = state_filter.get_member_split() + + # Now we look them up in the member and non-member caches + ( + non_member_state, + incomplete_groups_nm, + ) = yield self._get_state_for_groups_using_cache( + groups, self._state_group_cache, state_filter=non_member_filter + ) + + ( + member_state, + incomplete_groups_m, + ) = yield self._get_state_for_groups_using_cache( + groups, self._state_group_members_cache, state_filter=member_filter + ) + + state = dict(non_member_state) + for group in groups: + state[group].update(member_state[group]) + + # Now fetch any missing groups from the database + + incomplete_groups = incomplete_groups_m | incomplete_groups_nm + + if not incomplete_groups: + return state + + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence + + # Help the cache hit ratio by expanding the filter a bit + db_state_filter = state_filter.return_expanded() + + group_to_state_dict = yield self._get_state_groups_from_groups( + list(incomplete_groups), state_filter=db_state_filter + ) + + # Now lets update the caches + self._insert_into_cache( + group_to_state_dict, + db_state_filter, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + # And finally update the result dict, by filtering out any extra + # stuff we pulled out of the database. + for group, group_state_dict in iteritems(group_to_state_dict): + # We just replace any existing entries, as we will have loaded + # everything we need from the database anyway. + state[group] = state_filter.filter_state(group_state_dict) + + return state + + def _get_state_for_groups_using_cache(self, groups, cache, state_filter): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key, querying from a specific cache. + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + cache (DictionaryCache): the cache of group ids to state dicts which + we will pass through - either the normal state cache or the specific + members state cache. + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of + dict of state_group_id -> (dict of (type, state_key) -> event id) + of entries in the cache, and the state group ids either missing + from the cache or incomplete. + """ + results = {} + incomplete_groups = set() + for group in set(groups): + state_dict_ids, got_all = self._get_state_for_group_using_cache( + cache, group, state_filter + ) + results[group] = state_dict_ids + + if not got_all: + incomplete_groups.add(group) + + return results, incomplete_groups + + def _insert_into_cache( + self, + group_to_state_dict, + state_filter, + cache_seq_num_members, + cache_seq_num_non_members, + ): + """Inserts results from querying the database into the relevant cache. + + Args: + group_to_state_dict (dict): The new entries pulled from database. + Map from state group to state dict + state_filter (StateFilter): The state filter used to fetch state + from the database. + cache_seq_num_members (int): Sequence number of member cache since + last lookup in cache + cache_seq_num_non_members (int): Sequence number of member cache since + last lookup in cache + """ + + # We need to work out which types we've fetched from the DB for the + # member vs non-member caches. This should be as accurate as possible, + # but can be an underestimate (e.g. when we have wild cards) + + member_filter, non_member_filter = state_filter.get_member_split() + if member_filter.is_full(): + # We fetched all member events + member_types = None + else: + # `concrete_types()` will only return a subset when there are wild + # cards in the filter, but that's fine. + member_types = member_filter.concrete_types() + + if non_member_filter.is_full(): + # We fetched all non member events + non_member_types = None + else: + non_member_types = non_member_filter.concrete_types() + + for group, group_state_dict in iteritems(group_to_state_dict): + state_dict_members = {} + state_dict_non_members = {} + + for k, v in iteritems(group_state_dict): + if k[0] == EventTypes.Member: + state_dict_members[k] = v + else: + state_dict_non_members[k] = v + + self._state_group_members_cache.update( + cache_seq_num_members, + key=group, + value=state_dict_members, + fetched_keys=member_types, + ) + + self._state_group_cache.update( + cache_seq_num_non_members, + key=group, + value=state_dict_non_members, + fetched_keys=non_member_types, + ) + + def store_state_group( + self, event_id, room_id, prev_group, delta_ids, current_state_ids + ): + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id (str): The event ID for which the state was calculated + room_id (str) + prev_group (int|None): A previous state group for the room, optional. + delta_ids (dict|None): The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids (dict): The state to store. Map of (type, state_key) + to event_id. + + Returns: + Deferred[int]: The state group ID + """ + + def _store_state_group_txn(txn): + if current_state_ids is None: + # AFAIK, this can never happen + raise Exception("current_state_ids cannot be None") + + state_group = self.database_engine.get_next_state_group_id(txn) + + self.db.simple_insert_txn( + txn, + table="state_groups", + values={"id": state_group, "room_id": room_id, "event_id": event_id}, + ) + + # We persist as a delta if we can, while also ensuring the chain + # of deltas isn't tooo long, as otherwise read performance degrades. + if prev_group: + is_in_db = self.db.simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + self.db.simple_insert_txn( + txn, + table="state_group_edges", + values={"state_group": state_group, "prev_state_group": prev_group}, + ) + + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(delta_ids) + ], + ) + else: + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(current_state_ids) + ], + ) + + # Prefill the state group caches with this group. + # It's fine to use the sequence like this as the state group map + # is immutable. (If the map wasn't immutable then this prefill could + # race with another update) + + current_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] == EventTypes.Member + } + txn.call_after( + self._state_group_members_cache.update, + self._state_group_members_cache.sequence, + key=state_group, + value=dict(current_member_state_ids), + ) + + current_non_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] != EventTypes.Member + } + txn.call_after( + self._state_group_cache.update, + self._state_group_cache.sequence, + key=state_group, + value=dict(current_non_member_state_ids), + ) + + return state_group + + return self.db.runInteraction("store_state_group", _store_state_group_txn) + + def purge_unreferenced_state_groups( + self, room_id: str, state_groups_to_delete + ) -> defer.Deferred: + """Deletes no longer referenced state groups and de-deltas any state + groups that reference them. + + Args: + room_id: The room the state groups belong to (must all be in the + same room). + state_groups_to_delete (Collection[int]): Set of all state groups + to delete. + """ + + return self.db.runInteraction( + "purge_unreferenced_state_groups", + self._purge_unreferenced_state_groups, + room_id, + state_groups_to_delete, + ) + + def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete): + logger.info( + "[purge] found %i state groups to delete", len(state_groups_to_delete) + ) + + rows = self.db.simple_select_many_txn( + txn, + table="state_group_edges", + column="prev_state_group", + iterable=state_groups_to_delete, + keyvalues={}, + retcols=("state_group",), + ) + + remaining_state_groups = set( + row["state_group"] + for row in rows + if row["state_group"] not in state_groups_to_delete + ) + + logger.info( + "[purge] de-delta-ing %i remaining state groups", + len(remaining_state_groups), + ) + + # Now we turn the state groups that reference to-be-deleted state + # groups to non delta versions. + for sg in remaining_state_groups: + logger.info("[purge] de-delta-ing remaining state group %s", sg) + curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) + curr_state = curr_state[sg] + + self.db.simple_delete_txn( + txn, table="state_groups_state", keyvalues={"state_group": sg} + ) + + self.db.simple_delete_txn( + txn, table="state_group_edges", keyvalues={"state_group": sg} + ) + + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": sg, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(curr_state) + ], + ) + + logger.info("[purge] removing redundant state groups") + txn.executemany( + "DELETE FROM state_groups_state WHERE state_group = ?", + ((sg,) for sg in state_groups_to_delete), + ) + txn.executemany( + "DELETE FROM state_groups WHERE id = ?", + ((sg,) for sg in state_groups_to_delete), + ) + + @defer.inlineCallbacks + def get_previous_state_groups(self, state_groups): + """Fetch the previous groups of the given state groups. + + Args: + state_groups (Iterable[int]) + + Returns: + Deferred[dict[int, int]]: mapping from state group to previous + state group. + """ + + rows = yield self.db.simple_select_many_batch( + table="state_group_edges", + column="prev_state_group", + iterable=state_groups, + keyvalues={}, + retcols=("prev_state_group", "state_group"), + desc="get_previous_state_groups", + ) + + return {row["state_group"]: row["prev_state_group"] for row in rows} + + def purge_room_state(self, room_id, state_groups_to_delete): + """Deletes all record of a room from state tables + + Args: + room_id (str): + state_groups_to_delete (list[int]): State groups to delete + """ + + return self.db.runInteraction( + "purge_room_state", + self._purge_room_state_txn, + room_id, + state_groups_to_delete, + ) + + def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): + # first we have to delete the state groups states + logger.info("[purge] removing %s from state_groups_state", room_id) + + self.db.simple_delete_many_txn( + txn, + table="state_groups_state", + column="state_group", + iterable=state_groups_to_delete, + keyvalues={}, + ) + + # ... and the state group edges + logger.info("[purge] removing %s from state_group_edges", room_id) + + self.db.simple_delete_many_txn( + txn, + table="state_group_edges", + column="state_group", + iterable=state_groups_to_delete, + keyvalues={}, + ) + + # ... and the state groups + logger.info("[purge] removing %s from state_groups", room_id) + + self.db.simple_delete_many_txn( + txn, + table="state_groups", + column="id", + iterable=state_groups_to_delete, + keyvalues={}, + ) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index ec19ae1d9d..1003dd84a5 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -24,9 +24,11 @@ from six.moves import intern, range from prometheus_client import Histogram +from twisted.enterprise import adbapi from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.config.database import DatabaseConnectionConfig from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater @@ -74,6 +76,37 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { } +def make_pool( + reactor, db_config: DatabaseConnectionConfig, engine +) -> adbapi.ConnectionPool: + """Get the connection pool for the database. + """ + + return adbapi.ConnectionPool( + db_config.config["name"], + cp_reactor=reactor, + cp_openfun=engine.on_new_connection, + **db_config.config.get("args", {}) + ) + + +def make_conn(db_config: DatabaseConnectionConfig, engine): + """Make a new connection to the database and return it. + + Returns: + Connection + """ + + db_params = { + k: v + for k, v in db_config.config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = engine.module.connect(**db_params) + engine.on_new_connection(db_conn) + return db_conn + + class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() @@ -218,10 +251,11 @@ class Database(object): _TXN_ID = 0 - def __init__(self, hs): + def __init__(self, hs, database_config: DatabaseConnectionConfig, engine): self.hs = hs self._clock = hs.get_clock() - self._db_pool = hs.get_db_pool() + self._database_config = database_config + self._db_pool = make_pool(hs.get_reactor(), database_config, engine) self.updates = BackgroundUpdater(hs, self) @@ -234,7 +268,7 @@ class Database(object): # to watch it self._txn_perf_counters = PerformanceCounters() - self.engine = hs.database_engine + self.engine = engine # A set of tables that are not safe to use native upserts in. self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) @@ -255,6 +289,11 @@ class Database(object): self._check_safe_to_upsert, ) + def is_running(self): + """Is the database pool currently running + """ + return self._db_pool.running + @defer.inlineCallbacks def _check_safe_to_upsert(self): """ diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index cbc74cd302..df039a072d 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -16,8 +16,6 @@ import struct import threading -from synapse.storage.prepare_database import prepare_database - class Sqlite3Engine(object): single_threaded = True @@ -62,6 +60,10 @@ class Sqlite3Engine(object): return sql def on_new_connection(self, db_conn): + + # We need to import here to avoid an import loop. + from synapse.storage.prepare_database import prepare_database + if self._is_in_memory: # In memory databases need to be rebuilt each time. Ideally we'd # reuse the same connection as we do when starting up, but that diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index fa03ca9ff7..1ed44925fc 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -183,7 +183,7 @@ class EventsPersistenceStorage(object): # so we use separate variables here even though they point to the same # store for now. self.main_store = stores.main - self.state_store = stores.main + self.state_store = stores.state self._clock = hs.get_clock() self.is_mine_id = hs.is_mine_id diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 731e1c9d9c..e70026b80a 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -18,6 +18,7 @@ import imp import logging import os import re +from collections import Counter import attr @@ -41,7 +42,7 @@ class UpgradeDatabaseException(PrepareDatabaseException): pass -def prepare_database(db_conn, database_engine, config): +def prepare_database(db_conn, database_engine, config, data_stores=["main", "state"]): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. @@ -54,11 +55,10 @@ def prepare_database(db_conn, database_engine, config): config (synapse.config.homeserver.HomeServerConfig|None): application config, or None if we are connecting to an existing database which we expect to be configured already + data_stores (list[str]): The name of the data stores that will be used + with this database. Defaults to all data stores. """ - # For now we only have the one datastore. - data_stores = ["main"] - try: cur = db_conn.cursor() version_info = _get_or_create_schema_state(cur, database_engine) @@ -70,7 +70,10 @@ def prepare_database(db_conn, database_engine, config): if user_version != SCHEMA_VERSION: # If we don't pass in a config file then we are expecting to # have already upgraded the DB. - raise UpgradeDatabaseException("Database needs to be upgraded") + raise UpgradeDatabaseException( + "Expected database schema version %i but got %i" + % (SCHEMA_VERSION, user_version) + ) else: _upgrade_existing_database( cur, @@ -313,6 +316,9 @@ def _upgrade_existing_database( ) ) + # Used to check if we have any duplicate file names + file_name_counter = Counter() + # Now find which directories have anything of interest. directory_entries = [] for directory in directories: @@ -323,6 +329,9 @@ def _upgrade_existing_database( _DirectoryListing(file_name, os.path.join(directory, file_name)) for file_name in file_names ) + + for file_name in file_names: + file_name_counter[file_name] += 1 except FileNotFoundError: # Data stores can have empty entries for a given version delta. pass @@ -331,6 +340,17 @@ def _upgrade_existing_database( "Could not open delta dir for version %d: %s" % (v, directory) ) + duplicates = set( + file_name for file_name, count in file_name_counter.items() if count > 1 + ) + if duplicates: + # We don't support using the same file name in the same delta version. + raise PrepareDatabaseException( + "Found multiple delta files with the same name in v%d: %s", + v, + duplicates, + ) + # We sort to ensure that we apply the delta files in a consistent # order (to avoid bugs caused by inconsistent directory listing order) directory_entries.sort() diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py index a368182034..d6a7bd7834 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py @@ -58,7 +58,7 @@ class PurgeEventsStorage(object): sg_to_delete = yield self._find_unreferenced_groups(state_groups) - yield self.stores.main.purge_unreferenced_state_groups(room_id, sg_to_delete) + yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) @defer.inlineCallbacks def _find_unreferenced_groups(self, state_groups): @@ -102,7 +102,7 @@ class PurgeEventsStorage(object): # groups that are referenced. current_search -= referenced - edges = yield self.stores.main.get_previous_state_groups(current_search) + edges = yield self.stores.state.get_previous_state_groups(current_search) prevs = set(edges.values()) # We don't bother re-handling groups we've already seen diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 3735846899..cbeb586014 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -342,7 +342,7 @@ class StateGroupStorage(object): (prev_group, delta_ids) """ - return self.stores.main.get_state_group_delta(state_group) + return self.stores.state.get_state_group_delta(state_group) @defer.inlineCallbacks def get_state_groups_ids(self, _room_id, event_ids): @@ -362,7 +362,7 @@ class StateGroupStorage(object): event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self.stores.main._get_state_for_groups(groups) + group_to_state = yield self.stores.state._get_state_for_groups(groups) return group_to_state @@ -423,7 +423,7 @@ class StateGroupStorage(object): dict of state_group_id -> (dict of (type, state_key) -> event id) """ - return self.stores.main._get_state_groups_from_groups(groups, state_filter) + return self.stores.state._get_state_groups_from_groups(groups, state_filter) @defer.inlineCallbacks def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): @@ -439,7 +439,7 @@ class StateGroupStorage(object): event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self.stores.main._get_state_for_groups( + group_to_state = yield self.stores.state._get_state_for_groups( groups, state_filter ) @@ -476,7 +476,7 @@ class StateGroupStorage(object): event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self.stores.main._get_state_for_groups( + group_to_state = yield self.stores.state._get_state_for_groups( groups, state_filter ) @@ -532,7 +532,7 @@ class StateGroupStorage(object): Deferred[dict[int, dict[tuple[str, str], str]]]: dict of state_group_id -> (dict of (type, state_key) -> event id) """ - return self.stores.main._get_state_for_groups(groups, state_filter) + return self.stores.state._get_state_for_groups(groups, state_filter) def store_state_group( self, event_id, room_id, prev_group, delta_ids, current_state_ids @@ -552,6 +552,6 @@ class StateGroupStorage(object): Returns: Deferred[int]: The state group ID """ - return self.stores.main.store_state_group( + return self.stores.state.store_state_group( event_id, room_id, prev_group, delta_ids, current_state_ids ) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index b91fb2db7b..fcd2aaa9c9 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict + from twisted.internet import defer from synapse.handlers.account_data import AccountDataEventSource @@ -35,7 +37,7 @@ class EventSources(object): def __init__(self, hs): self.sources = { name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items() - } + } # type: Dict[str, Any] self.store = hs.get_datastore() @defer.inlineCallbacks diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 84f5ae22c3..2e8f6543e5 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -271,7 +271,7 @@ class _CacheDescriptorBase(object): else: self.function_to_call = orig - arg_spec = inspect.getargspec(orig) + arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args if "cache_context" in all_args: diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py deleted file mode 100644 index 8318db8d2c..0000000000 --- a/synapse/util/caches/snapshot_cache.py +++ /dev/null @@ -1,94 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.util.async_helpers import ObservableDeferred - - -class SnapshotCache(object): - """Cache for snapshots like the response of /initialSync. - The response of initialSync only has to be a recent snapshot of the - server state. It shouldn't matter to clients if it is a few minutes out - of date. - - This caches a deferred response. Until the deferred completes it will be - returned from the cache. This means that if the client retries the request - while the response is still being computed, that original response will be - used rather than trying to compute a new response. - - Once the deferred completes it will removed from the cache after 5 minutes. - We delay removing it from the cache because a client retrying its request - could race with us finishing computing the response. - - Rather than tracking precisely how long something has been in the cache we - keep two generations of completed responses. Every 5 minutes discard the - old generation, move the new generation to the old generation, and set the - new generation to be empty. This means that a result will be in the cache - somewhere between 5 and 10 minutes. - """ - - DURATION_MS = 5 * 60 * 1000 # Cache results for 5 minutes. - - def __init__(self): - self.pending_result_cache = {} # Request that haven't finished yet. - self.prev_result_cache = {} # The older requests that have finished. - self.next_result_cache = {} # The newer requests that have finished. - self.time_last_rotated_ms = 0 - - def rotate(self, time_now_ms): - # Rotate once if the cache duration has passed since the last rotation. - if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS: - self.prev_result_cache = self.next_result_cache - self.next_result_cache = {} - self.time_last_rotated_ms += self.DURATION_MS - - # Rotate again if the cache duration has passed twice since the last - # rotation. - if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS: - self.prev_result_cache = self.next_result_cache - self.next_result_cache = {} - self.time_last_rotated_ms = time_now_ms - - def get(self, time_now_ms, key): - self.rotate(time_now_ms) - # This cache is intended to deduplicate requests, so we expect it to be - # missed most of the time. So we just lookup the key in all of the - # dictionaries rather than trying to short circuit the lookup if the - # key is found. - result = self.prev_result_cache.get(key) - result = self.next_result_cache.get(key, result) - result = self.pending_result_cache.get(key, result) - if result is not None: - return result.observe() - else: - return None - - def set(self, time_now_ms, key, deferred): - self.rotate(time_now_ms) - - result = ObservableDeferred(deferred) - - self.pending_result_cache[key] = result - - def shuffle_along(r): - # When the deferred completes we shuffle it along to the first - # generation of the result cache. So that it will eventually - # expire from the rotation of that cache. - self.next_result_cache[key] = result - self.pending_result_cache.pop(key, None) - return r - - result.addBoth(shuffle_along) - - return result.observe() diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index fdfa2cbbc4..854eb6c024 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -183,10 +183,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) - test_replace_master_key.skip = ( - "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486" - ) - @defer.inlineCallbacks def test_reupload_signatures(self): """re-uploading a signature should not fail""" @@ -507,7 +503,3 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ], other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey], ) - - test_upload_signatures.skip = ( - "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486" - ) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 92b8726093..596ddc6970 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -64,28 +64,29 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): mock_federation_client = Mock(spec=["put_json"]) mock_federation_client.put_json.return_value = defer.succeed((200, "OK")) + datastores = Mock() + datastores.main = Mock( + spec=[ + # Bits that Federation needs + "prep_send_transaction", + "delivered_txn", + "get_received_txn_response", + "set_received_txn_response", + "get_destination_retry_timings", + "get_devices_by_remote", + # Bits that user_directory needs + "get_user_directory_stream_pos", + "get_current_state_deltas", + "get_device_updates_by_remote", + ] + ) + hs = self.setup_test_homeserver( - datastore=( - Mock( - spec=[ - # Bits that Federation needs - "prep_send_transaction", - "delivered_txn", - "get_received_txn_response", - "set_received_txn_response", - "get_destination_retry_timings", - "get_device_updates_by_remote", - # Bits that user_directory needs - "get_user_directory_stream_pos", - "get_current_state_deltas", - ] - ) - ), - notifier=Mock(), - http_client=mock_federation_client, - keyring=mock_keyring, + notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring ) + hs.datastores = datastores + return hs def prepare(self, reactor, clock, hs): diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 358b593cd4..80187406bc 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -165,6 +165,7 @@ class EmailPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) last_stream_ordering = pushers[0]["last_stream_ordering"] @@ -175,6 +176,7 @@ class EmailPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) @@ -192,5 +194,6 @@ class EmailPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index af2327fb66..fe3441f081 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -104,6 +104,7 @@ class HTTPPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) last_stream_ordering = pushers[0]["last_stream_ordering"] @@ -114,6 +115,7 @@ class HTTPPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) @@ -132,6 +134,7 @@ class HTTPPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) last_stream_ordering = pushers[0]["last_stream_ordering"] @@ -151,5 +154,6 @@ class HTTPPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 3dae83c543..2a1e7c7166 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -20,7 +20,7 @@ from synapse.replication.tcp.client import ( ReplicationClientHandler, ) from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory -from synapse.storage.database import Database +from synapse.storage.database import make_conn from tests import unittest from tests.server import FakeTransport @@ -41,10 +41,12 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): + db_config = hs.config.database.get_single_database() self.master_store = self.hs.get_datastore() self.storage = hs.get_storage() + database = hs.get_datastores().databases[0] self.slaved_store = self.STORE_TYPE( - Database(hs), self.hs.get_db_conn(), self.hs + database, make_conn(db_config, database.engine), self.hs ) self.event_id = 0 diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 12c5e95cb5..8df58b4a63 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -237,6 +237,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): config = self.default_config() config["require_auth_for_profile_requests"] = True + config["limit_profile_requests_to_users_who_share_rooms"] = True self.hs = self.setup_test_homeserver(config=config) return self.hs @@ -309,6 +310,7 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() config["require_auth_for_profile_requests"] = True + config["limit_profile_requests_to_users_who_share_rooms"] = True self.hs = self.setup_test_homeserver(config=config) return self.hs diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index c0d0d2b44e..d0c997e385 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -391,9 +391,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # Email config. self.email_attempts = [] - def sendmail(*args, **kwargs): + async def sendmail(*args, **kwargs): self.email_attempts.append((args, kwargs)) - return config["email"] = { "enable_notifs": True, diff --git a/tests/server.py b/tests/server.py index 2b7cf4242e..a554dfdd57 100644 --- a/tests/server.py +++ b/tests/server.py @@ -302,41 +302,42 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): Set up a synchronous test server, driven by the reactor used by the homeserver. """ - d = _sth(cleanup_func, *args, **kwargs).result + server = _sth(cleanup_func, *args, **kwargs) - if isinstance(d, Failure): - d.raiseException() + database = server.config.database.get_single_database() # Make the thread pool synchronous. - clock = d.get_clock() - pool = d.get_db_pool() - - def runWithConnection(func, *args, **kwargs): - return threads.deferToThreadPool( - pool._reactor, - pool.threadpool, - pool._runWithConnection, - func, - *args, - **kwargs - ) - - def runInteraction(interaction, *args, **kwargs): - return threads.deferToThreadPool( - pool._reactor, - pool.threadpool, - pool._runInteraction, - interaction, - *args, - **kwargs - ) + clock = server.get_clock() + + for database in server.get_datastores().databases: + pool = database._db_pool + + def runWithConnection(func, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runWithConnection, + func, + *args, + **kwargs + ) + + def runInteraction(interaction, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runInteraction, + interaction, + *args, + **kwargs + ) - if pool: pool.runWithConnection = runWithConnection pool.runInteraction = runInteraction pool.threadpool = ThreadPool(clock._reactor) pool.running = True - return d + + return server def get_clock(): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 2e521e9ab7..fd52512696 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -28,7 +28,7 @@ from synapse.storage.data_stores.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) -from synapse.storage.database import Database +from synapse.storage.database import Database, make_conn from tests import unittest from tests.utils import setup_test_homeserver @@ -55,8 +55,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - database = Database(hs) - self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + self.store = ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) def tearDown(self): # TODO: suboptimal that we need to create files for tests! @@ -111,9 +113,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): hs.config.event_cache_size = 1 hs.config.password_providers = [] - self.db_pool = hs.get_db_pool() - self.engine = hs.database_engine - self.as_list = [ {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, {"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"}, @@ -125,8 +124,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files = [] - database = Database(hs) - self.store = TestTransactionStore(database, hs.get_db_conn(), hs) + # We assume there is only one database in these tests + database = hs.get_datastores().databases[0] + self.db_pool = database._db_pool + self.engine = database.engine + + db_config = hs.config.get_single_database() + self.store = TestTransactionStore( + database, make_conn(db_config, self.engine), hs + ) def _add_service(self, url, as_token, id): as_yaml = dict( @@ -419,7 +425,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.event_cache_size = 1 hs.config.password_providers = [] - ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) @defer.inlineCallbacks def test_duplicate_ids(self): @@ -435,7 +444,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) e = cm.exception self.assertIn(f1, str(e)) @@ -456,7 +468,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) e = cm.exception self.assertIn(f1, str(e)) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 537cfe9f64..cdee0a9e60 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -52,15 +52,17 @@ class SQLBaseStoreTestCase(unittest.TestCase): config = Mock() config._disable_native_upserts = True config.event_cache_size = 1 - config.database_config = {"name": "sqlite3"} - engine = create_engine(config.database_config) + hs = TestHomeServer("test", config=config) + + sqlite_config = {"name": "sqlite3"} + engine = create_engine(sqlite_config) fake_engine = Mock(wraps=engine) fake_engine.can_native_upsert = False - hs = TestHomeServer( - "test", db_pool=self.db_pool, config=config, database_engine=fake_engine - ) - self.datastore = SQLBaseStore(Database(hs), None, hs) + db = Database(Mock(), Mock(config=sqlite_config), fake_engine) + db._db_pool = self.db_pool + + self.datastore = SQLBaseStore(db, None, hs) @defer.inlineCallbacks def test_insert_1col(self): diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index fc279340d4..bf674dd184 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -37,9 +37,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(12345678) user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) self.get_success( self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + user_id, "access_token", "ip", "user_agent", device_id ) ) @@ -47,14 +51,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(10) result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", + "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 12345678000, @@ -209,14 +213,16 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.store.db.updates.do_next_background_update(100), by=0.1 ) - # Insert a user IP user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) self.get_success( self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + user_id, "access_token", "ip", "user_agent", device_id ) ) - # Force persisting to disk self.reactor.advance(200) @@ -224,7 +230,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.get_success( self.store.db.simple_update( table="devices", - keyvalues={"user_id": user_id, "device_id": "device_id"}, + keyvalues={"user_id": user_id, "device_id": device_id}, updatevalues={"last_seen": None, "ip": None, "user_agent": None}, desc="test_devices_last_seen_bg_update", ) @@ -232,14 +238,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should now get nulls when querying result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", + "device_id": device_id, "ip": None, "user_agent": None, "last_seen": None, @@ -272,14 +278,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should now get the correct result again result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", + "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 0, @@ -296,11 +302,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.store.db.updates.do_next_background_update(100), by=0.1 ) - # Insert a user IP user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) self.get_success( self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + user_id, "access_token", "ip", "user_agent", device_id ) ) @@ -324,7 +333,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): "access_token": "access_token", "ip": "ip", "user_agent": "user_agent", - "device_id": "device_id", + "device_id": device_id, "last_seen": 0, } ], @@ -347,14 +356,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # But we should still get the correct values for the device result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", + "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 0, diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 4578cc3b60..ed5786865a 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -26,7 +26,6 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.db_pool = hs.get_db_pool() self.store = hs.get_datastore() diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 43200654f1..d6ecf102f8 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -35,7 +35,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store = hs.get_datastore() self.storage = hs.get_storage() - self.state_datastore = self.store + self.state_datastore = self.storage.state.stores.state self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() diff --git a/tests/test_federation.py b/tests/test_federation.py index ad165d7295..68684460c6 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -1,6 +1,6 @@ from mock import Mock -from twisted.internet.defer import maybeDeferred, succeed +from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed from synapse.events import FrozenEvent from synapse.logging.context import LoggingContext @@ -70,8 +70,10 @@ class MessageAcceptTests(unittest.TestCase): ) # Send the join, it should return None (which is not an error) - d = self.handler.on_receive_pdu( - "test.serv", join_event, sent_to_us_directly=True + d = ensureDeferred( + self.handler.on_receive_pdu( + "test.serv", join_event, sent_to_us_directly=True + ) ) self.reactor.advance(1) self.assertEqual(self.successResultOf(d), None) @@ -119,8 +121,10 @@ class MessageAcceptTests(unittest.TestCase): ) with LoggingContext(request="lying_event"): - d = self.handler.on_receive_pdu( - "test.serv", lying_event, sent_to_us_directly=True + d = ensureDeferred( + self.handler.on_receive_pdu( + "test.serv", lying_event, sent_to_us_directly=True + ) ) # Step the reactor, so the database fetches come back diff --git a/tests/test_state.py b/tests/test_state.py index 176535947a..e0aae06be4 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -209,7 +209,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids(self.store) + prev_state_ids = yield ctx_d.get_prev_state_ids() self.assertEqual(2, len(prev_state_ids)) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) @@ -253,7 +253,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids(self.store) + prev_state_ids = yield ctx_d.get_prev_state_ids() self.assertSetEqual( {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()} ) @@ -312,7 +312,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_e = context_store["E"] - prev_state_ids = yield ctx_e.get_prev_state_ids(self.store) + prev_state_ids = yield ctx_e.get_prev_state_ids() self.assertSetEqual( {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()} ) @@ -387,7 +387,7 @@ class StateTestCase(unittest.TestCase): ctx_b = context_store["B"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids(self.store) + prev_state_ids = yield ctx_d.get_prev_state_ids() self.assertSetEqual( {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()} ) @@ -419,10 +419,10 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event, old_state=old_state) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertCountEqual( (e.event_id for e in old_state), current_state_ids.values() ) @@ -442,10 +442,10 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event, old_state=old_state) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertCountEqual( (e.event_id for e in old_state + [event]), current_state_ids.values() ) @@ -479,7 +479,7 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual( set([e.event_id for e in old_state]), set(current_state_ids.values()) @@ -511,7 +511,7 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() self.assertEqual( set([e.event_id for e in old_state]), set(prev_state_ids.values()) @@ -552,7 +552,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual(len(current_state_ids), 6) @@ -594,7 +594,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual(len(current_state_ids), 6) @@ -649,7 +649,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")]) @@ -677,7 +677,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")]) diff --git a/tests/test_types.py b/tests/test_types.py index 9ab5f829b0..8d97c751ea 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -17,18 +17,15 @@ from synapse.api.errors import SynapseError from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart from tests import unittest -from tests.utils import TestHomeServer -mock_homeserver = TestHomeServer(hostname="my.domain") - -class UserIDTestCase(unittest.TestCase): +class UserIDTestCase(unittest.HomeserverTestCase): def test_parse(self): - user = UserID.from_string("@1234abcd:my.domain") + user = UserID.from_string("@1234abcd:test") self.assertEquals("1234abcd", user.localpart) - self.assertEquals("my.domain", user.domain) - self.assertEquals(True, mock_homeserver.is_mine(user)) + self.assertEquals("test", user.domain) + self.assertEquals(True, self.hs.is_mine(user)) def test_pase_empty(self): with self.assertRaises(SynapseError): @@ -48,13 +45,13 @@ class UserIDTestCase(unittest.TestCase): self.assertTrue(userA != userB) -class RoomAliasTestCase(unittest.TestCase): +class RoomAliasTestCase(unittest.HomeserverTestCase): def test_parse(self): - room = RoomAlias.from_string("#channel:my.domain") + room = RoomAlias.from_string("#channel:test") self.assertEquals("channel", room.localpart) - self.assertEquals("my.domain", room.domain) - self.assertEquals(True, mock_homeserver.is_mine(room)) + self.assertEquals("test", room.domain) + self.assertEquals(True, self.hs.is_mine(room)) def test_build(self): room = RoomAlias("channel", "my.domain") diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 8b8455c8b7..281b32c4b8 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -179,6 +179,30 @@ class LoggingContextTestCase(unittest.TestCase): nested_context = nested_logging_context(suffix="bar") self.assertEqual(nested_context.request, "foo-bar") + @defer.inlineCallbacks + def test_make_deferred_yieldable_with_await(self): + # an async function which retuns an incomplete coroutine, but doesn't + # follow the synapse rules. + + async def blocking_function(): + d = defer.Deferred() + reactor.callLater(0, d.callback, None) + await d + + sentinel_context = LoggingContext.current_context() + + with LoggingContext() as context_one: + context_one.request = "one" + + d1 = make_deferred_yieldable(blocking_function()) + # make sure that the context was reset by make_deferred_yieldable + self.assertIs(LoggingContext.current_context(), sentinel_context) + + yield d1 + + # now it should be restored + self._check_test_key("one") + # a function which returns a deferred which has been "called", but # which had a function which returned another incomplete deferred on diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py deleted file mode 100644 index 1a44f72425..0000000000 --- a/tests/util/test_snapshot_cache.py +++ /dev/null @@ -1,63 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet.defer import Deferred - -from synapse.util.caches.snapshot_cache import SnapshotCache - -from .. import unittest - - -class SnapshotCacheTestCase(unittest.TestCase): - def setUp(self): - self.cache = SnapshotCache() - self.cache.DURATION_MS = 1 - - def test_get_set(self): - # Check that getting a missing key returns None - self.assertEquals(self.cache.get(0, "key"), None) - - # Check that setting a key with a deferred returns - # a deferred that resolves when the initial deferred does - d = Deferred() - set_result = self.cache.set(0, "key", d) - self.assertIsNotNone(set_result) - self.assertFalse(set_result.called) - - # Check that getting the key before the deferred has resolved - # returns a deferred that resolves when the initial deferred does. - get_result_at_10 = self.cache.get(10, "key") - self.assertIsNotNone(get_result_at_10) - self.assertFalse(get_result_at_10.called) - - # Check that the returned deferreds resolve when the initial deferred - # does. - d.callback("v") - self.assertTrue(set_result.called) - self.assertTrue(get_result_at_10.called) - - # Check that getting the key after the deferred has resolved - # before the cache expires returns a resolved deferred. - get_result_at_11 = self.cache.get(11, "key") - self.assertIsNotNone(get_result_at_11) - if isinstance(get_result_at_11, Deferred): - # The cache may return the actual result rather than a deferred - self.assertTrue(get_result_at_11.called) - - # Check that getting the key after the deferred has resolved - # after the cache expires returns None - get_result_at_12 = self.cache.get(12, "key") - self.assertIsNone(get_result_at_12) diff --git a/tests/utils.py b/tests/utils.py index c57da59191..e2e9cafd79 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -30,6 +30,7 @@ from twisted.internet import defer, reactor from synapse.api.constants import EventTypes from synapse.api.errors import CodeMessageException, cs_error from synapse.api.room_versions import RoomVersions +from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.federation.transport import server as federation_server @@ -177,7 +178,6 @@ class TestHomeServer(HomeServer): DATASTORE_CLASS = DataStore -@defer.inlineCallbacks def setup_test_homeserver( cleanup_func, name="test", @@ -214,7 +214,7 @@ def setup_test_homeserver( if USE_POSTGRES_FOR_TESTS: test_db = "synapse_test_%s" % uuid.uuid4().hex - config.database_config = { + database_config = { "name": "psycopg2", "args": { "database": test_db, @@ -226,12 +226,15 @@ def setup_test_homeserver( }, } else: - config.database_config = { + database_config = { "name": "sqlite3", "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, } - db_engine = create_engine(config.database_config) + database = DatabaseConnectionConfig("master", database_config) + config.database.databases = [database] + + db_engine = create_engine(database.config) # Create the database before we actually try and connect to it, based off # the template database we generate in setupdb() @@ -251,39 +254,30 @@ def setup_test_homeserver( cur.close() db_conn.close() - # we need to configure the connection pool to run the on_new_connection - # function, so that we can test code that uses custom sqlite functions - # (like rank). - config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection - if datastore is None: hs = homeserverToUse( name, config=config, - db_config=config.database_config, version_string="Synapse/tests", - database_engine=db_engine, tls_server_context_factory=Mock(), tls_client_options_factory=Mock(), reactor=reactor, **kargs ) - # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to - # date db - if not isinstance(db_engine, PostgresEngine): - db_conn = hs.get_db_conn() - yield prepare_database(db_conn, db_engine, config) - db_conn.commit() - db_conn.close() + hs.setup() + if homeserverToUse.__name__ == "TestHomeServer": + hs.setup_master() + + if isinstance(db_engine, PostgresEngine): + database = hs.get_datastores().databases[0] - else: # We need to do cleanup on PostgreSQL def cleanup(): import psycopg2 # Close all the db pools - hs.get_db_pool().close() + database._db_pool.close() dropped = False @@ -322,23 +316,12 @@ def setup_test_homeserver( # Register the cleanup hook cleanup_func(cleanup) - hs.setup() - if homeserverToUse.__name__ == "TestHomeServer": - hs.setup_master() else: - # If we have been given an explicit datastore we probably want to mock - # out the DataStores somehow too. This all feels a bit wrong, but then - # mocking the stores feels wrong too. - datastores = Mock(datastore=datastore) - hs = homeserverToUse( name, - db_pool=None, datastore=datastore, - datastores=datastores, config=config, version_string="Synapse/tests", - database_engine=db_engine, tls_server_context_factory=Mock(), tls_client_options_factory=Mock(), reactor=reactor, diff --git a/tox.ini b/tox.ini index 903a245fb0..1d6428f64f 100644 --- a/tox.ini +++ b/tox.ini @@ -171,11 +171,23 @@ basepython = python3.7 skip_install = True deps = {[base]deps} - mypy==0.730 + mypy==0.750 mypy-zope env = MYPYPATH = stubs/ extras = all commands = mypy \ + synapse/config/ \ + synapse/handlers/ui_auth \ synapse/logging/ \ - synapse/config/ + synapse/module_api \ + synapse/rest/consent \ + synapse/rest/media/v0 \ + synapse/rest/saml2 \ + synapse/spam_checker_api \ + synapse/storage/engines \ + synapse/streams + +# To find all folders that pass mypy you run: +# +# find synapse/* -type d -not -name __pycache__ -exec bash -c "mypy '{}' > /dev/null" \; -print |