diff --git a/.buildkite/worker-blacklist b/.buildkite/worker-blacklist
index 7950d19db3..158ab79154 100644
--- a/.buildkite/worker-blacklist
+++ b/.buildkite/worker-blacklist
@@ -34,33 +34,8 @@ Device list doesn't change if remote server is down
Remote servers cannot set power levels in rooms without existing powerlevels
Remote servers should reject attempts by non-creators to set the power levels
-# new failures as of https://github.com/matrix-org/sytest/pull/753
-GET /rooms/:room_id/messages returns a message
-GET /rooms/:room_id/messages lazy loads members correctly
-Read receipts are sent as events
-Only original members of the room can see messages from erased users
-Device deletion propagates over federation
-If user leaves room, remote user changes device and rejoins we see update in /sync and /keys/changes
-Changing user-signing key notifies local users
-Newly updated tags appear in an incremental v2 /sync
+# https://buildkite.com/matrix-dot-org/synapse/builds/6134#6f67bf47-e234-474d-80e8-c6e1868b15c5
Server correctly handles incoming m.device_list_update
-Local device key changes get to remote servers with correct prev_id
-AS-ghosted users can use rooms via AS
-Ghost user must register before joining room
-Test that a message is pushed
-Invites are pushed
-Rooms with aliases are correctly named in pushed
-Rooms with names are correctly named in pushed
-Rooms with canonical alias are correctly named in pushed
-Rooms with many users are correctly pushed
-Don't get pushed for rooms you've muted
-Rejected events are not pushed
-Test that rejected pushers are removed.
-Events come down the correct room
-
-# https://buildkite.com/matrix-dot-org/sytest/builds/326#cca62404-a88a-4fcb-ad41-175fd3377603
-Presence changes to UNAVAILABLE are reported to remote room members
-If remote user leaves room, changes device and rejoins we see update in sync
-uploading self-signing key notifies over federation
-Inbound federation can receive redacted events
-Outbound federation can request missing events
+
+# this fails reliably with a torture level of 100 due to https://github.com/matrix-org/synapse/issues/6536
+Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state
diff --git a/AUTHORS.rst b/AUTHORS.rst
index b8b31a5b47..014f16d4a2 100644
--- a/AUTHORS.rst
+++ b/AUTHORS.rst
@@ -46,3 +46,6 @@ Joseph Weston <joseph at weston.cloud>
Benjamin Saunders <ben.e.saunders at gmail dot com>
* Documentation improvements
+
+Werner Sembach <werner.sembach at fau dot de>
+ * Automatically remove a group/community when it is empty
diff --git a/changelog.d/6245.misc b/changelog.d/6245.misc
new file mode 100644
index 0000000000..a3e6b8296e
--- /dev/null
+++ b/changelog.d/6245.misc
@@ -0,0 +1 @@
+Split out state storage into separate data store.
diff --git a/changelog.d/6349.feature b/changelog.d/6349.feature
new file mode 100644
index 0000000000..56c4fbf78e
--- /dev/null
+++ b/changelog.d/6349.feature
@@ -0,0 +1 @@
+Implement v2 APIs for the `send_join` and `send_leave` federation endpoints (as described in [MSC1802](https://github.com/matrix-org/matrix-doc/pull/1802)).
diff --git a/changelog.d/6377.bugfix b/changelog.d/6377.bugfix
new file mode 100644
index 0000000000..ccda96962f
--- /dev/null
+++ b/changelog.d/6377.bugfix
@@ -0,0 +1 @@
+Prevent redacted events from being returned during message search.
\ No newline at end of file
diff --git a/changelog.d/6385.bugfix b/changelog.d/6385.bugfix
new file mode 100644
index 0000000000..7a2bc02170
--- /dev/null
+++ b/changelog.d/6385.bugfix
@@ -0,0 +1 @@
+Prevent error on trying to search a upgraded room when the server is not in the predecessor room.
\ No newline at end of file
diff --git a/changelog.d/6394.feature b/changelog.d/6394.feature
new file mode 100644
index 0000000000..1a0e8845ad
--- /dev/null
+++ b/changelog.d/6394.feature
@@ -0,0 +1 @@
+Add a develop script to generate full SQL schemas.
\ No newline at end of file
diff --git a/changelog.d/6411.feature b/changelog.d/6411.feature
new file mode 100644
index 0000000000..ebea4a208d
--- /dev/null
+++ b/changelog.d/6411.feature
@@ -0,0 +1 @@
+Allow custom SAML username mapping functinality through an external provider plugin.
\ No newline at end of file
diff --git a/changelog.d/6453.feature b/changelog.d/6453.feature
new file mode 100644
index 0000000000..e7bb801c6a
--- /dev/null
+++ b/changelog.d/6453.feature
@@ -0,0 +1 @@
+Automatically delete empty groups/communities.
diff --git a/changelog.d/6486.bugfix b/changelog.d/6486.bugfix
new file mode 100644
index 0000000000..b98c5a9ae5
--- /dev/null
+++ b/changelog.d/6486.bugfix
@@ -0,0 +1 @@
+Improve performance of looking up cross-signing keys.
diff --git a/changelog.d/6496.misc b/changelog.d/6496.misc
new file mode 100644
index 0000000000..19c6e926b8
--- /dev/null
+++ b/changelog.d/6496.misc
@@ -0,0 +1 @@
+Port synapse.handlers.initial_sync to async/await.
diff --git a/changelog.d/6502.removal b/changelog.d/6502.removal
new file mode 100644
index 0000000000..0b72261d58
--- /dev/null
+++ b/changelog.d/6502.removal
@@ -0,0 +1 @@
+Remove redundant code from event authorisation implementation.
diff --git a/changelog.d/6504.misc b/changelog.d/6504.misc
new file mode 100644
index 0000000000..7c873459af
--- /dev/null
+++ b/changelog.d/6504.misc
@@ -0,0 +1 @@
+Port handlers.account_data and handlers.account_validity to async/await.
diff --git a/changelog.d/6505.misc b/changelog.d/6505.misc
new file mode 100644
index 0000000000..3a75b2d9dd
--- /dev/null
+++ b/changelog.d/6505.misc
@@ -0,0 +1 @@
+Make `make_deferred_yieldable` to work with async/await.
diff --git a/changelog.d/6506.misc b/changelog.d/6506.misc
new file mode 100644
index 0000000000..99d7a70bcf
--- /dev/null
+++ b/changelog.d/6506.misc
@@ -0,0 +1 @@
+Remove `SnapshotCache` in favour of `ResponseCache`.
diff --git a/changelog.d/6510.misc b/changelog.d/6510.misc
new file mode 100644
index 0000000000..214f06539b
--- /dev/null
+++ b/changelog.d/6510.misc
@@ -0,0 +1 @@
+Change phone home stats to not assume there is a single database and report information about the database used by the main data store.
diff --git a/changelog.d/6511.misc b/changelog.d/6511.misc
new file mode 100644
index 0000000000..19ce435e68
--- /dev/null
+++ b/changelog.d/6511.misc
@@ -0,0 +1 @@
+Move database config from apps into HomeServer object.
diff --git a/changelog.d/6512.misc b/changelog.d/6512.misc
new file mode 100644
index 0000000000..37a8099eec
--- /dev/null
+++ b/changelog.d/6512.misc
@@ -0,0 +1 @@
+Silence mypy errors for files outside those specified.
diff --git a/changelog.d/6513.misc b/changelog.d/6513.misc
new file mode 100644
index 0000000000..36700f5657
--- /dev/null
+++ b/changelog.d/6513.misc
@@ -0,0 +1 @@
+Remove all assumptions of there being a single phyiscal DB apart from the `synapse.config`.
diff --git a/changelog.d/6514.bugfix b/changelog.d/6514.bugfix
new file mode 100644
index 0000000000..6dc1985c24
--- /dev/null
+++ b/changelog.d/6514.bugfix
@@ -0,0 +1 @@
+Fix race which occasionally caused deleted devices to reappear.
diff --git a/changelog.d/6515.misc b/changelog.d/6515.misc
new file mode 100644
index 0000000000..a9c303ed1c
--- /dev/null
+++ b/changelog.d/6515.misc
@@ -0,0 +1 @@
+Clean up some logging when handling incoming events over federation.
diff --git a/changelog.d/6517.misc b/changelog.d/6517.misc
new file mode 100644
index 0000000000..c6ffed9952
--- /dev/null
+++ b/changelog.d/6517.misc
@@ -0,0 +1 @@
+Port some of FederationHandler to async/await.
\ No newline at end of file
diff --git a/changelog.d/6522.bugfix b/changelog.d/6522.bugfix
new file mode 100644
index 0000000000..ccda96962f
--- /dev/null
+++ b/changelog.d/6522.bugfix
@@ -0,0 +1 @@
+Prevent redacted events from being returned during message search.
\ No newline at end of file
diff --git a/changelog.d/6523.feature b/changelog.d/6523.feature
new file mode 100644
index 0000000000..798fa143df
--- /dev/null
+++ b/changelog.d/6523.feature
@@ -0,0 +1 @@
+Add option `limit_profile_requests_to_users_who_share_rooms` to prevent requirement of a local user sharing a room with another user to query their profile information.
diff --git a/changelog.d/6534.misc b/changelog.d/6534.misc
new file mode 100644
index 0000000000..7df6bb442a
--- /dev/null
+++ b/changelog.d/6534.misc
@@ -0,0 +1 @@
+Test more folders against mypy.
diff --git a/changelog.d/6537.misc b/changelog.d/6537.misc
new file mode 100644
index 0000000000..3543153584
--- /dev/null
+++ b/changelog.d/6537.misc
@@ -0,0 +1 @@
+Update `mypy` to new version.
diff --git a/changelog.d/6538.misc b/changelog.d/6538.misc
new file mode 100644
index 0000000000..cb4fd56948
--- /dev/null
+++ b/changelog.d/6538.misc
@@ -0,0 +1 @@
+Adjust the sytest blacklist for worker mode.
diff --git a/changelog.d/6541.doc b/changelog.d/6541.doc
new file mode 100644
index 0000000000..c20029edc0
--- /dev/null
+++ b/changelog.d/6541.doc
@@ -0,0 +1 @@
+Document the Room Shutdown Admin API.
\ No newline at end of file
diff --git a/changelog.d/6546.feature b/changelog.d/6546.feature
new file mode 100644
index 0000000000..954aacb0d0
--- /dev/null
+++ b/changelog.d/6546.feature
@@ -0,0 +1 @@
+Add an export_signing_key script to extract the public part of signing keys when rotating them.
diff --git a/changelog.d/6555.bugfix b/changelog.d/6555.bugfix
new file mode 100644
index 0000000000..86a5a56cf6
--- /dev/null
+++ b/changelog.d/6555.bugfix
@@ -0,0 +1 @@
+Fix missing row in device_max_stream_id that could cause unable to decrypt errors after server restart.
\ No newline at end of file
diff --git a/changelog.d/6557.misc b/changelog.d/6557.misc
new file mode 100644
index 0000000000..80e7eaedb8
--- /dev/null
+++ b/changelog.d/6557.misc
@@ -0,0 +1 @@
+Remove unused `get_pagination_rows` methods from `EventSource` classes.
diff --git a/changelog.d/6558.misc b/changelog.d/6558.misc
new file mode 100644
index 0000000000..a7572f1a85
--- /dev/null
+++ b/changelog.d/6558.misc
@@ -0,0 +1 @@
+Clean up logs from the push notifier at startup.
\ No newline at end of file
diff --git a/changelog.d/6559.misc b/changelog.d/6559.misc
new file mode 100644
index 0000000000..8bca37457d
--- /dev/null
+++ b/changelog.d/6559.misc
@@ -0,0 +1 @@
+Port `synapse.handlers.admin` and `synapse.handlers.deactivate_account` to async/await.
diff --git a/changelog.d/6564.misc b/changelog.d/6564.misc
new file mode 100644
index 0000000000..f644f5868b
--- /dev/null
+++ b/changelog.d/6564.misc
@@ -0,0 +1 @@
+Change `EventContext` to use the `Storage` class, in preparation for moving state database queries to a separate data store.
diff --git a/changelog.d/6565.misc b/changelog.d/6565.misc
new file mode 100644
index 0000000000..e83f245bf0
--- /dev/null
+++ b/changelog.d/6565.misc
@@ -0,0 +1 @@
+Add assertion that schema delta file names are unique.
diff --git a/changelog.d/6570.misc b/changelog.d/6570.misc
new file mode 100644
index 0000000000..e89955a51e
--- /dev/null
+++ b/changelog.d/6570.misc
@@ -0,0 +1 @@
+Improve diagnostics on database upgrade failure.
diff --git a/changelog.d/6571.bugfix b/changelog.d/6571.bugfix
new file mode 100644
index 0000000000..e38ea7b4f7
--- /dev/null
+++ b/changelog.d/6571.bugfix
@@ -0,0 +1 @@
+Fix a bug which meant that we did not send systemd notifications on startup if acme was enabled.
diff --git a/docs/admin_api/shutdown_room.md b/docs/admin_api/shutdown_room.md
new file mode 100644
index 0000000000..54ce1cd234
--- /dev/null
+++ b/docs/admin_api/shutdown_room.md
@@ -0,0 +1,72 @@
+# Shutdown room API
+
+Shuts down a room, preventing new joins and moves local users and room aliases automatically
+to a new room. The new room will be created with the user specified by the
+`new_room_user_id` parameter as room administrator and will contain a message
+explaining what happened. Users invited to the new room will have power level
+-10 by default, and thus be unable to speak. The old room's power levels will be changed to
+disallow any further invites or joins.
+
+The local server will only have the power to move local user and room aliases to
+the new room. Users on other servers will be unaffected.
+
+## API
+
+You will need to authenticate with an access token for an admin user.
+
+### URL
+
+`POST /_synapse/admin/v1/shutdown_room/{room_id}`
+
+### URL Parameters
+
+* `room_id` - The ID of the room (e.g `!someroom:example.com`)
+
+### JSON Body Parameters
+
+* `new_room_user_id` - Required. A string representing the user ID of the user that will admin
+ the new room that all users in the old room will be moved to.
+* `room_name` - Optional. A string representing the name of the room that new users will be
+ invited to.
+* `message` - Optional. A string containing the first message that will be sent as
+ `new_room_user_id` in the new room. Ideally this will clearly convey why the
+ original room was shut down.
+
+If not specified, the default value of `room_name` is "Content Violation
+Notification". The default value of `message` is "Sharing illegal content on
+othis server is not permitted and rooms in violation will be blocked."
+
+### Response Parameters
+
+* `kicked_users` - An integer number representing the number of users that
+ were kicked.
+* `failed_to_kick_users` - An integer number representing the number of users
+ that were not kicked.
+* `local_aliases` - An array of strings representing the local aliases that were migrated from
+ the old room to the new.
+* `new_room_id` - A string representing the room ID of the new room.
+
+## Example
+
+Request:
+
+```
+POST /_synapse/admin/v1/shutdown_room/!somebadroom%3Aexample.com
+
+{
+ "new_room_user_id": "@someuser:example.com",
+ "room_name": "Content Violation Notification",
+ "message": "Bad Room has been shutdown due to content violations on this server. Please review our Terms of Service."
+}
+```
+
+Response:
+
+```
+{
+ "kicked_users": 5,
+ "failed_to_kick_users": 0,
+ "local_aliases": ["#badroom:example.com", "#evilsaloon:example.com],
+ "new_room_id": "!newroomid:example.com",
+},
+```
diff --git a/docs/code_style.md b/docs/code_style.md
index f983f72d6c..71aecd41f7 100644
--- a/docs/code_style.md
+++ b/docs/code_style.md
@@ -137,6 +137,7 @@ Some guidelines follow:
correctly handles the top-level option being set to `None` (as it
will be if no sub-options are enabled).
- Lines should be wrapped at 80 characters.
+- Use two-space indents.
Example:
@@ -155,13 +156,13 @@ Example:
# Settings for the frobber
#
frobber:
- # frobbing speed. Defaults to 1.
- #
- #speed: 10
+ # frobbing speed. Defaults to 1.
+ #
+ #speed: 10
- # frobbing distance. Defaults to 1000.
- #
- #distance: 100
+ # frobbing distance. Defaults to 1000.
+ #
+ #distance: 100
Note that the sample configuration is generated from the synapse code
and is maintained by a script, `scripts-dev/generate_sample_config`.
diff --git a/docs/saml_mapping_providers.md b/docs/saml_mapping_providers.md
new file mode 100644
index 0000000000..92f2380488
--- /dev/null
+++ b/docs/saml_mapping_providers.md
@@ -0,0 +1,77 @@
+# SAML Mapping Providers
+
+A SAML mapping provider is a Python class (loaded via a Python module) that
+works out how to map attributes of a SAML response object to Matrix-specific
+user attributes. Details such as user ID localpart, displayname, and even avatar
+URLs are all things that can be mapped from talking to a SSO service.
+
+As an example, a SSO service may return the email address
+"john.smith@example.com" for a user, whereas Synapse will need to figure out how
+to turn that into a displayname when creating a Matrix user for this individual.
+It may choose `John Smith`, or `Smith, John [Example.com]` or any number of
+variations. As each Synapse configuration may want something different, this is
+where SAML mapping providers come into play.
+
+## Enabling Providers
+
+External mapping providers are provided to Synapse in the form of an external
+Python module. Retrieve this module from [PyPi](https://pypi.org) or elsewhere,
+then tell Synapse where to look for the handler class by editing the
+`saml2_config.user_mapping_provider.module` config option.
+
+`saml2_config.user_mapping_provider.config` allows you to provide custom
+configuration options to the module. Check with the module's documentation for
+what options it provides (if any). The options listed by default are for the
+user mapping provider built in to Synapse. If using a custom module, you should
+comment these options out and use those specified by the module instead.
+
+## Building a Custom Mapping Provider
+
+A custom mapping provider must specify the following methods:
+
+* `__init__(self, parsed_config)`
+ - Arguments:
+ - `parsed_config` - A configuration object that is the return value of the
+ `parse_config` method. You should set any configuration options needed by
+ the module here.
+* `saml_response_to_user_attributes(self, saml_response, failures)`
+ - Arguments:
+ - `saml_response` - A `saml2.response.AuthnResponse` object to extract user
+ information from.
+ - `failures` - An `int` that represents the amount of times the returned
+ mxid localpart mapping has failed. This should be used
+ to create a deduplicated mxid localpart which should be
+ returned instead. For example, if this method returns
+ `john.doe` as the value of `mxid_localpart` in the returned
+ dict, and that is already taken on the homeserver, this
+ method will be called again with the same parameters but
+ with failures=1. The method should then return a different
+ `mxid_localpart` value, such as `john.doe1`.
+ - This method must return a dictionary, which will then be used by Synapse
+ to build a new user. The following keys are allowed:
+ * `mxid_localpart` - Required. The mxid localpart of the new user.
+ * `displayname` - The displayname of the new user. If not provided, will default to
+ the value of `mxid_localpart`.
+* `parse_config(config)`
+ - This method should have the `@staticmethod` decoration.
+ - Arguments:
+ - `config` - A `dict` representing the parsed content of the
+ `saml2_config.user_mapping_provider.config` homeserver config option.
+ Runs on homeserver startup. Providers should extract any option values
+ they need here.
+ - Whatever is returned will be passed back to the user mapping provider module's
+ `__init__` method during construction.
+* `get_saml_attributes(config)`
+ - This method should have the `@staticmethod` decoration.
+ - Arguments:
+ - `config` - A object resulting from a call to `parse_config`.
+ - Returns a tuple of two sets. The first set equates to the saml auth
+ response attributes that are required for the module to function, whereas
+ the second set consists of those attributes which can be used if available,
+ but are not necessary.
+
+## Synapse's Default Provider
+
+Synapse has a built-in SAML mapping provider if a custom provider isn't
+specified in the config. It is located at
+[`synapse.handlers.saml_handler.DefaultSamlMappingProvider`](../synapse/handlers/saml_handler.py).
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 10664ae8f7..e3b05423b8 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -54,6 +54,13 @@ pid_file: DATADIR/homeserver.pid
#
#require_auth_for_profile_requests: true
+# Uncomment to require a user to share a room with another user in order
+# to retrieve their profile information. Only checked on Client-Server
+# requests. Profile requests from other servers should be checked by the
+# requesting server. Defaults to 'false'.
+#
+#limit_profile_requests_to_users_who_share_rooms: true
+
# If set to 'true', removes the need for authentication to access the server's
# public rooms directory through the client API, meaning that anyone can
# query the room directory. Defaults to 'false'.
@@ -1115,14 +1122,19 @@ metrics_flags:
signing_key_path: "CONFDIR/SERVERNAME.signing.key"
# The keys that the server used to sign messages with but won't use
-# to sign new messages. E.g. it has lost its private key
+# to sign new messages.
#
-#old_signing_keys:
-# "ed25519:auto":
-# # Base64 encoded public key
-# key: "The public part of your old signing key."
-# # Millisecond POSIX timestamp when the key expired.
-# expired_ts: 123456789123
+old_signing_keys:
+ # For each key, `key` should be the base64-encoded public key, and
+ # `expired_ts`should be the time (in milliseconds since the unix epoch) that
+ # it was last used.
+ #
+ # It is possible to build an entry from an old signing.key file using the
+ # `export_signing_key` script which is provided with synapse.
+ #
+ # For example:
+ #
+ #"ed25519:id": { key: "base64string", expired_ts: 123456789123 }
# How long key response published by this server is valid for.
# Used to set the valid_until_ts in /key/v2 APIs.
@@ -1250,33 +1262,58 @@ saml2_config:
#
#config_path: "CONFDIR/sp_conf.py"
- # the lifetime of a SAML session. This defines how long a user has to
+ # The lifetime of a SAML session. This defines how long a user has to
# complete the authentication process, if allow_unsolicited is unset.
# The default is 5 minutes.
#
#saml_session_lifetime: 5m
- # The SAML attribute (after mapping via the attribute maps) to use to derive
- # the Matrix ID from. 'uid' by default.
+ # An external module can be provided here as a custom solution to
+ # mapping attributes returned from a saml provider onto a matrix user.
#
- #mxid_source_attribute: displayName
-
- # The mapping system to use for mapping the saml attribute onto a matrix ID.
- # Options include:
- # * 'hexencode' (which maps unpermitted characters to '=xx')
- # * 'dotreplace' (which replaces unpermitted characters with '.').
- # The default is 'hexencode'.
- #
- #mxid_mapping: dotreplace
+ user_mapping_provider:
+ # The custom module's class. Uncomment to use a custom module.
+ #
+ #module: mapping_provider.SamlMappingProvider
- # In previous versions of synapse, the mapping from SAML attribute to MXID was
- # always calculated dynamically rather than stored in a table. For backwards-
- # compatibility, we will look for user_ids matching such a pattern before
- # creating a new account.
+ # Custom configuration values for the module. Below options are
+ # intended for the built-in provider, they should be changed if
+ # using a custom module. This section will be passed as a Python
+ # dictionary to the module's `parse_config` method.
+ #
+ config:
+ # The SAML attribute (after mapping via the attribute maps) to use
+ # to derive the Matrix ID from. 'uid' by default.
+ #
+ # Note: This used to be configured by the
+ # saml2_config.mxid_source_attribute option. If that is still
+ # defined, its value will be used instead.
+ #
+ #mxid_source_attribute: displayName
+
+ # The mapping system to use for mapping the saml attribute onto a
+ # matrix ID.
+ #
+ # Options include:
+ # * 'hexencode' (which maps unpermitted characters to '=xx')
+ # * 'dotreplace' (which replaces unpermitted characters with
+ # '.').
+ # The default is 'hexencode'.
+ #
+ # Note: This used to be configured by the
+ # saml2_config.mxid_mapping option. If that is still defined, its
+ # value will be used instead.
+ #
+ #mxid_mapping: dotreplace
+
+ # In previous versions of synapse, the mapping from SAML attribute to
+ # MXID was always calculated dynamically rather than stored in a
+ # table. For backwards- compatibility, we will look for user_ids
+ # matching such a pattern before creating a new account.
#
# This setting controls the SAML attribute which will be used for this
- # backwards-compatibility lookup. Typically it should be 'uid', but if the
- # attribute maps are changed, it may be necessary to change it.
+ # backwards-compatibility lookup. Typically it should be 'uid', but if
+ # the attribute maps are changed, it may be necessary to change it.
#
# The default is 'uid'.
#
diff --git a/mypy.ini b/mypy.ini
index 1d77c0ecc8..a66434b76b 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -1,7 +1,7 @@
[mypy]
namespace_packages = True
plugins = mypy_zope:plugin
-follow_imports = normal
+follow_imports = silent
check_untyped_defs = True
show_error_codes = True
show_traceback = True
diff --git a/scripts-dev/make_full_schema.sh b/scripts-dev/make_full_schema.sh
new file mode 100755
index 0000000000..60e8970a35
--- /dev/null
+++ b/scripts-dev/make_full_schema.sh
@@ -0,0 +1,184 @@
+#!/bin/bash
+#
+# This script generates SQL files for creating a brand new Synapse DB with the latest
+# schema, on both SQLite3 and Postgres.
+#
+# It does so by having Synapse generate an up-to-date SQLite DB, then running
+# synapse_port_db to convert it to Postgres. It then dumps the contents of both.
+
+POSTGRES_HOST="localhost"
+POSTGRES_DB_NAME="synapse_full_schema.$$"
+
+SQLITE_FULL_SCHEMA_OUTPUT_FILE="full.sql.sqlite"
+POSTGRES_FULL_SCHEMA_OUTPUT_FILE="full.sql.postgres"
+
+REQUIRED_DEPS=("matrix-synapse" "psycopg2")
+
+usage() {
+ echo
+ echo "Usage: $0 -p <postgres_username> -o <path> [-c] [-n] [-h]"
+ echo
+ echo "-p <postgres_username>"
+ echo " Username to connect to local postgres instance. The password will be requested"
+ echo " during script execution."
+ echo "-c"
+ echo " CI mode. Enables coverage tracking and prints every command that the script runs."
+ echo "-o <path>"
+ echo " Directory to output full schema files to."
+ echo "-h"
+ echo " Display this help text."
+}
+
+while getopts "p:co:h" opt; do
+ case $opt in
+ p)
+ POSTGRES_USERNAME=$OPTARG
+ ;;
+ c)
+ # Print all commands that are being executed
+ set -x
+
+ # Modify required dependencies for coverage
+ REQUIRED_DEPS+=("coverage" "coverage-enable-subprocess")
+
+ COVERAGE=1
+ ;;
+ o)
+ command -v realpath > /dev/null || (echo "The -o flag requires the 'realpath' binary to be installed" && exit 1)
+ OUTPUT_DIR="$(realpath "$OPTARG")"
+ ;;
+ h)
+ usage
+ exit
+ ;;
+ \?)
+ echo "ERROR: Invalid option: -$OPTARG" >&2
+ usage
+ exit
+ ;;
+ esac
+done
+
+# Check that required dependencies are installed
+unsatisfied_requirements=()
+for dep in "${REQUIRED_DEPS[@]}"; do
+ pip show "$dep" --quiet || unsatisfied_requirements+=("$dep")
+done
+if [ ${#unsatisfied_requirements} -ne 0 ]; then
+ echo "Please install the following python packages: ${unsatisfied_requirements[*]}"
+ exit 1
+fi
+
+if [ -z "$POSTGRES_USERNAME" ]; then
+ echo "No postgres username supplied"
+ usage
+ exit 1
+fi
+
+if [ -z "$OUTPUT_DIR" ]; then
+ echo "No output directory supplied"
+ usage
+ exit 1
+fi
+
+# Create the output directory if it doesn't exist
+mkdir -p "$OUTPUT_DIR"
+
+read -rsp "Postgres password for '$POSTGRES_USERNAME': " POSTGRES_PASSWORD
+echo ""
+
+# Exit immediately if a command fails
+set -e
+
+# cd to root of the synapse directory
+cd "$(dirname "$0")/.."
+
+# Create temporary SQLite and Postgres homeserver db configs and key file
+TMPDIR=$(mktemp -d)
+KEY_FILE=$TMPDIR/test.signing.key # default Synapse signing key path
+SQLITE_CONFIG=$TMPDIR/sqlite.conf
+SQLITE_DB=$TMPDIR/homeserver.db
+POSTGRES_CONFIG=$TMPDIR/postgres.conf
+
+# Ensure these files are delete on script exit
+trap 'rm -rf $TMPDIR' EXIT
+
+cat > "$SQLITE_CONFIG" <<EOF
+server_name: "test"
+
+signing_key_path: "$KEY_FILE"
+macaroon_secret_key: "abcde"
+
+report_stats: false
+
+database:
+ name: "sqlite3"
+ args:
+ database: "$SQLITE_DB"
+
+# Suppress the key server warning.
+trusted_key_servers: []
+EOF
+
+cat > "$POSTGRES_CONFIG" <<EOF
+server_name: "test"
+
+signing_key_path: "$KEY_FILE"
+macaroon_secret_key: "abcde"
+
+report_stats: false
+
+database:
+ name: "psycopg2"
+ args:
+ user: "$POSTGRES_USERNAME"
+ host: "$POSTGRES_HOST"
+ password: "$POSTGRES_PASSWORD"
+ database: "$POSTGRES_DB_NAME"
+
+# Suppress the key server warning.
+trusted_key_servers: []
+EOF
+
+# Generate the server's signing key.
+echo "Generating SQLite3 db schema..."
+python -m synapse.app.homeserver --generate-keys -c "$SQLITE_CONFIG"
+
+# Make sure the SQLite3 database is using the latest schema and has no pending background update.
+echo "Running db background jobs..."
+scripts-dev/update_database --database-config "$SQLITE_CONFIG"
+
+# Create the PostgreSQL database.
+echo "Creating postgres database..."
+createdb $POSTGRES_DB_NAME
+
+echo "Copying data from SQLite3 to Postgres with synapse_port_db..."
+if [ -z "$COVERAGE" ]; then
+ # No coverage needed
+ scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG"
+else
+ # Coverage desired
+ coverage run scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG"
+fi
+
+# Delete schema_version, applied_schema_deltas and applied_module_schemas tables
+# This needs to be done after synapse_port_db is run
+echo "Dropping unwanted db tables..."
+SQL="
+DROP TABLE schema_version;
+DROP TABLE applied_schema_deltas;
+DROP TABLE applied_module_schemas;
+"
+sqlite3 "$SQLITE_DB" <<< "$SQL"
+psql $POSTGRES_DB_NAME -U "$POSTGRES_USERNAME" -w <<< "$SQL"
+
+echo "Dumping SQLite3 schema to '$OUTPUT_DIR/$SQLITE_FULL_SCHEMA_OUTPUT_FILE'..."
+sqlite3 "$SQLITE_DB" ".dump" > "$OUTPUT_DIR/$SQLITE_FULL_SCHEMA_OUTPUT_FILE"
+
+echo "Dumping Postgres schema to '$OUTPUT_DIR/$POSTGRES_FULL_SCHEMA_OUTPUT_FILE'..."
+pg_dump --format=plain --no-tablespaces --no-acl --no-owner $POSTGRES_DB_NAME | sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > "$OUTPUT_DIR/$POSTGRES_FULL_SCHEMA_OUTPUT_FILE"
+
+echo "Cleaning up temporary Postgres database..."
+dropdb $POSTGRES_DB_NAME
+
+echo "Done! Files dumped to: $OUTPUT_DIR"
diff --git a/scripts-dev/update_database b/scripts-dev/update_database
index 1776d202c5..1d62f0403a 100755
--- a/scripts-dev/update_database
+++ b/scripts-dev/update_database
@@ -26,8 +26,6 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.server import HomeServer
from synapse.storage import DataStore
-from synapse.storage.engines import create_engine
-from synapse.storage.prepare_database import prepare_database
logger = logging.getLogger("update_database")
@@ -35,21 +33,11 @@ logger = logging.getLogger("update_database")
class MockHomeserver(HomeServer):
DATASTORE_CLASS = DataStore
- def __init__(self, config, database_engine, db_conn, **kwargs):
+ def __init__(self, config, **kwargs):
super(MockHomeserver, self).__init__(
- config.server_name,
- reactor=reactor,
- config=config,
- database_engine=database_engine,
- **kwargs
+ config.server_name, reactor=reactor, config=config, **kwargs
)
- self.database_engine = database_engine
- self.db_conn = db_conn
-
- def get_db_conn(self):
- return self.db_conn
-
if __name__ == "__main__":
parser = argparse.ArgumentParser(
@@ -85,25 +73,11 @@ if __name__ == "__main__":
config = HomeServerConfig()
config.parse_config_dict(hs_config, "", "")
- # Create the database engine and a connection to it.
- database_engine = create_engine(config.database_config)
- db_conn = database_engine.module.connect(
- **{
- k: v
- for k, v in config.database_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- )
-
- # Update the database to the latest schema.
- prepare_database(db_conn, database_engine, config=config)
- db_conn.commit()
-
# Instantiate and initialise the homeserver object.
- hs = MockHomeserver(
- config, database_engine, db_conn, db_config=config.database_config,
- )
- # setup instantiates the store within the homeserver object.
+ hs = MockHomeserver(config)
+
+ # Setup instantiates the store within the homeserver object and updates the
+ # DB.
hs.setup()
store = hs.get_datastore()
diff --git a/scripts/export_signing_key b/scripts/export_signing_key
new file mode 100755
index 0000000000..8aec9d802b
--- /dev/null
+++ b/scripts/export_signing_key
@@ -0,0 +1,94 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import sys
+import time
+from typing import Optional
+
+import nacl.signing
+from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys
+
+
+def exit(status: int = 0, message: Optional[str] = None):
+ if message:
+ print(message, file=sys.stderr)
+ sys.exit(status)
+
+
+def format_plain(public_key: nacl.signing.VerifyKey):
+ print(
+ "%s:%s %s"
+ % (public_key.alg, public_key.version, encode_verify_key_base64(public_key),)
+ )
+
+
+def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int):
+ print(
+ ' "%s:%s": { key: "%s", expired_ts: %i }'
+ % (
+ public_key.alg,
+ public_key.version,
+ encode_verify_key_base64(public_key),
+ expiry_ts,
+ )
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "key_file", nargs="+", type=argparse.FileType("r"), help="The key file to read",
+ )
+
+ parser.add_argument(
+ "-x",
+ action="store_true",
+ dest="for_config",
+ help="format the output for inclusion in the old_signing_keys config setting",
+ )
+
+ parser.add_argument(
+ "--expiry-ts",
+ type=int,
+ default=int(time.time() * 1000) + 6*3600000,
+ help=(
+ "The expiry time to use for -x, in milliseconds since 1970. The default "
+ "is (now+6h)."
+ ),
+ )
+
+ args = parser.parse_args()
+
+ formatter = (
+ (lambda k: format_for_config(k, args.expiry_ts))
+ if args.for_config
+ else format_plain
+ )
+
+ keys = []
+ for file in args.key_file:
+ try:
+ res = read_signing_keys(file)
+ except Exception as e:
+ exit(
+ status=1,
+ message="Error reading key from file %s: %s %s"
+ % (file.name, type(e), e),
+ )
+ res = []
+ for key in res:
+ formatter(get_verify_key(key))
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index e393a9b2f7..eb927f2094 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -30,6 +30,7 @@ import yaml
from twisted.enterprise import adbapi
from twisted.internet import defer, reactor
+from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.logging.context import PreserveLoggingContext
from synapse.storage._base import LoggingTransaction
@@ -50,12 +51,13 @@ from synapse.storage.data_stores.main.registration import (
from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore
from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore
from synapse.storage.data_stores.main.search import SearchBackgroundUpdateStore
-from synapse.storage.data_stores.main.state import StateBackgroundUpdateStore
+from synapse.storage.data_stores.main.state import MainStateBackgroundUpdateStore
from synapse.storage.data_stores.main.stats import StatsStore
from synapse.storage.data_stores.main.user_directory import (
UserDirectoryBackgroundUpdateStore,
)
-from synapse.storage.database import Database
+from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
+from synapse.storage.database import Database, make_conn
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.util import Clock
@@ -137,6 +139,7 @@ class Store(
RoomMemberBackgroundUpdateStore,
SearchBackgroundUpdateStore,
StateBackgroundUpdateStore,
+ MainStateBackgroundUpdateStore,
UserDirectoryBackgroundUpdateStore,
StatsStore,
):
@@ -165,23 +168,17 @@ class Store(
class MockHomeserver:
- def __init__(self, config, database_engine, db_conn, db_pool):
- self.database_engine = database_engine
- self.db_conn = db_conn
- self.db_pool = db_pool
+ def __init__(self, config):
self.clock = Clock(reactor)
self.config = config
self.hostname = config.server_name
- def get_db_conn(self):
- return self.db_conn
-
- def get_db_pool(self):
- return self.db_pool
-
def get_clock(self):
return self.clock
+ def get_reactor(self):
+ return reactor
+
class Porter(object):
def __init__(self, **kwargs):
@@ -445,45 +442,36 @@ class Porter(object):
else:
return
- def setup_db(self, db_config, database_engine):
- db_conn = database_engine.module.connect(
- **{
- k: v
- for k, v in db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- )
-
- prepare_database(db_conn, database_engine, config=None)
+ def setup_db(self, db_config: DatabaseConnectionConfig, engine):
+ db_conn = make_conn(db_config, engine)
+ prepare_database(db_conn, engine, config=None)
db_conn.commit()
return db_conn
@defer.inlineCallbacks
- def build_db_store(self, config):
+ def build_db_store(self, db_config: DatabaseConnectionConfig):
"""Builds and returns a database store using the provided configuration.
Args:
- config: The database configuration, i.e. a dict following the structure of
- the "database" section of Synapse's configuration file.
+ config: The database configuration
Returns:
The built Store object.
"""
- engine = create_engine(config)
-
- self.progress.set_state("Preparing %s" % config["name"])
- conn = self.setup_db(config, engine)
+ self.progress.set_state("Preparing %s" % db_config.config["name"])
- db_pool = adbapi.ConnectionPool(config["name"], **config["args"])
+ engine = create_engine(db_config.config)
+ conn = self.setup_db(db_config, engine)
- hs = MockHomeserver(self.hs_config, engine, conn, db_pool)
+ hs = MockHomeserver(self.hs_config)
- store = Store(Database(hs), conn, hs)
+ store = Store(Database(hs, db_config, engine), conn, hs)
yield store.db.runInteraction(
- "%s_engine.check_database" % config["name"], engine.check_database,
+ "%s_engine.check_database" % db_config.config["name"],
+ engine.check_database,
)
return store
@@ -509,7 +497,9 @@ class Porter(object):
@defer.inlineCallbacks
def run(self):
try:
- self.sqlite_store = yield self.build_db_store(self.sqlite_config)
+ self.sqlite_store = yield self.build_db_store(
+ DatabaseConnectionConfig("master-sqlite", self.sqlite_config)
+ )
# Check if all background updates are done, abort if not.
updates_complete = (
@@ -524,7 +514,7 @@ class Porter(object):
defer.returnValue(None)
self.postgres_store = yield self.build_db_store(
- self.hs_config.database_config
+ self.hs_config.get_single_database()
)
yield self.run_background_updates_on_postgres()
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 9fd52a8c77..abbc7079a3 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -79,7 +79,7 @@ class Auth(object):
@defer.inlineCallbacks
def check_from_context(self, room_version, event, context, do_sig_check=True):
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
auth_events_ids = yield self.compute_auth_events(
event, prev_state_ids, for_verification=True
)
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 9c96816096..0e8b467a3e 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -237,6 +237,12 @@ def start(hs, listeners=None):
"""
Start a Synapse server or worker.
+ Should be called once the reactor is running and (if we're using ACME) the
+ TLS certificates are in place.
+
+ Will start the main HTTP listeners and do some other startup tasks, and then
+ notify systemd.
+
Args:
hs (synapse.server.HomeServer)
listeners (list[dict]): Listener configuration ('listeners' in homeserver.yaml)
@@ -311,9 +317,7 @@ def setup_sdnotify(hs):
# Tell systemd our state, if we're using it. This will silently fail if
# we're not using systemd.
- hs.get_reactor().addSystemEventTrigger(
- "after", "startup", sdnotify, b"READY=1\nMAINPID=%i" % (os.getpid(),)
- )
+ sdnotify(b"READY=1\nMAINPID=%i" % (os.getpid(),))
hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", sdnotify, b"STOPPING=1"
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index 04751a6a5e..8e36bc57d3 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -45,7 +45,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
from synapse.util.logcontext import LoggingContext
from synapse.util.versionstring import get_version_string
@@ -105,8 +104,10 @@ def export_data_command(hs, args):
user_id = args.user_id
directory = args.output_directory
- res = yield hs.get_handlers().admin_handler.export_user_data(
- user_id, FileExfiltrationWriter(user_id, directory=directory)
+ res = yield defer.ensureDeferred(
+ hs.get_handlers().admin_handler.export_user_data(
+ user_id, FileExfiltrationWriter(user_id, directory=directory)
+ )
)
print(res)
@@ -229,14 +230,10 @@ def start(config_options):
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
-
ss = AdminCmdServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 02b900f382..e82e0f11e3 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -34,7 +34,6 @@ from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -143,8 +142,6 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
-
if config.notify_appservices:
sys.stderr.write(
"\nThe appservices must be disabled in the main synapse process"
@@ -159,10 +156,8 @@ def start(config_options):
ps = AppserviceServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
setup_logging(ps, config, use_worker_options=True)
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index dadb487d5f..3edfe19567 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -62,7 +62,6 @@ from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.rest.client.versions import VersionsRestServlet
from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -181,14 +180,10 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
-
ss = ClientReaderServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index d110599a35..d0ddbe38fc 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -57,7 +57,6 @@ from synapse.rest.client.v1.room import (
)
from synapse.server import HomeServer
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
-from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -180,14 +179,10 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
-
ss = EventCreatorServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 418c086254..311523e0ed 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -46,7 +46,6 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -162,14 +161,10 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
-
ss = FederationReaderServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index f24920a7d6..83c436229c 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -41,7 +41,6 @@ from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.replication.tcp.streams._base import ReceiptsStream
from synapse.server import HomeServer
from synapse.storage.database import Database
-from synapse.storage.engines import create_engine
from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer
from synapse.util.httpresourcetree import create_resource_tree
@@ -174,8 +173,6 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
-
if config.send_federation:
sys.stderr.write(
"\nThe send_federation must be disabled in the main synapse process"
@@ -190,10 +187,8 @@ def start(config_options):
ss = FederationSenderServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index e647459d0e..30e435eead 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -39,7 +39,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -234,14 +233,10 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
-
ss = FrontendProxyServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index df65d0a989..0e9bf7f53a 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -69,7 +69,7 @@ from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.well_known import WellKnownResource
from synapse.server import HomeServer
from synapse.storage import DataStore
-from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
+from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.prepare_database import UpgradeDatabaseException
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.httpresourcetree import create_resource_tree
@@ -328,15 +328,10 @@ def setup(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
- config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
-
hs = SynapseHomeServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
synapse.config.logger.setup_logging(hs, config, use_worker_options=False)
@@ -347,13 +342,8 @@ def setup(config_options):
hs.setup()
except IncorrectDatabaseSetup as e:
quit_with_error(str(e))
- except UpgradeDatabaseException:
- sys.stderr.write(
- "\nFailed to upgrade database.\n"
- "Have you checked for version specific instructions in"
- " UPGRADES.rst?\n"
- )
- sys.exit(1)
+ except UpgradeDatabaseException as e:
+ quit_with_error("Failed to upgrade database: %s" % (e,))
hs.setup_master()
@@ -519,8 +509,10 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
# Database version
#
- stats["database_engine"] = hs.database_engine.module.__name__
- stats["database_server_version"] = hs.database_engine.server_version
+ # This only reports info about the *main* database.
+ stats["database_engine"] = hs.get_datastore().db.engine.module.__name__
+ stats["database_server_version"] = hs.get_datastore().db.engine.server_version
+
logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
try:
yield hs.get_proxied_http_client().put_json(
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index 2c6dd3ef02..4c80f257e2 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -40,7 +40,6 @@ from synapse.rest.admin import register_servlets_for_media_repo
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.server import HomeServer
from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore
-from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -157,14 +156,10 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
-
ss = MediaRepositoryServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index dd52a9fc2d..09e639040a 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -37,7 +37,6 @@ from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer
from synapse.storage import DataStore
-from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -203,14 +202,10 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.start_pushers = True
- database_engine = create_engine(config.database_config)
-
ps = PusherServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
setup_logging(ps, config, use_worker_options=True)
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 288ee64b42..dd2132e608 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -55,7 +55,6 @@ from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
from synapse.rest.client.v2_alpha import sync
from synapse.server import HomeServer
from synapse.storage.data_stores.main.presence import UserPresenceState
-from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.stringutils import random_string
@@ -437,14 +436,10 @@ def start(config_options):
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
-
ss = SynchrotronServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
application_service_handler=SynchrotronApplicationService(),
)
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index c01fb34a9b..1257098f92 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -44,7 +44,6 @@ from synapse.rest.client.v2_alpha import user_directory
from synapse.server import HomeServer
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
from synapse.storage.database import Database
-from synapse.storage.engines import create_engine
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
@@ -200,8 +199,6 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
-
if config.update_user_directory:
sys.stderr.write(
"\nThe update_user_directory must be disabled in the main synapse process"
@@ -216,10 +213,8 @@ def start(config_options):
ss = UserDirectoryServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 0e2509f0b1..134824789c 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -12,12 +12,45 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import os
from textwrap import indent
+from typing import List
import yaml
-from ._base import Config
+from synapse.config._base import Config, ConfigError
+
+logger = logging.getLogger(__name__)
+
+
+class DatabaseConnectionConfig:
+ """Contains the connection config for a particular database.
+
+ Args:
+ name: A label for the database, used for logging.
+ db_config: The config for a particular database, as per `database`
+ section of main config. Has two fields: `name` for database
+ module name, and `args` for the args to give to the database
+ connector.
+ data_stores: The list of data stores that should be provisioned on the
+ database. Defaults to all data stores.
+ """
+
+ def __init__(
+ self, name: str, db_config: dict, data_stores: List[str] = ["main", "state"]
+ ):
+ if db_config["name"] not in ("sqlite3", "psycopg2"):
+ raise ConfigError("Unsupported database type %r" % (db_config["name"],))
+
+ if db_config["name"] == "sqlite3":
+ db_config.setdefault("args", {}).update(
+ {"cp_min": 1, "cp_max": 1, "check_same_thread": False}
+ )
+
+ self.name = name
+ self.config = db_config
+ self.data_stores = data_stores
class DatabaseConfig(Config):
@@ -26,20 +59,12 @@ class DatabaseConfig(Config):
def read_config(self, config, **kwargs):
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
- self.database_config = config.get("database")
+ database_config = config.get("database")
- if self.database_config is None:
- self.database_config = {"name": "sqlite3", "args": {}}
+ if database_config is None:
+ database_config = {"name": "sqlite3", "args": {}}
- name = self.database_config.get("name", None)
- if name == "psycopg2":
- pass
- elif name == "sqlite3":
- self.database_config.setdefault("args", {}).update(
- {"cp_min": 1, "cp_max": 1, "check_same_thread": False}
- )
- else:
- raise RuntimeError("Unsupported database type '%s'" % (name,))
+ self.databases = [DatabaseConnectionConfig("master", database_config)]
self.set_databasepath(config.get("database_path"))
@@ -76,11 +101,24 @@ class DatabaseConfig(Config):
self.set_databasepath(args.database_path)
def set_databasepath(self, database_path):
+ if database_path is None:
+ return
+
if database_path != ":memory:":
database_path = self.abspath(database_path)
- if self.database_config.get("name", None) == "sqlite3":
- if database_path is not None:
- self.database_config["args"]["database"] = database_path
+
+ # We only support setting a database path if we have a single sqlite3
+ # database.
+ if len(self.databases) != 1:
+ raise ConfigError("Cannot specify 'database_path' with multiple databases")
+
+ database = self.get_single_database()
+ if database.config["name"] != "sqlite3":
+ # We don't raise here as we haven't done so before for this case.
+ logger.warn("Ignoring 'database_path' for non-sqlite3 database")
+ return
+
+ database.config["args"]["database"] = database_path
@staticmethod
def add_arguments(parser):
@@ -91,3 +129,11 @@ class DatabaseConfig(Config):
metavar="SQLITE_DATABASE_PATH",
help="The path to a sqlite database to use.",
)
+
+ def get_single_database(self) -> DatabaseConnectionConfig:
+ """Returns the database if there is only one, useful for e.g. tests
+ """
+ if len(self.databases) != 1:
+ raise Exception("More than one database exists")
+
+ return self.databases[0]
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 18f42a87f9..35756bed87 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import email.utils
import os
from enum import Enum
+from typing import Optional
import pkg_resources
@@ -101,7 +102,7 @@ class EmailConfig(Config):
# both in RegistrationConfig and here. We should factor this bit out
self.account_threepid_delegate_email = self.trusted_third_party_id_servers[
0
- ]
+ ] # type: Optional[str]
self.using_identity_server_from_trusted_list = True
else:
raise ConfigError(
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 52ff1b2621..066e7838c3 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -108,7 +108,7 @@ class KeyConfig(Config):
self.signing_key = self.read_signing_keys(signing_key_path, "signing_key")
self.old_signing_keys = self.read_old_signing_keys(
- config.get("old_signing_keys", {})
+ config.get("old_signing_keys")
)
self.key_refresh_interval = self.parse_duration(
config.get("key_refresh_interval", "1d")
@@ -199,14 +199,19 @@ class KeyConfig(Config):
signing_key_path: "%(base_key_name)s.signing.key"
# The keys that the server used to sign messages with but won't use
- # to sign new messages. E.g. it has lost its private key
+ # to sign new messages.
#
- #old_signing_keys:
- # "ed25519:auto":
- # # Base64 encoded public key
- # key: "The public part of your old signing key."
- # # Millisecond POSIX timestamp when the key expired.
- # expired_ts: 123456789123
+ old_signing_keys:
+ # For each key, `key` should be the base64-encoded public key, and
+ # `expired_ts`should be the time (in milliseconds since the unix epoch) that
+ # it was last used.
+ #
+ # It is possible to build an entry from an old signing.key file using the
+ # `export_signing_key` script which is provided with synapse.
+ #
+ # For example:
+ #
+ #"ed25519:id": { key: "base64string", expired_ts: 123456789123 }
# How long key response published by this server is valid for.
# Used to set the valid_until_ts in /key/v2 APIs.
@@ -290,6 +295,8 @@ class KeyConfig(Config):
raise ConfigError("Error reading %s: %s" % (name, str(e)))
def read_old_signing_keys(self, old_signing_keys):
+ if old_signing_keys is None:
+ return {}
keys = {}
for key_id, key_data in old_signing_keys.items():
if is_signing_algorithm_supported(key_id):
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 947f653e03..4a3bfc4354 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -83,10 +83,9 @@ class RatelimitConfig(Config):
)
rc_admin_redaction = config.get("rc_admin_redaction")
+ self.rc_admin_redaction = None
if rc_admin_redaction:
self.rc_admin_redaction = RateLimitConfig(rc_admin_redaction)
- else:
- self.rc_admin_redaction = None
def generate_config_section(self, **kwargs):
return """\
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index c5ea2d43a1..b91414aa35 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -14,17 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import re
+import logging
from synapse.python_dependencies import DependencyException, check_requirements
-from synapse.types import (
- map_username_to_mxid_localpart,
- mxid_localpart_allowed_characters,
-)
-from synapse.util.module_loader import load_python_module
+from synapse.util.module_loader import load_module, load_python_module
from ._base import Config, ConfigError
+logger = logging.getLogger(__name__)
+
+DEFAULT_USER_MAPPING_PROVIDER = (
+ "synapse.handlers.saml_handler.DefaultSamlMappingProvider"
+)
+
def _dict_merge(merge_dict, into_dict):
"""Do a deep merge of two dicts
@@ -75,15 +77,69 @@ class SAML2Config(Config):
self.saml2_enabled = True
- self.saml2_mxid_source_attribute = saml2_config.get(
- "mxid_source_attribute", "uid"
- )
-
self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
"grandfathered_mxid_source_attribute", "uid"
)
- saml2_config_dict = self._default_saml_config_dict()
+ # user_mapping_provider may be None if the key is present but has no value
+ ump_dict = saml2_config.get("user_mapping_provider") or {}
+
+ # Use the default user mapping provider if not set
+ ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
+
+ # Ensure a config is present
+ ump_dict["config"] = ump_dict.get("config") or {}
+
+ if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER:
+ # Load deprecated options for use by the default module
+ old_mxid_source_attribute = saml2_config.get("mxid_source_attribute")
+ if old_mxid_source_attribute:
+ logger.warning(
+ "The config option saml2_config.mxid_source_attribute is deprecated. "
+ "Please use saml2_config.user_mapping_provider.config"
+ ".mxid_source_attribute instead."
+ )
+ ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute
+
+ old_mxid_mapping = saml2_config.get("mxid_mapping")
+ if old_mxid_mapping:
+ logger.warning(
+ "The config option saml2_config.mxid_mapping is deprecated. Please "
+ "use saml2_config.user_mapping_provider.config.mxid_mapping instead."
+ )
+ ump_dict["config"]["mxid_mapping"] = old_mxid_mapping
+
+ # Retrieve an instance of the module's class
+ # Pass the config dictionary to the module for processing
+ (
+ self.saml2_user_mapping_provider_class,
+ self.saml2_user_mapping_provider_config,
+ ) = load_module(ump_dict)
+
+ # Ensure loaded user mapping module has defined all necessary methods
+ # Note parse_config() is already checked during the call to load_module
+ required_methods = [
+ "get_saml_attributes",
+ "saml_response_to_user_attributes",
+ ]
+ missing_methods = [
+ method
+ for method in required_methods
+ if not hasattr(self.saml2_user_mapping_provider_class, method)
+ ]
+ if missing_methods:
+ raise ConfigError(
+ "Class specified by saml2_config."
+ "user_mapping_provider.module is missing required "
+ "methods: %s" % (", ".join(missing_methods),)
+ )
+
+ # Get the desired saml auth response attributes from the module
+ saml2_config_dict = self._default_saml_config_dict(
+ *self.saml2_user_mapping_provider_class.get_saml_attributes(
+ self.saml2_user_mapping_provider_config
+ )
+ )
_dict_merge(
merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
)
@@ -103,22 +159,27 @@ class SAML2Config(Config):
saml2_config.get("saml_session_lifetime", "5m")
)
- mapping = saml2_config.get("mxid_mapping", "hexencode")
- try:
- self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping]
- except KeyError:
- raise ConfigError("%s is not a known mxid_mapping" % (mapping,))
-
- def _default_saml_config_dict(self):
+ def _default_saml_config_dict(
+ self, required_attributes: set, optional_attributes: set
+ ):
+ """Generate a configuration dictionary with required and optional attributes that
+ will be needed to process new user registration
+
+ Args:
+ required_attributes: SAML auth response attributes that are
+ necessary to function
+ optional_attributes: SAML auth response attributes that can be used to add
+ additional information to Synapse user accounts, but are not required
+
+ Returns:
+ dict: A SAML configuration dictionary
+ """
import saml2
public_baseurl = self.public_baseurl
if public_baseurl is None:
raise ConfigError("saml2_config requires a public_baseurl to be set")
- required_attributes = {"uid", self.saml2_mxid_source_attribute}
-
- optional_attributes = {"displayName"}
if self.saml2_grandfathered_mxid_source_attribute:
optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
optional_attributes -= required_attributes
@@ -207,33 +268,58 @@ class SAML2Config(Config):
#
#config_path: "%(config_dir_path)s/sp_conf.py"
- # the lifetime of a SAML session. This defines how long a user has to
+ # The lifetime of a SAML session. This defines how long a user has to
# complete the authentication process, if allow_unsolicited is unset.
# The default is 5 minutes.
#
#saml_session_lifetime: 5m
- # The SAML attribute (after mapping via the attribute maps) to use to derive
- # the Matrix ID from. 'uid' by default.
+ # An external module can be provided here as a custom solution to
+ # mapping attributes returned from a saml provider onto a matrix user.
#
- #mxid_source_attribute: displayName
-
- # The mapping system to use for mapping the saml attribute onto a matrix ID.
- # Options include:
- # * 'hexencode' (which maps unpermitted characters to '=xx')
- # * 'dotreplace' (which replaces unpermitted characters with '.').
- # The default is 'hexencode'.
- #
- #mxid_mapping: dotreplace
-
- # In previous versions of synapse, the mapping from SAML attribute to MXID was
- # always calculated dynamically rather than stored in a table. For backwards-
- # compatibility, we will look for user_ids matching such a pattern before
- # creating a new account.
+ user_mapping_provider:
+ # The custom module's class. Uncomment to use a custom module.
+ #
+ #module: mapping_provider.SamlMappingProvider
+
+ # Custom configuration values for the module. Below options are
+ # intended for the built-in provider, they should be changed if
+ # using a custom module. This section will be passed as a Python
+ # dictionary to the module's `parse_config` method.
+ #
+ config:
+ # The SAML attribute (after mapping via the attribute maps) to use
+ # to derive the Matrix ID from. 'uid' by default.
+ #
+ # Note: This used to be configured by the
+ # saml2_config.mxid_source_attribute option. If that is still
+ # defined, its value will be used instead.
+ #
+ #mxid_source_attribute: displayName
+
+ # The mapping system to use for mapping the saml attribute onto a
+ # matrix ID.
+ #
+ # Options include:
+ # * 'hexencode' (which maps unpermitted characters to '=xx')
+ # * 'dotreplace' (which replaces unpermitted characters with
+ # '.').
+ # The default is 'hexencode'.
+ #
+ # Note: This used to be configured by the
+ # saml2_config.mxid_mapping option. If that is still defined, its
+ # value will be used instead.
+ #
+ #mxid_mapping: dotreplace
+
+ # In previous versions of synapse, the mapping from SAML attribute to
+ # MXID was always calculated dynamically rather than stored in a
+ # table. For backwards- compatibility, we will look for user_ids
+ # matching such a pattern before creating a new account.
#
# This setting controls the SAML attribute which will be used for this
- # backwards-compatibility lookup. Typically it should be 'uid', but if the
- # attribute maps are changed, it may be necessary to change it.
+ # backwards-compatibility lookup. Typically it should be 'uid', but if
+ # the attribute maps are changed, it may be necessary to change it.
#
# The default is 'uid'.
#
@@ -241,23 +327,3 @@ class SAML2Config(Config):
""" % {
"config_dir_path": config_dir_path
}
-
-
-DOT_REPLACE_PATTERN = re.compile(
- ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
-)
-
-
-def dot_replace_for_mxid(username: str) -> str:
- username = username.lower()
- username = DOT_REPLACE_PATTERN.sub(".", username)
-
- # regular mxids aren't allowed to start with an underscore either
- username = re.sub("^_", "", username)
- return username
-
-
-MXID_MAPPER_MAP = {
- "hexencode": map_username_to_mxid_localpart,
- "dotreplace": dot_replace_for_mxid,
-}
diff --git a/synapse/config/server.py b/synapse/config/server.py
index a4bef00936..38f6ff9edc 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -102,6 +102,12 @@ class ServerConfig(Config):
"require_auth_for_profile_requests", False
)
+ # Whether to require sharing a room with a user to retrieve their
+ # profile data
+ self.limit_profile_requests_to_users_who_share_rooms = config.get(
+ "limit_profile_requests_to_users_who_share_rooms", False,
+ )
+
if "restrict_public_rooms_to_local_users" in config and (
"allow_public_rooms_without_auth" in config
or "allow_public_rooms_over_federation" in config
@@ -200,7 +206,7 @@ class ServerConfig(Config):
self.admin_contact = config.get("admin_contact", None)
# FIXME: federation_domain_whitelist needs sytests
- self.federation_domain_whitelist = None
+ self.federation_domain_whitelist = None # type: Optional[dict]
federation_domain_whitelist = config.get("federation_domain_whitelist", None)
if federation_domain_whitelist is not None:
@@ -621,6 +627,13 @@ class ServerConfig(Config):
#
#require_auth_for_profile_requests: true
+ # Uncomment to require a user to share a room with another user in order
+ # to retrieve their profile information. Only checked on Client-Server
+ # requests. Profile requests from other servers should be checked by the
+ # requesting server. Defaults to 'false'.
+ #
+ #limit_profile_requests_to_users_who_share_rooms: true
+
# If set to 'true', removes the need for authentication to access the server's
# public rooms directory through the client API, meaning that anyone can
# query the room directory. Defaults to 'false'.
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 350ed9351f..1033e5e121 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -43,6 +43,8 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
Returns:
if the auth checks pass.
"""
+ assert isinstance(auth_events, dict)
+
if do_size_check:
_check_size_limits(event)
@@ -87,12 +89,6 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
- if auth_events is None:
- # Oh, we don't know what the state of the room was, so we
- # are trusting that this is allowed (at least for now)
- logger.warning("Trusting event: %s", event.event_id)
- return
-
if event.type == EventTypes.Create:
sender_domain = get_domain_from_id(event.sender)
room_id_domain = get_domain_from_id(event.room_id)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 64e898f40c..a44baea365 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -149,7 +149,7 @@ class EventContext:
# the prev_state_ids, so if we're a state event we include the event
# id that we replaced in the state.
if event.is_state():
- prev_state_ids = yield self.get_prev_state_ids(store)
+ prev_state_ids = yield self.get_prev_state_ids()
prev_state_id = prev_state_ids.get((event.type, event.state_key))
else:
prev_state_id = None
@@ -167,12 +167,13 @@ class EventContext:
}
@staticmethod
- def deserialize(store, input):
+ def deserialize(storage, input):
"""Converts a dict that was produced by `serialize` back into a
EventContext.
Args:
- store (DataStore): Used to convert AS ID to AS object
+ storage (Storage): Used to convert AS ID to AS object and fetch
+ state.
input (dict): A dict produced by `serialize`
Returns:
@@ -181,6 +182,7 @@ class EventContext:
context = _AsyncEventContextImpl(
# We use the state_group and prev_state_id stuff to pull the
# current_state_ids out of the DB and construct prev_state_ids.
+ storage=storage,
prev_state_id=input["prev_state_id"],
event_type=input["event_type"],
event_state_key=input["event_state_key"],
@@ -193,7 +195,7 @@ class EventContext:
app_service_id = input["app_service_id"]
if app_service_id:
- context.app_service = store.get_app_service_by_id(app_service_id)
+ context.app_service = storage.main.get_app_service_by_id(app_service_id)
return context
@@ -216,7 +218,7 @@ class EventContext:
return self._state_group
@defer.inlineCallbacks
- def get_current_state_ids(self, store):
+ def get_current_state_ids(self):
"""
Gets the room state map, including this event - ie, the state in ``state_group``
@@ -234,11 +236,11 @@ class EventContext:
if self.rejected:
raise RuntimeError("Attempt to access state_ids of rejected event")
- yield self._ensure_fetched(store)
+ yield self._ensure_fetched()
return self._current_state_ids
@defer.inlineCallbacks
- def get_prev_state_ids(self, store):
+ def get_prev_state_ids(self):
"""
Gets the room state map, excluding this event.
@@ -250,7 +252,7 @@ class EventContext:
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
- yield self._ensure_fetched(store)
+ yield self._ensure_fetched()
return self._prev_state_ids
def get_cached_current_state_ids(self):
@@ -270,7 +272,7 @@ class EventContext:
return self._current_state_ids
- def _ensure_fetched(self, store):
+ def _ensure_fetched(self):
return defer.succeed(None)
@@ -282,6 +284,8 @@ class _AsyncEventContextImpl(EventContext):
Attributes:
+ _storage (Storage)
+
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
been calculated. None if we haven't started calculating yet
@@ -295,28 +299,30 @@ class _AsyncEventContextImpl(EventContext):
that was replaced.
"""
+ # This needs to have a default as we're inheriting
+ _storage = attr.ib(default=None)
_prev_state_id = attr.ib(default=None)
_event_type = attr.ib(default=None)
_event_state_key = attr.ib(default=None)
_fetching_state_deferred = attr.ib(default=None)
- def _ensure_fetched(self, store):
+ def _ensure_fetched(self):
if not self._fetching_state_deferred:
- self._fetching_state_deferred = run_in_background(
- self._fill_out_state, store
- )
+ self._fetching_state_deferred = run_in_background(self._fill_out_state)
return make_deferred_yieldable(self._fetching_state_deferred)
@defer.inlineCallbacks
- def _fill_out_state(self, store):
+ def _fill_out_state(self):
"""Called to populate the _current_state_ids and _prev_state_ids
attributes by loading from the database.
"""
if self.state_group is None:
return
- self._current_state_ids = yield store.get_state_ids_for_group(self.state_group)
+ self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
+ self.state_group
+ )
if self._prev_state_id and self._event_state_key is not None:
self._prev_state_ids = dict(self._current_state_ids)
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 714a9b1579..86f7e5f8aa 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -53,7 +53,7 @@ class ThirdPartyEventRules(object):
if self.third_party_rules is None:
return True
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
# Retrieve the state events from the database.
state_events = {}
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index d396e6564f..af652a7659 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -526,13 +526,7 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_request(destination):
- time_now = self._clock.time_msec()
- _, content = yield self.transport_layer.send_join(
- destination=destination,
- room_id=pdu.room_id,
- event_id=pdu.event_id,
- content=pdu.get_pdu_json(time_now),
- )
+ content = yield self._do_send_join(destination, pdu)
logger.debug("Got content: %s", content)
@@ -600,6 +594,44 @@ class FederationClient(FederationBase):
return self._try_destination_list("send_join", destinations, send_request)
@defer.inlineCallbacks
+ def _do_send_join(self, destination, pdu):
+ time_now = self._clock.time_msec()
+
+ try:
+ content = yield self.transport_layer.send_join_v2(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+
+ return content
+ except HttpResponseException as e:
+ if e.code in [400, 404]:
+ err = e.to_synapse_error()
+
+ # If we receive an error response that isn't a generic error, or an
+ # unrecognised endpoint error, we assume that the remote understands
+ # the v2 invite API and this is a legitimate error.
+ if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]:
+ raise err
+ else:
+ raise e.to_synapse_error()
+
+ logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
+
+ resp = yield self.transport_layer.send_join_v1(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+
+ # We expect the v1 API to respond with [200, content], so we only return the
+ # content.
+ return resp[1]
+
+ @defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu):
room_version = yield self.store.get_room_version(room_id)
@@ -708,18 +740,50 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_request(destination):
- time_now = self._clock.time_msec()
- _, content = yield self.transport_layer.send_leave(
+ content = yield self._do_send_leave(destination, pdu)
+
+ logger.debug("Got content: %s", content)
+ return None
+
+ return self._try_destination_list("send_leave", destinations, send_request)
+
+ @defer.inlineCallbacks
+ def _do_send_leave(self, destination, pdu):
+ time_now = self._clock.time_msec()
+
+ try:
+ content = yield self.transport_layer.send_leave_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
- logger.debug("Got content: %s", content)
- return None
+ return content
+ except HttpResponseException as e:
+ if e.code in [400, 404]:
+ err = e.to_synapse_error()
- return self._try_destination_list("send_leave", destinations, send_request)
+ # If we receive an error response that isn't a generic error, or an
+ # unrecognised endpoint error, we assume that the remote understands
+ # the v2 invite API and this is a legitimate error.
+ if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]:
+ raise err
+ else:
+ raise e.to_synapse_error()
+
+ logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API")
+
+ resp = yield self.transport_layer.send_leave_v1(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+
+ # We expect the v1 API to respond with [200, content], so we only return the
+ # content.
+ return resp[1]
def get_public_rooms(
self,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 84d4eca041..d7ce333822 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -384,15 +384,10 @@ class FederationServer(FederationBase):
res_pdus = await self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
- return (
- 200,
- {
- "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
- "auth_chain": [
- p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
- ],
- },
- )
+ return {
+ "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
+ "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
+ }
async def on_make_leave_request(self, origin, room_id, user_id):
origin_host, _ = parse_server_name(origin)
@@ -419,7 +414,7 @@ class FederationServer(FederationBase):
pdu = await self._check_sigs_and_hash(room_version, pdu)
await self.handler.on_send_leave_request(origin, pdu)
- return 200, {}
+ return {}
async def on_event_auth(self, origin, room_id, event_id):
with (await self._server_linearizer.queue((origin, room_id))):
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 46dba84cac..198257414b 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -243,7 +243,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
- def send_join(self, destination, room_id, event_id, content):
+ def send_join_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json(
@@ -254,7 +254,18 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
- def send_leave(self, destination, room_id, event_id, content):
+ def send_join_v2(self, destination, room_id, event_id, content):
+ path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
+
+ response = yield self.client.put_json(
+ destination=destination, path=path, data=content
+ )
+
+ return response
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_leave_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json(
@@ -272,6 +283,24 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
+ def send_leave_v2(self, destination, room_id, event_id, content):
+ path = _create_v2_path("/send_leave/%s/%s", room_id, event_id)
+
+ response = yield self.client.put_json(
+ destination=destination,
+ path=path,
+ data=content,
+ # we want to do our best to send this through. The problem is
+ # that if it fails, we won't retry it later, so if the remote
+ # server was just having a momentary blip, the room will be out of
+ # sync.
+ ignore_backoff=True,
+ )
+
+ return response
+
+ @defer.inlineCallbacks
+ @log_function
def send_invite_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index fefc789c85..b4cbf23394 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -506,11 +506,21 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
return 200, content
-class FederationSendLeaveServlet(BaseFederationServlet):
+class FederationV1SendLeaveServlet(BaseFederationServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id):
content = await self.handler.on_send_leave_request(origin, content, room_id)
+ return 200, (200, content)
+
+
+class FederationV2SendLeaveServlet(BaseFederationServlet):
+ PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
+
+ PREFIX = FEDERATION_V2_PREFIX
+
+ async def on_PUT(self, origin, content, query, room_id, event_id):
+ content = await self.handler.on_send_leave_request(origin, content, room_id)
return 200, content
@@ -521,9 +531,21 @@ class FederationEventAuthServlet(BaseFederationServlet):
return await self.handler.on_event_auth(origin, context, event_id)
-class FederationSendJoinServlet(BaseFederationServlet):
+class FederationV1SendJoinServlet(BaseFederationServlet):
+ PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+
+ async def on_PUT(self, origin, content, query, context, event_id):
+ # TODO(paul): assert that context/event_id parsed from path actually
+ # match those given in content
+ content = await self.handler.on_send_join_request(origin, content, context)
+ return 200, (200, content)
+
+
+class FederationV2SendJoinServlet(BaseFederationServlet):
PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+ PREFIX = FEDERATION_V2_PREFIX
+
async def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
# match those given in content
@@ -1367,8 +1389,10 @@ FEDERATION_SERVLET_CLASSES = (
FederationMakeJoinServlet,
FederationMakeLeaveServlet,
FederationEventServlet,
- FederationSendJoinServlet,
- FederationSendLeaveServlet,
+ FederationV1SendJoinServlet,
+ FederationV2SendJoinServlet,
+ FederationV1SendLeaveServlet,
+ FederationV2SendLeaveServlet,
FederationV1InviteServlet,
FederationV2InviteServlet,
FederationQueryAuthServlet,
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 29e8ffc295..0ec9be3cb5 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -773,6 +773,11 @@ class GroupsServerHandler(object):
if not self.hs.is_mine_id(user_id):
yield self.store.maybe_delete_remote_profile_cache(user_id)
+ # Delete group if the last user has left
+ users = yield self.store.get_users_in_group(group_id, include_private=True)
+ if not users:
+ yield self.store.delete_group(group_id)
+
return {}
@defer.inlineCallbacks
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index d15c6282fb..51413d910e 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -134,7 +134,7 @@ class BaseHandler(object):
guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join":
if context:
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
current_state = yield self.store.get_events(
list(current_state_ids.values())
)
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 2d7e6df6e4..a8d3fbc6de 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
class AccountDataEventSource(object):
def __init__(self, hs):
@@ -23,15 +21,14 @@ class AccountDataEventSource(object):
def get_current_key(self, direction="f"):
return self.store.get_max_account_data_stream_id()
- @defer.inlineCallbacks
- def get_new_events(self, user, from_key, **kwargs):
+ async def get_new_events(self, user, from_key, **kwargs):
user_id = user.to_string()
last_stream_id = from_key
- current_stream_id = yield self.store.get_max_account_data_stream_id()
+ current_stream_id = self.store.get_max_account_data_stream_id()
results = []
- tags = yield self.store.get_updated_tags(user_id, last_stream_id)
+ tags = await self.store.get_updated_tags(user_id, last_stream_id)
for room_id, room_tags in tags.items():
results.append(
@@ -41,7 +38,7 @@ class AccountDataEventSource(object):
(
account_data,
room_account_data,
- ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
+ ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
for account_data_type, content in account_data.items():
results.append({"type": account_data_type, "content": content})
@@ -53,7 +50,3 @@ class AccountDataEventSource(object):
)
return results, current_stream_id
-
- @defer.inlineCallbacks
- def get_pagination_rows(self, user, config, key):
- return [], config.to_id
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index d04e0fe576..829f52eca1 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -18,8 +18,7 @@ import email.utils
import logging
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
-
-from twisted.internet import defer
+from typing import List
from synapse.api.errors import StoreError
from synapse.logging.context import make_deferred_yieldable
@@ -78,42 +77,39 @@ class AccountValidityHandler(object):
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
- "send_renewals", self.send_renewal_emails
+ "send_renewals", self._send_renewal_emails
)
self.clock.looping_call(send_emails, 30 * 60 * 1000)
- @defer.inlineCallbacks
- def send_renewal_emails(self):
+ async def _send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity``
configuration, and sends renewal emails to all of these users as long as they
have an email 3PID attached to their account.
"""
- expiring_users = yield self.store.get_users_expiring_soon()
+ expiring_users = await self.store.get_users_expiring_soon()
if expiring_users:
for user in expiring_users:
- yield self._send_renewal_email(
+ await self._send_renewal_email(
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
)
- @defer.inlineCallbacks
- def send_renewal_email_to_user(self, user_id):
- expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
- yield self._send_renewal_email(user_id, expiration_ts)
+ async def send_renewal_email_to_user(self, user_id: str):
+ expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
+ await self._send_renewal_email(user_id, expiration_ts)
- @defer.inlineCallbacks
- def _send_renewal_email(self, user_id, expiration_ts):
+ async def _send_renewal_email(self, user_id: str, expiration_ts: int):
"""Sends out a renewal email to every email address attached to the given user
with a unique link allowing them to renew their account.
Args:
- user_id (str): ID of the user to send email(s) to.
- expiration_ts (int): Timestamp in milliseconds for the expiration date of
+ user_id: ID of the user to send email(s) to.
+ expiration_ts: Timestamp in milliseconds for the expiration date of
this user's account (used in the email templates).
"""
- addresses = yield self._get_email_addresses_for_user(user_id)
+ addresses = await self._get_email_addresses_for_user(user_id)
# Stop right here if the user doesn't have at least one email address.
# In this case, they will have to ask their server admin to renew their
@@ -125,7 +121,7 @@ class AccountValidityHandler(object):
return
try:
- user_display_name = yield self.store.get_profile_displayname(
+ user_display_name = await self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
)
if user_display_name is None:
@@ -133,7 +129,7 @@ class AccountValidityHandler(object):
except StoreError:
user_display_name = user_id
- renewal_token = yield self._get_renewal_token(user_id)
+ renewal_token = await self._get_renewal_token(user_id)
url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
self.hs.config.public_baseurl,
renewal_token,
@@ -165,7 +161,7 @@ class AccountValidityHandler(object):
logger.info("Sending renewal email to %s", address)
- yield make_deferred_yieldable(
+ await make_deferred_yieldable(
self.sendmail(
self.hs.config.email_smtp_host,
self._raw_from,
@@ -180,19 +176,18 @@ class AccountValidityHandler(object):
)
)
- yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
+ await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
- @defer.inlineCallbacks
- def _get_email_addresses_for_user(self, user_id):
+ async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
"""Retrieve the list of email addresses attached to a user's account.
Args:
- user_id (str): ID of the user to lookup email addresses for.
+ user_id: ID of the user to lookup email addresses for.
Returns:
- defer.Deferred[list[str]]: Email addresses for this account.
+ Email addresses for this account.
"""
- threepids = yield self.store.user_get_threepids(user_id)
+ threepids = await self.store.user_get_threepids(user_id)
addresses = []
for threepid in threepids:
@@ -201,16 +196,15 @@ class AccountValidityHandler(object):
return addresses
- @defer.inlineCallbacks
- def _get_renewal_token(self, user_id):
+ async def _get_renewal_token(self, user_id: str) -> str:
"""Generates a 32-byte long random string that will be inserted into the
user's renewal email's unique link, then saves it into the database.
Args:
- user_id (str): ID of the user to generate a string for.
+ user_id: ID of the user to generate a string for.
Returns:
- defer.Deferred[str]: The generated string.
+ The generated string.
Raises:
StoreError(500): Couldn't generate a unique string after 5 attempts.
@@ -219,52 +213,52 @@ class AccountValidityHandler(object):
while attempts < 5:
try:
renewal_token = stringutils.random_string(32)
- yield self.store.set_renewal_token_for_user(user_id, renewal_token)
+ await self.store.set_renewal_token_for_user(user_id, renewal_token)
return renewal_token
except StoreError:
attempts += 1
raise StoreError(500, "Couldn't generate a unique string as refresh string.")
- @defer.inlineCallbacks
- def renew_account(self, renewal_token):
+ async def renew_account(self, renewal_token: str) -> bool:
"""Renews the account attached to a given renewal token by pushing back the
expiration date by the current validity period in the server's configuration.
Args:
- renewal_token (str): Token sent with the renewal request.
+ renewal_token: Token sent with the renewal request.
Returns:
- bool: Whether the provided token is valid.
+ Whether the provided token is valid.
"""
try:
- user_id = yield self.store.get_user_from_renewal_token(renewal_token)
+ user_id = await self.store.get_user_from_renewal_token(renewal_token)
except StoreError:
- defer.returnValue(False)
+ return False
logger.debug("Renewing an account for user %s", user_id)
- yield self.renew_account_for_user(user_id)
+ await self.renew_account_for_user(user_id)
- defer.returnValue(True)
+ return True
- @defer.inlineCallbacks
- def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
+ async def renew_account_for_user(
+ self, user_id: str, expiration_ts: int = None, email_sent: bool = False
+ ) -> int:
"""Renews the account attached to a given user by pushing back the
expiration date by the current validity period in the server's
configuration.
Args:
- renewal_token (str): Token sent with the renewal request.
- expiration_ts (int): New expiration date. Defaults to now + validity period.
- email_sent (bool): Whether an email has been sent for this validity period.
+ renewal_token: Token sent with the renewal request.
+ expiration_ts: New expiration date. Defaults to now + validity period.
+ email_sen: Whether an email has been sent for this validity period.
Defaults to False.
Returns:
- defer.Deferred[int]: New expiration date for this account, as a timestamp
- in milliseconds since epoch.
+ New expiration date for this account, as a timestamp in
+ milliseconds since epoch.
"""
if expiration_ts is None:
expiration_ts = self.clock.time_msec() + self._account_validity.period
- yield self.store.set_account_validity_for_user(
+ await self.store.set_account_validity_for_user(
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
)
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 14449b9a1e..1a4ba12385 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.constants import Membership
from synapse.types import RoomStreamToken
from synapse.visibility import filter_events_for_client
@@ -33,11 +31,10 @@ class AdminHandler(BaseHandler):
self.storage = hs.get_storage()
self.state_store = self.storage.state
- @defer.inlineCallbacks
- def get_whois(self, user):
+ async def get_whois(self, user):
connections = []
- sessions = yield self.store.get_user_ip_and_agents(user)
+ sessions = await self.store.get_user_ip_and_agents(user)
for session in sessions:
connections.append(
{
@@ -54,20 +51,18 @@ class AdminHandler(BaseHandler):
return ret
- @defer.inlineCallbacks
- def get_users(self):
+ async def get_users(self):
"""Function to retrieve a list of users in users table.
Args:
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
- ret = yield self.store.get_users()
+ ret = await self.store.get_users()
return ret
- @defer.inlineCallbacks
- def get_users_paginate(self, start, limit, name, guests, deactivated):
+ async def get_users_paginate(self, start, limit, name, guests, deactivated):
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users.
@@ -80,14 +75,13 @@ class AdminHandler(BaseHandler):
Returns:
defer.Deferred: resolves to json list[dict[str, Any]]
"""
- ret = yield self.store.get_users_paginate(
+ ret = await self.store.get_users_paginate(
start, limit, name, guests, deactivated
)
return ret
- @defer.inlineCallbacks
- def search_users(self, term):
+ async def search_users(self, term):
"""Function to search users list for one or more users with
the matched term.
@@ -96,7 +90,7 @@ class AdminHandler(BaseHandler):
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
- ret = yield self.store.search_users(term)
+ ret = await self.store.search_users(term)
return ret
@@ -119,8 +113,7 @@ class AdminHandler(BaseHandler):
"""
return self.store.set_server_admin(user, admin)
- @defer.inlineCallbacks
- def export_user_data(self, user_id, writer):
+ async def export_user_data(self, user_id, writer):
"""Write all data we have on the user to the given writer.
Args:
@@ -132,7 +125,7 @@ class AdminHandler(BaseHandler):
The returned value is that returned by `writer.finished()`.
"""
# Get all rooms the user is in or has been in
- rooms = yield self.store.get_rooms_for_user_where_membership_is(
+ rooms = await self.store.get_rooms_for_user_where_membership_is(
user_id,
membership_list=(
Membership.JOIN,
@@ -145,7 +138,7 @@ class AdminHandler(BaseHandler):
# We only try and fetch events for rooms the user has been in. If
# they've been e.g. invited to a room without joining then we handle
# those seperately.
- rooms_user_has_been_in = yield self.store.get_rooms_user_has_been_in(user_id)
+ rooms_user_has_been_in = await self.store.get_rooms_user_has_been_in(user_id)
for index, room in enumerate(rooms):
room_id = room.room_id
@@ -154,7 +147,7 @@ class AdminHandler(BaseHandler):
"[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms)
)
- forgotten = yield self.store.did_forget(user_id, room_id)
+ forgotten = await self.store.did_forget(user_id, room_id)
if forgotten:
logger.info("[%s] User forgot room %d, ignoring", user_id, room_id)
continue
@@ -166,7 +159,7 @@ class AdminHandler(BaseHandler):
if room.membership == Membership.INVITE:
event_id = room.event_id
- invite = yield self.store.get_event(event_id, allow_none=True)
+ invite = await self.store.get_event(event_id, allow_none=True)
if invite:
invited_state = invite.unsigned["invite_room_state"]
writer.write_invite(room_id, invite, invited_state)
@@ -177,7 +170,7 @@ class AdminHandler(BaseHandler):
# were joined. We estimate that point by looking at the
# stream_ordering of the last membership if it wasn't a join.
if room.membership == Membership.JOIN:
- stream_ordering = yield self.store.get_room_max_stream_ordering()
+ stream_ordering = self.store.get_room_max_stream_ordering()
else:
stream_ordering = room.stream_ordering
@@ -203,7 +196,7 @@ class AdminHandler(BaseHandler):
# events that we have and then filtering, this isn't the most
# efficient method perhaps but it does guarantee we get everything.
while True:
- events, _ = yield self.store.paginate_room_events(
+ events, _ = await self.store.paginate_room_events(
room_id, from_key, to_key, limit=100, direction="f"
)
if not events:
@@ -211,7 +204,7 @@ class AdminHandler(BaseHandler):
from_key = events[-1].internal_metadata.after
- events = yield filter_events_for_client(self.storage, user_id, events)
+ events = await filter_events_for_client(self.storage, user_id, events)
writer.write_events(room_id, events)
@@ -247,7 +240,7 @@ class AdminHandler(BaseHandler):
for event_id in extremities:
if not event_to_unseen_prevs[event_id]:
continue
- state = yield self.state_store.get_state_for_event(event_id)
+ state = await self.state_store.get_state_for_event(event_id)
writer.write_state(room_id, event_id, state)
return writer.finished()
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 6dedaaff8d..4426967f88 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -15,8 +15,6 @@
# limitations under the License.
import logging
-from twisted.internet import defer
-
from synapse.api.errors import SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID, create_requester
@@ -46,8 +44,7 @@ class DeactivateAccountHandler(BaseHandler):
self._account_validity_enabled = hs.config.account_validity.enabled
- @defer.inlineCallbacks
- def deactivate_account(self, user_id, erase_data, id_server=None):
+ async def deactivate_account(self, user_id, erase_data, id_server=None):
"""Deactivate a user's account
Args:
@@ -74,11 +71,11 @@ class DeactivateAccountHandler(BaseHandler):
identity_server_supports_unbinding = True
# Retrieve the 3PIDs this user has bound to an identity server
- threepids = yield self.store.user_get_bound_threepids(user_id)
+ threepids = await self.store.user_get_bound_threepids(user_id)
for threepid in threepids:
try:
- result = yield self._identity_handler.try_unbind_threepid(
+ result = await self._identity_handler.try_unbind_threepid(
user_id,
{
"medium": threepid["medium"],
@@ -91,33 +88,33 @@ class DeactivateAccountHandler(BaseHandler):
# Do we want this to be a fatal error or should we carry on?
logger.exception("Failed to remove threepid from ID server")
raise SynapseError(400, "Failed to remove threepid from ID server")
- yield self.store.user_delete_threepid(
+ await self.store.user_delete_threepid(
user_id, threepid["medium"], threepid["address"]
)
# Remove all 3PIDs this user has bound to the homeserver
- yield self.store.user_delete_threepids(user_id)
+ await self.store.user_delete_threepids(user_id)
# delete any devices belonging to the user, which will also
# delete corresponding access tokens.
- yield self._device_handler.delete_all_devices_for_user(user_id)
+ await self._device_handler.delete_all_devices_for_user(user_id)
# then delete any remaining access tokens which weren't associated with
# a device.
- yield self._auth_handler.delete_access_tokens_for_user(user_id)
+ await self._auth_handler.delete_access_tokens_for_user(user_id)
- yield self.store.user_set_password_hash(user_id, None)
+ await self.store.user_set_password_hash(user_id, None)
# Add the user to a table of users pending deactivation (ie.
# removal from all the rooms they're a member of)
- yield self.store.add_user_pending_deactivation(user_id)
+ await self.store.add_user_pending_deactivation(user_id)
# delete from user directory
- yield self.user_directory_handler.handle_user_deactivated(user_id)
+ await self.user_directory_handler.handle_user_deactivated(user_id)
# Mark the user as erased, if they asked for that
if erase_data:
logger.info("Marking %s as erased", user_id)
- yield self.store.mark_user_erased(user_id)
+ await self.store.mark_user_erased(user_id)
# Now start the process that goes through that list and
# parts users from rooms (if it isn't already running)
@@ -125,30 +122,29 @@ class DeactivateAccountHandler(BaseHandler):
# Reject all pending invites for the user, so that the user doesn't show up in the
# "invited" section of rooms' members list.
- yield self._reject_pending_invites_for_user(user_id)
+ await self._reject_pending_invites_for_user(user_id)
# Remove all information on the user from the account_validity table.
if self._account_validity_enabled:
- yield self.store.delete_account_validity_for_user(user_id)
+ await self.store.delete_account_validity_for_user(user_id)
# Mark the user as deactivated.
- yield self.store.set_user_deactivated_status(user_id, True)
+ await self.store.set_user_deactivated_status(user_id, True)
return identity_server_supports_unbinding
- @defer.inlineCallbacks
- def _reject_pending_invites_for_user(self, user_id):
+ async def _reject_pending_invites_for_user(self, user_id):
"""Reject pending invites addressed to a given user ID.
Args:
user_id (str): The user ID to reject pending invites for.
"""
user = UserID.from_string(user_id)
- pending_invites = yield self.store.get_invited_rooms_for_user(user_id)
+ pending_invites = await self.store.get_invited_rooms_for_user(user_id)
for room in pending_invites:
try:
- yield self._room_member_handler.update_membership(
+ await self._room_member_handler.update_membership(
create_requester(user),
user,
room.room_id,
@@ -180,8 +176,7 @@ class DeactivateAccountHandler(BaseHandler):
if not self._user_parter_running:
run_as_background_process("user_parter_loop", self._user_parter_loop)
- @defer.inlineCallbacks
- def _user_parter_loop(self):
+ async def _user_parter_loop(self):
"""Loop that parts deactivated users from rooms
Returns:
@@ -191,19 +186,18 @@ class DeactivateAccountHandler(BaseHandler):
logger.info("Starting user parter")
try:
while True:
- user_id = yield self.store.get_user_pending_deactivation()
+ user_id = await self.store.get_user_pending_deactivation()
if user_id is None:
break
logger.info("User parter parting %r", user_id)
- yield self._part_user(user_id)
- yield self.store.del_user_pending_deactivation(user_id)
+ await self._part_user(user_id)
+ await self.store.del_user_pending_deactivation(user_id)
logger.info("User parter finished parting %r", user_id)
logger.info("User parter finished: stopping")
finally:
self._user_parter_running = False
- @defer.inlineCallbacks
- def _part_user(self, user_id):
+ async def _part_user(self, user_id):
"""Causes the given user_id to leave all the rooms they're joined to
Returns:
@@ -211,11 +205,11 @@ class DeactivateAccountHandler(BaseHandler):
"""
user = UserID.from_string(user_id)
- rooms_for_user = yield self.store.get_rooms_for_user(user_id)
+ rooms_for_user = await self.store.get_rooms_for_user(user_id)
for room_id in rooms_for_user:
logger.info("User parter parting %r from %r", user_id, room_id)
try:
- yield self._room_member_handler.update_membership(
+ await self._room_member_handler.update_membership(
create_requester(user),
user,
room_id,
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 57a10daefd..2d889364d4 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -264,6 +264,7 @@ class E2eKeysHandler(object):
return ret
+ @defer.inlineCallbacks
def get_cross_signing_keys_from_cache(self, query, from_user_id):
"""Get cross-signing keys for users from the database
@@ -283,14 +284,32 @@ class E2eKeysHandler(object):
self_signing_keys = {}
user_signing_keys = {}
- # Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486
- return defer.succeed(
- {
- "master_keys": master_keys,
- "self_signing_keys": self_signing_keys,
- "user_signing_keys": user_signing_keys,
- }
- )
+ user_ids = list(query)
+
+ keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
+
+ for user_id, user_info in keys.items():
+ if user_info is None:
+ continue
+ if "master" in user_info:
+ master_keys[user_id] = user_info["master"]
+ if "self_signing" in user_info:
+ self_signing_keys[user_id] = user_info["self_signing"]
+
+ if (
+ from_user_id in keys
+ and keys[from_user_id] is not None
+ and "user_signing" in keys[from_user_id]
+ ):
+ # users can see other users' master and self-signing keys, but can
+ # only see their own user-signing keys
+ user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
+
+ return {
+ "master_keys": master_keys,
+ "self_signing_keys": self_signing_keys,
+ "user_signing_keys": user_signing_keys,
+ }
@trace
@defer.inlineCallbacks
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 6fb453ce60..72a0febc2b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,7 +19,7 @@
import itertools
import logging
-from typing import Dict, Iterable, Optional, Sequence, Tuple
+from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import six
from six import iteritems, itervalues
@@ -63,6 +63,7 @@ from synapse.replication.http.federation import (
)
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.types import UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room
@@ -163,8 +164,7 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
- @defer.inlineCallbacks
- def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
+ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
""" Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
@@ -174,17 +174,15 @@ class FederationHandler(BaseHandler):
pdu (FrozenEvent): received PDU
sent_to_us_directly (bool): True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event.
-
- Returns (Deferred): completes with None
"""
room_id = pdu.room_id
event_id = pdu.event_id
- logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
+ logger.info("handling received PDU: %s", pdu)
# We reprocess pdus when we have seen them only as outliers
- existing = yield self.store.get_event(
+ existing = await self.store.get_event(
event_id, allow_none=True, allow_rejected=True
)
@@ -228,7 +226,7 @@ class FederationHandler(BaseHandler):
#
# Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version.
- is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
+ is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
"[%s %s] Ignoring PDU from %s as we're not in the room",
@@ -243,12 +241,12 @@ class FederationHandler(BaseHandler):
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
- min_depth = yield self.get_min_depth_for_context(pdu.room_id)
+ min_depth = await self.get_min_depth_for_context(pdu.room_id)
logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
prevs = set(pdu.prev_event_ids())
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this
@@ -268,7 +266,7 @@ class FederationHandler(BaseHandler):
len(missing_prevs),
shortstr(missing_prevs),
)
- with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
+ with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"[%s %s] Acquired room lock to fetch %d missing prev_events",
room_id,
@@ -276,13 +274,19 @@ class FederationHandler(BaseHandler):
len(missing_prevs),
)
- yield self._get_missing_events_for_pdu(
- origin, pdu, prevs, min_depth
- )
+ try:
+ await self._get_missing_events_for_pdu(
+ origin, pdu, prevs, min_depth
+ )
+ except Exception as e:
+ raise Exception(
+ "Error fetching missing prev_events for %s: %s"
+ % (event_id, e)
+ )
# Update the set of things we've seen after trying to
# fetch the missing stuff
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if not prevs - seen:
logger.info(
@@ -290,14 +294,6 @@ class FederationHandler(BaseHandler):
room_id,
event_id,
)
- elif missing_prevs:
- logger.info(
- "[%s %s] Not recursively fetching %d missing prev_events: %s",
- room_id,
- event_id,
- len(missing_prevs),
- shortstr(missing_prevs),
- )
if prevs - seen:
# We've still not been able to get all of the prev_events for this event.
@@ -342,12 +338,18 @@ class FederationHandler(BaseHandler):
affected=pdu.event_id,
)
+ logger.info(
+ "Event %s is missing prev_events: calculating state for a "
+ "backwards extremity",
+ event_id,
+ )
+
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
event_map = {event_id: pdu}
try:
# Get the state of the events we know about
- ours = yield self.state_store.get_state_groups_ids(room_id, seen)
+ ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list(
@@ -361,17 +363,14 @@ class FederationHandler(BaseHandler):
# know about
for p in prevs - seen:
logger.info(
- "[%s %s] Requesting state at missing prev_event %s",
- room_id,
- event_id,
- p,
+ "Requesting state at missing prev_event %s", event_id,
)
with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
- (remote_state, _,) = yield self._get_state_for_room(
+ (remote_state, _,) = await self._get_state_for_room(
origin, room_id, p, include_event_in_state=True
)
@@ -383,8 +382,8 @@ class FederationHandler(BaseHandler):
for x in remote_state:
event_map[x.event_id] = x
- room_version = yield self.store.get_room_version(room_id)
- state_map = yield resolve_events_with_store(
+ room_version = await self.store.get_room_version(room_id)
+ state_map = await resolve_events_with_store(
room_id,
room_version,
state_maps,
@@ -397,10 +396,10 @@ class FederationHandler(BaseHandler):
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
- evs = yield self.store.get_events(
+ evs = await self.store.get_events(
list(state_map.values()),
get_prev_content=False,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
)
event_map.update(evs)
@@ -420,10 +419,9 @@ class FederationHandler(BaseHandler):
affected=event_id,
)
- yield self._process_received_pdu(origin, pdu, state=state)
+ await self._process_received_pdu(origin, pdu, state=state)
- @defer.inlineCallbacks
- def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+ async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
"""
Args:
origin (str): Origin of the pdu. Will be called to get the missing events
@@ -435,12 +433,12 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id
event_id = pdu.event_id
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if not prevs - seen:
return
- latest = yield self.store.get_latest_event_ids_in_room(room_id)
+ latest = await self.store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
@@ -504,7 +502,7 @@ class FederationHandler(BaseHandler):
# All that said: Let's try increasing the timout to 60s and see what happens.
try:
- missing_events = yield self.federation_client.get_missing_events(
+ missing_events = await self.federation_client.get_missing_events(
origin,
room_id,
earliest_events_ids=list(latest),
@@ -543,7 +541,7 @@ class FederationHandler(BaseHandler):
)
with nested_logging_context(ev.event_id):
try:
- yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
+ await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e:
if e.code == 403:
logger.warning(
@@ -555,29 +553,30 @@ class FederationHandler(BaseHandler):
else:
raise
- @defer.inlineCallbacks
- @log_function
- def _get_state_for_room(
- self, destination, room_id, event_id, include_event_in_state
- ):
+ async def _get_state_for_room(
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ include_event_in_state: bool = False,
+ ) -> Tuple[List[EventBase], List[EventBase]]:
"""Requests all of the room state at a given event from a remote homeserver.
Args:
- destination (str): The remote homeserver to query for the state.
- room_id (str): The id of the room we're interested in.
- event_id (str): The id of the event we want the state at.
+ destination: The remote homeserver to query for the state.
+ room_id: The id of the room we're interested in.
+ event_id: The id of the event we want the state at.
include_event_in_state: if true, the event itself will be included in the
returned state event list.
Returns:
- Deferred[Tuple[List[EventBase], List[EventBase]]]:
- A list of events in the state, and a list of events in the auth chain
- for the given event.
+ A list of events in the state, possibly including the event itself, and
+ a list of events in the auth chain for the given event.
"""
(
state_event_ids,
auth_event_ids,
- ) = yield self.federation_client.get_room_state_ids(
+ ) = await self.federation_client.get_room_state_ids(
destination, room_id, event_id=event_id
)
@@ -586,15 +585,15 @@ class FederationHandler(BaseHandler):
if include_event_in_state:
desired_events.add(event_id)
- event_map = yield self._get_events_from_store_or_dest(
+ event_map = await self._get_events_from_store_or_dest(
destination, room_id, desired_events
)
failed_to_fetch = desired_events - event_map.keys()
if failed_to_fetch:
logger.warning(
- "Failed to fetch missing state/auth events for %s: %s",
- room_id,
+ "Failed to fetch missing state/auth events for %s %s",
+ event_id,
failed_to_fetch,
)
@@ -614,15 +613,11 @@ class FederationHandler(BaseHandler):
return remote_state, auth_chain
- @defer.inlineCallbacks
- def _get_events_from_store_or_dest(self, destination, room_id, event_ids):
+ async def _get_events_from_store_or_dest(
+ self, destination: str, room_id: str, event_ids: Iterable[str]
+ ) -> Dict[str, EventBase]:
"""Fetch events from a remote destination, checking if we already have them.
- Args:
- destination (str)
- room_id (str)
- event_ids (Iterable[str])
-
Persists any events we don't already have as outliers.
If we fail to fetch any of the events, a warning will be logged, and the event
@@ -630,10 +625,9 @@ class FederationHandler(BaseHandler):
be in the given room.
Returns:
- Deferred[dict[str, EventBase]]: A deferred resolving to a map
- from event_id to event
+ map from event_id to event
"""
- fetched_events = yield self.store.get_events(event_ids, allow_rejected=True)
+ fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
missing_events = set(event_ids) - fetched_events.keys()
@@ -644,14 +638,14 @@ class FederationHandler(BaseHandler):
room_id,
)
- yield self._get_events_and_persist(
+ await self._get_events_and_persist(
destination=destination, room_id=room_id, events=missing_events
)
# we need to make sure we re-load from the database to get the rejected
# state correct.
fetched_events.update(
- (yield self.store.get_events(missing_events, allow_rejected=True))
+ (await self.store.get_events(missing_events, allow_rejected=True))
)
# check for events which were in the wrong room.
@@ -677,12 +671,14 @@ class FederationHandler(BaseHandler):
bad_room_id,
room_id,
)
+
del fetched_events[bad_event_id]
return fetched_events
- @defer.inlineCallbacks
- def _process_received_pdu(self, origin, event, state):
+ async def _process_received_pdu(
+ self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]],
+ ):
""" Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
@@ -701,15 +697,15 @@ class FederationHandler(BaseHandler):
logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
try:
- context = yield self._handle_new_event(origin, event, state=state)
+ context = await self._handle_new_event(origin, event, state=state)
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
- room = yield self.store.get_room(room_id)
+ room = await self.store.get_room(room_id)
if not room:
try:
- yield self.store.store_room(
+ await self.store.store_room(
room_id=room_id, room_creator_user_id="", is_public=False
)
except StoreError:
@@ -722,11 +718,11 @@ class FederationHandler(BaseHandler):
# changing their profile info.
newly_joined = True
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = await context.get_prev_state_ids()
prev_state_id = prev_state_ids.get((event.type, event.state_key))
if prev_state_id:
- prev_state = yield self.store.get_event(
+ prev_state = await self.store.get_event(
prev_state_id, allow_none=True
)
if prev_state and prev_state.membership == Membership.JOIN:
@@ -734,11 +730,10 @@ class FederationHandler(BaseHandler):
if newly_joined:
user = UserID.from_string(event.state_key)
- yield self.user_joined_room(user, room_id)
+ await self.user_joined_room(user, room_id)
@log_function
- @defer.inlineCallbacks
- def backfill(self, dest, room_id, limit, extremities):
+ async def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side
@@ -755,7 +750,7 @@ class FederationHandler(BaseHandler):
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
- events = yield self.federation_client.backfill(
+ events = await self.federation_client.backfill(
dest, room_id, limit=limit, extremities=extremities
)
@@ -770,7 +765,7 @@ class FederationHandler(BaseHandler):
# self._sanity_check_event(ev)
# Don't bother processing events we already have.
- seen_events = yield self.store.have_events_in_timeline(
+ seen_events = await self.store.have_events_in_timeline(
set(e.event_id for e in events)
)
@@ -796,7 +791,7 @@ class FederationHandler(BaseHandler):
state_events = {}
events_to_state = {}
for e_id in edges:
- state, auth = yield self._get_state_for_room(
+ state, auth = await self._get_state_for_room(
destination=dest,
room_id=room_id,
event_id=e_id,
@@ -843,7 +838,7 @@ class FederationHandler(BaseHandler):
)
)
- yield self._handle_new_events(dest, ev_infos, backfilled=True)
+ await self._handle_new_events(dest, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
@@ -859,16 +854,15 @@ class FederationHandler(BaseHandler):
# We store these one at a time since each event depends on the
# previous to work out the state.
# TODO: We can probably do something more clever here.
- yield self._handle_new_event(dest, event, backfilled=True)
+ await self._handle_new_event(dest, event, backfilled=True)
return events
- @defer.inlineCallbacks
- def maybe_backfill(self, room_id, current_depth):
+ async def maybe_backfill(self, room_id, current_depth):
"""Checks the database to see if we should backfill before paginating,
and if so do.
"""
- extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id)
+ extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
if not extremities:
logger.debug("Not backfilling as no extremeties found.")
@@ -900,15 +894,17 @@ class FederationHandler(BaseHandler):
# state *before* the event, ignoring the special casing certain event
# types have.
- forward_events = yield self.store.get_successor_events(list(extremities))
+ forward_events = await self.store.get_successor_events(list(extremities))
- extremities_events = yield self.store.get_events(
- forward_events, check_redacted=False, get_prev_content=False
+ extremities_events = await self.store.get_events(
+ forward_events,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ get_prev_content=False,
)
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
- filtered_extremities = yield filter_events_for_server(
+ filtered_extremities = await filter_events_for_server(
self.storage,
self.server_name,
list(extremities_events.values()),
@@ -938,7 +934,7 @@ class FederationHandler(BaseHandler):
# First we try hosts that are already in the room
# TODO: HEURISTIC ALERT.
- curr_state = yield self.state_handler.get_current_state(room_id)
+ curr_state = await self.state_handler.get_current_state(room_id)
def get_domains_from_state(state):
"""Get joined domains from state
@@ -977,12 +973,11 @@ class FederationHandler(BaseHandler):
domain for domain, depth in curr_domains if domain != self.server_name
]
- @defer.inlineCallbacks
- def try_backfill(domains):
+ async def try_backfill(domains):
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
- yield self.backfill(
+ await self.backfill(
dom, room_id, limit=100, extremities=extremities
)
# If this succeeded then we probably already have the
@@ -1013,7 +1008,7 @@ class FederationHandler(BaseHandler):
return False
- success = yield try_backfill(likely_domains)
+ success = await try_backfill(likely_domains)
if success:
return True
@@ -1027,7 +1022,7 @@ class FederationHandler(BaseHandler):
logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
- states = yield make_deferred_yieldable(
+ states = await make_deferred_yieldable(
defer.gatherResults(
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True
)
@@ -1037,7 +1032,7 @@ class FederationHandler(BaseHandler):
# event_ids.
states = dict(zip(event_ids, [s.state for s in states]))
- state_map = yield self.store.get_events(
+ state_map = await self.store.get_events(
[e_id for ids in itervalues(states) for e_id in itervalues(ids)],
get_prev_content=False,
)
@@ -1053,7 +1048,7 @@ class FederationHandler(BaseHandler):
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
- success = yield try_backfill(
+ success = await try_backfill(
[dom for dom, _ in likely_domains if dom not in tried_domains]
)
if success:
@@ -1063,8 +1058,7 @@ class FederationHandler(BaseHandler):
return False
- @defer.inlineCallbacks
- def _get_events_and_persist(
+ async def _get_events_and_persist(
self, destination: str, room_id: str, events: Iterable[str]
):
"""Fetch the given events from a server, and persist them as outliers.
@@ -1072,7 +1066,7 @@ class FederationHandler(BaseHandler):
Logs a warning if we can't find the given event.
"""
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
event_infos = []
@@ -1108,9 +1102,9 @@ class FederationHandler(BaseHandler):
e,
)
- yield concurrently_execute(get_event, events, 5)
+ await concurrently_execute(get_event, events, 5)
- yield self._handle_new_events(
+ await self._handle_new_events(
destination, event_infos,
)
@@ -1253,7 +1247,7 @@ class FederationHandler(BaseHandler):
# Check whether this room is the result of an upgrade of a room we already know
# about. If so, migrate over user information
predecessor = yield self.store.get_room_predecessor(room_id)
- if not predecessor:
+ if not predecessor or not isinstance(predecessor.get("room_id"), str):
return
old_room_id = predecessor["room_id"]
logger.debug(
@@ -1281,8 +1275,7 @@ class FederationHandler(BaseHandler):
return True
- @defer.inlineCallbacks
- def _handle_queued_pdus(self, room_queue):
+ async def _handle_queued_pdus(self, room_queue):
"""Process PDUs which got queued up while we were busy send_joining.
Args:
@@ -1298,7 +1291,7 @@ class FederationHandler(BaseHandler):
p.room_id,
)
with nested_logging_context(p.event_id):
- yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
+ await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
except Exception as e:
logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
@@ -1428,7 +1421,7 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key)
yield self.user_joined_room(user, event.room_id)
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
state_ids = list(prev_state_ids.values())
auth_chain = yield self.store.get_auth_chain(state_ids)
@@ -1496,7 +1489,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content):
origin, event, event_format_version = yield self._make_and_verify_event(
- target_hosts, room_id, user_id, "leave", content=content,
+ target_hosts, room_id, user_id, "leave", content=content
)
# Mark as outlier as we don't have any state for this event; we're not
# even in the room.
@@ -1937,7 +1930,7 @@ class FederationHandler(BaseHandler):
context = yield self.state_handler.compute_event_context(event, old_state=state)
if not auth_events:
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
auth_events_ids = yield self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
@@ -2346,12 +2339,12 @@ class FederationHandler(BaseHandler):
k: a.event_id for k, a in iteritems(auth_events) if k != event_key
}
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
current_state_ids = dict(current_state_ids)
current_state_ids.update(state_updates)
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = dict(prev_state_ids)
prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)})
@@ -2635,7 +2628,7 @@ class FederationHandler(BaseHandler):
event.content["third_party_invite"]["signed"]["token"],
)
original_invite = None
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
original_invite_id = prev_state_ids.get(key)
if original_invite_id:
original_invite = yield self.store.get_event(
@@ -2683,7 +2676,7 @@ class FederationHandler(BaseHandler):
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
invite_event = None
@@ -2857,7 +2850,7 @@ class FederationHandler(BaseHandler):
room_id=room_id, user_id=user.to_string(), change="joined"
)
else:
- return user_joined_room(self.distributor, user, room_id)
+ return defer.succeed(user_joined_room(self.distributor, user, room_id))
@defer.inlineCallbacks
def get_room_complexity(self, remote_room_hosts, room_id):
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 81dce96f4b..44ec3e66ae 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -26,7 +26,7 @@ from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken, UserID
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import concurrently_execute
-from synapse.util.caches.snapshot_cache import SnapshotCache
+from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -41,7 +41,7 @@ class InitialSyncHandler(BaseHandler):
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
- self.snapshot_cache = SnapshotCache()
+ self.snapshot_cache = ResponseCache(hs, "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
@@ -79,21 +79,17 @@ class InitialSyncHandler(BaseHandler):
as_client_event,
include_archived,
)
- now_ms = self.clock.time_msec()
- result = self.snapshot_cache.get(now_ms, key)
- if result is not None:
- return result
- return self.snapshot_cache.set(
- now_ms,
+ return self.snapshot_cache.wrap(
key,
- self._snapshot_all_rooms(
- user_id, pagin_config, as_client_event, include_archived
- ),
+ self._snapshot_all_rooms,
+ user_id,
+ pagin_config,
+ as_client_event,
+ include_archived,
)
- @defer.inlineCallbacks
- def _snapshot_all_rooms(
+ async def _snapshot_all_rooms(
self,
user_id=None,
pagin_config=None,
@@ -105,7 +101,7 @@ class InitialSyncHandler(BaseHandler):
if include_archived:
memberships.append(Membership.LEAVE)
- room_list = yield self.store.get_rooms_for_user_where_membership_is(
+ room_list = await self.store.get_rooms_for_user_where_membership_is(
user_id=user_id, membership_list=memberships
)
@@ -113,33 +109,32 @@ class InitialSyncHandler(BaseHandler):
rooms_ret = []
- now_token = yield self.hs.get_event_sources().get_current_token()
+ now_token = await self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token)
- presence, _ = yield presence_stream.get_pagination_rows(
+ presence, _ = await presence_stream.get_pagination_rows(
user, pagination_config.get_source_config("presence"), None
)
receipt_stream = self.hs.get_event_sources().sources["receipt"]
- receipt, _ = yield receipt_stream.get_pagination_rows(
+ receipt, _ = await receipt_stream.get_pagination_rows(
user, pagination_config.get_source_config("receipt"), None
)
- tags_by_room = yield self.store.get_tags_for_user(user_id)
+ tags_by_room = await self.store.get_tags_for_user(user_id)
- account_data, account_data_by_room = yield self.store.get_account_data_for_user(
+ account_data, account_data_by_room = await self.store.get_account_data_for_user(
user_id
)
- public_room_ids = yield self.store.get_public_room_ids()
+ public_room_ids = await self.store.get_public_room_ids()
limit = pagin_config.limit
if limit is None:
limit = 10
- @defer.inlineCallbacks
- def handle_room(event):
+ async def handle_room(event):
d = {
"room_id": event.room_id,
"membership": event.membership,
@@ -152,8 +147,8 @@ class InitialSyncHandler(BaseHandler):
time_now = self.clock.time_msec()
d["inviter"] = event.sender
- invite_event = yield self.store.get_event(event.event_id)
- d["invite"] = yield self._event_serializer.serialize_event(
+ invite_event = await self.store.get_event(event.event_id)
+ d["invite"] = await self._event_serializer.serialize_event(
invite_event, time_now, as_client_event
)
@@ -177,7 +172,7 @@ class InitialSyncHandler(BaseHandler):
lambda states: states[event.event_id]
)
- (messages, token), current_state = yield make_deferred_yieldable(
+ (messages, token), current_state = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@@ -191,7 +186,7 @@ class InitialSyncHandler(BaseHandler):
)
).addErrback(unwrapFirstError)
- messages = yield filter_events_for_client(
+ messages = await filter_events_for_client(
self.storage, user_id, messages
)
@@ -201,7 +196,7 @@ class InitialSyncHandler(BaseHandler):
d["messages"] = {
"chunk": (
- yield self._event_serializer.serialize_events(
+ await self._event_serializer.serialize_events(
messages, time_now=time_now, as_client_event=as_client_event
)
),
@@ -209,7 +204,7 @@ class InitialSyncHandler(BaseHandler):
"end": end_token.to_string(),
}
- d["state"] = yield self._event_serializer.serialize_events(
+ d["state"] = await self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
as_client_event=as_client_event,
@@ -232,7 +227,7 @@ class InitialSyncHandler(BaseHandler):
except Exception:
logger.exception("Failed to get snapshot")
- yield concurrently_execute(handle_room, room_list, 10)
+ await concurrently_execute(handle_room, room_list, 10)
account_data_events = []
for account_data_type, content in account_data.items():
@@ -256,8 +251,7 @@ class InitialSyncHandler(BaseHandler):
return ret
- @defer.inlineCallbacks
- def room_initial_sync(self, requester, room_id, pagin_config=None):
+ async def room_initial_sync(self, requester, room_id, pagin_config=None):
"""Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left.
@@ -274,32 +268,32 @@ class InitialSyncHandler(BaseHandler):
A JSON serialisable dict with the snapshot of the room.
"""
- blocked = yield self.store.is_room_blocked(room_id)
+ blocked = await self.store.is_room_blocked(room_id)
if blocked:
raise SynapseError(403, "This room has been blocked on this server")
user_id = requester.user.to_string()
- membership, member_event_id = yield self._check_in_room_or_world_readable(
+ membership, member_event_id = await self._check_in_room_or_world_readable(
room_id, user_id
)
is_peeking = member_event_id is None
if membership == Membership.JOIN:
- result = yield self._room_initial_sync_joined(
+ result = await self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_peeking
)
elif membership == Membership.LEAVE:
- result = yield self._room_initial_sync_parted(
+ result = await self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
)
account_data_events = []
- tags = yield self.store.get_tags_for_room(user_id, room_id)
+ tags = await self.store.get_tags_for_room(user_id, room_id)
if tags:
account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
- account_data = yield self.store.get_account_data_for_room(user_id, room_id)
+ account_data = await self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items():
account_data_events.append({"type": account_data_type, "content": content})
@@ -307,11 +301,10 @@ class InitialSyncHandler(BaseHandler):
return result
- @defer.inlineCallbacks
- def _room_initial_sync_parted(
+ async def _room_initial_sync_parted(
self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
):
- room_state = yield self.state_store.get_state_for_events([member_event_id])
+ room_state = await self.state_store.get_state_for_events([member_event_id])
room_state = room_state[member_event_id]
@@ -319,13 +312,13 @@ class InitialSyncHandler(BaseHandler):
if limit is None:
limit = 10
- stream_token = yield self.store.get_stream_token_for_event(member_event_id)
+ stream_token = await self.store.get_stream_token_for_event(member_event_id)
- messages, token = yield self.store.get_recent_events_for_room(
+ messages, token = await self.store.get_recent_events_for_room(
room_id, limit=limit, end_token=stream_token
)
- messages = yield filter_events_for_client(
+ messages = await filter_events_for_client(
self.storage, user_id, messages, is_peeking=is_peeking
)
@@ -339,13 +332,13 @@ class InitialSyncHandler(BaseHandler):
"room_id": room_id,
"messages": {
"chunk": (
- yield self._event_serializer.serialize_events(messages, time_now)
+ await self._event_serializer.serialize_events(messages, time_now)
),
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": (
- yield self._event_serializer.serialize_events(
+ await self._event_serializer.serialize_events(
room_state.values(), time_now
)
),
@@ -353,19 +346,18 @@ class InitialSyncHandler(BaseHandler):
"receipts": [],
}
- @defer.inlineCallbacks
- def _room_initial_sync_joined(
+ async def _room_initial_sync_joined(
self, user_id, room_id, pagin_config, membership, is_peeking
):
- current_state = yield self.state.get_current_state(room_id=room_id)
+ current_state = await self.state.get_current_state(room_id=room_id)
# TODO: These concurrently
time_now = self.clock.time_msec()
- state = yield self._event_serializer.serialize_events(
+ state = await self._event_serializer.serialize_events(
current_state.values(), time_now
)
- now_token = yield self.hs.get_event_sources().get_current_token()
+ now_token = await self.hs.get_event_sources().get_current_token()
limit = pagin_config.limit if pagin_config else None
if limit is None:
@@ -380,28 +372,26 @@ class InitialSyncHandler(BaseHandler):
presence_handler = self.hs.get_presence_handler()
- @defer.inlineCallbacks
- def get_presence():
+ async def get_presence():
# If presence is disabled, return an empty list
if not self.hs.config.use_presence:
return []
- states = yield presence_handler.get_states(
+ states = await presence_handler.get_states(
[m.user_id for m in room_members], as_event=True
)
return states
- @defer.inlineCallbacks
- def get_receipts():
- receipts = yield self.store.get_linearized_receipts_for_room(
+ async def get_receipts():
+ receipts = await self.store.get_linearized_receipts_for_room(
room_id, to_key=now_token.receipt_key
)
if not receipts:
receipts = []
return receipts
- presence, receipts, (messages, token) = yield make_deferred_yieldable(
+ presence, receipts, (messages, token) = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(get_presence),
@@ -417,7 +407,7 @@ class InitialSyncHandler(BaseHandler):
).addErrback(unwrapFirstError)
)
- messages = yield filter_events_for_client(
+ messages = await filter_events_for_client(
self.storage, user_id, messages, is_peeking=is_peeking
)
@@ -430,7 +420,7 @@ class InitialSyncHandler(BaseHandler):
"room_id": room_id,
"messages": {
"chunk": (
- yield self._event_serializer.serialize_events(messages, time_now)
+ await self._event_serializer.serialize_events(messages, time_now)
),
"start": start_token.to_string(),
"end": end_token.to_string(),
@@ -444,18 +434,17 @@ class InitialSyncHandler(BaseHandler):
return ret
- @defer.inlineCallbacks
- def _check_in_room_or_world_readable(self, room_id, user_id):
+ async def _check_in_room_or_world_readable(self, room_id, user_id):
try:
# check_user_was_in_room will return the most recent membership
# event for the user if:
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
- member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
+ member_event = await self.auth.check_user_was_in_room(room_id, user_id)
return member_event.membership, member_event.event_id
except AuthError:
- visibility = yield self.state_handler.get_current_state(
+ visibility = await self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 54fa216d83..4ad752205f 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -46,6 +46,7 @@ from synapse.events.validator import EventValidator
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
from synapse.util.async_helpers import Linearizer
@@ -514,7 +515,7 @@ class EventCreationHandler(object):
# federation as well as those created locally. As of room v3, aliases events
# can be created by users that are not in the room, therefore we have to
# tolerate them in event_auth.check().
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
prev_event = (
yield self.store.get_event(prev_event_id, allow_none=True)
@@ -664,7 +665,7 @@ class EventCreationHandler(object):
If so, returns the version of the event in context.
Otherwise, returns None.
"""
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
prev_event_id = prev_state_ids.get((event.type, event.state_key))
if not prev_event_id:
return
@@ -875,7 +876,7 @@ class EventCreationHandler(object):
if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event(
event.redacts,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
allow_rejected=False,
allow_none=True,
@@ -913,7 +914,7 @@ class EventCreationHandler(object):
def is_inviter_member_event(e):
return e.type == EventTypes.Member and e.sender == event.sender
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
state_to_include_ids = [
e_id
@@ -952,7 +953,7 @@ class EventCreationHandler(object):
if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event(
event.redacts,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
allow_rejected=False,
allow_none=True,
@@ -966,7 +967,7 @@ class EventCreationHandler(object):
if original_event.room_id != event.room_id:
raise SynapseError(400, "Cannot redact event from a different room")
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
auth_events_ids = yield self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
@@ -988,7 +989,7 @@ class EventCreationHandler(object):
event.internal_metadata.recheck_redaction = False
if event.type == EventTypes.Create:
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
if prev_state_ids:
raise AuthError(403, "Changing the room create event is forbidden")
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 8514ddc600..00a6afc963 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -280,8 +280,7 @@ class PaginationHandler(object):
await self.storage.purge_events.purge_room(room_id)
- @defer.inlineCallbacks
- def get_messages(
+ async def get_messages(
self,
requester,
room_id=None,
@@ -307,7 +306,7 @@ class PaginationHandler(object):
room_token = pagin_config.from_token.room_key
else:
pagin_config.from_token = (
- yield self.hs.get_event_sources().get_current_token_for_pagination()
+ await self.hs.get_event_sources().get_current_token_for_pagination()
)
room_token = pagin_config.from_token.room_key
@@ -319,11 +318,11 @@ class PaginationHandler(object):
source_config = pagin_config.get_source_config("room")
- with (yield self.pagination_lock.read(room_id)):
+ with (await self.pagination_lock.read(room_id)):
(
membership,
member_event_id,
- ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
+ ) = await self.auth.check_in_room_or_world_readable(room_id, user_id)
if source_config.direction == "b":
# if we're going backwards, we might need to backfill. This
@@ -331,7 +330,7 @@ class PaginationHandler(object):
if room_token.topological:
max_topo = room_token.topological
else:
- max_topo = yield self.store.get_max_topological_token(
+ max_topo = await self.store.get_max_topological_token(
room_id, room_token.stream
)
@@ -339,18 +338,18 @@ class PaginationHandler(object):
# If they have left the room then clamp the token to be before
# they left the room, to save the effort of loading from the
# database.
- leave_token = yield self.store.get_topological_token_for_event(
+ leave_token = await self.store.get_topological_token_for_event(
member_event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < max_topo:
source_config.from_key = str(leave_token)
- yield self.hs.get_handlers().federation_handler.maybe_backfill(
+ await self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, max_topo
)
- events, next_key = yield self.store.paginate_room_events(
+ events, next_key = await self.store.paginate_room_events(
room_id=room_id,
from_key=source_config.from_key,
to_key=source_config.to_key,
@@ -365,7 +364,7 @@ class PaginationHandler(object):
if event_filter:
events = event_filter.filter(events)
- events = yield filter_events_for_client(
+ events = await filter_events_for_client(
self.storage, user_id, events, is_peeking=(member_event_id is None)
)
@@ -385,19 +384,19 @@ class PaginationHandler(object):
(EventTypes.Member, event.sender) for event in events
)
- state_ids = yield self.state_store.get_state_ids_for_event(
+ state_ids = await self.state_store.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter
)
if state_ids:
- state = yield self.store.get_events(list(state_ids.values()))
+ state = await self.store.get_events(list(state_ids.values()))
state = state.values()
time_now = self.clock.time_msec()
chunk = {
"chunk": (
- yield self._event_serializer.serialize_events(
+ await self._event_serializer.serialize_events(
events, time_now, as_client_event=as_client_event
)
),
@@ -406,7 +405,7 @@ class PaginationHandler(object):
}
if state:
- chunk["state"] = yield self._event_serializer.serialize_events(
+ chunk["state"] = await self._event_serializer.serialize_events(
state, time_now, as_client_event=as_client_event
)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index eda15bc623..240c4add12 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -230,7 +230,7 @@ class PresenceHandler(object):
is some spurious presence changes that will self-correct.
"""
# If the DB pool has already terminated, don't try updating
- if not self.hs.get_db_pool().running:
+ if not self.store.database.is_running():
return
logger.info(
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 1e5a4613c9..f9579d69ee 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -295,12 +295,16 @@ class BaseProfileHandler(BaseHandler):
be found to be in any room the server is in, and therefore the query
is denied.
"""
+
# Implementation of MSC1301: don't allow looking up profiles if the
# requester isn't in the same room as the target. We expect requester to
# be None when this function is called outside of a profile query, e.g.
# when building a membership event. In this case, we must allow the
# lookup.
- if not self.hs.config.require_auth_for_profile_requests or not requester:
+ if (
+ not self.hs.config.limit_profile_requests_to_users_who_share_rooms
+ or not requester
+ ):
return
# Always allow the user to query their own profile.
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 60b8bbc7a5..89c9118b26 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -184,7 +184,7 @@ class RoomCreationHandler(BaseHandler):
requester, tombstone_event, tombstone_context
)
- old_room_state = yield tombstone_context.get_current_state_ids(self.store)
+ old_room_state = yield tombstone_context.get_current_state_ids()
# update any aliases
yield self._move_aliases_to_new_room(
@@ -1011,15 +1011,3 @@ class RoomEventSource(object):
def get_current_key_for_room(self, room_id):
return self.store.get_room_events_max_id(room_id)
-
- @defer.inlineCallbacks
- def get_pagination_rows(self, user, config, key):
- events, next_key = yield self.store.paginate_room_events(
- room_id=key,
- from_key=config.from_key,
- to_key=config.to_key,
- direction=config.direction,
- limit=config.limit,
- )
-
- return (events, next_key)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 7b7270fc61..44c5e3239c 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -193,7 +193,7 @@ class RoomMemberHandler(object):
requester, event, context, extra_users=[target], ratelimit=ratelimit
)
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
@@ -601,7 +601,7 @@ class RoomMemberHandler(object):
if prev_event is not None:
return
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
if event.membership == Membership.JOIN:
if requester.is_guest:
guest_can_join = yield self._can_guest_join(prev_state_ids)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index cc9e6b9bd0..0082f85c26 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -13,20 +13,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import re
+from typing import Tuple
import attr
import saml2
+import saml2.response
from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
+from synapse.config import ConfigError
from synapse.http.servlet import parse_string
from synapse.rest.client.v1.login import SSOAuthHandler
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import (
+ UserID,
+ map_username_to_mxid_localpart,
+ mxid_localpart_allowed_characters,
+)
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
+@attr.s
+class Saml2SessionData:
+ """Data we track about SAML2 sessions"""
+
+ # time the session was created, in milliseconds
+ creation_time = attr.ib()
+
+
class SamlHandler:
def __init__(self, hs):
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
@@ -37,11 +53,14 @@ class SamlHandler:
self._datastore = hs.get_datastore()
self._hostname = hs.hostname
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
- self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute
)
- self._mxid_mapper = hs.config.saml2_mxid_mapper
+
+ # plugin to do custom mapping from saml response to mxid
+ self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
+ hs.config.saml2_user_mapping_provider_config
+ )
# identifier for the external_ids table
self._auth_provider_id = "saml"
@@ -118,22 +137,10 @@ class SamlHandler:
remote_user_id = saml2_auth.ava["uid"][0]
except KeyError:
logger.warning("SAML2 response lacks a 'uid' attestation")
- raise SynapseError(400, "uid not in SAML2 response")
-
- try:
- mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
- except KeyError:
- logger.warning(
- "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
- )
- raise SynapseError(
- 400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
- )
+ raise SynapseError(400, "'uid' not in SAML2 response")
self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
- displayName = saml2_auth.ava.get("displayName", [None])[0]
-
with (await self._mapping_lock.queue(self._auth_provider_id)):
# first of all, check if we already have a mapping for this user
logger.info(
@@ -173,22 +180,46 @@ class SamlHandler:
)
return registered_user_id
- # figure out a new mxid for this user
- base_mxid_localpart = self._mxid_mapper(mxid_source)
+ # Map saml response to user attributes using the configured mapping provider
+ for i in range(1000):
+ attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
+ saml2_auth, i
+ )
+
+ logger.debug(
+ "Retrieved SAML attributes from user mapping provider: %s "
+ "(attempt %d)",
+ attribute_dict,
+ i,
+ )
+
+ localpart = attribute_dict.get("mxid_localpart")
+ if not localpart:
+ logger.error(
+ "SAML mapping provider plugin did not return a "
+ "mxid_localpart object"
+ )
+ raise SynapseError(500, "Error parsing SAML2 response")
- suffix = 0
- while True:
- localpart = base_mxid_localpart + (str(suffix) if suffix else "")
+ displayname = attribute_dict.get("displayname")
+
+ # Check if this mxid already exists
if not await self._datastore.get_users_by_id_case_insensitive(
UserID(localpart, self._hostname).to_string()
):
+ # This mxid is free
break
- suffix += 1
- logger.info("Allocating mxid for new user with localpart %s", localpart)
+ else:
+ # Unable to generate a username in 1000 iterations
+ # Break and return error to the user
+ raise SynapseError(
+ 500, "Unable to generate a Matrix ID from the SAML response"
+ )
registered_user_id = await self._registration_handler.register_user(
- localpart=localpart, default_display_name=displayName
+ localpart=localpart, default_display_name=displayname
)
+
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
@@ -205,9 +236,120 @@ class SamlHandler:
del self._outstanding_requests_dict[reqid]
+DOT_REPLACE_PATTERN = re.compile(
+ ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+)
+
+
+def dot_replace_for_mxid(username: str) -> str:
+ username = username.lower()
+ username = DOT_REPLACE_PATTERN.sub(".", username)
+
+ # regular mxids aren't allowed to start with an underscore either
+ username = re.sub("^_", "", username)
+ return username
+
+
+MXID_MAPPER_MAP = {
+ "hexencode": map_username_to_mxid_localpart,
+ "dotreplace": dot_replace_for_mxid,
+}
+
+
@attr.s
-class Saml2SessionData:
- """Data we track about SAML2 sessions"""
+class SamlConfig(object):
+ mxid_source_attribute = attr.ib()
+ mxid_mapper = attr.ib()
- # time the session was created, in milliseconds
- creation_time = attr.ib()
+
+class DefaultSamlMappingProvider(object):
+ __version__ = "0.0.1"
+
+ def __init__(self, parsed_config: SamlConfig):
+ """The default SAML user mapping provider
+
+ Args:
+ parsed_config: Module configuration
+ """
+ self._mxid_source_attribute = parsed_config.mxid_source_attribute
+ self._mxid_mapper = parsed_config.mxid_mapper
+
+ def saml_response_to_user_attributes(
+ self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
+ ) -> dict:
+ """Maps some text from a SAML response to attributes of a new user
+
+ Args:
+ saml_response: A SAML auth response object
+
+ failures: How many times a call to this function with this
+ saml_response has resulted in a failure
+
+ Returns:
+ dict: A dict containing new user attributes. Possible keys:
+ * mxid_localpart (str): Required. The localpart of the user's mxid
+ * displayname (str): The displayname of the user
+ """
+ try:
+ mxid_source = saml_response.ava[self._mxid_source_attribute][0]
+ except KeyError:
+ logger.warning(
+ "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
+ )
+ raise SynapseError(
+ 400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+ )
+
+ # Use the configured mapper for this mxid_source
+ base_mxid_localpart = self._mxid_mapper(mxid_source)
+
+ # Append suffix integer if last call to this function failed to produce
+ # a usable mxid
+ localpart = base_mxid_localpart + (str(failures) if failures else "")
+
+ # Retrieve the display name from the saml response
+ # If displayname is None, the mxid_localpart will be used instead
+ displayname = saml_response.ava.get("displayName", [None])[0]
+
+ return {
+ "mxid_localpart": localpart,
+ "displayname": displayname,
+ }
+
+ @staticmethod
+ def parse_config(config: dict) -> SamlConfig:
+ """Parse the dict provided by the homeserver's config
+ Args:
+ config: A dictionary containing configuration options for this provider
+ Returns:
+ SamlConfig: A custom config object for this module
+ """
+ # Parse config options and use defaults where necessary
+ mxid_source_attribute = config.get("mxid_source_attribute", "uid")
+ mapping_type = config.get("mxid_mapping", "hexencode")
+
+ # Retrieve the associating mapping function
+ try:
+ mxid_mapper = MXID_MAPPER_MAP[mapping_type]
+ except KeyError:
+ raise ConfigError(
+ "saml2_config.user_mapping_provider.config: '%s' is not a valid "
+ "mxid_mapping value" % (mapping_type,)
+ )
+
+ return SamlConfig(mxid_source_attribute, mxid_mapper)
+
+ @staticmethod
+ def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+ """Returns the required attributes of a SAML
+
+ Args:
+ config: A SamlConfig object containing configuration params for this provider
+
+ Returns:
+ tuple[set,set]: The first set equates to the saml auth response
+ attributes that are required for the module to function, whereas the
+ second set consists of those attributes which can be used if
+ available, but are not necessary
+ """
+ return {"uid", config.mxid_source_attribute}, {"displayName"}
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 56ed262a1f..ef750d1497 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -21,7 +21,7 @@ from unpaddedbase64 import decode_base64, encode_base64
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import SynapseError
+from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.storage.state import StateFilter
from synapse.visibility import filter_events_for_client
@@ -37,6 +37,7 @@ class SearchHandler(BaseHandler):
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id):
@@ -53,23 +54,38 @@ class SearchHandler(BaseHandler):
room_id (str): id of the room to search through.
Returns:
- Deferred[iterable[unicode]]: predecessor room ids
+ Deferred[iterable[str]]: predecessor room ids
"""
historical_room_ids = []
- while True:
- predecessor = yield self.store.get_room_predecessor(room_id)
+ # The initial room must have been known for us to get this far
+ predecessor = yield self.store.get_room_predecessor(room_id)
- # If no predecessor, assume we've hit a dead end
+ while True:
if not predecessor:
+ # We have reached the end of the chain of predecessors
+ break
+
+ if not isinstance(predecessor.get("room_id"), str):
+ # This predecessor object is malformed. Exit here
+ break
+
+ predecessor_room_id = predecessor["room_id"]
+
+ # Don't add it to the list until we have checked that we are in the room
+ try:
+ next_predecessor_room = yield self.store.get_room_predecessor(
+ predecessor_room_id
+ )
+ except NotFoundError:
+ # The predecessor is not a known room, so we are done here
break
- # Add predecessor's room ID
- historical_room_ids.append(predecessor["room_id"])
+ historical_room_ids.append(predecessor_room_id)
- # Scan through the old room for further predecessors
- room_id = predecessor["room_id"]
+ # And repeat
+ predecessor = next_predecessor_room
return historical_room_ids
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 6f78454322..b635c339ed 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -317,6 +317,3 @@ class TypingNotificationEventSource(object):
def get_current_key(self):
return self.get_typing_handler()._latest_room_serial
-
- def get_pagination_rows(self, user, pagination_config, key):
- return [], pagination_config.from_key
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index 03934956f4..c0b9384189 100644
--- a/synapse/logging/_terse_json.py
+++ b/synapse/logging/_terse_json.py
@@ -171,7 +171,7 @@ class LogProducer(object):
def stopProducing(self):
self._paused = True
- self._buffer = None
+ self._buffer = deque()
def resumeProducing(self):
self._paused = False
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 2c1fb9ddac..33b322209d 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -23,6 +23,7 @@ them.
See doc/log_contexts.rst for details on how this works.
"""
+import inspect
import logging
import threading
import types
@@ -404,6 +405,9 @@ class LoggingContext(object):
"""
current = get_thread_resource_usage()
+ # Indicate to mypy that we know that self.usage_start is None.
+ assert self.usage_start is not None
+
utime_delta = current.ru_utime - self.usage_start.ru_utime
stime_delta = current.ru_stime - self.usage_start.ru_stime
@@ -612,7 +616,8 @@ def run_in_background(f, *args, **kwargs):
def make_deferred_yieldable(deferred):
- """Given a deferred, make it follow the Synapse logcontext rules:
+ """Given a deferred (or coroutine), make it follow the Synapse logcontext
+ rules:
If the deferred has completed (or is not actually a Deferred), essentially
does nothing (just returns another completed deferred with the
@@ -624,6 +629,13 @@ def make_deferred_yieldable(deferred):
(This is more-or-less the opposite operation to run_in_background.)
"""
+ if inspect.isawaitable(deferred):
+ # If we're given a coroutine we convert it to a deferred so that we
+ # run it and find out if it immediately finishes, it it does then we
+ # don't need to fiddle with log contexts at all and can return
+ # immediately.
+ deferred = defer.ensureDeferred(deferred)
+
if not isinstance(deferred, defer.Deferred):
return deferred
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 7881780760..7d9f5a38d9 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -116,7 +116,7 @@ class BulkPushRuleEvaluator(object):
@defer.inlineCallbacks
def _get_power_levels_and_sender_level(self, event, context):
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
pl_event_id = prev_state_ids.get(POWER_KEY)
if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and
@@ -304,7 +304,7 @@ class RulesForRoom(object):
push_rules_delta_state_cache_metric.inc_hits()
else:
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
push_rules_delta_state_cache_metric.inc_misses()
push_rules_state_size_counter.inc(len(current_state_ids))
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index f277aeb131..8ad0bf5936 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -80,9 +80,11 @@ class PusherFactory(object):
return EmailPusher(self.hs, pusherdict, mailer)
def _app_name_from_pusherdict(self, pusherdict):
- if "data" in pusherdict and "brand" in pusherdict["data"]:
- app_name = pusherdict["data"]["brand"]
- else:
- app_name = self.config.email_app_name
+ data = pusherdict["data"]
- return app_name
+ if isinstance(data, dict):
+ brand = data.get("brand")
+ if isinstance(brand, str):
+ return brand
+
+ return self.config.email_app_name
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 0f6992202d..b9dca5bc63 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -232,7 +232,6 @@ class PusherPool:
Deferred
"""
pushers = yield self.store.get_all_pushers()
- logger.info("Starting %d pushers", len(pushers))
# Stagger starting up the pushers so we don't completely drown the
# process on start up.
@@ -245,7 +244,7 @@ class PusherPool:
"""Start the given pusher
Args:
- pusherdict (dict):
+ pusherdict (dict): dict with the values pulled from the db table
Returns:
Deferred[EmailPusher|HttpPusher]
@@ -254,7 +253,8 @@ class PusherPool:
p = self.pusher_factory.create_pusher(pusherdict)
except PusherConfigException as e:
logger.warning(
- "Pusher incorrectly configured user=%s, appid=%s, pushkey=%s: %s",
+ "Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
+ pusherdict["id"],
pusherdict.get("user_name"),
pusherdict.get("app_id"),
pusherdict.get("pushkey"),
@@ -262,7 +262,9 @@ class PusherPool:
)
return
except Exception:
- logger.exception("Couldn't start a pusher: caught Exception")
+ logger.exception(
+ "Couldn't start pusher id %i: caught Exception", pusherdict["id"],
+ )
return
if not p:
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 9af4e7e173..49a3251372 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -51,6 +51,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
self.clock = hs.get_clock()
self.federation_handler = hs.get_handlers().federation_handler
@@ -100,7 +101,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
EventType = event_type_from_format_version(format_ver)
event = EventType(event_dict, internal_metadata, rejected_reason)
- context = EventContext.deserialize(self.store, event_payload["context"])
+ context = EventContext.deserialize(
+ self.storage, event_payload["context"]
+ )
event_and_contexts.append((event, context))
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index 9bafd60b14..84b92f16ad 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -54,6 +54,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
self.clock = hs.get_clock()
@staticmethod
@@ -100,7 +101,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
event = EventType(event_dict, internal_metadata, rejected_reason)
requester = Requester.deserialize(self.store, content["requester"])
- context = EventContext.deserialize(self.store, content["context"])
+ context = EventContext.deserialize(self.storage, content["context"])
ratelimit = content["ratelimit"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 0791866f55..6f6b7aed6e 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -28,6 +28,17 @@ from synapse.rest.client.v2_alpha._base import client_patterns
logger = logging.getLogger(__name__)
+ALLOWED_KEYS = {
+ "app_display_name",
+ "app_id",
+ "data",
+ "device_display_name",
+ "kind",
+ "lang",
+ "profile_tag",
+ "pushkey",
+}
+
class PushersRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers$", v1=True)
@@ -43,23 +54,11 @@ class PushersRestServlet(RestServlet):
pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
- allowed_keys = [
- "app_display_name",
- "app_id",
- "data",
- "device_display_name",
- "kind",
- "lang",
- "profile_tag",
- "pushkey",
- ]
-
- for p in pushers:
- for k, v in list(p.items()):
- if k not in allowed_keys:
- del p[k]
-
- return 200, {"pushers": pushers}
+ filtered_pushers = list(
+ {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers
+ )
+
+ return 200, {"pushers": filtered_pushers}
def on_OPTIONS(self, _):
return 200, {}
diff --git a/synapse/server.py b/synapse/server.py
index 2db3dab221..7926867b77 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -25,7 +25,6 @@ import abc
import logging
import os
-from twisted.enterprise import adbapi
from twisted.mail.smtp import sendmail
from twisted.web.client import BrowserLikePolicyForHTTPS
@@ -34,6 +33,7 @@ from synapse.api.filtering import Filtering
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice.api import ApplicationServiceApi
from synapse.appservice.scheduler import ApplicationServiceScheduler
+from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory
@@ -132,7 +132,6 @@ class HomeServer(object):
DEPENDENCIES = [
"http_client",
- "db_pool",
"federation_client",
"federation_server",
"handlers",
@@ -209,16 +208,18 @@ class HomeServer(object):
# instantiated during setup() for future return by get_datastore()
DATASTORE_CLASS = abc.abstractproperty()
- def __init__(self, hostname, reactor=None, **kwargs):
+ def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs):
"""
Args:
hostname : The hostname for the server.
+ config: The full config for the homeserver.
"""
if not reactor:
from twisted.internet import reactor
self._reactor = reactor
self.hostname = hostname
+ self.config = config
self._building = {}
self._listening_services = []
self.start_time = None
@@ -237,10 +238,8 @@ class HomeServer(object):
def setup(self):
logger.info("Setting up.")
- with self.get_db_conn() as conn:
- self.datastores = DataStores(self.DATASTORE_CLASS, conn, self)
- conn.commit()
self.start_time = int(self.get_clock().time())
+ self.datastores = DataStores(self.DATASTORE_CLASS, self)
logger.info("Finished setting up.")
def setup_master(self):
@@ -274,6 +273,9 @@ class HomeServer(object):
def get_datastore(self):
return self.datastores.main
+ def get_datastores(self):
+ return self.datastores
+
def get_config(self):
return self.config
@@ -423,31 +425,6 @@ class HomeServer(object):
)
return MatrixFederationHttpClient(self, tls_client_options_factory)
- def build_db_pool(self):
- name = self.db_config["name"]
-
- return adbapi.ConnectionPool(
- name, cp_reactor=self.get_reactor(), **self.db_config.get("args", {})
- )
-
- def get_db_conn(self, run_new_connection=True):
- """Makes a new connection to the database, skipping the db pool
-
- Returns:
- Connection: a connection object implementing the PEP-249 spec
- """
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v
- for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def build_media_repository_resource(self):
# build the media repo resource. This indirects through the HomeServer
# to ensure that we only have a single instance of
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 0e75e94c6f..5accc071ab 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -32,6 +32,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function
from synapse.state import v1, v2
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.expiringcache import ExpiringCache
@@ -655,7 +656,7 @@ class StateResolutionStore(object):
return self.store.get_events(
event_ids,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
allow_rejected=allow_rejected,
)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b7637b5dc0..88546ad614 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -40,7 +40,7 @@ class SQLBaseStore(object):
def __init__(self, database: Database, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
- self.database_engine = hs.database_engine
+ self.database_engine = database.engine
self.db = database
self.rand = random.SystemRandom()
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index cafedd5c0d..d20df5f076 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -13,24 +13,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.database import Database
+import logging
+
+from synapse.storage.data_stores.state import StateGroupDataStore
+from synapse.storage.database import Database, make_conn
+from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
+logger = logging.getLogger(__name__)
+
class DataStores(object):
"""The various data stores.
These are low level interfaces to physical databases.
+
+ Attributes:
+ main (DataStore)
"""
- def __init__(self, main_store_class, db_conn, hs):
+ def __init__(self, main_store_class, hs):
# Note we pass in the main store class here as workers use a different main
# store.
- database = Database(hs)
- # Check that db is correctly configured.
- database.engine.check_database(db_conn.cursor())
+ self.databases = []
+
+ for database_config in hs.config.database.databases:
+ db_name = database_config.name
+ engine = create_engine(database_config.config)
+
+ with make_conn(database_config, engine) as db_conn:
+ logger.info("Preparing database %r...", db_name)
+
+ engine.check_database(db_conn.cursor())
+ prepare_database(
+ db_conn, engine, hs.config, data_stores=database_config.data_stores,
+ )
+
+ database = Database(hs, database_config, engine)
+
+ if "main" in database_config.data_stores:
+ logger.info("Starting 'main' data store")
+ self.main = main_store_class(database, db_conn, hs)
+
+ if "state" in database_config.data_stores:
+ logger.info("Starting 'state' data store")
+ self.state = StateGroupDataStore(database, db_conn, hs)
+
+ db_conn.commit()
- prepare_database(db_conn, database.engine, config=hs.config)
+ self.databases.append(database)
- self.main = main_store_class(database, db_conn, hs)
+ logger.info("Database %r prepared", db_name)
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 320c5b0f07..13f4c9c72e 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -412,7 +412,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
def _update_client_ips_batch(self):
# If the DB pool has already terminated, don't try updating
- if not self.hs.get_db_pool().running:
+ if not self.db.is_running():
return
to_update = self._batch_row_update
@@ -451,16 +451,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
# Technically an access token might not be associated with
# a device so we need to check.
if device_id:
- self.db.simple_upsert_txn(
+ # this is always an update rather than an upsert: the row should
+ # already exist, and if it doesn't, that may be because it has been
+ # deleted, and we don't want to re-create it.
+ self.db.simple_update_txn(
txn,
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
- values={
+ updatevalues={
"user_agent": user_agent,
"last_seen": last_seen,
"ip": ip,
},
- lock=False,
)
except Exception as e:
# Failed to upsert, log and continue
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 85cfa16850..0613b49f4a 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -358,21 +358,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
def _add_messages_to_local_device_inbox_txn(
self, txn, stream_id, messages_by_user_then_device
):
- # Compatible method of performing an upsert
- sql = "SELECT stream_id FROM device_max_stream_id"
-
- txn.execute(sql)
- rows = txn.fetchone()
- if rows:
- db_stream_id = rows[0]
- if db_stream_id < stream_id:
- # Insert the new stream_id
- sql = "UPDATE device_max_stream_id SET stream_id = ?"
- else:
- # No rows, perform an insert
- sql = "INSERT INTO device_max_stream_id (stream_id) VALUES (?)"
-
- txn.execute(sql, (stream_id,))
+ sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?"
+ txn.execute(sql, (stream_id, stream_id))
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 38cd0ca9b8..e551606f9d 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -14,15 +14,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Dict, List
+
from six import iteritems
from canonicaljson import encode_canonical_json, json
+from twisted.enterprise.adbapi import Connection
from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedList
class EndToEndKeyWorkerStore(SQLBaseStore):
@@ -271,7 +274,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
user_id (str): the user whose key is being requested
- key_type (str): the type of key that is being set: either 'master'
+ key_type (str): the type of key that is being requested: either 'master'
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
from_user_id (str): if specified, signatures made by this user on
@@ -316,8 +319,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"""Returns a user's cross-signing key.
Args:
- user_id (str): the user whose self-signing key is being requested
- key_type (str): the type of cross-signing key to get
+ user_id (str): the user whose key is being requested
+ key_type (str): the type of key that is being requested: either 'master'
+ for a master key, 'self_signing' for a self-signing key, or
+ 'user_signing' for a user-signing key
from_user_id (str): if specified, signatures made by this user on
the self-signing key will be included in the result
@@ -332,6 +337,206 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
from_user_id,
)
+ @cached(num_args=1)
+ def _get_bare_e2e_cross_signing_keys(self, user_id):
+ """Dummy function. Only used to make a cache for
+ _get_bare_e2e_cross_signing_keys_bulk.
+ """
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="_get_bare_e2e_cross_signing_keys",
+ list_name="user_ids",
+ num_args=1,
+ )
+ def _get_bare_e2e_cross_signing_keys_bulk(
+ self, user_ids: List[str]
+ ) -> Dict[str, Dict[str, dict]]:
+ """Returns the cross-signing keys for a set of users. The output of this
+ function should be passed to _get_e2e_cross_signing_signatures_txn if
+ the signatures for the calling user need to be fetched.
+
+ Args:
+ user_ids (list[str]): the users whose keys are being requested
+
+ Returns:
+ dict[str, dict[str, dict]]: mapping from user ID to key type to key
+ data. If a user's cross-signing keys were not found, either
+ their user ID will not be in the dict, or their user ID will map
+ to None.
+
+ """
+ return self.db.runInteraction(
+ "get_bare_e2e_cross_signing_keys_bulk",
+ self._get_bare_e2e_cross_signing_keys_bulk_txn,
+ user_ids,
+ )
+
+ def _get_bare_e2e_cross_signing_keys_bulk_txn(
+ self, txn: Connection, user_ids: List[str],
+ ) -> Dict[str, Dict[str, dict]]:
+ """Returns the cross-signing keys for a set of users. The output of this
+ function should be passed to _get_e2e_cross_signing_signatures_txn if
+ the signatures for the calling user need to be fetched.
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ user_ids (list[str]): the users whose keys are being requested
+
+ Returns:
+ dict[str, dict[str, dict]]: mapping from user ID to key type to key
+ data. If a user's cross-signing keys were not found, their user
+ ID will not be in the dict.
+
+ """
+ result = {}
+
+ batch_size = 100
+ chunks = [
+ user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
+ ]
+ for user_chunk in chunks:
+ sql = """
+ SELECT k.user_id, k.keytype, k.keydata, k.stream_id
+ FROM e2e_cross_signing_keys k
+ INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
+ FROM e2e_cross_signing_keys
+ GROUP BY user_id, keytype) s
+ USING (user_id, stream_id, keytype)
+ WHERE k.user_id IN (%s)
+ """ % (
+ ",".join("?" for u in user_chunk),
+ )
+ query_params = []
+ query_params.extend(user_chunk)
+
+ txn.execute(sql, query_params)
+ rows = self.db.cursor_to_dict(txn)
+
+ for row in rows:
+ user_id = row["user_id"]
+ key_type = row["keytype"]
+ key = json.loads(row["keydata"])
+ user_info = result.setdefault(user_id, {})
+ user_info[key_type] = key
+
+ return result
+
+ def _get_e2e_cross_signing_signatures_txn(
+ self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str,
+ ) -> Dict[str, Dict[str, dict]]:
+ """Returns the cross-signing signatures made by a user on a set of keys.
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ keys (dict[str, dict[str, dict]]): a map of user ID to key type to
+ key data. This dict will be modified to add signatures.
+ from_user_id (str): fetch the signatures made by this user
+
+ Returns:
+ dict[str, dict[str, dict]]: mapping from user ID to key type to key
+ data. The return value will be the same as the keys argument,
+ with the modifications included.
+ """
+
+ # find out what cross-signing keys (a.k.a. devices) we need to get
+ # signatures for. This is a map of (user_id, device_id) to key type
+ # (device_id is the key's public part).
+ devices = {}
+
+ for user_id, user_info in keys.items():
+ if user_info is None:
+ continue
+ for key_type, key in user_info.items():
+ device_id = None
+ for k in key["keys"].values():
+ device_id = k
+ devices[(user_id, device_id)] = key_type
+
+ device_list = list(devices)
+
+ # split into batches
+ batch_size = 100
+ chunks = [
+ device_list[i : i + batch_size]
+ for i in range(0, len(device_list), batch_size)
+ ]
+ for user_chunk in chunks:
+ sql = """
+ SELECT target_user_id, target_device_id, key_id, signature
+ FROM e2e_cross_signing_signatures
+ WHERE user_id = ?
+ AND (%s)
+ """ % (
+ " OR ".join(
+ "(target_user_id = ? AND target_device_id = ?)" for d in devices
+ )
+ )
+ query_params = [from_user_id]
+ for item in devices:
+ # item is a (user_id, device_id) tuple
+ query_params.extend(item)
+
+ txn.execute(sql, query_params)
+ rows = self.db.cursor_to_dict(txn)
+
+ # and add the signatures to the appropriate keys
+ for row in rows:
+ key_id = row["key_id"]
+ target_user_id = row["target_user_id"]
+ target_device_id = row["target_device_id"]
+ key_type = devices[(target_user_id, target_device_id)]
+ # We need to copy everything, because the result may have come
+ # from the cache. dict.copy only does a shallow copy, so we
+ # need to recursively copy the dicts that will be modified.
+ user_info = keys[target_user_id] = keys[target_user_id].copy()
+ target_user_key = user_info[key_type] = user_info[key_type].copy()
+ if "signatures" in target_user_key:
+ signatures = target_user_key["signatures"] = target_user_key[
+ "signatures"
+ ].copy()
+ if from_user_id in signatures:
+ user_sigs = signatures[from_user_id] = signatures[from_user_id]
+ user_sigs[key_id] = row["signature"]
+ else:
+ signatures[from_user_id] = {key_id: row["signature"]}
+ else:
+ target_user_key["signatures"] = {
+ from_user_id: {key_id: row["signature"]}
+ }
+
+ return keys
+
+ @defer.inlineCallbacks
+ def get_e2e_cross_signing_keys_bulk(
+ self, user_ids: List[str], from_user_id: str = None
+ ) -> defer.Deferred:
+ """Returns the cross-signing keys for a set of users.
+
+ Args:
+ user_ids (list[str]): the users whose keys are being requested
+ from_user_id (str): if specified, signatures made by this user on
+ the self-signing keys will be included in the result
+
+ Returns:
+ Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to
+ key data. If a user's cross-signing keys were not found, either
+ their user ID will not be in the dict, or their user ID will map
+ to None.
+ """
+
+ result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
+
+ if from_user_id:
+ result = yield self.db.runInteraction(
+ "get_e2e_cross_signing_signatures",
+ self._get_e2e_cross_signing_signatures_txn,
+ result,
+ from_user_id,
+ )
+
+ return result
+
def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
"""Return a list of changes from the user signature stream to notify remotes.
Note that the user signature stream represents when a user signs their
@@ -520,6 +725,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
},
)
+ self._invalidate_cache_and_stream(
+ txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
+ )
+
def set_e2e_cross_signing_key(self, user_id, key_type, key):
"""Set a user's cross-signing key.
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 998bba1aad..58f35d7f56 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -1757,163 +1757,6 @@ class EventsStore(
return state_groups
- def purge_unreferenced_state_groups(
- self, room_id: str, state_groups_to_delete
- ) -> defer.Deferred:
- """Deletes no longer referenced state groups and de-deltas any state
- groups that reference them.
-
- Args:
- room_id: The room the state groups belong to (must all be in the
- same room).
- state_groups_to_delete (Collection[int]): Set of all state groups
- to delete.
- """
-
- return self.db.runInteraction(
- "purge_unreferenced_state_groups",
- self._purge_unreferenced_state_groups,
- room_id,
- state_groups_to_delete,
- )
-
- def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete):
- logger.info(
- "[purge] found %i state groups to delete", len(state_groups_to_delete)
- )
-
- rows = self.db.simple_select_many_txn(
- txn,
- table="state_group_edges",
- column="prev_state_group",
- iterable=state_groups_to_delete,
- keyvalues={},
- retcols=("state_group",),
- )
-
- remaining_state_groups = set(
- row["state_group"]
- for row in rows
- if row["state_group"] not in state_groups_to_delete
- )
-
- logger.info(
- "[purge] de-delta-ing %i remaining state groups",
- len(remaining_state_groups),
- )
-
- # Now we turn the state groups that reference to-be-deleted state
- # groups to non delta versions.
- for sg in remaining_state_groups:
- logger.info("[purge] de-delta-ing remaining state group %s", sg)
- curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
- curr_state = curr_state[sg]
-
- self.db.simple_delete_txn(
- txn, table="state_groups_state", keyvalues={"state_group": sg}
- )
-
- self.db.simple_delete_txn(
- txn, table="state_group_edges", keyvalues={"state_group": sg}
- )
-
- self.db.simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": sg,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in iteritems(curr_state)
- ],
- )
-
- logger.info("[purge] removing redundant state groups")
- txn.executemany(
- "DELETE FROM state_groups_state WHERE state_group = ?",
- ((sg,) for sg in state_groups_to_delete),
- )
- txn.executemany(
- "DELETE FROM state_groups WHERE id = ?",
- ((sg,) for sg in state_groups_to_delete),
- )
-
- @defer.inlineCallbacks
- def get_previous_state_groups(self, state_groups):
- """Fetch the previous groups of the given state groups.
-
- Args:
- state_groups (Iterable[int])
-
- Returns:
- Deferred[dict[int, int]]: mapping from state group to previous
- state group.
- """
-
- rows = yield self.db.simple_select_many_batch(
- table="state_group_edges",
- column="prev_state_group",
- iterable=state_groups,
- keyvalues={},
- retcols=("prev_state_group", "state_group"),
- desc="get_previous_state_groups",
- )
-
- return {row["state_group"]: row["prev_state_group"] for row in rows}
-
- def purge_room_state(self, room_id, state_groups_to_delete):
- """Deletes all record of a room from state tables
-
- Args:
- room_id (str):
- state_groups_to_delete (list[int]): State groups to delete
- """
-
- return self.db.runInteraction(
- "purge_room_state",
- self._purge_room_state_txn,
- room_id,
- state_groups_to_delete,
- )
-
- def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete):
- # first we have to delete the state groups states
- logger.info("[purge] removing %s from state_groups_state", room_id)
-
- self.db.simple_delete_many_txn(
- txn,
- table="state_groups_state",
- column="state_group",
- iterable=state_groups_to_delete,
- keyvalues={},
- )
-
- # ... and the state group edges
- logger.info("[purge] removing %s from state_group_edges", room_id)
-
- self.db.simple_delete_many_txn(
- txn,
- table="state_group_edges",
- column="state_group",
- iterable=state_groups_to_delete,
- keyvalues={},
- )
-
- # ... and the state groups
- logger.info("[purge] removing %s from state_groups", room_id)
-
- self.db.simple_delete_many_txn(
- txn,
- table="state_groups",
- column="id",
- iterable=state_groups_to_delete,
- keyvalues={},
- )
-
async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
"""
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 9ee117ce0f..2c9142814c 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -19,8 +19,10 @@ import itertools
import logging
import threading
from collections import namedtuple
+from typing import List, Optional
from canonicaljson import json
+from constantly import NamedConstant, Names
from twisted.internet import defer
@@ -55,6 +57,16 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+class EventRedactBehaviour(Names):
+ """
+ What to do when retrieving a redacted event from the database.
+ """
+
+ AS_IS = NamedConstant()
+ REDACT = NamedConstant()
+ BLOCK = NamedConstant()
+
+
class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
@@ -125,25 +137,27 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_event(
self,
- event_id,
- check_redacted=True,
- get_prev_content=False,
- allow_rejected=False,
- allow_none=False,
- check_room_id=None,
+ event_id: List[str],
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ allow_none: bool = False,
+ check_room_id: Optional[str] = None,
):
"""Get an event from the database by event_id.
Args:
- event_id (str): The event_id of the event to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
+ event_id: The event_id of the event to fetch
+ redact_behaviour: Determine what to do with a redacted event. Possible values:
+ * AS_IS - Return the full event body with no redacted content
+ * REDACT - Return the event but with a redacted body
+ * DISALLOW - Do not return redacted events
+ get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
- allow_none (bool): If True, return None if no event found, if
+ allow_rejected: If True return rejected events.
+ allow_none: If True, return None if no event found, if
False throw a NotFoundError
- check_room_id (str|None): if not None, check the room of the found event.
+ check_room_id: if not None, check the room of the found event.
If there is a mismatch, behave as per allow_none.
Returns:
@@ -154,7 +168,7 @@ class EventsWorkerStore(SQLBaseStore):
events = yield self.get_events_as_list(
[event_id],
- check_redacted=check_redacted,
+ redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
@@ -173,27 +187,30 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_events(
self,
- event_ids,
- check_redacted=True,
- get_prev_content=False,
- allow_rejected=False,
+ event_ids: List[str],
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
):
"""Get events from the database
Args:
- event_ids (list): The event_ids of the events to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
+ event_ids: The event_ids of the events to fetch
+ redact_behaviour: Determine what to do with a redacted event. Possible
+ values:
+ * AS_IS - Return the full event body with no redacted content
+ * REDACT - Return the event but with a redacted body
+ * DISALLOW - Do not return redacted events
+ get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
+ allow_rejected: If True return rejected events.
Returns:
Deferred : Dict from event_id to event.
"""
events = yield self.get_events_as_list(
event_ids,
- check_redacted=check_redacted,
+ redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
@@ -203,21 +220,23 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_events_as_list(
self,
- event_ids,
- check_redacted=True,
- get_prev_content=False,
- allow_rejected=False,
+ event_ids: List[str],
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
):
"""Get events from the database and return in a list in the same order
as given by `event_ids` arg.
Args:
- event_ids (list): The event_ids of the events to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
+ event_ids: The event_ids of the events to fetch
+ redact_behaviour: Determine what to do with a redacted event. Possible values:
+ * AS_IS - Return the full event body with no redacted content
+ * REDACT - Return the event but with a redacted body
+ * DISALLOW - Do not return redacted events
+ get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
+ allow_rejected: If True, return rejected events.
Returns:
Deferred[list[EventBase]]: List of events fetched from the database. The
@@ -319,10 +338,14 @@ class EventsWorkerStore(SQLBaseStore):
# Update the cache to save doing the checks again.
entry.event.internal_metadata.recheck_redaction = False
- if check_redacted and entry.redacted_event:
- event = entry.redacted_event
- else:
- event = entry.event
+ event = entry.event
+
+ if entry.redacted_event:
+ if redact_behaviour == EventRedactBehaviour.BLOCK:
+ # Skip this event
+ continue
+ elif redact_behaviour == EventRedactBehaviour.REDACT:
+ event = entry.redacted_event
events.append(event)
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index 5ba13aa973..e2673ae073 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -244,7 +244,7 @@ class PushRulesWorkerStore(
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- current_state_ids = yield context.get_current_state_ids(self)
+ current_state_ids = yield context.get_current_state_ids()
result = yield self._bulk_get_push_rules_for_room(
event.room_id, state_group, current_state_ids, event=event
)
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index f07309ef09..6b03233262 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -15,8 +15,7 @@
# limitations under the License.
import logging
-
-import six
+from typing import Iterable, Iterator
from canonicaljson import encode_canonical_json, json
@@ -27,21 +26,16 @@ from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
logger = logging.getLogger(__name__)
-if six.PY2:
- db_binary_type = six.moves.builtins.buffer
-else:
- db_binary_type = memoryview
-
class PusherWorkerStore(SQLBaseStore):
- def _decode_pushers_rows(self, rows):
+ def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]:
+ """JSON-decode the data in the rows returned from the `pushers` table
+
+ Drops any rows whose data cannot be decoded
+ """
for r in rows:
dataJson = r["data"]
- r["data"] = None
try:
- if isinstance(dataJson, db_binary_type):
- dataJson = str(dataJson).decode("UTF8")
-
r["data"] = json.loads(dataJson)
except Exception as e:
logger.warning(
@@ -50,12 +44,9 @@ class PusherWorkerStore(SQLBaseStore):
dataJson,
e.args[0],
)
- pass
-
- if isinstance(r["pushkey"], db_binary_type):
- r["pushkey"] = str(r["pushkey"]).decode("UTF8")
+ continue
- return rows
+ yield r
@defer.inlineCallbacks
def user_has_pusher(self, user_id):
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 92e3b9c512..70ff5751b6 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -477,7 +477,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- current_state_ids = yield context.get_current_state_ids(self)
+ current_state_ids = yield context.get_current_state_ids()
result = yield self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)
diff --git a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
index 4219cdd06a..2de50d408c 100644
--- a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
+++ b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
@@ -20,7 +20,6 @@ DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream
DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room
DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room
DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY
-DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY
DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY
DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT
diff --git a/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql
new file mode 100644
index 0000000000..c2f557fde9
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql
@@ -0,0 +1,20 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This line already existed in deltas/35/device_stream_id but was not included in the
+-- 54 full schema SQL. Add some SQL here to insert the missing row if it does not exist
+INSERT INTO device_max_stream_id (stream_id) SELECT 0 WHERE NOT EXISTS (
+ SELECT * from device_max_stream_id
+);
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql
new file mode 100644
index 0000000000..4f24c1405d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql
@@ -0,0 +1,29 @@
+/* Copyright 2019 Werner Sembach
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Groups/communities now get deleted when the last member leaves. This is a one time cleanup to remove old groups/communities that were already empty before that change was made.
+DELETE FROM group_attestations_remote WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_attestations_renewals WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_invites WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_users WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM local_group_membership WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM local_group_updates WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM groups WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
index 4ad2929f32..889a9a0ce4 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
@@ -975,40 +975,6 @@ CREATE TABLE state_events (
-CREATE TABLE state_group_edges (
- state_group bigint NOT NULL,
- prev_state_group bigint NOT NULL
-);
-
-
-
-CREATE SEQUENCE state_group_id_seq
- START WITH 1
- INCREMENT BY 1
- NO MINVALUE
- NO MAXVALUE
- CACHE 1;
-
-
-
-CREATE TABLE state_groups (
- id bigint NOT NULL,
- room_id text NOT NULL,
- event_id text NOT NULL
-);
-
-
-
-CREATE TABLE state_groups_state (
- state_group bigint NOT NULL,
- room_id text NOT NULL,
- type text NOT NULL,
- state_key text NOT NULL,
- event_id text NOT NULL
-);
-
-
-
CREATE TABLE stats_stream_pos (
lock character(1) DEFAULT 'X'::bpchar NOT NULL,
stream_id bigint,
@@ -1482,12 +1448,6 @@ ALTER TABLE ONLY state_events
ADD CONSTRAINT state_events_event_id_key UNIQUE (event_id);
-
-ALTER TABLE ONLY state_groups
- ADD CONSTRAINT state_groups_pkey PRIMARY KEY (id);
-
-
-
ALTER TABLE ONLY stats_stream_pos
ADD CONSTRAINT stats_stream_pos_lock_key UNIQUE (lock);
@@ -1928,18 +1888,6 @@ CREATE UNIQUE INDEX room_stats_room_ts ON room_stats USING btree (room_id, ts);
-CREATE INDEX state_group_edges_idx ON state_group_edges USING btree (state_group);
-
-
-
-CREATE INDEX state_group_edges_prev_idx ON state_group_edges USING btree (prev_state_group);
-
-
-
-CREATE INDEX state_groups_state_type_idx ON state_groups_state USING btree (state_group, type, state_key);
-
-
-
CREATE INDEX stream_ordering_to_exterm_idx ON stream_ordering_to_exterm USING btree (stream_ordering);
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
index bad33291e7..a0411ede7e 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
@@ -42,8 +42,6 @@ CREATE INDEX ev_edges_id ON event_edges(event_id);
CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id);
CREATE TABLE room_depth( room_id TEXT NOT NULL, min_depth INTEGER NOT NULL, UNIQUE (room_id) );
CREATE INDEX room_depth_room ON room_depth(room_id);
-CREATE TABLE state_groups( id BIGINT PRIMARY KEY, room_id TEXT NOT NULL, event_id TEXT NOT NULL );
-CREATE TABLE state_groups_state( state_group BIGINT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, event_id TEXT NOT NULL );
CREATE TABLE event_to_state_groups( event_id TEXT NOT NULL, state_group BIGINT NOT NULL, UNIQUE (event_id) );
CREATE TABLE local_media_repository ( media_id TEXT, media_type TEXT, media_length INTEGER, created_ts BIGINT, upload_name TEXT, user_id TEXT, quarantined_by TEXT, url_cache TEXT, last_access_ts BIGINT, UNIQUE (media_id) );
CREATE TABLE local_media_repository_thumbnails ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type ) );
@@ -120,9 +118,6 @@ CREATE TABLE device_max_stream_id ( stream_id BIGINT NOT NULL );
CREATE TABLE public_room_list_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, visibility BOOLEAN NOT NULL , appservice_id TEXT, network_id TEXT);
CREATE INDEX public_room_list_stream_idx on public_room_list_stream( stream_id );
CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( room_id, stream_id );
-CREATE TABLE state_group_edges( state_group BIGINT NOT NULL, prev_state_group BIGINT NOT NULL );
-CREATE INDEX state_group_edges_idx ON state_group_edges(state_group);
-CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group);
CREATE TABLE stream_ordering_to_exterm ( stream_ordering BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL );
CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( stream_ordering );
CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( room_id, stream_ordering );
@@ -254,6 +249,5 @@ CREATE INDEX user_ips_last_seen_only ON user_ips (last_seen);
CREATE INDEX users_creation_ts ON users (creation_ts);
CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups (state_group);
CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache (user_id, device_id);
-CREATE INDEX state_groups_state_type_idx ON state_groups_state(state_group, type, state_key);
CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties (user_id);
CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips (user_id, access_token, ip);
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
index c265fd20e2..91d21b2921 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
@@ -5,3 +5,4 @@ INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coales
INSERT INTO user_directory_stream_pos (stream_id) VALUES (0);
INSERT INTO stats_stream_pos (stream_id) VALUES (0);
INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0);
+-- device_max_stream_id is handled separately in 56/device_stream_id_insert.sql
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md
new file mode 100644
index 0000000000..bbd3f18604
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/full_schemas/README.md
@@ -0,0 +1,13 @@
+# Building full schema dumps
+
+These schemas need to be made from a database that has had all background updates run.
+
+To do so, use `scripts-dev/make_full_schema.sh`. This will produce
+`full.sql.postgres ` and `full.sql.sqlite` files.
+
+Ensure postgres is installed and your user has the ability to run bash commands
+such as `createdb`.
+
+```
+./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
+```
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.txt b/synapse/storage/data_stores/main/schema/full_schemas/README.txt
deleted file mode 100644
index d3f6401344..0000000000
--- a/synapse/storage/data_stores/main/schema/full_schemas/README.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-Building full schema dumps
-==========================
-
-These schemas need to be made from a database that has had all background updates run.
-
-Postgres
---------
-
-$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres
-
-SQLite
-------
-
-$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite
-
-After
------
-
-Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas".
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index 4eec2fae5e..47ebb8a214 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -25,6 +25,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@@ -384,7 +385,7 @@ class SearchStore(SearchBackgroundUpdateStore):
"""
clauses = []
- search_query = search_query = _parse_query(self.database_engine, search_term)
+ search_query = _parse_query(self.database_engine, search_term)
args = []
@@ -453,7 +454,12 @@ class SearchStore(SearchBackgroundUpdateStore):
results = list(filter(lambda row: row["room_id"] in room_ids, results))
- events = yield self.get_events_as_list([r["event_id"] for r in results])
+ # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
+ # search results (which is a data leak)
+ events = yield self.get_events_as_list(
+ [r["event_id"] for r in results],
+ redact_behaviour=EventRedactBehaviour.BLOCK,
+ )
event_map = {ev.event_id: ev for ev in events}
@@ -495,7 +501,7 @@ class SearchStore(SearchBackgroundUpdateStore):
"""
clauses = []
- search_query = search_query = _parse_query(self.database_engine, search_term)
+ search_query = _parse_query(self.database_engine, search_term)
args = []
@@ -600,7 +606,12 @@ class SearchStore(SearchBackgroundUpdateStore):
results = list(filter(lambda row: row["room_id"] in room_ids, results))
- events = yield self.get_events_as_list([r["event_id"] for r in results])
+ # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
+ # search results (which is a data leak)
+ events = yield self.get_events_as_list(
+ [r["event_id"] for r in results],
+ redact_behaviour=EventRedactBehaviour.BLOCK,
+ )
event_map = {ev.event_id: ev for ev in events}
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 9ef7b48c74..0dc39f139c 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -17,8 +17,7 @@ import logging
from collections import namedtuple
from typing import Iterable, Tuple
-from six import iteritems, itervalues
-from six.moves import range
+from six import iteritems
from twisted.internet import defer
@@ -29,11 +28,9 @@ from synapse.events.snapshot import EventContext
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.database import Database
-from synapse.storage.engines import PostgresEngine
from synapse.storage.state import StateFilter
-from synapse.util.caches import get_cache_factor_for, intern_string
+from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.stringutils import to_ascii
logger = logging.getLogger(__name__)
@@ -55,207 +52,14 @@ class _GetStateGroupDelta(
return len(self.delta_ids) if self.delta_ids else 0
-class StateGroupBackgroundUpdateStore(SQLBaseStore):
- """Defines functions related to state groups needed to run the state backgroud
- updates.
- """
-
- def _count_state_group_hops_txn(self, txn, state_group):
- """Given a state group, count how many hops there are in the tree.
-
- This is used to ensure the delta chains don't get too long.
- """
- if isinstance(self.database_engine, PostgresEngine):
- sql = """
- WITH RECURSIVE state(state_group) AS (
- VALUES(?::bigint)
- UNION ALL
- SELECT prev_state_group FROM state_group_edges e, state s
- WHERE s.state_group = e.state_group
- )
- SELECT count(*) FROM state;
- """
-
- txn.execute(sql, (state_group,))
- row = txn.fetchone()
- if row and row[0]:
- return row[0]
- else:
- return 0
- else:
- # We don't use WITH RECURSIVE on sqlite3 as there are distributions
- # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
- next_group = state_group
- count = 0
-
- while next_group:
- next_group = self.db.simple_select_one_onecol_txn(
- txn,
- table="state_group_edges",
- keyvalues={"state_group": next_group},
- retcol="prev_state_group",
- allow_none=True,
- )
- if next_group:
- count += 1
-
- return count
-
- def _get_state_groups_from_groups_txn(
- self, txn, groups, state_filter=StateFilter.all()
- ):
- results = {group: {} for group in groups}
-
- where_clause, where_args = state_filter.make_sql_filter_clause()
-
- # Unless the filter clause is empty, we're going to append it after an
- # existing where clause
- if where_clause:
- where_clause = " AND (%s)" % (where_clause,)
-
- if isinstance(self.database_engine, PostgresEngine):
- # Temporarily disable sequential scans in this transaction. This is
- # a temporary hack until we can add the right indices in
- txn.execute("SET LOCAL enable_seqscan=off")
-
- # The below query walks the state_group tree so that the "state"
- # table includes all state_groups in the tree. It then joins
- # against `state_groups_state` to fetch the latest state.
- # It assumes that previous state groups are always numerically
- # lesser.
- # The PARTITION is used to get the event_id in the greatest state
- # group for the given type, state_key.
- # This may return multiple rows per (type, state_key), but last_value
- # should be the same.
- sql = """
- WITH RECURSIVE state(state_group) AS (
- VALUES(?::bigint)
- UNION ALL
- SELECT prev_state_group FROM state_group_edges e, state s
- WHERE s.state_group = e.state_group
- )
- SELECT DISTINCT type, state_key, last_value(event_id) OVER (
- PARTITION BY type, state_key ORDER BY state_group ASC
- ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
- ) AS event_id FROM state_groups_state
- WHERE state_group IN (
- SELECT state_group FROM state
- )
- """
-
- for group in groups:
- args = [group]
- args.extend(where_args)
-
- txn.execute(sql + where_clause, args)
- for row in txn:
- typ, state_key, event_id = row
- key = (typ, state_key)
- results[group][key] = event_id
- else:
- max_entries_returned = state_filter.max_entries_returned()
-
- # We don't use WITH RECURSIVE on sqlite3 as there are distributions
- # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
- for group in groups:
- next_group = group
-
- while next_group:
- # We did this before by getting the list of group ids, and
- # then passing that list to sqlite to get latest event for
- # each (type, state_key). However, that was terribly slow
- # without the right indices (which we can't add until
- # after we finish deduping state, which requires this func)
- args = [next_group]
- args.extend(where_args)
-
- txn.execute(
- "SELECT type, state_key, event_id FROM state_groups_state"
- " WHERE state_group = ? " + where_clause,
- args,
- )
- results[group].update(
- ((typ, state_key), event_id)
- for typ, state_key, event_id in txn
- if (typ, state_key) not in results[group]
- )
-
- # If the number of entries in the (type,state_key)->event_id dict
- # matches the number of (type,state_keys) types we were searching
- # for, then we must have found them all, so no need to go walk
- # further down the tree... UNLESS our types filter contained
- # wildcards (i.e. Nones) in which case we have to do an exhaustive
- # search
- if (
- max_entries_returned is not None
- and len(results[group]) == max_entries_returned
- ):
- break
-
- next_group = self.db.simple_select_one_onecol_txn(
- txn,
- table="state_group_edges",
- keyvalues={"state_group": next_group},
- retcol="prev_state_group",
- allow_none=True,
- )
-
- return results
-
-
# this inherits from EventsWorkerStore because it calls self.get_events
-class StateGroupWorkerStore(
- EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore
-):
+class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers.
"""
- STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
- STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
- CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
-
def __init__(self, database: Database, db_conn, hs):
super(StateGroupWorkerStore, self).__init__(database, db_conn, hs)
- # Originally the state store used a single DictionaryCache to cache the
- # event IDs for the state types in a given state group to avoid hammering
- # on the state_group* tables.
- #
- # The point of using a DictionaryCache is that it can cache a subset
- # of the state events for a given state group (i.e. a subset of the keys for a
- # given dict which is an entry in the cache for a given state group ID).
- #
- # However, this poses problems when performing complicated queries
- # on the store - for instance: "give me all the state for this group, but
- # limit members to this subset of users", as DictionaryCache's API isn't
- # rich enough to say "please cache any of these fields, apart from this subset".
- # This is problematic when lazy loading members, which requires this behaviour,
- # as without it the cache has no choice but to speculatively load all
- # state events for the group, which negates the efficiency being sought.
- #
- # Rather than overcomplicating DictionaryCache's API, we instead split the
- # state_group_cache into two halves - one for tracking non-member events,
- # and the other for tracking member_events. This means that lazy loading
- # queries can be made in a cache-friendly manner by querying both caches
- # separately and then merging the result. So for the example above, you
- # would query the members cache for a specific subset of state keys
- # (which DictionaryCache will handle efficiently and fine) and the non-members
- # cache for all state (which DictionaryCache will similarly handle fine)
- # and then just merge the results together.
- #
- # We size the non-members cache to be smaller than the members cache as the
- # vast majority of state in Matrix (today) is member events.
-
- self._state_group_cache = DictionaryCache(
- "*stateGroupCache*",
- # TODO: this hasn't been tuned yet
- 50000 * get_cache_factor_for("stateGroupCache"),
- )
- self._state_group_members_cache = DictionaryCache(
- "*stateGroupMembersCache*",
- 500000 * get_cache_factor_for("stateGroupMembersCache"),
- )
-
@defer.inlineCallbacks
def get_room_version(self, room_id):
"""Get the room_version of a given room
@@ -278,7 +82,7 @@ class StateGroupWorkerStore(
@defer.inlineCallbacks
def get_room_predecessor(self, room_id):
- """Get the predecessor room of an upgraded room if one exists.
+ """Get the predecessor of an upgraded room if it exists.
Otherwise return None.
Args:
@@ -291,14 +95,22 @@ class StateGroupWorkerStore(
* room_id (str): The room ID of the predecessor room
* event_id (str): The ID of the tombstone event in the predecessor room
+ None if a predecessor key is not found, or is not a dictionary.
+
Raises:
- NotFoundError if the room is unknown
+ NotFoundError if the given room is unknown
"""
# Retrieve the room's create event
create_event = yield self.get_create_event_for_room(room_id)
- # Return predecessor if present
- return create_event.content.get("predecessor", None)
+ # Retrieve the predecessor key of the create event
+ predecessor = create_event.content.get("predecessor", None)
+
+ # Ensure the key is a dictionary
+ if not isinstance(predecessor, dict):
+ return None
+
+ return predecessor
@defer.inlineCallbacks
def get_create_event_for_room(self, room_id):
@@ -318,7 +130,7 @@ class StateGroupWorkerStore(
# If we can't find the create event, assume we've hit a dead end
if not create_id:
- raise NotFoundError("Unknown room %s" % (room_id))
+ raise NotFoundError("Unknown room %s" % (room_id,))
# Retrieve the room's create event and return
create_event = yield self.get_event(create_id)
@@ -423,229 +235,6 @@ class StateGroupWorkerStore(
return event.content.get("canonical_alias")
- @cached(max_entries=10000, iterable=True)
- def get_state_group_delta(self, state_group):
- """Given a state group try to return a previous group and a delta between
- the old and the new.
-
- Returns:
- (prev_group, delta_ids), where both may be None.
- """
-
- def _get_state_group_delta_txn(txn):
- prev_group = self.db.simple_select_one_onecol_txn(
- txn,
- table="state_group_edges",
- keyvalues={"state_group": state_group},
- retcol="prev_state_group",
- allow_none=True,
- )
-
- if not prev_group:
- return _GetStateGroupDelta(None, None)
-
- delta_ids = self.db.simple_select_list_txn(
- txn,
- table="state_groups_state",
- keyvalues={"state_group": state_group},
- retcols=("type", "state_key", "event_id"),
- )
-
- return _GetStateGroupDelta(
- prev_group,
- {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
- )
-
- return self.db.runInteraction(
- "get_state_group_delta", _get_state_group_delta_txn
- )
-
- @defer.inlineCallbacks
- def get_state_groups_ids(self, _room_id, event_ids):
- """Get the event IDs of all the state for the state groups for the given events
-
- Args:
- _room_id (str): id of the room for these events
- event_ids (iterable[str]): ids of the events
-
- Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
- """
- if not event_ids:
- return {}
-
- event_to_groups = yield self._get_state_group_for_events(event_ids)
-
- groups = set(itervalues(event_to_groups))
- group_to_state = yield self._get_state_for_groups(groups)
-
- return group_to_state
-
- @defer.inlineCallbacks
- def get_state_ids_for_group(self, state_group):
- """Get the event IDs of all the state in the given state group
-
- Args:
- state_group (int)
-
- Returns:
- Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
- """
- group_to_state = yield self._get_state_for_groups((state_group,))
-
- return group_to_state[state_group]
-
- @defer.inlineCallbacks
- def get_state_groups(self, room_id, event_ids):
- """ Get the state groups for the given list of event_ids
-
- Returns:
- Deferred[dict[int, list[EventBase]]]:
- dict of state_group_id -> list of state events.
- """
- if not event_ids:
- return {}
-
- group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
-
- state_event_map = yield self.get_events(
- [
- ev_id
- for group_ids in itervalues(group_to_ids)
- for ev_id in itervalues(group_ids)
- ],
- get_prev_content=False,
- )
-
- return {
- group: [
- state_event_map[v]
- for v in itervalues(event_id_map)
- if v in state_event_map
- ]
- for group, event_id_map in iteritems(group_to_ids)
- }
-
- @defer.inlineCallbacks
- def _get_state_groups_from_groups(self, groups, state_filter):
- """Returns the state groups for a given set of groups, filtering on
- types of state events.
-
- Args:
- groups(list[int]): list of state group IDs to query
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
- Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
- """
- results = {}
-
- chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
- for chunk in chunks:
- res = yield self.db.runInteraction(
- "_get_state_groups_from_groups",
- self._get_state_groups_from_groups_txn,
- chunk,
- state_filter,
- )
- results.update(res)
-
- return results
-
- @defer.inlineCallbacks
- def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
- """Given a list of event_ids and type tuples, return a list of state
- dicts for each event.
-
- Args:
- event_ids (list[string])
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
-
- Returns:
- deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
- """
- event_to_groups = yield self._get_state_group_for_events(event_ids)
-
- groups = set(itervalues(event_to_groups))
- group_to_state = yield self._get_state_for_groups(groups, state_filter)
-
- state_event_map = yield self.get_events(
- [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
- get_prev_content=False,
- )
-
- event_to_state = {
- event_id: {
- k: state_event_map[v]
- for k, v in iteritems(group_to_state[group])
- if v in state_event_map
- }
- for event_id, group in iteritems(event_to_groups)
- }
-
- return {event: event_to_state[event] for event in event_ids}
-
- @defer.inlineCallbacks
- def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
- """
- Get the state dicts corresponding to a list of events, containing the event_ids
- of the state events (as opposed to the events themselves)
-
- Args:
- event_ids(list(str)): events whose state should be returned
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
-
- Returns:
- A deferred dict from event_id -> (type, state_key) -> event_id
- """
- event_to_groups = yield self._get_state_group_for_events(event_ids)
-
- groups = set(itervalues(event_to_groups))
- group_to_state = yield self._get_state_for_groups(groups, state_filter)
-
- event_to_state = {
- event_id: group_to_state[group]
- for event_id, group in iteritems(event_to_groups)
- }
-
- return {event: event_to_state[event] for event in event_ids}
-
- @defer.inlineCallbacks
- def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
- """
- Get the state dict corresponding to a particular event
-
- Args:
- event_id(str): event whose state should be returned
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
-
- Returns:
- A deferred dict from (type, state_key) -> state_event
- """
- state_map = yield self.get_state_for_events([event_id], state_filter)
- return state_map[event_id]
-
- @defer.inlineCallbacks
- def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
- """
- Get the state dict corresponding to a particular event
-
- Args:
- event_id(str): event whose state should be returned
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
-
- Returns:
- A deferred dict from (type, state_key) -> state_event
- """
- state_map = yield self.get_state_ids_for_events([event_id], state_filter)
- return state_map[event_id]
-
@cached(max_entries=50000)
def _get_state_group_for_event(self, event_id):
return self.db.simple_select_one_onecol(
@@ -676,329 +265,6 @@ class StateGroupWorkerStore(
return {row["event_id"]: row["state_group"] for row in rows}
- def _get_state_for_group_using_cache(self, cache, group, state_filter):
- """Checks if group is in cache. See `_get_state_for_groups`
-
- Args:
- cache(DictionaryCache): the state group cache to use
- group(int): The state group to lookup
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
-
- Returns 2-tuple (`state_dict`, `got_all`).
- `got_all` is a bool indicating if we successfully retrieved all
- requests state from the cache, if False we need to query the DB for the
- missing state.
- """
- is_all, known_absent, state_dict_ids = cache.get(group)
-
- if is_all or state_filter.is_full():
- # Either we have everything or want everything, either way
- # `is_all` tells us whether we've gotten everything.
- return state_filter.filter_state(state_dict_ids), is_all
-
- # tracks whether any of our requested types are missing from the cache
- missing_types = False
-
- if state_filter.has_wildcards():
- # We don't know if we fetched all the state keys for the types in
- # the filter that are wildcards, so we have to assume that we may
- # have missed some.
- missing_types = True
- else:
- # There aren't any wild cards, so `concrete_types()` returns the
- # complete list of event types we're wanting.
- for key in state_filter.concrete_types():
- if key not in state_dict_ids and key not in known_absent:
- missing_types = True
- break
-
- return state_filter.filter_state(state_dict_ids), not missing_types
-
- @defer.inlineCallbacks
- def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
- """Gets the state at each of a list of state groups, optionally
- filtering by type/state_key
-
- Args:
- groups (iterable[int]): list of state groups for which we want
- to get the state.
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
- Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
- """
-
- member_filter, non_member_filter = state_filter.get_member_split()
-
- # Now we look them up in the member and non-member caches
- (
- non_member_state,
- incomplete_groups_nm,
- ) = yield self._get_state_for_groups_using_cache(
- groups, self._state_group_cache, state_filter=non_member_filter
- )
-
- (
- member_state,
- incomplete_groups_m,
- ) = yield self._get_state_for_groups_using_cache(
- groups, self._state_group_members_cache, state_filter=member_filter
- )
-
- state = dict(non_member_state)
- for group in groups:
- state[group].update(member_state[group])
-
- # Now fetch any missing groups from the database
-
- incomplete_groups = incomplete_groups_m | incomplete_groups_nm
-
- if not incomplete_groups:
- return state
-
- cache_sequence_nm = self._state_group_cache.sequence
- cache_sequence_m = self._state_group_members_cache.sequence
-
- # Help the cache hit ratio by expanding the filter a bit
- db_state_filter = state_filter.return_expanded()
-
- group_to_state_dict = yield self._get_state_groups_from_groups(
- list(incomplete_groups), state_filter=db_state_filter
- )
-
- # Now lets update the caches
- self._insert_into_cache(
- group_to_state_dict,
- db_state_filter,
- cache_seq_num_members=cache_sequence_m,
- cache_seq_num_non_members=cache_sequence_nm,
- )
-
- # And finally update the result dict, by filtering out any extra
- # stuff we pulled out of the database.
- for group, group_state_dict in iteritems(group_to_state_dict):
- # We just replace any existing entries, as we will have loaded
- # everything we need from the database anyway.
- state[group] = state_filter.filter_state(group_state_dict)
-
- return state
-
- def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
- """Gets the state at each of a list of state groups, optionally
- filtering by type/state_key, querying from a specific cache.
-
- Args:
- groups (iterable[int]): list of state groups for which we want
- to get the state.
- cache (DictionaryCache): the cache of group ids to state dicts which
- we will pass through - either the normal state cache or the specific
- members state cache.
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
-
- Returns:
- tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
- dict of state_group_id -> (dict of (type, state_key) -> event id)
- of entries in the cache, and the state group ids either missing
- from the cache or incomplete.
- """
- results = {}
- incomplete_groups = set()
- for group in set(groups):
- state_dict_ids, got_all = self._get_state_for_group_using_cache(
- cache, group, state_filter
- )
- results[group] = state_dict_ids
-
- if not got_all:
- incomplete_groups.add(group)
-
- return results, incomplete_groups
-
- def _insert_into_cache(
- self,
- group_to_state_dict,
- state_filter,
- cache_seq_num_members,
- cache_seq_num_non_members,
- ):
- """Inserts results from querying the database into the relevant cache.
-
- Args:
- group_to_state_dict (dict): The new entries pulled from database.
- Map from state group to state dict
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
- cache_seq_num_members (int): Sequence number of member cache since
- last lookup in cache
- cache_seq_num_non_members (int): Sequence number of member cache since
- last lookup in cache
- """
-
- # We need to work out which types we've fetched from the DB for the
- # member vs non-member caches. This should be as accurate as possible,
- # but can be an underestimate (e.g. when we have wild cards)
-
- member_filter, non_member_filter = state_filter.get_member_split()
- if member_filter.is_full():
- # We fetched all member events
- member_types = None
- else:
- # `concrete_types()` will only return a subset when there are wild
- # cards in the filter, but that's fine.
- member_types = member_filter.concrete_types()
-
- if non_member_filter.is_full():
- # We fetched all non member events
- non_member_types = None
- else:
- non_member_types = non_member_filter.concrete_types()
-
- for group, group_state_dict in iteritems(group_to_state_dict):
- state_dict_members = {}
- state_dict_non_members = {}
-
- for k, v in iteritems(group_state_dict):
- if k[0] == EventTypes.Member:
- state_dict_members[k] = v
- else:
- state_dict_non_members[k] = v
-
- self._state_group_members_cache.update(
- cache_seq_num_members,
- key=group,
- value=state_dict_members,
- fetched_keys=member_types,
- )
-
- self._state_group_cache.update(
- cache_seq_num_non_members,
- key=group,
- value=state_dict_non_members,
- fetched_keys=non_member_types,
- )
-
- def store_state_group(
- self, event_id, room_id, prev_group, delta_ids, current_state_ids
- ):
- """Store a new set of state, returning a newly assigned state group.
-
- Args:
- event_id (str): The event ID for which the state was calculated
- room_id (str)
- prev_group (int|None): A previous state group for the room, optional.
- delta_ids (dict|None): The delta between state at `prev_group` and
- `current_state_ids`, if `prev_group` was given. Same format as
- `current_state_ids`.
- current_state_ids (dict): The state to store. Map of (type, state_key)
- to event_id.
-
- Returns:
- Deferred[int]: The state group ID
- """
-
- def _store_state_group_txn(txn):
- if current_state_ids is None:
- # AFAIK, this can never happen
- raise Exception("current_state_ids cannot be None")
-
- state_group = self.database_engine.get_next_state_group_id(txn)
-
- self.db.simple_insert_txn(
- txn,
- table="state_groups",
- values={"id": state_group, "room_id": room_id, "event_id": event_id},
- )
-
- # We persist as a delta if we can, while also ensuring the chain
- # of deltas isn't tooo long, as otherwise read performance degrades.
- if prev_group:
- is_in_db = self.db.simple_select_one_onecol_txn(
- txn,
- table="state_groups",
- keyvalues={"id": prev_group},
- retcol="id",
- allow_none=True,
- )
- if not is_in_db:
- raise Exception(
- "Trying to persist state with unpersisted prev_group: %r"
- % (prev_group,)
- )
-
- potential_hops = self._count_state_group_hops_txn(txn, prev_group)
- if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
- self.db.simple_insert_txn(
- txn,
- table="state_group_edges",
- values={"state_group": state_group, "prev_state_group": prev_group},
- )
-
- self.db.simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": state_group,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in iteritems(delta_ids)
- ],
- )
- else:
- self.db.simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": state_group,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in iteritems(current_state_ids)
- ],
- )
-
- # Prefill the state group caches with this group.
- # It's fine to use the sequence like this as the state group map
- # is immutable. (If the map wasn't immutable then this prefill could
- # race with another update)
-
- current_member_state_ids = {
- s: ev
- for (s, ev) in iteritems(current_state_ids)
- if s[0] == EventTypes.Member
- }
- txn.call_after(
- self._state_group_members_cache.update,
- self._state_group_members_cache.sequence,
- key=state_group,
- value=dict(current_member_state_ids),
- )
-
- current_non_member_state_ids = {
- s: ev
- for (s, ev) in iteritems(current_state_ids)
- if s[0] != EventTypes.Member
- }
- txn.call_after(
- self._state_group_cache.update,
- self._state_group_cache.sequence,
- key=state_group,
- value=dict(current_non_member_state_ids),
- )
-
- return state_group
-
- return self.db.runInteraction("store_state_group", _store_state_group_txn)
-
@defer.inlineCallbacks
def get_referenced_state_groups(self, state_groups):
"""Check if the state groups are referenced by events.
@@ -1023,22 +289,14 @@ class StateGroupWorkerStore(
return set(row["state_group"] for row in rows)
-class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
+class MainStateBackgroundUpdateStore(SQLBaseStore):
- STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
- STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
def __init__(self, database: Database, db_conn, hs):
- super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
- self.db.updates.register_background_update_handler(
- self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
- self._background_deduplicate_state,
- )
- self.db.updates.register_background_update_handler(
- self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
- )
+ super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+
self.db.updates.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
index_name="current_state_events_member_index",
@@ -1053,181 +311,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
columns=["state_group"],
)
- @defer.inlineCallbacks
- def _background_deduplicate_state(self, progress, batch_size):
- """This background update will slowly deduplicate state by reencoding
- them as deltas.
- """
- last_state_group = progress.get("last_state_group", 0)
- rows_inserted = progress.get("rows_inserted", 0)
- max_group = progress.get("max_group", None)
-
- BATCH_SIZE_SCALE_FACTOR = 100
- batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
-
- if max_group is None:
- rows = yield self.db.execute(
- "_background_deduplicate_state",
- None,
- "SELECT coalesce(max(id), 0) FROM state_groups",
- )
- max_group = rows[0][0]
-
- def reindex_txn(txn):
- new_last_state_group = last_state_group
- for count in range(batch_size):
- txn.execute(
- "SELECT id, room_id FROM state_groups"
- " WHERE ? < id AND id <= ?"
- " ORDER BY id ASC"
- " LIMIT 1",
- (new_last_state_group, max_group),
- )
- row = txn.fetchone()
- if row:
- state_group, room_id = row
-
- if not row or not state_group:
- return True, count
-
- txn.execute(
- "SELECT state_group FROM state_group_edges"
- " WHERE state_group = ?",
- (state_group,),
- )
-
- # If we reach a point where we've already started inserting
- # edges we should stop.
- if txn.fetchall():
- return True, count
-
- txn.execute(
- "SELECT coalesce(max(id), 0) FROM state_groups"
- " WHERE id < ? AND room_id = ?",
- (state_group, room_id),
- )
- (prev_group,) = txn.fetchone()
- new_last_state_group = state_group
-
- if prev_group:
- potential_hops = self._count_state_group_hops_txn(txn, prev_group)
- if potential_hops >= MAX_STATE_DELTA_HOPS:
- # We want to ensure chains are at most this long,#
- # otherwise read performance degrades.
- continue
-
- prev_state = self._get_state_groups_from_groups_txn(
- txn, [prev_group]
- )
- prev_state = prev_state[prev_group]
-
- curr_state = self._get_state_groups_from_groups_txn(
- txn, [state_group]
- )
- curr_state = curr_state[state_group]
-
- if not set(prev_state.keys()) - set(curr_state.keys()):
- # We can only do a delta if the current has a strict super set
- # of keys
-
- delta_state = {
- key: value
- for key, value in iteritems(curr_state)
- if prev_state.get(key, None) != value
- }
-
- self.db.simple_delete_txn(
- txn,
- table="state_group_edges",
- keyvalues={"state_group": state_group},
- )
-
- self.db.simple_insert_txn(
- txn,
- table="state_group_edges",
- values={
- "state_group": state_group,
- "prev_state_group": prev_group,
- },
- )
-
- self.db.simple_delete_txn(
- txn,
- table="state_groups_state",
- keyvalues={"state_group": state_group},
- )
-
- self.db.simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": state_group,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in iteritems(delta_state)
- ],
- )
-
- progress = {
- "last_state_group": state_group,
- "rows_inserted": rows_inserted + batch_size,
- "max_group": max_group,
- }
-
- self.db.updates._background_update_progress_txn(
- txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
- )
-
- return False, batch_size
-
- finished, result = yield self.db.runInteraction(
- self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
- )
-
- if finished:
- yield self.db.updates._end_background_update(
- self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
- )
-
- return result * BATCH_SIZE_SCALE_FACTOR
-
- @defer.inlineCallbacks
- def _background_index_state(self, progress, batch_size):
- def reindex_txn(conn):
- conn.rollback()
- if isinstance(self.database_engine, PostgresEngine):
- # postgres insists on autocommit for the index
- conn.set_session(autocommit=True)
- try:
- txn = conn.cursor()
- txn.execute(
- "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
- " ON state_groups_state(state_group, type, state_key)"
- )
- txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
- finally:
- conn.set_session(autocommit=False)
- else:
- txn = conn.cursor()
- txn.execute(
- "CREATE INDEX state_groups_state_type_idx"
- " ON state_groups_state(state_group, type, state_key)"
- )
- txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
-
- yield self.db.runWithConnection(reindex_txn)
-
- yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
-
- return 1
-
-
-class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
+class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
""" Keeps track of the state at a given event.
This is done by the concept of `state groups`. Every event is a assigned
diff --git a/synapse/storage/data_stores/state/__init__.py b/synapse/storage/data_stores/state/__init__.py
new file mode 100644
index 0000000000..86e09f6229
--- /dev/null
+++ b/synapse/storage/data_stores/state/__init__.py
@@ -0,0 +1,16 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.data_stores.state.store import StateGroupDataStore # noqa: F401
diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/data_stores/state/bg_updates.py
new file mode 100644
index 0000000000..e8edaf9f7b
--- /dev/null
+++ b/synapse/storage/data_stores/state/bg_updates.py
@@ -0,0 +1,374 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from six import iteritems
+
+from twisted.internet import defer
+
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.state import StateFilter
+
+logger = logging.getLogger(__name__)
+
+
+MAX_STATE_DELTA_HOPS = 100
+
+
+class StateGroupBackgroundUpdateStore(SQLBaseStore):
+ """Defines functions related to state groups needed to run the state backgroud
+ updates.
+ """
+
+ def _count_state_group_hops_txn(self, txn, state_group):
+ """Given a state group, count how many hops there are in the tree.
+
+ This is used to ensure the delta chains don't get too long.
+ """
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = """
+ WITH RECURSIVE state(state_group) AS (
+ VALUES(?::bigint)
+ UNION ALL
+ SELECT prev_state_group FROM state_group_edges e, state s
+ WHERE s.state_group = e.state_group
+ )
+ SELECT count(*) FROM state;
+ """
+
+ txn.execute(sql, (state_group,))
+ row = txn.fetchone()
+ if row and row[0]:
+ return row[0]
+ else:
+ return 0
+ else:
+ # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+ # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+ next_group = state_group
+ count = 0
+
+ while next_group:
+ next_group = self.db.simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": next_group},
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+ if next_group:
+ count += 1
+
+ return count
+
+ def _get_state_groups_from_groups_txn(
+ self, txn, groups, state_filter=StateFilter.all()
+ ):
+ results = {group: {} for group in groups}
+
+ where_clause, where_args = state_filter.make_sql_filter_clause()
+
+ # Unless the filter clause is empty, we're going to append it after an
+ # existing where clause
+ if where_clause:
+ where_clause = " AND (%s)" % (where_clause,)
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # Temporarily disable sequential scans in this transaction. This is
+ # a temporary hack until we can add the right indices in
+ txn.execute("SET LOCAL enable_seqscan=off")
+
+ # The below query walks the state_group tree so that the "state"
+ # table includes all state_groups in the tree. It then joins
+ # against `state_groups_state` to fetch the latest state.
+ # It assumes that previous state groups are always numerically
+ # lesser.
+ # The PARTITION is used to get the event_id in the greatest state
+ # group for the given type, state_key.
+ # This may return multiple rows per (type, state_key), but last_value
+ # should be the same.
+ sql = """
+ WITH RECURSIVE state(state_group) AS (
+ VALUES(?::bigint)
+ UNION ALL
+ SELECT prev_state_group FROM state_group_edges e, state s
+ WHERE s.state_group = e.state_group
+ )
+ SELECT DISTINCT type, state_key, last_value(event_id) OVER (
+ PARTITION BY type, state_key ORDER BY state_group ASC
+ ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
+ ) AS event_id FROM state_groups_state
+ WHERE state_group IN (
+ SELECT state_group FROM state
+ )
+ """
+
+ for group in groups:
+ args = [group]
+ args.extend(where_args)
+
+ txn.execute(sql + where_clause, args)
+ for row in txn:
+ typ, state_key, event_id = row
+ key = (typ, state_key)
+ results[group][key] = event_id
+ else:
+ max_entries_returned = state_filter.max_entries_returned()
+
+ # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+ # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+ for group in groups:
+ next_group = group
+
+ while next_group:
+ # We did this before by getting the list of group ids, and
+ # then passing that list to sqlite to get latest event for
+ # each (type, state_key). However, that was terribly slow
+ # without the right indices (which we can't add until
+ # after we finish deduping state, which requires this func)
+ args = [next_group]
+ args.extend(where_args)
+
+ txn.execute(
+ "SELECT type, state_key, event_id FROM state_groups_state"
+ " WHERE state_group = ? " + where_clause,
+ args,
+ )
+ results[group].update(
+ ((typ, state_key), event_id)
+ for typ, state_key, event_id in txn
+ if (typ, state_key) not in results[group]
+ )
+
+ # If the number of entries in the (type,state_key)->event_id dict
+ # matches the number of (type,state_keys) types we were searching
+ # for, then we must have found them all, so no need to go walk
+ # further down the tree... UNLESS our types filter contained
+ # wildcards (i.e. Nones) in which case we have to do an exhaustive
+ # search
+ if (
+ max_entries_returned is not None
+ and len(results[group]) == max_entries_returned
+ ):
+ break
+
+ next_group = self.db.simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": next_group},
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+
+ return results
+
+
+class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
+
+ STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
+ STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
+ STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ self.db.updates.register_background_update_handler(
+ self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
+ self._background_deduplicate_state,
+ )
+ self.db.updates.register_background_update_handler(
+ self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
+ )
+ self.db.updates.register_background_index_update(
+ self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME,
+ index_name="state_groups_room_id_idx",
+ table="state_groups",
+ columns=["room_id"],
+ )
+
+ @defer.inlineCallbacks
+ def _background_deduplicate_state(self, progress, batch_size):
+ """This background update will slowly deduplicate state by reencoding
+ them as deltas.
+ """
+ last_state_group = progress.get("last_state_group", 0)
+ rows_inserted = progress.get("rows_inserted", 0)
+ max_group = progress.get("max_group", None)
+
+ BATCH_SIZE_SCALE_FACTOR = 100
+
+ batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
+
+ if max_group is None:
+ rows = yield self.db.execute(
+ "_background_deduplicate_state",
+ None,
+ "SELECT coalesce(max(id), 0) FROM state_groups",
+ )
+ max_group = rows[0][0]
+
+ def reindex_txn(txn):
+ new_last_state_group = last_state_group
+ for count in range(batch_size):
+ txn.execute(
+ "SELECT id, room_id FROM state_groups"
+ " WHERE ? < id AND id <= ?"
+ " ORDER BY id ASC"
+ " LIMIT 1",
+ (new_last_state_group, max_group),
+ )
+ row = txn.fetchone()
+ if row:
+ state_group, room_id = row
+
+ if not row or not state_group:
+ return True, count
+
+ txn.execute(
+ "SELECT state_group FROM state_group_edges"
+ " WHERE state_group = ?",
+ (state_group,),
+ )
+
+ # If we reach a point where we've already started inserting
+ # edges we should stop.
+ if txn.fetchall():
+ return True, count
+
+ txn.execute(
+ "SELECT coalesce(max(id), 0) FROM state_groups"
+ " WHERE id < ? AND room_id = ?",
+ (state_group, room_id),
+ )
+ (prev_group,) = txn.fetchone()
+ new_last_state_group = state_group
+
+ if prev_group:
+ potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+ if potential_hops >= MAX_STATE_DELTA_HOPS:
+ # We want to ensure chains are at most this long,#
+ # otherwise read performance degrades.
+ continue
+
+ prev_state = self._get_state_groups_from_groups_txn(
+ txn, [prev_group]
+ )
+ prev_state = prev_state[prev_group]
+
+ curr_state = self._get_state_groups_from_groups_txn(
+ txn, [state_group]
+ )
+ curr_state = curr_state[state_group]
+
+ if not set(prev_state.keys()) - set(curr_state.keys()):
+ # We can only do a delta if the current has a strict super set
+ # of keys
+
+ delta_state = {
+ key: value
+ for key, value in iteritems(curr_state)
+ if prev_state.get(key, None) != value
+ }
+
+ self.db.simple_delete_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": state_group},
+ )
+
+ self.db.simple_insert_txn(
+ txn,
+ table="state_group_edges",
+ values={
+ "state_group": state_group,
+ "prev_state_group": prev_group,
+ },
+ )
+
+ self.db.simple_delete_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={"state_group": state_group},
+ )
+
+ self.db.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": state_group,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in iteritems(delta_state)
+ ],
+ )
+
+ progress = {
+ "last_state_group": state_group,
+ "rows_inserted": rows_inserted + batch_size,
+ "max_group": max_group,
+ }
+
+ self.db.updates._background_update_progress_txn(
+ txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
+ )
+
+ return False, batch_size
+
+ finished, result = yield self.db.runInteraction(
+ self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
+ )
+
+ if finished:
+ yield self.db.updates._end_background_update(
+ self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
+ )
+
+ return result * BATCH_SIZE_SCALE_FACTOR
+
+ @defer.inlineCallbacks
+ def _background_index_state(self, progress, batch_size):
+ def reindex_txn(conn):
+ conn.rollback()
+ if isinstance(self.database_engine, PostgresEngine):
+ # postgres insists on autocommit for the index
+ conn.set_session(autocommit=True)
+ try:
+ txn = conn.cursor()
+ txn.execute(
+ "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
+ " ON state_groups_state(state_group, type, state_key)"
+ )
+ txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
+ finally:
+ conn.set_session(autocommit=False)
+ else:
+ txn = conn.cursor()
+ txn.execute(
+ "CREATE INDEX state_groups_state_type_idx"
+ " ON state_groups_state(state_group, type, state_key)"
+ )
+ txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
+
+ yield self.db.runWithConnection(reindex_txn)
+
+ yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
+
+ return 1
diff --git a/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql
index ae09fa0065..ae09fa0065 100644
--- a/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql
+++ b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql
index e85699e82e..e85699e82e 100644
--- a/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql
+++ b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql
diff --git a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql
new file mode 100644
index 0000000000..1450313bfa
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql
@@ -0,0 +1,19 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- The following indices are redundant, other indices are equivalent or
+-- supersets
+DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY
diff --git a/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql
index 33980d02f0..33980d02f0 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql
+++ b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/35/state.sql b/synapse/storage/data_stores/state/schema/delta/35/state.sql
index 0f1fa68a89..0f1fa68a89 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/state.sql
+++ b/synapse/storage/data_stores/state/schema/delta/35/state.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql
index 97e5067ef4..97e5067ef4 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql
+++ b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py
index 9fd1ccf6f7..9fd1ccf6f7 100644
--- a/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py
+++ b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py
diff --git a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql
new file mode 100644
index 0000000000..7916ef18b2
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('state_groups_room_id_idx', '{}');
diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql
new file mode 100644
index 0000000000..35f97d6b3d
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql
@@ -0,0 +1,37 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE state_groups (
+ id BIGINT PRIMARY KEY,
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL
+);
+
+CREATE TABLE state_groups_state (
+ state_group BIGINT NOT NULL,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ event_id TEXT NOT NULL
+);
+
+CREATE TABLE state_group_edges (
+ state_group BIGINT NOT NULL,
+ prev_state_group BIGINT NOT NULL
+);
+
+CREATE INDEX state_group_edges_idx ON state_group_edges (state_group);
+CREATE INDEX state_group_edges_prev_idx ON state_group_edges (prev_state_group);
+CREATE INDEX state_groups_state_type_idx ON state_groups_state (state_group, type, state_key);
diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres
new file mode 100644
index 0000000000..fcd926c9fb
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres
@@ -0,0 +1,21 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE state_group_id_seq
+ START WITH 1
+ INCREMENT BY 1
+ NO MINVALUE
+ NO MAXVALUE
+ CACHE 1;
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
new file mode 100644
index 0000000000..d53695f238
--- /dev/null
+++ b/synapse/storage/data_stores/state/store.py
@@ -0,0 +1,640 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from collections import namedtuple
+
+from six import iteritems
+from six.moves import range
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
+from synapse.storage.database import Database
+from synapse.storage.state import StateFilter
+from synapse.util.caches import get_cache_factor_for
+from synapse.util.caches.descriptors import cached
+from synapse.util.caches.dictionary_cache import DictionaryCache
+
+logger = logging.getLogger(__name__)
+
+
+MAX_STATE_DELTA_HOPS = 100
+
+
+class _GetStateGroupDelta(
+ namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
+):
+ """Return type of get_state_group_delta that implements __len__, which lets
+ us use the itrable flag when caching
+ """
+
+ __slots__ = []
+
+ def __len__(self):
+ return len(self.delta_ids) if self.delta_ids else 0
+
+
+class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
+ """A data store for fetching/storing state groups.
+ """
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(StateGroupDataStore, self).__init__(database, db_conn, hs)
+
+ # Originally the state store used a single DictionaryCache to cache the
+ # event IDs for the state types in a given state group to avoid hammering
+ # on the state_group* tables.
+ #
+ # The point of using a DictionaryCache is that it can cache a subset
+ # of the state events for a given state group (i.e. a subset of the keys for a
+ # given dict which is an entry in the cache for a given state group ID).
+ #
+ # However, this poses problems when performing complicated queries
+ # on the store - for instance: "give me all the state for this group, but
+ # limit members to this subset of users", as DictionaryCache's API isn't
+ # rich enough to say "please cache any of these fields, apart from this subset".
+ # This is problematic when lazy loading members, which requires this behaviour,
+ # as without it the cache has no choice but to speculatively load all
+ # state events for the group, which negates the efficiency being sought.
+ #
+ # Rather than overcomplicating DictionaryCache's API, we instead split the
+ # state_group_cache into two halves - one for tracking non-member events,
+ # and the other for tracking member_events. This means that lazy loading
+ # queries can be made in a cache-friendly manner by querying both caches
+ # separately and then merging the result. So for the example above, you
+ # would query the members cache for a specific subset of state keys
+ # (which DictionaryCache will handle efficiently and fine) and the non-members
+ # cache for all state (which DictionaryCache will similarly handle fine)
+ # and then just merge the results together.
+ #
+ # We size the non-members cache to be smaller than the members cache as the
+ # vast majority of state in Matrix (today) is member events.
+
+ self._state_group_cache = DictionaryCache(
+ "*stateGroupCache*",
+ # TODO: this hasn't been tuned yet
+ 50000 * get_cache_factor_for("stateGroupCache"),
+ )
+ self._state_group_members_cache = DictionaryCache(
+ "*stateGroupMembersCache*",
+ 500000 * get_cache_factor_for("stateGroupMembersCache"),
+ )
+
+ @cached(max_entries=10000, iterable=True)
+ def get_state_group_delta(self, state_group):
+ """Given a state group try to return a previous group and a delta between
+ the old and the new.
+
+ Returns:
+ (prev_group, delta_ids), where both may be None.
+ """
+
+ def _get_state_group_delta_txn(txn):
+ prev_group = self.db.simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": state_group},
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+
+ if not prev_group:
+ return _GetStateGroupDelta(None, None)
+
+ delta_ids = self.db.simple_select_list_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={"state_group": state_group},
+ retcols=("type", "state_key", "event_id"),
+ )
+
+ return _GetStateGroupDelta(
+ prev_group,
+ {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
+ )
+
+ return self.db.runInteraction(
+ "get_state_group_delta", _get_state_group_delta_txn
+ )
+
+ @defer.inlineCallbacks
+ def _get_state_groups_from_groups(self, groups, state_filter):
+ """Returns the state groups for a given set of groups, filtering on
+ types of state events.
+
+ Args:
+ groups(list[int]): list of state group IDs to query
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+ Returns:
+ Deferred[dict[int, dict[tuple[str, str], str]]]:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
+ """
+ results = {}
+
+ chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
+ for chunk in chunks:
+ res = yield self.db.runInteraction(
+ "_get_state_groups_from_groups",
+ self._get_state_groups_from_groups_txn,
+ chunk,
+ state_filter,
+ )
+ results.update(res)
+
+ return results
+
+ def _get_state_for_group_using_cache(self, cache, group, state_filter):
+ """Checks if group is in cache. See `_get_state_for_groups`
+
+ Args:
+ cache(DictionaryCache): the state group cache to use
+ group(int): The state group to lookup
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+
+ Returns 2-tuple (`state_dict`, `got_all`).
+ `got_all` is a bool indicating if we successfully retrieved all
+ requests state from the cache, if False we need to query the DB for the
+ missing state.
+ """
+ is_all, known_absent, state_dict_ids = cache.get(group)
+
+ if is_all or state_filter.is_full():
+ # Either we have everything or want everything, either way
+ # `is_all` tells us whether we've gotten everything.
+ return state_filter.filter_state(state_dict_ids), is_all
+
+ # tracks whether any of our requested types are missing from the cache
+ missing_types = False
+
+ if state_filter.has_wildcards():
+ # We don't know if we fetched all the state keys for the types in
+ # the filter that are wildcards, so we have to assume that we may
+ # have missed some.
+ missing_types = True
+ else:
+ # There aren't any wild cards, so `concrete_types()` returns the
+ # complete list of event types we're wanting.
+ for key in state_filter.concrete_types():
+ if key not in state_dict_ids and key not in known_absent:
+ missing_types = True
+ break
+
+ return state_filter.filter_state(state_dict_ids), not missing_types
+
+ @defer.inlineCallbacks
+ def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+ """Gets the state at each of a list of state groups, optionally
+ filtering by type/state_key
+
+ Args:
+ groups (iterable[int]): list of state groups for which we want
+ to get the state.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+ Returns:
+ Deferred[dict[int, dict[tuple[str, str], str]]]:
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
+ """
+
+ member_filter, non_member_filter = state_filter.get_member_split()
+
+ # Now we look them up in the member and non-member caches
+ (
+ non_member_state,
+ incomplete_groups_nm,
+ ) = yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_cache, state_filter=non_member_filter
+ )
+
+ (
+ member_state,
+ incomplete_groups_m,
+ ) = yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_members_cache, state_filter=member_filter
+ )
+
+ state = dict(non_member_state)
+ for group in groups:
+ state[group].update(member_state[group])
+
+ # Now fetch any missing groups from the database
+
+ incomplete_groups = incomplete_groups_m | incomplete_groups_nm
+
+ if not incomplete_groups:
+ return state
+
+ cache_sequence_nm = self._state_group_cache.sequence
+ cache_sequence_m = self._state_group_members_cache.sequence
+
+ # Help the cache hit ratio by expanding the filter a bit
+ db_state_filter = state_filter.return_expanded()
+
+ group_to_state_dict = yield self._get_state_groups_from_groups(
+ list(incomplete_groups), state_filter=db_state_filter
+ )
+
+ # Now lets update the caches
+ self._insert_into_cache(
+ group_to_state_dict,
+ db_state_filter,
+ cache_seq_num_members=cache_sequence_m,
+ cache_seq_num_non_members=cache_sequence_nm,
+ )
+
+ # And finally update the result dict, by filtering out any extra
+ # stuff we pulled out of the database.
+ for group, group_state_dict in iteritems(group_to_state_dict):
+ # We just replace any existing entries, as we will have loaded
+ # everything we need from the database anyway.
+ state[group] = state_filter.filter_state(group_state_dict)
+
+ return state
+
+ def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
+ """Gets the state at each of a list of state groups, optionally
+ filtering by type/state_key, querying from a specific cache.
+
+ Args:
+ groups (iterable[int]): list of state groups for which we want
+ to get the state.
+ cache (DictionaryCache): the cache of group ids to state dicts which
+ we will pass through - either the normal state cache or the specific
+ members state cache.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+
+ Returns:
+ tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
+ of entries in the cache, and the state group ids either missing
+ from the cache or incomplete.
+ """
+ results = {}
+ incomplete_groups = set()
+ for group in set(groups):
+ state_dict_ids, got_all = self._get_state_for_group_using_cache(
+ cache, group, state_filter
+ )
+ results[group] = state_dict_ids
+
+ if not got_all:
+ incomplete_groups.add(group)
+
+ return results, incomplete_groups
+
+ def _insert_into_cache(
+ self,
+ group_to_state_dict,
+ state_filter,
+ cache_seq_num_members,
+ cache_seq_num_non_members,
+ ):
+ """Inserts results from querying the database into the relevant cache.
+
+ Args:
+ group_to_state_dict (dict): The new entries pulled from database.
+ Map from state group to state dict
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+ cache_seq_num_members (int): Sequence number of member cache since
+ last lookup in cache
+ cache_seq_num_non_members (int): Sequence number of member cache since
+ last lookup in cache
+ """
+
+ # We need to work out which types we've fetched from the DB for the
+ # member vs non-member caches. This should be as accurate as possible,
+ # but can be an underestimate (e.g. when we have wild cards)
+
+ member_filter, non_member_filter = state_filter.get_member_split()
+ if member_filter.is_full():
+ # We fetched all member events
+ member_types = None
+ else:
+ # `concrete_types()` will only return a subset when there are wild
+ # cards in the filter, but that's fine.
+ member_types = member_filter.concrete_types()
+
+ if non_member_filter.is_full():
+ # We fetched all non member events
+ non_member_types = None
+ else:
+ non_member_types = non_member_filter.concrete_types()
+
+ for group, group_state_dict in iteritems(group_to_state_dict):
+ state_dict_members = {}
+ state_dict_non_members = {}
+
+ for k, v in iteritems(group_state_dict):
+ if k[0] == EventTypes.Member:
+ state_dict_members[k] = v
+ else:
+ state_dict_non_members[k] = v
+
+ self._state_group_members_cache.update(
+ cache_seq_num_members,
+ key=group,
+ value=state_dict_members,
+ fetched_keys=member_types,
+ )
+
+ self._state_group_cache.update(
+ cache_seq_num_non_members,
+ key=group,
+ value=state_dict_non_members,
+ fetched_keys=non_member_types,
+ )
+
+ def store_state_group(
+ self, event_id, room_id, prev_group, delta_ids, current_state_ids
+ ):
+ """Store a new set of state, returning a newly assigned state group.
+
+ Args:
+ event_id (str): The event ID for which the state was calculated
+ room_id (str)
+ prev_group (int|None): A previous state group for the room, optional.
+ delta_ids (dict|None): The delta between state at `prev_group` and
+ `current_state_ids`, if `prev_group` was given. Same format as
+ `current_state_ids`.
+ current_state_ids (dict): The state to store. Map of (type, state_key)
+ to event_id.
+
+ Returns:
+ Deferred[int]: The state group ID
+ """
+
+ def _store_state_group_txn(txn):
+ if current_state_ids is None:
+ # AFAIK, this can never happen
+ raise Exception("current_state_ids cannot be None")
+
+ state_group = self.database_engine.get_next_state_group_id(txn)
+
+ self.db.simple_insert_txn(
+ txn,
+ table="state_groups",
+ values={"id": state_group, "room_id": room_id, "event_id": event_id},
+ )
+
+ # We persist as a delta if we can, while also ensuring the chain
+ # of deltas isn't tooo long, as otherwise read performance degrades.
+ if prev_group:
+ is_in_db = self.db.simple_select_one_onecol_txn(
+ txn,
+ table="state_groups",
+ keyvalues={"id": prev_group},
+ retcol="id",
+ allow_none=True,
+ )
+ if not is_in_db:
+ raise Exception(
+ "Trying to persist state with unpersisted prev_group: %r"
+ % (prev_group,)
+ )
+
+ potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+ if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+ self.db.simple_insert_txn(
+ txn,
+ table="state_group_edges",
+ values={"state_group": state_group, "prev_state_group": prev_group},
+ )
+
+ self.db.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": state_group,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in iteritems(delta_ids)
+ ],
+ )
+ else:
+ self.db.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": state_group,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in iteritems(current_state_ids)
+ ],
+ )
+
+ # Prefill the state group caches with this group.
+ # It's fine to use the sequence like this as the state group map
+ # is immutable. (If the map wasn't immutable then this prefill could
+ # race with another update)
+
+ current_member_state_ids = {
+ s: ev
+ for (s, ev) in iteritems(current_state_ids)
+ if s[0] == EventTypes.Member
+ }
+ txn.call_after(
+ self._state_group_members_cache.update,
+ self._state_group_members_cache.sequence,
+ key=state_group,
+ value=dict(current_member_state_ids),
+ )
+
+ current_non_member_state_ids = {
+ s: ev
+ for (s, ev) in iteritems(current_state_ids)
+ if s[0] != EventTypes.Member
+ }
+ txn.call_after(
+ self._state_group_cache.update,
+ self._state_group_cache.sequence,
+ key=state_group,
+ value=dict(current_non_member_state_ids),
+ )
+
+ return state_group
+
+ return self.db.runInteraction("store_state_group", _store_state_group_txn)
+
+ def purge_unreferenced_state_groups(
+ self, room_id: str, state_groups_to_delete
+ ) -> defer.Deferred:
+ """Deletes no longer referenced state groups and de-deltas any state
+ groups that reference them.
+
+ Args:
+ room_id: The room the state groups belong to (must all be in the
+ same room).
+ state_groups_to_delete (Collection[int]): Set of all state groups
+ to delete.
+ """
+
+ return self.db.runInteraction(
+ "purge_unreferenced_state_groups",
+ self._purge_unreferenced_state_groups,
+ room_id,
+ state_groups_to_delete,
+ )
+
+ def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete):
+ logger.info(
+ "[purge] found %i state groups to delete", len(state_groups_to_delete)
+ )
+
+ rows = self.db.simple_select_many_txn(
+ txn,
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ retcols=("state_group",),
+ )
+
+ remaining_state_groups = set(
+ row["state_group"]
+ for row in rows
+ if row["state_group"] not in state_groups_to_delete
+ )
+
+ logger.info(
+ "[purge] de-delta-ing %i remaining state groups",
+ len(remaining_state_groups),
+ )
+
+ # Now we turn the state groups that reference to-be-deleted state
+ # groups to non delta versions.
+ for sg in remaining_state_groups:
+ logger.info("[purge] de-delta-ing remaining state group %s", sg)
+ curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
+ curr_state = curr_state[sg]
+
+ self.db.simple_delete_txn(
+ txn, table="state_groups_state", keyvalues={"state_group": sg}
+ )
+
+ self.db.simple_delete_txn(
+ txn, table="state_group_edges", keyvalues={"state_group": sg}
+ )
+
+ self.db.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": sg,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in iteritems(curr_state)
+ ],
+ )
+
+ logger.info("[purge] removing redundant state groups")
+ txn.executemany(
+ "DELETE FROM state_groups_state WHERE state_group = ?",
+ ((sg,) for sg in state_groups_to_delete),
+ )
+ txn.executemany(
+ "DELETE FROM state_groups WHERE id = ?",
+ ((sg,) for sg in state_groups_to_delete),
+ )
+
+ @defer.inlineCallbacks
+ def get_previous_state_groups(self, state_groups):
+ """Fetch the previous groups of the given state groups.
+
+ Args:
+ state_groups (Iterable[int])
+
+ Returns:
+ Deferred[dict[int, int]]: mapping from state group to previous
+ state group.
+ """
+
+ rows = yield self.db.simple_select_many_batch(
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=state_groups,
+ keyvalues={},
+ retcols=("prev_state_group", "state_group"),
+ desc="get_previous_state_groups",
+ )
+
+ return {row["state_group"]: row["prev_state_group"] for row in rows}
+
+ def purge_room_state(self, room_id, state_groups_to_delete):
+ """Deletes all record of a room from state tables
+
+ Args:
+ room_id (str):
+ state_groups_to_delete (list[int]): State groups to delete
+ """
+
+ return self.db.runInteraction(
+ "purge_room_state",
+ self._purge_room_state_txn,
+ room_id,
+ state_groups_to_delete,
+ )
+
+ def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete):
+ # first we have to delete the state groups states
+ logger.info("[purge] removing %s from state_groups_state", room_id)
+
+ self.db.simple_delete_many_txn(
+ txn,
+ table="state_groups_state",
+ column="state_group",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ )
+
+ # ... and the state group edges
+ logger.info("[purge] removing %s from state_group_edges", room_id)
+
+ self.db.simple_delete_many_txn(
+ txn,
+ table="state_group_edges",
+ column="state_group",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ )
+
+ # ... and the state groups
+ logger.info("[purge] removing %s from state_groups", room_id)
+
+ self.db.simple_delete_many_txn(
+ txn,
+ table="state_groups",
+ column="id",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ )
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index ec19ae1d9d..1003dd84a5 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -24,9 +24,11 @@ from six.moves import intern, range
from prometheus_client import Histogram
+from twisted.enterprise import adbapi
from twisted.internet import defer
from synapse.api.errors import StoreError
+from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
@@ -74,6 +76,37 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
}
+def make_pool(
+ reactor, db_config: DatabaseConnectionConfig, engine
+) -> adbapi.ConnectionPool:
+ """Get the connection pool for the database.
+ """
+
+ return adbapi.ConnectionPool(
+ db_config.config["name"],
+ cp_reactor=reactor,
+ cp_openfun=engine.on_new_connection,
+ **db_config.config.get("args", {})
+ )
+
+
+def make_conn(db_config: DatabaseConnectionConfig, engine):
+ """Make a new connection to the database and return it.
+
+ Returns:
+ Connection
+ """
+
+ db_params = {
+ k: v
+ for k, v in db_config.config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ db_conn = engine.module.connect(**db_params)
+ engine.on_new_connection(db_conn)
+ return db_conn
+
+
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
@@ -218,10 +251,11 @@ class Database(object):
_TXN_ID = 0
- def __init__(self, hs):
+ def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
self.hs = hs
self._clock = hs.get_clock()
- self._db_pool = hs.get_db_pool()
+ self._database_config = database_config
+ self._db_pool = make_pool(hs.get_reactor(), database_config, engine)
self.updates = BackgroundUpdater(hs, self)
@@ -234,7 +268,7 @@ class Database(object):
# to watch it
self._txn_perf_counters = PerformanceCounters()
- self.engine = hs.database_engine
+ self.engine = engine
# A set of tables that are not safe to use native upserts in.
self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
@@ -255,6 +289,11 @@ class Database(object):
self._check_safe_to_upsert,
)
+ def is_running(self):
+ """Is the database pool currently running
+ """
+ return self._db_pool.running
+
@defer.inlineCallbacks
def _check_safe_to_upsert(self):
"""
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index cbc74cd302..df039a072d 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -16,8 +16,6 @@
import struct
import threading
-from synapse.storage.prepare_database import prepare_database
-
class Sqlite3Engine(object):
single_threaded = True
@@ -62,6 +60,10 @@ class Sqlite3Engine(object):
return sql
def on_new_connection(self, db_conn):
+
+ # We need to import here to avoid an import loop.
+ from synapse.storage.prepare_database import prepare_database
+
if self._is_in_memory:
# In memory databases need to be rebuilt each time. Ideally we'd
# reuse the same connection as we do when starting up, but that
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index fa03ca9ff7..1ed44925fc 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -183,7 +183,7 @@ class EventsPersistenceStorage(object):
# so we use separate variables here even though they point to the same
# store for now.
self.main_store = stores.main
- self.state_store = stores.main
+ self.state_store = stores.state
self._clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 731e1c9d9c..e70026b80a 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -18,6 +18,7 @@ import imp
import logging
import os
import re
+from collections import Counter
import attr
@@ -41,7 +42,7 @@ class UpgradeDatabaseException(PrepareDatabaseException):
pass
-def prepare_database(db_conn, database_engine, config):
+def prepare_database(db_conn, database_engine, config, data_stores=["main", "state"]):
"""Prepares a database for usage. Will either create all necessary tables
or upgrade from an older schema version.
@@ -54,11 +55,10 @@ def prepare_database(db_conn, database_engine, config):
config (synapse.config.homeserver.HomeServerConfig|None):
application config, or None if we are connecting to an existing
database which we expect to be configured already
+ data_stores (list[str]): The name of the data stores that will be used
+ with this database. Defaults to all data stores.
"""
- # For now we only have the one datastore.
- data_stores = ["main"]
-
try:
cur = db_conn.cursor()
version_info = _get_or_create_schema_state(cur, database_engine)
@@ -70,7 +70,10 @@ def prepare_database(db_conn, database_engine, config):
if user_version != SCHEMA_VERSION:
# If we don't pass in a config file then we are expecting to
# have already upgraded the DB.
- raise UpgradeDatabaseException("Database needs to be upgraded")
+ raise UpgradeDatabaseException(
+ "Expected database schema version %i but got %i"
+ % (SCHEMA_VERSION, user_version)
+ )
else:
_upgrade_existing_database(
cur,
@@ -313,6 +316,9 @@ def _upgrade_existing_database(
)
)
+ # Used to check if we have any duplicate file names
+ file_name_counter = Counter()
+
# Now find which directories have anything of interest.
directory_entries = []
for directory in directories:
@@ -323,6 +329,9 @@ def _upgrade_existing_database(
_DirectoryListing(file_name, os.path.join(directory, file_name))
for file_name in file_names
)
+
+ for file_name in file_names:
+ file_name_counter[file_name] += 1
except FileNotFoundError:
# Data stores can have empty entries for a given version delta.
pass
@@ -331,6 +340,17 @@ def _upgrade_existing_database(
"Could not open delta dir for version %d: %s" % (v, directory)
)
+ duplicates = set(
+ file_name for file_name, count in file_name_counter.items() if count > 1
+ )
+ if duplicates:
+ # We don't support using the same file name in the same delta version.
+ raise PrepareDatabaseException(
+ "Found multiple delta files with the same name in v%d: %s",
+ v,
+ duplicates,
+ )
+
# We sort to ensure that we apply the delta files in a consistent
# order (to avoid bugs caused by inconsistent directory listing order)
directory_entries.sort()
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index a368182034..d6a7bd7834 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -58,7 +58,7 @@ class PurgeEventsStorage(object):
sg_to_delete = yield self._find_unreferenced_groups(state_groups)
- yield self.stores.main.purge_unreferenced_state_groups(room_id, sg_to_delete)
+ yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
@defer.inlineCallbacks
def _find_unreferenced_groups(self, state_groups):
@@ -102,7 +102,7 @@ class PurgeEventsStorage(object):
# groups that are referenced.
current_search -= referenced
- edges = yield self.stores.main.get_previous_state_groups(current_search)
+ edges = yield self.stores.state.get_previous_state_groups(current_search)
prevs = set(edges.values())
# We don't bother re-handling groups we've already seen
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 3735846899..cbeb586014 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -342,7 +342,7 @@ class StateGroupStorage(object):
(prev_group, delta_ids)
"""
- return self.stores.main.get_state_group_delta(state_group)
+ return self.stores.state.get_state_group_delta(state_group)
@defer.inlineCallbacks
def get_state_groups_ids(self, _room_id, event_ids):
@@ -362,7 +362,7 @@ class StateGroupStorage(object):
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
- group_to_state = yield self.stores.main._get_state_for_groups(groups)
+ group_to_state = yield self.stores.state._get_state_for_groups(groups)
return group_to_state
@@ -423,7 +423,7 @@ class StateGroupStorage(object):
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
- return self.stores.main._get_state_groups_from_groups(groups, state_filter)
+ return self.stores.state._get_state_groups_from_groups(groups, state_filter)
@defer.inlineCallbacks
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
@@ -439,7 +439,7 @@ class StateGroupStorage(object):
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
- group_to_state = yield self.stores.main._get_state_for_groups(
+ group_to_state = yield self.stores.state._get_state_for_groups(
groups, state_filter
)
@@ -476,7 +476,7 @@ class StateGroupStorage(object):
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
- group_to_state = yield self.stores.main._get_state_for_groups(
+ group_to_state = yield self.stores.state._get_state_for_groups(
groups, state_filter
)
@@ -532,7 +532,7 @@ class StateGroupStorage(object):
Deferred[dict[int, dict[tuple[str, str], str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
- return self.stores.main._get_state_for_groups(groups, state_filter)
+ return self.stores.state._get_state_for_groups(groups, state_filter)
def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
@@ -552,6 +552,6 @@ class StateGroupStorage(object):
Returns:
Deferred[int]: The state group ID
"""
- return self.stores.main.store_state_group(
+ return self.stores.state.store_state_group(
event_id, room_id, prev_group, delta_ids, current_state_ids
)
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index b91fb2db7b..fcd2aaa9c9 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict
+
from twisted.internet import defer
from synapse.handlers.account_data import AccountDataEventSource
@@ -35,7 +37,7 @@ class EventSources(object):
def __init__(self, hs):
self.sources = {
name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items()
- }
+ } # type: Dict[str, Any]
self.store = hs.get_datastore()
@defer.inlineCallbacks
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 84f5ae22c3..2e8f6543e5 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -271,7 +271,7 @@ class _CacheDescriptorBase(object):
else:
self.function_to_call = orig
- arg_spec = inspect.getargspec(orig)
+ arg_spec = inspect.getfullargspec(orig)
all_args = arg_spec.args
if "cache_context" in all_args:
diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py
deleted file mode 100644
index 8318db8d2c..0000000000
--- a/synapse/util/caches/snapshot_cache.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.util.async_helpers import ObservableDeferred
-
-
-class SnapshotCache(object):
- """Cache for snapshots like the response of /initialSync.
- The response of initialSync only has to be a recent snapshot of the
- server state. It shouldn't matter to clients if it is a few minutes out
- of date.
-
- This caches a deferred response. Until the deferred completes it will be
- returned from the cache. This means that if the client retries the request
- while the response is still being computed, that original response will be
- used rather than trying to compute a new response.
-
- Once the deferred completes it will removed from the cache after 5 minutes.
- We delay removing it from the cache because a client retrying its request
- could race with us finishing computing the response.
-
- Rather than tracking precisely how long something has been in the cache we
- keep two generations of completed responses. Every 5 minutes discard the
- old generation, move the new generation to the old generation, and set the
- new generation to be empty. This means that a result will be in the cache
- somewhere between 5 and 10 minutes.
- """
-
- DURATION_MS = 5 * 60 * 1000 # Cache results for 5 minutes.
-
- def __init__(self):
- self.pending_result_cache = {} # Request that haven't finished yet.
- self.prev_result_cache = {} # The older requests that have finished.
- self.next_result_cache = {} # The newer requests that have finished.
- self.time_last_rotated_ms = 0
-
- def rotate(self, time_now_ms):
- # Rotate once if the cache duration has passed since the last rotation.
- if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
- self.prev_result_cache = self.next_result_cache
- self.next_result_cache = {}
- self.time_last_rotated_ms += self.DURATION_MS
-
- # Rotate again if the cache duration has passed twice since the last
- # rotation.
- if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
- self.prev_result_cache = self.next_result_cache
- self.next_result_cache = {}
- self.time_last_rotated_ms = time_now_ms
-
- def get(self, time_now_ms, key):
- self.rotate(time_now_ms)
- # This cache is intended to deduplicate requests, so we expect it to be
- # missed most of the time. So we just lookup the key in all of the
- # dictionaries rather than trying to short circuit the lookup if the
- # key is found.
- result = self.prev_result_cache.get(key)
- result = self.next_result_cache.get(key, result)
- result = self.pending_result_cache.get(key, result)
- if result is not None:
- return result.observe()
- else:
- return None
-
- def set(self, time_now_ms, key, deferred):
- self.rotate(time_now_ms)
-
- result = ObservableDeferred(deferred)
-
- self.pending_result_cache[key] = result
-
- def shuffle_along(r):
- # When the deferred completes we shuffle it along to the first
- # generation of the result cache. So that it will eventually
- # expire from the rotation of that cache.
- self.next_result_cache[key] = result
- self.pending_result_cache.pop(key, None)
- return r
-
- result.addBoth(shuffle_along)
-
- return result.observe()
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index fdfa2cbbc4..854eb6c024 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -183,10 +183,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
- test_replace_master_key.skip = (
- "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
- )
-
@defer.inlineCallbacks
def test_reupload_signatures(self):
"""re-uploading a signature should not fail"""
@@ -507,7 +503,3 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
],
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
)
-
- test_upload_signatures.skip = (
- "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
- )
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 92b8726093..596ddc6970 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -64,28 +64,29 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
mock_federation_client = Mock(spec=["put_json"])
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
+ datastores = Mock()
+ datastores.main = Mock(
+ spec=[
+ # Bits that Federation needs
+ "prep_send_transaction",
+ "delivered_txn",
+ "get_received_txn_response",
+ "set_received_txn_response",
+ "get_destination_retry_timings",
+ "get_devices_by_remote",
+ # Bits that user_directory needs
+ "get_user_directory_stream_pos",
+ "get_current_state_deltas",
+ "get_device_updates_by_remote",
+ ]
+ )
+
hs = self.setup_test_homeserver(
- datastore=(
- Mock(
- spec=[
- # Bits that Federation needs
- "prep_send_transaction",
- "delivered_txn",
- "get_received_txn_response",
- "set_received_txn_response",
- "get_destination_retry_timings",
- "get_device_updates_by_remote",
- # Bits that user_directory needs
- "get_user_directory_stream_pos",
- "get_current_state_deltas",
- ]
- )
- ),
- notifier=Mock(),
- http_client=mock_federation_client,
- keyring=mock_keyring,
+ notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring
)
+ hs.datastores = datastores
+
return hs
def prepare(self, reactor, clock, hs):
diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index 358b593cd4..80187406bc 100644
--- a/tests/push/test_email.py
+++ b/tests/push/test_email.py
@@ -165,6 +165,7 @@ class EmailPusherTests(HomeserverTestCase):
pushers = self.get_success(
self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
)
+ pushers = list(pushers)
self.assertEqual(len(pushers), 1)
last_stream_ordering = pushers[0]["last_stream_ordering"]
@@ -175,6 +176,7 @@ class EmailPusherTests(HomeserverTestCase):
pushers = self.get_success(
self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
)
+ pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
@@ -192,5 +194,6 @@ class EmailPusherTests(HomeserverTestCase):
pushers = self.get_success(
self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id))
)
+ pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index af2327fb66..fe3441f081 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -104,6 +104,7 @@ class HTTPPusherTests(HomeserverTestCase):
pushers = self.get_success(
self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
)
+ pushers = list(pushers)
self.assertEqual(len(pushers), 1)
last_stream_ordering = pushers[0]["last_stream_ordering"]
@@ -114,6 +115,7 @@ class HTTPPusherTests(HomeserverTestCase):
pushers = self.get_success(
self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
)
+ pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
@@ -132,6 +134,7 @@ class HTTPPusherTests(HomeserverTestCase):
pushers = self.get_success(
self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
)
+ pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
last_stream_ordering = pushers[0]["last_stream_ordering"]
@@ -151,5 +154,6 @@ class HTTPPusherTests(HomeserverTestCase):
pushers = self.get_success(
self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
)
+ pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 3dae83c543..2a1e7c7166 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -20,7 +20,7 @@ from synapse.replication.tcp.client import (
ReplicationClientHandler,
)
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
-from synapse.storage.database import Database
+from synapse.storage.database import make_conn
from tests import unittest
from tests.server import FakeTransport
@@ -41,10 +41,12 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
+ db_config = hs.config.database.get_single_database()
self.master_store = self.hs.get_datastore()
self.storage = hs.get_storage()
+ database = hs.get_datastores().databases[0]
self.slaved_store = self.STORE_TYPE(
- Database(hs), self.hs.get_db_conn(), self.hs
+ database, make_conn(db_config, database.engine), self.hs
)
self.event_id = 0
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 12c5e95cb5..8df58b4a63 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -237,6 +237,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase):
config = self.default_config()
config["require_auth_for_profile_requests"] = True
+ config["limit_profile_requests_to_users_who_share_rooms"] = True
self.hs = self.setup_test_homeserver(config=config)
return self.hs
@@ -309,6 +310,7 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
config = self.default_config()
config["require_auth_for_profile_requests"] = True
+ config["limit_profile_requests_to_users_who_share_rooms"] = True
self.hs = self.setup_test_homeserver(config=config)
return self.hs
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index c0d0d2b44e..d0c997e385 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -391,9 +391,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# Email config.
self.email_attempts = []
- def sendmail(*args, **kwargs):
+ async def sendmail(*args, **kwargs):
self.email_attempts.append((args, kwargs))
- return
config["email"] = {
"enable_notifs": True,
diff --git a/tests/server.py b/tests/server.py
index 2b7cf4242e..a554dfdd57 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -302,41 +302,42 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
Set up a synchronous test server, driven by the reactor used by
the homeserver.
"""
- d = _sth(cleanup_func, *args, **kwargs).result
+ server = _sth(cleanup_func, *args, **kwargs)
- if isinstance(d, Failure):
- d.raiseException()
+ database = server.config.database.get_single_database()
# Make the thread pool synchronous.
- clock = d.get_clock()
- pool = d.get_db_pool()
-
- def runWithConnection(func, *args, **kwargs):
- return threads.deferToThreadPool(
- pool._reactor,
- pool.threadpool,
- pool._runWithConnection,
- func,
- *args,
- **kwargs
- )
-
- def runInteraction(interaction, *args, **kwargs):
- return threads.deferToThreadPool(
- pool._reactor,
- pool.threadpool,
- pool._runInteraction,
- interaction,
- *args,
- **kwargs
- )
+ clock = server.get_clock()
+
+ for database in server.get_datastores().databases:
+ pool = database._db_pool
+
+ def runWithConnection(func, *args, **kwargs):
+ return threads.deferToThreadPool(
+ pool._reactor,
+ pool.threadpool,
+ pool._runWithConnection,
+ func,
+ *args,
+ **kwargs
+ )
+
+ def runInteraction(interaction, *args, **kwargs):
+ return threads.deferToThreadPool(
+ pool._reactor,
+ pool.threadpool,
+ pool._runInteraction,
+ interaction,
+ *args,
+ **kwargs
+ )
- if pool:
pool.runWithConnection = runWithConnection
pool.runInteraction = runInteraction
pool.threadpool = ThreadPool(clock._reactor)
pool.running = True
- return d
+
+ return server
def get_clock():
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 2e521e9ab7..fd52512696 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -28,7 +28,7 @@ from synapse.storage.data_stores.main.appservice import (
ApplicationServiceStore,
ApplicationServiceTransactionStore,
)
-from synapse.storage.database import Database
+from synapse.storage.database import Database, make_conn
from tests import unittest
from tests.utils import setup_test_homeserver
@@ -55,8 +55,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts
- database = Database(hs)
- self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs)
+ database = hs.get_datastores().databases[0]
+ self.store = ApplicationServiceStore(
+ database, make_conn(database._database_config, database.engine), hs
+ )
def tearDown(self):
# TODO: suboptimal that we need to create files for tests!
@@ -111,9 +113,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
hs.config.event_cache_size = 1
hs.config.password_providers = []
- self.db_pool = hs.get_db_pool()
- self.engine = hs.database_engine
-
self.as_list = [
{"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
{"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"},
@@ -125,8 +124,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.as_yaml_files = []
- database = Database(hs)
- self.store = TestTransactionStore(database, hs.get_db_conn(), hs)
+ # We assume there is only one database in these tests
+ database = hs.get_datastores().databases[0]
+ self.db_pool = database._db_pool
+ self.engine = database.engine
+
+ db_config = hs.config.get_single_database()
+ self.store = TestTransactionStore(
+ database, make_conn(db_config, self.engine), hs
+ )
def _add_service(self, url, as_token, id):
as_yaml = dict(
@@ -419,7 +425,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs.config.event_cache_size = 1
hs.config.password_providers = []
- ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+ database = hs.get_datastores().databases[0]
+ ApplicationServiceStore(
+ database, make_conn(database._database_config, database.engine), hs
+ )
@defer.inlineCallbacks
def test_duplicate_ids(self):
@@ -435,7 +444,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
- ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+ database = hs.get_datastores().databases[0]
+ ApplicationServiceStore(
+ database, make_conn(database._database_config, database.engine), hs
+ )
e = cm.exception
self.assertIn(f1, str(e))
@@ -456,7 +468,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
- ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs)
+ database = hs.get_datastores().databases[0]
+ ApplicationServiceStore(
+ database, make_conn(database._database_config, database.engine), hs
+ )
e = cm.exception
self.assertIn(f1, str(e))
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index 537cfe9f64..cdee0a9e60 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -52,15 +52,17 @@ class SQLBaseStoreTestCase(unittest.TestCase):
config = Mock()
config._disable_native_upserts = True
config.event_cache_size = 1
- config.database_config = {"name": "sqlite3"}
- engine = create_engine(config.database_config)
+ hs = TestHomeServer("test", config=config)
+
+ sqlite_config = {"name": "sqlite3"}
+ engine = create_engine(sqlite_config)
fake_engine = Mock(wraps=engine)
fake_engine.can_native_upsert = False
- hs = TestHomeServer(
- "test", db_pool=self.db_pool, config=config, database_engine=fake_engine
- )
- self.datastore = SQLBaseStore(Database(hs), None, hs)
+ db = Database(Mock(), Mock(config=sqlite_config), fake_engine)
+ db._db_pool = self.db_pool
+
+ self.datastore = SQLBaseStore(db, None, hs)
@defer.inlineCallbacks
def test_insert_1col(self):
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index fc279340d4..bf674dd184 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -37,9 +37,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(12345678)
user_id = "@user:id"
+ device_id = "MY_DEVICE"
+
+ # Insert a user IP
+ self.get_success(self.store.store_device(user_id, device_id, "display name",))
self.get_success(
self.store.insert_client_ip(
- user_id, "access_token", "ip", "user_agent", "device_id"
+ user_id, "access_token", "ip", "user_agent", device_id
)
)
@@ -47,14 +51,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(10)
result = self.get_success(
- self.store.get_last_client_ip_by_device(user_id, "device_id")
+ self.store.get_last_client_ip_by_device(user_id, device_id)
)
- r = result[(user_id, "device_id")]
+ r = result[(user_id, device_id)]
self.assertDictContainsSubset(
{
"user_id": user_id,
- "device_id": "device_id",
+ "device_id": device_id,
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 12345678000,
@@ -209,14 +213,16 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.store.db.updates.do_next_background_update(100), by=0.1
)
- # Insert a user IP
user_id = "@user:id"
+ device_id = "MY_DEVICE"
+
+ # Insert a user IP
+ self.get_success(self.store.store_device(user_id, device_id, "display name",))
self.get_success(
self.store.insert_client_ip(
- user_id, "access_token", "ip", "user_agent", "device_id"
+ user_id, "access_token", "ip", "user_agent", device_id
)
)
-
# Force persisting to disk
self.reactor.advance(200)
@@ -224,7 +230,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.get_success(
self.store.db.simple_update(
table="devices",
- keyvalues={"user_id": user_id, "device_id": "device_id"},
+ keyvalues={"user_id": user_id, "device_id": device_id},
updatevalues={"last_seen": None, "ip": None, "user_agent": None},
desc="test_devices_last_seen_bg_update",
)
@@ -232,14 +238,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# We should now get nulls when querying
result = self.get_success(
- self.store.get_last_client_ip_by_device(user_id, "device_id")
+ self.store.get_last_client_ip_by_device(user_id, device_id)
)
- r = result[(user_id, "device_id")]
+ r = result[(user_id, device_id)]
self.assertDictContainsSubset(
{
"user_id": user_id,
- "device_id": "device_id",
+ "device_id": device_id,
"ip": None,
"user_agent": None,
"last_seen": None,
@@ -272,14 +278,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# We should now get the correct result again
result = self.get_success(
- self.store.get_last_client_ip_by_device(user_id, "device_id")
+ self.store.get_last_client_ip_by_device(user_id, device_id)
)
- r = result[(user_id, "device_id")]
+ r = result[(user_id, device_id)]
self.assertDictContainsSubset(
{
"user_id": user_id,
- "device_id": "device_id",
+ "device_id": device_id,
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 0,
@@ -296,11 +302,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.store.db.updates.do_next_background_update(100), by=0.1
)
- # Insert a user IP
user_id = "@user:id"
+ device_id = "MY_DEVICE"
+
+ # Insert a user IP
+ self.get_success(self.store.store_device(user_id, device_id, "display name",))
self.get_success(
self.store.insert_client_ip(
- user_id, "access_token", "ip", "user_agent", "device_id"
+ user_id, "access_token", "ip", "user_agent", device_id
)
)
@@ -324,7 +333,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
"access_token": "access_token",
"ip": "ip",
"user_agent": "user_agent",
- "device_id": "device_id",
+ "device_id": device_id,
"last_seen": 0,
}
],
@@ -347,14 +356,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# But we should still get the correct values for the device
result = self.get_success(
- self.store.get_last_client_ip_by_device(user_id, "device_id")
+ self.store.get_last_client_ip_by_device(user_id, device_id)
)
- r = result[(user_id, "device_id")]
+ r = result[(user_id, device_id)]
self.assertDictContainsSubset(
{
"user_id": user_id,
- "device_id": "device_id",
+ "device_id": device_id,
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 0,
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 4578cc3b60..ed5786865a 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -26,7 +26,6 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
- self.db_pool = hs.get_db_pool()
self.store = hs.get_datastore()
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 43200654f1..d6ecf102f8 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -35,7 +35,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.store = hs.get_datastore()
self.storage = hs.get_storage()
- self.state_datastore = self.store
+ self.state_datastore = self.storage.state.stores.state
self.event_builder_factory = hs.get_event_builder_factory()
self.event_creation_handler = hs.get_event_creation_handler()
diff --git a/tests/test_federation.py b/tests/test_federation.py
index ad165d7295..68684460c6 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -1,6 +1,6 @@
from mock import Mock
-from twisted.internet.defer import maybeDeferred, succeed
+from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
from synapse.events import FrozenEvent
from synapse.logging.context import LoggingContext
@@ -70,8 +70,10 @@ class MessageAcceptTests(unittest.TestCase):
)
# Send the join, it should return None (which is not an error)
- d = self.handler.on_receive_pdu(
- "test.serv", join_event, sent_to_us_directly=True
+ d = ensureDeferred(
+ self.handler.on_receive_pdu(
+ "test.serv", join_event, sent_to_us_directly=True
+ )
)
self.reactor.advance(1)
self.assertEqual(self.successResultOf(d), None)
@@ -119,8 +121,10 @@ class MessageAcceptTests(unittest.TestCase):
)
with LoggingContext(request="lying_event"):
- d = self.handler.on_receive_pdu(
- "test.serv", lying_event, sent_to_us_directly=True
+ d = ensureDeferred(
+ self.handler.on_receive_pdu(
+ "test.serv", lying_event, sent_to_us_directly=True
+ )
)
# Step the reactor, so the database fetches come back
diff --git a/tests/test_state.py b/tests/test_state.py
index 176535947a..e0aae06be4 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -209,7 +209,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
+ prev_state_ids = yield ctx_d.get_prev_state_ids()
self.assertEqual(2, len(prev_state_ids))
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -253,7 +253,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
+ prev_state_ids = yield ctx_d.get_prev_state_ids()
self.assertSetEqual(
{"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
)
@@ -312,7 +312,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_e = context_store["E"]
- prev_state_ids = yield ctx_e.get_prev_state_ids(self.store)
+ prev_state_ids = yield ctx_e.get_prev_state_ids()
self.assertSetEqual(
{"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
)
@@ -387,7 +387,7 @@ class StateTestCase(unittest.TestCase):
ctx_b = context_store["B"]
ctx_d = context_store["D"]
- prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
+ prev_state_ids = yield ctx_d.get_prev_state_ids()
self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
)
@@ -419,10 +419,10 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event, old_state=old_state)
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
self.assertCountEqual(
(e.event_id for e in old_state), current_state_ids.values()
)
@@ -442,10 +442,10 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event, old_state=old_state)
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
self.assertCountEqual(
(e.event_id for e in old_state + [event]), current_state_ids.values()
)
@@ -479,7 +479,7 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event)
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
self.assertEqual(
set([e.event_id for e in old_state]), set(current_state_ids.values())
@@ -511,7 +511,7 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event)
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
self.assertEqual(
set([e.event_id for e in old_state]), set(prev_state_ids.values())
@@ -552,7 +552,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
self.assertEqual(len(current_state_ids), 6)
@@ -594,7 +594,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
self.assertEqual(len(current_state_ids), 6)
@@ -649,7 +649,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
@@ -677,7 +677,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
diff --git a/tests/test_types.py b/tests/test_types.py
index 9ab5f829b0..8d97c751ea 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -17,18 +17,15 @@ from synapse.api.errors import SynapseError
from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
from tests import unittest
-from tests.utils import TestHomeServer
-mock_homeserver = TestHomeServer(hostname="my.domain")
-
-class UserIDTestCase(unittest.TestCase):
+class UserIDTestCase(unittest.HomeserverTestCase):
def test_parse(self):
- user = UserID.from_string("@1234abcd:my.domain")
+ user = UserID.from_string("@1234abcd:test")
self.assertEquals("1234abcd", user.localpart)
- self.assertEquals("my.domain", user.domain)
- self.assertEquals(True, mock_homeserver.is_mine(user))
+ self.assertEquals("test", user.domain)
+ self.assertEquals(True, self.hs.is_mine(user))
def test_pase_empty(self):
with self.assertRaises(SynapseError):
@@ -48,13 +45,13 @@ class UserIDTestCase(unittest.TestCase):
self.assertTrue(userA != userB)
-class RoomAliasTestCase(unittest.TestCase):
+class RoomAliasTestCase(unittest.HomeserverTestCase):
def test_parse(self):
- room = RoomAlias.from_string("#channel:my.domain")
+ room = RoomAlias.from_string("#channel:test")
self.assertEquals("channel", room.localpart)
- self.assertEquals("my.domain", room.domain)
- self.assertEquals(True, mock_homeserver.is_mine(room))
+ self.assertEquals("test", room.domain)
+ self.assertEquals(True, self.hs.is_mine(room))
def test_build(self):
room = RoomAlias("channel", "my.domain")
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index 8b8455c8b7..281b32c4b8 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -179,6 +179,30 @@ class LoggingContextTestCase(unittest.TestCase):
nested_context = nested_logging_context(suffix="bar")
self.assertEqual(nested_context.request, "foo-bar")
+ @defer.inlineCallbacks
+ def test_make_deferred_yieldable_with_await(self):
+ # an async function which retuns an incomplete coroutine, but doesn't
+ # follow the synapse rules.
+
+ async def blocking_function():
+ d = defer.Deferred()
+ reactor.callLater(0, d.callback, None)
+ await d
+
+ sentinel_context = LoggingContext.current_context()
+
+ with LoggingContext() as context_one:
+ context_one.request = "one"
+
+ d1 = make_deferred_yieldable(blocking_function())
+ # make sure that the context was reset by make_deferred_yieldable
+ self.assertIs(LoggingContext.current_context(), sentinel_context)
+
+ yield d1
+
+ # now it should be restored
+ self._check_test_key("one")
+
# a function which returns a deferred which has been "called", but
# which had a function which returned another incomplete deferred on
diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py
deleted file mode 100644
index 1a44f72425..0000000000
--- a/tests/util/test_snapshot_cache.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet.defer import Deferred
-
-from synapse.util.caches.snapshot_cache import SnapshotCache
-
-from .. import unittest
-
-
-class SnapshotCacheTestCase(unittest.TestCase):
- def setUp(self):
- self.cache = SnapshotCache()
- self.cache.DURATION_MS = 1
-
- def test_get_set(self):
- # Check that getting a missing key returns None
- self.assertEquals(self.cache.get(0, "key"), None)
-
- # Check that setting a key with a deferred returns
- # a deferred that resolves when the initial deferred does
- d = Deferred()
- set_result = self.cache.set(0, "key", d)
- self.assertIsNotNone(set_result)
- self.assertFalse(set_result.called)
-
- # Check that getting the key before the deferred has resolved
- # returns a deferred that resolves when the initial deferred does.
- get_result_at_10 = self.cache.get(10, "key")
- self.assertIsNotNone(get_result_at_10)
- self.assertFalse(get_result_at_10.called)
-
- # Check that the returned deferreds resolve when the initial deferred
- # does.
- d.callback("v")
- self.assertTrue(set_result.called)
- self.assertTrue(get_result_at_10.called)
-
- # Check that getting the key after the deferred has resolved
- # before the cache expires returns a resolved deferred.
- get_result_at_11 = self.cache.get(11, "key")
- self.assertIsNotNone(get_result_at_11)
- if isinstance(get_result_at_11, Deferred):
- # The cache may return the actual result rather than a deferred
- self.assertTrue(get_result_at_11.called)
-
- # Check that getting the key after the deferred has resolved
- # after the cache expires returns None
- get_result_at_12 = self.cache.get(12, "key")
- self.assertIsNone(get_result_at_12)
diff --git a/tests/utils.py b/tests/utils.py
index c57da59191..e2e9cafd79 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -30,6 +30,7 @@ from twisted.internet import defer, reactor
from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
from synapse.api.room_versions import RoomVersions
+from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.federation.transport import server as federation_server
@@ -177,7 +178,6 @@ class TestHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
-@defer.inlineCallbacks
def setup_test_homeserver(
cleanup_func,
name="test",
@@ -214,7 +214,7 @@ def setup_test_homeserver(
if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex
- config.database_config = {
+ database_config = {
"name": "psycopg2",
"args": {
"database": test_db,
@@ -226,12 +226,15 @@ def setup_test_homeserver(
},
}
else:
- config.database_config = {
+ database_config = {
"name": "sqlite3",
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
}
- db_engine = create_engine(config.database_config)
+ database = DatabaseConnectionConfig("master", database_config)
+ config.database.databases = [database]
+
+ db_engine = create_engine(database.config)
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
@@ -251,39 +254,30 @@ def setup_test_homeserver(
cur.close()
db_conn.close()
- # we need to configure the connection pool to run the on_new_connection
- # function, so that we can test code that uses custom sqlite functions
- # (like rank).
- config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
-
if datastore is None:
hs = homeserverToUse(
name,
config=config,
- db_config=config.database_config,
version_string="Synapse/tests",
- database_engine=db_engine,
tls_server_context_factory=Mock(),
tls_client_options_factory=Mock(),
reactor=reactor,
**kargs
)
- # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to
- # date db
- if not isinstance(db_engine, PostgresEngine):
- db_conn = hs.get_db_conn()
- yield prepare_database(db_conn, db_engine, config)
- db_conn.commit()
- db_conn.close()
+ hs.setup()
+ if homeserverToUse.__name__ == "TestHomeServer":
+ hs.setup_master()
+
+ if isinstance(db_engine, PostgresEngine):
+ database = hs.get_datastores().databases[0]
- else:
# We need to do cleanup on PostgreSQL
def cleanup():
import psycopg2
# Close all the db pools
- hs.get_db_pool().close()
+ database._db_pool.close()
dropped = False
@@ -322,23 +316,12 @@ def setup_test_homeserver(
# Register the cleanup hook
cleanup_func(cleanup)
- hs.setup()
- if homeserverToUse.__name__ == "TestHomeServer":
- hs.setup_master()
else:
- # If we have been given an explicit datastore we probably want to mock
- # out the DataStores somehow too. This all feels a bit wrong, but then
- # mocking the stores feels wrong too.
- datastores = Mock(datastore=datastore)
-
hs = homeserverToUse(
name,
- db_pool=None,
datastore=datastore,
- datastores=datastores,
config=config,
version_string="Synapse/tests",
- database_engine=db_engine,
tls_server_context_factory=Mock(),
tls_client_options_factory=Mock(),
reactor=reactor,
diff --git a/tox.ini b/tox.ini
index 903a245fb0..1d6428f64f 100644
--- a/tox.ini
+++ b/tox.ini
@@ -171,11 +171,23 @@ basepython = python3.7
skip_install = True
deps =
{[base]deps}
- mypy==0.730
+ mypy==0.750
mypy-zope
env =
MYPYPATH = stubs/
extras = all
commands = mypy \
+ synapse/config/ \
+ synapse/handlers/ui_auth \
synapse/logging/ \
- synapse/config/
+ synapse/module_api \
+ synapse/rest/consent \
+ synapse/rest/media/v0 \
+ synapse/rest/saml2 \
+ synapse/spam_checker_api \
+ synapse/storage/engines \
+ synapse/streams
+
+# To find all folders that pass mypy you run:
+#
+# find synapse/* -type d -not -name __pycache__ -exec bash -c "mypy '{}' > /dev/null" \; -print
|