diff --git a/.circleci/config.yml b/.circleci/config.yml
index 98c217dd1d..5bd2ab2b76 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -4,18 +4,16 @@ jobs:
machine: true
steps:
- checkout
- - run: docker build -f docker/Dockerfile --label gitsha1=${CIRCLE_SHA1} -t matrixdotorg/synapse:${CIRCLE_TAG} -t matrixdotorg/synapse:${CIRCLE_TAG}-py3 .
+ - run: docker build -f docker/Dockerfile --label gitsha1=${CIRCLE_SHA1} -t matrixdotorg/synapse:${CIRCLE_TAG} .
- run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD
- run: docker push matrixdotorg/synapse:${CIRCLE_TAG}
- - run: docker push matrixdotorg/synapse:${CIRCLE_TAG}-py3
dockerhubuploadlatest:
machine: true
steps:
- checkout
- - run: docker build -f docker/Dockerfile --label gitsha1=${CIRCLE_SHA1} -t matrixdotorg/synapse:latest -t matrixdotorg/synapse:latest-py3 .
+ - run: docker build -f docker/Dockerfile --label gitsha1=${CIRCLE_SHA1} -t matrixdotorg/synapse:latest .
- run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD
- run: docker push matrixdotorg/synapse:latest
- - run: docker push matrixdotorg/synapse:latest-py3
workflows:
version: 2
diff --git a/.github/ISSUE_TEMPLATE/BUG_REPORT.md b/.github/ISSUE_TEMPLATE/BUG_REPORT.md
index 75c9b2c9fe..978b699886 100644
--- a/.github/ISSUE_TEMPLATE/BUG_REPORT.md
+++ b/.github/ISSUE_TEMPLATE/BUG_REPORT.md
@@ -4,12 +4,12 @@ about: Create a report to help us improve
---
+<!--
+
**THIS IS NOT A SUPPORT CHANNEL!**
**IF YOU HAVE SUPPORT QUESTIONS ABOUT RUNNING OR CONFIGURING YOUR OWN HOME SERVER**,
please ask in **#synapse:matrix.org** (using a matrix.org account if necessary)
-<!--
-
If you want to report a security issue, please see https://matrix.org/security-disclosure-policy/
This is a bug report template. By following the instructions below and
diff --git a/changelog.d/7314.misc b/changelog.d/7314.misc
new file mode 100644
index 0000000000..30720100c2
--- /dev/null
+++ b/changelog.d/7314.misc
@@ -0,0 +1 @@
+Allow guest access to the `GET /_matrix/client/r0/rooms/{room_id}/members` endpoint, according to MSC2689. Contributed by Awesome Technologies Innovationslabor GmbH.
diff --git a/changelog.d/7372.misc b/changelog.d/7372.misc
new file mode 100644
index 0000000000..67a39f0471
--- /dev/null
+++ b/changelog.d/7372.misc
@@ -0,0 +1 @@
+Reduce the amount of whitespace in JSON stored and sent in responses. Contributed by David Vo.
diff --git a/changelog.d/7736.feature b/changelog.d/7736.feature
deleted file mode 100644
index feb02be234..0000000000
--- a/changelog.d/7736.feature
+++ /dev/null
@@ -1 +0,0 @@
-Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654).
diff --git a/changelog.d/7977.bugfix b/changelog.d/7977.bugfix
new file mode 100644
index 0000000000..c587f13055
--- /dev/null
+++ b/changelog.d/7977.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in Synapse v1.7.2 which caused inaccurate membership counts in the room directory.
diff --git a/changelog.d/7987.misc b/changelog.d/7987.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/7987.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/7989.misc b/changelog.d/7989.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/7989.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/7996.bugfix b/changelog.d/7996.bugfix
new file mode 100644
index 0000000000..1e51f20558
--- /dev/null
+++ b/changelog.d/7996.bugfix
@@ -0,0 +1 @@
+Fix various comments and minor discrepencies in server notices code.
diff --git a/changelog.d/7997.misc b/changelog.d/7997.misc
new file mode 100644
index 0000000000..fd53674bc6
--- /dev/null
+++ b/changelog.d/7997.misc
@@ -0,0 +1 @@
+Implement new experimental push rules for some users.
diff --git a/changelog.d/7999.bugfix b/changelog.d/7999.bugfix
new file mode 100644
index 0000000000..e0b8c4922f
--- /dev/null
+++ b/changelog.d/7999.bugfix
@@ -0,0 +1 @@
+Fix a long standing bug where HTTP HEAD requests resulted in a 400 error.
diff --git a/changelog.d/8000.doc b/changelog.d/8000.doc
new file mode 100644
index 0000000000..8d8fd926e9
--- /dev/null
+++ b/changelog.d/8000.doc
@@ -0,0 +1 @@
+Improve workers docs.
diff --git a/changelog.d/8001.misc b/changelog.d/8001.misc
new file mode 100644
index 0000000000..0be4b37d22
--- /dev/null
+++ b/changelog.d/8001.misc
@@ -0,0 +1 @@
+Remove redundant and unreliable signature check for v1 Identity Service lookup responses.
diff --git a/changelog.d/8003.misc b/changelog.d/8003.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8003.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8008.feature b/changelog.d/8008.feature
new file mode 100644
index 0000000000..c6d381809a
--- /dev/null
+++ b/changelog.d/8008.feature
@@ -0,0 +1 @@
+Add rate limiting to users joining rooms.
diff --git a/changelog.d/8009.misc b/changelog.d/8009.misc
new file mode 100644
index 0000000000..3d58a11313
--- /dev/null
+++ b/changelog.d/8009.misc
@@ -0,0 +1 @@
+Improve the performance of the register endpoint.
diff --git a/changelog.d/8010.doc b/changelog.d/8010.doc
new file mode 100644
index 0000000000..fc8b3f0c3d
--- /dev/null
+++ b/changelog.d/8010.doc
@@ -0,0 +1 @@
+Add documentation for how to undo a room shutdown.
diff --git a/changelog.d/8011.bugfix b/changelog.d/8011.bugfix
new file mode 100644
index 0000000000..c673040de9
--- /dev/null
+++ b/changelog.d/8011.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug which caused two copies of some log lines to be written when synctl was used along with a MemoryHandler logger.
diff --git a/changelog.d/8012.bugfix b/changelog.d/8012.bugfix
new file mode 100644
index 0000000000..c673040de9
--- /dev/null
+++ b/changelog.d/8012.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug which caused two copies of some log lines to be written when synctl was used along with a MemoryHandler logger.
diff --git a/changelog.d/8014.misc b/changelog.d/8014.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8014.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8016.misc b/changelog.d/8016.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8016.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8024.misc b/changelog.d/8024.misc
new file mode 100644
index 0000000000..4bc739502b
--- /dev/null
+++ b/changelog.d/8024.misc
@@ -0,0 +1 @@
+Reduce less useful output in the newsfragment CI step. Add a link to the changelog section of the contributing guide on error.
\ No newline at end of file
diff --git a/changelog.d/8027.misc b/changelog.d/8027.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8027.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8031.misc b/changelog.d/8031.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8031.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8032.misc b/changelog.d/8032.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8032.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8033.misc b/changelog.d/8033.misc
new file mode 100644
index 0000000000..7a9782d14b
--- /dev/null
+++ b/changelog.d/8033.misc
@@ -0,0 +1 @@
+Rename storage layer objects to be more sensible.
diff --git a/changelog.d/8035.misc b/changelog.d/8035.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8035.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8039.misc b/changelog.d/8039.misc
new file mode 100644
index 0000000000..599933c80e
--- /dev/null
+++ b/changelog.d/8039.misc
@@ -0,0 +1 @@
+Revert MSC2654 implementation because of perf issues. Please delete this line when processing the 1.19 changelog.
diff --git a/changelog.d/8040.misc b/changelog.d/8040.misc
new file mode 100644
index 0000000000..a126151392
--- /dev/null
+++ b/changelog.d/8040.misc
@@ -0,0 +1 @@
+Change the default log config to reduce disk I/O and storage for new servers.
diff --git a/changelog.d/8041.misc b/changelog.d/8041.misc
new file mode 100644
index 0000000000..eefa98d744
--- /dev/null
+++ b/changelog.d/8041.misc
@@ -0,0 +1 @@
+Add an assertion on prev_events in create_new_client_event.
diff --git a/changelog.d/8042.misc b/changelog.d/8042.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8042.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8043.misc b/changelog.d/8043.misc
new file mode 100644
index 0000000000..683d553666
--- /dev/null
+++ b/changelog.d/8043.misc
@@ -0,0 +1 @@
+Add a comment to `ServerContextFactory` about the use of `SSLv23_METHOD`.
diff --git a/changelog.d/8044.misc b/changelog.d/8044.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8044.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8045.misc b/changelog.d/8045.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8045.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8048.feature b/changelog.d/8048.feature
new file mode 100644
index 0000000000..8521d1920e
--- /dev/null
+++ b/changelog.d/8048.feature
@@ -0,0 +1 @@
+Add a `/health` endpoint to every configured HTTP listener that can be used as a health check endpoint by load balancers.
diff --git a/changelog.d/8049.misc b/changelog.d/8049.misc
new file mode 100644
index 0000000000..7fce36215d
--- /dev/null
+++ b/changelog.d/8049.misc
@@ -0,0 +1 @@
+Log `OPTIONS` requests at `DEBUG` rather than `INFO` level to reduce amount logged at `INFO`.
diff --git a/changelog.d/8050.misc b/changelog.d/8050.misc
new file mode 100644
index 0000000000..cc8d1af7fa
--- /dev/null
+++ b/changelog.d/8050.misc
@@ -0,0 +1 @@
+Reduce amount of outbound request logging at INFO level.
diff --git a/changelog.d/8051.misc b/changelog.d/8051.misc
new file mode 100644
index 0000000000..9e472cd481
--- /dev/null
+++ b/changelog.d/8051.misc
@@ -0,0 +1 @@
+It is no longer necessary to explicitly define `filters` in the logging configuration. (Continuing to do so is redundant but harmless.)
diff --git a/changelog.d/8052.feature b/changelog.d/8052.feature
new file mode 100644
index 0000000000..6aa020c764
--- /dev/null
+++ b/changelog.d/8052.feature
@@ -0,0 +1 @@
+Allow login to be blocked based on the values of SAML attributes.
diff --git a/changelog.d/8056.docker b/changelog.d/8056.docker
new file mode 100644
index 0000000000..d56734c13a
--- /dev/null
+++ b/changelog.d/8056.docker
@@ -0,0 +1 @@
+We no longer publish Docker images with the `-py3` tag suffix, as announced at https://github.com/matrix-org/synapse/blob/develop/UPGRADE.rst#upgrading-to-v1180.
diff --git a/changelog.d/8058.misc b/changelog.d/8058.misc
new file mode 100644
index 0000000000..41a27e5d72
--- /dev/null
+++ b/changelog.d/8058.misc
@@ -0,0 +1 @@
+Add type hints to `Notifier`.
diff --git a/changelog.d/8060.misc b/changelog.d/8060.misc
new file mode 100644
index 0000000000..a0caf008d7
--- /dev/null
+++ b/changelog.d/8060.misc
@@ -0,0 +1 @@
+Improve typing information on `HomeServer` object.
diff --git a/changelog.d/8061.misc b/changelog.d/8061.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8061.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8062.misc b/changelog.d/8062.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8062.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8063.misc b/changelog.d/8063.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8063.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8064.misc b/changelog.d/8064.misc
new file mode 100644
index 0000000000..41a27e5d72
--- /dev/null
+++ b/changelog.d/8064.misc
@@ -0,0 +1 @@
+Add type hints to `Notifier`.
diff --git a/changelog.d/8066.misc b/changelog.d/8066.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8066.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8067.misc b/changelog.d/8067.misc
new file mode 100644
index 0000000000..f4404b7506
--- /dev/null
+++ b/changelog.d/8067.misc
@@ -0,0 +1 @@
+Add type hints to `synapse.handlers.message` and `synapse.events.builder`.
diff --git a/changelog.d/8069.misc b/changelog.d/8069.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8069.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/changelog.d/8070.misc b/changelog.d/8070.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8070.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/docker/conf/log.config b/docker/conf/log.config
index ed418a57cd..491bbcc87a 100644
--- a/docker/conf/log.config
+++ b/docker/conf/log.config
@@ -4,16 +4,10 @@ formatters:
precise:
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s'
-filters:
- context:
- (): synapse.logging.context.LoggingContextFilter
- request: ""
-
handlers:
console:
class: logging.StreamHandler
formatter: precise
- filters: [context]
loggers:
synapse.storage.SQL:
diff --git a/docs/admin_api/shutdown_room.md b/docs/admin_api/shutdown_room.md
index 2ff552bcb3..9b1cb1c184 100644
--- a/docs/admin_api/shutdown_room.md
+++ b/docs/admin_api/shutdown_room.md
@@ -79,13 +79,20 @@ Response:
the structure can and does change without notice.
First, it's important to understand that a room shutdown is very destructive. Undoing a shutdown is not as simple as pretending it
-never happened - work has to be done to move forward instead of resetting the past.
+never happened - work has to be done to move forward instead of resetting the past. In fact, in some cases it might not be possible
+to recover at all:
-1. For safety reasons, it is recommended to shut down Synapse prior to continuing.
+* If the room was invite-only, your users will need to be re-invited.
+* If the room no longer has any members at all, it'll be impossible to rejoin.
+* The first user to rejoin will have to do so via an alias on a different server.
+
+With all that being said, if you still want to try and recover the room:
+
+1. For safety reasons, shut down Synapse.
2. In the database, run `DELETE FROM blocked_rooms WHERE room_id = '!example:example.org';`
* For caution: it's recommended to run this in a transaction: `BEGIN; DELETE ...;`, verify you got 1 result, then `COMMIT;`.
* The room ID is the same one supplied to the shutdown room API, not the Content Violation room.
-3. Restart Synapse (required).
+3. Restart Synapse.
You will have to manually handle, if you so choose, the following:
diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md
index 7bfb96eff6..fd48ba0874 100644
--- a/docs/reverse_proxy.md
+++ b/docs/reverse_proxy.md
@@ -139,3 +139,10 @@ client IP addresses are recorded correctly.
Having done so, you can then use `https://matrix.example.com` (instead
of `https://matrix.example.com:8448`) as the "Custom server" when
connecting to Synapse from a client.
+
+
+## Health check endpoint
+
+Synapse exposes a health check endpoint for use by reverse proxies.
+Each configured HTTP listener has a `/health` endpoint which always returns
+200 OK (and doesn't get logged).
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 341bd2f858..9235b89fb1 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -746,6 +746,10 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# - one for ratelimiting redactions by room admins. If this is not explicitly
# set then it uses the same ratelimiting as per rc_message. This is useful
# to allow room admins to deal with abuse quickly.
+# - two for ratelimiting number of rooms a user can join, "local" for when
+# users are joining rooms the server is already in (this is cheap) vs
+# "remote" for when users are trying to join rooms not on the server (which
+# can be more expensive)
#
# The defaults are as shown below.
#
@@ -771,6 +775,14 @@ log_config: "CONFDIR/SERVERNAME.log.config"
#rc_admin_redaction:
# per_second: 1
# burst_count: 50
+#
+#rc_joins:
+# local:
+# per_second: 0.1
+# burst_count: 3
+# remote:
+# per_second: 0.01
+# burst_count: 3
# Ratelimiting settings for incoming federation
@@ -1565,6 +1577,17 @@ saml2_config:
#
#grandfathered_mxid_source_attribute: upn
+ # It is possible to configure Synapse to only allow logins if SAML attributes
+ # match particular values. The requirements can be listed under
+ # `attribute_requirements` as shown below. All of the listed attributes must
+ # match for the login to be permitted.
+ #
+ #attribute_requirements:
+ # - attribute: userGroup
+ # value: "staff"
+ # - attribute: department
+ # value: "sales"
+
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
diff --git a/docs/sample_log_config.yaml b/docs/sample_log_config.yaml
index 1a2739455e..55a48a9ed6 100644
--- a/docs/sample_log_config.yaml
+++ b/docs/sample_log_config.yaml
@@ -11,24 +11,33 @@ formatters:
precise:
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s'
-filters:
- context:
- (): synapse.logging.context.LoggingContextFilter
- request: ""
-
handlers:
file:
- class: logging.handlers.RotatingFileHandler
+ class: logging.handlers.TimedRotatingFileHandler
formatter: precise
filename: /var/log/matrix-synapse/homeserver.log
- maxBytes: 104857600
- backupCount: 10
- filters: [context]
+ when: midnight
+ backupCount: 3 # Does not include the current log file.
encoding: utf8
+
+ # Default to buffering writes to log file for efficiency. This means that
+ # will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR
+ # logs will still be flushed immediately.
+ buffer:
+ class: logging.handlers.MemoryHandler
+ target: file
+ # The capacity is the number of log lines that are buffered before
+ # being written to disk. Increasing this will lead to better
+ # performance, at the expensive of it taking longer for log lines to
+ # be written to disk.
+ capacity: 10
+ flushLevel: 30 # Flush for WARNING logs as well
+
+ # A handler that writes logs to stderr. Unused by default, but can be used
+ # instead of "buffer" and "file" in the logger handlers.
console:
class: logging.StreamHandler
formatter: precise
- filters: [context]
loggers:
synapse.storage.SQL:
@@ -36,8 +45,23 @@ loggers:
# information such as access tokens.
level: INFO
+ twisted:
+ # We send the twisted logging directly to the file handler,
+ # to work around https://github.com/matrix-org/synapse/issues/3471
+ # when using "buffer" logger. Use "console" to log to stderr instead.
+ handlers: [file]
+ propagate: false
+
root:
level: INFO
- handlers: [file, console]
+
+ # Write logs to the `buffer` handler, which will buffer them together in memory,
+ # then write them to a file.
+ #
+ # Replace "buffer" with "console" to log to stderr instead. (Note that you'll
+ # also need to update the configuation for the `twisted` logger above, in
+ # this case.)
+ #
+ handlers: [buffer]
disable_existing_loggers: false
diff --git a/docs/systemd-with-workers/workers/federation_reader.yaml b/docs/systemd-with-workers/workers/federation_reader.yaml
index 5b65c7040d..13e69e62c9 100644
--- a/docs/systemd-with-workers/workers/federation_reader.yaml
+++ b/docs/systemd-with-workers/workers/federation_reader.yaml
@@ -1,7 +1,7 @@
worker_app: synapse.app.federation_reader
+worker_name: federation_reader1
worker_replication_host: 127.0.0.1
-worker_replication_port: 9092
worker_replication_http_port: 9093
worker_listeners:
diff --git a/docs/user_directory.md b/docs/user_directory.md
index 37dc71e751..872fc21979 100644
--- a/docs/user_directory.md
+++ b/docs/user_directory.md
@@ -7,6 +7,6 @@ who are present in a publicly viewable room present on the server.
The directory info is stored in various tables, which can (typically after
DB corruption) get stale or out of sync. If this happens, for now the
-solution to fix it is to execute the SQL [here](../synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql)
+solution to fix it is to execute the SQL [here](../synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql)
and then restart synapse. This should then start a background task to
flush the current tables and regenerate the directory.
diff --git a/docs/workers.md b/docs/workers.md
index 80b65a0cec..bfec745897 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -23,7 +23,7 @@ The processes communicate with each other via a Synapse-specific protocol called
feeds streams of newly written data between processes so they can be kept in
sync with the database state.
-When configured to do so, Synapse uses a
+When configured to do so, Synapse uses a
[Redis pub/sub channel](https://redis.io/topics/pubsub) to send the replication
stream between all configured Synapse processes. Additionally, processes may
make HTTP requests to each other, primarily for operations which need to wait
@@ -66,23 +66,31 @@ https://hub.docker.com/r/matrixdotorg/synapse/.
To make effective use of the workers, you will need to configure an HTTP
reverse-proxy such as nginx or haproxy, which will direct incoming requests to
-the correct worker, or to the main synapse instance. See
+the correct worker, or to the main synapse instance. See
[reverse_proxy.md](reverse_proxy.md) for information on setting up a reverse
proxy.
-To enable workers you should create a configuration file for each worker
-process. Each worker configuration file inherits the configuration of the shared
-homeserver configuration file. You can then override configuration specific to
-that worker, e.g. the HTTP listener that it provides (if any); logging
-configuration; etc. You should minimise the number of overrides though to
-maintain a usable config.
+When using workers, each worker process has its own configuration file which
+contains settings specific to that worker, such as the HTTP listener that it
+provides (if any), logging configuration, etc.
+Normally, the worker processes are configured to read from a shared
+configuration file as well as the worker-specific configuration files. This
+makes it easier to keep common configuration settings synchronised across all
+the processes.
-### Shared Configuration
+The main process is somewhat special in this respect: it does not normally
+need its own configuration file and can take all of its configuration from the
+shared configuration file.
+
+
+### Shared configuration
+
+Normally, only a couple of changes are needed to make an existing configuration
+file suitable for use with workers. First, you need to enable an "HTTP replication
+listener" for the main process; and secondly, you need to enable redis-based
+replication. For example:
-Next you need to add both a HTTP replication listener, used for HTTP requests
-between processes, and redis config to the shared Synapse configuration file
-(`homeserver.yaml`). For example:
```yaml
# extend the existing `listeners` section. This defines the ports that the
@@ -105,7 +113,7 @@ Under **no circumstances** should the replication listener be exposed to the
public internet; it has no authentication and is unencrypted.
-### Worker Configuration
+### Worker configuration
In the config file for each worker, you must specify the type of worker
application (`worker_app`), and you should specify a unqiue name for the worker
@@ -145,6 +153,9 @@ plain HTTP endpoint on port 8083 separately serving various endpoints, e.g.
Obviously you should configure your reverse-proxy to route the relevant
endpoints to the worker (`localhost:8083` in the above example).
+
+### Running Synapse with workers
+
Finally, you need to start your worker processes. This can be done with either
`synctl` or your distribution's preferred service manager such as `systemd`. We
recommend the use of `systemd` where available: for information on setting up
@@ -407,6 +418,23 @@ all these to be folded into the `generic_worker` app and to use config to define
which processes handle the various proccessing such as push notifications.
+## Migration from old config
+
+There are two main independent changes that have been made: introducing Redis
+support and merging apps into `synapse.app.generic_worker`. Both these changes
+are backwards compatible and so no changes to the config are required, however
+server admins are encouraged to plan to migrate to Redis as the old style direct
+TCP replication config is deprecated.
+
+To migrate to Redis add the `redis` config as above, and optionally remove the
+TCP `replication` listener from master and `worker_replication_port` from worker
+config.
+
+To migrate apps to use `synapse.app.generic_worker` simply update the
+`worker_app` option in the worker configs, and where worker are started (e.g.
+in systemd service files, but not required for synctl).
+
+
## Architectural diagram
The following shows an example setup using Redis and a reverse proxy:
diff --git a/mypy.ini b/mypy.ini
index a61009b197..c69cb5dc40 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -81,3 +81,6 @@ ignore_missing_imports = True
[mypy-rust_python_jaeger_reporter.*]
ignore_missing_imports = True
+
+[mypy-nacl.*]
+ignore_missing_imports = True
diff --git a/scripts-dev/check-newsfragment b/scripts-dev/check-newsfragment
index 98a618f6b2..448cadb829 100755
--- a/scripts-dev/check-newsfragment
+++ b/scripts-dev/check-newsfragment
@@ -3,6 +3,8 @@
# A script which checks that an appropriate news file has been added on this
# branch.
+echo -e "+++ \033[32mChecking newsfragment\033[m"
+
set -e
# make sure that origin/develop is up to date
@@ -16,6 +18,8 @@ pr="$BUILDKITE_PULL_REQUEST"
if ! git diff --quiet FETCH_HEAD... -- debian; then
if git diff --quiet FETCH_HEAD... -- debian/changelog; then
echo "Updates to debian directory, but no update to the changelog." >&2
+ echo "!! Please see the contributing guide for help writing your changelog entry:" >&2
+ echo "https://github.com/matrix-org/synapse/blob/develop/CONTRIBUTING.md#debian-changelog" >&2
exit 1
fi
fi
@@ -26,7 +30,12 @@ if ! git diff --name-only FETCH_HEAD... | grep -qv '^debian/'; then
exit 0
fi
-tox -qe check-newsfragment
+# Print a link to the contributing guide if the user makes a mistake
+CONTRIBUTING_GUIDE_TEXT="!! Please see the contributing guide for help writing your changelog entry:
+https://github.com/matrix-org/synapse/blob/develop/CONTRIBUTING.md#changelog"
+
+# If check-newsfragment returns a non-zero exit code, print the contributing guide and exit
+tox -qe check-newsfragment || (echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2 && exit 1)
echo
echo "--------------------------"
@@ -38,6 +47,7 @@ for f in `git diff --name-only FETCH_HEAD... -- changelog.d`; do
lastchar=`tr -d '\n' < $f | tail -c 1`
if [ $lastchar != '.' -a $lastchar != '!' ]; then
echo -e "\e[31mERROR: newsfragment $f does not end with a '.' or '!'\e[39m" >&2
+ echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2
exit 1
fi
@@ -47,5 +57,6 @@ done
if [[ -n "$pr" && "$matched" -eq 0 ]]; then
echo -e "\e[31mERROR: Did not find a news fragment with the right number: expected changelog.d/$pr.*.\e[39m" >&2
+ echo -e "$CONTRIBUTING_GUIDE_TEXT" >&2
exit 1
fi
diff --git a/scripts-dev/update_database b/scripts-dev/update_database
index 94aa8758b4..56365e2b58 100755
--- a/scripts-dev/update_database
+++ b/scripts-dev/update_database
@@ -40,7 +40,7 @@ class MockHomeserver(HomeServer):
config.server_name, reactor=reactor, config=config, **kwargs
)
- self.version_string = "Synapse/"+get_version_string(synapse)
+ self.version_string = "Synapse/" + get_version_string(synapse)
if __name__ == "__main__":
@@ -86,7 +86,7 @@ if __name__ == "__main__":
store = hs.get_datastore()
async def run_background_updates():
- await store.db.updates.run_background_updates(sleep=False)
+ await store.db_pool.updates.run_background_updates(sleep=False)
# Stop the reactor to exit the script once every background update is run.
reactor.stop()
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index bee525197f..a34bdf1830 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -35,31 +35,29 @@ from synapse.logging.context import (
make_deferred_yieldable,
run_in_background,
)
-from synapse.storage.data_stores.main.client_ips import ClientIpBackgroundUpdateStore
-from synapse.storage.data_stores.main.deviceinbox import (
- DeviceInboxBackgroundUpdateStore,
-)
-from synapse.storage.data_stores.main.devices import DeviceBackgroundUpdateStore
-from synapse.storage.data_stores.main.events_bg_updates import (
+from synapse.storage.database import DatabasePool, make_conn
+from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
+from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
+from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore
+from synapse.storage.databases.main.events_bg_updates import (
EventsBackgroundUpdatesStore,
)
-from synapse.storage.data_stores.main.media_repository import (
+from synapse.storage.databases.main.media_repository import (
MediaRepositoryBackgroundUpdateStore,
)
-from synapse.storage.data_stores.main.registration import (
+from synapse.storage.databases.main.registration import (
RegistrationBackgroundUpdateStore,
find_max_generated_user_id_localpart,
)
-from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore
-from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore
-from synapse.storage.data_stores.main.search import SearchBackgroundUpdateStore
-from synapse.storage.data_stores.main.state import MainStateBackgroundUpdateStore
-from synapse.storage.data_stores.main.stats import StatsStore
-from synapse.storage.data_stores.main.user_directory import (
+from synapse.storage.databases.main.room import RoomBackgroundUpdateStore
+from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore
+from synapse.storage.databases.main.search import SearchBackgroundUpdateStore
+from synapse.storage.databases.main.state import MainStateBackgroundUpdateStore
+from synapse.storage.databases.main.stats import StatsStore
+from synapse.storage.databases.main.user_directory import (
UserDirectoryBackgroundUpdateStore,
)
-from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
-from synapse.storage.database import Database, make_conn
+from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.util import Clock
@@ -69,7 +67,7 @@ logger = logging.getLogger("synapse_port_db")
BOOLEAN_COLUMNS = {
- "events": ["processed", "outlier", "contains_url", "count_as_unread"],
+ "events": ["processed", "outlier", "contains_url"],
"rooms": ["is_public"],
"event_edges": ["is_state"],
"presence_list": ["accepted"],
@@ -175,14 +173,14 @@ class Store(
StatsStore,
):
def execute(self, f, *args, **kwargs):
- return self.db.runInteraction(f.__name__, f, *args, **kwargs)
+ return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
def execute_sql(self, sql, *args):
def r(txn):
txn.execute(sql, args)
return txn.fetchall()
- return self.db.runInteraction("execute_sql", r)
+ return self.db_pool.runInteraction("execute_sql", r)
def insert_many_txn(self, txn, table, headers, rows):
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
@@ -227,7 +225,7 @@ class Porter(object):
async def setup_table(self, table):
if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting.
- row = await self.postgres_store.db.simple_select_one(
+ row = await self.postgres_store.db_pool.simple_select_one(
table="port_from_sqlite3",
keyvalues={"table_name": table},
retcols=("forward_rowid", "backward_rowid"),
@@ -244,7 +242,7 @@ class Porter(object):
) = await self._setup_sent_transactions()
backward_chunk = 0
else:
- await self.postgres_store.db.simple_insert(
+ await self.postgres_store.db_pool.simple_insert(
table="port_from_sqlite3",
values={
"table_name": table,
@@ -274,7 +272,7 @@ class Porter(object):
await self.postgres_store.execute(delete_all)
- await self.postgres_store.db.simple_insert(
+ await self.postgres_store.db_pool.simple_insert(
table="port_from_sqlite3",
values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0},
)
@@ -318,7 +316,7 @@ class Porter(object):
if table == "user_directory_stream_pos":
# We need to make sure there is a single row, `(X, null), as that is
# what synapse expects to be there.
- await self.postgres_store.db.simple_insert(
+ await self.postgres_store.db_pool.simple_insert(
table=table, values={"stream_id": None}
)
self.progress.update(table, table_size) # Mark table as done
@@ -359,7 +357,7 @@ class Porter(object):
return headers, forward_rows, backward_rows
- headers, frows, brows = await self.sqlite_store.db.runInteraction(
+ headers, frows, brows = await self.sqlite_store.db_pool.runInteraction(
"select", r
)
@@ -375,7 +373,7 @@ class Porter(object):
def insert(txn):
self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
- self.postgres_store.db.simple_update_one_txn(
+ self.postgres_store.db_pool.simple_update_one_txn(
txn,
table="port_from_sqlite3",
keyvalues={"table_name": table},
@@ -413,7 +411,7 @@ class Porter(object):
return headers, rows
- headers, rows = await self.sqlite_store.db.runInteraction("select", r)
+ headers, rows = await self.sqlite_store.db_pool.runInteraction("select", r)
if rows:
forward_chunk = rows[-1][0] + 1
@@ -451,7 +449,7 @@ class Porter(object):
],
)
- self.postgres_store.db.simple_update_one_txn(
+ self.postgres_store.db_pool.simple_update_one_txn(
txn,
table="port_from_sqlite3",
keyvalues={"table_name": "event_search"},
@@ -494,7 +492,7 @@ class Porter(object):
db_conn, allow_outdated_version=allow_outdated_version
)
prepare_database(db_conn, engine, config=self.hs_config)
- store = Store(Database(hs, db_config, engine), db_conn, hs)
+ store = Store(DatabasePool(hs, db_config, engine), db_conn, hs)
db_conn.commit()
return store
@@ -502,7 +500,7 @@ class Porter(object):
async def run_background_updates_on_postgres(self):
# Manually apply all background updates on the PostgreSQL database.
postgres_ready = (
- await self.postgres_store.db.updates.has_completed_background_updates()
+ await self.postgres_store.db_pool.updates.has_completed_background_updates()
)
if not postgres_ready:
@@ -511,9 +509,9 @@ class Porter(object):
self.progress.set_state("Running background updates on PostgreSQL")
while not postgres_ready:
- await self.postgres_store.db.updates.do_next_background_update(100)
+ await self.postgres_store.db_pool.updates.do_next_background_update(100)
postgres_ready = await (
- self.postgres_store.db.updates.has_completed_background_updates()
+ self.postgres_store.db_pool.updates.has_completed_background_updates()
)
async def run(self):
@@ -534,7 +532,7 @@ class Porter(object):
# Check if all background updates are done, abort if not.
updates_complete = (
- await self.sqlite_store.db.updates.has_completed_background_updates()
+ await self.sqlite_store.db_pool.updates.has_completed_background_updates()
)
if not updates_complete:
end_error = (
@@ -576,22 +574,24 @@ class Porter(object):
)
try:
- await self.postgres_store.db.runInteraction("alter_table", alter_table)
+ await self.postgres_store.db_pool.runInteraction(
+ "alter_table", alter_table
+ )
except Exception:
# On Error Resume Next
pass
- await self.postgres_store.db.runInteraction(
+ await self.postgres_store.db_pool.runInteraction(
"create_port_table", create_port_table
)
# Step 2. Get tables.
self.progress.set_state("Fetching tables")
- sqlite_tables = await self.sqlite_store.db.simple_select_onecol(
+ sqlite_tables = await self.sqlite_store.db_pool.simple_select_onecol(
table="sqlite_master", keyvalues={"type": "table"}, retcol="name"
)
- postgres_tables = await self.postgres_store.db.simple_select_onecol(
+ postgres_tables = await self.postgres_store.db_pool.simple_select_onecol(
table="information_schema.tables",
keyvalues={},
retcol="distinct table_name",
@@ -692,7 +692,7 @@ class Porter(object):
return headers, [r for r in rows if r[ts_ind] < yesterday]
- headers, rows = await self.sqlite_store.db.runInteraction("select", r)
+ headers, rows = await self.sqlite_store.db_pool.runInteraction("select", r)
rows = self._convert_rows("sent_transactions", headers, rows)
@@ -725,7 +725,7 @@ class Porter(object):
next_chunk = await self.sqlite_store.execute(get_start_id)
next_chunk = max(max_inserted_rowid + 1, next_chunk)
- await self.postgres_store.db.simple_insert(
+ await self.postgres_store.db_pool.simple_insert(
table="port_from_sqlite3",
values={
"table_name": "sent_transactions",
@@ -794,14 +794,14 @@ class Porter(object):
next_id = curr_id + 1
txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
- return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r)
+ return self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r)
def _setup_user_id_seq(self):
def r(txn):
next_id = find_max_generated_user_id_localpart(txn) + 1
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
- return self.postgres_store.db.runInteraction("setup_user_id_seq", r)
+ return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
##############################################
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 2178e623da..d8190f92ab 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -13,12 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Optional
+from typing import List, Optional, Tuple
import pymacaroons
from netaddr import IPAddress
-from twisted.internet import defer
from twisted.web.server import Request
import synapse.types
@@ -80,13 +79,14 @@ class Auth(object):
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
self._macaroon_secret_key = hs.config.macaroon_secret_key
- @defer.inlineCallbacks
- def check_from_context(self, room_version: str, event, context, do_sig_check=True):
- prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
- auth_events_ids = yield self.compute_auth_events(
+ async def check_from_context(
+ self, room_version: str, event, context, do_sig_check=True
+ ):
+ prev_state_ids = await context.get_prev_state_ids()
+ auth_events_ids = self.compute_auth_events(
event, prev_state_ids, for_verification=True
)
- auth_events = yield self.store.get_events(auth_events_ids)
+ auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
@@ -94,14 +94,13 @@ class Auth(object):
room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
)
- @defer.inlineCallbacks
- def check_user_in_room(
+ async def check_user_in_room(
self,
room_id: str,
user_id: str,
current_state: Optional[StateMap[EventBase]] = None,
allow_departed_users: bool = False,
- ):
+ ) -> EventBase:
"""Check if the user is in the room, or was at some point.
Args:
room_id: The room to check.
@@ -119,37 +118,35 @@ class Auth(object):
Raises:
AuthError if the user is/was not in the room.
Returns:
- Deferred[Optional[EventBase]]:
- Membership event for the user if the user was in the
- room. This will be the join event if they are currently joined to
- the room. This will be the leave event if they have left the room.
+ Membership event for the user if the user was in the
+ room. This will be the join event if they are currently joined to
+ the room. This will be the leave event if they have left the room.
"""
if current_state:
member = current_state.get((EventTypes.Member, user_id), None)
else:
- member = yield defer.ensureDeferred(
- self.state.get_current_state(
- room_id=room_id, event_type=EventTypes.Member, state_key=user_id
- )
+ member = await self.state.get_current_state(
+ room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
- membership = member.membership if member else None
- if membership == Membership.JOIN:
- return member
+ if member:
+ membership = member.membership
- # XXX this looks totally bogus. Why do we not allow users who have been banned,
- # or those who were members previously and have been re-invited?
- if allow_departed_users and membership == Membership.LEAVE:
- forgot = yield self.store.did_forget(user_id, room_id)
- if not forgot:
+ if membership == Membership.JOIN:
return member
+ # XXX this looks totally bogus. Why do we not allow users who have been banned,
+ # or those who were members previously and have been re-invited?
+ if allow_departed_users and membership == Membership.LEAVE:
+ forgot = await self.store.did_forget(user_id, room_id)
+ if not forgot:
+ return member
+
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
- @defer.inlineCallbacks
- def check_host_in_room(self, room_id, host):
+ async def check_host_in_room(self, room_id, host):
with Measure(self.clock, "check_host_in_room"):
- latest_event_ids = yield self.store.is_host_joined(room_id, host)
+ latest_event_ids = await self.store.is_host_joined(room_id, host)
return latest_event_ids
def can_federate(self, event, auth_events):
@@ -160,14 +157,13 @@ class Auth(object):
def get_public_keys(self, invite_event):
return event_auth.get_public_keys(invite_event)
- @defer.inlineCallbacks
- def get_user_by_req(
+ async def get_user_by_req(
self,
request: Request,
allow_guest: bool = False,
rights: str = "access",
allow_expired: bool = False,
- ):
+ ) -> synapse.types.Requester:
""" Get a registered user's ID.
Args:
@@ -180,7 +176,7 @@ class Auth(object):
/login will deliver access tokens regardless of expiration.
Returns:
- defer.Deferred: resolves to a `synapse.types.Requester` object
+ Resolves to the requester
Raises:
InvalidClientCredentialsError if no user by that token exists or the token
is invalid.
@@ -194,14 +190,14 @@ class Auth(object):
access_token = self.get_access_token_from_request(request)
- user_id, app_service = yield self._get_appservice_user_id(request)
+ user_id, app_service = await self._get_appservice_user_id(request)
if user_id:
request.authenticated_entity = user_id
opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("appservice_id", app_service.id)
if ip_addr and self._track_appservice_user_ips:
- yield self.store.insert_client_ip(
+ await self.store.insert_client_ip(
user_id=user_id,
access_token=access_token,
ip=ip_addr,
@@ -211,7 +207,7 @@ class Auth(object):
return synapse.types.create_requester(user_id, app_service=app_service)
- user_info = yield self.get_user_by_access_token(
+ user_info = await self.get_user_by_access_token(
access_token, rights, allow_expired=allow_expired
)
user = user_info["user"]
@@ -221,7 +217,7 @@ class Auth(object):
# Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired:
user_id = user.to_string()
- expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
+ expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
if (
expiration_ts is not None
and self.clock.time_msec() >= expiration_ts
@@ -235,7 +231,7 @@ class Auth(object):
device_id = user_info.get("device_id")
if user and access_token and ip_addr:
- yield self.store.insert_client_ip(
+ await self.store.insert_client_ip(
user_id=user.to_string(),
access_token=access_token,
ip=ip_addr,
@@ -261,8 +257,7 @@ class Auth(object):
except KeyError:
raise MissingClientTokenError()
- @defer.inlineCallbacks
- def _get_appservice_user_id(self, request):
+ async def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request)
)
@@ -283,14 +278,13 @@ class Auth(object):
if not app_service.is_interested_in_user(user_id):
raise AuthError(403, "Application service cannot masquerade as this user.")
- if not (yield self.store.get_user_by_id(user_id)):
+ if not (await self.store.get_user_by_id(user_id)):
raise AuthError(403, "Application service has not registered this user")
return user_id, app_service
- @defer.inlineCallbacks
- def get_user_by_access_token(
+ async def get_user_by_access_token(
self, token: str, rights: str = "access", allow_expired: bool = False,
- ):
+ ) -> dict:
""" Validate access token and get user_id from it
Args:
@@ -300,7 +294,7 @@ class Auth(object):
allow_expired: If False, raises an InvalidClientTokenError
if the token is expired
Returns:
- Deferred[dict]: dict that includes:
+ dict that includes:
`user` (UserID)
`is_guest` (bool)
`token_id` (int|None): access token id. May be None if guest
@@ -314,7 +308,7 @@ class Auth(object):
if rights == "access":
# first look in the database
- r = yield self._look_up_user_by_access_token(token)
+ r = await self._look_up_user_by_access_token(token)
if r:
valid_until_ms = r["valid_until_ms"]
if (
@@ -352,7 +346,7 @@ class Auth(object):
# It would of course be much easier to store guest access
# tokens in the database as well, but that would break existing
# guest tokens.
- stored_user = yield self.store.get_user_by_id(user_id)
+ stored_user = await self.store.get_user_by_id(user_id)
if not stored_user:
raise InvalidClientTokenError("Unknown user_id %s" % user_id)
if not stored_user["is_guest"]:
@@ -482,9 +476,8 @@ class Auth(object):
now = self.hs.get_clock().time_msec()
return now < expiry
- @defer.inlineCallbacks
- def _look_up_user_by_access_token(self, token):
- ret = yield self.store.get_user_by_access_token(token)
+ async def _look_up_user_by_access_token(self, token):
+ ret = await self.store.get_user_by_access_token(token)
if not ret:
return None
@@ -507,7 +500,7 @@ class Auth(object):
logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError()
request.authenticated_entity = service.sender
- return defer.succeed(service)
+ return service
async def is_server_admin(self, user: UserID) -> bool:
""" Check if the given user is a local server admin.
@@ -522,7 +515,7 @@ class Auth(object):
def compute_auth_events(
self, event, current_state_ids: StateMap[str], for_verification: bool = False,
- ):
+ ) -> List[str]:
"""Given an event and current state return the list of event IDs used
to auth an event.
@@ -530,11 +523,11 @@ class Auth(object):
should be added to the event's `auth_events`.
Returns:
- defer.Deferred(list[str]): List of event IDs.
+ List of event IDs.
"""
if event.type == EventTypes.Create:
- return defer.succeed([])
+ return []
# Currently we ignore the `for_verification` flag even though there are
# some situations where we can drop particular auth events when adding
@@ -553,7 +546,7 @@ class Auth(object):
if auth_ev_id:
auth_ids.append(auth_ev_id)
- return defer.succeed(auth_ids)
+ return auth_ids
async def check_can_change_room_list(self, room_id: str, user: UserID):
"""Determine whether the user is allowed to edit the room's entry in the
@@ -636,10 +629,9 @@ class Auth(object):
return query_params[0].decode("ascii")
- @defer.inlineCallbacks
- def check_user_in_room_or_world_readable(
+ async def check_user_in_room_or_world_readable(
self, room_id: str, user_id: str, allow_departed_users: bool = False
- ):
+ ) -> Tuple[str, Optional[str]]:
"""Checks that the user is or was in the room or the room is world
readable. If it isn't then an exception is raised.
@@ -650,10 +642,9 @@ class Auth(object):
members but have now departed
Returns:
- Deferred[tuple[str, str|None]]: Resolves to the current membership of
- the user in the room and the membership event ID of the user. If
- the user is not in the room and never has been, then
- `(Membership.JOIN, None)` is returned.
+ Resolves to the current membership of the user in the room and the
+ membership event ID of the user. If the user is not in the room and
+ never has been, then `(Membership.JOIN, None)` is returned.
"""
try:
@@ -662,15 +653,13 @@ class Auth(object):
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
- member_event = yield self.check_user_in_room(
+ member_event = await self.check_user_in_room(
room_id, user_id, allow_departed_users=allow_departed_users
)
return member_event.membership, member_event.event_id
except AuthError:
- visibility = yield defer.ensureDeferred(
- self.state.get_current_state(
- room_id, EventTypes.RoomHistoryVisibility, ""
- )
+ visibility = await self.state.get_current_state(
+ room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
visibility
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index 5c499b6b4e..49093bf181 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved
@@ -36,8 +34,7 @@ class AuthBlocking(object):
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
- @defer.inlineCallbacks
- def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
+ async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
@@ -60,7 +57,7 @@ class AuthBlocking(object):
if user_id is not None:
if user_id == self._server_notices_mxid:
return
- if (yield self.store.is_support_user(user_id)):
+ if await self.store.is_support_user(user_id):
return
if self._hs_disabled:
@@ -76,11 +73,11 @@ class AuthBlocking(object):
# If the user is already part of the MAU cohort or a trial user
if user_id:
- timestamp = yield self.store.user_last_seen_monthly_active(user_id)
+ timestamp = await self.store.user_last_seen_monthly_active(user_id)
if timestamp:
return
- is_trial = yield self.store.is_trial_user(user_id)
+ is_trial = await self.store.is_trial_user(user_id)
if is_trial:
return
elif threepid:
@@ -93,7 +90,7 @@ class AuthBlocking(object):
# allow registration. Support users are excluded from MAU checks.
return
# Else if there is no room in the MAU bucket, bail
- current_mau = yield self.store.get_monthly_active_count()
+ current_mau = await self.store.get_monthly_active_count()
if current_mau >= self._max_mau_value:
raise ResourceLimitError(
403,
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index b3bab1aa52..6e40630ab6 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -238,14 +238,16 @@ class InteractiveAuthIncompleteError(Exception):
(This indicates we should return a 401 with 'result' as the body)
Attributes:
+ session_id: The ID of the ongoing interactive auth session.
result: the server response to the request, which should be
passed back to the client
"""
- def __init__(self, result: "JsonDict"):
+ def __init__(self, session_id: str, result: "JsonDict"):
super(InteractiveAuthIncompleteError, self).__init__(
"Interactive auth not yet complete"
)
+ self.session_id = session_id
self.result = result
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index f988f62a1e..7393d6cb74 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -21,8 +21,6 @@ import jsonschema
from canonicaljson import json
from jsonschema import FormatChecker
-from twisted.internet import defer
-
from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.storage.presence import UserPresenceState
@@ -137,9 +135,8 @@ class Filtering(object):
super(Filtering, self).__init__()
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def get_user_filter(self, user_localpart, filter_id):
- result = yield self.store.get_user_filter(user_localpart, filter_id)
+ async def get_user_filter(self, user_localpart, filter_id):
+ result = await self.store.get_user_filter(user_localpart, filter_id)
return FilterCollection(result)
def add_user_filter(self, user_localpart, user_filter):
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 373a80a4a7..2b2cd795e0 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import gc
import logging
import os
@@ -22,7 +21,6 @@ import sys
import traceback
from typing import Iterable
-from daemonize import Daemonize
from typing_extensions import NoReturn
from twisted.internet import defer, error, reactor
@@ -34,6 +32,7 @@ from synapse.config.server import ListenerConfig
from synapse.crypto import context_factory
from synapse.logging.context import PreserveLoggingContext
from synapse.util.async_helpers import Linearizer
+from synapse.util.daemonize import daemonize_process
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
@@ -129,17 +128,8 @@ def start_reactor(
if print_pidfile:
print(pid_file)
- daemon = Daemonize(
- app=appname,
- pid=pid_file,
- action=run,
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
- daemon.start()
- else:
- run()
+ daemonize_process(pid_file, logger)
+ run()
def quit_with_error(error_string: str) -> NoReturn:
@@ -278,7 +268,7 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
# It is now safe to start your Synapse.
hs.start_listening(listeners)
- hs.get_datastore().db.start_profiling()
+ hs.get_datastore().db_pool.start_profiling()
hs.get_pusherpool().start()
setup_sentry(hs)
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index c478df53be..739b013d4c 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -123,17 +123,18 @@ from synapse.rest.client.v2_alpha.account_data import (
from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.rest.client.versions import VersionsRestServlet
+from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource
-from synapse.server import HomeServer
-from synapse.storage.data_stores.main.censor_events import CensorEventsStore
-from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore
-from synapse.storage.data_stores.main.monthly_active_users import (
+from synapse.server import HomeServer, cache_in_self
+from synapse.storage.databases.main.censor_events import CensorEventsStore
+from synapse.storage.databases.main.media_repository import MediaRepositoryStore
+from synapse.storage.databases.main.monthly_active_users import (
MonthlyActiveUsersWorkerStore,
)
-from synapse.storage.data_stores.main.presence import UserPresenceState
-from synapse.storage.data_stores.main.search import SearchWorkerStore
-from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore
-from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
+from synapse.storage.databases.main.presence import UserPresenceState
+from synapse.storage.databases.main.search import SearchWorkerStore
+from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
+from synapse.storage.databases.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer
from synapse.util.httpresourcetree import create_resource_tree
@@ -493,7 +494,10 @@ class GenericWorkerServer(HomeServer):
site_tag = listener_config.http_options.tag
if site_tag is None:
site_tag = port
- resources = {}
+
+ # We always include a health resource.
+ resources = {"/health": HealthResource()}
+
for res in listener_config.http_options.resources:
for name in res.names:
if name == "metrics":
@@ -631,10 +635,12 @@ class GenericWorkerServer(HomeServer):
async def remove_pusher(self, app_id, push_key, user_id):
self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
- def build_replication_data_handler(self):
+ @cache_in_self
+ def get_replication_data_handler(self):
return GenericWorkerReplicationHandler(self)
- def build_presence_handler(self):
+ @cache_in_self
+ def get_presence_handler(self):
return GenericWorkerPresence(self)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index ec7401f911..98d0d14a12 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -68,6 +68,7 @@ from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.rest import ClientRestResource
from synapse.rest.admin import AdminRestResource
+from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.well_known import WellKnownResource
from synapse.server import HomeServer
@@ -98,7 +99,9 @@ class SynapseHomeServer(HomeServer):
if site_tag is None:
site_tag = port
- resources = {}
+ # We always include a health resource.
+ resources = {"/health": HealthResource()}
+
for res in listener_config.http_options.resources:
for name in res.names:
if name == "openid" and "federation" in res.names:
@@ -380,13 +383,12 @@ def setup(config_options):
hs.setup_master()
- @defer.inlineCallbacks
- def do_acme():
+ async def do_acme() -> bool:
"""
Reprovision an ACME certificate, if it's required.
Returns:
- Deferred[bool]: Whether the cert has been updated.
+ Whether the cert has been updated.
"""
acme = hs.get_acme_handler()
@@ -405,7 +407,7 @@ def setup(config_options):
provision = True
if provision:
- yield acme.provision_certificate()
+ await acme.provision_certificate()
return provision
@@ -415,7 +417,7 @@ def setup(config_options):
Provision a certificate from ACME, if required, and reload the TLS
certificate if it's renewed.
"""
- reprovisioned = yield do_acme()
+ reprovisioned = yield defer.ensureDeferred(do_acme())
if reprovisioned:
_base.refresh_certificate(hs)
@@ -427,8 +429,8 @@ def setup(config_options):
acme = hs.get_acme_handler()
# Start up the webservices which we will respond to ACME
# challenges with, and then provision.
- yield acme.start_listening()
- yield do_acme()
+ yield defer.ensureDeferred(acme.start_listening())
+ yield defer.ensureDeferred(do_acme())
# Check if it needs to be reprovisioned every day.
hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
@@ -442,7 +444,7 @@ def setup(config_options):
_base.start(hs, config.listeners)
- hs.get_datastore().db.updates.start_doing_background_updates()
+ hs.get_datastore().db_pool.updates.start_doing_background_updates()
except Exception:
# Print the exception and bail out.
print("Error during startup:", file=sys.stderr)
@@ -552,8 +554,8 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
#
# This only reports info about the *main* database.
- stats["database_engine"] = hs.get_datastore().db.engine.module.__name__
- stats["database_server_version"] = hs.get_datastore().db.engine.server_version
+ stats["database_engine"] = hs.get_datastore().db_pool.engine.module.__name__
+ stats["database_server_version"] = hs.get_datastore().db_pool.engine.server_version
logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
try:
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index db578bda79..e72a0b9ac0 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -175,7 +175,7 @@ class ApplicationServiceApi(SimpleHttpClient):
urllib.parse.quote(protocol),
)
try:
- info = yield self.get_json(uri, {})
+ info = yield defer.ensureDeferred(self.get_json(uri, {}))
if not _is_valid_3pe_metadata(info):
logger.warning(
diff --git a/synapse/config/_util.py b/synapse/config/_util.py
new file mode 100644
index 0000000000..cd31b1c3c9
--- /dev/null
+++ b/synapse/config/_util.py
@@ -0,0 +1,49 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, List
+
+import jsonschema
+
+from synapse.config._base import ConfigError
+from synapse.types import JsonDict
+
+
+def validate_config(json_schema: JsonDict, config: Any, config_path: List[str]) -> None:
+ """Validates a config setting against a JsonSchema definition
+
+ This can be used to validate a section of the config file against a schema
+ definition. If the validation fails, a ConfigError is raised with a textual
+ description of the problem.
+
+ Args:
+ json_schema: the schema to validate against
+ config: the configuration value to be validated
+ config_path: the path within the config file. This will be used as a basis
+ for the error message.
+ """
+ try:
+ jsonschema.validate(config, json_schema)
+ except jsonschema.ValidationError as e:
+ # copy `config_path` before modifying it.
+ path = list(config_path)
+ for p in list(e.path):
+ if isinstance(p, int):
+ path.append("<item %i>" % p)
+ else:
+ path.append(str(p))
+
+ raise ConfigError(
+ "Unable to parse configuration: %s at %s" % (e.message, ".".join(path))
+ )
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 62bccd9ef5..8a18a9ca2a 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -100,7 +100,10 @@ class DatabaseConnectionConfig:
self.name = name
self.config = db_config
- self.data_stores = data_stores
+
+ # The `data_stores` config is actually talking about `databases` (we
+ # changed the name).
+ self.databases = data_stores
class DatabaseConfig(Config):
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index dd775a97e8..c96e6ef62a 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -55,24 +55,33 @@ formatters:
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \
%(request)s - %(message)s'
-filters:
- context:
- (): synapse.logging.context.LoggingContextFilter
- request: ""
-
handlers:
file:
- class: logging.handlers.RotatingFileHandler
+ class: logging.handlers.TimedRotatingFileHandler
formatter: precise
filename: ${log_file}
- maxBytes: 104857600
- backupCount: 10
- filters: [context]
+ when: midnight
+ backupCount: 3 # Does not include the current log file.
encoding: utf8
+
+ # Default to buffering writes to log file for efficiency. This means that
+ # will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR
+ # logs will still be flushed immediately.
+ buffer:
+ class: logging.handlers.MemoryHandler
+ target: file
+ # The capacity is the number of log lines that are buffered before
+ # being written to disk. Increasing this will lead to better
+ # performance, at the expensive of it taking longer for log lines to
+ # be written to disk.
+ capacity: 10
+ flushLevel: 30 # Flush for WARNING logs as well
+
+ # A handler that writes logs to stderr. Unused by default, but can be used
+ # instead of "buffer" and "file" in the logger handlers.
console:
class: logging.StreamHandler
formatter: precise
- filters: [context]
loggers:
synapse.storage.SQL:
@@ -80,9 +89,24 @@ loggers:
# information such as access tokens.
level: INFO
+ twisted:
+ # We send the twisted logging directly to the file handler,
+ # to work around https://github.com/matrix-org/synapse/issues/3471
+ # when using "buffer" logger. Use "console" to log to stderr instead.
+ handlers: [file]
+ propagate: false
+
root:
level: INFO
- handlers: [file, console]
+
+ # Write logs to the `buffer` handler, which will buffer them together in memory,
+ # then write them to a file.
+ #
+ # Replace "buffer" with "console" to log to stderr instead. (Note that you'll
+ # also need to update the configuation for the `twisted` logger above, in
+ # this case.)
+ #
+ handlers: [buffer]
disable_existing_loggers: false
"""
@@ -168,11 +192,26 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
handler = logging.StreamHandler()
handler.setFormatter(formatter)
- handler.addFilter(LoggingContextFilter(request=""))
logger.addHandler(handler)
else:
logging.config.dictConfig(log_config)
+ # We add a log record factory that runs all messages through the
+ # LoggingContextFilter so that we get the context *at the time we log*
+ # rather than when we write to a handler. This can be done in config using
+ # filter options, but care must when using e.g. MemoryHandler to buffer
+ # writes.
+
+ log_filter = LoggingContextFilter(request="")
+ old_factory = logging.getLogRecordFactory()
+
+ def factory(*args, **kwargs):
+ record = old_factory(*args, **kwargs)
+ log_filter.filter(record)
+ return record
+
+ logging.setLogRecordFactory(factory)
+
# Route Twisted's native logging through to the standard library logging
# system.
observer = STDLibLogObserver()
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 2dd94bae2b..b2c78ac40c 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -93,6 +93,15 @@ class RatelimitConfig(Config):
if rc_admin_redaction:
self.rc_admin_redaction = RateLimitConfig(rc_admin_redaction)
+ self.rc_joins_local = RateLimitConfig(
+ config.get("rc_joins", {}).get("local", {}),
+ defaults={"per_second": 0.1, "burst_count": 3},
+ )
+ self.rc_joins_remote = RateLimitConfig(
+ config.get("rc_joins", {}).get("remote", {}),
+ defaults={"per_second": 0.01, "burst_count": 3},
+ )
+
def generate_config_section(self, **kwargs):
return """\
## Ratelimiting ##
@@ -118,6 +127,10 @@ class RatelimitConfig(Config):
# - one for ratelimiting redactions by room admins. If this is not explicitly
# set then it uses the same ratelimiting as per rc_message. This is useful
# to allow room admins to deal with abuse quickly.
+ # - two for ratelimiting number of rooms a user can join, "local" for when
+ # users are joining rooms the server is already in (this is cheap) vs
+ # "remote" for when users are trying to join rooms not on the server (which
+ # can be more expensive)
#
# The defaults are as shown below.
#
@@ -143,6 +156,14 @@ class RatelimitConfig(Config):
#rc_admin_redaction:
# per_second: 1
# burst_count: 50
+ #
+ #rc_joins:
+ # local:
+ # per_second: 0.1
+ # burst_count: 3
+ # remote:
+ # per_second: 0.01
+ # burst_count: 3
# Ratelimiting settings for incoming federation
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index 293643b2de..9277b5f342 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -15,7 +15,9 @@
# limitations under the License.
import logging
+from typing import Any, List
+import attr
import jinja2
import pkg_resources
@@ -23,6 +25,7 @@ from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module, load_python_module
from ._base import Config, ConfigError
+from ._util import validate_config
logger = logging.getLogger(__name__)
@@ -80,6 +83,11 @@ class SAML2Config(Config):
self.saml2_enabled = True
+ attribute_requirements = saml2_config.get("attribute_requirements") or []
+ self.attribute_requirements = _parse_attribute_requirements_def(
+ attribute_requirements
+ )
+
self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
"grandfathered_mxid_source_attribute", "uid"
)
@@ -341,6 +349,17 @@ class SAML2Config(Config):
#
#grandfathered_mxid_source_attribute: upn
+ # It is possible to configure Synapse to only allow logins if SAML attributes
+ # match particular values. The requirements can be listed under
+ # `attribute_requirements` as shown below. All of the listed attributes must
+ # match for the login to be permitted.
+ #
+ #attribute_requirements:
+ # - attribute: userGroup
+ # value: "staff"
+ # - attribute: department
+ # value: "sales"
+
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
@@ -368,3 +387,34 @@ class SAML2Config(Config):
""" % {
"config_dir_path": config_dir_path
}
+
+
+@attr.s(frozen=True)
+class SamlAttributeRequirement:
+ """Object describing a single requirement for SAML attributes."""
+
+ attribute = attr.ib(type=str)
+ value = attr.ib(type=str)
+
+ JSON_SCHEMA = {
+ "type": "object",
+ "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
+ "required": ["attribute", "value"],
+ }
+
+
+ATTRIBUTE_REQUIREMENTS_SCHEMA = {
+ "type": "array",
+ "items": SamlAttributeRequirement.JSON_SCHEMA,
+}
+
+
+def _parse_attribute_requirements_def(
+ attribute_requirements: Any,
+) -> List[SamlAttributeRequirement]:
+ validate_config(
+ ATTRIBUTE_REQUIREMENTS_SCHEMA,
+ attribute_requirements,
+ config_path=["saml2_config", "attribute_requirements"],
+ )
+ return [SamlAttributeRequirement(**x) for x in attribute_requirements]
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 848587d232..9f15ed109e 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -530,6 +530,21 @@ class ServerConfig(Config):
"request_token_inhibit_3pid_errors", False,
)
+ # List of users trialing the new experimental default push rules. This setting is
+ # not included in the sample configuration file on purpose as it's a temporary
+ # hack, so that some users can trial the new defaults without impacting every
+ # user on the homeserver.
+ users_new_default_push_rules = (
+ config.get("users_new_default_push_rules") or []
+ ) # type: list
+ if not isinstance(users_new_default_push_rules, list):
+ raise ConfigError("'users_new_default_push_rules' must be a list")
+
+ # Turn the list into a set to improve lookup speed.
+ self.users_new_default_push_rules = set(
+ users_new_default_push_rules
+ ) # type: set
+
def has_tls_listener(self) -> bool:
return any(listener.tls for listener in self.listeners)
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index a5a2a7815d..777c0f00b1 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -48,6 +48,14 @@ class ServerContextFactory(ContextFactory):
connections."""
def __init__(self, config):
+ # TODO: once pyOpenSSL exposes TLS_METHOD and SSL_CTX_set_min_proto_version,
+ # switch to those (see https://github.com/pyca/cryptography/issues/5379).
+ #
+ # note that, despite the confusing name, SSLv23_METHOD does *not* enforce SSLv2
+ # or v3, but is a synonym for TLS_METHOD, which allows the client and server
+ # to negotiate an appropriate version of TLS constrained by the version options
+ # set with context.set_options.
+ #
self._context = SSL.Context(SSL.SSLv23_METHOD)
self.configure_context(self._context, config)
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 443cde0b6d..28ef7cfdb9 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -223,8 +223,7 @@ class Keyring(object):
return results
- @defer.inlineCallbacks
- def _start_key_lookups(self, verify_requests):
+ async def _start_key_lookups(self, verify_requests):
"""Sets off the key fetches for each verify request
Once each fetch completes, verify_request.key_ready will be resolved.
@@ -245,7 +244,7 @@ class Keyring(object):
server_to_request_ids.setdefault(server_name, set()).add(request_id)
# Wait for any previous lookups to complete before proceeding.
- yield self.wait_for_previous_lookups(server_to_request_ids.keys())
+ await self.wait_for_previous_lookups(server_to_request_ids.keys())
# take out a lock on each of the servers by sticking a Deferred in
# key_downloads
@@ -283,15 +282,14 @@ class Keyring(object):
except Exception:
logger.exception("Error starting key lookups")
- @defer.inlineCallbacks
- def wait_for_previous_lookups(self, server_names):
+ async def wait_for_previous_lookups(self, server_names) -> None:
"""Waits for any previous key lookups for the given servers to finish.
Args:
server_names (Iterable[str]): list of servers which we want to look up
Returns:
- Deferred[None]: resolves once all key lookups for the given servers have
+ Resolves once all key lookups for the given servers have
completed. Follows the synapse rules of logcontext preservation.
"""
loop_count = 1
@@ -309,7 +307,7 @@ class Keyring(object):
loop_count,
)
with PreserveLoggingContext():
- yield defer.DeferredList((w[1] for w in wait_on))
+ await defer.DeferredList((w[1] for w in wait_on))
loop_count += 1
@@ -326,44 +324,44 @@ class Keyring(object):
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
- @defer.inlineCallbacks
- def do_iterations():
- with Measure(self.clock, "get_server_verify_keys"):
- for f in self._key_fetchers:
- if not remaining_requests:
- return
- yield self._attempt_key_fetches_with_fetcher(f, remaining_requests)
+ async def do_iterations():
+ try:
+ with Measure(self.clock, "get_server_verify_keys"):
+ for f in self._key_fetchers:
+ if not remaining_requests:
+ return
+ await self._attempt_key_fetches_with_fetcher(
+ f, remaining_requests
+ )
- # look for any requests which weren't satisfied
+ # look for any requests which weren't satisfied
+ with PreserveLoggingContext():
+ for verify_request in remaining_requests:
+ verify_request.key_ready.errback(
+ SynapseError(
+ 401,
+ "No key for %s with ids in %s (min_validity %i)"
+ % (
+ verify_request.server_name,
+ verify_request.key_ids,
+ verify_request.minimum_valid_until_ts,
+ ),
+ Codes.UNAUTHORIZED,
+ )
+ )
+ except Exception as err:
+ # we don't really expect to get here, because any errors should already
+ # have been caught and logged. But if we do, let's log the error and make
+ # sure that all of the deferreds are resolved.
+ logger.error("Unexpected error in _get_server_verify_keys: %s", err)
with PreserveLoggingContext():
for verify_request in remaining_requests:
- verify_request.key_ready.errback(
- SynapseError(
- 401,
- "No key for %s with ids in %s (min_validity %i)"
- % (
- verify_request.server_name,
- verify_request.key_ids,
- verify_request.minimum_valid_until_ts,
- ),
- Codes.UNAUTHORIZED,
- )
- )
-
- def on_err(err):
- # we don't really expect to get here, because any errors should already
- # have been caught and logged. But if we do, let's log the error and make
- # sure that all of the deferreds are resolved.
- logger.error("Unexpected error in _get_server_verify_keys: %s", err)
- with PreserveLoggingContext():
- for verify_request in remaining_requests:
- if not verify_request.key_ready.called:
- verify_request.key_ready.errback(err)
+ if not verify_request.key_ready.called:
+ verify_request.key_ready.errback(err)
- run_in_background(do_iterations).addErrback(on_err)
+ run_in_background(do_iterations)
- @defer.inlineCallbacks
- def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
+ async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
"""Use a key fetcher to attempt to satisfy some key requests
Args:
@@ -390,7 +388,7 @@ class Keyring(object):
verify_request.minimum_valid_until_ts,
)
- results = yield fetcher.get_keys(missing_keys)
+ results = await fetcher.get_keys(missing_keys)
completed = []
for verify_request in remaining_requests:
@@ -423,7 +421,7 @@ class Keyring(object):
class KeyFetcher(object):
- def get_keys(self, keys_to_fetch):
+ async def get_keys(self, keys_to_fetch):
"""
Args:
keys_to_fetch (dict[str, dict[str, int]]):
@@ -442,8 +440,7 @@ class StoreKeyFetcher(KeyFetcher):
def __init__(self, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def get_keys(self, keys_to_fetch):
+ async def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys"""
keys_to_fetch = (
@@ -452,7 +449,7 @@ class StoreKeyFetcher(KeyFetcher):
for key_id in keys_for_server.keys()
)
- res = yield self.store.get_server_verify_keys(keys_to_fetch)
+ res = await self.store.get_server_verify_keys(keys_to_fetch)
keys = {}
for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key
@@ -464,8 +461,7 @@ class BaseV2KeyFetcher(object):
self.store = hs.get_datastore()
self.config = hs.get_config()
- @defer.inlineCallbacks
- def process_v2_response(self, from_server, response_json, time_added_ms):
+ async def process_v2_response(self, from_server, response_json, time_added_ms):
"""Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from
@@ -537,7 +533,7 @@ class BaseV2KeyFetcher(object):
key_json_bytes = encode_canonical_json(response_json)
- yield make_deferred_yieldable(
+ await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@@ -567,14 +563,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
self.client = hs.get_http_client()
self.key_servers = self.config.key_servers
- @defer.inlineCallbacks
- def get_keys(self, keys_to_fetch):
+ async def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys"""
- @defer.inlineCallbacks
- def get_key(key_server):
+ async def get_key(key_server):
try:
- result = yield self.get_server_verify_key_v2_indirect(
+ result = await self.get_server_verify_key_v2_indirect(
keys_to_fetch, key_server
)
return result
@@ -592,7 +586,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return {}
- results = yield make_deferred_yieldable(
+ results = await make_deferred_yieldable(
defer.gatherResults(
[run_in_background(get_key, server) for server in self.key_servers],
consumeErrors=True,
@@ -606,8 +600,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return union_of_keys
- @defer.inlineCallbacks
- def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
+ async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
"""
Args:
keys_to_fetch (dict[str, dict[str, int]]):
@@ -617,7 +610,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
the keys
Returns:
- Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map
+ dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map
from server_name -> key_id -> FetchKeyResult
Raises:
@@ -632,20 +625,18 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
)
try:
- query_response = yield defer.ensureDeferred(
- self.client.post_json(
- destination=perspective_name,
- path="/_matrix/key/v2/query",
- data={
- "server_keys": {
- server_name: {
- key_id: {"minimum_valid_until_ts": min_valid_ts}
- for key_id, min_valid_ts in server_keys.items()
- }
- for server_name, server_keys in keys_to_fetch.items()
+ query_response = await self.client.post_json(
+ destination=perspective_name,
+ path="/_matrix/key/v2/query",
+ data={
+ "server_keys": {
+ server_name: {
+ key_id: {"minimum_valid_until_ts": min_valid_ts}
+ for key_id, min_valid_ts in server_keys.items()
}
- },
- )
+ for server_name, server_keys in keys_to_fetch.items()
+ }
+ },
)
except (NotRetryingDestination, RequestSendFailed) as e:
# these both have str() representations which we can't really improve upon
@@ -670,7 +661,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
try:
self._validate_perspectives_response(key_server, response)
- processed_response = yield self.process_v2_response(
+ processed_response = await self.process_v2_response(
perspective_name, response, time_added_ms=time_now_ms
)
except KeyLookupError as e:
@@ -689,7 +680,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
)
keys.setdefault(server_name, {}).update(processed_response)
- yield self.store.store_server_verify_keys(
+ await self.store.store_server_verify_keys(
perspective_name, time_now_ms, added_keys
)
@@ -741,24 +732,23 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
self.clock = hs.get_clock()
self.client = hs.get_http_client()
- def get_keys(self, keys_to_fetch):
+ async def get_keys(self, keys_to_fetch):
"""
Args:
keys_to_fetch (dict[str, iterable[str]]):
the keys to be fetched. server_name -> key_ids
Returns:
- Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
+ dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]:
map from server_name -> key_id -> FetchKeyResult
"""
results = {}
- @defer.inlineCallbacks
- def get_key(key_to_fetch_item):
+ async def get_key(key_to_fetch_item):
server_name, key_ids = key_to_fetch_item
try:
- keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids)
+ keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
results[server_name] = keys
except KeyLookupError as e:
logger.warning(
@@ -767,12 +757,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except Exception:
logger.exception("Error getting keys %s from %s", key_ids, server_name)
- return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback(
- lambda _: results
- )
+ return await yieldable_gather_results(
+ get_key, keys_to_fetch.items()
+ ).addCallback(lambda _: results)
- @defer.inlineCallbacks
- def get_server_verify_key_v2_direct(self, server_name, key_ids):
+ async def get_server_verify_key_v2_direct(self, server_name, key_ids):
"""
Args:
@@ -794,25 +783,23 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
time_now_ms = self.clock.time_msec()
try:
- response = yield defer.ensureDeferred(
- self.client.get_json(
- destination=server_name,
- path="/_matrix/key/v2/server/"
- + urllib.parse.quote(requested_key_id),
- ignore_backoff=True,
- # we only give the remote server 10s to respond. It should be an
- # easy request to handle, so if it doesn't reply within 10s, it's
- # probably not going to.
- #
- # Furthermore, when we are acting as a notary server, we cannot
- # wait all day for all of the origin servers, as the requesting
- # server will otherwise time out before we can respond.
- #
- # (Note that get_json may make 4 attempts, so this can still take
- # almost 45 seconds to fetch the headers, plus up to another 60s to
- # read the response).
- timeout=10000,
- )
+ response = await self.client.get_json(
+ destination=server_name,
+ path="/_matrix/key/v2/server/"
+ + urllib.parse.quote(requested_key_id),
+ ignore_backoff=True,
+ # we only give the remote server 10s to respond. It should be an
+ # easy request to handle, so if it doesn't reply within 10s, it's
+ # probably not going to.
+ #
+ # Furthermore, when we are acting as a notary server, we cannot
+ # wait all day for all of the origin servers, as the requesting
+ # server will otherwise time out before we can respond.
+ #
+ # (Note that get_json may make 4 attempts, so this can still take
+ # almost 45 seconds to fetch the headers, plus up to another 60s to
+ # read the response).
+ timeout=10000,
)
except (NotRetryingDestination, RequestSendFailed) as e:
# these both have str() representations which we can't really improve
@@ -827,12 +814,12 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
% (server_name, response["server_name"])
)
- response_keys = yield self.process_v2_response(
+ response_keys = await self.process_v2_response(
from_server=server_name,
response_json=response,
time_added_ms=time_now_ms,
)
- yield self.store.store_server_verify_keys(
+ await self.store.store_server_verify_keys(
server_name,
time_now_ms,
((server_name, key_id, key) for key_id, key in response_keys.items()),
@@ -842,22 +829,18 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
return keys
-@defer.inlineCallbacks
-def _handle_key_deferred(verify_request):
+async def _handle_key_deferred(verify_request) -> None:
"""Waits for the key to become available, and then performs a verification
Args:
verify_request (VerifyJsonRequest):
- Returns:
- Deferred[None]
-
Raises:
SynapseError if there was a problem performing the verification
"""
server_name = verify_request.server_name
with PreserveLoggingContext():
- _, key_id, verify_key = yield verify_request.key_ready
+ _, key_id, verify_key = await verify_request.key_ready
json_object = verify_request.json_object
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 69b53ca2bc..9ed24380dd 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -17,6 +17,7 @@ from typing import Optional
import attr
from nacl.signing import SigningKey
+from synapse.api.auth import Auth
from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import UnsupportedRoomVersionError
from synapse.api.room_versions import (
@@ -27,6 +28,8 @@ from synapse.api.room_versions import (
)
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
+from synapse.state import StateHandler
+from synapse.storage.databases.main import DataStore
from synapse.types import EventID, JsonDict
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -42,45 +45,46 @@ class EventBuilder(object):
Attributes:
room_version: Version of the target room
- room_id (str)
- type (str)
- sender (str)
- content (dict)
- unsigned (dict)
- internal_metadata (_EventInternalMetadata)
-
- _state (StateHandler)
- _auth (synapse.api.Auth)
- _store (DataStore)
- _clock (Clock)
- _hostname (str): The hostname of the server creating the event
+ room_id
+ type
+ sender
+ content
+ unsigned
+ internal_metadata
+
+ _state
+ _auth
+ _store
+ _clock
+ _hostname: The hostname of the server creating the event
_signing_key: The signing key to use to sign the event as the server
"""
- _state = attr.ib()
- _auth = attr.ib()
- _store = attr.ib()
- _clock = attr.ib()
- _hostname = attr.ib()
- _signing_key = attr.ib()
+ _state = attr.ib(type=StateHandler)
+ _auth = attr.ib(type=Auth)
+ _store = attr.ib(type=DataStore)
+ _clock = attr.ib(type=Clock)
+ _hostname = attr.ib(type=str)
+ _signing_key = attr.ib(type=SigningKey)
room_version = attr.ib(type=RoomVersion)
- room_id = attr.ib()
- type = attr.ib()
- sender = attr.ib()
+ room_id = attr.ib(type=str)
+ type = attr.ib(type=str)
+ sender = attr.ib(type=str)
- content = attr.ib(default=attr.Factory(dict))
- unsigned = attr.ib(default=attr.Factory(dict))
+ content = attr.ib(default=attr.Factory(dict), type=JsonDict)
+ unsigned = attr.ib(default=attr.Factory(dict), type=JsonDict)
# These only exist on a subset of events, so they raise AttributeError if
# someone tries to get them when they don't exist.
- _state_key = attr.ib(default=None)
- _redacts = attr.ib(default=None)
- _origin_server_ts = attr.ib(default=None)
+ _state_key = attr.ib(default=None, type=Optional[str])
+ _redacts = attr.ib(default=None, type=Optional[str])
+ _origin_server_ts = attr.ib(default=None, type=Optional[int])
internal_metadata = attr.ib(
- default=attr.Factory(lambda: _EventInternalMetadata({}))
+ default=attr.Factory(lambda: _EventInternalMetadata({})),
+ type=_EventInternalMetadata,
)
@property
@@ -106,7 +110,7 @@ class EventBuilder(object):
state_ids = await self._state.get_current_state_ids(
self.room_id, prev_event_ids
)
- auth_ids = await self._auth.compute_auth_events(self, state_ids)
+ auth_ids = self._auth.compute_auth_events(self, state_ids)
format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1:
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index cca93e3a46..afecafe15c 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -23,7 +23,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import StateMap
if TYPE_CHECKING:
- from synapse.storage.data_stores.main import DataStore
+ from synapse.storage.databases.main import DataStore
@attr.s(slots=True)
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 9c2c6a232d..a66a24b392 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, List, Tuple
from canonicaljson import json
@@ -58,7 +58,10 @@ class TransactionManager(object):
@measure_func("_send_new_transaction")
async def send_new_transaction(
- self, destination: str, pending_pdus: List[EventBase], pending_edus: List[Edu]
+ self,
+ destination: str,
+ pending_pdus: List[Tuple[EventBase, int]],
+ pending_edus: List[Edu],
):
# Make a transaction-sending opentracing span. This span follows on from
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index a2d7959abe..7666d3abcd 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -17,7 +17,6 @@ import logging
import twisted
import twisted.internet.error
-from twisted.internet import defer
from twisted.web import server, static
from twisted.web.resource import Resource
@@ -41,8 +40,7 @@ class AcmeHandler(object):
self.reactor = hs.get_reactor()
self._acme_domain = hs.config.acme_domain
- @defer.inlineCallbacks
- def start_listening(self):
+ async def start_listening(self):
from synapse.handlers import acme_issuing_service
# Configure logging for txacme, if you need to debug
@@ -82,18 +80,17 @@ class AcmeHandler(object):
self._issuer._registered = False
try:
- yield self._issuer._ensure_registered()
+ await self._issuer._ensure_registered()
except Exception:
logger.error(ACME_REGISTER_FAIL_ERROR)
raise
- @defer.inlineCallbacks
- def provision_certificate(self):
+ async def provision_certificate(self):
logger.warning("Reprovisioning %s", self._acme_domain)
try:
- yield self._issuer.issue_cert(self._acme_domain)
+ await self._issuer.issue_cert(self._acme_domain)
except Exception:
logger.exception("Fail!")
raise
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index fbc56c351b..c9044a5019 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -101,7 +101,7 @@ class ApplicationServicesHandler(object):
async def start_scheduler():
try:
- return self.scheduler.start()
+ return await self.scheduler.start()
except Exception:
logger.error("Application Services Failure")
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c7d921c21a..c24e7bafe0 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -162,7 +162,7 @@ class AuthHandler(BaseHandler):
request_body: Dict[str, Any],
clientip: str,
description: str,
- ) -> dict:
+ ) -> Tuple[dict, str]:
"""
Checks that the user is who they claim to be, via a UI auth.
@@ -183,9 +183,14 @@ class AuthHandler(BaseHandler):
describes the operation happening on their account.
Returns:
- The parameters for this request (which may
+ A tuple of (params, session_id).
+
+ 'params' contains the parameters for this request (which may
have been given only in a previous call).
+ 'session_id' is the ID of this session, either passed in by the
+ client or assigned by this call
+
Raises:
InteractiveAuthIncompleteError if the client has not yet completed
any of the permitted login flows
@@ -207,7 +212,7 @@ class AuthHandler(BaseHandler):
flows = [[login_type] for login_type in self._supported_ui_auth_types]
try:
- result, params, _ = await self.check_auth(
+ result, params, session_id = await self.check_ui_auth(
flows, request, request_body, clientip, description
)
except LoginError:
@@ -230,7 +235,7 @@ class AuthHandler(BaseHandler):
if user_id != requester.user.to_string():
raise AuthError(403, "Invalid auth")
- return params
+ return params, session_id
def get_enabled_auth_types(self):
"""Return the enabled user-interactive authentication types
@@ -240,7 +245,7 @@ class AuthHandler(BaseHandler):
"""
return self.checkers.keys()
- async def check_auth(
+ async def check_ui_auth(
self,
flows: List[List[str]],
request: SynapseRequest,
@@ -363,7 +368,7 @@ class AuthHandler(BaseHandler):
if not authdict:
raise InteractiveAuthIncompleteError(
- self._auth_dict_for_flows(flows, session.session_id)
+ session.session_id, self._auth_dict_for_flows(flows, session.session_id)
)
# check auth type currently being presented
@@ -410,7 +415,7 @@ class AuthHandler(BaseHandler):
ret = self._auth_dict_for_flows(flows, session.session_id)
ret["completed"] = list(creds)
ret.update(errordict)
- raise InteractiveAuthIncompleteError(ret)
+ raise InteractiveAuthIncompleteError(session.session_id, ret)
async def add_oob_auth(
self, stagetype: str, authdict: Dict[str, Any], clientip: str
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 71a89f09c7..1924636c4d 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -57,13 +57,10 @@ class EventStreamHandler(BaseHandler):
timeout=0,
as_client_event=True,
affect_presence=True,
- only_keys=None,
room_id=None,
is_guest=False,
):
"""Fetches the events stream for a given user.
-
- If `only_keys` is not None, events from keys will be sent down.
"""
if room_id:
@@ -93,7 +90,6 @@ class EventStreamHandler(BaseHandler):
auth_user,
pagin_config,
timeout,
- only_keys=only_keys,
is_guest=is_guest,
explicit_room_id=room_id,
)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 0d7d1adcea..593932adb7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -71,7 +71,7 @@ from synapse.replication.http.federation import (
)
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
-from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
+from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room
@@ -2064,7 +2064,7 @@ class FederationHandler(BaseHandler):
if not auth_events:
prev_state_ids = await context.get_prev_state_ids()
- auth_events_ids = await self.auth.compute_auth_events(
+ auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events_x = await self.store.get_events(auth_events_ids)
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 0bd2c3e37a..92b7404706 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -22,14 +22,10 @@ import urllib.parse
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
from canonicaljson import json
-from signedjson.key import decode_verify_key_bytes
-from signedjson.sign import verify_signed_json
-from unpaddedbase64 import decode_base64
from twisted.internet.error import TimeoutError
from synapse.api.errors import (
- AuthError,
CodeMessageException,
Codes,
HttpResponseException,
@@ -628,9 +624,9 @@ class IdentityHandler(BaseHandler):
)
if "mxid" in data:
- if "signatures" not in data:
- raise AuthError(401, "No signatures on 3pid binding")
- await self._verify_any_signature(data, id_server)
+ # note: we used to verify the identity server's signature here, but no longer
+ # require or validate it. See the following for context:
+ # https://github.com/matrix-org/synapse/issues/5253#issuecomment-666246950
return data["mxid"]
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
@@ -751,30 +747,6 @@ class IdentityHandler(BaseHandler):
mxid = lookup_results["mappings"].get(lookup_value)
return mxid
- async def _verify_any_signature(self, data, server_hostname):
- if server_hostname not in data["signatures"]:
- raise AuthError(401, "No signature from server %s" % (server_hostname,))
- for key_name, signature in data["signatures"][server_hostname].items():
- try:
- key_data = await self.blacklisting_http_client.get_json(
- "%s%s/_matrix/identity/api/v1/pubkey/%s"
- % (id_server_scheme, server_hostname, key_name)
- )
- except TimeoutError:
- raise SynapseError(500, "Timed out contacting identity server")
- if "public_key" not in key_data:
- raise AuthError(
- 401, "No public key named %s from %s" % (key_name, server_hostname)
- )
- verify_signed_json(
- data,
- server_hostname,
- decode_verify_key_bytes(
- key_name, decode_base64(key_data["public_key"])
- ),
- )
- return
-
async def ask_id_server_for_third_party_invite(
self,
requester: Requester,
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index f88bad5f25..ae6bd1d352 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -109,7 +109,7 @@ class InitialSyncHandler(BaseHandler):
rooms_ret = []
- now_token = await self.hs.get_event_sources().get_current_token()
+ now_token = self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token)
@@ -360,7 +360,7 @@ class InitialSyncHandler(BaseHandler):
current_state.values(), time_now
)
- now_token = await self.hs.get_event_sources().get_current_token()
+ now_token = self.hs.get_event_sources().get_current_token()
limit = pagin_config.limit if pagin_config else None
if limit is None:
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index a28068244d..73e787f2f7 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from canonicaljson import encode_canonical_json, json
@@ -45,7 +45,7 @@ from synapse.events.validator import EventValidator
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
-from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
+from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import (
Collection,
@@ -93,11 +93,11 @@ class MessageHandler(object):
async def get_room_data(
self,
- user_id: str = None,
- room_id: str = None,
- event_type: Optional[str] = None,
- state_key: str = "",
- is_guest: bool = False,
+ user_id: str,
+ room_id: str,
+ event_type: str,
+ state_key: str,
+ is_guest: bool,
) -> dict:
""" Get data from a room.
@@ -407,7 +407,7 @@ class EventCreationHandler(object):
#
# map from room id to time-of-last-attempt.
#
- self._rooms_to_exclude_from_dummy_event_insertion = {} # type: dict[str, int]
+ self._rooms_to_exclude_from_dummy_event_insertion = {} # type: Dict[str, int]
# we need to construct a ConsentURIBuilder here, as it checks that the necessary
# config options, but *only* if we have a configuration for which we are
@@ -709,7 +709,7 @@ class EventCreationHandler(object):
async def create_and_send_nonmember_event(
self,
requester: Requester,
- event_dict: EventBase,
+ event_dict: dict,
ratelimit: bool = True,
txn_id: Optional[str] = None,
) -> Tuple[EventBase, int]:
@@ -770,6 +770,15 @@ class EventCreationHandler(object):
else:
prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
+ # we now ought to have some prev_events (unless it's a create event).
+ #
+ # do a quick sanity check here, rather than waiting until we've created the
+ # event and then try to auth it (which fails with a somewhat confusing "No
+ # create event in auth events")
+ assert (
+ builder.type == EventTypes.Create or len(prev_event_ids) > 0
+ ), "Attempting to create an event with no prev_events"
+
event = await builder.build(prev_event_ids=prev_event_ids)
context = await self.state.compute_event_context(event)
if requester:
@@ -964,7 +973,7 @@ class EventCreationHandler(object):
# Validate a newly added alias or newly added alt_aliases.
original_alias = None
- original_alt_aliases = set()
+ original_alt_aliases = [] # type: List[str]
original_event_id = event.unsigned.get("replaces_state")
if original_event_id:
@@ -1012,6 +1021,10 @@ class EventCreationHandler(object):
current_state_ids = await context.get_current_state_ids()
+ # We know this event is not an outlier, so this must be
+ # non-None.
+ assert current_state_ids is not None
+
state_to_include_ids = [
e_id
for k, e_id in current_state_ids.items()
@@ -1063,7 +1076,7 @@ class EventCreationHandler(object):
raise SynapseError(400, "Cannot redact event from a different room")
prev_state_ids = await context.get_prev_state_ids()
- auth_events_ids = await self.auth.compute_auth_events(
+ auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events = await self.store.get_events(auth_events_ids)
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 87f0c5e197..fa5ee5de8f 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -14,7 +14,7 @@
# limitations under the License.
import json
import logging
-from typing import Dict, Generic, List, Optional, Tuple, TypeVar
+from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from urllib.parse import urlencode
import attr
@@ -39,9 +39,11 @@ from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.push.mailer import load_jinja2_templates
-from synapse.server import HomeServer
from synapse.types import UserID, map_username_to_mxid_localpart
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
SESSION_COOKIE_NAME = b"oidc_session"
@@ -91,7 +93,7 @@ class OidcHandler:
"""Handles requests related to the OpenID Connect login flow.
"""
- def __init__(self, hs: HomeServer):
+ def __init__(self, hs: "HomeServer"):
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._client_auth = ClientAuth(
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index da06582d4b..487420bb5d 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -309,7 +309,7 @@ class PaginationHandler(object):
room_token = pagin_config.from_token.room_key
else:
pagin_config.from_token = (
- await self.hs.get_event_sources().get_current_token_for_pagination()
+ self.hs.get_event_sources().get_current_token_for_pagination()
)
room_token = pagin_config.from_token.room_key
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index b3a3bb8c3f..5387b3724f 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -38,7 +38,7 @@ from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateHandler
-from synapse.storage.data_stores.main import DataStore
+from synapse.storage.databases.main import DataStore
from synapse.storage.presence import UserPresenceState
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
@@ -319,7 +319,7 @@ class PresenceHandler(BasePresenceHandler):
is some spurious presence changes that will self-correct.
"""
# If the DB pool has already terminated, don't try updating
- if not self.store.db.is_running():
+ if not self.store.db_pool.is_running():
return
logger.info(
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 501f0fe795..c94209ab3d 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -548,7 +548,7 @@ class RegistrationHandler(BaseHandler):
address (str|None): the IP address used to perform the registration.
Returns:
- Deferred
+ Awaitable
"""
if self.hs.config.worker_app:
return self._register_client(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 0c5b99234d..a8545255b1 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -22,7 +22,7 @@ import logging
import math
import string
from collections import OrderedDict
-from typing import Optional, Tuple
+from typing import Awaitable, Optional, Tuple
from synapse.api.constants import (
EventTypes,
@@ -1041,7 +1041,7 @@ class RoomEventSource(object):
):
# We just ignore the key for now.
- to_key = await self.get_current_key()
+ to_key = self.get_current_key()
from_token = RoomStreamToken.parse(from_key)
if from_token.topological:
@@ -1081,10 +1081,10 @@ class RoomEventSource(object):
return (events, end_key)
- def get_current_key(self):
- return self.store.get_room_events_max_id()
+ def get_current_key(self) -> str:
+ return "s%d" % (self.store.get_room_max_stream_ordering(),)
- def get_current_key_for_room(self, room_id):
+ def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
return self.store.get_room_events_max_id(room_id)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 39e57a4503..4634f4df9d 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -16,13 +16,14 @@
import abc
import logging
from http import HTTPStatus
-from typing import Dict, Iterable, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
from unpaddedbase64 import encode_base64
from synapse import types
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
+from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import EventFormatVersions
from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase
@@ -36,6 +37,10 @@ from synapse.util.distributor import user_joined_room, user_left_room
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
logger = logging.getLogger(__name__)
@@ -47,7 +52,7 @@ class RoomMemberHandler(object):
__metaclass__ = abc.ABCMeta
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
@@ -78,6 +83,17 @@ class RoomMemberHandler(object):
if self._is_on_event_persistence_instance:
self.persist_event_storage = hs.get_storage().persistence
+ self._join_rate_limiter_local = Ratelimiter(
+ clock=self.clock,
+ rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
+ burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
+ )
+ self._join_rate_limiter_remote = Ratelimiter(
+ clock=self.clock,
+ rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
+ burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
+ )
+
# This is only used to get at ratelimit function, and
# maybe_kick_guest_users. It's fine there are multiple of these as
# it doesn't store state.
@@ -196,7 +212,7 @@ class RoomMemberHandler(object):
return duplicate.event_id, stream_id
stream_id = await self.event_creation_handler.handle_new_client_event(
- requester, event, context, extra_users=[target], ratelimit=ratelimit
+ requester, event, context, extra_users=[target], ratelimit=ratelimit,
)
prev_state_ids = await context.get_prev_state_ids()
@@ -461,7 +477,28 @@ class RoomMemberHandler(object):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
- if not is_host_in_room:
+ if is_host_in_room:
+ time_now_s = self.clock.time()
+ allowed, time_allowed = self._join_rate_limiter_local.can_do_action(
+ requester.user.to_string(),
+ )
+
+ if not allowed:
+ raise LimitExceededError(
+ retry_after_ms=int(1000 * (time_allowed - time_now_s))
+ )
+
+ else:
+ time_now_s = self.clock.time()
+ allowed, time_allowed = self._join_rate_limiter_remote.can_do_action(
+ requester.user.to_string(),
+ )
+
+ if not allowed:
+ raise LimitExceededError(
+ retry_after_ms=int(1000 * (time_allowed - time_now_s))
+ )
+
inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
@@ -987,7 +1024,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
check_complexity = self.hs.config.limit_remote_rooms.enabled
if check_complexity and self.hs.config.limit_remote_rooms.admins_can_join:
- check_complexity = not await self.hs.auth.is_server_admin(user)
+ check_complexity = not await self.auth.is_server_admin(user)
if check_complexity:
# Fetch the room complexity
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 2d506dc1f2..c1fcb98454 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -14,15 +14,16 @@
# limitations under the License.
import logging
import re
-from typing import Callable, Dict, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple
import attr
import saml2
import saml2.response
from saml2.client import Saml2Client
-from synapse.api.errors import SynapseError
+from synapse.api.errors import AuthError, SynapseError
from synapse.config import ConfigError
+from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
@@ -34,6 +35,9 @@ from synapse.types import (
from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq
+if TYPE_CHECKING:
+ import synapse.server
+
logger = logging.getLogger(__name__)
@@ -49,7 +53,7 @@ class Saml2SessionData:
class SamlHandler:
- def __init__(self, hs):
+ def __init__(self, hs: "synapse.server.HomeServer"):
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
@@ -62,6 +66,7 @@ class SamlHandler:
self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute
)
+ self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
# plugin to do custom mapping from saml response to mxid
self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
@@ -73,7 +78,7 @@ class SamlHandler:
self._auth_provider_id = "saml"
# a map from saml session id to Saml2SessionData object
- self._outstanding_requests_dict = {}
+ self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
# a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
@@ -165,11 +170,18 @@ class SamlHandler:
saml2.BINDING_HTTP_POST,
outstanding=self._outstanding_requests_dict,
)
+ except saml2.response.UnsolicitedResponse as e:
+ # the pysaml2 library helpfully logs an ERROR here, but neglects to log
+ # the session ID. I don't really want to put the full text of the exception
+ # in the (user-visible) exception message, so let's log the exception here
+ # so we can track down the session IDs later.
+ logger.warning(str(e))
+ raise SynapseError(400, "Unexpected SAML2 login.")
except Exception as e:
- raise SynapseError(400, "Unable to parse SAML2 response: %s" % (e,))
+ raise SynapseError(400, "Unable to parse SAML2 response: %s." % (e,))
if saml2_auth.not_signed:
- raise SynapseError(400, "SAML2 response was not signed")
+ raise SynapseError(400, "SAML2 response was not signed.")
logger.debug("SAML2 response: %s", saml2_auth.origxml)
for assertion in saml2_auth.assertions:
@@ -188,6 +200,9 @@ class SamlHandler:
saml2_auth.in_response_to, None
)
+ for requirement in self._saml2_attribute_requirements:
+ _check_attribute_requirement(saml2_auth.ava, requirement)
+
remote_user_id = self._user_mapping_provider.get_remote_user_id(
saml2_auth, client_redirect_url
)
@@ -294,6 +309,21 @@ class SamlHandler:
del self._outstanding_requests_dict[reqid]
+def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
+ values = ava.get(req.attribute, [])
+ for v in values:
+ if v == req.value:
+ return
+
+ logger.info(
+ "SAML2 attribute %s did not match required value '%s' (was '%s')",
+ req.attribute,
+ req.value,
+ values,
+ )
+ raise AuthError(403, "You are not authorized to log in here.")
+
+
DOT_REPLACE_PATTERN = re.compile(
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
)
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 9b312a1558..d58f9788c5 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -340,7 +340,7 @@ class SearchHandler(BaseHandler):
# If client has asked for "context" for each event (i.e. some surrounding
# events and state), fetch that
if event_context is not None:
- now_token = await self.hs.get_event_sources().get_current_token()
+ now_token = self.hs.get_event_sources().get_current_token()
contexts = {}
for event in allowed_events:
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 149f861239..249ffe2a55 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -232,7 +232,7 @@ class StatsHandler:
if membership == prev_membership:
pass # noop
- if membership == Membership.JOIN:
+ elif membership == Membership.JOIN:
room_stats_delta["joined_members"] += 1
elif membership == Membership.INVITE:
room_stats_delta["invited_members"] += 1
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 67b9f9afbf..e4932a1939 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -104,7 +104,6 @@ class JoinedSyncResult:
account_data = attr.ib(type=List[JsonDict])
unread_notifications = attr.ib(type=JsonDict)
summary = attr.ib(type=Optional[JsonDict])
- unread_count = attr.ib(type=int)
def __nonzero__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -964,7 +963,7 @@ class SyncHandler(object):
# this is due to some of the underlying streams not supporting the ability
# to query up to a given point.
# Always use the `now_token` in `SyncResultBuilder`
- now_token = await self.event_sources.get_current_token()
+ now_token = self.event_sources.get_current_token()
logger.debug(
"Calculating sync response for %r between %s and %s",
@@ -1890,10 +1889,6 @@ class SyncHandler(object):
if room_builder.rtype == "joined":
unread_notifications = {} # type: Dict[str, str]
-
- unread_count = await self.store.get_unread_message_count_for_user(
- room_id, sync_config.user.to_string(),
- )
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=batch,
@@ -1902,7 +1897,6 @@ class SyncHandler(object):
account_data=account_data_events,
unread_notifications=unread_notifications,
summary=summary,
- unread_count=unread_count,
)
if room_sync or always_include:
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 155b7460d4..8aeb70cdec 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -284,8 +284,7 @@ class SimpleHttpClient(object):
ip_blacklist=self._ip_blacklist,
)
- @defer.inlineCallbacks
- def request(self, method, uri, data=None, headers=None):
+ async def request(self, method, uri, data=None, headers=None):
"""
Args:
method (str): HTTP method to use.
@@ -298,7 +297,7 @@ class SimpleHttpClient(object):
outgoing_requests_counter.labels(method).inc()
# log request but strip `access_token` (AS requests for example include this)
- logger.info("Sending request %s %s", method, redact_uri(uri))
+ logger.debug("Sending request %s %s", method, redact_uri(uri))
with start_active_span(
"outgoing-client-request",
@@ -330,7 +329,7 @@ class SimpleHttpClient(object):
self.hs.get_reactor(),
cancelled_to_request_timed_out_error,
)
- response = yield make_deferred_yieldable(request_deferred)
+ response = await make_deferred_yieldable(request_deferred)
incoming_responses_counter.labels(method, response.code).inc()
logger.info(
@@ -353,8 +352,7 @@ class SimpleHttpClient(object):
set_tag("error_reason", e.args[0])
raise
- @defer.inlineCallbacks
- def post_urlencoded_get_json(self, uri, args={}, headers=None):
+ async def post_urlencoded_get_json(self, uri, args={}, headers=None):
"""
Args:
uri (str):
@@ -363,7 +361,7 @@ class SimpleHttpClient(object):
header name to a list of values for that header
Returns:
- Deferred[object]: parsed json
+ object: parsed json
Raises:
HttpResponseException: On a non-2xx HTTP response.
@@ -386,11 +384,11 @@ class SimpleHttpClient(object):
if headers:
actual_headers.update(headers)
- response = yield self.request(
+ response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=query_bytes
)
- body = yield make_deferred_yieldable(readBody(response))
+ body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
return json.loads(body.decode("utf-8"))
@@ -399,8 +397,7 @@ class SimpleHttpClient(object):
response.code, response.phrase.decode("ascii", errors="replace"), body
)
- @defer.inlineCallbacks
- def post_json_get_json(self, uri, post_json, headers=None):
+ async def post_json_get_json(self, uri, post_json, headers=None):
"""
Args:
@@ -410,7 +407,7 @@ class SimpleHttpClient(object):
header name to a list of values for that header
Returns:
- Deferred[object]: parsed json
+ object: parsed json
Raises:
HttpResponseException: On a non-2xx HTTP response.
@@ -429,11 +426,11 @@ class SimpleHttpClient(object):
if headers:
actual_headers.update(headers)
- response = yield self.request(
+ response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=json_str
)
- body = yield make_deferred_yieldable(readBody(response))
+ body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
return json.loads(body.decode("utf-8"))
@@ -442,8 +439,7 @@ class SimpleHttpClient(object):
response.code, response.phrase.decode("ascii", errors="replace"), body
)
- @defer.inlineCallbacks
- def get_json(self, uri, args={}, headers=None):
+ async def get_json(self, uri, args={}, headers=None):
""" Gets some json from the given URI.
Args:
@@ -455,7 +451,7 @@ class SimpleHttpClient(object):
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
- Deferred: Succeeds when we get *any* 2xx HTTP response, with the
+ Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
Raises:
HttpResponseException On a non-2xx HTTP response.
@@ -466,11 +462,10 @@ class SimpleHttpClient(object):
if headers:
actual_headers.update(headers)
- body = yield self.get_raw(uri, args, headers=headers)
+ body = await self.get_raw(uri, args, headers=headers)
return json.loads(body.decode("utf-8"))
- @defer.inlineCallbacks
- def put_json(self, uri, json_body, args={}, headers=None):
+ async def put_json(self, uri, json_body, args={}, headers=None):
""" Puts some json to the given URI.
Args:
@@ -483,7 +478,7 @@ class SimpleHttpClient(object):
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
- Deferred: Succeeds when we get *any* 2xx HTTP response, with the
+ Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
Raises:
HttpResponseException On a non-2xx HTTP response.
@@ -504,11 +499,11 @@ class SimpleHttpClient(object):
if headers:
actual_headers.update(headers)
- response = yield self.request(
+ response = await self.request(
"PUT", uri, headers=Headers(actual_headers), data=json_str
)
- body = yield make_deferred_yieldable(readBody(response))
+ body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
return json.loads(body.decode("utf-8"))
@@ -517,8 +512,7 @@ class SimpleHttpClient(object):
response.code, response.phrase.decode("ascii", errors="replace"), body
)
- @defer.inlineCallbacks
- def get_raw(self, uri, args={}, headers=None):
+ async def get_raw(self, uri, args={}, headers=None):
""" Gets raw text from the given URI.
Args:
@@ -530,7 +524,7 @@ class SimpleHttpClient(object):
headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
- Deferred: Succeeds when we get *any* 2xx HTTP response, with the
+ Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as bytes.
Raises:
HttpResponseException on a non-2xx HTTP response.
@@ -543,9 +537,9 @@ class SimpleHttpClient(object):
if headers:
actual_headers.update(headers)
- response = yield self.request("GET", uri, headers=Headers(actual_headers))
+ response = await self.request("GET", uri, headers=Headers(actual_headers))
- body = yield make_deferred_yieldable(readBody(response))
+ body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
return body
@@ -557,8 +551,7 @@ class SimpleHttpClient(object):
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
- @defer.inlineCallbacks
- def get_file(self, url, output_stream, max_size=None, headers=None):
+ async def get_file(self, url, output_stream, max_size=None, headers=None):
"""GETs a file from a given URL
Args:
url (str): The URL to GET
@@ -574,7 +567,7 @@ class SimpleHttpClient(object):
if headers:
actual_headers.update(headers)
- response = yield self.request("GET", url, headers=Headers(actual_headers))
+ response = await self.request("GET", url, headers=Headers(actual_headers))
resp_headers = dict(response.headers.getAllRawHeaders())
@@ -598,7 +591,7 @@ class SimpleHttpClient(object):
# straight back in again
try:
- length = yield make_deferred_yieldable(
+ length = await make_deferred_yieldable(
_readBodyToFile(response, output_stream, max_size)
)
except SynapseError:
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 0c02648015..369bf9c2fc 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -247,7 +247,7 @@ class MatrixHostnameEndpoint(object):
port = server.port
try:
- logger.info("Connecting to %s:%i", host.decode("ascii"), port)
+ logger.debug("Connecting to %s:%i", host.decode("ascii"), port)
endpoint = HostnameEndpoint(self._reactor, host, port)
if self._tls_options:
endpoint = wrapClientTLS(self._tls_options, endpoint)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 2a6373937a..738be43f46 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -29,10 +29,11 @@ from zope.interface import implementer
from twisted.internet import defer, protocol
from twisted.internet.error import DNSLookupError
-from twisted.internet.interfaces import IReactorPluggableNameResolver
+from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime
from twisted.internet.task import _EPSILON, Cooperator
from twisted.web._newclient import ResponseDone
from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
import synapse.metrics
import synapse.util.retryutils
@@ -74,7 +75,7 @@ MAXINT = sys.maxsize
_next_id = 1
-@attr.s
+@attr.s(frozen=True)
class MatrixFederationRequest(object):
method = attr.ib()
"""HTTP method
@@ -110,26 +111,52 @@ class MatrixFederationRequest(object):
:type: str|None
"""
+ uri = attr.ib(init=False, type=bytes)
+ """The URI of this request
+ """
+
def __attrs_post_init__(self):
global _next_id
- self.txn_id = "%s-O-%s" % (self.method, _next_id)
+ txn_id = "%s-O-%s" % (self.method, _next_id)
_next_id = (_next_id + 1) % (MAXINT - 1)
+ object.__setattr__(self, "txn_id", txn_id)
+
+ destination_bytes = self.destination.encode("ascii")
+ path_bytes = self.path.encode("ascii")
+ if self.query:
+ query_bytes = encode_query_args(self.query)
+ else:
+ query_bytes = b""
+
+ # The object is frozen so we can pre-compute this.
+ uri = urllib.parse.urlunparse(
+ (b"matrix", destination_bytes, path_bytes, None, query_bytes, b"")
+ )
+ object.__setattr__(self, "uri", uri)
+
def get_json(self):
if self.json_callback:
return self.json_callback()
return self.json
-async def _handle_json_response(reactor, timeout_sec, request, response):
+async def _handle_json_response(
+ reactor: IReactorTime,
+ timeout_sec: float,
+ request: MatrixFederationRequest,
+ response: IResponse,
+ start_ms: int,
+):
"""
Reads the JSON body of a response, with a timeout
Args:
- reactor (IReactor): twisted reactor, for the timeout
- timeout_sec (float): number of seconds to wait for response to complete
- request (MatrixFederationRequest): the request that triggered the response
- response (IResponse): response to the request
+ reactor: twisted reactor, for the timeout
+ timeout_sec: number of seconds to wait for response to complete
+ request: the request that triggered the response
+ response: response to the request
+ start_ms: Timestamp when request was made
Returns:
dict: parsed JSON response
@@ -143,23 +170,35 @@ async def _handle_json_response(reactor, timeout_sec, request, response):
body = await make_deferred_yieldable(d)
except TimeoutError as e:
logger.warning(
- "{%s} [%s] Timed out reading response", request.txn_id, request.destination,
+ "{%s} [%s] Timed out reading response - %s %s",
+ request.txn_id,
+ request.destination,
+ request.method,
+ request.uri.decode("ascii"),
)
raise RequestSendFailed(e, can_retry=True) from e
except Exception as e:
logger.warning(
- "{%s} [%s] Error reading response: %s",
+ "{%s} [%s] Error reading response %s %s: %s",
request.txn_id,
request.destination,
+ request.method,
+ request.uri.decode("ascii"),
e,
)
raise
+
+ time_taken_secs = reactor.seconds() - start_ms / 1000
+
logger.info(
- "{%s} [%s] Completed: %d %s",
+ "{%s} [%s] Completed request: %d %s in %.2f secs - %s %s",
request.txn_id,
request.destination,
response.code,
response.phrase.decode("ascii", errors="replace"),
+ time_taken_secs,
+ request.method,
+ request.uri.decode("ascii"),
)
return body
@@ -261,7 +300,9 @@ class MatrixFederationHttpClient(object):
# 'M_UNRECOGNIZED' which some endpoints can return when omitting a
# trailing slash on Synapse <= v0.99.3.
logger.info("Retrying request with trailing slash")
- request.path += "/"
+
+ # Request is frozen so we create a new instance
+ request = attr.evolve(request, path=request.path + "/")
response = await self._send_request(request, **send_request_args)
@@ -373,9 +414,7 @@ class MatrixFederationHttpClient(object):
else:
retries_left = MAX_SHORT_RETRIES
- url_bytes = urllib.parse.urlunparse(
- (b"matrix", destination_bytes, path_bytes, None, query_bytes, b"")
- )
+ url_bytes = request.uri
url_str = url_bytes.decode("ascii")
url_to_sign_bytes = urllib.parse.urlunparse(
@@ -402,7 +441,7 @@ class MatrixFederationHttpClient(object):
headers_dict[b"Authorization"] = auth_headers
- logger.info(
+ logger.debug(
"{%s} [%s] Sending request: %s %s; timeout %fs",
request.txn_id,
request.destination,
@@ -436,7 +475,6 @@ class MatrixFederationHttpClient(object):
except DNSLookupError as e:
raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e
except Exception as e:
- logger.info("Failed to send request: %s", e)
raise RequestSendFailed(e, can_retry=True) from e
incoming_responses_counter.labels(
@@ -496,7 +534,7 @@ class MatrixFederationHttpClient(object):
break
except RequestSendFailed as e:
- logger.warning(
+ logger.info(
"{%s} [%s] Request failed: %s %s: %s",
request.txn_id,
request.destination,
@@ -654,6 +692,8 @@ class MatrixFederationHttpClient(object):
json=data,
)
+ start_ms = self.clock.time_msec()
+
response = await self._send_request_with_optional_trailing_slash(
request,
try_trailing_slash_on_400,
@@ -664,7 +704,7 @@ class MatrixFederationHttpClient(object):
)
body = await _handle_json_response(
- self.reactor, self.default_timeout, request, response
+ self.reactor, self.default_timeout, request, response, start_ms
)
return body
@@ -720,6 +760,8 @@ class MatrixFederationHttpClient(object):
method="POST", destination=destination, path=path, query=args, json=data
)
+ start_ms = self.clock.time_msec()
+
response = await self._send_request(
request,
long_retries=long_retries,
@@ -733,7 +775,7 @@ class MatrixFederationHttpClient(object):
_sec_timeout = self.default_timeout
body = await _handle_json_response(
- self.reactor, _sec_timeout, request, response
+ self.reactor, _sec_timeout, request, response, start_ms,
)
return body
@@ -786,6 +828,8 @@ class MatrixFederationHttpClient(object):
method="GET", destination=destination, path=path, query=args
)
+ start_ms = self.clock.time_msec()
+
response = await self._send_request_with_optional_trailing_slash(
request,
try_trailing_slash_on_400,
@@ -796,7 +840,7 @@ class MatrixFederationHttpClient(object):
)
body = await _handle_json_response(
- self.reactor, self.default_timeout, request, response
+ self.reactor, self.default_timeout, request, response, start_ms
)
return body
@@ -846,6 +890,8 @@ class MatrixFederationHttpClient(object):
method="DELETE", destination=destination, path=path, query=args
)
+ start_ms = self.clock.time_msec()
+
response = await self._send_request(
request,
long_retries=long_retries,
@@ -854,7 +900,7 @@ class MatrixFederationHttpClient(object):
)
body = await _handle_json_response(
- self.reactor, self.default_timeout, request, response
+ self.reactor, self.default_timeout, request, response, start_ms
)
return body
@@ -914,12 +960,14 @@ class MatrixFederationHttpClient(object):
)
raise
logger.info(
- "{%s} [%s] Completed: %d %s [%d bytes]",
+ "{%s} [%s] Completed: %d %s [%d bytes] %s %s",
request.txn_id,
request.destination,
response.code,
response.phrase.decode("ascii", errors="replace"),
length,
+ request.method,
+ request.uri.decode("ascii"),
)
return (length, headers)
diff --git a/synapse/http/server.py b/synapse/http/server.py
index d4f9ad6e67..ffe6cfa09e 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -25,7 +25,7 @@ from io import BytesIO
from typing import Any, Callable, Dict, Tuple, Union
import jinja2
-from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
+from canonicaljson import encode_canonical_json, encode_pretty_printed_json
from twisted.internet import defer
from twisted.python import failure
@@ -46,6 +46,7 @@ from synapse.api.errors import (
from synapse.http.site import SynapseRequest
from synapse.logging.context import preserve_fn
from synapse.logging.opentracing import trace_servlet
+from synapse.util import json_encoder
from synapse.util.caches import intern_dict
logger = logging.getLogger(__name__)
@@ -242,10 +243,12 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
no appropriate method exists. Can be overriden in sub classes for
different routing.
"""
+ # Treat HEAD requests as GET requests.
+ request_method = request.method.decode("ascii")
+ if request_method == "HEAD":
+ request_method = "GET"
- method_handler = getattr(
- self, "_async_render_%s" % (request.method.decode("ascii"),), None
- )
+ method_handler = getattr(self, "_async_render_%s" % (request_method,), None)
if method_handler:
raw_callback_return = method_handler(request)
@@ -362,11 +365,15 @@ class JsonResource(DirectServeJsonResource):
A tuple of the callback to use, the name of the servlet, and the
key word arguments to pass to the callback
"""
+ # Treat HEAD requests as GET requests.
request_path = request.path.decode("ascii")
+ request_method = request.method
+ if request_method == b"HEAD":
+ request_method = b"GET"
# Loop through all the registered callbacks to check if the method
# and path regex match
- for path_entry in self.path_regexs.get(request.method, []):
+ for path_entry in self.path_regexs.get(request_method, []):
m = path_entry.pattern.match(request_path)
if m:
# We found a match!
@@ -532,7 +539,7 @@ def respond_with_json(
# canonicaljson already encodes to bytes
json_bytes = encode_canonical_json(json_object)
else:
- json_bytes = json.dumps(json_object).encode("utf-8")
+ json_bytes = json_encoder.encode(json_object).encode("utf-8")
return respond_with_json_bytes(request, code, json_bytes, send_cors=send_cors)
@@ -579,7 +586,7 @@ def set_cors_headers(request: Request):
"""
request.setHeader(b"Access-Control-Allow-Origin", b"*")
request.setHeader(
- b"Access-Control-Allow-Methods", b"GET, POST, PUT, DELETE, OPTIONS"
+ b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS"
)
request.setHeader(
b"Access-Control-Allow-Headers",
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 6f3b2258cc..6e79b47828 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -146,10 +146,9 @@ class SynapseRequest(Request):
Returns a context manager; the correct way to use this is:
- @defer.inlineCallbacks
- def handle_request(request):
+ async def handle_request(request):
with request.processing("FooServlet"):
- yield really_handle_the_request()
+ await really_handle_the_request()
Once the context manager is closed, the completion of the request will be logged,
and the various metrics will be updated.
@@ -287,7 +286,9 @@ class SynapseRequest(Request):
# the connection dropped)
code += "!"
- self.site.access_logger.info(
+ log_level = logging.INFO if self._should_log_request() else logging.DEBUG
+ self.site.access_logger.log(
+ log_level,
"%s - %s - {%s}"
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
' %sB %s "%s %s %s" "%s" [%d dbevts]',
@@ -315,6 +316,17 @@ class SynapseRequest(Request):
except Exception as e:
logger.warning("Failed to stop metrics: %r", e)
+ def _should_log_request(self) -> bool:
+ """Whether we should log at INFO that we processed the request.
+ """
+ if self.path == b"/health":
+ return False
+
+ if self.method == b"OPTIONS":
+ return False
+
+ return True
+
class XForwardedForRequest(SynapseRequest):
def __init__(self, *args, **kw):
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index a9269196b3..f766d16db6 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -13,16 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import logging
import threading
-from asyncio import iscoroutine
from functools import wraps
from typing import TYPE_CHECKING, Dict, Optional, Set
from prometheus_client.core import REGISTRY, Counter, Gauge
from twisted.internet import defer
-from twisted.python.failure import Failure
from synapse.logging.context import LoggingContext, PreserveLoggingContext
@@ -167,7 +166,7 @@ class _BackgroundProcess(object):
)
-def run_as_background_process(desc, func, *args, **kwargs):
+def run_as_background_process(desc: str, func, *args, **kwargs):
"""Run the given function in its own logcontext, with resource metrics
This should be used to wrap processes which are fired off to run in the
@@ -179,7 +178,7 @@ def run_as_background_process(desc, func, *args, **kwargs):
normal synapse inlineCallbacks function).
Args:
- desc (str): a description for this background process type
+ desc: a description for this background process type
func: a function, which may return a Deferred or a coroutine
args: positional args for func
kwargs: keyword args for func
@@ -188,8 +187,7 @@ def run_as_background_process(desc, func, *args, **kwargs):
follow the synapse logcontext rules.
"""
- @defer.inlineCallbacks
- def run():
+ async def run():
with _bg_metrics_lock:
count = _background_process_counts.get(desc, 0)
_background_process_counts[desc] = count + 1
@@ -203,29 +201,21 @@ def run_as_background_process(desc, func, *args, **kwargs):
try:
result = func(*args, **kwargs)
- # We probably don't have an ensureDeferred in our call stack to handle
- # coroutine results, so we need to ensureDeferred here.
- #
- # But we need this check because ensureDeferred doesn't like being
- # called on immediate values (as opposed to Deferreds or coroutines).
- if iscoroutine(result):
- result = defer.ensureDeferred(result)
+ if inspect.isawaitable(result):
+ result = await result
- return (yield result)
+ return result
except Exception:
- # failure.Failure() fishes the original Failure out of our stack, and
- # thus gives us a sensible stack trace.
- f = Failure()
- logger.error(
- "Background process '%s' threw an exception",
- desc,
- exc_info=(f.type, f.value, f.getTracebackObject()),
+ logger.exception(
+ "Background process '%s' threw an exception", desc,
)
finally:
_background_process_in_flight_count.labels(desc).dec()
with PreserveLoggingContext():
- return run()
+ # Note that we return a Deferred here so that it can be used in a
+ # looping_call and other places that expect a Deferred.
+ return defer.ensureDeferred(run())
def wrap_as_background_process(desc):
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index a7849cefa5..c2fb757d9a 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -194,12 +194,16 @@ class ModuleApi(object):
synapse.api.errors.AuthError: the access token is invalid
"""
# see if the access token corresponds to a device
- user_info = yield self._auth.get_user_by_access_token(access_token)
+ user_info = yield defer.ensureDeferred(
+ self._auth.get_user_by_access_token(access_token)
+ )
device_id = user_info.get("device_id")
user_id = user_info["user"].to_string()
if device_id:
# delete the device, which will also delete its access tokens
- yield self._hs.get_device_handler().delete_device(user_id, device_id)
+ yield defer.ensureDeferred(
+ self._hs.get_device_handler().delete_device(user_id, device_id)
+ )
else:
# no associated device. Just delete the access token.
yield defer.ensureDeferred(
@@ -219,7 +223,7 @@ class ModuleApi(object):
Returns:
Deferred[object]: result of func
"""
- return self._store.db.runInteraction(desc, func, *args, **kwargs)
+ return self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
def complete_sso_login(
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
diff --git a/synapse/notifier.py b/synapse/notifier.py
index bd41f77852..dfb096e589 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -15,7 +15,18 @@
import logging
from collections import namedtuple
-from typing import Callable, Iterable, List, TypeVar
+from typing import (
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+ Union,
+)
from prometheus_client import Counter
@@ -24,12 +35,14 @@ from twisted.internet import defer
import synapse.server
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError
+from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
from synapse.logging.context import PreserveLoggingContext
from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import StreamToken
+from synapse.streams.config import PaginationConfig
+from synapse.types import Collection, StreamToken, UserID
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client
@@ -77,7 +90,13 @@ class _NotifierUserStream(object):
so that it can remove itself from the indexes in the Notifier class.
"""
- def __init__(self, user_id, rooms, current_token, time_now_ms):
+ def __init__(
+ self,
+ user_id: str,
+ rooms: Collection[str],
+ current_token: StreamToken,
+ time_now_ms: int,
+ ):
self.user_id = user_id
self.rooms = set(rooms)
self.current_token = current_token
@@ -93,13 +112,13 @@ class _NotifierUserStream(object):
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred())
- def notify(self, stream_key, stream_id, time_now_ms):
+ def notify(self, stream_key: str, stream_id: int, time_now_ms: int):
"""Notify any listeners for this user of a new event from an
event source.
Args:
- stream_key(str): The stream the event came from.
- stream_id(str): The new id for the stream the event came from.
- time_now_ms(int): The current time in milliseconds.
+ stream_key: The stream the event came from.
+ stream_id: The new id for the stream the event came from.
+ time_now_ms: The current time in milliseconds.
"""
self.current_token = self.current_token.copy_and_advance(stream_key, stream_id)
self.last_notified_token = self.current_token
@@ -112,7 +131,7 @@ class _NotifierUserStream(object):
self.notify_deferred = ObservableDeferred(defer.Deferred())
noify_deferred.callback(self.current_token)
- def remove(self, notifier):
+ def remove(self, notifier: "Notifier"):
""" Remove this listener from all the indexes in the Notifier
it knows about.
"""
@@ -123,10 +142,10 @@ class _NotifierUserStream(object):
notifier.user_to_user_stream.pop(self.user_id)
- def count_listeners(self):
+ def count_listeners(self) -> int:
return len(self.notify_deferred.observers())
- def new_listener(self, token):
+ def new_listener(self, token: StreamToken) -> _NotificationListener:
"""Returns a deferred that is resolved when there is a new token
greater than the given token.
@@ -159,14 +178,16 @@ class Notifier(object):
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
def __init__(self, hs: "synapse.server.HomeServer"):
- self.user_to_user_stream = {}
- self.room_to_user_streams = {}
+ self.user_to_user_stream = {} # type: Dict[str, _NotifierUserStream]
+ self.room_to_user_streams = {} # type: Dict[str, Set[_NotifierUserStream]]
self.hs = hs
self.storage = hs.get_storage()
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
- self.pending_new_room_events = []
+ self.pending_new_room_events = (
+ []
+ ) # type: List[Tuple[int, EventBase, Collection[Union[str, UserID]]]]
# Called when there are new things to stream over replication
self.replication_callbacks = [] # type: List[Callable[[], None]]
@@ -178,10 +199,9 @@ class Notifier(object):
self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler()
+ self.federation_sender = None
if hs.should_send_federation():
self.federation_sender = hs.get_federation_sender()
- else:
- self.federation_sender = None
self.state_handler = hs.get_state_handler()
@@ -193,12 +213,12 @@ class Notifier(object):
# when rendering the metrics page, which is likely once per minute at
# most when scraping it.
def count_listeners():
- all_user_streams = set()
+ all_user_streams = set() # type: Set[_NotifierUserStream]
- for x in list(self.room_to_user_streams.values()):
- all_user_streams |= x
- for x in list(self.user_to_user_stream.values()):
- all_user_streams.add(x)
+ for streams in list(self.room_to_user_streams.values()):
+ all_user_streams |= streams
+ for stream in list(self.user_to_user_stream.values()):
+ all_user_streams.add(stream)
return sum(stream.count_listeners() for stream in all_user_streams)
@@ -223,7 +243,11 @@ class Notifier(object):
self.replication_callbacks.append(cb)
def on_new_room_event(
- self, event, room_stream_id, max_room_stream_id, extra_users=[]
+ self,
+ event: EventBase,
+ room_stream_id: int,
+ max_room_stream_id: int,
+ extra_users: Collection[Union[str, UserID]] = [],
):
""" Used by handlers to inform the notifier something has happened
in the room, room event wise.
@@ -241,11 +265,11 @@ class Notifier(object):
self.notify_replication()
- def _notify_pending_new_room_events(self, max_room_stream_id):
+ def _notify_pending_new_room_events(self, max_room_stream_id: int):
"""Notify for the room events that were queued waiting for a previous
event to be persisted.
Args:
- max_room_stream_id(int): The highest stream_id below which all
+ max_room_stream_id: The highest stream_id below which all
events have been persisted.
"""
pending = self.pending_new_room_events
@@ -258,7 +282,12 @@ class Notifier(object):
else:
self._on_new_room_event(event, room_stream_id, extra_users)
- def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
+ def _on_new_room_event(
+ self,
+ event: EventBase,
+ room_stream_id: int,
+ extra_users: Collection[Union[str, UserID]] = [],
+ ):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service.
run_as_background_process(
@@ -275,13 +304,19 @@ class Notifier(object):
"room_key", room_stream_id, users=extra_users, rooms=[event.room_id]
)
- async def _notify_app_services(self, room_stream_id):
+ async def _notify_app_services(self, room_stream_id: int):
try:
await self.appservice_handler.notify_interested_services(room_stream_id)
except Exception:
logger.exception("Error notifying application services of event")
- def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
+ def on_new_event(
+ self,
+ stream_key: str,
+ new_token: int,
+ users: Collection[Union[str, UserID]] = [],
+ rooms: Collection[str] = [],
+ ):
""" Used to inform listeners that something has happened event wise.
Will wake up all listeners for the given users and rooms.
@@ -307,20 +342,25 @@ class Notifier(object):
self.notify_replication()
- def on_new_replication_data(self):
+ def on_new_replication_data(self) -> None:
"""Used to inform replication listeners that something has happend
without waking up any of the normal user event streams"""
self.notify_replication()
async def wait_for_events(
- self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START
- ):
+ self,
+ user_id: str,
+ timeout: int,
+ callback: Callable[[StreamToken, StreamToken], Awaitable[T]],
+ room_ids=None,
+ from_token=StreamToken.START,
+ ) -> T:
"""Wait until the callback returns a non empty response or the
timeout fires.
"""
user_stream = self.user_to_user_stream.get(user_id)
if user_stream is None:
- current_token = await self.event_sources.get_current_token()
+ current_token = self.event_sources.get_current_token()
if room_ids is None:
room_ids = await self.store.get_rooms_for_user(user_id)
user_stream = _NotifierUserStream(
@@ -377,19 +417,16 @@ class Notifier(object):
async def get_events_for(
self,
- user,
- pagination_config,
- timeout,
- only_keys=None,
- is_guest=False,
- explicit_room_id=None,
- ):
+ user: UserID,
+ pagination_config: PaginationConfig,
+ timeout: int,
+ is_guest: bool = False,
+ explicit_room_id: str = None,
+ ) -> EventStreamResult:
""" For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning.
- If `only_keys` is not None, events from keys will be sent down.
-
If explicit_room_id is not set, the user's joined rooms will be polled
for events.
If explicit_room_id is set, that room will be polled for events only if
@@ -397,18 +434,20 @@ class Notifier(object):
"""
from_token = pagination_config.from_token
if not from_token:
- from_token = await self.event_sources.get_current_token()
+ from_token = self.event_sources.get_current_token()
limit = pagination_config.limit
room_ids, is_joined = await self._get_room_ids(user, explicit_room_id)
is_peeking = not is_joined
- async def check_for_updates(before_token, after_token):
+ async def check_for_updates(
+ before_token: StreamToken, after_token: StreamToken
+ ) -> EventStreamResult:
if not after_token.is_after(before_token):
return EventStreamResult([], (from_token, from_token))
- events = []
+ events = [] # type: List[EventBase]
end_token = from_token
for name, source in self.event_sources.sources.items():
@@ -417,8 +456,6 @@ class Notifier(object):
after_id = getattr(after_token, keyname)
if before_id == after_id:
continue
- if only_keys and name not in only_keys:
- continue
new_events, new_key = await source.get_new_events(
user=user,
@@ -476,7 +513,9 @@ class Notifier(object):
return result
- async def _get_room_ids(self, user, explicit_room_id):
+ async def _get_room_ids(
+ self, user: UserID, explicit_room_id: Optional[str]
+ ) -> Tuple[Collection[str], bool]:
joined_room_ids = await self.store.get_rooms_for_user(user.to_string())
if explicit_room_id:
if explicit_room_id in joined_room_ids:
@@ -486,7 +525,7 @@ class Notifier(object):
raise AuthError(403, "Non-joined access not allowed")
return joined_room_ids, True
- async def _is_world_readable(self, room_id):
+ async def _is_world_readable(self, room_id: str) -> bool:
state = await self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
@@ -496,7 +535,7 @@ class Notifier(object):
return False
@log_function
- def remove_expired_streams(self):
+ def remove_expired_streams(self) -> None:
time_now_ms = self.clock.time_msec()
expired_streams = []
expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
@@ -510,21 +549,21 @@ class Notifier(object):
expired_stream.remove(self)
@log_function
- def _register_with_keys(self, user_stream):
+ def _register_with_keys(self, user_stream: _NotifierUserStream):
self.user_to_user_stream[user_stream.user_id] = user_stream
for room in user_stream.rooms:
s = self.room_to_user_streams.setdefault(room, set())
s.add(user_stream)
- def _user_joined_room(self, user_id, room_id):
+ def _user_joined_room(self, user_id: str, room_id: str):
new_user_stream = self.user_to_user_stream.get(user_id)
if new_user_stream is not None:
room_streams = self.room_to_user_streams.setdefault(room_id, set())
room_streams.add(new_user_stream)
new_user_stream.rooms.add(room_id)
- def notify_replication(self):
+ def notify_replication(self) -> None:
"""Notify the any replication listeners that there's a new event"""
for cb in self.replication_callbacks:
cb()
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 286374d0b5..8047873ff1 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -19,11 +19,13 @@ import copy
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
-def list_with_base_rules(rawrules):
+def list_with_base_rules(rawrules, use_new_defaults=False):
"""Combine the list of rules set by the user with the default push rules
Args:
rawrules(list): The rules the user has modified or set.
+ use_new_defaults(bool): Whether to use the new experimental default rules when
+ appending or prepending default rules.
Returns:
A new list with the rules set by the user combined with the defaults.
@@ -43,7 +45,9 @@ def list_with_base_rules(rawrules):
ruleslist.extend(
make_base_prepend_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
+ PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+ modified_base_rules,
+ use_new_defaults,
)
)
@@ -54,6 +58,7 @@ def list_with_base_rules(rawrules):
make_base_append_rules(
PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
modified_base_rules,
+ use_new_defaults,
)
)
current_prio_class -= 1
@@ -62,6 +67,7 @@ def list_with_base_rules(rawrules):
make_base_prepend_rules(
PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
modified_base_rules,
+ use_new_defaults,
)
)
@@ -70,27 +76,39 @@ def list_with_base_rules(rawrules):
while current_prio_class > 0:
ruleslist.extend(
make_base_append_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
+ PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+ modified_base_rules,
+ use_new_defaults,
)
)
current_prio_class -= 1
if current_prio_class > 0:
ruleslist.extend(
make_base_prepend_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
+ PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+ modified_base_rules,
+ use_new_defaults,
)
)
return ruleslist
-def make_base_append_rules(kind, modified_base_rules):
+def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
rules = []
if kind == "override":
- rules = BASE_APPEND_OVERRIDE_RULES
+ rules = (
+ NEW_APPEND_OVERRIDE_RULES
+ if use_new_defaults
+ else BASE_APPEND_OVERRIDE_RULES
+ )
elif kind == "underride":
- rules = BASE_APPEND_UNDERRIDE_RULES
+ rules = (
+ NEW_APPEND_UNDERRIDE_RULES
+ if use_new_defaults
+ else BASE_APPEND_UNDERRIDE_RULES
+ )
elif kind == "content":
rules = BASE_APPEND_CONTENT_RULES
@@ -105,7 +123,7 @@ def make_base_append_rules(kind, modified_base_rules):
return rules
-def make_base_prepend_rules(kind, modified_base_rules):
+def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
rules = []
if kind == "override":
@@ -270,6 +288,135 @@ BASE_APPEND_OVERRIDE_RULES = [
]
+NEW_APPEND_OVERRIDE_RULES = [
+ {
+ "rule_id": "global/override/.m.rule.encrypted",
+ "conditions": [
+ {
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.room.encrypted",
+ "_id": "_encrypted",
+ }
+ ],
+ "actions": ["notify"],
+ },
+ {
+ "rule_id": "global/override/.m.rule.suppress_notices",
+ "conditions": [
+ {
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.room.message",
+ "_id": "_suppress_notices_type",
+ },
+ {
+ "kind": "event_match",
+ "key": "content.msgtype",
+ "pattern": "m.notice",
+ "_id": "_suppress_notices",
+ },
+ ],
+ "actions": [],
+ },
+ {
+ "rule_id": "global/underride/.m.rule.suppress_edits",
+ "conditions": [
+ {
+ "kind": "event_match",
+ "key": "m.relates_to.m.rel_type",
+ "pattern": "m.replace",
+ "_id": "_suppress_edits",
+ }
+ ],
+ "actions": [],
+ },
+ {
+ "rule_id": "global/override/.m.rule.invite_for_me",
+ "conditions": [
+ {
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.room.member",
+ "_id": "_member",
+ },
+ {
+ "kind": "event_match",
+ "key": "content.membership",
+ "pattern": "invite",
+ "_id": "_invite_member",
+ },
+ {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"},
+ ],
+ "actions": ["notify", {"set_tweak": "sound", "value": "default"}],
+ },
+ {
+ "rule_id": "global/override/.m.rule.contains_display_name",
+ "conditions": [{"kind": "contains_display_name"}],
+ "actions": [
+ "notify",
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight"},
+ ],
+ },
+ {
+ "rule_id": "global/override/.m.rule.tombstone",
+ "conditions": [
+ {
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.room.tombstone",
+ "_id": "_tombstone",
+ },
+ {
+ "kind": "event_match",
+ "key": "state_key",
+ "pattern": "",
+ "_id": "_tombstone_statekey",
+ },
+ ],
+ "actions": [
+ "notify",
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight"},
+ ],
+ },
+ {
+ "rule_id": "global/override/.m.rule.roomnotif",
+ "conditions": [
+ {
+ "kind": "event_match",
+ "key": "content.body",
+ "pattern": "@room",
+ "_id": "_roomnotif_content",
+ },
+ {
+ "kind": "sender_notification_permission",
+ "key": "room",
+ "_id": "_roomnotif_pl",
+ },
+ ],
+ "actions": [
+ "notify",
+ {"set_tweak": "highlight"},
+ {"set_tweak": "sound", "value": "default"},
+ ],
+ },
+ {
+ "rule_id": "global/override/.m.rule.call",
+ "conditions": [
+ {
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.call.invite",
+ "_id": "_call",
+ }
+ ],
+ "actions": ["notify", {"set_tweak": "sound", "value": "ring"}],
+ },
+]
+
+
BASE_APPEND_UNDERRIDE_RULES = [
{
"rule_id": "global/underride/.m.rule.call",
@@ -354,6 +501,36 @@ BASE_APPEND_UNDERRIDE_RULES = [
]
+NEW_APPEND_UNDERRIDE_RULES = [
+ {
+ "rule_id": "global/underride/.m.rule.room_one_to_one",
+ "conditions": [
+ {"kind": "room_member_count", "is": "2", "_id": "member_count"},
+ {
+ "kind": "event_match",
+ "key": "content.body",
+ "pattern": "*",
+ "_id": "body",
+ },
+ ],
+ "actions": ["notify", {"set_tweak": "sound", "value": "default"}],
+ },
+ {
+ "rule_id": "global/underride/.m.rule.message",
+ "conditions": [
+ {
+ "kind": "event_match",
+ "key": "content.body",
+ "pattern": "*",
+ "_id": "body",
+ },
+ ],
+ "actions": ["notify"],
+ "enabled": False,
+ },
+]
+
+
BASE_RULE_IDS = set()
for r in BASE_APPEND_CONTENT_RULES:
@@ -375,3 +552,26 @@ for r in BASE_APPEND_UNDERRIDE_RULES:
r["priority_class"] = PRIORITY_CLASS_MAP["underride"]
r["default"] = True
BASE_RULE_IDS.add(r["rule_id"])
+
+
+NEW_RULE_IDS = set()
+
+for r in BASE_APPEND_CONTENT_RULES:
+ r["priority_class"] = PRIORITY_CLASS_MAP["content"]
+ r["default"] = True
+ NEW_RULE_IDS.add(r["rule_id"])
+
+for r in BASE_PREPEND_OVERRIDE_RULES:
+ r["priority_class"] = PRIORITY_CLASS_MAP["override"]
+ r["default"] = True
+ NEW_RULE_IDS.add(r["rule_id"])
+
+for r in NEW_APPEND_OVERRIDE_RULES:
+ r["priority_class"] = PRIORITY_CLASS_MAP["override"]
+ r["default"] = True
+ NEW_RULE_IDS.add(r["rule_id"])
+
+for r in NEW_APPEND_UNDERRIDE_RULES:
+ r["priority_class"] = PRIORITY_CLASS_MAP["underride"]
+ r["default"] = True
+ NEW_RULE_IDS.add(r["rule_id"])
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 04b9d8ac82..e7fcee0e87 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -120,7 +120,7 @@ class BulkPushRuleEvaluator(object):
pl_event = await self.store.get_event(pl_event_id)
auth_events = {POWER_KEY: pl_event}
else:
- auth_events_ids = await self.auth.compute_auth_events(
+ auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=False
)
auth_events = await self.store.get_events(auth_events_ids)
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index bc8f71916b..d0145666bf 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -21,13 +21,22 @@ async def get_badge_count(store, user_id):
invites = await store.get_invited_rooms_for_local_user(user_id)
joins = await store.get_rooms_for_user(user_id)
+ my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read")
+
badge = len(invites)
for room_id in joins:
- unread_count = await store.get_unread_message_count_for_user(room_id, user_id)
- # return one badge count per conversation, as count per
- # message is so noisy as to be almost useless
- badge += 1 if unread_count else 0
+ if room_id in my_receipts_by_room:
+ last_unread_event_id = my_receipts_by_room[room_id]
+
+ notifs = await (
+ store.get_unread_event_push_actions_by_room_for_user(
+ room_id, user_id, last_unread_event_id
+ )
+ )
+ # return one badge count per conversation, as count per
+ # message is so noisy as to be almost useless
+ badge += 1 if notifs["notify_count"] else 0
return badge
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index abea2be4ef..e5f22fb858 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -59,7 +59,6 @@ REQUIREMENTS = [
"pyyaml>=3.11",
"pyasn1>=0.1.9",
"pyasn1-modules>=0.0.7",
- "daemonize>=2.3.1",
"bcrypt>=3.1.0",
"pillow>=4.3.0",
"sortedcontainers>=1.4.4",
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index fb0dd04f88..6a28c2db9d 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -20,8 +20,6 @@ import urllib
from inspect import signature
from typing import Dict, List, Tuple
-from twisted.internet import defer
-
from synapse.api.errors import (
CodeMessageException,
HttpResponseException,
@@ -101,7 +99,7 @@ class ReplicationEndpoint(object):
assert self.METHOD in ("PUT", "POST", "GET")
@abc.abstractmethod
- def _serialize_payload(**kwargs):
+ async def _serialize_payload(**kwargs):
"""Static method that is called when creating a request.
Concrete implementations should have explicit parameters (rather than
@@ -110,9 +108,8 @@ class ReplicationEndpoint(object):
argument list.
Returns:
- Deferred[dict]|dict: If POST/PUT request then dictionary must be
- JSON serialisable, otherwise must be appropriate for adding as
- query args.
+ dict: If POST/PUT request then dictionary must be JSON serialisable,
+ otherwise must be appropriate for adding as query args.
"""
return {}
@@ -144,8 +141,7 @@ class ReplicationEndpoint(object):
instance_map = hs.config.worker.instance_map
@trace(opname="outgoing_replication_request")
- @defer.inlineCallbacks
- def send_request(instance_name="master", **kwargs):
+ async def send_request(instance_name="master", **kwargs):
if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self")
if instance_name == "master":
@@ -159,7 +155,7 @@ class ReplicationEndpoint(object):
"Instance %r not in 'instance_map' config" % (instance_name,)
)
- data = yield cls._serialize_payload(**kwargs)
+ data = await cls._serialize_payload(**kwargs)
url_args = [
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
@@ -197,7 +193,7 @@ class ReplicationEndpoint(object):
headers = {} # type: Dict[bytes, List[bytes]]
inject_active_span_byte_dict(headers, None, check_destination=False)
try:
- result = yield request_func(uri, data, headers=headers)
+ result = await request_func(uri, data, headers=headers)
break
except CodeMessageException as e:
if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
@@ -207,7 +203,7 @@ class ReplicationEndpoint(object):
# If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway.
- yield clock.sleep(1)
+ await clock.sleep(1)
except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index e32aac0a25..20f3ba76c0 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -60,7 +60,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
- def _serialize_payload(user_id):
+ async def _serialize_payload(user_id):
return {}
async def _handle_request(self, request, user_id):
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index ca065e819e..6b56315148 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
@@ -67,8 +65,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
self.federation_handler = hs.get_handlers().federation_handler
@staticmethod
- @defer.inlineCallbacks
- def _serialize_payload(store, event_and_contexts, backfilled):
+ async def _serialize_payload(store, event_and_contexts, backfilled):
"""
Args:
store
@@ -78,9 +75,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
"""
event_payloads = []
for event, context in event_and_contexts:
- serialized_context = yield defer.ensureDeferred(
- context.serialize(event, store)
- )
+ serialized_context = await context.serialize(event, store)
event_payloads.append(
{
@@ -156,7 +151,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
self.registry = hs.get_federation_registry()
@staticmethod
- def _serialize_payload(edu_type, origin, content):
+ async def _serialize_payload(edu_type, origin, content):
return {"origin": origin, "content": content}
async def _handle_request(self, request, edu_type):
@@ -199,7 +194,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
self.registry = hs.get_federation_registry()
@staticmethod
- def _serialize_payload(query_type, args):
+ async def _serialize_payload(query_type, args):
"""
Args:
query_type (str)
@@ -240,7 +235,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
self.store = hs.get_datastore()
@staticmethod
- def _serialize_payload(room_id, args):
+ async def _serialize_payload(room_id, args):
"""
Args:
room_id (str)
@@ -275,7 +270,7 @@ class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
self.store = hs.get_datastore()
@staticmethod
- def _serialize_payload(room_id, room_version):
+ async def _serialize_payload(room_id, room_version):
return {"room_version": room_version.identifier}
async def _handle_request(self, request, room_id):
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 798b9d3af5..fb326bb869 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -36,7 +36,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler()
@staticmethod
- def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
+ async def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
"""
Args:
device_id (str|None): Device ID to use, if None a new one is
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 63ef6eb7be..741329ab5f 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -52,7 +52,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
- def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
+ async def _serialize_payload(
+ requester, room_id, user_id, remote_room_hosts, content
+ ):
"""
Args:
requester(Requester)
@@ -112,7 +114,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
self.member_handler = hs.get_room_member_handler()
@staticmethod
- def _serialize_payload( # type: ignore
+ async def _serialize_payload( # type: ignore
invite_event_id: str,
txn_id: Optional[str],
requester: Requester,
@@ -174,7 +176,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
self.distributor = hs.get_distributor()
@staticmethod
- def _serialize_payload(room_id, user_id, change):
+ async def _serialize_payload(room_id, user_id, change):
"""
Args:
room_id (str)
diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py
index ea1b33331b..bc9aa82cb4 100644
--- a/synapse/replication/http/presence.py
+++ b/synapse/replication/http/presence.py
@@ -50,7 +50,7 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
self._presence_handler = hs.get_presence_handler()
@staticmethod
- def _serialize_payload(user_id):
+ async def _serialize_payload(user_id):
return {}
async def _handle_request(self, request, user_id):
@@ -92,7 +92,7 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
self._presence_handler = hs.get_presence_handler()
@staticmethod
- def _serialize_payload(user_id, state, ignore_status_msg=False):
+ async def _serialize_payload(user_id, state, ignore_status_msg=False):
return {
"state": state,
"ignore_status_msg": ignore_status_msg,
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 0c4aca1291..ce9420aa69 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -34,7 +34,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler()
@staticmethod
- def _serialize_payload(
+ async def _serialize_payload(
user_id,
password_hash,
was_guest,
@@ -105,7 +105,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler()
@staticmethod
- def _serialize_payload(user_id, auth_result, access_token):
+ async def _serialize_payload(user_id, auth_result, access_token):
"""
Args:
user_id (str): The user ID that consented
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index b30e4d5039..f13d452426 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
@@ -62,8 +60,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
- @defer.inlineCallbacks
- def _serialize_payload(
+ async def _serialize_payload(
event_id, store, event, context, requester, ratelimit, extra_users
):
"""
@@ -77,7 +74,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
extra_users (list(UserID)): Any extra users to notify about event
"""
- serialized_context = yield defer.ensureDeferred(context.serialize(event, store))
+ serialized_context = await context.serialize(event, store)
payload = {
"event": event.get_pdu_json(),
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
index bde97eef32..309159e304 100644
--- a/synapse/replication/http/streams.py
+++ b/synapse/replication/http/streams.py
@@ -54,7 +54,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
self.streams = hs.get_replication_streams()
@staticmethod
- def _serialize_payload(stream_name, from_token, upto_token):
+ async def _serialize_payload(stream_name, from_token, upto_token):
return {"from_token": from_token, "upto_token": upto_token}
async def _handle_request(self, request, stream_name):
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index f9e2533e96..60f2e1245f 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -16,8 +16,8 @@
import logging
from typing import Optional
-from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
class BaseSlavedStore(CacheInvalidationWorkerStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = MultiWriterIdGenerator(
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 525b94fd87..154f0e687c 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -17,13 +17,13 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
-from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
-from synapse.storage.data_stores.main.tags import TagsWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.storage.databases.main.tags import TagsWorkerStore
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
self._account_data_id_gen = SlavedIdTracker(
db_conn,
"account_data",
diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
index a67fbeffb7..0f8d7037bd 100644
--- a/synapse/replication/slave/storage/appservice.py
+++ b/synapse/replication/slave/storage/appservice.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.appservice import (
+from synapse.storage.databases.main.appservice import (
ApplicationServiceTransactionWorkerStore,
ApplicationServiceWorkerStore,
)
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 1a38f53dfb..a6fdedde63 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -13,22 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches.descriptors import Cache
from ._base import BaseSlavedStore
class SlavedClientIpStore(BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000
)
- def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
+ async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())
key = (user_id, access_token, ip)
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index a8a16dbc71..ee7f69a918 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -16,14 +16,14 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ToDeviceStream
-from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
db_conn, "device_inbox", "stream_id"
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 9d8067342f..722f3745e9 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -16,14 +16,14 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
-from synapse.storage.data_stores.main.devices import DeviceWorkerStore
-from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.devices import DeviceWorkerStore
+from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedDeviceStore, self).__init__(database, db_conn, hs)
self.hs = hs
diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py
index 8b9717c46f..1945bcf9a8 100644
--- a/synapse/replication/slave/storage/directory.py
+++ b/synapse/replication/slave/storage/directory.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.directory import DirectoryWorkerStore
+from synapse.storage.databases.main.directory import DirectoryWorkerStore
from ._base import BaseSlavedStore
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 1a1a50a24f..da1cc836cf 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -15,18 +15,18 @@
# limitations under the License.
import logging
-from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore
-from synapse.storage.data_stores.main.event_push_actions import (
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
+from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.data_stores.main.relations import RelationsWorkerStore
-from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
-from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
-from synapse.storage.data_stores.main.state import StateGroupWorkerStore
-from synapse.storage.data_stores.main.stream import StreamWorkerStore
-from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.databases.main.relations import RelationsWorkerStore
+from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
+from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.storage.databases.main.state import StateGroupWorkerStore
+from synapse.storage.databases.main.stream import StreamWorkerStore
+from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore
@@ -55,11 +55,11 @@ class SlavedEventStore(
RelationsWorkerStore,
BaseSlavedStore,
):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedEventStore, self).__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
- curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
+ curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
db_conn,
"current_state_delta_stream",
entity_column="room_id",
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index bcb0688954..2562b6fc38 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.filtering import FilteringStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.filtering import FilteringStore
from ._base import BaseSlavedStore
class SlavedFilteringStore(BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedFilteringStore, self).__init__(database, db_conn, hs)
# Filters are immutable so this cache doesn't need to be expired
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 5d210fa3a1..3291558c7a 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -16,13 +16,13 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import GroupServerStream
-from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.group_server import GroupServerWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)
self.hs = hs
diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py
index 3def367ae9..961579751c 100644
--- a/synapse/replication/slave/storage/keys.py
+++ b/synapse/replication/slave/storage/keys.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.keys import KeyStore
+from synapse.storage.databases.main.keys import KeyStore
# KeyStore isn't really safe to use from a worker, but for now we do so and hope that
# the races it creates aren't too bad.
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 2938cb8e43..a912c04360 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -15,8 +15,8 @@
from synapse.replication.tcp.streams import PresenceStream
from synapse.storage import DataStore
-from synapse.storage.data_stores.main.presence import PresenceStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.presence import PresenceStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore
@@ -24,7 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker
class SlavedPresenceStore(BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py
index 28c508aad3..f85b20a071 100644
--- a/synapse/replication/slave/storage/profile.py
+++ b/synapse/replication/slave/storage/profile.py
@@ -14,7 +14,7 @@
# limitations under the License.
from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.storage.data_stores.main.profile import ProfileWorkerStore
+from synapse.storage.databases.main.profile import ProfileWorkerStore
class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore):
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 23ec1c5b11..590187df46 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -15,7 +15,7 @@
# limitations under the License.
from synapse.replication.tcp.streams import PushRulesStream
-from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
+from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
from .events import SlavedEventStore
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index ff449f3658..63300e5da6 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -15,15 +15,15 @@
# limitations under the License.
from synapse.replication.tcp.streams import PushersStream
-from synapse.storage.data_stores.main.pusher import PusherWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.pusher import PusherWorkerStore
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedPusherStore, self).__init__(database, db_conn, hs)
self._pushers_id_gen = SlavedIdTracker(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 6982686eb5..17ba1f22ac 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -15,15 +15,15 @@
# limitations under the License.
from synapse.replication.tcp.streams import ReceiptsStream
-from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = SlavedIdTracker(
diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py
index 4b8553e250..a40f064e2b 100644
--- a/synapse/replication/slave/storage/registration.py
+++ b/synapse/replication/slave/storage/registration.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.registration import RegistrationWorkerStore
+from synapse.storage.databases.main.registration import RegistrationWorkerStore
from ._base import BaseSlavedStore
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 8710207ada..427c81772b 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -14,15 +14,15 @@
# limitations under the License.
from synapse.replication.tcp.streams import PublicRoomsStream
-from synapse.storage.data_stores.main.room import RoomWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.room import RoomWorkerStore
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
class RoomStore(RoomWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomStore, self).__init__(database, db_conn, hs)
self._public_room_id_gen = SlavedIdTracker(
db_conn, "public_room_list_stream", "stream_id"
diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
index ac88e6b8c3..2091ac0df6 100644
--- a/synapse/replication/slave/storage/transactions.py
+++ b/synapse/replication/slave/storage/transactions.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.transactions import TransactionStore
+from synapse.storage.databases.main.transactions import TransactionStore
from ._base import BaseSlavedStore
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index f33801f883..d853e4447e 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -18,11 +18,12 @@ The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are
allowed to be sent by which side.
"""
import abc
-import json
import logging
from typing import Tuple, Type
-_json_encoder = json.JSONEncoder()
+from canonicaljson import json
+
+from synapse.util import json_encoder as _json_encoder
logger = logging.getLogger(__name__)
diff --git a/synapse/res/templates/saml_error.html b/synapse/res/templates/saml_error.html
index f8a5fccd38..01cd9bdaf3 100644
--- a/synapse/res/templates/saml_error.html
+++ b/synapse/res/templates/saml_error.html
@@ -2,10 +2,17 @@
<html lang="en">
<head>
<meta charset="UTF-8">
- <title>SSO error</title>
+ <title>SSO login error</title>
</head>
<body>
- <p>Oops! Something went wrong during authentication<span id="errormsg"></span>.</p>
+{# a 403 means we have actively rejected their login #}
+{% if code == 403 %}
+ <p>You are not allowed to log in here.</p>
+{% else %}
+ <p>
+ There was an error during authentication:
+ </p>
+ <div id="errormsg" style="margin:20px 80px">{{ msg }}</div>
<p>
If you are seeing this page after clicking a link sent to you via email, make
sure you only click the confirmation link once, and that you open the
@@ -37,9 +44,9 @@
// to print one.
let errorDesc = new URLSearchParams(searchStr).get("error_description")
if (errorDesc) {
-
- document.getElementById("errormsg").innerText = ` ("${errorDesc}")`;
+ document.getElementById("errormsg").innerText = errorDesc;
}
</script>
+{% endif %}
</body>
</html>
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index a8364d9793..7c292ef3f9 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -31,7 +31,7 @@ from synapse.rest.admin._base import (
assert_user_is_admin,
historical_admin_path_patterns,
)
-from synapse.storage.data_stores.main.room import RoomSortOrder
+from synapse.storage.databases.main.room import RoomSortOrder
from synapse.types import RoomAlias, RoomID, UserID, create_requester
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 5934b1fe8b..b210015173 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -89,7 +89,7 @@ class ClientDirectoryServer(RestServlet):
dir_handler = self.handlers.directory_handler
try:
- service = await self.auth.get_appservice_by_req(request)
+ service = self.auth.get_appservice_by_req(request)
room_alias = RoomAlias.from_string(room_alias)
await dir_handler.delete_appservice_association(service, room_alias)
logger.info(
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 9fd4908136..00831879f3 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -25,7 +25,7 @@ from synapse.http.servlet import (
parse_json_value_from_request,
parse_string,
)
-from synapse.push.baserules import BASE_RULE_IDS
+from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest.client.v2_alpha._base import client_patterns
@@ -45,6 +45,8 @@ class PushRuleRestServlet(RestServlet):
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
+ self._users_new_default_push_rules = hs.config.users_new_default_push_rules
+
async def on_PUT(self, request, path):
if self._is_worker:
raise Exception("Cannot handle PUT /push_rules on worker")
@@ -179,7 +181,12 @@ class PushRuleRestServlet(RestServlet):
rule_id = spec["rule_id"]
is_default_rule = rule_id.startswith(".")
if is_default_rule:
- if namespaced_rule_id not in BASE_RULE_IDS:
+ if user_id in self._users_new_default_push_rules:
+ rule_ids = NEW_RULE_IDS
+ else:
+ rule_ids = BASE_RULE_IDS
+
+ if namespaced_rule_id not in rule_ids:
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
return self.store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 26d5a51cb2..2ab30ce897 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -444,7 +444,7 @@ class RoomMemberListRestServlet(RestServlet):
async def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
- requester = await self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
handler = self.message_handler
# request the state as of a given event, as identified by a stream token,
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 3767a809a4..fead85074b 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -18,7 +18,12 @@ import logging
from http import HTTPStatus
from synapse.api.constants import LoginType
-from synapse.api.errors import Codes, SynapseError, ThreepidValidationError
+from synapse.api.errors import (
+ Codes,
+ InteractiveAuthIncompleteError,
+ SynapseError,
+ ThreepidValidationError,
+)
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
@@ -239,18 +244,12 @@ class PasswordRestServlet(RestServlet):
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the new password provided to us.
- if "new_password" in body:
- new_password = body.pop("new_password")
+ new_password = body.pop("new_password", None)
+ if new_password is not None:
if not isinstance(new_password, str) or len(new_password) > 512:
raise SynapseError(400, "Invalid password")
self.password_policy_handler.validate_password(new_password)
- # If the password is valid, hash it and store it back on the body.
- # This ensures that only the hashed password is handled everywhere.
- if "new_password_hash" in body:
- raise SynapseError(400, "Unexpected property: new_password_hash")
- body["new_password_hash"] = await self.auth_handler.hash(new_password)
-
# there are two possibilities here. Either the user does not have an
# access token, and needs to do a password reset; or they have one and
# need to validate their identity.
@@ -263,23 +262,49 @@ class PasswordRestServlet(RestServlet):
if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request)
- params = await self.auth_handler.validate_user_via_ui_auth(
- requester,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "modify your account password",
- )
+ try:
+ params, session_id = await self.auth_handler.validate_user_via_ui_auth(
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "modify your account password",
+ )
+ except InteractiveAuthIncompleteError as e:
+ # The user needs to provide more steps to complete auth, but
+ # they're not required to provide the password again.
+ #
+ # If a password is available now, hash the provided password and
+ # store it for later.
+ if new_password:
+ password_hash = await self.auth_handler.hash(new_password)
+ await self.auth_handler.set_session_data(
+ e.session_id, "password_hash", password_hash
+ )
+ raise
user_id = requester.user.to_string()
else:
requester = None
- result, params, _ = await self.auth_handler.check_auth(
- [[LoginType.EMAIL_IDENTITY]],
- request,
- body,
- self.hs.get_ip_from_request(request),
- "modify your account password",
- )
+ try:
+ result, params, session_id = await self.auth_handler.check_ui_auth(
+ [[LoginType.EMAIL_IDENTITY]],
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "modify your account password",
+ )
+ except InteractiveAuthIncompleteError as e:
+ # The user needs to provide more steps to complete auth, but
+ # they're not required to provide the password again.
+ #
+ # If a password is available now, hash the provided password and
+ # store it for later.
+ if new_password:
+ password_hash = await self.auth_handler.hash(new_password)
+ await self.auth_handler.set_session_data(
+ e.session_id, "password_hash", password_hash
+ )
+ raise
if LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
@@ -304,12 +329,21 @@ class PasswordRestServlet(RestServlet):
logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
- assert_params_in_dict(params, ["new_password_hash"])
- new_password_hash = params["new_password_hash"]
+ # If we have a password in this request, prefer it. Otherwise, there
+ # must be a password hash from an earlier request.
+ if new_password:
+ password_hash = await self.auth_handler.hash(new_password)
+ else:
+ password_hash = await self.auth_handler.get_session_data(
+ session_id, "password_hash", None
+ )
+ if not password_hash:
+ raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
+
logout_devices = params.get("logout_devices", True)
await self._set_password_handler.set_password(
- user_id, new_password_hash, logout_devices, requester
+ user_id, password_hash, logout_devices, requester
)
return 200, {}
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 370742ce59..f808175698 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -24,6 +24,7 @@ import synapse.types
from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
+ InteractiveAuthIncompleteError,
SynapseError,
ThreepidValidationError,
UnrecognizedRequestError,
@@ -387,6 +388,7 @@ class RegisterRestServlet(RestServlet):
self.ratelimiter = hs.get_registration_ratelimiter()
self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
+ self._registration_enabled = self.hs.config.enable_registration
self._registration_flows = _calculate_registration_flows(
hs.config, self.auth_handler
@@ -412,20 +414,8 @@ class RegisterRestServlet(RestServlet):
"Do not understand membership kind: %s" % (kind.decode("utf8"),)
)
- # we do basic sanity checks here because the auth layer will store these
- # in sessions. Pull out the username/password provided to us.
- if "password" in body:
- password = body.pop("password")
- if not isinstance(password, str) or len(password) > 512:
- raise SynapseError(400, "Invalid password")
- self.password_policy_handler.validate_password(password)
-
- # If the password is valid, hash it and store it back on the body.
- # This ensures that only the hashed password is handled everywhere.
- if "password_hash" in body:
- raise SynapseError(400, "Unexpected property: password_hash")
- body["password_hash"] = await self.auth_handler.hash(password)
-
+ # Pull out the provided username and do basic sanity checks early since
+ # the auth layer will store these in sessions.
desired_username = None
if "username" in body:
if not isinstance(body["username"], str) or len(body["username"]) > 512:
@@ -434,7 +424,7 @@ class RegisterRestServlet(RestServlet):
appservice = None
if self.auth.has_access_token(request):
- appservice = await self.auth.get_appservice_by_req(request)
+ appservice = self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes which have completely
# different registration flows to normal users
@@ -459,22 +449,35 @@ class RegisterRestServlet(RestServlet):
)
return 200, result # we throw for non 200 responses
- # for regular registration, downcase the provided username before
- # attempting to register it. This should mean
- # that people who try to register with upper-case in their usernames
- # don't get a nasty surprise. (Note that we treat username
- # case-insenstively in login, so they are free to carry on imagining
- # that their username is CrAzYh4cKeR if that keeps them happy)
- if desired_username is not None:
- desired_username = desired_username.lower()
-
# == Normal User Registration == (everyone else)
- if not self.hs.config.enable_registration:
+ if not self._registration_enabled:
raise SynapseError(403, "Registration has been disabled")
+ # For regular registration, convert the provided username to lowercase
+ # before attempting to register it. This should mean that people who try
+ # to register with upper-case in their usernames don't get a nasty surprise.
+ #
+ # Note that we treat usernames case-insensitively in login, so they are
+ # free to carry on imagining that their username is CrAzYh4cKeR if that
+ # keeps them happy.
+ if desired_username is not None:
+ desired_username = desired_username.lower()
+
+ # Check if this account is upgrading from a guest account.
guest_access_token = body.get("guest_access_token", None)
- if "initial_device_display_name" in body and "password_hash" not in body:
+ # Pull out the provided password and do basic sanity checks early.
+ #
+ # Note that we remove the password from the body since the auth layer
+ # will store the body in the session and we don't want a plaintext
+ # password store there.
+ password = body.pop("password", None)
+ if password is not None:
+ if not isinstance(password, str) or len(password) > 512:
+ raise SynapseError(400, "Invalid password")
+ self.password_policy_handler.validate_password(password)
+
+ if "initial_device_display_name" in body and password is None:
# ignore 'initial_device_display_name' if sent without
# a password to work around a client bug where it sent
# the 'initial_device_display_name' param alone, wiping out
@@ -484,6 +487,7 @@ class RegisterRestServlet(RestServlet):
session_id = self.auth_handler.get_session_id(body)
registered_user_id = None
+ password_hash = None
if session_id:
# if we get a registered user id out of here, it means we previously
# registered a user for this session, so we could just return the
@@ -492,7 +496,12 @@ class RegisterRestServlet(RestServlet):
registered_user_id = await self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
+ # Extract the previously-hashed password from the session.
+ password_hash = await self.auth_handler.get_session_data(
+ session_id, "password_hash", None
+ )
+ # Ensure that the username is valid.
if desired_username is not None:
await self.registration_handler.check_username(
desired_username,
@@ -500,20 +509,38 @@ class RegisterRestServlet(RestServlet):
assigned_user_id=registered_user_id,
)
- auth_result, params, session_id = await self.auth_handler.check_auth(
- self._registration_flows,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "register a new account",
- )
+ # Check if the user-interactive authentication flows are complete, if
+ # not this will raise a user-interactive auth error.
+ try:
+ auth_result, params, session_id = await self.auth_handler.check_ui_auth(
+ self._registration_flows,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "register a new account",
+ )
+ except InteractiveAuthIncompleteError as e:
+ # The user needs to provide more steps to complete auth.
+ #
+ # Hash the password and store it with the session since the client
+ # is not required to provide the password again.
+ #
+ # If a password hash was previously stored we will not attempt to
+ # re-hash and store it for efficiency. This assumes the password
+ # does not change throughout the authentication flow, but this
+ # should be fine since the data is meant to be consistent.
+ if not password_hash and password:
+ password_hash = await self.auth_handler.hash(password)
+ await self.auth_handler.set_session_data(
+ e.session_id, "password_hash", password_hash
+ )
+ raise
# Check that we're not trying to register a denied 3pid.
#
# the user-facing checks will probably already have happened in
# /register/email/requestToken when we requested a 3pid, but that's not
# guaranteed.
-
if auth_result:
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
if login_type in auth_result:
@@ -535,12 +562,15 @@ class RegisterRestServlet(RestServlet):
# don't re-register the threepids
registered = False
else:
- # NB: This may be from the auth handler and NOT from the POST
- assert_params_in_dict(params, ["password_hash"])
+ # If we have a password in this request, prefer it. Otherwise, there
+ # might be a password hash from an earlier request.
+ if password:
+ password_hash = await self.auth_handler.hash(password)
+ if not password_hash:
+ raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
desired_username = params.get("username", None)
guest_access_token = params.get("guest_access_token", None)
- new_password_hash = params.get("password_hash", None)
if desired_username is not None:
desired_username = desired_username.lower()
@@ -582,7 +612,7 @@ class RegisterRestServlet(RestServlet):
registered_user_id = await self.registration_handler.register_user(
localpart=desired_username,
- password_hash=new_password_hash,
+ password_hash=password_hash,
guest_access_token=guest_access_token,
threepid=threepid,
address=client_addr,
@@ -595,8 +625,8 @@ class RegisterRestServlet(RestServlet):
):
await self.store.upsert_monthly_active_user(registered_user_id)
- # remember that we've now registered that user account, and with
- # what user ID (since the user may not have specified)
+ # Remember that the user account has been registered (and the user
+ # ID it was registered with, since it might not have been specified).
await self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id
)
@@ -635,7 +665,7 @@ class RegisterRestServlet(RestServlet):
(object) params: registration parameters, from which we pull
device_id, initial_device_name and inhibit_login
Returns:
- defer.Deferred: (object) dictionary for response from /register
+ (object) dictionary for response from /register
"""
result = {"user_id": user_id, "home_server": self.hs.hostname}
if not params.get("inhibit_login", False):
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 3f5bf75e59..a5c24fbd63 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -426,7 +426,6 @@ class SyncRestServlet(RestServlet):
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
result["summary"] = room.summary
- result["org.matrix.msc2654.unread_count"] = room.unread_count
return result
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 4386eb4e72..b3e4d5612e 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -22,8 +22,6 @@ from os import path
import jinja2
from jinja2 import TemplateNotFound
-from twisted.internet import defer
-
from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError
from synapse.http.server import DirectServeHtmlResource, respond_with_html
@@ -135,7 +133,7 @@ class ConsentResource(DirectServeHtmlResource):
else:
qualified_user_id = UserID(username, self.hs.hostname).to_string()
- u = await defer.maybeDeferred(self.store.get_user_by_id, qualified_user_id)
+ u = await self.store.get_user_by_id(qualified_user_id)
if u is None:
raise NotFoundError("Unknown user")
diff --git a/synapse/rest/health.py b/synapse/rest/health.py
new file mode 100644
index 0000000000..0170950bf3
--- /dev/null
+++ b/synapse/rest/health.py
@@ -0,0 +1,31 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.web.resource import Resource
+
+
+class HealthResource(Resource):
+ """A resource that does nothing except return a 200 with a body of `OK`,
+ which can be used as a health check.
+
+ Note: `SynapseRequest._should_log_request` ensures that requests to
+ `/health` do not get logged at INFO.
+ """
+
+ isLeaf = 1
+
+ def render_GET(self, request):
+ request.setHeader(b"Content-Type", b"text/plain")
+ return b"OK"
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 858b6d3005..ab1fa705bf 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
-import inspect
import logging
import os
import shutil
@@ -30,7 +29,7 @@ from .filepath import MediaFilePaths
if TYPE_CHECKING:
from synapse.server import HomeServer
- from .storage_provider import StorageProvider
+ from .storage_provider import StorageProviderWrapper
logger = logging.getLogger(__name__)
@@ -50,7 +49,7 @@ class MediaStorage(object):
hs: "HomeServer",
local_media_directory: str,
filepaths: MediaFilePaths,
- storage_providers: Sequence["StorageProvider"],
+ storage_providers: Sequence["StorageProviderWrapper"],
):
self.hs = hs
self.local_media_directory = local_media_directory
@@ -115,11 +114,7 @@ class MediaStorage(object):
async def finish():
for provider in self.storage_providers:
- # store_file is supposed to return an Awaitable, but guard
- # against improper implementations.
- result = provider.store_file(path, file_info)
- if inspect.isawaitable(result):
- await result
+ await provider.store_file(path, file_info)
finished_called[0] = True
@@ -153,11 +148,7 @@ class MediaStorage(object):
return FileResponder(open(local_path, "rb"))
for provider in self.storage_providers:
- res = provider.fetch(path, file_info) # type: Any
- # Fetch is supposed to return an Awaitable[Responder], but guard
- # against improper implementations.
- if inspect.isawaitable(res):
- res = await res
+ res = await provider.fetch(path, file_info) # type: Any
if res:
logger.debug("Streaming %s from %s", path, provider)
return res
@@ -184,11 +175,7 @@ class MediaStorage(object):
os.makedirs(dirname)
for provider in self.storage_providers:
- res = provider.fetch(path, file_info) # type: Any
- # Fetch is supposed to return an Awaitable[Responder], but guard
- # against improper implementations.
- if inspect.isawaitable(res):
- res = await res
+ res = await provider.fetch(path, file_info) # type: Any
if res:
with res:
consumer = BackgroundFileConsumer(
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index e12f65a206..cd8c246594 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -27,9 +27,7 @@ from typing import Dict, Optional
from urllib import parse as urlparse
import attr
-from canonicaljson import json
-from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from synapse.api.errors import Codes, SynapseError
@@ -43,6 +41,7 @@ from synapse.http.servlet import parse_integer, parse_string
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
+from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.stringutils import random_string
@@ -228,7 +227,7 @@ class PreviewUrlResource(DirectServeJsonResource):
else:
logger.info("Returning cached response")
- og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
+ og = await make_deferred_yieldable(observable.observe())
respond_with_json_bytes(request, 200, og, send_cors=True)
async def _do_preview(self, url: str, user: str, ts: int) -> bytes:
@@ -355,7 +354,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("Calculated OG for %s as %s", url, og)
- jsonog = json.dumps(og)
+ jsonog = json_encoder.encode(og)
# store OG in history-aware DB cache
await self.store.store_url_cache(
@@ -586,7 +585,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("Running url preview cache expiry")
- if not (await self.store.db.updates.has_completed_background_updates()):
+ if not (await self.store.db_pool.updates.has_completed_background_updates()):
logger.info("Still running DB updates; skipping expiry")
return
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index a33f56e806..18c9ed48d6 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import logging
import os
import shutil
@@ -88,12 +89,18 @@ class StorageProviderWrapper(StorageProvider):
return None
if self.store_synchronous:
- return await self.backend.store_file(path, file_info)
+ # store_file is supposed to return an Awaitable, but guard
+ # against improper implementations.
+ result = self.backend.store_file(path, file_info)
+ if inspect.isawaitable(result):
+ return await result
else:
# TODO: Handle errors.
- def store():
+ async def store():
try:
- return self.backend.store_file(path, file_info)
+ result = self.backend.store_file(path, file_info)
+ if inspect.isawaitable(result):
+ return await result
except Exception:
logger.exception("Error storing file")
@@ -101,7 +108,11 @@ class StorageProviderWrapper(StorageProvider):
return None
async def fetch(self, path, file_info):
- return await self.backend.fetch(path, file_info)
+ # store_file is supposed to return an Awaitable, but guard
+ # against improper implementations.
+ result = self.backend.fetch(path, file_info)
+ if inspect.isawaitable(result):
+ return await result
class FileStorageProviderBackend(StorageProvider):
diff --git a/synapse/secrets.py b/synapse/secrets.py
index 5f43f81eb0..ff86950a54 100644
--- a/synapse/secrets.py
+++ b/synapse/secrets.py
@@ -25,8 +25,12 @@ import sys
if sys.version_info[0:2] >= (3, 6):
import secrets
- def Secrets():
- return secrets
+ class Secrets:
+ def token_bytes(self, nbytes=32):
+ return secrets.token_bytes(nbytes)
+
+ def token_hex(self, nbytes=32):
+ return secrets.token_hex(nbytes)
else:
diff --git a/synapse/server.py b/synapse/server.py
index 8e41112530..9055b97ac3 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -22,10 +22,14 @@
# Imports required for the default HomeServer() implementation
import abc
+import functools
import logging
import os
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
+import twisted
from twisted.mail.smtp import sendmail
+from twisted.web.iweb import IPolicyForHTTPS
from synapse.api.auth import Auth
from synapse.api.filtering import Filtering
@@ -93,7 +97,7 @@ from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamer
-from synapse.replication.tcp.streams import STREAMS_MAP
+from synapse.replication.tcp.streams import STREAMS_MAP, Stream
from synapse.rest.media.v1.media_repository import (
MediaRepository,
MediaRepositoryResource,
@@ -105,32 +109,74 @@ from synapse.server_notices.worker_server_notices_sender import (
WorkerServerNoticesSender,
)
from synapse.state import StateHandler, StateResolutionHandler
-from synapse.storage import DataStore, DataStores, Storage
+from synapse.storage import Databases, DataStore, Storage
from synapse.streams.events import EventSources
+from synapse.types import DomainSpecificString
from synapse.util import Clock
from synapse.util.distributor import Distributor
from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
+if TYPE_CHECKING:
+ from synapse.handlers.oidc_handler import OidcHandler
+ from synapse.handlers.saml_handler import SamlHandler
-class HomeServer(object):
+
+T = TypeVar("T", bound=Callable[..., Any])
+
+
+def cache_in_self(builder: T) -> T:
+ """Wraps a function called e.g. `get_foo`, checking if `self.foo` exists and
+ returning if so. If not, calls the given function and sets `self.foo` to it.
+
+ Also ensures that dependency cycles throw an exception correctly, rather
+ than overflowing the stack.
+ """
+
+ if not builder.__name__.startswith("get_"):
+ raise Exception(
+ "@cache_in_self can only be used on functions starting with `get_`"
+ )
+
+ depname = builder.__name__[len("get_") :]
+
+ building = [False]
+
+ @functools.wraps(builder)
+ def _get(self):
+ try:
+ return getattr(self, depname)
+ except AttributeError:
+ pass
+
+ # Prevent cyclic dependencies from deadlocking
+ if building[0]:
+ raise ValueError("Cyclic dependency while building %s" % (depname,))
+
+ building[0] = True
+ try:
+ dep = builder(self)
+ setattr(self, depname, dep)
+ finally:
+ building[0] = False
+
+ return dep
+
+ # We cast here as we need to tell mypy that `_get` has the same signature as
+ # `builder`.
+ return cast(T, _get)
+
+
+class HomeServer(metaclass=abc.ABCMeta):
"""A basic homeserver object without lazy component builders.
This will need all of the components it requires to either be passed as
constructor arguments, or the relevant methods overriding to create them.
Typically this would only be used for unit tests.
- For every dependency in the DEPENDENCIES list below, this class creates one
- method,
- def get_DEPENDENCY(self)
- which returns the value of that dependency. If no value has yet been set
- nor was provided to the constructor, it will attempt to call a lazy builder
- method called
- def build_DEPENDENCY(self)
- which must be implemented by the subclass. This code may call any of the
- required "get" methods on the instance to obtain the sub-dependencies that
- one requires.
+ Dependencies should be added by creating a `def get_<depname>(self)`
+ function, wrapping it in `@cache_in_self`.
Attributes:
config (synapse.config.homeserver.HomeserverConfig):
@@ -138,86 +184,6 @@ class HomeServer(object):
we are listening on to provide HTTP services.
"""
- __metaclass__ = abc.ABCMeta
-
- DEPENDENCIES = [
- "http_client",
- "federation_client",
- "federation_server",
- "handlers",
- "auth",
- "room_creation_handler",
- "room_shutdown_handler",
- "state_handler",
- "state_resolution_handler",
- "presence_handler",
- "sync_handler",
- "typing_handler",
- "room_list_handler",
- "acme_handler",
- "auth_handler",
- "device_handler",
- "stats_handler",
- "e2e_keys_handler",
- "e2e_room_keys_handler",
- "event_handler",
- "event_stream_handler",
- "initial_sync_handler",
- "application_service_api",
- "application_service_scheduler",
- "application_service_handler",
- "device_message_handler",
- "profile_handler",
- "event_creation_handler",
- "deactivate_account_handler",
- "set_password_handler",
- "notifier",
- "event_sources",
- "keyring",
- "pusherpool",
- "event_builder_factory",
- "filtering",
- "http_client_context_factory",
- "simple_http_client",
- "proxied_http_client",
- "media_repository",
- "media_repository_resource",
- "federation_transport_client",
- "federation_sender",
- "receipts_handler",
- "macaroon_generator",
- "tcp_replication",
- "read_marker_handler",
- "action_generator",
- "user_directory_handler",
- "groups_local_handler",
- "groups_server_handler",
- "groups_attestation_signing",
- "groups_attestation_renewer",
- "secrets",
- "spam_checker",
- "third_party_event_rules",
- "room_member_handler",
- "federation_registry",
- "server_notices_manager",
- "server_notices_sender",
- "message_handler",
- "pagination_handler",
- "room_context_handler",
- "sendmail",
- "registration_handler",
- "account_validity_handler",
- "cas_handler",
- "saml_handler",
- "oidc_handler",
- "event_client_serializer",
- "password_policy_handler",
- "storage",
- "replication_streamer",
- "replication_data_handler",
- "replication_streams",
- ]
-
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
# This is overridden in derived application classes
@@ -232,16 +198,17 @@ class HomeServer(object):
config: The full config for the homeserver.
"""
if not reactor:
- from twisted.internet import reactor
+ from twisted.internet import reactor as _reactor
+
+ reactor = _reactor
self._reactor = reactor
self.hostname = hostname
# the key we use to sign events and requests
self.signing_key = config.key.signing_key[0]
self.config = config
- self._building = {}
- self._listening_services = []
- self.start_time = None
+ self._listening_services = [] # type: List[twisted.internet.tcp.Port]
+ self.start_time = None # type: Optional[int]
self._instance_id = random_string(5)
self._instance_name = config.worker_name or "master"
@@ -255,13 +222,13 @@ class HomeServer(object):
burst_count=config.rc_registration.burst_count,
)
- self.datastores = None
+ self.datastores = None # type: Optional[Databases]
# Other kwargs are explicit dependencies
for depname in kwargs:
setattr(self, depname, kwargs[depname])
- def get_instance_id(self):
+ def get_instance_id(self) -> str:
"""A unique ID for this synapse process instance.
This is used to distinguish running instances in worker-based
@@ -277,13 +244,13 @@ class HomeServer(object):
"""
return self._instance_name
- def setup(self):
+ def setup(self) -> None:
logger.info("Setting up.")
self.start_time = int(self.get_clock().time())
- self.datastores = DataStores(self.DATASTORE_CLASS, self)
+ self.datastores = Databases(self.DATASTORE_CLASS, self)
logger.info("Finished setting up.")
- def setup_master(self):
+ def setup_master(self) -> None:
"""
Some handlers have side effects on instantiation (like registering
background updates). This function causes them to be fetched, and
@@ -292,192 +259,242 @@ class HomeServer(object):
for i in self.REQUIRED_ON_MASTER_STARTUP:
getattr(self, "get_" + i)()
- def get_reactor(self):
+ def get_reactor(self) -> twisted.internet.base.ReactorBase:
"""
Fetch the Twisted reactor in use by this HomeServer.
"""
return self._reactor
- def get_ip_from_request(self, request):
+ def get_ip_from_request(self, request) -> str:
# X-Forwarded-For is handled by our custom request type.
return request.getClientIP()
- def is_mine(self, domain_specific_string):
+ def is_mine(self, domain_specific_string: DomainSpecificString) -> bool:
return domain_specific_string.domain == self.hostname
- def is_mine_id(self, string):
+ def is_mine_id(self, string: str) -> bool:
return string.split(":", 1)[1] == self.hostname
- def get_clock(self):
+ def get_clock(self) -> Clock:
return self.clock
def get_datastore(self) -> DataStore:
+ if not self.datastores:
+ raise Exception("HomeServer.setup must be called before getting datastores")
+
return self.datastores.main
- def get_datastores(self):
+ def get_datastores(self) -> Databases:
+ if not self.datastores:
+ raise Exception("HomeServer.setup must be called before getting datastores")
+
return self.datastores
- def get_config(self):
+ def get_config(self) -> HomeServerConfig:
return self.config
- def get_distributor(self):
+ def get_distributor(self) -> Distributor:
return self.distributor
def get_registration_ratelimiter(self) -> Ratelimiter:
return self.registration_ratelimiter
- def build_federation_client(self):
+ @cache_in_self
+ def get_federation_client(self) -> FederationClient:
return FederationClient(self)
- def build_federation_server(self):
+ @cache_in_self
+ def get_federation_server(self) -> FederationServer:
return FederationServer(self)
- def build_handlers(self):
+ @cache_in_self
+ def get_handlers(self) -> Handlers:
return Handlers(self)
- def build_notifier(self):
+ @cache_in_self
+ def get_notifier(self) -> Notifier:
return Notifier(self)
- def build_auth(self):
+ @cache_in_self
+ def get_auth(self) -> Auth:
return Auth(self)
- def build_http_client_context_factory(self):
+ @cache_in_self
+ def get_http_client_context_factory(self) -> IPolicyForHTTPS:
return (
InsecureInterceptableContextFactory()
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
else RegularPolicyForHTTPS()
)
- def build_simple_http_client(self):
+ @cache_in_self
+ def get_simple_http_client(self) -> SimpleHttpClient:
return SimpleHttpClient(self)
- def build_proxied_http_client(self):
+ @cache_in_self
+ def get_proxied_http_client(self) -> SimpleHttpClient:
return SimpleHttpClient(
self,
http_proxy=os.getenvb(b"http_proxy"),
https_proxy=os.getenvb(b"HTTPS_PROXY"),
)
- def build_room_creation_handler(self):
+ @cache_in_self
+ def get_room_creation_handler(self) -> RoomCreationHandler:
return RoomCreationHandler(self)
- def build_room_shutdown_handler(self):
+ @cache_in_self
+ def get_room_shutdown_handler(self) -> RoomShutdownHandler:
return RoomShutdownHandler(self)
- def build_sendmail(self):
+ @cache_in_self
+ def get_sendmail(self) -> sendmail:
return sendmail
- def build_state_handler(self):
+ @cache_in_self
+ def get_state_handler(self) -> StateHandler:
return StateHandler(self)
- def build_state_resolution_handler(self):
+ @cache_in_self
+ def get_state_resolution_handler(self) -> StateResolutionHandler:
return StateResolutionHandler(self)
- def build_presence_handler(self):
+ @cache_in_self
+ def get_presence_handler(self) -> PresenceHandler:
return PresenceHandler(self)
- def build_typing_handler(self):
+ @cache_in_self
+ def get_typing_handler(self):
if self.config.worker.writers.typing == self.get_instance_name():
return TypingWriterHandler(self)
else:
return FollowerTypingHandler(self)
- def build_sync_handler(self):
+ @cache_in_self
+ def get_sync_handler(self) -> SyncHandler:
return SyncHandler(self)
- def build_room_list_handler(self):
+ @cache_in_self
+ def get_room_list_handler(self) -> RoomListHandler:
return RoomListHandler(self)
- def build_auth_handler(self):
+ @cache_in_self
+ def get_auth_handler(self) -> AuthHandler:
return AuthHandler(self)
- def build_macaroon_generator(self):
+ @cache_in_self
+ def get_macaroon_generator(self) -> MacaroonGenerator:
return MacaroonGenerator(self)
- def build_device_handler(self):
+ @cache_in_self
+ def get_device_handler(self):
if self.config.worker_app:
return DeviceWorkerHandler(self)
else:
return DeviceHandler(self)
- def build_device_message_handler(self):
+ @cache_in_self
+ def get_device_message_handler(self) -> DeviceMessageHandler:
return DeviceMessageHandler(self)
- def build_e2e_keys_handler(self):
+ @cache_in_self
+ def get_e2e_keys_handler(self) -> E2eKeysHandler:
return E2eKeysHandler(self)
- def build_e2e_room_keys_handler(self):
+ @cache_in_self
+ def get_e2e_room_keys_handler(self) -> E2eRoomKeysHandler:
return E2eRoomKeysHandler(self)
- def build_acme_handler(self):
+ @cache_in_self
+ def get_acme_handler(self) -> AcmeHandler:
return AcmeHandler(self)
- def build_application_service_api(self):
+ @cache_in_self
+ def get_application_service_api(self) -> ApplicationServiceApi:
return ApplicationServiceApi(self)
- def build_application_service_scheduler(self):
+ @cache_in_self
+ def get_application_service_scheduler(self) -> ApplicationServiceScheduler:
return ApplicationServiceScheduler(self)
- def build_application_service_handler(self):
+ @cache_in_self
+ def get_application_service_handler(self) -> ApplicationServicesHandler:
return ApplicationServicesHandler(self)
- def build_event_handler(self):
+ @cache_in_self
+ def get_event_handler(self) -> EventHandler:
return EventHandler(self)
- def build_event_stream_handler(self):
+ @cache_in_self
+ def get_event_stream_handler(self) -> EventStreamHandler:
return EventStreamHandler(self)
- def build_initial_sync_handler(self):
+ @cache_in_self
+ def get_initial_sync_handler(self) -> InitialSyncHandler:
return InitialSyncHandler(self)
- def build_profile_handler(self):
+ @cache_in_self
+ def get_profile_handler(self):
if self.config.worker_app:
return BaseProfileHandler(self)
else:
return MasterProfileHandler(self)
- def build_event_creation_handler(self):
+ @cache_in_self
+ def get_event_creation_handler(self) -> EventCreationHandler:
return EventCreationHandler(self)
- def build_deactivate_account_handler(self):
+ @cache_in_self
+ def get_deactivate_account_handler(self) -> DeactivateAccountHandler:
return DeactivateAccountHandler(self)
- def build_set_password_handler(self):
+ @cache_in_self
+ def get_set_password_handler(self) -> SetPasswordHandler:
return SetPasswordHandler(self)
- def build_event_sources(self):
+ @cache_in_self
+ def get_event_sources(self) -> EventSources:
return EventSources(self)
- def build_keyring(self):
+ @cache_in_self
+ def get_keyring(self) -> Keyring:
return Keyring(self)
- def build_event_builder_factory(self):
+ @cache_in_self
+ def get_event_builder_factory(self) -> EventBuilderFactory:
return EventBuilderFactory(self)
- def build_filtering(self):
+ @cache_in_self
+ def get_filtering(self) -> Filtering:
return Filtering(self)
- def build_pusherpool(self):
+ @cache_in_self
+ def get_pusherpool(self) -> PusherPool:
return PusherPool(self)
- def build_http_client(self):
+ @cache_in_self
+ def get_http_client(self) -> MatrixFederationHttpClient:
tls_client_options_factory = context_factory.FederationPolicyForHTTPS(
self.config
)
return MatrixFederationHttpClient(self, tls_client_options_factory)
- def build_media_repository_resource(self):
+ @cache_in_self
+ def get_media_repository_resource(self) -> MediaRepositoryResource:
# build the media repo resource. This indirects through the HomeServer
# to ensure that we only have a single instance of
return MediaRepositoryResource(self)
- def build_media_repository(self):
+ @cache_in_self
+ def get_media_repository(self) -> MediaRepository:
return MediaRepository(self)
- def build_federation_transport_client(self):
+ @cache_in_self
+ def get_federation_transport_client(self) -> TransportLayerClient:
return TransportLayerClient(self)
- def build_federation_sender(self):
+ @cache_in_self
+ def get_federation_sender(self):
if self.should_send_federation():
return FederationSender(self)
elif not self.config.worker_app:
@@ -485,156 +502,152 @@ class HomeServer(object):
else:
raise Exception("Workers cannot send federation traffic")
- def build_receipts_handler(self):
+ @cache_in_self
+ def get_receipts_handler(self) -> ReceiptsHandler:
return ReceiptsHandler(self)
- def build_read_marker_handler(self):
+ @cache_in_self
+ def get_read_marker_handler(self) -> ReadMarkerHandler:
return ReadMarkerHandler(self)
- def build_tcp_replication(self):
+ @cache_in_self
+ def get_tcp_replication(self) -> ReplicationCommandHandler:
return ReplicationCommandHandler(self)
- def build_action_generator(self):
+ @cache_in_self
+ def get_action_generator(self) -> ActionGenerator:
return ActionGenerator(self)
- def build_user_directory_handler(self):
+ @cache_in_self
+ def get_user_directory_handler(self) -> UserDirectoryHandler:
return UserDirectoryHandler(self)
- def build_groups_local_handler(self):
+ @cache_in_self
+ def get_groups_local_handler(self):
if self.config.worker_app:
return GroupsLocalWorkerHandler(self)
else:
return GroupsLocalHandler(self)
- def build_groups_server_handler(self):
+ @cache_in_self
+ def get_groups_server_handler(self):
if self.config.worker_app:
return GroupsServerWorkerHandler(self)
else:
return GroupsServerHandler(self)
- def build_groups_attestation_signing(self):
+ @cache_in_self
+ def get_groups_attestation_signing(self) -> GroupAttestationSigning:
return GroupAttestationSigning(self)
- def build_groups_attestation_renewer(self):
+ @cache_in_self
+ def get_groups_attestation_renewer(self) -> GroupAttestionRenewer:
return GroupAttestionRenewer(self)
- def build_secrets(self):
+ @cache_in_self
+ def get_secrets(self) -> Secrets:
return Secrets()
- def build_stats_handler(self):
+ @cache_in_self
+ def get_stats_handler(self) -> StatsHandler:
return StatsHandler(self)
- def build_spam_checker(self):
+ @cache_in_self
+ def get_spam_checker(self):
return SpamChecker(self)
- def build_third_party_event_rules(self):
+ @cache_in_self
+ def get_third_party_event_rules(self) -> ThirdPartyEventRules:
return ThirdPartyEventRules(self)
- def build_room_member_handler(self):
+ @cache_in_self
+ def get_room_member_handler(self):
if self.config.worker_app:
return RoomMemberWorkerHandler(self)
return RoomMemberMasterHandler(self)
- def build_federation_registry(self):
+ @cache_in_self
+ def get_federation_registry(self) -> FederationHandlerRegistry:
return FederationHandlerRegistry(self)
- def build_server_notices_manager(self):
+ @cache_in_self
+ def get_server_notices_manager(self):
if self.config.worker_app:
raise Exception("Workers cannot send server notices")
return ServerNoticesManager(self)
- def build_server_notices_sender(self):
+ @cache_in_self
+ def get_server_notices_sender(self):
if self.config.worker_app:
return WorkerServerNoticesSender(self)
return ServerNoticesSender(self)
- def build_message_handler(self):
+ @cache_in_self
+ def get_message_handler(self) -> MessageHandler:
return MessageHandler(self)
- def build_pagination_handler(self):
+ @cache_in_self
+ def get_pagination_handler(self) -> PaginationHandler:
return PaginationHandler(self)
- def build_room_context_handler(self):
+ @cache_in_self
+ def get_room_context_handler(self) -> RoomContextHandler:
return RoomContextHandler(self)
- def build_registration_handler(self):
+ @cache_in_self
+ def get_registration_handler(self) -> RegistrationHandler:
return RegistrationHandler(self)
- def build_account_validity_handler(self):
+ @cache_in_self
+ def get_account_validity_handler(self) -> AccountValidityHandler:
return AccountValidityHandler(self)
- def build_cas_handler(self):
+ @cache_in_self
+ def get_cas_handler(self) -> CasHandler:
return CasHandler(self)
- def build_saml_handler(self):
+ @cache_in_self
+ def get_saml_handler(self) -> "SamlHandler":
from synapse.handlers.saml_handler import SamlHandler
return SamlHandler(self)
- def build_oidc_handler(self):
+ @cache_in_self
+ def get_oidc_handler(self) -> "OidcHandler":
from synapse.handlers.oidc_handler import OidcHandler
return OidcHandler(self)
- def build_event_client_serializer(self):
+ @cache_in_self
+ def get_event_client_serializer(self) -> EventClientSerializer:
return EventClientSerializer(self)
- def build_password_policy_handler(self):
+ @cache_in_self
+ def get_password_policy_handler(self) -> PasswordPolicyHandler:
return PasswordPolicyHandler(self)
- def build_storage(self) -> Storage:
- return Storage(self, self.datastores)
+ @cache_in_self
+ def get_storage(self) -> Storage:
+ return Storage(self, self.get_datastores())
- def build_replication_streamer(self) -> ReplicationStreamer:
+ @cache_in_self
+ def get_replication_streamer(self) -> ReplicationStreamer:
return ReplicationStreamer(self)
- def build_replication_data_handler(self):
+ @cache_in_self
+ def get_replication_data_handler(self) -> ReplicationDataHandler:
return ReplicationDataHandler(self)
- def build_replication_streams(self):
+ @cache_in_self
+ def get_replication_streams(self) -> Dict[str, Stream]:
return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()}
- def remove_pusher(self, app_id, push_key, user_id):
- return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
+ async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
+ return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
- def should_send_federation(self):
+ def should_send_federation(self) -> bool:
"Should this server be sending federation traffic directly?"
return self.config.send_federation and (
not self.config.worker_app
or self.config.worker_app == "synapse.app.federation_sender"
)
-
-
-def _make_dependency_method(depname):
- def _get(hs):
- try:
- return getattr(hs, depname)
- except AttributeError:
- pass
-
- try:
- builder = getattr(hs, "build_%s" % (depname))
- except AttributeError:
- raise NotImplementedError(
- "%s has no %s nor a builder for it" % (type(hs).__name__, depname)
- )
-
- # Prevent cyclic dependencies from deadlocking
- if depname in hs._building:
- raise ValueError("Cyclic dependency while building %s" % (depname,))
-
- hs._building[depname] = 1
- try:
- dep = builder()
- setattr(hs, depname, dep)
- finally:
- del hs._building[depname]
-
- return dep
-
- setattr(HomeServer, "get_%s" % (depname), _get)
-
-
-# Build magic accessors for every dependency
-for depname in HomeServer.DEPENDENCIES:
- _make_dependency_method(depname)
diff --git a/synapse/server.pyi b/synapse/server.pyi
deleted file mode 100644
index 1aba408c21..0000000000
--- a/synapse/server.pyi
+++ /dev/null
@@ -1,155 +0,0 @@
-from typing import Dict
-
-import twisted.internet
-
-import synapse.api.auth
-import synapse.config.homeserver
-import synapse.crypto.keyring
-import synapse.federation.federation_server
-import synapse.federation.sender
-import synapse.federation.transport.client
-import synapse.handlers
-import synapse.handlers.auth
-import synapse.handlers.deactivate_account
-import synapse.handlers.device
-import synapse.handlers.e2e_keys
-import synapse.handlers.message
-import synapse.handlers.presence
-import synapse.handlers.register
-import synapse.handlers.room
-import synapse.handlers.room_member
-import synapse.handlers.set_password
-import synapse.http.client
-import synapse.http.matrixfederationclient
-import synapse.notifier
-import synapse.push.pusherpool
-import synapse.replication.tcp.client
-import synapse.replication.tcp.handler
-import synapse.rest.media.v1.media_repository
-import synapse.server_notices.server_notices_manager
-import synapse.server_notices.server_notices_sender
-import synapse.state
-import synapse.storage
-from synapse.events.builder import EventBuilderFactory
-from synapse.handlers.typing import FollowerTypingHandler
-from synapse.replication.tcp.streams import Stream
-
-class HomeServer(object):
- @property
- def config(self) -> synapse.config.homeserver.HomeServerConfig:
- pass
- @property
- def hostname(self) -> str:
- pass
- def get_auth(self) -> synapse.api.auth.Auth:
- pass
- def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler:
- pass
- def get_datastore(self) -> synapse.storage.DataStore:
- pass
- def get_device_handler(self) -> synapse.handlers.device.DeviceHandler:
- pass
- def get_e2e_keys_handler(self) -> synapse.handlers.e2e_keys.E2eKeysHandler:
- pass
- def get_handlers(self) -> synapse.handlers.Handlers:
- pass
- def get_state_handler(self) -> synapse.state.StateHandler:
- pass
- def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
- pass
- def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient:
- """Fetch an HTTP client implementation which doesn't do any blacklisting
- or support any HTTP_PROXY settings"""
- pass
- def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient:
- """Fetch an HTTP client implementation which doesn't do any blacklisting
- but does support HTTP_PROXY settings"""
- pass
- def get_deactivate_account_handler(
- self,
- ) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
- pass
- def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler:
- pass
- def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler:
- pass
- def get_room_shutdown_handler(self) -> synapse.handlers.room.RoomShutdownHandler:
- pass
- def get_event_creation_handler(
- self,
- ) -> synapse.handlers.message.EventCreationHandler:
- pass
- def get_set_password_handler(
- self,
- ) -> synapse.handlers.set_password.SetPasswordHandler:
- pass
- def get_federation_sender(self) -> synapse.federation.sender.FederationSender:
- pass
- def get_federation_transport_client(
- self,
- ) -> synapse.federation.transport.client.TransportLayerClient:
- pass
- def get_media_repository_resource(
- self,
- ) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
- pass
- def get_media_repository(
- self,
- ) -> synapse.rest.media.v1.media_repository.MediaRepository:
- pass
- def get_server_notices_manager(
- self,
- ) -> synapse.server_notices.server_notices_manager.ServerNoticesManager:
- pass
- def get_server_notices_sender(
- self,
- ) -> synapse.server_notices.server_notices_sender.ServerNoticesSender:
- pass
- def get_notifier(self) -> synapse.notifier.Notifier:
- pass
- def get_presence_handler(self) -> synapse.handlers.presence.BasePresenceHandler:
- pass
- def get_clock(self) -> synapse.util.Clock:
- pass
- def get_reactor(self) -> twisted.internet.base.ReactorBase:
- pass
- def get_keyring(self) -> synapse.crypto.keyring.Keyring:
- pass
- def get_tcp_replication(
- self,
- ) -> synapse.replication.tcp.handler.ReplicationCommandHandler:
- pass
- def get_replication_data_handler(
- self,
- ) -> synapse.replication.tcp.client.ReplicationDataHandler:
- pass
- def get_federation_registry(
- self,
- ) -> synapse.federation.federation_server.FederationHandlerRegistry:
- pass
- def is_mine_id(self, domain_id: str) -> bool:
- pass
- def get_instance_id(self) -> str:
- pass
- def get_instance_name(self) -> str:
- pass
- def get_event_builder_factory(self) -> EventBuilderFactory:
- pass
- def get_storage(self) -> synapse.storage.Storage:
- pass
- def get_registration_handler(self) -> synapse.handlers.register.RegistrationHandler:
- pass
- def get_macaroon_generator(self) -> synapse.handlers.auth.MacaroonGenerator:
- pass
- def get_pusherpool(self) -> synapse.push.pusherpool.PusherPool:
- pass
- def get_replication_streams(self) -> Dict[str, Stream]:
- pass
- def get_http_client(
- self,
- ) -> synapse.http.matrixfederationclient.MatrixFederationHttpClient:
- pass
- def should_send_federation(self) -> bool:
- pass
- def get_typing_handler(self) -> FollowerTypingHandler:
- pass
diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
index 3bfc8d7278..089cfef0b3 100644
--- a/synapse/server_notices/consent_server_notices.py
+++ b/synapse/server_notices/consent_server_notices.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Any
from synapse.api.errors import SynapseError
from synapse.api.urls import ConsentURIBuilder
@@ -55,14 +56,11 @@ class ConsentServerNotices(object):
self._consent_uri_builder = ConsentURIBuilder(hs.config)
- async def maybe_send_server_notice_to_user(self, user_id):
+ async def maybe_send_server_notice_to_user(self, user_id: str) -> None:
"""Check if we need to send a notice to this user, and does so if so
Args:
- user_id (str): user to check
-
- Returns:
- Deferred
+ user_id: user to check
"""
if self._server_notice_content is None:
# not enabled
@@ -105,7 +103,7 @@ class ConsentServerNotices(object):
self._users_in_progress.remove(user_id)
-def copy_with_str_subst(x, substitutions):
+def copy_with_str_subst(x: Any, substitutions: Any) -> Any:
"""Deep-copy a structure, carrying out string substitions on any strings
Args:
@@ -121,7 +119,7 @@ def copy_with_str_subst(x, substitutions):
if isinstance(x, dict):
return {k: copy_with_str_subst(v, substitutions) for (k, v) in x.items()}
if isinstance(x, (list, tuple)):
- return [copy_with_str_subst(y) for y in x]
+ return [copy_with_str_subst(y, substitutions) for y in x]
# assume it's uninterested and can be shallow-copied.
return x
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index 4404ceff93..c2faef6eab 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import List, Tuple
from synapse.api.constants import (
EventTypes,
@@ -52,7 +53,7 @@ class ResourceLimitsServerNotices(object):
and not hs.config.hs_disabled
)
- async def maybe_send_server_notice_to_user(self, user_id):
+ async def maybe_send_server_notice_to_user(self, user_id: str) -> None:
"""Check if we need to send a notice to this user, this will be true in
two cases.
1. The server has reached its limit does not reflect this
@@ -60,10 +61,7 @@ class ResourceLimitsServerNotices(object):
actually the server is fine
Args:
- user_id (str): user to check
-
- Returns:
- Deferred
+ user_id: user to check
"""
if not self._enabled:
return
@@ -115,19 +113,21 @@ class ResourceLimitsServerNotices(object):
elif not currently_blocked and limit_msg:
# Room is not notifying of a block, when it ought to be.
await self._apply_limit_block_notification(
- user_id, limit_msg, limit_type
+ user_id, limit_msg, limit_type # type: ignore
)
except SynapseError as e:
logger.error("Error sending resource limits server notice: %s", e)
- async def _remove_limit_block_notification(self, user_id, ref_events):
+ async def _remove_limit_block_notification(
+ self, user_id: str, ref_events: List[str]
+ ) -> None:
"""Utility method to remove limit block notifications from the server
notices room.
Args:
- user_id (str): user to notify
- ref_events (list[str]): The event_ids of pinned events that are unrelated to
- limit blocking and need to be preserved.
+ user_id: user to notify
+ ref_events: The event_ids of pinned events that are unrelated to
+ limit blocking and need to be preserved.
"""
content = {"pinned": ref_events}
await self._server_notices_manager.send_notice(
@@ -135,16 +135,16 @@ class ResourceLimitsServerNotices(object):
)
async def _apply_limit_block_notification(
- self, user_id, event_body, event_limit_type
- ):
+ self, user_id: str, event_body: str, event_limit_type: str
+ ) -> None:
"""Utility method to apply limit block notifications in the server
notices room.
Args:
- user_id (str): user to notify
- event_body(str): The human readable text that describes the block.
- event_limit_type(str): Specifies the type of block e.g. monthly active user
- limit has been exceeded.
+ user_id: user to notify
+ event_body: The human readable text that describes the block.
+ event_limit_type: Specifies the type of block e.g. monthly active user
+ limit has been exceeded.
"""
content = {
"body": event_body,
@@ -162,7 +162,7 @@ class ResourceLimitsServerNotices(object):
user_id, content, EventTypes.Pinned, ""
)
- async def _check_and_set_tags(self, user_id, room_id):
+ async def _check_and_set_tags(self, user_id: str, room_id: str) -> None:
"""
Since server notices rooms were originally not with tags,
important to check that tags have been set correctly
@@ -182,17 +182,16 @@ class ResourceLimitsServerNotices(object):
)
self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
- async def _is_room_currently_blocked(self, room_id):
+ async def _is_room_currently_blocked(self, room_id: str) -> Tuple[bool, List[str]]:
"""
Determines if the room is currently blocked
Args:
- room_id(str): The room id of the server notices room
+ room_id: The room id of the server notices room
Returns:
- Deferred[Tuple[bool, List]]:
bool: Is the room currently blocked
- list: The list of pinned events that are unrelated to limit blocking
+ list: The list of pinned event IDs that are unrelated to limit blocking
This list can be used as a convenience in the case where the block
is to be lifted and the remaining pinned event references need to be
preserved
@@ -207,7 +206,7 @@ class ResourceLimitsServerNotices(object):
# The user has yet to join the server notices room
pass
- referenced_events = []
+ referenced_events = [] # type: List[str]
if pinned_state_event is not None:
referenced_events = list(pinned_state_event.content.get("pinned", []))
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index bf2454c01c..ed96aa8571 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Optional
from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
+from synapse.events import EventBase
from synapse.types import UserID, create_requester
from synapse.util.caches.descriptors import cached
@@ -50,20 +52,21 @@ class ServerNoticesManager(object):
return self._config.server_notices_mxid is not None
async def send_notice(
- self, user_id, event_content, type=EventTypes.Message, state_key=None
- ):
+ self,
+ user_id: str,
+ event_content: dict,
+ type: str = EventTypes.Message,
+ state_key: Optional[bool] = None,
+ ) -> EventBase:
"""Send a notice to the given user
Creates the server notices room, if none exists.
Args:
- user_id (str): mxid of user to send event to.
- event_content (dict): content of event to send
- type(EventTypes): type of event
- is_state_event(bool): Is the event a state event
-
- Returns:
- Deferred[FrozenEvent]
+ user_id: mxid of user to send event to.
+ event_content: content of event to send
+ type: type of event
+ is_state_event: Is the event a state event
"""
room_id = await self.get_or_create_notice_room_for_user(user_id)
await self.maybe_invite_user_to_room(user_id, room_id)
@@ -89,17 +92,17 @@ class ServerNoticesManager(object):
return event
@cached()
- async def get_or_create_notice_room_for_user(self, user_id):
+ async def get_or_create_notice_room_for_user(self, user_id: str) -> str:
"""Get the room for notices for a given user
If we have not yet created a notice room for this user, create it, but don't
invite the user to it.
Args:
- user_id (str): complete user id for the user we want a room for
+ user_id: complete user id for the user we want a room for
Returns:
- str: room id of notice room.
+ room id of notice room.
"""
if not self.is_enabled():
raise Exception("Server notices not enabled")
@@ -163,7 +166,7 @@ class ServerNoticesManager(object):
logger.info("Created server notices room %s for %s", room_id, user_id)
return room_id
- async def maybe_invite_user_to_room(self, user_id: str, room_id: str):
+ async def maybe_invite_user_to_room(self, user_id: str, room_id: str) -> None:
"""Invite the given user to the given server room, unless the user has already
joined or been invited to it.
diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py
index be74e86641..a754f75db4 100644
--- a/synapse/server_notices/server_notices_sender.py
+++ b/synapse/server_notices/server_notices_sender.py
@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Iterable, Union
+
from synapse.server_notices.consent_server_notices import ConsentServerNotices
from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices,
@@ -32,22 +34,22 @@ class ServerNoticesSender(object):
self._server_notices = (
ConsentServerNotices(hs),
ResourceLimitsServerNotices(hs),
- )
+ ) # type: Iterable[Union[ConsentServerNotices, ResourceLimitsServerNotices]]
- async def on_user_syncing(self, user_id):
+ async def on_user_syncing(self, user_id: str) -> None:
"""Called when the user performs a sync operation.
Args:
- user_id (str): mxid of user who synced
+ user_id: mxid of user who synced
"""
for sn in self._server_notices:
await sn.maybe_send_server_notice_to_user(user_id)
- async def on_user_ip(self, user_id):
+ async def on_user_ip(self, user_id: str) -> None:
"""Called on the master when a worker process saw a client request.
Args:
- user_id (str): mxid
+ user_id: mxid
"""
# The synchrotrons use a stubbed version of ServerNoticesSender, so
# we check for notices to send to the user in on_user_ip as well as
diff --git a/synapse/server_notices/worker_server_notices_sender.py b/synapse/server_notices/worker_server_notices_sender.py
index 245ec7c64f..e9390b19da 100644
--- a/synapse/server_notices/worker_server_notices_sender.py
+++ b/synapse/server_notices/worker_server_notices_sender.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
class WorkerServerNoticesSender(object):
@@ -24,24 +23,18 @@ class WorkerServerNoticesSender(object):
hs (synapse.server.HomeServer):
"""
- def on_user_syncing(self, user_id):
+ async def on_user_syncing(self, user_id: str) -> None:
"""Called when the user performs a sync operation.
Args:
- user_id (str): mxid of user who synced
-
- Returns:
- Deferred
+ user_id: mxid of user who synced
"""
- return defer.succeed(None)
+ return None
- def on_user_ip(self, user_id):
+ async def on_user_ip(self, user_id: str) -> None:
"""Called on the master when a worker process saw a client request.
Args:
- user_id (str): mxid
-
- Returns:
- Deferred
+ user_id: mxid
"""
raise AssertionError("on_user_ip unexpectedly called on worker")
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 25ccef5aa5..a1d3884667 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -28,7 +28,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function
from synapse.state import v1, v2
-from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
+from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
from synapse.types import StateMap
from synapse.util import Clock
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index ec89f645d4..5ef3853559 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -17,18 +17,19 @@
"""
The storage layer is split up into multiple parts to allow Synapse to run
against different configurations of databases (e.g. single or multiple
-databases). The `Database` class represents a single physical database. The
-`data_stores` are classes that talk directly to a `Database` instance and have
-associated schemas, background updates, etc. On top of those there are classes
-that provide high level interfaces that combine calls to multiple `data_stores`.
+databases). The `DatabasePool` class represents connections to a single physical
+database. The `databases` are classes that talk directly to a `DatabasePool`
+instance and have associated schemas, background updates, etc. On top of those
+there are classes that provide high level interfaces that combine calls to
+multiple `databases`.
There are also schemas that get applied to every database, regardless of the
data stores associated with them (e.g. the schema version tables), which are
stored in `synapse.storage.schema`.
"""
-from synapse.storage.data_stores import DataStores
-from synapse.storage.data_stores.main import DataStore
+from synapse.storage.databases import Databases
+from synapse.storage.databases.main import DataStore
from synapse.storage.persist_events import EventsPersistenceStorage
from synapse.storage.purge_events import PurgeEventsStorage
from synapse.storage.state import StateGroupStorage
@@ -40,7 +41,7 @@ class Storage(object):
"""The high level interfaces for talking to various storage layers.
"""
- def __init__(self, hs, stores: DataStores):
+ def __init__(self, hs, stores: Databases):
# We include the main data store here mainly so that we don't have to
# rewrite all the existing code to split it into high vs low level
# interfaces.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 985a042869..6814bf5fcf 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -23,7 +23,7 @@ from canonicaljson import json
from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
from synapse.types import Collection, get_domain_from_id
logger = logging.getLogger(__name__)
@@ -37,11 +37,11 @@ class SQLBaseStore(metaclass=ABCMeta):
per data store (and not one per physical database).
"""
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
- self.db = database
+ self.db_pool = database
self.rand = random.SystemRandom()
def process_replication_rows(self, stream_name, instance_name, token, rows):
@@ -58,7 +58,6 @@ class SQLBaseStore(metaclass=ABCMeta):
"""
for host in {get_domain_from_id(u) for u in members_changed}:
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
- self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 018826ef69..f43463df53 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -88,7 +88,7 @@ class BackgroundUpdater(object):
def __init__(self, hs, database):
self._clock = hs.get_clock()
- self.db = database
+ self.db_pool = database
# if a background update is currently running, its name.
self._current_background_update = None # type: Optional[str]
@@ -139,7 +139,7 @@ class BackgroundUpdater(object):
# otherwise, check if there are updates to be run. This is important,
# as we may be running on a worker which doesn't perform the bg updates
# itself, but still wants to wait for them to happen.
- updates = await self.db.simple_select_onecol(
+ updates = await self.db_pool.simple_select_onecol(
"background_updates",
keyvalues=None,
retcol="1",
@@ -160,7 +160,7 @@ class BackgroundUpdater(object):
if update_name == self._current_background_update:
return False
- update_exists = await self.db.simple_select_one_onecol(
+ update_exists = await self.db_pool.simple_select_one_onecol(
"background_updates",
keyvalues={"update_name": update_name},
retcol="1",
@@ -189,10 +189,10 @@ class BackgroundUpdater(object):
ORDER BY ordering, update_name
"""
)
- return self.db.cursor_to_dict(txn)
+ return self.db_pool.cursor_to_dict(txn)
if not self._current_background_update:
- all_pending_updates = await self.db.runInteraction(
+ all_pending_updates = await self.db_pool.runInteraction(
"background_updates", get_background_updates_txn,
)
if not all_pending_updates:
@@ -243,7 +243,7 @@ class BackgroundUpdater(object):
else:
batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
- progress_json = await self.db.simple_select_one_onecol(
+ progress_json = await self.db_pool.simple_select_one_onecol(
"background_updates",
keyvalues={"update_name": update_name},
retcol="progress_json",
@@ -402,7 +402,7 @@ class BackgroundUpdater(object):
logger.debug("[SQL] %s", sql)
c.execute(sql)
- if isinstance(self.db.engine, engines.PostgresEngine):
+ if isinstance(self.db_pool.engine, engines.PostgresEngine):
runner = create_index_psql
elif psql_only:
runner = None
@@ -413,7 +413,7 @@ class BackgroundUpdater(object):
def updater(progress, batch_size):
if runner is not None:
logger.info("Adding index %s to %s", index_name, table)
- yield self.db.runWithConnection(runner)
+ yield self.db_pool.runWithConnection(runner)
yield self._end_background_update(update_name)
return 1
@@ -433,7 +433,7 @@ class BackgroundUpdater(object):
% update_name
)
self._current_background_update = None
- return self.db.simple_delete_one(
+ return self.db_pool.simple_delete_one(
"background_updates", keyvalues={"update_name": update_name}
)
@@ -445,7 +445,7 @@ class BackgroundUpdater(object):
progress: The progress of the update.
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"background_update_progress",
self._background_update_progress_txn,
update_name,
@@ -463,7 +463,7 @@ class BackgroundUpdater(object):
progress_json = json.dumps(progress)
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
"background_updates",
keyvalues={"update_name": update_name},
diff --git a/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql
deleted file mode 100644
index 531b532c73..0000000000
--- a/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql
+++ /dev/null
@@ -1,18 +0,0 @@
-/* Copyright 2020 The Matrix.org Foundation C.I.C
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
--- Store a boolean value in the events table for whether the event should be counted in
--- the unread_count property of sync responses.
-ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN;
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index ce8757a400..4ada6f5563 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -279,7 +279,7 @@ class PerformanceCounters(object):
return top_n_counters
-class Database(object):
+class DatabasePool(object):
"""Wraps a single physical database and connection pool.
A single database may be used by multiple data stores.
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/databases/__init__.py
index 599ee470d4..4406e58273 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -15,17 +15,17 @@
import logging
-from synapse.storage.data_stores.main.events import PersistEventsStore
-from synapse.storage.data_stores.state import StateGroupDataStore
-from synapse.storage.database import Database, make_conn
+from synapse.storage.database import DatabasePool, make_conn
+from synapse.storage.databases.main.events import PersistEventsStore
+from synapse.storage.databases.state import StateGroupDataStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
logger = logging.getLogger(__name__)
-class DataStores(object):
- """The various data stores.
+class Databases(object):
+ """The various databases.
These are low level interfaces to physical databases.
@@ -38,9 +38,9 @@ class DataStores(object):
# store.
self.databases = []
- self.main = None
- self.state = None
- self.persist_events = None
+ main = None
+ state = None
+ persist_events = None
for database_config in hs.config.database.databases:
db_name = database_config.name
@@ -51,37 +51,35 @@ class DataStores(object):
engine.check_database(db_conn)
prepare_database(
- db_conn, engine, hs.config, data_stores=database_config.data_stores,
+ db_conn, engine, hs.config, databases=database_config.databases,
)
- database = Database(hs, database_config, engine)
+ database = DatabasePool(hs, database_config, engine)
- if "main" in database_config.data_stores:
+ if "main" in database_config.databases:
logger.info("Starting 'main' data store")
# Sanity check we don't try and configure the main store on
# multiple databases.
- if self.main:
+ if main:
raise Exception("'main' data store already configured")
- self.main = main_store_class(database, db_conn, hs)
+ main = main_store_class(database, db_conn, hs)
# If we're on a process that can persist events also
# instantiate a `PersistEventsStore`
if hs.config.worker.writers.events == hs.get_instance_name():
- self.persist_events = PersistEventsStore(
- hs, database, self.main
- )
+ persist_events = PersistEventsStore(hs, database, main)
- if "state" in database_config.data_stores:
+ if "state" in database_config.databases:
logger.info("Starting 'state' data store")
# Sanity check we don't try and configure the state store on
# multiple databases.
- if self.state:
+ if state:
raise Exception("'state' data store already configured")
- self.state = StateGroupDataStore(database, db_conn, hs)
+ state = StateGroupDataStore(database, db_conn, hs)
db_conn.commit()
@@ -90,8 +88,14 @@ class DataStores(object):
logger.info("Database %r prepared", db_name)
# Sanity check that we have actually configured all the required stores.
- if not self.main:
+ if not main:
raise Exception("No 'main' data store configured")
- if not self.state:
+ if not state:
raise Exception("No 'main' data store configured")
+
+ # We use local variables here to ensure that the databases do not have
+ # optional types.
+ self.main = main
+ self.state = state
+ self.persist_events = persist_events
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 932458f651..17fa470919 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -21,7 +21,7 @@ import time
from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
IdGenerator,
@@ -119,7 +119,7 @@ class DataStore(
CacheInvalidationWorkerStore,
ServerMetricsStore,
):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
@@ -174,7 +174,7 @@ class DataStore(
self._presence_on_startup = self._get_active_presence(db_conn)
- presence_cache_prefill, min_presence_val = self.db.get_cache_dict(
+ presence_cache_prefill, min_presence_val = self.db_pool.get_cache_dict(
db_conn,
"presence_stream",
entity_column="user_id",
@@ -188,7 +188,7 @@ class DataStore(
)
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
- device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict(
+ device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
db_conn,
"device_inbox",
entity_column="user_id",
@@ -203,7 +203,7 @@ class DataStore(
)
# The federation outbox and the local device inbox uses the same
# stream_id generator.
- device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict(
+ device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
db_conn,
"device_federation_outbox",
entity_column="destination",
@@ -229,7 +229,7 @@ class DataStore(
)
events_max = self._stream_id_gen.get_current_token()
- curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
+ curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
db_conn,
"current_state_delta_stream",
entity_column="room_id",
@@ -243,7 +243,7 @@ class DataStore(
prefilled_cache=curr_state_delta_prefill,
)
- _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict(
+ _group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict(
db_conn,
"local_group_updates",
entity_column="user_id",
@@ -282,7 +282,7 @@ class DataStore(
txn = db_conn.cursor()
txn.execute(sql, (PresenceState.OFFLINE,))
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
txn.close()
for row in rows:
@@ -295,7 +295,9 @@ class DataStore(
Counts the number of users who used this homeserver in the last 24 hours.
"""
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
- return self.db.runInteraction("count_daily_users", self._count_users, yesterday)
+ return self.db_pool.runInteraction(
+ "count_daily_users", self._count_users, yesterday
+ )
def count_monthly_users(self):
"""
@@ -305,7 +307,7 @@ class DataStore(
amongst other things, includes a 3 day grace period before a user counts.
"""
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"count_monthly_users", self._count_users, thirty_days_ago
)
@@ -405,7 +407,7 @@ class DataStore(
return results
- return self.db.runInteraction("count_r30_users", _count_r30_users)
+ return self.db_pool.runInteraction("count_r30_users", _count_r30_users)
def _get_start_of_day(self):
"""
@@ -470,7 +472,7 @@ class DataStore(
# frequently
self._last_user_visit_update = now
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"generate_user_daily_visits", _generate_user_daily_visits
)
@@ -481,7 +483,7 @@ class DataStore(
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
- return self.db.simple_select_list(
+ return self.db_pool.simple_select_list(
table="users",
keyvalues={},
retcols=[
@@ -543,10 +545,12 @@ class DataStore(
where_clause
)
txn.execute(sql, args)
- users = self.db.cursor_to_dict(txn)
+ users = self.db_pool.cursor_to_dict(txn)
return users, count
- return self.db.runInteraction("get_users_paginate_txn", get_users_paginate_txn)
+ return self.db_pool.runInteraction(
+ "get_users_paginate_txn", get_users_paginate_txn
+ )
def search_users(self, term):
"""Function to search users list for one or more users with
@@ -558,7 +562,7 @@ class DataStore(
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
- return self.db.simple_search_list(
+ return self.db_pool.simple_search_list(
table="users",
term=term,
col="name",
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 33cc372dfd..82aac2bbf3 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -16,16 +16,16 @@
import abc
import logging
-from typing import List, Tuple
-
-from canonicaljson import json
+from typing import List, Optional, Tuple
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.types import JsonDict
+from synapse.util import json_encoder
+from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
@@ -40,7 +40,7 @@ class AccountDataWorkerStore(SQLBaseStore):
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max
@@ -69,7 +69,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_user_txn(txn):
- rows = self.db.simple_select_list_txn(
+ rows = self.db_pool.simple_select_list_txn(
txn,
"account_data",
{"user_id": user_id},
@@ -80,7 +80,7 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
- rows = self.db.simple_select_list_txn(
+ rows = self.db_pool.simple_select_list_txn(
txn,
"room_account_data",
{"user_id": user_id},
@@ -94,17 +94,19 @@ class AccountDataWorkerStore(SQLBaseStore):
return global_account_data, by_room
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_account_data_for_user", get_account_data_for_user_txn
)
- @cachedInlineCallbacks(num_args=2, max_entries=5000)
- def get_global_account_data_by_type_for_user(self, data_type, user_id):
+ @cached(num_args=2, max_entries=5000)
+ async def get_global_account_data_by_type_for_user(
+ self, data_type: str, user_id: str
+ ) -> Optional[JsonDict]:
"""
Returns:
- Deferred: A dict
+ The account data.
"""
- result = yield self.db.simple_select_one_onecol(
+ result = await self.db_pool.simple_select_one_onecol(
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": data_type},
retcol="content",
@@ -129,7 +131,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_room_txn(txn):
- rows = self.db.simple_select_list_txn(
+ rows = self.db_pool.simple_select_list_txn(
txn,
"room_account_data",
{"user_id": user_id, "room_id": room_id},
@@ -140,7 +142,7 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_account_data_for_room", get_account_data_for_room_txn
)
@@ -158,7 +160,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_room_and_type_txn(txn):
- content_json = self.db.simple_select_one_onecol_txn(
+ content_json = self.db_pool.simple_select_one_onecol_txn(
txn,
table="room_account_data",
keyvalues={
@@ -172,7 +174,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return db_to_json(content_json) if content_json else None
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
@@ -202,7 +204,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_updated_global_account_data", get_updated_global_account_data_txn
)
@@ -232,7 +234,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_updated_room_account_data", get_updated_room_account_data_txn
)
@@ -277,13 +279,15 @@ class AccountDataWorkerStore(SQLBaseStore):
if not changed:
return defer.succeed(({}, {}))
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
- @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
- def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
- ignored_account_data = yield self.get_global_account_data_by_type_for_user(
+ @cached(num_args=2, cache_context=True, max_entries=5000)
+ async def is_ignored_by(
+ self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext
+ ) -> bool:
+ ignored_account_data = await self.get_global_account_data_by_type_for_user(
"m.ignored_user_list",
ignorer_user_id,
on_invalidate=cache_context.invalidate,
@@ -295,7 +299,7 @@ class AccountDataWorkerStore(SQLBaseStore):
class AccountDataStore(AccountDataWorkerStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
self._account_data_id_gen = StreamIdGenerator(
db_conn,
"account_data_max_stream_id",
@@ -308,32 +312,35 @@ class AccountDataStore(AccountDataWorkerStore):
super(AccountDataStore, self).__init__(database, db_conn, hs)
- def get_max_account_data_stream_id(self):
+ def get_max_account_data_stream_id(self) -> int:
"""Get the current max stream id for the private user data stream
Returns:
- A deferred int.
+ The maximum stream ID.
"""
return self._account_data_id_gen.get_current_token()
- @defer.inlineCallbacks
- def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
+ async def add_account_data_to_room(
+ self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
+ ) -> int:
"""Add some account_data to a room for a user.
+
Args:
- user_id(str): The user to add a tag for.
- room_id(str): The room to add a tag for.
- account_data_type(str): The type of account_data to add.
- content(dict): A json object to associate with the tag.
+ user_id: The user to add a tag for.
+ room_id: The room to add a tag for.
+ account_data_type: The type of account_data to add.
+ content: A json object to associate with the tag.
+
Returns:
- A deferred that completes once the account_data has been added.
+ The maximum stream ID.
"""
- content_json = json.dumps(content)
+ content_json = json_encoder.encode(content)
with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
- yield self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
desc="add_room_account_data",
table="room_account_data",
keyvalues={
@@ -351,7 +358,7 @@ class AccountDataStore(AccountDataWorkerStore):
# doesn't sound any worse than the whole update getting lost,
# which is what would happen if we combined the two into one
# transaction.
- yield self._update_max_stream_id(next_id)
+ await self._update_max_stream_id(next_id)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
@@ -360,26 +367,28 @@ class AccountDataStore(AccountDataWorkerStore):
(user_id, room_id, account_data_type), content
)
- result = self._account_data_id_gen.get_current_token()
- return result
+ return self._account_data_id_gen.get_current_token()
- @defer.inlineCallbacks
- def add_account_data_for_user(self, user_id, account_data_type, content):
+ async def add_account_data_for_user(
+ self, user_id: str, account_data_type: str, content: JsonDict
+ ) -> int:
"""Add some account_data to a room for a user.
+
Args:
- user_id(str): The user to add a tag for.
- account_data_type(str): The type of account_data to add.
- content(dict): A json object to associate with the tag.
+ user_id: The user to add a tag for.
+ account_data_type: The type of account_data to add.
+ content: A json object to associate with the tag.
+
Returns:
- A deferred that completes once the account_data has been added.
+ The maximum stream ID.
"""
- content_json = json.dumps(content)
+ content_json = json_encoder.encode(content)
with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
- yield self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
desc="add_user_account_data",
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": account_data_type},
@@ -397,7 +406,7 @@ class AccountDataStore(AccountDataWorkerStore):
# Note: This is only here for backwards compat to allow admins to
# roll back to a previous Synapse version. Next time we update the
# database version we can remove this table.
- yield self._update_max_stream_id(next_id)
+ await self._update_max_stream_id(next_id)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
@@ -405,14 +414,13 @@ class AccountDataStore(AccountDataWorkerStore):
(account_data_type, user_id)
)
- result = self._account_data_id_gen.get_current_token()
- return result
+ return self._account_data_id_gen.get_current_token()
- def _update_max_stream_id(self, next_id):
+ def _update_max_stream_id(self, next_id: int):
"""Update the max stream_id
Args:
- next_id(int): The the revision to advance to.
+ next_id: The the revision to advance to.
"""
# Note: This is only here for backwards compat to allow admins to
@@ -427,4 +435,4 @@ class AccountDataStore(AccountDataWorkerStore):
)
txn.execute(update_max_id_sql, (next_id, next_id))
- return self.db.runInteraction("update_account_data_max_stream_id", _update)
+ return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 56659fed37..5cf1a88399 100644
--- a/synapse/storage/data_stores/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -18,13 +18,11 @@ import re
from canonicaljson import json
-from twisted.internet import defer
-
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
logger = logging.getLogger(__name__)
@@ -49,7 +47,7 @@ def _make_exclusive_regex(services_cache):
class ApplicationServiceWorkerStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
self.services_cache = load_appservices(
hs.hostname, hs.config.app_service_config_files
)
@@ -124,17 +122,15 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
class ApplicationServiceTransactionWorkerStore(
ApplicationServiceWorkerStore, EventsWorkerStore
):
- @defer.inlineCallbacks
- def get_appservices_by_state(self, state):
+ async def get_appservices_by_state(self, state):
"""Get a list of application services based on their state.
Args:
state(ApplicationServiceState): The state to filter on.
Returns:
- A Deferred which resolves to a list of ApplicationServices, which
- may be empty.
+ A list of ApplicationServices, which may be empty.
"""
- results = yield self.db.simple_select_list(
+ results = await self.db_pool.simple_select_list(
"application_services_state", {"state": state}, ["as_id"]
)
# NB: This assumes this class is linked with ApplicationServiceStore
@@ -147,16 +143,15 @@ class ApplicationServiceTransactionWorkerStore(
services.append(service)
return services
- @defer.inlineCallbacks
- def get_appservice_state(self, service):
+ async def get_appservice_state(self, service):
"""Get the application service state.
Args:
service(ApplicationService): The service whose state to set.
Returns:
- A Deferred which resolves to ApplicationServiceState.
+ An ApplicationServiceState.
"""
- result = yield self.db.simple_select_one(
+ result = await self.db_pool.simple_select_one(
"application_services_state",
{"as_id": service.id},
["state"],
@@ -176,7 +171,7 @@ class ApplicationServiceTransactionWorkerStore(
Returns:
A Deferred which resolves when the state was set successfully.
"""
- return self.db.simple_upsert(
+ return self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}
)
@@ -217,7 +212,9 @@ class ApplicationServiceTransactionWorkerStore(
)
return AppServiceTransaction(service=service, id=new_txn_id, events=events)
- return self.db.runInteraction("create_appservice_txn", _create_appservice_txn)
+ return self.db_pool.runInteraction(
+ "create_appservice_txn", _create_appservice_txn
+ )
def complete_appservice_txn(self, txn_id, service):
"""Completes an application service transaction.
@@ -250,7 +247,7 @@ class ApplicationServiceTransactionWorkerStore(
)
# Set current txn_id for AS to 'txn_id'
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
"application_services_state",
{"as_id": service.id},
@@ -258,26 +255,24 @@ class ApplicationServiceTransactionWorkerStore(
)
# Delete txn
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
"application_services_txns",
{"txn_id": txn_id, "as_id": service.id},
)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"complete_appservice_txn", _complete_appservice_txn
)
- @defer.inlineCallbacks
- def get_oldest_unsent_txn(self, service):
+ async def get_oldest_unsent_txn(self, service):
"""Get the oldest transaction which has not been sent for this
service.
Args:
service(ApplicationService): The app service to get the oldest txn.
Returns:
- A Deferred which resolves to an AppServiceTransaction or
- None.
+ An AppServiceTransaction or None.
"""
def _get_oldest_unsent_txn(txn):
@@ -288,7 +283,7 @@ class ApplicationServiceTransactionWorkerStore(
" ORDER BY txn_id ASC LIMIT 1",
(service.id,),
)
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
if not rows:
return None
@@ -296,7 +291,7 @@ class ApplicationServiceTransactionWorkerStore(
return entry
- entry = yield self.db.runInteraction(
+ entry = await self.db_pool.runInteraction(
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
)
@@ -305,7 +300,7 @@ class ApplicationServiceTransactionWorkerStore(
event_ids = db_to_json(entry["event_ids"])
- events = yield self.get_events_as_list(event_ids)
+ events = await self.get_events_as_list(event_ids)
return AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
@@ -326,12 +321,11 @@ class ApplicationServiceTransactionWorkerStore(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"set_appservice_last_pos", set_appservice_last_pos_txn
)
- @defer.inlineCallbacks
- def get_new_events_for_appservice(self, current_id, limit):
+ async def get_new_events_for_appservice(self, current_id, limit):
"""Get all new evnets"""
def get_new_events_for_appservice_txn(txn):
@@ -355,11 +349,11 @@ class ApplicationServiceTransactionWorkerStore(
return upper_bound, [row[1] for row in rows]
- upper_bound, event_ids = yield self.db.runInteraction(
+ upper_bound, event_ids = await self.db_pool.runInteraction(
"get_new_events_for_appservice", get_new_events_for_appservice_txn
)
- events = yield self.get_events_as_list(event_ids)
+ events = await self.get_events_as_list(event_ids)
return upper_bound, events
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/databases/main/cache.py
index edc3624fed..10de446065 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -26,7 +26,7 @@ from synapse.replication.tcp.streams.events import (
EventsStreamEventRow,
)
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter
@@ -39,7 +39,7 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationWorkerStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
@@ -92,7 +92,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
return updates, upto_token, limited
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
@@ -172,7 +172,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_latest_event_ids_in_room.invalidate((room_id,))
- self.get_unread_message_count_for_user.invalidate_many((room_id,))
self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
if not backfilled:
@@ -203,7 +202,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
return
cache_func.invalidate(keys)
- await self.db.runInteraction(
+ await self.db_pool.runInteraction(
"invalidate_cache_and_stream",
self._send_invalidation_to_replication,
cache_func.__name__,
@@ -288,7 +287,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if keys is not None:
keys = list(keys)
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="cache_invalidation_stream_by_instance",
values={
diff --git a/synapse/storage/data_stores/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index 2d48261724..f211ddbaf8 100644
--- a/synapse/storage/data_stores/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -16,15 +16,13 @@
import logging
from typing import TYPE_CHECKING
-from twisted.internet import defer
-
from synapse.events.utils import prune_event_dict
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.data_stores.main.events import encode_json
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.databases.main.events import encode_json
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -34,7 +32,7 @@ logger = logging.getLogger(__name__)
class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs: "HomeServer"):
+ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
def _censor_redactions():
@@ -56,7 +54,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
return
if not (
- await self.db.updates.has_completed_background_update(
+ await self.db_pool.updates.has_completed_background_update(
"redactions_have_censored_ts_idx"
)
):
@@ -85,7 +83,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
LIMIT ?
"""
- rows = await self.db.execute(
+ rows = await self.db_pool.execute(
"_censor_redactions_fetch", None, sql, before_ts, 100
)
@@ -123,14 +121,14 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
if pruned_json:
self._censor_event_txn(txn, event_id, pruned_json)
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
table="redactions",
keyvalues={"event_id": redaction_id},
updatevalues={"have_censored": True},
)
- await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
+ await self.db_pool.runInteraction("_update_censor_txn", _update_censor_txn)
def _censor_event_txn(self, txn, event_id, pruned_json):
"""Censor an event by replacing its JSON in the event_json table with the
@@ -141,24 +139,23 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
event_id (str): The ID of the event to censor.
pruned_json (str): The pruned JSON
"""
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
table="event_json",
keyvalues={"event_id": event_id},
updatevalues={"json": pruned_json},
)
- @defer.inlineCallbacks
- def expire_event(self, event_id):
+ async def expire_event(self, event_id: str) -> None:
"""Retrieve and expire an event that has expired, and delete its associated
expiry timestamp. If the event can't be retrieved, delete its associated
timestamp so we don't try to expire it again in the future.
Args:
- event_id (str): The ID of the event to delete.
+ event_id: The ID of the event to delete.
"""
# Try to retrieve the event's content from the database or the event cache.
- event = yield self.get_event(event_id)
+ event = await self.get_event(event_id)
def delete_expired_event_txn(txn):
# Delete the expiry timestamp associated with this event from the database.
@@ -193,7 +190,9 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
txn, "_get_event_cache", (event.event_id,)
)
- yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn)
+ await self.db_pool.runInteraction(
+ "delete_expired_event", delete_expired_event_txn
+ )
def _delete_event_expiry_txn(self, txn, event_id):
"""Delete the expiry timestamp associated with an event ID without deleting the
@@ -203,6 +202,6 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
txn (LoggingTransaction): The transaction to use to perform the deletion.
event_id (str): The event ID to delete the associated expiry timestamp of.
"""
- return self.db.simple_delete_txn(
+ return self.db_pool.simple_delete_txn(
txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
)
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 1c035d51cb..4e2b2a85ee 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -14,12 +14,11 @@
# limitations under the License.
import logging
-
-from twisted.internet import defer
+from typing import Dict, Optional, Tuple
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database, make_tuple_comparison_clause
+from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
from synapse.util.caches.descriptors import Cache
logger = logging.getLogger(__name__)
@@ -31,40 +30,40 @@ LAST_SEEN_GRANULARITY = 10 * 60 * 1000
class ClientIpBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"user_ips_device_index",
index_name="user_ips_device_id",
table="user_ips",
columns=["user_id", "device_id", "last_seen"],
)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"user_ips_last_seen_index",
index_name="user_ips_last_seen",
table="user_ips",
columns=["user_id", "last_seen"],
)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"user_ips_last_seen_only_index",
index_name="user_ips_last_seen_only",
table="user_ips",
columns=["last_seen"],
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"user_ips_analyze", self._analyze_user_ip
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"user_ips_remove_dupes", self._remove_user_ip_dupes
)
# Register a unique index
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"user_ips_device_unique_index",
index_name="user_ips_user_token_ip_unique_index",
table="user_ips",
@@ -73,28 +72,28 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
)
# Drop the old non-unique index
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"user_ips_drop_nonunique_index", self._remove_user_ip_nonunique
)
# Update the last seen info in devices.
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"devices_last_seen", self._devices_last_seen_update
)
- @defer.inlineCallbacks
- def _remove_user_ip_nonunique(self, progress, batch_size):
+ async def _remove_user_ip_nonunique(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
txn.close()
- yield self.db.runWithConnection(f)
- yield self.db.updates._end_background_update("user_ips_drop_nonunique_index")
+ await self.db_pool.runWithConnection(f)
+ await self.db_pool.updates._end_background_update(
+ "user_ips_drop_nonunique_index"
+ )
return 1
- @defer.inlineCallbacks
- def _analyze_user_ip(self, progress, batch_size):
+ async def _analyze_user_ip(self, progress, batch_size):
# Background update to analyze user_ips table before we run the
# deduplication background update. The table may not have been analyzed
# for ages due to the table locks.
@@ -104,14 +103,13 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
def user_ips_analyze(txn):
txn.execute("ANALYZE user_ips")
- yield self.db.runInteraction("user_ips_analyze", user_ips_analyze)
+ await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
- yield self.db.updates._end_background_update("user_ips_analyze")
+ await self.db_pool.updates._end_background_update("user_ips_analyze")
return 1
- @defer.inlineCallbacks
- def _remove_user_ip_dupes(self, progress, batch_size):
+ async def _remove_user_ip_dupes(self, progress, batch_size):
# This works function works by scanning the user_ips table in batches
# based on `last_seen`. For each row in a batch it searches the rest of
# the table to see if there are any duplicates, if there are then they
@@ -138,7 +136,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return None
# Get a last seen that has roughly `batch_size` since `begin_last_seen`
- end_last_seen = yield self.db.runInteraction(
+ end_last_seen = await self.db_pool.runInteraction(
"user_ips_dups_get_last_seen", get_last_seen
)
@@ -269,19 +267,18 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
(user_id, access_token, ip, device_id, user_agent, last_seen),
)
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
)
- yield self.db.runInteraction("user_ips_dups_remove", remove)
+ await self.db_pool.runInteraction("user_ips_dups_remove", remove)
if last:
- yield self.db.updates._end_background_update("user_ips_remove_dupes")
+ await self.db_pool.updates._end_background_update("user_ips_remove_dupes")
return batch_size
- @defer.inlineCallbacks
- def _devices_last_seen_update(self, progress, batch_size):
+ async def _devices_last_seen_update(self, progress, batch_size):
"""Background update to insert last seen info into devices table
"""
@@ -336,7 +333,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
txn.execute_batch(sql, rows)
_, _, _, user_id, device_id = rows[-1]
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn,
"devices_last_seen",
{"last_user_id": user_id, "last_device_id": device_id},
@@ -344,18 +341,18 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return len(rows)
- updated = yield self.db.runInteraction(
+ updated = await self.db_pool.runInteraction(
"_devices_last_seen_update", _devices_last_seen_update_txn
)
if not updated:
- yield self.db.updates._end_background_update("devices_last_seen")
+ await self.db_pool.updates._end_background_update("devices_last_seen")
return updated
class ClientIpStore(ClientIpBackgroundUpdateStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000
@@ -378,8 +375,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
if self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
- @defer.inlineCallbacks
- def insert_client_ip(
+ async def insert_client_ip(
self, user_id, access_token, ip, user_agent, device_id, now=None
):
if not now:
@@ -390,7 +386,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
- yield self.populate_monthly_active_users(user_id)
+ await self.populate_monthly_active_users(user_id)
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
@@ -403,18 +399,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
def _update_client_ips_batch(self):
# If the DB pool has already terminated, don't try updating
- if not self.db.is_running():
+ if not self.db_pool.is_running():
return
to_update = self._batch_row_update
self._batch_row_update = {}
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
def _update_client_ips_batch_txn(self, txn, to_update):
- if "user_ips" in self.db._unsafe_to_upsert_tables or (
+ if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
not self.database_engine.can_native_upsert
):
self.database_engine.lock_table(txn, "user_ips")
@@ -423,7 +419,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
try:
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="user_ips",
keyvalues={
@@ -445,7 +441,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
# this is always an update rather than an upsert: the row should
# already exist, and if it doesn't, that may be because it has been
# deleted, and we don't want to re-create it.
- self.db.simple_update_txn(
+ self.db_pool.simple_update_txn(
txn,
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -459,25 +455,25 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
# Failed to upsert, log and continue
logger.error("Failed to insert client IP %r: %r", entry, e)
- @defer.inlineCallbacks
- def get_last_client_ip_by_device(self, user_id, device_id):
+ async def get_last_client_ip_by_device(
+ self, user_id: str, device_id: Optional[str]
+ ) -> Dict[Tuple[str, str], dict]:
"""For each device_id listed, give the user_ip it was last seen on
Args:
- user_id (str)
- device_id (str): If None fetches all devices for the user
+ user_id: The user to fetch devices for.
+ device_id: If None fetches all devices for the user
Returns:
- defer.Deferred: resolves to a dict, where the keys
- are (user_id, device_id) tuples. The values are also dicts, with
- keys giving the column names
+ A dictionary mapping a tuple of (user_id, device_id) to dicts, with
+ keys giving the column names from the devices table.
"""
keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id
- res = yield self.db.simple_select_list(
+ res = await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
@@ -499,8 +495,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
}
return ret
- @defer.inlineCallbacks
- def get_user_ip_and_agents(self, user):
+ async def get_user_ip_and_agents(self, user):
user_id = user.to_string()
results = {}
@@ -510,7 +505,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
user_agent, _, last_seen = self._batch_row_update[key]
results[(access_token, ip)] = (user_agent, last_seen)
- rows = yield self.db.simple_select_list(
+ rows = await self.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "last_seen"],
@@ -540,7 +535,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
# Nothing to do
return
- if not await self.db.updates.has_completed_background_update(
+ if not await self.db_pool.updates.has_completed_background_update(
"devices_last_seen"
):
# Only start pruning if we have finished populating the devices
@@ -573,4 +568,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
def _prune_old_user_ips_txn(txn):
txn.execute(sql, (timestamp,))
- await self.db.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn)
+ await self.db_pool.runInteraction(
+ "_prune_old_user_ips", _prune_old_user_ips_txn
+ )
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index da297b31fb..1f6e995c4f 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -16,13 +16,10 @@
import logging
from typing import List, Tuple
-from canonicaljson import json
-
-from twisted.internet import defer
-
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
logger = logging.getLogger(__name__)
@@ -32,24 +29,31 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
- def get_new_messages_for_device(
- self, user_id, device_id, last_stream_id, current_stream_id, limit=100
- ):
+ async def get_new_messages_for_device(
+ self,
+ user_id: str,
+ device_id: str,
+ last_stream_id: int,
+ current_stream_id: int,
+ limit: int = 100,
+ ) -> Tuple[List[dict], int]:
"""
Args:
- user_id(str): The recipient user_id.
- device_id(str): The recipient device_id.
- current_stream_id(int): The current position of the to device
+ user_id: The recipient user_id.
+ device_id: The recipient device_id.
+ last_stream_id: The last stream ID checked.
+ current_stream_id: The current position of the to device
message stream.
+ limit: The maximum number of messages to retrieve.
+
Returns:
- Deferred ([dict], int): List of messages for the device and where
- in the stream the messages got to.
+ A list of messages for the device and where in the stream the messages got to.
"""
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_stream_id
)
if not has_changed:
- return defer.succeed(([], current_stream_id))
+ return ([], current_stream_id)
def get_new_messages_for_device_txn(txn):
sql = (
@@ -70,20 +74,22 @@ class DeviceInboxWorkerStore(SQLBaseStore):
stream_pos = current_stream_id
return messages, stream_pos
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn
)
@trace
- @defer.inlineCallbacks
- def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
+ async def delete_messages_for_device(
+ self, user_id: str, device_id: str, up_to_stream_id: int
+ ) -> int:
"""
Args:
- user_id(str): The recipient user_id.
- device_id(str): The recipient device_id.
- up_to_stream_id(int): Where to delete messages up to.
+ user_id: The recipient user_id.
+ device_id: The recipient device_id.
+ up_to_stream_id: Where to delete messages up to.
+
Returns:
- A deferred that resolves to the number of messages deleted.
+ The number of messages deleted.
"""
# If we have cached the last stream id we've deleted up to, we can
# check if there is likely to be anything that needs deleting
@@ -110,7 +116,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount
- count = yield self.db.runInteraction(
+ count = await self.db_pool.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn
)
@@ -129,9 +135,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return count
@trace
- def get_new_device_msgs_for_remote(
+ async def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit
- ):
+ ) -> Tuple[List[dict], int]:
"""
Args:
destination(str): The name of the remote server.
@@ -140,8 +146,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
current_stream_id(int|long): The current position of the device
message stream.
Returns:
- Deferred ([dict], int|long): List of messages for the device and where
- in the stream the messages got to.
+ A list of messages for the device and where in the stream the messages got to.
"""
set_tag("destination", destination)
@@ -154,11 +159,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
if not has_changed or last_stream_id == current_stream_id:
log_kv({"message": "No new messages in stream"})
- return defer.succeed(([], current_stream_id))
+ return ([], current_stream_id)
if limit <= 0:
# This can happen if we run out of room for EDUs in the transaction.
- return defer.succeed(([], last_stream_id))
+ return ([], last_stream_id)
@trace
def get_new_messages_for_remote_destination_txn(txn):
@@ -179,7 +184,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
stream_pos = current_stream_id
return messages, stream_pos
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn,
)
@@ -204,7 +209,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
txn.execute(sql, (destination, up_to_stream_id))
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)
@@ -269,7 +274,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return updates, upto_token, limited
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)
@@ -277,30 +282,29 @@ class DeviceInboxWorkerStore(SQLBaseStore):
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"device_inbox_stream_index",
index_name="device_inbox_stream_id_user_id",
table="device_inbox",
columns=["stream_id", "user_id"],
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
- @defer.inlineCallbacks
- def _background_drop_index_device_inbox(self, progress, batch_size):
+ async def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
- yield self.db.runWithConnection(reindex_txn)
+ await self.db_pool.runWithConnection(reindex_txn)
- yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
+ await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
return 1
@@ -308,7 +312,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(DeviceInboxStore, self).__init__(database, db_conn, hs)
# Map of (user_id, device_id) to the last stream_id that has been
@@ -321,21 +325,21 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
)
@trace
- @defer.inlineCallbacks
- def add_messages_to_device_inbox(
- self, local_messages_by_user_then_device, remote_messages_by_destination
- ):
+ async def add_messages_to_device_inbox(
+ self,
+ local_messages_by_user_then_device: dict,
+ remote_messages_by_destination: dict,
+ ) -> int:
"""Used to send messages from this server.
Args:
- sender_user_id(str): The ID of the user sending these messages.
- local_messages_by_user_and_device(dict):
+ local_messages_by_user_and_device:
Dictionary of user_id to device_id to message.
- remote_messages_by_destination(dict):
+ remote_messages_by_destination:
Dictionary of destination server_name to the EDU JSON to send.
+
Returns:
- A deferred stream_id that resolves when the messages have been
- inserted.
+ The new stream_id.
"""
def add_messages_txn(txn, now_ms, stream_id):
@@ -354,13 +358,13 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
)
rows = []
for destination, edu in remote_messages_by_destination.items():
- edu_json = json.dumps(edu)
+ edu_json = json_encoder.encode(edu)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
)
for user_id in local_messages_by_user_then_device.keys():
@@ -372,15 +376,14 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
return self._device_inbox_id_gen.get_current_token()
- @defer.inlineCallbacks
- def add_messages_from_remote_to_device_inbox(
- self, origin, message_id, local_messages_by_user_then_device
- ):
+ async def add_messages_from_remote_to_device_inbox(
+ self, origin: str, message_id: str, local_messages_by_user_then_device: dict
+ ) -> int:
def add_messages_txn(txn, now_ms, stream_id):
# Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message.
- already_inserted = self.db.simple_select_one_txn(
+ already_inserted = self.db_pool.simple_select_one_txn(
txn,
table="device_federation_inbox",
keyvalues={"origin": origin, "message_id": message_id},
@@ -392,7 +395,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Add an entry for this message_id so that we know we've processed
# it.
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="device_federation_inbox",
values={
@@ -410,7 +413,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
add_messages_txn,
now_ms,
@@ -432,7 +435,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Handle wildcard device_ids.
sql = "SELECT device_id FROM devices WHERE user_id = ?"
txn.execute(sql, (user_id,))
- message_json = json.dumps(messages_by_device["*"])
+ message_json = json_encoder.encode(messages_by_device["*"])
for row in txn:
# Add the message for all devices for this user on this
# server.
@@ -454,7 +457,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Only insert into the local inbox if the device exists on
# this server
device = row[0]
- message_json = json.dumps(messages_by_device[device])
+ message_json = json_encoder.encode(messages_by_device[device])
messages_json_for_user[device] = message_json
if messages_json_for_user:
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/databases/main/devices.py
index 45581a6500..2b33060480 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -15,11 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List, Optional, Set, Tuple
-
-from canonicaljson import json
-
-from twisted.internet import defer
+from typing import Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
@@ -31,17 +27,13 @@ from synapse.logging.opentracing import (
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
- Database,
+ DatabasePool,
LoggingTransaction,
make_tuple_comparison_clause,
)
-from synapse.types import Collection, get_verify_key_from_cross_signing_key
-from synapse.util.caches.descriptors import (
- Cache,
- cached,
- cachedInlineCallbacks,
- cachedList,
-)
+from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
+from synapse.util import json_encoder
+from synapse.util.caches.descriptors import Cache, cached, cachedList
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -55,38 +47,36 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore):
- def get_device(self, user_id, device_id):
+ def get_device(self, user_id: str, device_id: str):
"""Retrieve a device. Only returns devices that are not marked as
hidden.
Args:
- user_id (str): The ID of the user which owns the device
- device_id (str): The ID of the device to retrieve
+ user_id: The ID of the user which owns the device
+ device_id: The ID of the device to retrieve
Returns:
defer.Deferred for a dict containing the device information
Raises:
StoreError: if the device is not found
"""
- return self.db.simple_select_one(
+ return self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
)
- @defer.inlineCallbacks
- def get_devices_by_user(self, user_id):
+ async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
"""Retrieve all of a user's registered devices. Only returns devices
that are not marked as hidden.
Args:
- user_id (str):
+ user_id:
Returns:
- defer.Deferred: resolves to a dict from device_id to a dict
- containing "device_id", "user_id" and "display_name" for each
- device.
+ A mapping from device_id to a dict containing "device_id", "user_id"
+ and "display_name" for each device.
"""
- devices = yield self.db.simple_select_list(
+ devices = await self.db_pool.simple_select_list(
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
@@ -96,19 +86,20 @@ class DeviceWorkerStore(SQLBaseStore):
return {d["device_id"]: d for d in devices}
@trace
- @defer.inlineCallbacks
- def get_device_updates_by_remote(self, destination, from_stream_id, limit):
+ async def get_device_updates_by_remote(
+ self, destination: str, from_stream_id: int, limit: int
+ ) -> Tuple[int, List[Tuple[str, dict]]]:
"""Get a stream of device updates to send to the given remote server.
Args:
- destination (str): The host the device updates are intended for
- from_stream_id (int): The minimum stream_id to filter updates by, exclusive
- limit (int): Maximum number of device updates to return
+ destination: The host the device updates are intended for
+ from_stream_id: The minimum stream_id to filter updates by, exclusive
+ limit: Maximum number of device updates to return
+
Returns:
- Deferred[tuple[int, list[tuple[string,dict]]]]:
- current stream id (ie, the stream id of the last update included in the
- response), and the list of updates, where each update is a pair of EDU
- type and EDU contents
+ A mapping from the current stream id (ie, the stream id of the last
+ update included in the response), and the list of updates, where
+ each update is a pair of EDU type and EDU contents.
"""
now_stream_id = self._device_list_id_gen.get_current_token()
@@ -118,7 +109,7 @@ class DeviceWorkerStore(SQLBaseStore):
if not has_changed:
return now_stream_id, []
- updates = yield self.db.runInteraction(
+ updates = await self.db_pool.runInteraction(
"get_device_updates_by_remote",
self._get_device_updates_by_remote_txn,
destination,
@@ -137,7 +128,7 @@ class DeviceWorkerStore(SQLBaseStore):
master_key_by_user = {}
self_signing_key_by_user = {}
for user in users:
- cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
+ cross_signing_key = await self.get_e2e_cross_signing_key(user, "master")
if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key(
cross_signing_key
@@ -150,7 +141,7 @@ class DeviceWorkerStore(SQLBaseStore):
"device_id": verify_key.version,
}
- cross_signing_key = yield self.get_e2e_cross_signing_key(
+ cross_signing_key = await self.get_e2e_cross_signing_key(
user, "self_signing"
)
if cross_signing_key:
@@ -201,7 +192,7 @@ class DeviceWorkerStore(SQLBaseStore):
if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context)
- results = yield self._get_device_update_edus_by_remote(
+ results = await self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
@@ -214,16 +205,21 @@ class DeviceWorkerStore(SQLBaseStore):
return now_stream_id, results
def _get_device_updates_by_remote_txn(
- self, txn, destination, from_stream_id, now_stream_id, limit
+ self,
+ txn: LoggingTransaction,
+ destination: str,
+ from_stream_id: int,
+ now_stream_id: int,
+ limit: int,
):
"""Return device update information for a given remote destination
Args:
- txn (LoggingTransaction): The transaction to execute
- destination (str): The host the device updates are intended for
- from_stream_id (int): The minimum stream_id to filter updates by, exclusive
- now_stream_id (int): The maximum stream_id to filter updates by, inclusive
- limit (int): Maximum number of device updates to return
+ txn: The transaction to execute
+ destination: The host the device updates are intended for
+ from_stream_id: The minimum stream_id to filter updates by, exclusive
+ now_stream_id: The maximum stream_id to filter updates by, inclusive
+ limit: Maximum number of device updates to return
Returns:
List: List of device updates
@@ -239,23 +235,26 @@ class DeviceWorkerStore(SQLBaseStore):
return list(txn)
- @defer.inlineCallbacks
- def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map):
+ async def _get_device_update_edus_by_remote(
+ self,
+ destination: str,
+ from_stream_id: int,
+ query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]],
+ ) -> List[Tuple[str, dict]]:
"""Returns a list of device update EDUs as well as E2EE keys
Args:
- destination (str): The host the device updates are intended for
- from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+ destination: The host the device updates are intended for
+ from_stream_id: The minimum stream_id to filter updates by, exclusive
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
- user_id/device_id to update stream_id and the relevent json-encoded
+ user_id/device_id to update stream_id and the relevant json-encoded
opentracing context
Returns:
- List[Dict]: List of objects representing an device update EDU
-
+ List of objects representing an device update EDU
"""
devices = (
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"_get_e2e_device_keys_txn",
self._get_e2e_device_keys_txn,
query_map.keys(),
@@ -270,7 +269,7 @@ class DeviceWorkerStore(SQLBaseStore):
for user_id, user_devices in devices.items():
# The prev_id for the first row is always the last row before
# `from_stream_id`
- prev_id = yield self._get_last_device_update_for_remote_user(
+ prev_id = await self._get_last_device_update_for_remote_user(
destination, user_id, from_stream_id
)
@@ -314,7 +313,7 @@ class DeviceWorkerStore(SQLBaseStore):
return results
def _get_last_device_update_for_remote_user(
- self, destination, user_id, from_stream_id
+ self, destination: str, user_id: str, from_stream_id: int
):
def f(txn):
prev_sent_id_sql = """
@@ -326,19 +325,21 @@ class DeviceWorkerStore(SQLBaseStore):
rows = txn.fetchall()
return rows[0][0]
- return self.db.runInteraction("get_last_device_update_for_remote_user", f)
+ return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
- def mark_as_sent_devices_by_remote(self, destination, stream_id):
+ def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int):
"""Mark that updates have successfully been sent to the destination.
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"mark_as_sent_devices_by_remote",
self._mark_as_sent_devices_by_remote_txn,
destination,
stream_id,
)
- def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
+ def _mark_as_sent_devices_by_remote_txn(
+ self, txn: LoggingTransaction, destination: str, stream_id: int
+ ) -> None:
# We update the device_lists_outbound_last_success with the successfully
# poked users.
sql = """
@@ -350,7 +351,7 @@ class DeviceWorkerStore(SQLBaseStore):
txn.execute(sql, (destination, stream_id))
rows = txn.fetchall()
- self.db.simple_upsert_many_txn(
+ self.db_pool.simple_upsert_many_txn(
txn=txn,
table="device_lists_outbound_last_success",
key_names=("destination", "user_id"),
@@ -366,17 +367,21 @@ class DeviceWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (destination, stream_id))
- @defer.inlineCallbacks
- def add_user_signature_change_to_streams(self, from_user_id, user_ids):
+ async def add_user_signature_change_to_streams(
+ self, from_user_id: str, user_ids: List[str]
+ ) -> int:
"""Persist that a user has made new signatures
Args:
- from_user_id (str): the user who made the signatures
- user_ids (list[str]): the users who were signed
+ from_user_id: the user who made the signatures
+ user_ids: the users who were signed
+
+ Returns:
+ THe new stream ID.
"""
with self._device_list_id_gen.get_next() as stream_id:
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
from_user_id,
@@ -385,45 +390,52 @@ class DeviceWorkerStore(SQLBaseStore):
)
return stream_id
- def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id):
+ def _add_user_signature_change_txn(
+ self,
+ txn: LoggingTransaction,
+ from_user_id: str,
+ user_ids: List[str],
+ stream_id: int,
+ ) -> None:
txn.call_after(
self._user_signature_stream_cache.entity_has_changed,
from_user_id,
stream_id,
)
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
"user_signature_stream",
values={
"stream_id": stream_id,
"from_user_id": from_user_id,
- "user_ids": json.dumps(user_ids),
+ "user_ids": json_encoder.encode(user_ids),
},
)
- def get_device_stream_token(self):
+ def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
@trace
- @defer.inlineCallbacks
- def get_user_devices_from_cache(self, query_list):
+ async def get_user_devices_from_cache(
+ self, query_list: List[Tuple[str, str]]
+ ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache.
Args:
- query_list(list): List of (user_id, device_ids), if device_ids is
+ query_list: List of (user_id, device_ids), if device_ids is
falsey then return all device ids for that user.
Returns:
- (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
- a set of user_ids and results_map is a mapping of
- user_id -> device_id -> device_info
+ A tuple of (user_ids_not_in_cache, results_map), where
+ user_ids_not_in_cache is a set of user_ids and results_map is a
+ mapping of user_id -> device_id -> device_info.
"""
user_ids = {user_id for user_id, _ in query_list}
- user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+ user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids))
# We go and check if any of the users need to have their device lists
# resynced. If they do then we remove them from the cached list.
- users_needing_resync = yield self.get_user_ids_requiring_device_list_resync(
+ users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
user_ids
)
user_ids_in_cache = {
@@ -437,19 +449,19 @@ class DeviceWorkerStore(SQLBaseStore):
continue
if device_id:
- device = yield self._get_cached_user_device(user_id, device_id)
+ device = await self._get_cached_user_device(user_id, device_id)
results.setdefault(user_id, {})[device_id] = device
else:
- results[user_id] = yield self.get_cached_devices_for_user(user_id)
+ results[user_id] = await self.get_cached_devices_for_user(user_id)
set_tag("in_cache", results)
set_tag("not_in_cache", user_ids_not_in_cache)
return user_ids_not_in_cache, results
- @cachedInlineCallbacks(num_args=2, tree=True)
- def _get_cached_user_device(self, user_id, device_id):
- content = yield self.db.simple_select_one_onecol(
+ @cached(num_args=2, tree=True)
+ async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict:
+ content = await self.db_pool.simple_select_one_onecol(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="content",
@@ -457,9 +469,9 @@ class DeviceWorkerStore(SQLBaseStore):
)
return db_to_json(content)
- @cachedInlineCallbacks()
- def get_cached_devices_for_user(self, user_id):
- devices = yield self.db.simple_select_list(
+ @cached()
+ async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
+ devices = await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
retcols=("device_id", "content"),
@@ -469,19 +481,21 @@ class DeviceWorkerStore(SQLBaseStore):
device["device_id"]: db_to_json(device["content"]) for device in devices
}
- def get_devices_with_keys_by_user(self, user_id):
+ def get_devices_with_keys_by_user(self, user_id: str):
"""Get all devices (with any device keys) for a user
Returns:
- (stream_id, devices)
+ Deferred which resolves to (stream_id, devices)
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn,
user_id,
)
- def _get_devices_with_keys_by_user_txn(self, txn, user_id):
+ def _get_devices_with_keys_by_user_txn(
+ self, txn: LoggingTransaction, user_id: str
+ ) -> Tuple[int, List[JsonDict]]:
now_stream_id = self._device_list_id_gen.get_current_token()
devices = self._get_e2e_device_keys_txn(
@@ -514,17 +528,18 @@ class DeviceWorkerStore(SQLBaseStore):
return now_stream_id, []
- def get_users_whose_devices_changed(self, from_key, user_ids):
+ async def get_users_whose_devices_changed(
+ self, from_key: str, user_ids: Iterable[str]
+ ) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids.
Args:
- from_key (str): The device lists stream token
- user_ids (Iterable[str])
+ from_key: The device lists stream token
+ user_ids: The user IDs to query for devices.
Returns:
- Deferred[set[str]]: The set of user_ids whose devices have changed
- since `from_key`
+ The set of user_ids whose devices have changed since `from_key`
"""
from_key = int(from_key)
@@ -535,7 +550,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
if not to_check:
- return defer.succeed(set())
+ return set()
def _get_users_whose_devices_changed_txn(txn):
changes = set()
@@ -555,18 +570,22 @@ class DeviceWorkerStore(SQLBaseStore):
return changes
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
)
- @defer.inlineCallbacks
- def get_users_whose_signatures_changed(self, user_id, from_key):
+ async def get_users_whose_signatures_changed(
+ self, user_id: str, from_key: str
+ ) -> Set[str]:
"""Get the users who have new cross-signing signatures made by `user_id` since
`from_key`.
Args:
- user_id (str): the user who made the signatures
- from_key (str): The device lists stream token
+ user_id: the user who made the signatures
+ from_key: The device lists stream token
+
+ Returns:
+ A set of user IDs with updated signatures.
"""
from_key = int(from_key)
if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
@@ -574,7 +593,7 @@ class DeviceWorkerStore(SQLBaseStore):
SELECT DISTINCT user_ids FROM user_signature_stream
WHERE from_user_id = ? AND stream_id > ?
"""
- rows = yield self.db.execute(
+ rows = await self.db_pool.execute(
"get_users_whose_signatures_changed", None, sql, user_id, from_key
)
return {user for row in rows for user in db_to_json(row[0])}
@@ -600,7 +619,7 @@ class DeviceWorkerStore(SQLBaseStore):
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
- function to get further updatees.
+ function to get further updates.
The updates are a list of 2-tuples of stream ID and the row data
"""
@@ -631,17 +650,17 @@ class DeviceWorkerStore(SQLBaseStore):
return updates, upto_token, limited
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_device_list_changes_for_remotes",
_get_all_device_list_changes_for_remotes,
)
@cached(max_entries=10000)
- def get_device_list_last_stream_id_for_remote(self, user_id):
+ def get_device_list_last_stream_id_for_remote(self, user_id: str):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
- return self.db.simple_select_one_onecol(
+ return self.db_pool.simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
@@ -654,8 +673,8 @@ class DeviceWorkerStore(SQLBaseStore):
list_name="user_ids",
inlineCallbacks=True,
)
- def get_device_list_last_stream_id_for_remotes(self, user_ids):
- rows = yield self.db.simple_select_many_batch(
+ def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
+ rows = yield self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
@@ -668,8 +687,7 @@ class DeviceWorkerStore(SQLBaseStore):
return results
- @defer.inlineCallbacks
- def get_user_ids_requiring_device_list_resync(
+ async def get_user_ids_requiring_device_list_resync(
self, user_ids: Optional[Collection[str]] = None,
) -> Set[str]:
"""Given a list of remote users return the list of users that we
@@ -680,7 +698,7 @@ class DeviceWorkerStore(SQLBaseStore):
The IDs of users whose device lists need resync.
"""
if user_ids:
- rows = yield self.db.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_resync",
column="user_id",
iterable=user_ids,
@@ -688,7 +706,7 @@ class DeviceWorkerStore(SQLBaseStore):
desc="get_user_ids_requiring_device_list_resync_with_iterable",
)
else:
- rows = yield self.db.simple_select_list(
+ rows = await self.db_pool.simple_select_list(
table="device_lists_remote_resync",
keyvalues=None,
retcols=("user_id",),
@@ -701,7 +719,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""Records that the server has reason to believe the cache of the devices
for the remote users is out of date.
"""
- return self.db.simple_upsert(
+ return self.db_pool.simple_upsert(
table="device_lists_remote_resync",
keyvalues={"user_id": user_id},
values={},
@@ -709,12 +727,12 @@ class DeviceWorkerStore(SQLBaseStore):
desc="make_remote_user_device_cache_as_stale",
)
- def mark_remote_user_device_list_as_unsubscribed(self, user_id):
+ def mark_remote_user_device_list_as_unsubscribed(self, user_id: str):
"""Mark that we no longer track device lists for remote user.
"""
def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
@@ -723,17 +741,17 @@ class DeviceWorkerStore(SQLBaseStore):
txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"mark_remote_user_device_list_as_unsubscribed",
_mark_remote_user_device_list_as_unsubscribed_txn,
)
class DeviceBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"device_lists_stream_idx",
index_name="device_lists_stream_user_id",
table="device_lists_stream",
@@ -741,7 +759,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
)
# create a unique index on device_lists_remote_cache
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"device_lists_remote_cache_unique_idx",
index_name="device_lists_remote_cache_unique_id",
table="device_lists_remote_cache",
@@ -750,7 +768,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
)
# And one on device_lists_remote_extremeties
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"device_lists_remote_extremeties_unique_idx",
index_name="device_lists_remote_extremeties_unique_idx",
table="device_lists_remote_extremeties",
@@ -759,35 +777,34 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
)
# once they complete, we can remove the old non-unique indexes.
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES,
self._drop_device_list_streams_non_unique_indexes,
)
# clear out duplicate device list outbound pokes
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
)
# a pair of background updates that were added during the 1.14 release cycle,
# but replaced with 58/06dlols_unique_idx.py
- self.db.updates.register_noop_background_update(
+ self.db_pool.updates.register_noop_background_update(
"device_lists_outbound_last_success_unique_idx",
)
- self.db.updates.register_noop_background_update(
+ self.db_pool.updates.register_noop_background_update(
"drop_device_lists_outbound_last_success_non_unique_idx",
)
- @defer.inlineCallbacks
- def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
+ async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
txn.close()
- yield self.db.runWithConnection(f)
- yield self.db.updates._end_background_update(
+ await self.db_pool.runWithConnection(f)
+ await self.db_pool.updates._end_background_update(
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
)
return 1
@@ -807,7 +824,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
def _txn(txn):
clause, args = make_tuple_comparison_clause(
- self.db.engine, [(x, last_row[x]) for x in KEY_COLS]
+ self.db_pool.engine, [(x, last_row[x]) for x in KEY_COLS]
)
sql = """
SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
@@ -823,30 +840,32 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
",".join(KEY_COLS), # ORDER BY
)
txn.execute(sql, args + [batch_size])
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
row = None
for row in rows:
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS},
)
row["sent"] = False
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn, "device_lists_outbound_pokes", row,
)
if row:
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row},
)
return len(rows)
- rows = await self.db.runInteraction(BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn)
+ rows = await self.db_pool.runInteraction(
+ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn
+ )
if not rows:
- await self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES
)
@@ -854,7 +873,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(DeviceStore, self).__init__(database, db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
@@ -865,18 +884,20 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
- @defer.inlineCallbacks
- def store_device(self, user_id, device_id, initial_device_display_name):
+ async def store_device(
+ self, user_id: str, device_id: str, initial_device_display_name: str
+ ) -> bool:
"""Ensure the given device is known; add it to the store if not
Args:
- user_id (str): id of user associated with the device
- device_id (str): id of device
- initial_device_display_name (str): initial displayname of the
- device. Ignored if device exists.
+ user_id: id of user associated with the device
+ device_id: id of device
+ initial_device_display_name: initial displayname of the device.
+ Ignored if device exists.
+
Returns:
- defer.Deferred: boolean whether the device was inserted or an
- existing device existed with that ID.
+ Whether the device was inserted or an existing device existed with that ID.
+
Raises:
StoreError: if the device is already in use
"""
@@ -885,7 +906,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return False
try:
- inserted = yield self.db.simple_insert(
+ inserted = await self.db_pool.simple_insert(
"devices",
values={
"user_id": user_id,
@@ -899,7 +920,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not inserted:
# if the device already exists, check if it's a real device, or
# if the device ID is reserved by something else
- hidden = yield self.db.simple_select_one_onecol(
+ hidden = await self.db_pool.simple_select_one_onecol(
"devices",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="hidden",
@@ -924,17 +945,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
raise StoreError(500, "Problem storing device.")
- @defer.inlineCallbacks
- def delete_device(self, user_id, device_id):
+ async def delete_device(self, user_id: str, device_id: str) -> None:
"""Delete a device.
Args:
- user_id (str): The ID of the user which owns the device
- device_id (str): The ID of the device to delete
- Returns:
- defer.Deferred
+ user_id: The ID of the user which owns the device
+ device_id: The ID of the device to delete
"""
- yield self.db.simple_delete_one(
+ await self.db_pool.simple_delete_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
desc="delete_device",
@@ -942,17 +960,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.device_id_exists_cache.invalidate((user_id, device_id))
- @defer.inlineCallbacks
- def delete_devices(self, user_id, device_ids):
+ async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
"""Deletes several devices.
Args:
- user_id (str): The ID of the user which owns the devices
- device_ids (list): The IDs of the devices to delete
- Returns:
- defer.Deferred
+ user_id: The ID of the user which owns the devices
+ device_ids: The IDs of the devices to delete
"""
- yield self.db.simple_delete_many(
+ await self.db_pool.simple_delete_many(
table="devices",
column="device_id",
iterable=device_ids,
@@ -962,26 +977,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
- def update_device(self, user_id, device_id, new_display_name=None):
+ async def update_device(
+ self, user_id: str, device_id: str, new_display_name: Optional[str] = None
+ ) -> None:
"""Update a device. Only updates the device if it is not marked as
hidden.
Args:
- user_id (str): The ID of the user which owns the device
- device_id (str): The ID of the device to update
- new_display_name (str|None): new displayname for device; None
- to leave unchanged
+ user_id: The ID of the user which owns the device
+ device_id: The ID of the device to update
+ new_display_name: new displayname for device; None to leave unchanged
Raises:
StoreError: if the device is not found
- Returns:
- defer.Deferred
"""
updates = {}
if new_display_name is not None:
updates["display_name"] = new_display_name
if not updates:
- return defer.succeed(None)
- return self.db.simple_update_one(
+ return None
+ await self.db_pool.simple_update_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
updatevalues=updates,
@@ -989,7 +1003,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
def update_remote_device_list_cache_entry(
- self, user_id, device_id, content, stream_id
+ self, user_id: str, device_id: str, content: JsonDict, stream_id: int
):
"""Updates a single device in the cache of a remote user's devicelist.
@@ -997,15 +1011,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device list.
Args:
- user_id (str): User to update device list for
- device_id (str): ID of decivice being updated
- content (dict): new data on this device
- stream_id (int): the version of the device list
+ user_id: User to update device list for
+ device_id: ID of decivice being updated
+ content: new data on this device
+ stream_id: the version of the device list
Returns:
Deferred[None]
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
user_id,
@@ -1015,10 +1029,15 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
def _update_remote_device_list_cache_entry_txn(
- self, txn, user_id, device_id, content, stream_id
- ):
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ content: JsonDict,
+ stream_id: int,
+ ) -> None:
if content.get("deleted"):
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -1026,11 +1045,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
else:
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
- values={"content": json.dumps(content)},
+ values={"content": json_encoder.encode(content)},
# we don't need to lock, because we assume we are the only thread
# updating this user's devices.
lock=False,
@@ -1042,7 +1061,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
)
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
@@ -1052,21 +1071,23 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
lock=False,
)
- def update_remote_device_list_cache(self, user_id, devices, stream_id):
+ def update_remote_device_list_cache(
+ self, user_id: str, devices: List[dict], stream_id: int
+ ):
"""Replace the entire cache of the remote user's devices.
Note: assumes that we are the only thread that can be updating this user's
device list.
Args:
- user_id (str): User to update device list for
- devices (list[dict]): list of device objects supplied over federation
- stream_id (int): the version of the device list
+ user_id: User to update device list for
+ devices: list of device objects supplied over federation
+ stream_id: the version of the device list
Returns:
Deferred[None]
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
user_id,
@@ -1074,19 +1095,21 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
stream_id,
)
- def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
- self.db.simple_delete_txn(
+ def _update_remote_device_list_cache_txn(
+ self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
+ ):
+ self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_remote_cache",
values=[
{
"user_id": user_id,
"device_id": content["device_id"],
- "content": json.dumps(content),
+ "content": json_encoder.encode(content),
}
for content in devices
],
@@ -1098,7 +1121,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
)
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
@@ -1111,12 +1134,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# If we're replacing the remote user's device list cache presumably
# we've done a full resync, so we remove the entry that says we need
# to resync
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
)
- @defer.inlineCallbacks
- def add_device_change_to_streams(self, user_id, device_ids, hosts):
+ async def add_device_change_to_streams(
+ self, user_id: str, device_ids: Collection[str], hosts: List[str]
+ ):
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
@@ -1124,7 +1148,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return
with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"add_device_change_to_stream",
self._add_device_change_to_stream_txn,
user_id,
@@ -1139,7 +1163,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
with self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"add_device_outbound_poke_to_stream",
self._add_device_outbound_poke_to_stream_txn,
user_id,
@@ -1174,7 +1198,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
[(user_id, device_id, min_stream_id) for device_id in device_ids],
)
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_stream",
values=[
@@ -1184,7 +1208,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
def _add_device_outbound_poke_to_stream_txn(
- self, txn, user_id, device_ids, hosts, stream_ids, context,
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_ids: Collection[str],
+ hosts: List[str],
+ stream_ids: List[str],
+ context: Dict[str, str],
):
for host in hosts:
txn.call_after(
@@ -1196,7 +1226,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
now = self._clock.time_msec()
next_stream_id = iter(stream_ids)
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
values=[
@@ -1207,7 +1237,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"device_id": device_id,
"sent": False,
"ts": now,
- "opentracing_context": json.dumps(context)
+ "opentracing_context": json_encoder.encode(context)
if whitelisted_homeserver(destination)
else "{}",
}
@@ -1216,7 +1246,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
],
)
- def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000):
+ def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000):
"""Delete old entries out of the device_lists_outbound_pokes to ensure
that we don't fill up due to dead servers.
@@ -1303,7 +1333,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return run_as_background_process(
"prune_old_outbound_device_pokes",
- self.db.runInteraction,
+ self.db_pool.runInteraction,
"_prune_old_outbound_device_pokes",
_prune_txn,
)
diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/databases/main/directory.py
index e1d1bc3e05..037e02603c 100644
--- a/synapse/storage/data_stores/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -14,30 +14,29 @@
# limitations under the License.
from collections import namedtuple
-from typing import Optional
-
-from twisted.internet import defer
+from typing import Iterable, Optional
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
+from synapse.types import RoomAlias
from synapse.util.caches.descriptors import cached
RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
class DirectoryWorkerStore(SQLBaseStore):
- @defer.inlineCallbacks
- def get_association_from_room_alias(self, room_alias):
- """ Get's the room_id and server list for a given room_alias
+ async def get_association_from_room_alias(
+ self, room_alias: RoomAlias
+ ) -> Optional[RoomAliasMapping]:
+ """Gets the room_id and server list for a given room_alias
Args:
- room_alias (RoomAlias)
+ room_alias: The alias to translate to an ID.
Returns:
- Deferred: results in namedtuple with keys "room_id" and
- "servers" or None if no association can be found
+ The room alias mapping or None if no association can be found.
"""
- room_id = yield self.db.simple_select_one_onecol(
+ room_id = await self.db_pool.simple_select_one_onecol(
"room_aliases",
{"room_alias": room_alias.to_string()},
"room_id",
@@ -48,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore):
if not room_id:
return None
- servers = yield self.db.simple_select_onecol(
+ servers = await self.db_pool.simple_select_onecol(
"room_alias_servers",
{"room_alias": room_alias.to_string()},
"server",
@@ -61,7 +60,7 @@ class DirectoryWorkerStore(SQLBaseStore):
return RoomAliasMapping(room_id, room_alias.to_string(), servers)
def get_room_alias_creator(self, room_alias):
- return self.db.simple_select_one_onecol(
+ return self.db_pool.simple_select_one_onecol(
table="room_aliases",
keyvalues={"room_alias": room_alias},
retcol="creator",
@@ -70,7 +69,7 @@ class DirectoryWorkerStore(SQLBaseStore):
@cached(max_entries=5000)
def get_aliases_for_room(self, room_id):
- return self.db.simple_select_onecol(
+ return self.db_pool.simple_select_onecol(
"room_aliases",
{"room_id": room_id},
"room_alias",
@@ -79,22 +78,24 @@ class DirectoryWorkerStore(SQLBaseStore):
class DirectoryStore(DirectoryWorkerStore):
- @defer.inlineCallbacks
- def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
+ async def create_room_alias_association(
+ self,
+ room_alias: RoomAlias,
+ room_id: str,
+ servers: Iterable[str],
+ creator: Optional[str] = None,
+ ) -> None:
""" Creates an association between a room alias and room_id/servers
Args:
- room_alias (RoomAlias)
- room_id (str)
- servers (list)
- creator (str): Optional user_id of creator.
-
- Returns:
- Deferred
+ room_alias: The alias to create.
+ room_id: The target of the alias.
+ servers: A list of servers through which it may be possible to join the room
+ creator: Optional user_id of creator.
"""
def alias_txn(txn):
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
"room_aliases",
{
@@ -104,7 +105,7 @@ class DirectoryStore(DirectoryWorkerStore):
},
)
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="room_alias_servers",
values=[
@@ -118,24 +119,22 @@ class DirectoryStore(DirectoryWorkerStore):
)
try:
- ret = yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"create_room_alias_association", alias_txn
)
except self.database_engine.module.IntegrityError:
raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string()
)
- return ret
- @defer.inlineCallbacks
- def delete_room_alias(self, room_alias):
- room_id = yield self.db.runInteraction(
+ async def delete_room_alias(self, room_alias: RoomAlias) -> str:
+ room_id = await self.db_pool.runInteraction(
"delete_room_alias", self._delete_room_alias_txn, room_alias
)
return room_id
- def _delete_room_alias_txn(self, txn, room_alias):
+ def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str:
txn.execute(
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
(room_alias.to_string(),),
@@ -190,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore):
txn, self.get_aliases_for_room, (new_room_id,)
)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
)
diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 615364f018..2eeb9f97dc 100644
--- a/synapse/storage/data_stores/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -14,18 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from canonicaljson import json
-
-from twisted.internet import defer
-
from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.util import json_encoder
class EndToEndRoomKeyStore(SQLBaseStore):
- @defer.inlineCallbacks
- def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
+ async def update_e2e_room_key(
+ self, user_id, version, room_id, session_id, room_key
+ ):
"""Replaces the encrypted E2E room key for a given session in a given backup
Args:
@@ -38,7 +36,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
StoreError
"""
- yield self.db.simple_update_one(
+ await self.db_pool.simple_update_one(
table="e2e_room_keys",
keyvalues={
"user_id": user_id,
@@ -50,13 +48,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"first_message_index": room_key["first_message_index"],
"forwarded_count": room_key["forwarded_count"],
"is_verified": room_key["is_verified"],
- "session_data": json.dumps(room_key["session_data"]),
+ "session_data": json_encoder.encode(room_key["session_data"]),
},
desc="update_e2e_room_key",
)
- @defer.inlineCallbacks
- def add_e2e_room_keys(self, user_id, version, room_keys):
+ async def add_e2e_room_keys(self, user_id, version, room_keys):
"""Bulk add room keys to a given backup.
Args:
@@ -77,7 +74,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"first_message_index": room_key["first_message_index"],
"forwarded_count": room_key["forwarded_count"],
"is_verified": room_key["is_verified"],
- "session_data": json.dumps(room_key["session_data"]),
+ "session_data": json_encoder.encode(room_key["session_data"]),
}
)
log_kv(
@@ -89,13 +86,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
}
)
- yield self.db.simple_insert_many(
+ await self.db_pool.simple_insert_many(
table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
)
@trace
- @defer.inlineCallbacks
- def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
@@ -110,7 +106,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
the backup (or for the specified room)
Returns:
- A deferred list of dicts giving the session_data and message metadata for
+ A list of dicts giving the session_data and message metadata for
these room keys.
"""
@@ -125,7 +121,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
if session_id:
keyvalues["session_id"] = session_id
- rows = yield self.db.simple_select_list(
+ rows = await self.db_pool.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
@@ -171,7 +167,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_e2e_room_keys_multi",
self._get_e2e_room_keys_multi_txn,
user_id,
@@ -235,7 +231,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
version (str): the version ID of the backup we're querying about
"""
- return self.db.simple_select_one_onecol(
+ return self.db_pool.simple_select_one_onecol(
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": version},
retcol="COUNT(*)",
@@ -243,8 +239,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
@trace
- @defer.inlineCallbacks
- def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
+ async def delete_e2e_room_keys(
+ self, user_id, version, room_id=None, session_id=None
+ ):
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session.
@@ -259,7 +256,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
the backup (or for the specified room)
Returns:
- A deferred of the deletion transaction
+ The deletion transaction
"""
keyvalues = {"user_id": user_id, "version": int(version)}
@@ -268,7 +265,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
if session_id:
keyvalues["session_id"] = session_id
- yield self.db.simple_delete(
+ await self.db_pool.simple_delete(
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
)
@@ -313,7 +310,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
# it isn't there.
raise StoreError(404, "No row found")
- result = self.db.simple_select_one_txn(
+ result = self.db_pool.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
@@ -325,7 +322,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
result["etag"] = 0
return result
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
)
@@ -353,20 +350,20 @@ class EndToEndRoomKeyStore(SQLBaseStore):
new_version = str(int(current_version) + 1)
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="e2e_room_keys_versions",
values={
"user_id": user_id,
"version": new_version,
"algorithm": info["algorithm"],
- "auth_data": json.dumps(info["auth_data"]),
+ "auth_data": json_encoder.encode(info["auth_data"]),
},
)
return new_version
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
)
@@ -387,12 +384,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
updatevalues = {}
if info is not None and "auth_data" in info:
- updatevalues["auth_data"] = json.dumps(info["auth_data"])
+ updatevalues["auth_data"] = json_encoder.encode(info["auth_data"])
if version_etag is not None:
updatevalues["etag"] = version_etag
if updatevalues:
- return self.db.simple_update(
+ return self.db_pool.simple_update(
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": version},
updatevalues=updatevalues,
@@ -421,19 +418,19 @@ class EndToEndRoomKeyStore(SQLBaseStore):
else:
this_version = version
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": this_version},
)
- return self.db.simple_update_one_txn(
+ return self.db_pool.simple_update_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version},
updatevalues={"deleted": 1},
)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
)
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 317c07a829..f93e0d320d 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -14,24 +14,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, List, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from twisted.enterprise.adbapi import Connection
-from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
class EndToEndKeyWorkerStore(SQLBaseStore):
@trace
- @defer.inlineCallbacks
- def get_e2e_device_keys(
+ async def get_e2e_device_keys(
self, query_list, include_all_devices=False, include_deleted_devices=False
):
"""Fetch a list of device keys.
@@ -51,7 +50,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
if not query_list:
return {}
- results = yield self.db.runInteraction(
+ results = await self.db_pool.runInteraction(
"get_e2e_device_keys",
self._get_e2e_device_keys_txn,
query_list,
@@ -128,7 +127,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
txn.execute(sql, query_params)
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
result = {}
for row in rows:
@@ -146,7 +145,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
txn.execute(signature_sql, signature_query_params)
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
# add each cross-signing signature to the correct device in the result dict.
for row in rows:
@@ -174,8 +173,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
log_kv(result)
return result
- @defer.inlineCallbacks
- def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
+ async def get_e2e_one_time_keys(
+ self, user_id: str, device_id: str, key_ids: List[str]
+ ) -> Dict[Tuple[str, str], str]:
"""Retrieve a number of one-time keys for a user
Args:
@@ -185,11 +185,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
retrieve
Returns:
- deferred resolving to Dict[(str, str), str]: map from (algorithm,
- key_id) to json string for key
+ A map from (algorithm, key_id) to json string for key
"""
- rows = yield self.db.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="e2e_one_time_keys_json",
column="key_id",
iterable=key_ids,
@@ -201,17 +200,21 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
return result
- @defer.inlineCallbacks
- def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
+ async def add_e2e_one_time_keys(
+ self,
+ user_id: str,
+ device_id: str,
+ time_now: int,
+ new_keys: Iterable[Tuple[str, str, str]],
+ ) -> None:
"""Insert some new one time keys for a device. Errors if any of the
keys already exist.
Args:
- user_id(str): id of user to get keys for
- device_id(str): id of device to get keys for
- time_now(long): insertion time to record (ms since epoch)
- new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
- (algorithm, key_id, key json)
+ user_id: id of user to get keys for
+ device_id: id of device to get keys for
+ time_now: insertion time to record (ms since epoch)
+ new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
"""
def _add_e2e_one_time_keys(txn):
@@ -222,7 +225,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
# insert one set.
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="e2e_one_time_keys_json",
values=[
@@ -241,7 +244,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
)
@@ -264,26 +267,27 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result[algorithm] = key_count
return result
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
- @defer.inlineCallbacks
- def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
+ async def get_e2e_cross_signing_key(
+ self, user_id: str, key_type: str, from_user_id: Optional[str] = None
+ ) -> Optional[dict]:
"""Returns a user's cross-signing key.
Args:
- user_id (str): the user whose key is being requested
- key_type (str): the type of key that is being requested: either 'master'
+ user_id: the user whose key is being requested
+ key_type: the type of key that is being requested: either 'master'
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
- from_user_id (str): if specified, signatures made by this user on
+ from_user_id: if specified, signatures made by this user on
the self-signing key will be included in the result
Returns:
dict of the key data or None if not found
"""
- res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
+ res = await self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
user_keys = res.get(user_id)
if not user_keys:
return None
@@ -318,7 +322,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
to None.
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids,
@@ -361,7 +365,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
txn.execute(sql, params)
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
for row in rows:
user_id = row["user_id"]
@@ -420,7 +424,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
query_params.extend(item)
txn.execute(sql, query_params)
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
# and add the signatures to the appropriate keys
for row in rows:
@@ -449,28 +453,26 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return keys
- @defer.inlineCallbacks
- def get_e2e_cross_signing_keys_bulk(
- self, user_ids: List[str], from_user_id: str = None
- ) -> defer.Deferred:
+ async def get_e2e_cross_signing_keys_bulk(
+ self, user_ids: List[str], from_user_id: Optional[str] = None
+ ) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users.
Args:
- user_ids (list[str]): the users whose keys are being requested
- from_user_id (str): if specified, signatures made by this user on
+ user_ids: the users whose keys are being requested
+ from_user_id: if specified, signatures made by this user on
the self-signing keys will be included in the result
Returns:
- Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to
- key data. If a user's cross-signing keys were not found, either
- their user ID will not be in the dict, or their user ID will map
- to None.
+ A map of user ID to key type to key data. If a user's cross-signing
+ keys were not found, either their user ID will not be in the dict,
+ or their user ID will map to None.
"""
- result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
+ result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
if from_user_id:
- result = yield self.db.runInteraction(
+ result = await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_txn,
result,
@@ -531,7 +533,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return updates, upto_token, limited
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_user_signature_changes_for_remotes",
_get_all_user_signature_changes_for_remotes_txn,
)
@@ -549,7 +551,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
set_tag("time_now", time_now)
set_tag("device_keys", device_keys)
- old_key_json = self.db.simple_select_one_onecol_txn(
+ old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -565,7 +567,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
log_kv({"Message": "Device key already stored."})
return False
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -574,7 +576,9 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
log_kv({"message": "Device keys stored."})
return True
- return self.db.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
+ return self.db_pool.runInteraction(
+ "set_e2e_device_keys", _set_e2e_device_keys_txn
+ )
def claim_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database"""
@@ -613,7 +617,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
return result
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
@@ -626,12 +630,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"user_id": user_id,
}
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="e2e_one_time_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -640,7 +644,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
@@ -679,7 +683,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
# We only need to do this for local users, since remote servers should be
# responsible for checking this for their own users.
if self.hs.is_mine_id(user_id):
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
"devices",
values={
@@ -692,13 +696,13 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
# and finally, store the key itself
with self._cross_signing_id_gen.get_next() as stream_id:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
"e2e_cross_signing_keys",
values={
"user_id": user_id,
"keytype": key_type,
- "keydata": json.dumps(key),
+ "keydata": json_encoder.encode(key),
"stream_id": stream_id,
},
)
@@ -715,7 +719,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key_type (str): the type of cross-signing key to set
key (dict): the key data
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"add_e2e_cross_signing_key",
self._set_e2e_cross_signing_key_txn,
user_id,
@@ -730,7 +734,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
user_id (str): the user who made the signatures
signatures (iterable[SignatureListItem]): signatures to add
"""
- return self.db.simple_insert_many(
+ return self.db_pool.simple_insert_many(
"e2e_cross_signing_signatures",
[
{
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index a6bb3221ff..484875f989 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -15,16 +15,14 @@
import itertools
import logging
from queue import Empty, PriorityQueue
-from typing import Dict, List, Optional, Set, Tuple
-
-from twisted.internet import defer
+from typing import Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.util.caches.descriptors import cached
from synapse.util.iterutils import batch_iter
@@ -65,7 +63,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns:
list of event_ids
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
event_ids,
@@ -114,7 +112,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Deferred[Set[str]]
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_auth_chain_difference",
self._get_auth_chain_difference_txn,
state_sets,
@@ -260,12 +258,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return {eid for eid, n in event_to_missing_sets.items() if n}
def get_oldest_events_in_room(self, room_id):
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
)
def get_oldest_events_with_depth_in_room(self, room_id):
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_oldest_events_with_depth_in_room",
self.get_oldest_events_with_depth_in_room_txn,
room_id,
@@ -286,17 +284,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return dict(txn)
- @defer.inlineCallbacks
- def get_max_depth_of(self, event_ids):
+ async def get_max_depth_of(self, event_ids: List[str]) -> int:
"""Returns the max depth of a set of event IDs
Args:
- event_ids (list[str])
-
- Returns
- Deferred[int]
+ event_ids: The event IDs to calculate the max depth of.
"""
- rows = yield self.db.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
@@ -310,7 +304,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return max(row["depth"] for row in rows)
def _get_oldest_events_in_room_txn(self, txn, room_id):
- return self.db.simple_select_onecol_txn(
+ return self.db_pool.simple_select_onecol_txn(
txn,
table="event_backward_extremities",
keyvalues={"room_id": room_id},
@@ -332,7 +326,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
)
@@ -387,13 +381,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, query_args)
return [room_id for room_id, in txn]
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
)
@cached(max_entries=5000, iterable=True)
def get_latest_event_ids_in_room(self, room_id):
- return self.db.simple_select_onecol(
+ return self.db_pool.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
@@ -403,12 +397,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
def get_min_depth(self, room_id):
""" For hte given room, get the minimum depth we have seen for it.
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_min_depth", self._get_min_depth_interaction, room_id
)
def _get_min_depth_interaction(self, txn, room_id):
- min_depth = self.db.simple_select_one_onecol_txn(
+ min_depth = self.db_pool.simple_select_one_onecol_txn(
txn,
table="room_depth",
keyvalues={"room_id": room_id},
@@ -474,7 +468,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn]
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)
@@ -489,7 +483,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
limit (int)
"""
return (
- self.db.runInteraction(
+ self.db_pool.runInteraction(
"get_backfill_events",
self._get_backfill_events,
room_id,
@@ -520,7 +514,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
queue = PriorityQueue()
for event_id in event_list:
- depth = self.db.simple_select_one_onecol_txn(
+ depth = self.db_pool.simple_select_one_onecol_txn(
txn,
table="events",
keyvalues={"event_id": event_id, "room_id": room_id},
@@ -550,9 +544,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return event_results
- @defer.inlineCallbacks
- def get_missing_events(self, room_id, earliest_events, latest_events, limit):
- ids = yield self.db.runInteraction(
+ async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
+ ids = await self.db_pool.runInteraction(
"get_missing_events",
self._get_missing_events,
room_id,
@@ -560,7 +553,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
latest_events,
limit,
)
- events = yield self.get_events_as_list(ids)
+ events = await self.get_events_as_list(ids)
return events
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
@@ -595,17 +588,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_results.reverse()
return event_results
- @defer.inlineCallbacks
- def get_successor_events(self, event_ids):
+ async def get_successor_events(self, event_ids: Iterable[str]) -> List[str]:
"""Fetch all events that have the given events as a prev event
Args:
- event_ids (iterable[str])
-
- Returns:
- Deferred[list[str]]
+ event_ids: The events to use as the previous events.
"""
- rows = yield self.db.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="event_edges",
column="prev_event_id",
iterable=event_ids,
@@ -628,10 +617,10 @@ class EventFederationStore(EventFederationWorkerStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(EventFederationStore, self).__init__(database, db_conn, hs)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
)
@@ -658,13 +647,13 @@ class EventFederationStore(EventFederationWorkerStore):
return run_as_background_process(
"delete_old_forward_extrem_cache",
- self.db.runInteraction,
+ self.db_pool.runInteraction,
"_delete_old_forward_extrem_cache",
_delete_old_forward_extrem_cache_txn,
)
def clean_room_for_join(self, room_id):
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id
)
@@ -674,8 +663,7 @@ class EventFederationStore(EventFederationWorkerStore):
txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
- @defer.inlineCallbacks
- def _background_delete_non_state_event_auth(self, progress, batch_size):
+ async def _background_delete_non_state_event_auth(self, progress, batch_size):
def delete_event_auth(txn):
target_min_stream_id = progress.get("target_min_stream_id_inclusive")
max_stream_id = progress.get("max_stream_id_exclusive")
@@ -708,17 +696,19 @@ class EventFederationStore(EventFederationWorkerStore):
"max_stream_id_exclusive": min_stream_id,
}
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, self.EVENT_AUTH_STATE_ONLY, new_progress
)
return min_stream_id >= target_min_stream_id
- result = yield self.db.runInteraction(
+ result = await self.db_pool.runInteraction(
self.EVENT_AUTH_STATE_ONLY, delete_event_auth
)
if not result:
- yield self.db.updates._end_background_update(self.EVENT_AUTH_STATE_ONLY)
+ await self.db_pool.updates._end_background_update(
+ self.EVENT_AUTH_STATE_ONLY
+ )
return batch_size
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index ad82838901..7c246d3e4c 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -17,11 +17,10 @@
import logging
from typing import List
-from canonicaljson import json
-
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cachedInlineCallbacks
logger = logging.getLogger(__name__)
@@ -50,7 +49,7 @@ def _serialize_action(actions, is_highlight):
else:
if actions == DEFAULT_NOTIF_ACTION:
return ""
- return json.dumps(actions)
+ return json_encoder.encode(actions)
def _deserialize_action(actions, is_highlight):
@@ -66,7 +65,7 @@ def _deserialize_action(actions, is_highlight):
class EventPushActionsWorkerStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs)
# These get correctly set by _find_stream_orderings_for_times_txn
@@ -91,7 +90,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):
- ret = yield self.db.runInteraction(
+ ret = yield self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
room_id,
@@ -176,7 +175,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn]
- ret = await self.db.runInteraction("get_push_action_users_in_range", f)
+ ret = await self.db_pool.runInteraction("get_push_action_users_in_range", f)
return ret
async def get_unread_push_actions_for_user_in_range_for_http(
@@ -230,7 +229,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
- after_read_receipt = await self.db.runInteraction(
+ after_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
)
@@ -258,7 +257,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
- no_read_receipt = await self.db.runInteraction(
+ no_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
)
@@ -332,7 +331,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
- after_read_receipt = await self.db.runInteraction(
+ after_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
)
@@ -360,7 +359,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
- no_read_receipt = await self.db.runInteraction(
+ no_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
)
@@ -410,7 +409,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, min_stream_ordering))
return bool(txn.fetchone())
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_if_maybe_push_in_range_for_user",
_get_if_maybe_push_in_range_for_user_txn,
)
@@ -461,7 +460,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
),
)
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
)
@@ -471,7 +470,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"""
try:
- res = await self.db.simple_delete(
+ res = await self.db_pool.simple_delete(
table="event_push_actions_staging",
keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging",
@@ -488,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def _find_stream_orderings_for_times(self):
return run_as_background_process(
"event_push_action_stream_orderings",
- self.db.runInteraction,
+ self.db_pool.runInteraction,
"_find_stream_orderings_for_times",
self._find_stream_orderings_for_times_txn,
)
@@ -524,7 +523,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
Deferred[int]: stream ordering of the first event received on/after
the timestamp
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"_find_first_stream_ordering_after_ts_txn",
self._find_first_stream_ordering_after_ts_txn,
ts,
@@ -619,24 +618,26 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (stream_ordering,))
return txn.fetchone()
- result = await self.db.runInteraction("get_time_of_last_push_action_before", f)
+ result = await self.db_pool.runInteraction(
+ "get_time_of_last_push_action_before", f
+ )
return result[0] if result else None
class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(EventPushActionsStore, self).__init__(database, db_conn, hs)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
self.EPA_HIGHLIGHT_INDEX,
index_name="event_push_actions_u_highlight",
table="event_push_actions",
columns=["user_id", "stream_ordering"],
)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"event_push_actions_highlights_index",
index_name="event_push_actions_highlights_index",
table="event_push_actions",
@@ -678,9 +679,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
" LIMIT ?" % (before_clause,)
)
txn.execute(sql, args)
- return self.db.cursor_to_dict(txn)
+ return self.db_pool.cursor_to_dict(txn)
- push_actions = await self.db.runInteraction("get_push_actions_for_user", f)
+ push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
for pa in push_actions:
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions
@@ -690,7 +691,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
return txn.fetchone()
- result = await self.db.runInteraction(
+ result = await self.db_pool.runInteraction(
"get_latest_push_action_stream_ordering", f
)
return result[0] or 0
@@ -753,7 +754,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
while True:
logger.info("Rotating notifications")
- caught_up = await self.db.runInteraction(
+ caught_up = await self.db_pool.runInteraction(
"_rotate_notifs", self._rotate_notifs_txn
)
if caught_up:
@@ -767,7 +768,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
the archiving process has caught up or not.
"""
- old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn(
+ old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
@@ -803,7 +804,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
return caught_up
def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
- old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn(
+ old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
@@ -835,7 +836,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# If the `old.user_id` above is NULL then we know there isn't already an
# entry in the table, so we simply insert it. Otherwise we update the
# existing table.
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="event_push_summary",
values=[
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/databases/main/events.py
index 0c9c02afa1..1a68bf32cb 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -32,8 +32,8 @@ from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
-from synapse.storage.data_stores.main.search import SearchEntry
-from synapse.storage.database import Database, LoggingTransaction
+from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util.frozenutils import frozendict_json_encoder
@@ -41,7 +41,7 @@ from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.server import HomeServer
- from synapse.storage.data_stores.main import DataStore
+ from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@@ -53,47 +53,6 @@ event_counter = Counter(
["type", "origin_type", "origin_entity"],
)
-STATE_EVENT_TYPES_TO_MARK_UNREAD = {
- EventTypes.Topic,
- EventTypes.Name,
- EventTypes.RoomAvatar,
- EventTypes.Tombstone,
-}
-
-
-def should_count_as_unread(event: EventBase, context: EventContext) -> bool:
- # Exclude rejected and soft-failed events.
- if context.rejected or event.internal_metadata.is_soft_failed():
- return False
-
- # Exclude notices.
- if (
- not event.is_state()
- and event.type == EventTypes.Message
- and event.content.get("msgtype") == "m.notice"
- ):
- return False
-
- # Exclude edits.
- relates_to = event.content.get("m.relates_to", {})
- if relates_to.get("rel_type") == RelationTypes.REPLACE:
- return False
-
- # Mark events that have a non-empty string body as unread.
- body = event.content.get("body")
- if isinstance(body, str) and body:
- return True
-
- # Mark some state events as unread.
- if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
- return True
-
- # Mark encrypted events as unread.
- if not event.is_state() and event.type == EventTypes.Encrypted:
- return True
-
- return False
-
def encode_json(json_object):
"""
@@ -132,9 +91,11 @@ class PersistEventsStore:
Note: This is not part of the `DataStore` mixin.
"""
- def __init__(self, hs: "HomeServer", db: Database, main_data_store: "DataStore"):
+ def __init__(
+ self, hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore"
+ ):
self.hs = hs
- self.db = db
+ self.db_pool = db
self.store = main_data_store
self.database_engine = db.engine
self._clock = hs.get_clock()
@@ -207,7 +168,7 @@ class PersistEventsStore:
for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
- yield self.db.runInteraction(
+ yield self.db_pool.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=events_and_contexts,
@@ -237,10 +198,6 @@ class PersistEventsStore:
event_counter.labels(event.type, origin_type, origin_entity).inc()
- self.store.get_unread_message_count_for_user.invalidate_many(
- (event.room_id,),
- )
-
for room_id, new_state in current_state_for_room.items():
self.store.get_current_state_ids.prefill((room_id,), new_state)
@@ -283,7 +240,7 @@ class PersistEventsStore:
results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100):
- yield self.db.runInteraction(
+ yield self.db_pool.runInteraction(
"_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
)
@@ -347,7 +304,7 @@ class PersistEventsStore:
existing_prevs.add(prev_event_id)
for chunk in batch_iter(event_ids, 100):
- yield self.db.runInteraction(
+ yield self.db_pool.runInteraction(
"_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
)
@@ -421,7 +378,7 @@ class PersistEventsStore:
# event's auth chain, but its easier for now just to store them (and
# it doesn't take much storage compared to storing the entire event
# anyway).
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="event_auth",
values=[
@@ -484,7 +441,7 @@ class PersistEventsStore:
"""
txn.execute(sql, (stream_id, room_id))
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, table="current_state_events", keyvalues={"room_id": room_id},
)
else:
@@ -632,7 +589,7 @@ class PersistEventsStore:
creator = content.get("creator")
room_version_id = content.get("room_version", RoomVersions.V1.identifier)
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="rooms",
keyvalues={"room_id": room_id},
@@ -644,14 +601,14 @@ class PersistEventsStore:
self, txn, new_forward_extremities, max_stream_order
):
for room_id, new_extrem in new_forward_extremities.items():
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
)
txn.call_after(
self.store.get_latest_event_ids_in_room.invalidate, (room_id,)
)
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="event_forward_extremities",
values=[
@@ -664,7 +621,7 @@ class PersistEventsStore:
# new stream_ordering to new forward extremeties in the room.
# This allows us to later efficiently look up the forward extremeties
# for a room before a given stream_ordering
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="stream_ordering_to_exterm",
values=[
@@ -788,7 +745,7 @@ class PersistEventsStore:
# change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="ex_outlier_stream",
values={
@@ -826,7 +783,7 @@ class PersistEventsStore:
d.pop("redacted_because", None)
return d
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="event_json",
values=[
@@ -843,7 +800,7 @@ class PersistEventsStore:
],
)
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="events",
values=[
@@ -862,9 +819,8 @@ class PersistEventsStore:
"contains_url": (
"url" in event.content and isinstance(event.content["url"], str)
),
- "count_as_unread": should_count_as_unread(event, context),
}
- for event, context in events_and_contexts
+ for event, _ in events_and_contexts
],
)
@@ -873,7 +829,7 @@ class PersistEventsStore:
# If we're persisting an unredacted event we go and ensure
# that we mark any redactions that reference this event as
# requiring censoring.
- self.db.simple_update_txn(
+ self.db_pool.simple_update_txn(
txn,
table="redactions",
keyvalues={"redacts": event.event_id},
@@ -1015,7 +971,9 @@ class PersistEventsStore:
state_values.append(vals)
- self.db.simple_insert_many_txn(txn, table="state_events", values=state_values)
+ self.db_pool.simple_insert_many_txn(
+ txn, table="state_events", values=state_values
+ )
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
@@ -1046,7 +1004,7 @@ class PersistEventsStore:
)
txn.execute(sql + clause, args)
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
for row in rows:
event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]:
@@ -1066,7 +1024,7 @@ class PersistEventsStore:
# invalidate the cache for the redacted event
txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="redactions",
values={
@@ -1089,7 +1047,7 @@ class PersistEventsStore:
room_id (str): The ID of the room the event was sent to.
topological_ordering (int): The position of the event in the room's topology.
"""
- return self.db.simple_insert_many_txn(
+ return self.db_pool.simple_insert_many_txn(
txn=txn,
table="event_labels",
values=[
@@ -1111,7 +1069,7 @@ class PersistEventsStore:
event_id (str): The event ID the expiry timestamp is associated with.
expiry_ts (int): The timestamp at which to expire (delete) the event.
"""
- return self.db.simple_insert_txn(
+ return self.db_pool.simple_insert_txn(
txn=txn,
table="event_expiry",
values={"event_id": event_id, "expiry_ts": expiry_ts},
@@ -1135,12 +1093,14 @@ class PersistEventsStore:
}
)
- self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
+ self.db_pool.simple_insert_many_txn(
+ txn, table="event_reference_hashes", values=vals
+ )
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="room_memberships",
values=[
@@ -1180,7 +1140,7 @@ class PersistEventsStore:
and event.internal_metadata.is_outlier()
and event.internal_metadata.is_out_of_band_membership()
):
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="local_current_membership",
keyvalues={"room_id": event.room_id, "user_id": event.state_key},
@@ -1218,7 +1178,7 @@ class PersistEventsStore:
aggregation_key = relation.get("key")
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="event_relations",
values={
@@ -1246,7 +1206,7 @@ class PersistEventsStore:
redacted_event_id (str): The event that was redacted.
"""
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)
@@ -1282,7 +1242,7 @@ class PersistEventsStore:
# Ignore the event if one of the value isn't an integer.
return
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn=txn,
table="room_retention",
values={
@@ -1363,7 +1323,7 @@ class PersistEventsStore:
)
for event, _ in events_and_contexts:
- user_ids = self.db.simple_select_onecol_txn(
+ user_ids = self.db_pool.simple_select_onecol_txn(
txn,
table="event_push_actions_staging",
keyvalues={"event_id": event.event_id},
@@ -1395,7 +1355,7 @@ class PersistEventsStore:
)
def _store_rejections_txn(self, txn, event_id, reason):
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="rejections",
values={
@@ -1421,7 +1381,7 @@ class PersistEventsStore:
state_groups[event.event_id] = context.state_group
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[
@@ -1443,7 +1403,7 @@ class PersistEventsStore:
if min_depth is not None and depth >= min_depth:
return
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="room_depth",
keyvalues={"room_id": room_id},
@@ -1455,7 +1415,7 @@ class PersistEventsStore:
For the given event, update the event edges table and forward and
backward extremities tables.
"""
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="event_edges",
values=[
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 663c94b24f..35a0e09e3c 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.api.constants import EventContentFields
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
logger = logging.getLogger(__name__)
@@ -30,18 +30,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
self._background_reindex_fields_sender,
)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"event_contains_url_index",
index_name="event_contains_url_index",
table="events",
@@ -52,7 +52,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# an event_id index on event_search is useful for the purge_history
# api. Plus it means we get to enforce some integrity with a UNIQUE
# clause
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"event_search_event_id_idx",
index_name="event_search_event_id_idx",
table="event_search",
@@ -61,16 +61,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
psql_only=True,
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"redactions_received_ts", self._redactions_received_ts
)
# This index gets deleted in `event_fix_redactions_bytes` update
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"event_fix_redactions_bytes_create_index",
index_name="redactions_censored_redacts",
table="redactions",
@@ -78,15 +78,15 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
where_clause="have_censored",
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"event_fix_redactions_bytes", self._event_fix_redactions_bytes
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"event_store_labels", self._event_store_labels
)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"redactions_have_censored_ts_idx",
index_name="redactions_have_censored_ts",
table="redactions",
@@ -149,18 +149,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
"rows_inserted": rows_inserted + len(rows),
}
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
)
return len(rows)
- result = yield self.db.runInteraction(
+ result = yield self.db_pool.runInteraction(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
)
if not result:
- yield self.db.updates._end_background_update(
+ yield self.db_pool.updates._end_background_update(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
)
@@ -195,7 +195,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
for chunk in chunks:
- ev_rows = self.db.simple_select_many_txn(
+ ev_rows = self.db_pool.simple_select_many_txn(
txn,
table="event_json",
column="event_id",
@@ -228,18 +228,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
"rows_inserted": rows_inserted + len(rows_to_update),
}
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
)
return len(rows_to_update)
- result = yield self.db.runInteraction(
+ result = yield self.db_pool.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
)
if not result:
- yield self.db.updates._end_background_update(
+ yield self.db_pool.updates._end_background_update(
self.EVENT_ORIGIN_SERVER_TS_NAME
)
@@ -374,7 +374,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
to_delete.intersection_update(original_set)
- deleted = self.db.simple_delete_many_txn(
+ deleted = self.db_pool.simple_delete_many_txn(
txn=txn,
table="event_forward_extremities",
column="event_id",
@@ -390,7 +390,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
if deleted:
# We now need to invalidate the caches of these rooms
- rows = self.db.simple_select_many_txn(
+ rows = self.db_pool.simple_select_many_txn(
txn,
table="events",
column="event_id",
@@ -404,7 +404,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
- self.db.simple_delete_many_txn(
+ self.db_pool.simple_delete_many_txn(
txn=txn,
table="_extremities_to_check",
column="event_id",
@@ -414,19 +414,19 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(original_set)
- num_handled = yield self.db.runInteraction(
+ num_handled = yield self.db_pool.runInteraction(
"_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
)
if not num_handled:
- yield self.db.updates._end_background_update(
+ yield self.db_pool.updates._end_background_update(
self.DELETE_SOFT_FAILED_EXTREMITIES
)
def _drop_table_txn(txn):
txn.execute("DROP TABLE _extremities_to_check")
- yield self.db.runInteraction(
+ yield self.db_pool.runInteraction(
"_cleanup_extremities_bg_update_drop_table", _drop_table_txn
)
@@ -474,18 +474,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id))
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, "redactions_received_ts", {"last_event_id": upper_event_id}
)
return len(rows)
- count = yield self.db.runInteraction(
+ count = yield self.db_pool.runInteraction(
"_redactions_received_ts", _redactions_received_ts_txn
)
if not count:
- yield self.db.updates._end_background_update("redactions_received_ts")
+ yield self.db_pool.updates._end_background_update("redactions_received_ts")
return count
@@ -511,11 +511,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn.execute("DROP INDEX redactions_censored_redacts")
- yield self.db.runInteraction(
+ yield self.db_pool.runInteraction(
"_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
)
- yield self.db.updates._end_background_update("event_fix_redactions_bytes")
+ yield self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
return 1
@@ -543,7 +543,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
try:
event_json = db_to_json(event_json_raw)
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn=txn,
table="event_labels",
values=[
@@ -569,17 +569,17 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
nbrows += 1
last_row_event_id = event_id
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, "event_store_labels", {"last_event_id": last_row_event_id}
)
return nbrows
- num_rows = yield self.db.runInteraction(
+ num_rows = yield self.db_pool.runInteraction(
desc="event_store_labels", func=_event_store_labels_txn
)
if not num_rows:
- yield self.db.updates._end_background_update("event_store_labels")
+ yield self.db_pool.updates._end_background_update("event_store_labels")
return num_rows
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index b03b259636..755b7a2a85 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -40,16 +40,10 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import Database
-from synapse.storage.types import Cursor
+from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import (
- Cache,
- _CacheContext,
- cached,
- cachedInlineCallbacks,
-)
+from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -80,7 +74,7 @@ class EventRedactBehaviour(Names):
class EventsWorkerStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker.writers.events == hs.get_instance_name():
@@ -136,7 +130,7 @@ class EventsWorkerStore(SQLBaseStore):
Deferred[int|None]: Timestamp in milliseconds, or None for events
that were persisted before received_ts was implemented.
"""
- return self.db.simple_select_one_onecol(
+ return self.db_pool.simple_select_one_onecol(
table="events",
keyvalues={"event_id": event_id},
retcol="received_ts",
@@ -175,7 +169,7 @@ class EventsWorkerStore(SQLBaseStore):
return ts
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_approximate_received_ts", _get_approximate_received_ts_txn
)
@@ -543,7 +537,7 @@ class EventsWorkerStore(SQLBaseStore):
event_id for events, _ in event_list for event_id in events
}
- row_dict = self.db.new_transaction(
+ row_dict = self.db_pool.new_transaction(
conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
)
@@ -720,7 +714,7 @@ class EventsWorkerStore(SQLBaseStore):
if should_start:
run_as_background_process(
- "fetch_events", self.db.runWithConnection, self._do_fetch
+ "fetch_events", self.db_pool.runWithConnection, self._do_fetch
)
logger.debug("Loading %d events: %s", len(events), events)
@@ -889,7 +883,7 @@ class EventsWorkerStore(SQLBaseStore):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
- rows = yield self.db.simple_select_many_batch(
+ rows = yield self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
@@ -924,7 +918,7 @@ class EventsWorkerStore(SQLBaseStore):
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
- yield self.db.runInteraction(
+ yield self.db_pool.runInteraction(
"have_seen_events", have_seen_events_txn, chunk
)
return results
@@ -953,7 +947,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
Deferred[int]
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_total_state_event_counts",
self._get_total_state_event_counts_txn,
room_id,
@@ -978,7 +972,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
Deferred[int]
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_current_state_event_counts",
self._get_current_state_event_counts_txn,
room_id,
@@ -1043,7 +1037,7 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
@@ -1077,7 +1071,7 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id))
return txn.fetchall()
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
)
@@ -1151,7 +1145,7 @@ class EventsWorkerStore(SQLBaseStore):
return new_event_updates, upper_bound, limited
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
@@ -1199,7 +1193,7 @@ class EventsWorkerStore(SQLBaseStore):
# we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id.
- rows = await self.db.runInteraction(
+ rows = await self.db_pool.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
) # type: List[Tuple]
@@ -1222,7 +1216,7 @@ class EventsWorkerStore(SQLBaseStore):
# stream id. let's run the query again, without a row limit, but for
# just one stream id.
to_token += 1
- rows = await self.db.runInteraction(
+ rows = await self.db_pool.runInteraction(
"get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token
)
@@ -1317,7 +1311,7 @@ class EventsWorkerStore(SQLBaseStore):
backward_ex_outliers,
)
- return self.db.runInteraction("get_all_new_events", get_all_new_events_txn)
+ return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn)
async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
@@ -1328,7 +1322,7 @@ class EventsWorkerStore(SQLBaseStore):
@cachedInlineCallbacks(max_entries=5000)
def get_event_ordering(self, event_id):
- res = yield self.db.simple_select_one(
+ res = yield self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
@@ -1360,88 +1354,10 @@ class EventsWorkerStore(SQLBaseStore):
return txn.fetchone()
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
- @cached(tree=True, cache_context=True)
- async def get_unread_message_count_for_user(
- self, room_id: str, user_id: str, cache_context: _CacheContext,
- ) -> int:
- """Retrieve the count of unread messages for the given room and user.
-
- Args:
- room_id: The ID of the room to count unread messages in.
- user_id: The ID of the user to count unread messages for.
-
- Returns:
- The number of unread messages for the given user in the given room.
- """
- with Measure(self._clock, "get_unread_message_count_for_user"):
- last_read_event_id = await self.get_last_receipt_event_id_for_user(
- user_id=user_id,
- room_id=room_id,
- receipt_type="m.read",
- on_invalidate=cache_context.invalidate,
- )
-
- return await self.db.runInteraction(
- "get_unread_message_count_for_user",
- self._get_unread_message_count_for_user_txn,
- user_id,
- room_id,
- last_read_event_id,
- )
-
- def _get_unread_message_count_for_user_txn(
- self,
- txn: Cursor,
- user_id: str,
- room_id: str,
- last_read_event_id: Optional[str],
- ) -> int:
- if last_read_event_id:
- # Get the stream ordering for the last read event.
- stream_ordering = self.db.simple_select_one_onecol_txn(
- txn=txn,
- table="events",
- keyvalues={"room_id": room_id, "event_id": last_read_event_id},
- retcol="stream_ordering",
- )
- else:
- # If there's no read receipt for that room, it probably means the user hasn't
- # opened it yet, in which case use the stream ID of their join event.
- # We can't just set it to 0 otherwise messages from other local users from
- # before this user joined will be counted as well.
- txn.execute(
- """
- SELECT stream_ordering FROM local_current_membership
- LEFT JOIN events USING (event_id, room_id)
- WHERE membership = 'join'
- AND user_id = ?
- AND room_id = ?
- """,
- (user_id, room_id),
- )
- row = txn.fetchone()
-
- if row is None:
- return 0
-
- stream_ordering = row[0]
-
- # Count the messages that qualify as unread after the stream ordering we've just
- # retrieved.
- sql = """
- SELECT COUNT(*) FROM events
- WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread
- """
-
- txn.execute(sql, (user_id, room_id, stream_ordering))
- row = txn.fetchone()
-
- return row[0] if row else 0
-
AllNewEventsResult = namedtuple(
"AllNewEventsResult",
diff --git a/synapse/storage/data_stores/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 342d6622a4..45a1760170 100644
--- a/synapse/storage/data_stores/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -17,12 +17,12 @@ from canonicaljson import encode_canonical_json
from synapse.api.errors import Codes, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
class FilteringStore(SQLBaseStore):
- @cachedInlineCallbacks(num_args=2)
- def get_user_filter(self, user_localpart, filter_id):
+ @cached(num_args=2)
+ async def get_user_filter(self, user_localpart, filter_id):
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
# with a coherent error message rather than 500 M_UNKNOWN.
try:
@@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore):
except ValueError:
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
- def_json = yield self.db.simple_select_one_onecol(
+ def_json = await self.db_pool.simple_select_one_onecol(
table="user_filters",
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
retcol="filter_json",
@@ -71,4 +71,4 @@ class FilteringStore(SQLBaseStore):
return filter_id
- return self.db.runInteraction("add_user_filter", _do_txn)
+ return self.db_pool.runInteraction("add_user_filter", _do_txn)
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 01ff561e1a..380db3a3f3 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -14,14 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import List, Tuple
-
-from canonicaljson import json
-
-from twisted.internet import defer
+from typing import List, Optional, Tuple
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.types import JsonDict
+from synapse.util import json_encoder
# The category ID for the "default" category. We don't store as null in the
# database to avoid the fun of null != null
@@ -31,7 +29,7 @@ _DEFAULT_ROLE_ID = ""
class GroupServerWorkerStore(SQLBaseStore):
def get_group(self, group_id):
- return self.db.simple_select_one(
+ return self.db_pool.simple_select_one(
table="groups",
keyvalues={"group_id": group_id},
retcols=(
@@ -53,7 +51,7 @@ class GroupServerWorkerStore(SQLBaseStore):
if not include_private:
keyvalues["is_public"] = True
- return self.db.simple_select_list(
+ return self.db_pool.simple_select_list(
table="group_users",
keyvalues=keyvalues,
retcols=("user_id", "is_public", "is_admin"),
@@ -63,7 +61,7 @@ class GroupServerWorkerStore(SQLBaseStore):
def get_invited_users_in_group(self, group_id):
# TODO: Pagination
- return self.db.simple_select_onecol(
+ return self.db_pool.simple_select_onecol(
table="group_invites",
keyvalues={"group_id": group_id},
retcol="user_id",
@@ -117,7 +115,9 @@ class GroupServerWorkerStore(SQLBaseStore):
for room_id, is_public in txn
]
- return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn)
+ return self.db_pool.runInteraction(
+ "get_rooms_in_group", _get_rooms_in_group_txn
+ )
def get_rooms_for_summary_by_category(
self, group_id: str, include_private: bool = False,
@@ -205,13 +205,12 @@ class GroupServerWorkerStore(SQLBaseStore):
return rooms, categories
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_rooms_for_summary", _get_rooms_for_summary_txn
)
- @defer.inlineCallbacks
- def get_group_categories(self, group_id):
- rows = yield self.db.simple_select_list(
+ async def get_group_categories(self, group_id):
+ rows = await self.db_pool.simple_select_list(
table="group_room_categories",
keyvalues={"group_id": group_id},
retcols=("category_id", "is_public", "profile"),
@@ -226,9 +225,8 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in rows
}
- @defer.inlineCallbacks
- def get_group_category(self, group_id, category_id):
- category = yield self.db.simple_select_one(
+ async def get_group_category(self, group_id, category_id):
+ category = await self.db_pool.simple_select_one(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
retcols=("is_public", "profile"),
@@ -239,9 +237,8 @@ class GroupServerWorkerStore(SQLBaseStore):
return category
- @defer.inlineCallbacks
- def get_group_roles(self, group_id):
- rows = yield self.db.simple_select_list(
+ async def get_group_roles(self, group_id):
+ rows = await self.db_pool.simple_select_list(
table="group_roles",
keyvalues={"group_id": group_id},
retcols=("role_id", "is_public", "profile"),
@@ -256,9 +253,8 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in rows
}
- @defer.inlineCallbacks
- def get_group_role(self, group_id, role_id):
- role = yield self.db.simple_select_one(
+ async def get_group_role(self, group_id, role_id):
+ role = await self.db_pool.simple_select_one(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
retcols=("is_public", "profile"),
@@ -277,7 +273,7 @@ class GroupServerWorkerStore(SQLBaseStore):
Deferred[list[str]]: A twisted.Deferred containing a list of group ids
containing this room
"""
- return self.db.simple_select_onecol(
+ return self.db_pool.simple_select_onecol(
table="group_rooms",
keyvalues={"room_id": room_id},
retcol="group_id",
@@ -341,12 +337,12 @@ class GroupServerWorkerStore(SQLBaseStore):
return users, roles
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_users_for_summary_by_role", _get_users_for_summary_txn
)
def is_user_in_group(self, user_id, group_id):
- return self.db.simple_select_one_onecol(
+ return self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
@@ -355,7 +351,7 @@ class GroupServerWorkerStore(SQLBaseStore):
).addCallback(lambda r: bool(r))
def is_user_admin_in_group(self, group_id, user_id):
- return self.db.simple_select_one_onecol(
+ return self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="is_admin",
@@ -366,7 +362,7 @@ class GroupServerWorkerStore(SQLBaseStore):
def is_user_invited_to_local_group(self, group_id, user_id):
"""Has the group server invited a user?
"""
- return self.db.simple_select_one_onecol(
+ return self.db_pool.simple_select_one_onecol(
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
@@ -389,7 +385,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"""
def _get_users_membership_in_group_txn(txn):
- row = self.db.simple_select_one_txn(
+ row = self.db_pool.simple_select_one_txn(
txn,
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -404,7 +400,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"is_privileged": row["is_admin"],
}
- row = self.db.simple_select_one_onecol_txn(
+ row = self.db_pool.simple_select_one_onecol_txn(
txn,
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -417,14 +413,14 @@ class GroupServerWorkerStore(SQLBaseStore):
return {}
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_users_membership_info_in_group", _get_users_membership_in_group_txn
)
def get_publicised_groups_for_user(self, user_id):
"""Get all groups a user is publicising
"""
- return self.db.simple_select_onecol(
+ return self.db_pool.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
retcol="group_id",
@@ -441,18 +437,17 @@ class GroupServerWorkerStore(SQLBaseStore):
WHERE valid_until_ms <= ?
"""
txn.execute(sql, (valid_until_ms,))
- return self.db.cursor_to_dict(txn)
+ return self.db_pool.cursor_to_dict(txn)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
)
- @defer.inlineCallbacks
- def get_remote_attestation(self, group_id, user_id):
+ async def get_remote_attestation(self, group_id, user_id):
"""Get the attestation that proves the remote agrees that the user is
in the group.
"""
- row = yield self.db.simple_select_one(
+ row = await self.db_pool.simple_select_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
retcols=("valid_until_ms", "attestation_json"),
@@ -467,7 +462,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return None
def get_joined_groups(self, user_id):
- return self.db.simple_select_onecol(
+ return self.db_pool.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join"},
retcol="group_id",
@@ -494,17 +489,17 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in txn
]
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_all_groups_for_user", _get_all_groups_for_user_txn
)
- def get_groups_changes_for_user(self, user_id, from_token, to_token):
+ async def get_groups_changes_for_user(self, user_id, from_token, to_token):
from_token = int(from_token)
has_changed = self._group_updates_stream_cache.has_entity_changed(
user_id, from_token
)
if not has_changed:
- return defer.succeed([])
+ return []
def _get_groups_changes_for_user_txn(txn):
sql = """
@@ -524,7 +519,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for group_id, membership, gtype, content_json in txn
]
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_groups_changes_for_user", _get_groups_changes_for_user_txn
)
@@ -579,7 +574,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return updates, upto_token, limited
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_groups_changes", _get_all_groups_changes_txn
)
@@ -592,7 +587,7 @@ class GroupServerStore(GroupServerWorkerStore):
* "invite"
* "open"
"""
- return self.db.simple_update_one(
+ return self.db_pool.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
updatevalues={"join_policy": join_policy},
@@ -600,7 +595,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"add_room_to_summary",
self._add_room_to_summary_txn,
group_id,
@@ -624,7 +619,7 @@ class GroupServerStore(GroupServerWorkerStore):
an order of 1 will put the room first. Otherwise, the room gets
added to the end.
"""
- room_in_group = self.db.simple_select_one_onecol_txn(
+ room_in_group = self.db_pool.simple_select_one_onecol_txn(
txn,
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
@@ -637,7 +632,7 @@ class GroupServerStore(GroupServerWorkerStore):
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
else:
- cat_exists = self.db.simple_select_one_onecol_txn(
+ cat_exists = self.db_pool.simple_select_one_onecol_txn(
txn,
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
@@ -648,7 +643,7 @@ class GroupServerStore(GroupServerWorkerStore):
raise SynapseError(400, "Category doesn't exist")
# TODO: Check category is part of summary already
- cat_exists = self.db.simple_select_one_onecol_txn(
+ cat_exists = self.db_pool.simple_select_one_onecol_txn(
txn,
table="group_summary_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
@@ -668,7 +663,7 @@ class GroupServerStore(GroupServerWorkerStore):
(group_id, category_id, group_id, category_id),
)
- existing = self.db.simple_select_one_txn(
+ existing = self.db_pool.simple_select_one_txn(
txn,
table="group_summary_rooms",
keyvalues={
@@ -701,7 +696,7 @@ class GroupServerStore(GroupServerWorkerStore):
to_update["room_order"] = order
if is_public is not None:
to_update["is_public"] = is_public
- self.db.simple_update_txn(
+ self.db_pool.simple_update_txn(
txn,
table="group_summary_rooms",
keyvalues={
@@ -715,7 +710,7 @@ class GroupServerStore(GroupServerWorkerStore):
if is_public is None:
is_public = True
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="group_summary_rooms",
values={
@@ -731,7 +726,7 @@ class GroupServerStore(GroupServerWorkerStore):
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
- return self.db.simple_delete(
+ return self.db_pool.simple_delete(
table="group_summary_rooms",
keyvalues={
"group_id": group_id,
@@ -750,14 +745,14 @@ class GroupServerStore(GroupServerWorkerStore):
if profile is None:
insertion_values["profile"] = "{}"
else:
- update_values["profile"] = json.dumps(profile)
+ update_values["profile"] = json_encoder.encode(profile)
if is_public is None:
insertion_values["is_public"] = True
else:
update_values["is_public"] = is_public
- return self.db.simple_upsert(
+ return self.db_pool.simple_upsert(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
values=update_values,
@@ -766,7 +761,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
def remove_group_category(self, group_id, category_id):
- return self.db.simple_delete(
+ return self.db_pool.simple_delete(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
desc="remove_group_category",
@@ -781,14 +776,14 @@ class GroupServerStore(GroupServerWorkerStore):
if profile is None:
insertion_values["profile"] = "{}"
else:
- update_values["profile"] = json.dumps(profile)
+ update_values["profile"] = json_encoder.encode(profile)
if is_public is None:
insertion_values["is_public"] = True
else:
update_values["is_public"] = is_public
- return self.db.simple_upsert(
+ return self.db_pool.simple_upsert(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
values=update_values,
@@ -797,14 +792,14 @@ class GroupServerStore(GroupServerWorkerStore):
)
def remove_group_role(self, group_id, role_id):
- return self.db.simple_delete(
+ return self.db_pool.simple_delete(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
desc="remove_group_role",
)
def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"add_user_to_summary",
self._add_user_to_summary_txn,
group_id,
@@ -828,7 +823,7 @@ class GroupServerStore(GroupServerWorkerStore):
an order of 1 will put the user first. Otherwise, the user gets
added to the end.
"""
- user_in_group = self.db.simple_select_one_onecol_txn(
+ user_in_group = self.db_pool.simple_select_one_onecol_txn(
txn,
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -841,7 +836,7 @@ class GroupServerStore(GroupServerWorkerStore):
if role_id is None:
role_id = _DEFAULT_ROLE_ID
else:
- role_exists = self.db.simple_select_one_onecol_txn(
+ role_exists = self.db_pool.simple_select_one_onecol_txn(
txn,
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
@@ -852,7 +847,7 @@ class GroupServerStore(GroupServerWorkerStore):
raise SynapseError(400, "Role doesn't exist")
# TODO: Check role is part of the summary already
- role_exists = self.db.simple_select_one_onecol_txn(
+ role_exists = self.db_pool.simple_select_one_onecol_txn(
txn,
table="group_summary_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
@@ -872,7 +867,7 @@ class GroupServerStore(GroupServerWorkerStore):
(group_id, role_id, group_id, role_id),
)
- existing = self.db.simple_select_one_txn(
+ existing = self.db_pool.simple_select_one_txn(
txn,
table="group_summary_users",
keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id},
@@ -901,7 +896,7 @@ class GroupServerStore(GroupServerWorkerStore):
to_update["user_order"] = order
if is_public is not None:
to_update["is_public"] = is_public
- self.db.simple_update_txn(
+ self.db_pool.simple_update_txn(
txn,
table="group_summary_users",
keyvalues={
@@ -915,7 +910,7 @@ class GroupServerStore(GroupServerWorkerStore):
if is_public is None:
is_public = True
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="group_summary_users",
values={
@@ -931,7 +926,7 @@ class GroupServerStore(GroupServerWorkerStore):
if role_id is None:
role_id = _DEFAULT_ROLE_ID
- return self.db.simple_delete(
+ return self.db_pool.simple_delete(
table="group_summary_users",
keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id},
desc="remove_user_from_summary",
@@ -940,7 +935,7 @@ class GroupServerStore(GroupServerWorkerStore):
def add_group_invite(self, group_id, user_id):
"""Record that the group server has invited a user
"""
- return self.db.simple_insert(
+ return self.db_pool.simple_insert(
table="group_invites",
values={"group_id": group_id, "user_id": user_id},
desc="add_group_invite",
@@ -970,7 +965,7 @@ class GroupServerStore(GroupServerWorkerStore):
"""
def _add_user_to_group_txn(txn):
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="group_users",
values={
@@ -981,14 +976,14 @@ class GroupServerStore(GroupServerWorkerStore):
},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
)
if local_attestation:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="group_attestations_renewals",
values={
@@ -998,60 +993,60 @@ class GroupServerStore(GroupServerWorkerStore):
},
)
if remote_attestation:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="group_attestations_remote",
values={
"group_id": group_id,
"user_id": user_id,
"valid_until_ms": remote_attestation["valid_until_ms"],
- "attestation_json": json.dumps(remote_attestation),
+ "attestation_json": json_encoder.encode(remote_attestation),
},
)
- return self.db.runInteraction("add_user_to_group", _add_user_to_group_txn)
+ return self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
def remove_user_from_group(self, group_id, user_id):
def _remove_user_from_group_txn(txn):
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="group_summary_users",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"remove_user_from_group", _remove_user_from_group_txn
)
def add_room_to_group(self, group_id, room_id, is_public):
- return self.db.simple_insert(
+ return self.db_pool.simple_insert(
table="group_rooms",
values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
desc="add_room_to_group",
)
def update_room_in_group_visibility(self, group_id, room_id, is_public):
- return self.db.simple_update(
+ return self.db_pool.simple_update(
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
updatevalues={"is_public": is_public},
@@ -1060,67 +1055,67 @@ class GroupServerStore(GroupServerWorkerStore):
def remove_room_from_group(self, group_id, room_id):
def _remove_room_from_group_txn(txn):
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="group_summary_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"remove_room_from_group", _remove_room_from_group_txn
)
def update_group_publicity(self, group_id, user_id, publicise):
"""Update whether the user is publicising their membership of the group
"""
- return self.db.simple_update_one(
+ return self.db_pool.simple_update_one(
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"is_publicised": publicise},
desc="update_group_publicity",
)
- @defer.inlineCallbacks
- def register_user_group_membership(
+ async def register_user_group_membership(
self,
- group_id,
- user_id,
- membership,
- is_admin=False,
- content={},
- local_attestation=None,
- remote_attestation=None,
- is_publicised=False,
- ):
+ group_id: str,
+ user_id: str,
+ membership: str,
+ is_admin: bool = False,
+ content: JsonDict = {},
+ local_attestation: Optional[dict] = None,
+ remote_attestation: Optional[dict] = None,
+ is_publicised: bool = False,
+ ) -> int:
"""Registers that a local user is a member of a (local or remote) group.
Args:
- group_id (str)
- user_id (str)
- membership (str)
- is_admin (bool)
- content (dict): Content of the membership, e.g. includes the inviter
+ group_id: The group the member is being added to.
+ user_id: THe user ID to add to the group.
+ membership: The type of group membership.
+ is_admin: Whether the user should be added as a group admin.
+ content: Content of the membership, e.g. includes the inviter
if the user has been invited.
- local_attestation (dict): If remote group then store the fact that we
+ local_attestation: If remote group then store the fact that we
have given out an attestation, else None.
- remote_attestation (dict): If remote group then store the remote
+ remote_attestation: If remote group then store the remote
attestation from the group, else None.
+ is_publicised: Whether this should be publicised.
"""
def _register_user_group_membership_txn(txn, next_id):
# TODO: Upsert?
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="local_group_membership",
values={
@@ -1129,11 +1124,11 @@ class GroupServerStore(GroupServerWorkerStore):
"is_admin": is_admin,
"membership": membership,
"is_publicised": is_publicised,
- "content": json.dumps(content),
+ "content": json_encoder.encode(content),
},
)
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="local_group_updates",
values={
@@ -1141,7 +1136,7 @@ class GroupServerStore(GroupServerWorkerStore):
"group_id": group_id,
"user_id": user_id,
"type": "membership",
- "content": json.dumps(
+ "content": json_encoder.encode(
{"membership": membership, "content": content}
),
},
@@ -1152,7 +1147,7 @@ class GroupServerStore(GroupServerWorkerStore):
if membership == "join":
if local_attestation:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="group_attestations_renewals",
values={
@@ -1162,23 +1157,23 @@ class GroupServerStore(GroupServerWorkerStore):
},
)
if remote_attestation:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="group_attestations_remote",
values={
"group_id": group_id,
"user_id": user_id,
"valid_until_ms": remote_attestation["valid_until_ms"],
- "attestation_json": json.dumps(remote_attestation),
+ "attestation_json": json_encoder.encode(remote_attestation),
},
)
else:
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -1187,18 +1182,17 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
with self._group_updates_id_gen.get_next() as next_id:
- res = yield self.db.runInteraction(
+ res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
next_id,
)
return res
- @defer.inlineCallbacks
- def create_group(
+ async def create_group(
self, group_id, user_id, name, avatar_url, short_description, long_description
- ):
- yield self.db.simple_insert(
+ ) -> None:
+ await self.db_pool.simple_insert(
table="groups",
values={
"group_id": group_id,
@@ -1211,9 +1205,8 @@ class GroupServerStore(GroupServerWorkerStore):
desc="create_group",
)
- @defer.inlineCallbacks
- def update_group_profile(self, group_id, profile):
- yield self.db.simple_update_one(
+ async def update_group_profile(self, group_id, profile):
+ await self.db_pool.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
updatevalues=profile,
@@ -1223,7 +1216,7 @@ class GroupServerStore(GroupServerWorkerStore):
def update_attestation_renewal(self, group_id, user_id, attestation):
"""Update an attestation that we have renewed
"""
- return self.db.simple_update_one(
+ return self.db_pool.simple_update_one(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
@@ -1233,12 +1226,12 @@ class GroupServerStore(GroupServerWorkerStore):
def update_remote_attestion(self, group_id, user_id, attestation):
"""Update an attestation that a remote has renewed
"""
- return self.db.simple_update_one(
+ return self.db_pool.simple_update_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={
"valid_until_ms": attestation["valid_until_ms"],
- "attestation_json": json.dumps(attestation),
+ "attestation_json": json_encoder.encode(attestation),
},
desc="update_remote_attestion",
)
@@ -1252,7 +1245,7 @@ class GroupServerStore(GroupServerWorkerStore):
group_id (str)
user_id (str)
"""
- return self.db.simple_delete(
+ return self.db_pool.simple_delete(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
desc="remove_attestation_renewal",
@@ -1288,8 +1281,8 @@ class GroupServerStore(GroupServerWorkerStore):
]
for table in tables:
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, table=table, keyvalues={"group_id": group_id}
)
- return self.db.runInteraction("delete_group", _delete_group_txn)
+ return self.db_pool.runInteraction("delete_group", _delete_group_txn)
diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/databases/main/keys.py
index 4e1642a27a..384e9c5eb0 100644
--- a/synapse/storage/data_stores/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -86,7 +86,7 @@ class KeyStore(SQLBaseStore):
_get_keys(txn, batch)
return keys
- return self.db.runInteraction("get_server_verify_keys", _txn)
+ return self.db_pool.runInteraction("get_server_verify_keys", _txn)
def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
"""Stores NACL verification keys for remote servers.
@@ -121,9 +121,9 @@ class KeyStore(SQLBaseStore):
f((i,))
return res
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"store_server_verify_keys",
- self.db.simple_upsert_many_txn,
+ self.db_pool.simple_upsert_many_txn,
table="server_signature_keys",
key_names=("server_name", "key_id"),
key_values=key_values,
@@ -151,7 +151,7 @@ class KeyStore(SQLBaseStore):
ts_valid_until_ms (int): The time when this json stops being valid.
key_json (bytes): The encoded JSON.
"""
- return self.db.simple_upsert(
+ return self.db_pool.simple_upsert(
table="server_keys_json",
keyvalues={
"server_name": server_name,
@@ -190,7 +190,7 @@ class KeyStore(SQLBaseStore):
keyvalues["key_id"] = key_id
if from_server is not None:
keyvalues["from_server"] = from_server
- rows = self.db.simple_select_list_txn(
+ rows = self.db_pool.simple_select_list_txn(
txn,
"server_keys_json",
keyvalues=keyvalues,
@@ -205,4 +205,6 @@ class KeyStore(SQLBaseStore):
results[(server_name, key_id, from_server)] = rows
return results
- return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn)
+ return self.db_pool.runInteraction(
+ "get_server_keys_json", _get_server_keys_json_txn
+ )
diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 15bc13cbd0..80fc1cd009 100644
--- a/synapse/storage/data_stores/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -13,16 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(MediaRepositoryBackgroundUpdateStore, self).__init__(
database, db_conn, hs
)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
update_name="local_media_repository_url_idx",
index_name="local_media_repository_url_idx",
table="local_media_repository",
@@ -34,7 +34,7 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars"""
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
def get_local_media(self, media_id):
@@ -42,7 +42,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
Returns:
None if the media_id doesn't exist.
"""
- return self.db.simple_select_one(
+ return self.db_pool.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
@@ -67,7 +67,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
user_id,
url_cache=None,
):
- return self.db.simple_insert(
+ return self.db_pool.simple_insert(
"local_media_repository",
{
"media_id": media_id,
@@ -83,7 +83,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def mark_local_media_as_safe(self, media_id: str):
"""Mark a local media as safe from quarantining."""
- return self.db.simple_update_one(
+ return self.db_pool.simple_update_one(
table="local_media_repository",
keyvalues={"media_id": media_id},
updatevalues={"safe_from_quarantine": True},
@@ -136,12 +136,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
)
- return self.db.runInteraction("get_url_cache", get_url_cache_txn)
+ return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
):
- return self.db.simple_insert(
+ return self.db_pool.simple_insert(
"local_media_repository_url_cache",
{
"url": url,
@@ -156,7 +156,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
def get_local_media_thumbnails(self, media_id):
- return self.db.simple_select_list(
+ return self.db_pool.simple_select_list(
"local_media_repository_thumbnails",
{"media_id": media_id},
(
@@ -178,7 +178,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- return self.db.simple_insert(
+ return self.db_pool.simple_insert(
"local_media_repository_thumbnails",
{
"media_id": media_id,
@@ -192,7 +192,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
def get_cached_remote_media(self, origin, media_id):
- return self.db.simple_select_one(
+ return self.db_pool.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
@@ -217,7 +217,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
upload_name,
filesystem_id,
):
- return self.db.simple_insert(
+ return self.db_pool.simple_insert(
"remote_media_cache",
{
"media_origin": origin,
@@ -262,12 +262,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn
)
def get_remote_media_thumbnails(self, origin, media_id):
- return self.db.simple_select_list(
+ return self.db_pool.simple_select_list(
"remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id},
(
@@ -292,7 +292,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- return self.db.simple_insert(
+ return self.db_pool.simple_insert(
"remote_media_cache_thumbnails",
{
"media_origin": origin,
@@ -314,24 +314,26 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE last_access_ts < ?"
)
- return self.db.execute(
- "get_remote_media_before", self.db.cursor_to_dict, sql, before_ts
+ return self.db_pool.execute(
+ "get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
)
def delete_remote_media(self, media_origin, media_id):
def delete_remote_media_txn(txn):
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
"remote_media_cache",
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
"remote_media_cache_thumbnails",
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
- return self.db.runInteraction("delete_remote_media", delete_remote_media_txn)
+ return self.db_pool.runInteraction(
+ "delete_remote_media", delete_remote_media_txn
+ )
def get_expired_url_cache(self, now_ts):
sql = (
@@ -345,7 +347,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (now_ts,))
return [row[0] for row in txn]
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_expired_url_cache", _get_expired_url_cache_txn
)
@@ -358,7 +360,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def _delete_url_cache_txn(txn):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
- return await self.db.runInteraction("delete_url_cache", _delete_url_cache_txn)
+ return await self.db_pool.runInteraction(
+ "delete_url_cache", _delete_url_cache_txn
+ )
def get_url_cache_media_before(self, before_ts):
sql = (
@@ -372,7 +376,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (before_ts,))
return [row[0] for row in txn]
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_url_cache_media_before", _get_url_cache_media_before_txn
)
@@ -389,6 +393,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn
)
diff --git a/synapse/storage/data_stores/main/metrics.py b/synapse/storage/databases/main/metrics.py
index dad5bbc602..686052bd83 100644
--- a/synapse/storage/data_stores/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -15,15 +15,13 @@
import typing
from collections import Counter
-from twisted.internet import defer
-
from synapse.metrics import BucketCollector
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.event_push_actions import (
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
-from synapse.storage.database import Database
class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
@@ -31,7 +29,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
stats and prometheus metrics.
"""
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
# Collect metrics on the number of forward extremities that exist.
@@ -66,11 +64,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
)
return txn.fetchall()
- res = await self.db.runInteraction("read_forward_extremities", fetch)
+ res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
self._current_forward_extremities_amount = Counter([x[0] for x in res])
- @defer.inlineCallbacks
- def count_daily_messages(self):
+ async def count_daily_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.
@@ -88,11 +85,9 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(count,) = txn.fetchone()
return count
- ret = yield self.db.runInteraction("count_messages", _count_messages)
- return ret
+ return await self.db_pool.runInteraction("count_messages", _count_messages)
- @defer.inlineCallbacks
- def count_daily_sent_messages(self):
+ async def count_daily_sent_messages(self):
def _count_messages(txn):
# This is good enough as if you have silly characters in your own
# hostname then thats your own fault.
@@ -109,11 +104,11 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(count,) = txn.fetchone()
return count
- ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages)
- return ret
+ return await self.db_pool.runInteraction(
+ "count_daily_sent_messages", _count_messages
+ )
- @defer.inlineCallbacks
- def count_daily_active_rooms(self):
+ async def count_daily_active_rooms(self):
def _count(txn):
sql = """
SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
@@ -124,5 +119,4 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(count,) = txn.fetchone()
return count
- ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
- return ret
+ return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 6572f41971..1d4db758d4 100644
--- a/synapse/storage/data_stores/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -15,10 +15,8 @@
import logging
from typing import List
-from twisted.internet import defer
-
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database, make_in_list_sql_clause
+from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -29,7 +27,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersWorkerStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs)
self._clock = hs.get_clock()
self.hs = hs
@@ -48,7 +46,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
(count,) = txn.fetchone()
return count
- return self.db.runInteraction("count_users", _count_users)
+ return self.db_pool.runInteraction("count_users", _count_users)
@cached(num_args=0)
def get_monthly_active_count_by_service(self):
@@ -76,7 +74,9 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
result = txn.fetchall()
return dict(result)
- return self.db.runInteraction("count_users_by_service", _count_users_by_service)
+ return self.db_pool.runInteraction(
+ "count_users_by_service", _count_users_by_service
+ )
async def get_registered_reserved_users(self) -> List[str]:
"""Of the reserved threepids defined in config, retrieve those that are associated
@@ -109,7 +109,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""
- return self.db.simple_select_one_onecol(
+ return self.db_pool.simple_select_one_onecol(
table="monthly_active_users",
keyvalues={"user_id": user_id},
retcol="timestamp",
@@ -119,7 +119,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
@@ -128,7 +128,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
# Do not add more reserved users than the total allowable number
# cur = LoggingTransaction(
- self.db.new_transaction(
+ self.db_pool.new_transaction(
db_conn,
"initialise_mau_threepids",
[],
@@ -162,7 +162,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
is_support = self.is_support_user_txn(txn, user_id)
if not is_support:
# We do this manually here to avoid hitting #6791
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="monthly_active_users",
keyvalues={"user_id": user_id},
@@ -246,20 +246,16 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
reserved_users = await self.get_registered_reserved_users()
- await self.db.runInteraction(
+ await self.db_pool.runInteraction(
"reap_monthly_active_users", _reap_users, reserved_users
)
- @defer.inlineCallbacks
- def upsert_monthly_active_user(self, user_id):
+ async def upsert_monthly_active_user(self, user_id: str) -> None:
"""Updates or inserts the user into the monthly active user table, which
is used to track the current MAU usage of the server
Args:
- user_id (str): user to add/update
-
- Returns:
- Deferred
+ user_id: user to add/update
"""
# Support user never to be included in MAU stats. Note I can't easily call this
# from upsert_monthly_active_user_txn because then I need a _txn form of
@@ -269,11 +265,11 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
# _initialise_reserved_users reasoning that it would be very strange to
# include a support user in this context.
- is_support = yield self.is_support_user(user_id)
+ is_support = await self.is_support_user(user_id)
if is_support:
return
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
)
@@ -303,7 +299,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
# never be a big table and alternative approaches (batching multiple
# upserts into a single txn) introduced a lot of extra complexity.
# See https://github.com/matrix-org/synapse/issues/3854 for more
- is_insert = self.db.simple_upsert_txn(
+ is_insert = self.db_pool.simple_upsert_txn(
txn,
table="monthly_active_users",
keyvalues={"user_id": user_id},
@@ -320,8 +316,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
return is_insert
- @defer.inlineCallbacks
- def populate_monthly_active_users(self, user_id):
+ async def populate_monthly_active_users(self, user_id):
"""Checks on the state of monthly active user limits and optionally
add the user to the monthly active tables
@@ -330,14 +325,14 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"""
if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group
- is_guest = yield self.is_guest(user_id)
+ is_guest = await self.is_guest(user_id)
if is_guest:
return
- is_trial = yield self.is_trial_user(user_id)
+ is_trial = await self.is_trial_user(user_id)
if is_trial:
return
- last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)
+ last_seen_timestamp = await self.user_last_seen_monthly_active(user_id)
now = self.hs.get_clock().time_msec()
# We want to reduce to the total number of db writes, and are happy
@@ -350,10 +345,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
# False, there is no point in checking get_monthly_active_count - it
# adds no value and will break the logic if max_mau_value is exceeded.
if not self._limit_usage_by_mau:
- yield self.upsert_monthly_active_user(user_id)
+ await self.upsert_monthly_active_user(user_id)
else:
- count = yield self.get_monthly_active_count()
+ count = await self.get_monthly_active_count()
if count < self._max_mau_value:
- yield self.upsert_monthly_active_user(user_id)
+ await self.upsert_monthly_active_user(user_id)
elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
- yield self.upsert_monthly_active_user(user_id)
+ await self.upsert_monthly_active_user(user_id)
diff --git a/synapse/storage/data_stores/main/openid.py b/synapse/storage/databases/main/openid.py
index cc21437e92..dcd1ff911a 100644
--- a/synapse/storage/data_stores/main/openid.py
+++ b/synapse/storage/databases/main/openid.py
@@ -3,7 +3,7 @@ from synapse.storage._base import SQLBaseStore
class OpenIdStore(SQLBaseStore):
def insert_open_id_token(self, token, ts_valid_until_ms, user_id):
- return self.db.simple_insert(
+ return self.db_pool.simple_insert(
table="open_id_tokens",
values={
"token": token,
@@ -28,6 +28,6 @@ class OpenIdStore(SQLBaseStore):
else:
return rows[0][0]
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_user_id_for_token", get_user_id_for_token_txn
)
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/databases/main/presence.py
index 7574612619..59ba12820a 100644
--- a/synapse/storage/data_stores/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -15,8 +15,6 @@
from typing import List, Tuple
-from twisted.internet import defer
-
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.presence import UserPresenceState
from synapse.util.caches.descriptors import cached, cachedList
@@ -24,14 +22,13 @@ from synapse.util.iterutils import batch_iter
class PresenceStore(SQLBaseStore):
- @defer.inlineCallbacks
- def update_presence(self, presence_states):
+ async def update_presence(self, presence_states):
stream_ordering_manager = self._presence_id_gen.get_next_mult(
len(presence_states)
)
with stream_ordering_manager as stream_orderings:
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"update_presence",
self._update_presence_txn,
stream_orderings,
@@ -48,7 +45,7 @@ class PresenceStore(SQLBaseStore):
txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
# Actually insert new rows
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="presence_stream",
values=[
@@ -124,7 +121,7 @@ class PresenceStore(SQLBaseStore):
return updates, upper_bound, limited
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_presence_updates", get_all_presence_updates_txn
)
@@ -139,7 +136,7 @@ class PresenceStore(SQLBaseStore):
inlineCallbacks=True,
)
def get_presence_for_users(self, user_ids):
- rows = yield self.db.simple_select_many_batch(
+ rows = yield self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
iterable=user_ids,
@@ -165,7 +162,7 @@ class PresenceStore(SQLBaseStore):
return self._presence_id_gen.get_current_token()
def allow_presence_visible(self, observed_localpart, observer_userid):
- return self.db.simple_insert(
+ return self.db_pool.simple_insert(
table="presence_allow_inbound",
values={
"observed_user_id": observed_localpart,
@@ -176,7 +173,7 @@ class PresenceStore(SQLBaseStore):
)
def disallow_presence_visible(self, observed_localpart, observer_userid):
- return self.db.simple_delete_one(
+ return self.db_pool.simple_delete_one(
table="presence_allow_inbound",
keyvalues={
"observed_user_id": observed_localpart,
diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/databases/main/profile.py
index bfc9369f0b..b8261357d4 100644
--- a/synapse/storage/data_stores/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -13,18 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.roommember import ProfileInfo
+from synapse.storage.databases.main.roommember import ProfileInfo
class ProfileWorkerStore(SQLBaseStore):
- @defer.inlineCallbacks
- def get_profileinfo(self, user_localpart):
+ async def get_profileinfo(self, user_localpart):
try:
- profile = yield self.db.simple_select_one(
+ profile = await self.db_pool.simple_select_one(
table="profiles",
keyvalues={"user_id": user_localpart},
retcols=("displayname", "avatar_url"),
@@ -42,7 +39,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
def get_profile_displayname(self, user_localpart):
- return self.db.simple_select_one_onecol(
+ return self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
@@ -50,7 +47,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
def get_profile_avatar_url(self, user_localpart):
- return self.db.simple_select_one_onecol(
+ return self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
@@ -58,7 +55,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
def get_from_remote_profile_cache(self, user_id):
- return self.db.simple_select_one(
+ return self.db_pool.simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url"),
@@ -67,12 +64,12 @@ class ProfileWorkerStore(SQLBaseStore):
)
def create_profile(self, user_localpart):
- return self.db.simple_insert(
+ return self.db_pool.simple_insert(
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
)
def set_profile_displayname(self, user_localpart, new_displayname):
- return self.db.simple_update_one(
+ return self.db_pool.simple_update_one(
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"displayname": new_displayname},
@@ -80,7 +77,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
- return self.db.simple_update_one(
+ return self.db_pool.simple_update_one(
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"avatar_url": new_avatar_url},
@@ -95,7 +92,7 @@ class ProfileStore(ProfileWorkerStore):
This should only be called when `is_subscribed_remote_profile_for_user`
would return true for the user.
"""
- return self.db.simple_upsert(
+ return self.db_pool.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
@@ -107,7 +104,7 @@ class ProfileStore(ProfileWorkerStore):
)
def update_remote_profile_cache(self, user_id, displayname, avatar_url):
- return self.db.simple_update(
+ return self.db_pool.simple_update(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
updatevalues={
@@ -118,14 +115,13 @@ class ProfileStore(ProfileWorkerStore):
desc="update_remote_profile_cache",
)
- @defer.inlineCallbacks
- def maybe_delete_remote_profile_cache(self, user_id):
+ async def maybe_delete_remote_profile_cache(self, user_id):
"""Check if we still care about the remote user's profile, and if we
don't then remove their profile from the cache
"""
- subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
+ subscribed = await self.is_subscribed_remote_profile_for_user(user_id)
if not subscribed:
- yield self.db.simple_delete(
+ await self.db_pool.simple_delete(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
desc="delete_remote_profile_cache",
@@ -144,18 +140,17 @@ class ProfileStore(ProfileWorkerStore):
txn.execute(sql, (last_checked,))
- return self.db.cursor_to_dict(txn)
+ return self.db_pool.cursor_to_dict(txn)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_remote_profile_cache_entries_that_expire",
_get_remote_profile_cache_entries_that_expire_txn,
)
- @defer.inlineCallbacks
- def is_subscribed_remote_profile_for_user(self, user_id):
+ async def is_subscribed_remote_profile_for_user(self, user_id):
"""Check whether we are interested in a remote user's profile.
"""
- res = yield self.db.simple_select_one_onecol(
+ res = await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"user_id": user_id},
retcol="user_id",
@@ -166,7 +161,7 @@ class ProfileStore(ProfileWorkerStore):
if res:
return True
- res = yield self.db.simple_select_one_onecol(
+ res = await self.db_pool.simple_select_one_onecol(
table="group_invites",
keyvalues={"user_id": user_id},
retcol="user_id",
diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index b53fe35c33..3526b6fd66 100644
--- a/synapse/storage/data_stores/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -18,7 +18,7 @@ from typing import Any, Tuple
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.types import RoomStreamToken
logger = logging.getLogger(__name__)
@@ -43,7 +43,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
deleted events.
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"purge_history",
self._purge_history_txn,
room_id,
@@ -293,7 +293,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
Deferred[List[int]]: The list of state groups to delete.
"""
- return self.db.runInteraction("purge_room", self._purge_room_txn, room_id)
+ return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id)
def _purge_room_txn(self, txn, room_id):
# First we fetch all the state groups that should be deleted, before
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index c229248101..6562db5c2b 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -18,28 +18,27 @@ import abc
import logging
from typing import List, Tuple, Union
-from canonicaljson import json
-
from twisted.internet import defer
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.data_stores.main.pusher import PusherWorkerStore
-from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
-from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.databases.main.pusher import PusherWorkerStore
+from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
+from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import ChainedIdGenerator
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
-def _load_rules(rawrules, enabled_map):
+def _load_rules(rawrules, enabled_map, use_new_defaults=False):
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
@@ -49,7 +48,7 @@ def _load_rules(rawrules, enabled_map):
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy
- rules = list(list_with_base_rules(ruleslist))
+ rules = list(list_with_base_rules(ruleslist, use_new_defaults))
for i, rule in enumerate(rules):
rule_id = rule["rule_id"]
@@ -79,7 +78,7 @@ class PushRulesWorkerStore(
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
@@ -91,7 +90,7 @@ class PushRulesWorkerStore(
db_conn, "push_rules_stream", "stream_id"
)
- push_rules_prefill, push_rules_id = self.db.get_cache_dict(
+ push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
db_conn,
"push_rules_stream",
entity_column="user_id",
@@ -105,6 +104,8 @@ class PushRulesWorkerStore(
prefilled_cache=push_rules_prefill,
)
+ self._users_new_default_push_rules = hs.config.users_new_default_push_rules
+
@abc.abstractmethod
def get_max_push_rules_stream_id(self):
"""Get the position of the push rules stream.
@@ -116,7 +117,7 @@ class PushRulesWorkerStore(
@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_for_user(self, user_id):
- rows = yield self.db.simple_select_list(
+ rows = yield self.db_pool.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
retcols=(
@@ -134,13 +135,15 @@ class PushRulesWorkerStore(
enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
- rules = _load_rules(rows, enabled_map)
+ use_new_defaults = user_id in self._users_new_default_push_rules
+
+ rules = _load_rules(rows, enabled_map, use_new_defaults)
return rules
@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_enabled_for_user(self, user_id):
- results = yield self.db.simple_select_list(
+ results = yield self.db_pool.simple_select_list(
table="push_rules_enable",
keyvalues={"user_name": user_id},
retcols=("user_name", "rule_id", "enabled"),
@@ -162,7 +165,7 @@ class PushRulesWorkerStore(
(count,) = txn.fetchone()
return bool(count)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
@@ -178,7 +181,7 @@ class PushRulesWorkerStore(
results = {user_id: [] for user_id in user_ids}
- rows = yield self.db.simple_select_many_batch(
+ rows = yield self.db_pool.simple_select_many_batch(
table="push_rules",
column="user_name",
iterable=user_ids,
@@ -194,7 +197,11 @@ class PushRulesWorkerStore(
enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
for user_id, rules in results.items():
- results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
+ use_new_defaults = user_id in self._users_new_default_push_rules
+
+ results[user_id] = _load_rules(
+ rules, enabled_map_by_user.get(user_id, {}), use_new_defaults,
+ )
return results
@@ -249,81 +256,6 @@ class PushRulesWorkerStore(
):
yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
- @defer.inlineCallbacks
- def bulk_get_push_rules_for_room(self, event, context):
- state_group = context.state_group
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
-
- current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
- result = yield self._bulk_get_push_rules_for_room(
- event.room_id, state_group, current_state_ids, event=event
- )
- return result
-
- @cachedInlineCallbacks(num_args=2, cache_context=True)
- def _bulk_get_push_rules_for_room(
- self, room_id, state_group, current_state_ids, cache_context, event=None
- ):
- # We don't use `state_group`, its there so that we can cache based
- # on it. However, its important that its never None, since two current_state's
- # with a state_group of None are likely to be different.
- # See bulk_get_push_rules_for_room for how we work around this.
- assert state_group is not None
-
- # We also will want to generate notifs for other people in the room so
- # their unread countss are correct in the event stream, but to avoid
- # generating them for bot / AS users etc, we only do so for people who've
- # sent a read receipt into the room.
-
- users_in_room = yield self._get_joined_users_from_context(
- room_id,
- state_group,
- current_state_ids,
- on_invalidate=cache_context.invalidate,
- event=event,
- )
-
- # We ignore app service users for now. This is so that we don't fill
- # up the `get_if_users_have_pushers` cache with AS entries that we
- # know don't have pushers, nor even read receipts.
- local_users_in_room = {
- u
- for u in users_in_room
- if self.hs.is_mine_id(u)
- and not self.get_if_app_services_interested_in_user(u)
- }
-
- # users in the room who have pushers need to get push rules run because
- # that's how their pushers work
- if_users_with_pushers = yield self.get_if_users_have_pushers(
- local_users_in_room, on_invalidate=cache_context.invalidate
- )
- user_ids = {
- uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
- }
-
- users_with_receipts = yield self.get_users_with_read_receipts_in_room(
- room_id, on_invalidate=cache_context.invalidate
- )
-
- # any users with pushers must be ours: they have pushers
- for uid in users_with_receipts:
- if uid in local_users_in_room:
- user_ids.add(uid)
-
- rules_by_user = yield self.bulk_get_push_rules(
- user_ids, on_invalidate=cache_context.invalidate
- )
-
- rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
-
- return rules_by_user
-
@cachedList(
cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids",
@@ -336,7 +268,7 @@ class PushRulesWorkerStore(
results = {user_id: {} for user_id in user_ids}
- rows = yield self.db.simple_select_many_batch(
+ rows = yield self.db_pool.simple_select_many_batch(
table="push_rules_enable",
column="user_name",
iterable=user_ids,
@@ -394,7 +326,7 @@ class PushRulesWorkerStore(
return updates, upper_bound, limited
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_push_rule_updates", get_all_push_rule_updates_txn
)
@@ -411,12 +343,12 @@ class PushRuleStore(PushRulesWorkerStore):
before=None,
after=None,
):
- conditions_json = json.dumps(conditions)
- actions_json = json.dumps(actions)
+ conditions_json = json_encoder.encode(conditions)
+ actions_json = json_encoder.encode(actions)
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
if before or after:
- yield self.db.runInteraction(
+ yield self.db_pool.runInteraction(
"_add_push_rule_relative_txn",
self._add_push_rule_relative_txn,
stream_id,
@@ -430,7 +362,7 @@ class PushRuleStore(PushRulesWorkerStore):
after,
)
else:
- yield self.db.runInteraction(
+ yield self.db_pool.runInteraction(
"_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn,
stream_id,
@@ -461,7 +393,7 @@ class PushRuleStore(PushRulesWorkerStore):
relative_to_rule = before or after
- res = self.db.simple_select_one_txn(
+ res = self.db_pool.simple_select_one_txn(
txn,
table="push_rules",
keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
@@ -584,7 +516,7 @@ class PushRuleStore(PushRulesWorkerStore):
# We didn't update a row with the given rule_id so insert one
push_rule_id = self._push_rule_id_gen.get_next()
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="push_rules",
values={
@@ -627,7 +559,7 @@ class PushRuleStore(PushRulesWorkerStore):
"""
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
- self.db.simple_delete_one_txn(
+ self.db_pool.simple_delete_one_txn(
txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
)
@@ -637,7 +569,7 @@ class PushRuleStore(PushRulesWorkerStore):
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
- yield self.db.runInteraction(
+ yield self.db_pool.runInteraction(
"delete_push_rule",
delete_push_rule_txn,
stream_id,
@@ -648,7 +580,7 @@ class PushRuleStore(PushRulesWorkerStore):
def set_push_rule_enabled(self, user_id, rule_id, enabled):
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
- yield self.db.runInteraction(
+ yield self.db_pool.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
stream_id,
@@ -662,7 +594,7 @@ class PushRuleStore(PushRulesWorkerStore):
self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
):
new_id = self._push_rules_enable_id_gen.get_next()
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
"push_rules_enable",
{"user_name": user_id, "rule_id": rule_id},
@@ -681,7 +613,7 @@ class PushRuleStore(PushRulesWorkerStore):
@defer.inlineCallbacks
def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
- actions_json = json.dumps(actions)
+ actions_json = json_encoder.encode(actions)
def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
if is_default_rule:
@@ -702,7 +634,7 @@ class PushRuleStore(PushRulesWorkerStore):
update_stream=False,
)
else:
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
"push_rules",
{"user_name": user_id, "rule_id": rule_id},
@@ -721,7 +653,7 @@ class PushRuleStore(PushRulesWorkerStore):
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
- yield self.db.runInteraction(
+ yield self.db_pool.runInteraction(
"set_push_rule_actions",
set_push_rule_actions_txn,
stream_id,
@@ -741,7 +673,7 @@ class PushRuleStore(PushRulesWorkerStore):
if data is not None:
values.update(data)
- self.db.simple_insert_txn(txn, "push_rules_stream", values=values)
+ self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values)
txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/databases/main/pusher.py
index e18f1ca87c..b5200fbe79 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -50,7 +50,7 @@ class PusherWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def user_has_pusher(self, user_id):
- ret = yield self.db.simple_select_one_onecol(
+ ret = yield self.db_pool.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
return ret is not None
@@ -63,7 +63,7 @@ class PusherWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_pushers_by(self, keyvalues):
- ret = yield self.db.simple_select_list(
+ ret = yield self.db_pool.simple_select_list(
"pushers",
keyvalues,
[
@@ -91,11 +91,11 @@ class PusherWorkerStore(SQLBaseStore):
def get_all_pushers(self):
def get_pushers(txn):
txn.execute("SELECT * FROM pushers")
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
return self._decode_pushers_rows(rows)
- rows = yield self.db.runInteraction("get_all_pushers", get_pushers)
+ rows = yield self.db_pool.runInteraction("get_all_pushers", get_pushers)
return rows
async def get_all_updated_pushers_rows(
@@ -160,7 +160,7 @@ class PusherWorkerStore(SQLBaseStore):
return updates, upper_bound, limited
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
)
@@ -176,7 +176,7 @@ class PusherWorkerStore(SQLBaseStore):
inlineCallbacks=True,
)
def get_if_users_have_pushers(self, user_ids):
- rows = yield self.db.simple_select_many_batch(
+ rows = yield self.db_pool.simple_select_many_batch(
table="pushers",
column="user_name",
iterable=user_ids,
@@ -193,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore):
def update_pusher_last_stream_ordering(
self, app_id, pushkey, user_id, last_stream_ordering
):
- yield self.db.simple_update_one(
+ yield self.db_pool.simple_update_one(
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
{"last_stream_ordering": last_stream_ordering},
@@ -216,7 +216,7 @@ class PusherWorkerStore(SQLBaseStore):
Returns:
Deferred[bool]: True if the pusher still exists; False if it has been deleted.
"""
- updated = yield self.db.simple_update(
+ updated = yield self.db_pool.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={
@@ -230,7 +230,7 @@ class PusherWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
- yield self.db.simple_update(
+ yield self.db_pool.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={"failing_since": failing_since},
@@ -239,7 +239,7 @@ class PusherWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_throttle_params_by_room(self, pusher_id):
- res = yield self.db.simple_select_list(
+ res = yield self.db_pool.simple_select_list(
"pusher_throttle",
{"pusher": pusher_id},
["room_id", "last_sent_ts", "throttle_ms"],
@@ -259,7 +259,7 @@ class PusherWorkerStore(SQLBaseStore):
def set_throttle_params(self, pusher_id, room_id, params):
# no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so simple_upsert will retry
- yield self.db.simple_upsert(
+ yield self.db_pool.simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
params,
@@ -291,7 +291,7 @@ class PusherStore(PusherWorkerStore):
with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
- yield self.db.simple_upsert(
+ yield self.db_pool.simple_upsert(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
values={
@@ -316,7 +316,7 @@ class PusherStore(PusherWorkerStore):
if user_has_pusher is not True:
# invalidate, since we the user might not have had a pusher before
- yield self.db.runInteraction(
+ yield self.db_pool.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream,
self.get_if_user_has_pusher,
@@ -330,7 +330,7 @@ class PusherStore(PusherWorkerStore):
txn, self.get_if_user_has_pusher, (user_id,)
)
- self.db.simple_delete_one_txn(
+ self.db_pool.simple_delete_one_txn(
txn,
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
@@ -339,7 +339,7 @@ class PusherStore(PusherWorkerStore):
# it's possible for us to end up with duplicate rows for
# (app_id, pushkey, user_id) at different stream_ids, but that
# doesn't really matter.
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="deleted_pushers",
values={
@@ -351,4 +351,6 @@ class PusherStore(PusherWorkerStore):
)
with self._pushers_id_gen.get_next() as stream_id:
- yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
+ yield self.db_pool.runInteraction(
+ "delete_pusher", delete_pusher_txn, stream_id
+ )
diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 1d723f2d34..1920a8a152 100644
--- a/synapse/storage/data_stores/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -18,13 +18,12 @@ import abc
import logging
from typing import List, Tuple
-from canonicaljson import json
-
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -41,7 +40,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
@@ -64,7 +63,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
- return self.db.simple_select_list(
+ return self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"room_id": room_id, "receipt_type": receipt_type},
retcols=("user_id", "event_id"),
@@ -73,7 +72,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(num_args=3)
def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
- return self.db.simple_select_one_onecol(
+ return self.db_pool.simple_select_one_onecol(
table="receipts_linearized",
keyvalues={
"room_id": room_id,
@@ -87,7 +86,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2)
def get_receipts_for_user(self, user_id, receipt_type):
- rows = yield self.db.simple_select_list(
+ rows = yield self.db_pool.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
retcols=("room_id", "event_id"),
@@ -111,7 +110,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return txn.fetchall()
- rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f)
+ rows = yield self.db_pool.runInteraction(
+ "get_receipts_for_user_with_orderings", f
+ )
return {
row[0]: {
"event_id": row[1],
@@ -190,11 +191,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (room_id, to_key))
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
return rows
- rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f)
+ rows = yield self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
if not rows:
return []
@@ -240,9 +241,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql + clause, [to_key] + list(args))
- return self.db.cursor_to_dict(txn)
+ return self.db_pool.cursor_to_dict(txn)
- txn_results = yield self.db.runInteraction(
+ txn_results = yield self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
@@ -288,7 +289,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return [r[0] for r in txn]
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_users_sent_receipts_between", _get_users_sent_receipts_between_txn
)
@@ -340,7 +341,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return updates, upper_bound, limited
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn
)
@@ -371,7 +372,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
class ReceiptsStore(ReceiptsWorkerStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = StreamIdGenerator(
@@ -393,7 +394,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
"""
- res = self.db.simple_select_one_txn(
+ res = self.db_pool.simple_select_one_txn(
txn,
table="events",
retcols=["stream_ordering", "received_ts"],
@@ -446,7 +447,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
(user_id, room_id, receipt_type),
)
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="receipts_linearized",
keyvalues={
@@ -457,7 +458,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
values={
"stream_id": stream_id,
"event_id": event_id,
- "data": json.dumps(data),
+ "data": json_encoder.encode(data),
},
# receipts_linearized has a unique constraint on
# (user_id, room_id, receipt_type), so no need to lock
@@ -506,13 +507,13 @@ class ReceiptsStore(ReceiptsWorkerStore):
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
- linearized_event_id = yield self.db.runInteraction(
+ linearized_event_id = yield self.db_pool.runInteraction(
"insert_receipt_conv", graph_to_linear
)
stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id:
- event_ts = yield self.db.runInteraction(
+ event_ts = yield self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id,
@@ -541,7 +542,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
return stream_id, max_persisted_id
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id,
@@ -567,7 +568,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="receipts_graph",
keyvalues={
@@ -576,14 +577,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
"user_id": user_id,
},
)
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="receipts_graph",
values={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
- "event_ids": json.dumps(event_ids),
- "data": json.dumps(data),
+ "event_ids": json_encoder.encode(event_ids),
+ "data": json_encoder.encode(data),
},
)
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/databases/main/registration.py
index 27d2c5028c..402ae25571 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,20 +17,19 @@
import logging
import re
-from typing import Optional
+from typing import Dict, List, Optional
-from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
@@ -38,7 +37,7 @@ logger = logging.getLogger(__name__)
class RegistrationWorkerStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RegistrationWorkerStore, self).__init__(database, db_conn, hs)
self.config = hs.config
@@ -50,7 +49,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@cached()
def get_user_by_id(self, user_id):
- return self.db.simple_select_one(
+ return self.db_pool.simple_select_one(
table="users",
keyvalues={"name": user_id},
retcols=[
@@ -69,19 +68,15 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_user_by_id",
)
- @defer.inlineCallbacks
- def is_trial_user(self, user_id):
+ async def is_trial_user(self, user_id: str) -> bool:
"""Checks if user is in the "trial" period, i.e. within the first
N days of registration defined by `mau_trial_days` config
Args:
- user_id (str)
-
- Returns:
- Deferred[bool]
+ user_id: The user to check for trial status.
"""
- info = yield self.get_user_by_id(user_id)
+ info = await self.get_user_by_id(user_id)
if not info:
return False
@@ -101,50 +96,51 @@ class RegistrationWorkerStore(SQLBaseStore):
including the keys `name`, `is_guest`, `device_id`, `token_id`,
`valid_until_ms`.
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
)
- @cachedInlineCallbacks()
- def get_expiration_ts_for_user(self, user_id):
+ @cached()
+ async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]:
"""Get the expiration timestamp for the account bearing a given user ID.
Args:
- user_id (str): The ID of the user.
+ user_id: The ID of the user.
Returns:
- defer.Deferred: None, if the account has no expiration timestamp,
- otherwise int representation of the timestamp (as a number of
- milliseconds since epoch).
+ None, if the account has no expiration timestamp, otherwise int
+ representation of the timestamp (as a number of milliseconds since epoch).
"""
- res = yield self.db.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="expiration_ts_ms",
allow_none=True,
desc="get_expiration_ts_for_user",
)
- return res
- @defer.inlineCallbacks
- def set_account_validity_for_user(
- self, user_id, expiration_ts, email_sent, renewal_token=None
- ):
+ async def set_account_validity_for_user(
+ self,
+ user_id: str,
+ expiration_ts: int,
+ email_sent: bool,
+ renewal_token: Optional[str] = None,
+ ) -> None:
"""Updates the account validity properties of the given account, with the
given values.
Args:
- user_id (str): ID of the account to update properties for.
- expiration_ts (int): New expiration date, as a timestamp in milliseconds
+ user_id: ID of the account to update properties for.
+ expiration_ts: New expiration date, as a timestamp in milliseconds
since epoch.
- email_sent (bool): True means a renewal email has been sent for this
- account and there's no need to send another one for the current validity
+ email_sent: True means a renewal email has been sent for this account
+ and there's no need to send another one for the current validity
period.
- renewal_token (str): Renewal token the user can use to extend the validity
+ renewal_token: Renewal token the user can use to extend the validity
of their account. Defaults to no token.
"""
def set_account_validity_for_user_txn(txn):
- self.db.simple_update_txn(
+ self.db_pool.simple_update_txn(
txn=txn,
table="account_validity",
keyvalues={"user_id": user_id},
@@ -158,75 +154,69 @@ class RegistrationWorkerStore(SQLBaseStore):
txn, self.get_expiration_ts_for_user, (user_id,)
)
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"set_account_validity_for_user", set_account_validity_for_user_txn
)
- @defer.inlineCallbacks
- def set_renewal_token_for_user(self, user_id, renewal_token):
+ async def set_renewal_token_for_user(
+ self, user_id: str, renewal_token: str
+ ) -> None:
"""Defines a renewal token for a given user.
Args:
- user_id (str): ID of the user to set the renewal token for.
- renewal_token (str): Random unique string that will be used to renew the
+ user_id: ID of the user to set the renewal token for.
+ renewal_token: Random unique string that will be used to renew the
user's account.
Raises:
StoreError: The provided token is already set for another user.
"""
- yield self.db.simple_update_one(
+ await self.db_pool.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"renewal_token": renewal_token},
desc="set_renewal_token_for_user",
)
- @defer.inlineCallbacks
- def get_user_from_renewal_token(self, renewal_token):
+ async def get_user_from_renewal_token(self, renewal_token: str) -> str:
"""Get a user ID from a renewal token.
Args:
- renewal_token (str): The renewal token to perform the lookup with.
+ renewal_token: The renewal token to perform the lookup with.
Returns:
- defer.Deferred[str]: The ID of the user to which the token belongs.
+ The ID of the user to which the token belongs.
"""
- res = yield self.db.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="account_validity",
keyvalues={"renewal_token": renewal_token},
retcol="user_id",
desc="get_user_from_renewal_token",
)
- return res
-
- @defer.inlineCallbacks
- def get_renewal_token_for_user(self, user_id):
+ async def get_renewal_token_for_user(self, user_id: str) -> str:
"""Get the renewal token associated with a given user ID.
Args:
- user_id (str): The user ID to lookup a token for.
+ user_id: The user ID to lookup a token for.
Returns:
- defer.Deferred[str]: The renewal token associated with this user ID.
+ The renewal token associated with this user ID.
"""
- res = yield self.db.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="renewal_token",
desc="get_renewal_token_for_user",
)
- return res
-
- @defer.inlineCallbacks
- def get_users_expiring_soon(self):
+ async def get_users_expiring_soon(self) -> List[Dict[str, int]]:
"""Selects users whose account will expire in the [now, now + renew_at] time
window (see configuration for account_validity for information on what renew_at
refers to).
Returns:
- Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
+ A list of dictionaries mapping user ID to expiration time (in milliseconds).
"""
def select_users_txn(txn, now_ms, renew_at):
@@ -236,58 +226,54 @@ class RegistrationWorkerStore(SQLBaseStore):
)
values = [False, now_ms, renew_at]
txn.execute(sql, values)
- return self.db.cursor_to_dict(txn)
+ return self.db_pool.cursor_to_dict(txn)
- res = yield self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_users_expiring_soon",
select_users_txn,
self.clock.time_msec(),
self.config.account_validity.renew_at,
)
- return res
-
- @defer.inlineCallbacks
- def set_renewal_mail_status(self, user_id, email_sent):
+ async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None:
"""Sets or unsets the flag that indicates whether a renewal email has been sent
to the user (and the user hasn't renewed their account yet).
Args:
- user_id (str): ID of the user to set/unset the flag for.
- email_sent (bool): Flag which indicates whether a renewal email has been sent
+ user_id: ID of the user to set/unset the flag for.
+ email_sent: Flag which indicates whether a renewal email has been sent
to this user.
"""
- yield self.db.simple_update_one(
+ await self.db_pool.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"email_sent": email_sent},
desc="set_renewal_mail_status",
)
- @defer.inlineCallbacks
- def delete_account_validity_for_user(self, user_id):
+ async def delete_account_validity_for_user(self, user_id: str) -> None:
"""Deletes the entry for the given user in the account validity table, removing
their expiration date and renewal token.
Args:
- user_id (str): ID of the user to remove from the account validity table.
+ user_id: ID of the user to remove from the account validity table.
"""
- yield self.db.simple_delete_one(
+ await self.db_pool.simple_delete_one(
table="account_validity",
keyvalues={"user_id": user_id},
desc="delete_account_validity_for_user",
)
- async def is_server_admin(self, user):
+ async def is_server_admin(self, user: UserID) -> bool:
"""Determines if a user is an admin of this homeserver.
Args:
- user (UserID): user ID of the user to test
+ user: user ID of the user to test
- Returns (bool):
+ Returns:
true iff the user is a server admin, false otherwise.
"""
- res = await self.db.simple_select_one_onecol(
+ res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",
@@ -307,14 +293,14 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
def set_server_admin_txn(txn):
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
)
self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user.to_string(),)
)
- return self.db.runInteraction("set_server_admin", set_server_admin_txn)
+ return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token):
sql = (
@@ -326,43 +312,42 @@ class RegistrationWorkerStore(SQLBaseStore):
)
txn.execute(sql, (token,))
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
if rows:
return rows[0]
return None
- @cachedInlineCallbacks()
- def is_real_user(self, user_id):
+ @cached()
+ async def is_real_user(self, user_id: str) -> bool:
"""Determines if the user is a real user, ie does not have a 'user_type'.
Args:
- user_id (str): user id to test
+ user_id: user id to test
Returns:
- Deferred[bool]: True if user 'user_type' is null or empty string
+ True if user 'user_type' is null or empty string
"""
- res = yield self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"is_real_user", self.is_real_user_txn, user_id
)
- return res
@cached()
- def is_support_user(self, user_id):
+ async def is_support_user(self, user_id: str) -> bool:
"""Determines if the user is of type UserTypes.SUPPORT
Args:
- user_id (str): user id to test
+ user_id: user id to test
Returns:
- Deferred[bool]: True if user is of type UserTypes.SUPPORT
+ True if user is of type UserTypes.SUPPORT
"""
- return self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"is_support_user", self.is_support_user_txn, user_id
)
def is_real_user_txn(self, txn, user_id):
- res = self.db.simple_select_one_onecol_txn(
+ res = self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
@@ -372,7 +357,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return res is None
def is_support_user_txn(self, txn, user_id):
- res = self.db.simple_select_one_onecol_txn(
+ res = self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
@@ -391,7 +376,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return dict(txn)
- return self.db.runInteraction("get_users_by_id_case_insensitive", f)
+ return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
@@ -405,7 +390,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
str|None: the mxid of the user, or None if they are not known
"""
- return await self.db.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="user_external_ids",
keyvalues={"auth_provider": auth_provider, "external_id": external_id},
retcol="user_id",
@@ -413,19 +398,17 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_user_by_external_id",
)
- @defer.inlineCallbacks
- def count_all_users(self):
+ async def count_all_users(self):
"""Counts all users registered on the homeserver."""
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
- ret = yield self.db.runInteraction("count_users", _count_users)
- return ret
+ return await self.db_pool.runInteraction("count_users", _count_users)
def count_daily_user_type(self):
"""
@@ -456,10 +439,11 @@ class RegistrationWorkerStore(SQLBaseStore):
results[row[0]] = row[1]
return results
- return self.db.runInteraction("count_daily_user_type", _count_daily_user_type)
+ return self.db_pool.runInteraction(
+ "count_daily_user_type", _count_daily_user_type
+ )
- @defer.inlineCallbacks
- def count_nonbridged_users(self):
+ async def count_nonbridged_users(self):
def _count_users(txn):
txn.execute(
"""
@@ -470,29 +454,26 @@ class RegistrationWorkerStore(SQLBaseStore):
(count,) = txn.fetchone()
return count
- ret = yield self.db.runInteraction("count_users", _count_users)
- return ret
+ return await self.db_pool.runInteraction("count_users", _count_users)
- @defer.inlineCallbacks
- def count_real_users(self):
+ async def count_real_users(self):
"""Counts all users without a special user_type registered on the homeserver."""
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
- ret = yield self.db.runInteraction("count_real_users", _count_users)
- return ret
+ return await self.db_pool.runInteraction("count_real_users", _count_users)
async def generate_user_id(self) -> str:
"""Generate a suitable localpart for a guest user
Returns: a (hopefully) free localpart
"""
- next_id = await self.db.runInteraction(
+ next_id = await self.db_pool.runInteraction(
"generate_user_id", self._user_id_seq.get_next_id_txn
)
@@ -508,7 +489,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
The user ID or None if no user id/threepid mapping exists
"""
- user_id = await self.db.runInteraction(
+ user_id = await self.db_pool.runInteraction(
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
)
return user_id
@@ -524,7 +505,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
str|None: user id or None if no user id/threepid mapping exists
"""
- ret = self.db.simple_select_one_txn(
+ ret = self.db_pool.simple_select_one_txn(
txn,
"user_threepids",
{"medium": medium, "address": address},
@@ -535,26 +516,23 @@ class RegistrationWorkerStore(SQLBaseStore):
return ret["user_id"]
return None
- @defer.inlineCallbacks
- def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
- yield self.db.simple_upsert(
+ async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
+ await self.db_pool.simple_upsert(
"user_threepids",
{"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)
- @defer.inlineCallbacks
- def user_get_threepids(self, user_id):
- ret = yield self.db.simple_select_list(
+ async def user_get_threepids(self, user_id):
+ return await self.db_pool.simple_select_list(
"user_threepids",
{"user_id": user_id},
["medium", "address", "validated_at", "added_at"],
"user_get_threepids",
)
- return ret
def user_delete_threepid(self, user_id, medium, address):
- return self.db.simple_delete(
+ return self.db_pool.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
desc="user_delete_threepid",
@@ -567,7 +545,7 @@ class RegistrationWorkerStore(SQLBaseStore):
user_id: The user id to delete all threepids of
"""
- return self.db.simple_delete(
+ return self.db_pool.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id},
desc="user_delete_threepids",
@@ -589,7 +567,7 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
# We need to use an upsert, in case they user had already bound the
# threepid
- return self.db.simple_upsert(
+ return self.db_pool.simple_upsert(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
@@ -615,7 +593,7 @@ class RegistrationWorkerStore(SQLBaseStore):
medium (str): The medium of the threepid (e.g "email")
address (str): The address of the threepid (e.g "bob@example.com")
"""
- return self.db.simple_select_list(
+ return self.db_pool.simple_select_list(
table="user_threepid_id_server",
keyvalues={"user_id": user_id},
retcols=["medium", "address"],
@@ -636,7 +614,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred
"""
- return self.db.simple_delete(
+ return self.db_pool.simple_delete(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
@@ -659,25 +637,25 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[list[str]]: Resolves to a list of identity servers
"""
- return self.db.simple_select_onecol(
+ return self.db_pool.simple_select_onecol(
table="user_threepid_id_server",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
retcol="id_server",
desc="get_id_servers_user_bound",
)
- @cachedInlineCallbacks()
- def get_user_deactivated_status(self, user_id):
+ @cached()
+ async def get_user_deactivated_status(self, user_id: str) -> bool:
"""Retrieve the value for the `deactivated` property for the provided user.
Args:
- user_id (str): The ID of the user to retrieve the status for.
+ user_id: The ID of the user to retrieve the status for.
Returns:
- defer.Deferred(bool): The requested value.
+ True if the user was deactivated, false if the user is still active.
"""
- res = yield self.db.simple_select_one_onecol(
+ res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="deactivated",
@@ -744,13 +722,13 @@ class RegistrationWorkerStore(SQLBaseStore):
sql += " LIMIT 1"
txn.execute(sql, list(keyvalues.values()))
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
if not rows:
return None
return rows[0]
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn
)
@@ -764,37 +742,37 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
def delete_threepid_session_txn(txn):
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="threepid_validation_token",
keyvalues={"session_id": session_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"delete_threepid_session", delete_threepid_session_txn
)
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.clock = hs.get_clock()
self.config = hs.config
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"access_tokens_device_index",
index_name="access_tokens_device_id",
table="access_tokens",
columns=["user_id", "device_id"],
)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"users_creation_ts",
index_name="users_creation_ts",
table="users",
@@ -804,18 +782,19 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
# we no longer use refresh tokens, but it's possible that some people
# might have a background update queued to build this index. Just
# clear the background update.
- self.db.updates.register_noop_background_update("refresh_tokens_device_index")
+ self.db_pool.updates.register_noop_background_update(
+ "refresh_tokens_device_index"
+ )
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"user_threepids_grandfather", self._bg_user_threepids_grandfather
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
- @defer.inlineCallbacks
- def _background_update_set_deactivated_flag(self, progress, batch_size):
+ async def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them.
"""
@@ -843,7 +822,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
(last_user, batch_size),
)
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
if not rows:
return True, 0
@@ -857,7 +836,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
logger.info("Marked %d rows as deactivated", rows_processed_nb)
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
)
@@ -866,17 +845,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
else:
return False, len(rows)
- end, nb_processed = yield self.db.runInteraction(
+ end, nb_processed = await self.db_pool.runInteraction(
"users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
)
if end:
- yield self.db.updates._end_background_update("users_set_deactivated_flag")
+ await self.db_pool.updates._end_background_update(
+ "users_set_deactivated_flag"
+ )
return nb_processed
- @defer.inlineCallbacks
- def _bg_user_threepids_grandfather(self, progress, batch_size):
+ async def _bg_user_threepids_grandfather(self, progress, batch_size):
"""We now track which identity servers a user binds their 3PID to, so
we need to handle the case of existing bindings where we didn't track
this.
@@ -897,17 +877,17 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers:
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
)
- yield self.db.updates._end_background_update("user_threepids_grandfather")
+ await self.db_pool.updates._end_background_update("user_threepids_grandfather")
return 1
class RegistrationStore(RegistrationBackgroundUpdateStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RegistrationStore, self).__init__(database, db_conn, hs)
self._account_validity = hs.config.account_validity
@@ -931,23 +911,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
- @defer.inlineCallbacks
- def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
+ async def add_access_token_to_user(
+ self,
+ user_id: str,
+ token: str,
+ device_id: Optional[str],
+ valid_until_ms: Optional[int],
+ ) -> None:
"""Adds an access token for the given user.
Args:
- user_id (str): The user ID.
- token (str): The new access token to add.
- device_id (str): ID of the device to associate with the access
- token
- valid_until_ms (int|None): when the token is valid until. None for
- no expiry.
+ user_id: The user ID.
+ token: The new access token to add.
+ device_id: ID of the device to associate with the access token
+ valid_until_ms: when the token is valid until. None for no expiry.
Raises:
StoreError if there was a problem adding this.
"""
next_id = self._access_tokens_id_gen.get_next()
- yield self.db.simple_insert(
+ await self.db_pool.simple_insert(
"access_tokens",
{
"id": next_id,
@@ -992,7 +975,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Returns:
Deferred
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"register_user",
self._register_user,
user_id,
@@ -1026,7 +1009,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
# Ensure that the guest user actually exists
# ``allow_none=False`` makes this raise an exception
# if the row isn't in the database.
- self.db.simple_select_one_txn(
+ self.db_pool.simple_select_one_txn(
txn,
"users",
keyvalues={"name": user_id, "is_guest": 1},
@@ -1034,7 +1017,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
allow_none=False,
)
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
"users",
keyvalues={"name": user_id, "is_guest": 1},
@@ -1048,7 +1031,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
},
)
else:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
"users",
values={
@@ -1091,7 +1074,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- txn.call_after(self.is_guest.invalidate, (user_id,))
def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
@@ -1103,7 +1085,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
external_id: id on that system
user_id: complete mxid that it is mapped to
"""
- return self.db.simple_insert(
+ return self.db_pool.simple_insert(
table="user_external_ids",
values={
"auth_provider": auth_provider,
@@ -1121,12 +1103,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
def user_set_password_hash_txn(txn):
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn, "users", {"name": user_id}, {"password_hash": password_hash}
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"user_set_password_hash", user_set_password_hash_txn
)
@@ -1143,7 +1125,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
def f(txn):
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
table="users",
keyvalues={"name": user_id},
@@ -1151,7 +1133,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.db.runInteraction("user_set_consent_version", f)
+ return self.db_pool.runInteraction("user_set_consent_version", f)
def user_set_consent_server_notice_sent(self, user_id, consent_version):
"""Updates the user table to record that we have sent the user a server
@@ -1167,7 +1149,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
def f(txn):
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
table="users",
keyvalues={"name": user_id},
@@ -1175,7 +1157,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.db.runInteraction("user_set_consent_server_notice_sent", f)
+ return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
"""
@@ -1221,11 +1203,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return tokens_and_devices
- return self.db.runInteraction("user_delete_access_tokens", f)
+ return self.db_pool.runInteraction("user_delete_access_tokens", f)
def delete_access_token(self, access_token):
def f(txn):
- self.db.simple_delete_one_txn(
+ self.db_pool.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token}
)
@@ -1233,11 +1215,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn, self.get_user_by_access_token, (access_token,)
)
- return self.db.runInteraction("delete_access_token", f)
+ return self.db_pool.runInteraction("delete_access_token", f)
- @cachedInlineCallbacks()
- def is_guest(self, user_id):
- res = yield self.db.simple_select_one_onecol(
+ @cached()
+ async def is_guest(self, user_id: str) -> bool:
+ res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
@@ -1252,7 +1234,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Adds a user to the table of users who need to be parted from all the rooms they're
in
"""
- return self.db.simple_insert(
+ return self.db_pool.simple_insert(
"users_pending_deactivation",
values={"user_id": user_id},
desc="add_user_pending_deactivation",
@@ -1265,7 +1247,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
# XXX: This should be simple_delete_one but we failed to put a unique index on
# the table, so somehow duplicate entries have ended up in it.
- return self.db.simple_delete(
+ return self.db_pool.simple_delete(
"users_pending_deactivation",
keyvalues={"user_id": user_id},
desc="del_user_pending_deactivation",
@@ -1276,7 +1258,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Gets one user from the table of users waiting to be parted from all the rooms
they're in.
"""
- return self.db.simple_select_one_onecol(
+ return self.db_pool.simple_select_one_onecol(
"users_pending_deactivation",
keyvalues={},
retcol="user_id",
@@ -1306,7 +1288,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
# Insert everything into a transaction in order to run atomically
def validate_threepid_session_txn(txn):
- row = self.db.simple_select_one_txn(
+ row = self.db_pool.simple_select_one_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@@ -1324,7 +1306,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
400, "This client_secret does not match the provided session_id"
)
- row = self.db.simple_select_one_txn(
+ row = self.db_pool.simple_select_one_txn(
txn,
table="threepid_validation_token",
keyvalues={"session_id": session_id, "token": token},
@@ -1349,7 +1331,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
# Looks good. Validate the session
- self.db.simple_update_txn(
+ self.db_pool.simple_update_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@@ -1359,7 +1341,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return next_link
# Return next_link if it exists
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"validate_threepid_session_txn", validate_threepid_session_txn
)
@@ -1392,7 +1374,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
if validated_at:
insertion_values["validated_at"] = validated_at
- return self.db.simple_upsert(
+ return self.db_pool.simple_upsert(
table="threepid_validation_session",
keyvalues={"session_id": session_id},
values={"last_send_attempt": send_attempt},
@@ -1430,7 +1412,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def start_or_continue_validation_session_txn(txn):
# Create or update a validation session
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@@ -1443,7 +1425,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
# Create a new validation token with this session ID
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="threepid_validation_token",
values={
@@ -1454,7 +1436,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
},
)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"start_or_continue_validation_session",
start_or_continue_validation_session_txn,
)
@@ -1469,22 +1451,23 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
return txn.execute(sql, (ts,))
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(),
)
- @defer.inlineCallbacks
- def set_user_deactivated_status(self, user_id, deactivated):
+ async def set_user_deactivated_status(
+ self, user_id: str, deactivated: bool
+ ) -> None:
"""Set the `deactivated` property for the provided user to the provided value.
Args:
- user_id (str): The ID of the user to set the status for.
- deactivated (bool): The value to set for `deactivated`.
+ user_id: The ID of the user to set the status for.
+ deactivated: The value to set for `deactivated`.
"""
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"set_user_deactivated_status",
self.set_user_deactivated_status_txn,
user_id,
@@ -1492,7 +1475,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
@@ -1501,9 +1484,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
+ txn.call_after(self.is_guest.invalidate, (user_id,))
- @defer.inlineCallbacks
- def _set_expiration_date_when_missing(self):
+ async def _set_expiration_date_when_missing(self):
"""
Retrieves the list of registered users that don't have an expiration date, and
adds an expiration date for each of them.
@@ -1520,14 +1503,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
txn.execute(sql, [])
- res = self.db.cursor_to_dict(txn)
+ res = self.db_pool.cursor_to_dict(txn)
if res:
for user in res:
self.set_expiration_date_for_user_txn(
txn, user["name"], use_delta=True
)
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"get_users_with_no_expiration_date",
select_users_with_no_expiration_date_txn,
)
@@ -1551,7 +1534,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
expiration_ts,
)
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
"account_validity",
keyvalues={"user_id": user_id},
diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/databases/main/rejections.py
index 27e5a2084a..cf9ba51205 100644
--- a/synapse/storage/data_stores/main/rejections.py
+++ b/synapse/storage/databases/main/rejections.py
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
class RejectionsStore(SQLBaseStore):
def get_rejection_reason(self, event_id):
- return self.db.simple_select_one_onecol(
+ return self.db_pool.simple_select_one_onecol(
table="rejections",
retcol="reason",
keyvalues={"event_id": event_id},
diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/databases/main/relations.py
index 7d477f8d01..a9ceffc20e 100644
--- a/synapse/storage/data_stores/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -14,18 +14,20 @@
# limitations under the License.
import logging
+from typing import Optional
import attr
from synapse.api.constants import RelationTypes
+from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.stream import generate_pagination_where_clause
+from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import (
AggregationPaginationToken,
PaginationChunk,
RelationPaginationToken,
)
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -129,7 +131,7 @@ class RelationsWorkerStore(SQLBaseStore):
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_recent_references_for_event", _get_recent_references_for_event_txn
)
@@ -223,22 +225,22 @@ class RelationsWorkerStore(SQLBaseStore):
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
)
- @cachedInlineCallbacks()
- def get_applicable_edit(self, event_id):
+ @cached()
+ async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
"""Get the most recent edit (if any) that has happened for the given
event.
Correctly handles checking whether edits were allowed to happen.
Args:
- event_id (str): The original event ID
+ event_id: The original event ID
Returns:
- Deferred[EventBase|None]: Returns the most recent edit, if any.
+ The most recent edit, if any.
"""
# We only allow edits for `m.room.message` events that have the same sender
@@ -268,15 +270,14 @@ class RelationsWorkerStore(SQLBaseStore):
if row:
return row[0]
- edit_id = yield self.db.runInteraction(
+ edit_id = await self.db_pool.runInteraction(
"get_applicable_edit", _get_applicable_edit_txn
)
if not edit_id:
- return
+ return None
- edit_event = yield self.get_event(edit_id, allow_none=True)
- return edit_event
+ return await self.get_event(edit_id, allow_none=True)
def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
"""Check if a user has already annotated an event with the same key
@@ -318,7 +319,7 @@ class RelationsWorkerStore(SQLBaseStore):
return bool(txn.fetchone())
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/databases/main/room.py
index ab48052cdc..f4008e6221 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -27,8 +27,8 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.data_stores.main.search import SearchStore
-from synapse.storage.database import Database, LoggingTransaction
+from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.search import SearchStore
from synapse.types import ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached
@@ -73,7 +73,7 @@ class RoomSortOrder(Enum):
class RoomWorkerStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomWorkerStore, self).__init__(database, db_conn, hs)
self.config = hs.config
@@ -86,7 +86,7 @@ class RoomWorkerStore(SQLBaseStore):
Returns:
A dict containing the room information, or None if the room is unknown.
"""
- return self.db.simple_select_one(
+ return self.db_pool.simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator"),
@@ -118,7 +118,7 @@ class RoomWorkerStore(SQLBaseStore):
txn.execute(sql, [room_id])
# Catch error if sql returns empty result to return "None" instead of an error
try:
- res = self.db.cursor_to_dict(txn)[0]
+ res = self.db_pool.cursor_to_dict(txn)[0]
except IndexError:
return None
@@ -126,12 +126,12 @@ class RoomWorkerStore(SQLBaseStore):
res["public"] = bool(res["public"])
return res
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_room_with_stats", get_room_with_stats_txn, room_id
)
def get_public_room_ids(self):
- return self.db.simple_select_onecol(
+ return self.db_pool.simple_select_onecol(
table="rooms",
keyvalues={"is_public": True},
retcol="room_id",
@@ -188,7 +188,9 @@ class RoomWorkerStore(SQLBaseStore):
txn.execute(sql, query_args)
return txn.fetchone()[0]
- return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn)
+ return self.db_pool.runInteraction(
+ "count_public_rooms", _count_public_rooms_txn
+ )
async def get_largest_public_rooms(
self,
@@ -320,21 +322,21 @@ class RoomWorkerStore(SQLBaseStore):
def _get_largest_public_rooms_txn(txn):
txn.execute(sql, query_args)
- results = self.db.cursor_to_dict(txn)
+ results = self.db_pool.cursor_to_dict(txn)
if not forwards:
results.reverse()
return results
- ret_val = await self.db.runInteraction(
+ ret_val = await self.db_pool.runInteraction(
"get_largest_public_rooms", _get_largest_public_rooms_txn
)
return ret_val
@cached(max_entries=10000)
def is_room_blocked(self, room_id):
- return self.db.simple_select_one_onecol(
+ return self.db_pool.simple_select_one_onecol(
table="blocked_rooms",
keyvalues={"room_id": room_id},
retcol="1",
@@ -502,7 +504,7 @@ class RoomWorkerStore(SQLBaseStore):
room_count = txn.fetchone()
return rooms, room_count[0]
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_rooms_paginate", _get_rooms_paginate_txn,
)
@@ -519,7 +521,7 @@ class RoomWorkerStore(SQLBaseStore):
of RatelimitOverride are None or 0 then ratelimitng has been
disabled for that user entirely.
"""
- row = await self.db.simple_select_one(
+ row = await self.db_pool.simple_select_one(
table="ratelimit_override",
keyvalues={"user_id": user_id},
retcols=("messages_per_second", "burst_count"),
@@ -561,9 +563,9 @@ class RoomWorkerStore(SQLBaseStore):
(room_id,),
)
- return self.db.cursor_to_dict(txn)
+ return self.db_pool.cursor_to_dict(txn)
- ret = await self.db.runInteraction(
+ ret = await self.db_pool.runInteraction(
"get_retention_policy_for_room", get_retention_policy_for_room_txn,
)
@@ -613,7 +615,7 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_media_ids_in_room", _get_media_mxcs_in_room_txn
)
@@ -630,7 +632,7 @@ class RoomWorkerStore(SQLBaseStore):
txn, local_mxcs, remote_mxcs, quarantined_by
)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
@@ -714,7 +716,7 @@ class RoomWorkerStore(SQLBaseStore):
txn, local_mxcs, remote_mxcs, quarantined_by
)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"quarantine_media_by_user", _quarantine_media_by_id_txn
)
@@ -730,7 +732,7 @@ class RoomWorkerStore(SQLBaseStore):
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"quarantine_media_by_user", _quarantine_media_by_user_txn
)
@@ -848,7 +850,7 @@ class RoomWorkerStore(SQLBaseStore):
return updates, upto_token, limited
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms
)
@@ -857,21 +859,21 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column"
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.config = hs.config
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"insert_room_retention", self._background_insert_retention,
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE,
self._remove_tombstoned_rooms_from_directory,
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.ADD_ROOMS_ROOM_VERSION_COLUMN,
self._background_add_rooms_room_version_column,
)
@@ -900,7 +902,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
(last_room, batch_size),
)
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
if not rows:
return True
@@ -912,7 +914,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
ev = db_to_json(row["json"])
retention_policy = ev["content"]
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn=txn,
table="room_retention",
values={
@@ -925,7 +927,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
logger.info("Inserted %d rows into room_retention", len(rows))
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
)
@@ -934,12 +936,12 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
else:
return False
- end = await self.db.runInteraction(
+ end = await self.db_pool.runInteraction(
"insert_room_retention", _background_insert_retention_txn,
)
if end:
- await self.db.updates._end_background_update("insert_room_retention")
+ await self.db_pool.updates._end_background_update("insert_room_retention")
return batch_size
@@ -983,7 +985,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
# mainly for paranoia as much badness would happen if we don't
# insert the row and then try and get the room version for the
# room.
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="rooms",
keyvalues={"room_id": room_id},
@@ -992,19 +994,19 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
)
new_last_room_id = room_id
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, self.ADD_ROOMS_ROOM_VERSION_COLUMN, {"room_id": new_last_room_id}
)
return False
- end = await self.db.runInteraction(
+ end = await self.db_pool.runInteraction(
"_background_add_rooms_room_version_column",
_background_add_rooms_room_version_column_txn,
)
if end:
- await self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.ADD_ROOMS_ROOM_VERSION_COLUMN
)
@@ -1038,12 +1040,12 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return [row[0] for row in txn]
- rooms = await self.db.runInteraction(
+ rooms = await self.db_pool.runInteraction(
"get_tombstoned_directory_rooms", _get_rooms
)
if not rooms:
- await self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE
)
return 0
@@ -1052,7 +1054,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
logger.info("Removing tombstoned room %s from the directory", room_id)
await self.set_room_is_public(room_id, False)
- await self.db.updates._background_update_progress(
+ await self.db_pool.updates._background_update_progress(
self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]}
)
@@ -1068,7 +1070,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomStore, self).__init__(database, db_conn, hs)
self.config = hs.config
@@ -1079,7 +1081,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
Called when we join a room over federation, and overwrites any room version
currently in the table.
"""
- await self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
desc="upsert_room_on_join",
table="rooms",
keyvalues={"room_id": room_id},
@@ -1111,7 +1113,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
try:
def store_room_txn(txn, next_id):
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
"rooms",
{
@@ -1122,7 +1124,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
if is_public:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="public_room_list_stream",
values={
@@ -1133,7 +1135,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
with self._public_room_id_gen.get_next() as next_id:
- await self.db.runInteraction("store_room_txn", store_room_txn, next_id)
+ await self.db_pool.runInteraction(
+ "store_room_txn", store_room_txn, next_id
+ )
except Exception as e:
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
@@ -1143,7 +1147,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
When we receive an invite over federation, store the version of the room if we
don't already know the room version.
"""
- await self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
desc="maybe_store_room_on_invite",
table="rooms",
keyvalues={"room_id": room_id},
@@ -1160,14 +1164,14 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
async def set_room_is_public(self, room_id, is_public):
def set_room_is_public_txn(txn, next_id):
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
table="rooms",
keyvalues={"room_id": room_id},
updatevalues={"is_public": is_public},
)
- entries = self.db.simple_select_list_txn(
+ entries = self.db_pool.simple_select_list_txn(
txn,
table="public_room_list_stream",
keyvalues={
@@ -1185,7 +1189,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
add_to_stream = bool(entries[-1]["visibility"]) != is_public
if add_to_stream:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="public_room_list_stream",
values={
@@ -1198,7 +1202,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
with self._public_room_id_gen.get_next() as next_id:
- await self.db.runInteraction(
+ await self.db_pool.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id
)
self.hs.get_notifier().on_new_replication_data()
@@ -1224,7 +1228,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def set_room_is_public_appservice_txn(txn, next_id):
if is_public:
try:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="appservice_room_list",
values={
@@ -1237,7 +1241,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
# We've already inserted, nothing to do.
return
else:
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="appservice_room_list",
keyvalues={
@@ -1247,7 +1251,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- entries = self.db.simple_select_list_txn(
+ entries = self.db_pool.simple_select_list_txn(
txn,
table="public_room_list_stream",
keyvalues={
@@ -1265,7 +1269,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
add_to_stream = bool(entries[-1]["visibility"]) != is_public
if add_to_stream:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="public_room_list_stream",
values={
@@ -1278,7 +1282,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
with self._public_room_id_gen.get_next() as next_id:
- await self.db.runInteraction(
+ await self.db_pool.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn,
next_id,
@@ -1295,13 +1299,13 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
row = txn.fetchone()
return row[0] or 0
- return self.db.runInteraction("get_rooms", f)
+ return self.db_pool.runInteraction("get_rooms", f)
def add_event_report(
self, room_id, event_id, user_id, reason, content, received_ts
):
next_id = self._event_reports_id_gen.get_next()
- return self.db.simple_insert(
+ return self.db_pool.simple_insert(
table="event_reports",
values={
"id": next_id,
@@ -1325,14 +1329,14 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
room_id: Room to block
user_id: Who blocked it
"""
- await self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
table="blocked_rooms",
keyvalues={"room_id": room_id},
values={},
insertion_values={"user_id": user_id},
desc="block_room",
)
- await self.db.runInteraction(
+ await self.db_pool.runInteraction(
"block_room_invalidation",
self._invalidate_cache_and_stream,
self.is_room_blocked,
@@ -1388,7 +1392,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
txn.execute(sql, args)
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
rooms_dict = {}
for row in rows:
@@ -1404,7 +1408,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
txn.execute(sql)
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
# If a room isn't already in the dict (i.e. it doesn't have a retention
# policy in its state), add it with a null policy.
@@ -1417,7 +1421,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
return rooms_dict
- rooms = await self.db.runInteraction(
+ rooms = await self.db_pool.runInteraction(
"get_rooms_for_retention_period_in_range",
get_rooms_for_retention_period_in_range_txn,
)
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/databases/main/roommember.py
index a92e401e88..b2fcfc9bfe 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -15,11 +15,13 @@
# limitations under the License.
import logging
-from typing import Iterable, List, Set
+from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import (
@@ -28,8 +30,8 @@ from synapse.storage._base import (
db_to_json,
make_in_list_sql_clause,
)
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import Sqlite3Engine
from synapse.storage.roommember import (
GetRoomsForUserWithStreamOrdering,
@@ -40,9 +42,12 @@ from synapse.storage.roommember import (
from synapse.types import Collection, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from synapse.util.metrics import Measure
+if TYPE_CHECKING:
+ from synapse.state import _StateCacheEntry
+
logger = logging.getLogger(__name__)
@@ -51,7 +56,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
class RoomMemberWorkerStore(EventsWorkerStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs)
# Is the current_state_events.membership up to date? Or is the
@@ -116,7 +121,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(query)
return list(txn)[0][0]
- count = yield self.db.runInteraction("get_known_servers", _transact)
+ count = yield self.db_pool.runInteraction("get_known_servers", _transact)
# We always know about ourselves, even if we have nothing in
# room_memberships (for example, the server is new).
@@ -128,7 +133,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
membership column is up to date
"""
- pending_update = self.db.simple_select_one_txn(
+ pending_update = self.db_pool.simple_select_one_txn(
txn,
table="background_updates",
keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME},
@@ -144,18 +149,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
15.0,
run_as_background_process,
"_check_safe_current_state_events_membership_updated",
- self.db.runInteraction,
+ self.db_pool.runInteraction,
"_check_safe_current_state_events_membership_updated",
self._check_safe_current_state_events_membership_updated_txn,
)
@cached(max_entries=100000, iterable=True)
- def get_users_in_room(self, room_id):
- return self.db.runInteraction(
+ def get_users_in_room(self, room_id: str):
+ return self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
)
- def get_users_in_room_txn(self, txn, room_id):
+ def get_users_in_room_txn(self, txn, room_id: str) -> List[str]:
# If we can assume current_state_events.membership is up to date
# then we can avoid a join, which is a Very Good Thing given how
# frequently this function gets called.
@@ -178,11 +183,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return [r[0] for r in txn]
@cached(max_entries=100000)
- def get_room_summary(self, room_id):
+ def get_room_summary(self, room_id: str):
""" Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
Args:
- room_id (str): The room ID to query
+ room_id: The room ID to query
Returns:
Deferred[dict[str, MemberSummary]:
dict of membership states, pointing to a MemberSummary named tuple.
@@ -259,80 +264,61 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return res
- return self.db.runInteraction("get_room_summary", _get_room_summary_txn)
-
- def _get_user_counts_in_room_txn(self, txn, room_id):
- """
- Get the user count in a room by membership.
-
- Args:
- room_id (str)
- membership (Membership)
-
- Returns:
- Deferred[int]
- """
- sql = """
- SELECT m.membership, count(*) FROM room_memberships as m
- INNER JOIN current_state_events as c USING(event_id)
- WHERE c.type = 'm.room.member' AND c.room_id = ?
- GROUP BY m.membership
- """
-
- txn.execute(sql, (room_id,))
- return {row[0]: row[1] for row in txn}
+ return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
@cached()
- def get_invited_rooms_for_local_user(self, user_id):
- """ Get all the rooms the *local* user is invited to
+ def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]:
+ """Get all the rooms the *local* user is invited to.
Args:
- user_id (str): The user ID.
+ user_id: The user ID.
+
Returns:
- A deferred list of RoomsForUser.
+ A awaitable list of RoomsForUser.
"""
return self.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE]
)
- @defer.inlineCallbacks
- def get_invite_for_local_user_in_room(self, user_id, room_id):
- """Gets the invite for the given *local* user and room
+ async def get_invite_for_local_user_in_room(
+ self, user_id: str, room_id: str
+ ) -> Optional[RoomsForUser]:
+ """Gets the invite for the given *local* user and room.
Args:
- user_id (str)
- room_id (str)
+ user_id: The user ID to find the invite of.
+ room_id: The room to user was invited to.
Returns:
- Deferred: Resolves to either a RoomsForUser or None if no invite was
- found.
+ Either a RoomsForUser or None if no invite was found.
"""
- invites = yield self.get_invited_rooms_for_local_user(user_id)
+ invites = await self.get_invited_rooms_for_local_user(user_id)
for invite in invites:
if invite.room_id == room_id:
return invite
return None
- @defer.inlineCallbacks
- def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list):
- """ Get all the rooms for this *local* user where the membership for this user
+ async def get_rooms_for_local_user_where_membership_is(
+ self, user_id: str, membership_list: List[str]
+ ) -> Optional[List[RoomsForUser]]:
+ """Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
Filters out forgotten rooms.
Args:
- user_id (str): The user ID.
- membership_list (list): A list of synapse.api.constants.Membership
- values which the user must be in.
+ user_id: The user ID.
+ membership_list: A list of synapse.api.constants.Membership
+ values which the user must be in.
Returns:
- Deferred[list[RoomsForUser]]
+ The RoomsForUser that the user matches the membership types.
"""
if not membership_list:
- return defer.succeed(None)
+ return None
- rooms = yield self.db.runInteraction(
+ rooms = await self.db_pool.runInteraction(
"get_rooms_for_local_user_where_membership_is",
self._get_rooms_for_local_user_where_membership_is_txn,
user_id,
@@ -340,12 +326,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
# Now we filter out forgotten rooms
- forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
+ forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id)
return [room for room in rooms if room.room_id not in forgotten_rooms]
def _get_rooms_for_local_user_where_membership_is_txn(
- self, txn, user_id, membership_list
- ):
+ self, txn, user_id: str, membership_list: List[str]
+ ) -> List[RoomsForUser]:
# Paranoia check.
if not self.hs.is_mine_id(user_id):
raise Exception(
@@ -369,32 +355,32 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
txn.execute(sql, (user_id, *args))
- results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
+ results = [RoomsForUser(**r) for r in self.db_pool.cursor_to_dict(txn)]
return results
@cached(max_entries=500000, iterable=True)
- def get_rooms_for_user_with_stream_ordering(self, user_id):
+ def get_rooms_for_user_with_stream_ordering(self, user_id: str):
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
participating in.
Args:
- user_id (str)
+ user_id
Returns:
Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
the rooms the user is in currently, along with the stream ordering
of the most recent join for that user and room.
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_rooms_for_user_with_stream_ordering",
self._get_rooms_for_user_with_stream_ordering_txn,
user_id,
)
- def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id):
+ def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str):
# We use `current_state_events` here and not `local_current_membership`
# as a) this gets called with remote users and b) this only gets called
# for rooms the server is participating in.
@@ -453,42 +439,44 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return {row[0] for row in txn}
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"get_users_server_still_shares_room_with",
_get_users_server_still_shares_room_with_txn,
)
- @defer.inlineCallbacks
- def get_rooms_for_user(self, user_id, on_invalidate=None):
+ async def get_rooms_for_user(self, user_id: str, on_invalidate=None):
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
participating in.
"""
- rooms = yield self.get_rooms_for_user_with_stream_ordering(
+ rooms = await self.get_rooms_for_user_with_stream_ordering(
user_id, on_invalidate=on_invalidate
)
return frozenset(r.room_id for r in rooms)
- @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
- def get_users_who_share_room_with_user(self, user_id, cache_context):
+ @cached(max_entries=500000, cache_context=True, iterable=True)
+ async def get_users_who_share_room_with_user(
+ self, user_id: str, cache_context: _CacheContext
+ ) -> Set[str]:
"""Returns the set of users who share a room with `user_id`
"""
- room_ids = yield self.get_rooms_for_user(
+ room_ids = await self.get_rooms_for_user(
user_id, on_invalidate=cache_context.invalidate
)
user_who_share_room = set()
for room_id in room_ids:
- user_ids = yield self.get_users_in_room(
+ user_ids = await self.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate
)
user_who_share_room.update(user_ids)
return user_who_share_room
- @defer.inlineCallbacks
- def get_joined_users_from_context(self, event, context):
+ async def get_joined_users_from_context(
+ self, event: EventBase, context: EventContext
+ ):
state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -497,14 +485,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
- current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
- result = yield self._get_joined_users_from_context(
+ current_state_ids = await context.get_current_state_ids()
+ return await self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)
- return result
- @defer.inlineCallbacks
- def get_joined_users_from_state(self, room_id, state_entry):
+ async def get_joined_users_from_state(self, room_id, state_entry):
state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -514,16 +500,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
state_group = object()
with Measure(self._clock, "get_joined_users_from_state"):
- return (
- yield self._get_joined_users_from_context(
- room_id, state_group, state_entry.state, context=state_entry
- )
+ return await self._get_joined_users_from_context(
+ room_id, state_group, state_entry.state, context=state_entry
)
- @cachedInlineCallbacks(
- num_args=2, cache_context=True, iterable=True, max_entries=100000
- )
- def _get_joined_users_from_context(
+ @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
+ async def _get_joined_users_from_context(
self,
room_id,
state_group,
@@ -535,7 +517,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
# with a state_group of None are likely to be different.
- # See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
users_in_room = {}
@@ -588,7 +569,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
missing_member_event_ids.append(event_id)
if missing_member_event_ids:
- event_to_memberships = yield self._get_joined_profiles_from_event_ids(
+ event_to_memberships = await self._get_joined_profiles_from_event_ids(
missing_member_event_ids
)
users_in_room.update((row for row in event_to_memberships.values() if row))
@@ -612,19 +593,19 @@ class RoomMemberWorkerStore(EventsWorkerStore):
list_name="event_ids",
inlineCallbacks=True,
)
- def _get_joined_profiles_from_event_ids(self, event_ids):
+ def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.
Args:
- event_ids (Iterable[str]): The member event IDs to lookup
+ event_ids: The member event IDs to lookup
Returns:
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
to `user_id` and ProfileInfo (or None if not join event).
"""
- rows = yield self.db.simple_select_many_batch(
+ rows = yield self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
@@ -644,8 +625,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
for row in rows
}
- @cachedInlineCallbacks(max_entries=10000)
- def is_host_joined(self, room_id, host):
+ @cached(max_entries=10000)
+ async def is_host_joined(self, room_id: str, host: str) -> bool:
if "%" in host or "_" in host:
raise Exception("Invalid host name")
@@ -664,47 +645,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# the returned user actually has the correct domain.
like_clause = "%:" + host
- rows = yield self.db.execute("is_host_joined", None, sql, room_id, like_clause)
-
- if not rows:
- return False
-
- user_id = rows[0][0]
- if get_domain_from_id(user_id) != host:
- # This can only happen if the host name has something funky in it
- raise Exception("Invalid host name")
-
- return True
-
- @cachedInlineCallbacks()
- def was_host_joined(self, room_id, host):
- """Check whether the server is or ever was in the room.
-
- Args:
- room_id (str)
- host (str)
-
- Returns:
- Deferred: Resolves to True if the host is/was in the room, otherwise
- False.
- """
- if "%" in host or "_" in host:
- raise Exception("Invalid host name")
-
- sql = """
- SELECT user_id FROM room_memberships
- WHERE room_id = ?
- AND user_id LIKE ?
- AND membership = 'join'
- LIMIT 1
- """
-
- # We do need to be careful to ensure that host doesn't have any wild cards
- # in it, but we checked above for known ones and we'll check below that
- # the returned user actually has the correct domain.
- like_clause = "%:" + host
-
- rows = yield self.db.execute("was_host_joined", None, sql, room_id, like_clause)
+ rows = await self.db_pool.execute(
+ "is_host_joined", None, sql, room_id, like_clause
+ )
if not rows:
return False
@@ -716,8 +659,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
- @defer.inlineCallbacks
- def get_joined_hosts(self, room_id, state_entry):
+ async def get_joined_hosts(self, room_id: str, state_entry):
state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
@@ -727,32 +669,28 @@ class RoomMemberWorkerStore(EventsWorkerStore):
state_group = object()
with Measure(self._clock, "get_joined_hosts"):
- return (
- yield self._get_joined_hosts(
- room_id, state_group, state_entry.state, state_entry=state_entry
- )
+ return await self._get_joined_hosts(
+ room_id, state_group, state_entry.state, state_entry=state_entry
)
- @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
- # @defer.inlineCallbacks
- def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
+ @cached(num_args=2, max_entries=10000, iterable=True)
+ async def _get_joined_hosts(
+ self, room_id, state_group, current_state_ids, state_entry
+ ):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
- # See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
- cache = yield self._get_joined_hosts_cache(room_id)
- joined_hosts = yield cache.get_destinations(state_entry)
-
- return joined_hosts
+ cache = await self._get_joined_hosts_cache(room_id)
+ return await cache.get_destinations(state_entry)
@cached(max_entries=10000)
- def _get_joined_hosts_cache(self, room_id):
+ def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
return _JoinedHostsCache(self, room_id)
- @cachedInlineCallbacks(num_args=2)
- def did_forget(self, user_id, room_id):
+ @cached(num_args=2)
+ async def did_forget(self, user_id: str, room_id: str) -> bool:
"""Returns whether user_id has elected to discard history for room_id.
Returns False if they have since re-joined."""
@@ -774,15 +712,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
rows = txn.fetchall()
return rows[0][0]
- count = yield self.db.runInteraction("did_forget_membership", f)
+ count = await self.db_pool.runInteraction("did_forget_membership", f)
return count == 0
@cached()
- def get_forgotten_rooms_for_user(self, user_id):
+ def get_forgotten_rooms_for_user(self, user_id: str):
"""Gets all rooms the user has forgotten.
Args:
- user_id (str)
+ user_id
Returns:
Deferred[set[str]]
@@ -811,22 +749,21 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, (user_id,))
return {row[0] for row in txn if row[1] == 0}
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
)
- @defer.inlineCallbacks
- def get_rooms_user_has_been_in(self, user_id):
+ async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:
"""Get all rooms that the user has ever been in.
Args:
- user_id (str)
+ user_id: The user ID to get the rooms of.
Returns:
- Deferred[set[str]]: Set of room IDs.
+ Set of room IDs.
"""
- room_ids = yield self.db.simple_select_onecol(
+ room_ids = await self.db_pool.simple_select_onecol(
table="room_memberships",
keyvalues={"membership": Membership.JOIN, "user_id": user_id},
retcol="room_id",
@@ -841,7 +778,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""Get user_id and membership of a set of event IDs.
"""
- return self.db.simple_select_many_batch(
+ return self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
@@ -877,23 +814,23 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return bool(txn.fetchone())
- return await self.db.runInteraction(
+ return await self.db_pool.runInteraction(
"is_local_host_in_room_ignoring_users",
_is_local_host_in_room_ignoring_users_txn,
)
class RoomMemberBackgroundUpdateStore(SQLBaseStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
self._background_current_state_membership,
)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
"room_membership_forgotten_idx",
index_name="room_memberships_user_room_forgotten",
table="room_memberships",
@@ -901,8 +838,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
where_clause="forgotten = 1",
)
- @defer.inlineCallbacks
- def _background_add_membership_profile(self, progress, batch_size):
+ async def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get(
"target_min_stream_id_inclusive", self._min_stream_order_on_start
)
@@ -926,7 +862,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
if not rows:
return 0
@@ -961,25 +897,24 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
"max_stream_id_exclusive": min_stream_id,
}
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress
)
return len(rows)
- result = yield self.db.runInteraction(
+ result = await self.db_pool.runInteraction(
_MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
)
if not result:
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
_MEMBERSHIP_PROFILE_UPDATE_NAME
)
return result
- @defer.inlineCallbacks
- def _background_current_state_membership(self, progress, batch_size):
+ async def _background_current_state_membership(self, progress, batch_size):
"""Update the new membership column on current_state_events.
This works by iterating over all rooms in alphebetical order.
@@ -1013,7 +948,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
last_processed_room = next_room
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn,
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
{"last_processed_room": last_processed_room},
@@ -1025,14 +960,14 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
# string, which will compare before all room IDs correctly.
last_processed_room = progress.get("last_processed_room", "")
- row_count, finished = yield self.db.runInteraction(
+ row_count, finished = await self.db_pool.runInteraction(
"_background_current_state_membership_update",
_background_current_state_membership_txn,
last_processed_room,
)
if finished:
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME
)
@@ -1040,10 +975,10 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomMemberStore, self).__init__(database, db_conn, hs)
- def forget(self, user_id, room_id):
+ def forget(self, user_id: str, room_id: str):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
@@ -1064,7 +999,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
- return self.db.runInteraction("forget_membership", f)
+ return self.db_pool.runInteraction("forget_membership", f)
class _JoinedHostsCache(object):
@@ -1084,17 +1019,19 @@ class _JoinedHostsCache(object):
self._len = 0
- @defer.inlineCallbacks
- def get_destinations(self, state_entry):
+ async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]:
"""Get set of destinations for a state entry
Args:
- state_entry(synapse.state._StateCacheEntry)
+ state_entry
+
+ Returns:
+ The destinations as a set.
"""
if state_entry.state_group == self.state_group:
return frozenset(self.hosts_to_joined_users)
- with (yield self.linearizer.queue(())):
+ with (await self.linearizer.queue(())):
if state_entry.state_group == self.state_group:
pass
elif state_entry.prev_group == self.state_group:
@@ -1106,7 +1043,7 @@ class _JoinedHostsCache(object):
user_id = state_key
known_joins = self.hosts_to_joined_users.setdefault(host, set())
- event = yield self.store.get_event(event_id)
+ event = await self.store.get_event(event_id)
if event.membership == Membership.JOIN:
known_joins.add(user_id)
else:
@@ -1115,7 +1052,7 @@ class _JoinedHostsCache(object):
if not known_joins:
self.hosts_to_joined_users.pop(host, None)
else:
- joined_users = yield self.store.get_joined_users_from_state(
+ joined_users = await self.store.get_joined_users_from_state(
self.room_id, state_entry
)
diff --git a/synapse/storage/data_stores/main/schema/delta/12/v12.sql b/synapse/storage/databases/main/schema/delta/12/v12.sql
index 5964c5aaac..5964c5aaac 100644
--- a/synapse/storage/data_stores/main/schema/delta/12/v12.sql
+++ b/synapse/storage/databases/main/schema/delta/12/v12.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/13/v13.sql b/synapse/storage/databases/main/schema/delta/13/v13.sql
index f8649e5d99..f8649e5d99 100644
--- a/synapse/storage/data_stores/main/schema/delta/13/v13.sql
+++ b/synapse/storage/databases/main/schema/delta/13/v13.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/14/v14.sql b/synapse/storage/databases/main/schema/delta/14/v14.sql
index a831920da6..a831920da6 100644
--- a/synapse/storage/data_stores/main/schema/delta/14/v14.sql
+++ b/synapse/storage/databases/main/schema/delta/14/v14.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql b/synapse/storage/databases/main/schema/delta/15/appservice_txns.sql
index e4f5e76aec..e4f5e76aec 100644
--- a/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql
+++ b/synapse/storage/databases/main/schema/delta/15/appservice_txns.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql b/synapse/storage/databases/main/schema/delta/15/presence_indices.sql
index 6b8d0f1ca7..6b8d0f1ca7 100644
--- a/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql
+++ b/synapse/storage/databases/main/schema/delta/15/presence_indices.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/15/v15.sql b/synapse/storage/databases/main/schema/delta/15/v15.sql
index 9523d2bcc3..9523d2bcc3 100644
--- a/synapse/storage/data_stores/main/schema/delta/15/v15.sql
+++ b/synapse/storage/databases/main/schema/delta/15/v15.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql b/synapse/storage/databases/main/schema/delta/16/events_order_index.sql
index a48f215170..a48f215170 100644
--- a/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql
+++ b/synapse/storage/databases/main/schema/delta/16/events_order_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql b/synapse/storage/databases/main/schema/delta/16/remote_media_cache_index.sql
index 7a15265cb1..7a15265cb1 100644
--- a/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql
+++ b/synapse/storage/databases/main/schema/delta/16/remote_media_cache_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql b/synapse/storage/databases/main/schema/delta/16/remove_duplicates.sql
index 65c97b5e2f..65c97b5e2f 100644
--- a/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql
+++ b/synapse/storage/databases/main/schema/delta/16/remove_duplicates.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql b/synapse/storage/databases/main/schema/delta/16/room_alias_index.sql
index f82486132b..f82486132b 100644
--- a/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql
+++ b/synapse/storage/databases/main/schema/delta/16/room_alias_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql b/synapse/storage/databases/main/schema/delta/16/unique_constraints.sql
index 5b8de52c33..5b8de52c33 100644
--- a/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql
+++ b/synapse/storage/databases/main/schema/delta/16/unique_constraints.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/16/users.sql b/synapse/storage/databases/main/schema/delta/16/users.sql
index cd0709250d..cd0709250d 100644
--- a/synapse/storage/data_stores/main/schema/delta/16/users.sql
+++ b/synapse/storage/databases/main/schema/delta/16/users.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql b/synapse/storage/databases/main/schema/delta/17/drop_indexes.sql
index 7c9a90e27f..7c9a90e27f 100644
--- a/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql
+++ b/synapse/storage/databases/main/schema/delta/17/drop_indexes.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql b/synapse/storage/databases/main/schema/delta/17/server_keys.sql
index 70b247a06b..70b247a06b 100644
--- a/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql
+++ b/synapse/storage/databases/main/schema/delta/17/server_keys.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql b/synapse/storage/databases/main/schema/delta/17/user_threepids.sql
index c17715ac80..c17715ac80 100644
--- a/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql
+++ b/synapse/storage/databases/main/schema/delta/17/user_threepids.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql b/synapse/storage/databases/main/schema/delta/18/server_keys_bigger_ints.sql
index 6e0871c92b..6e0871c92b 100644
--- a/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql
+++ b/synapse/storage/databases/main/schema/delta/18/server_keys_bigger_ints.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/19/event_index.sql b/synapse/storage/databases/main/schema/delta/19/event_index.sql
index 18b97b4332..18b97b4332 100644
--- a/synapse/storage/data_stores/main/schema/delta/19/event_index.sql
+++ b/synapse/storage/databases/main/schema/delta/19/event_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/20/dummy.sql b/synapse/storage/databases/main/schema/delta/20/dummy.sql
index e0ac49d1ec..e0ac49d1ec 100644
--- a/synapse/storage/data_stores/main/schema/delta/20/dummy.sql
+++ b/synapse/storage/databases/main/schema/delta/20/dummy.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/20/pushers.py b/synapse/storage/databases/main/schema/delta/20/pushers.py
index 3edfcfd783..3edfcfd783 100644
--- a/synapse/storage/data_stores/main/schema/delta/20/pushers.py
+++ b/synapse/storage/databases/main/schema/delta/20/pushers.py
diff --git a/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql b/synapse/storage/databases/main/schema/delta/21/end_to_end_keys.sql
index 4c2fb20b77..4c2fb20b77 100644
--- a/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql
+++ b/synapse/storage/databases/main/schema/delta/21/end_to_end_keys.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/21/receipts.sql b/synapse/storage/databases/main/schema/delta/21/receipts.sql
index d070845477..d070845477 100644
--- a/synapse/storage/data_stores/main/schema/delta/21/receipts.sql
+++ b/synapse/storage/databases/main/schema/delta/21/receipts.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql b/synapse/storage/databases/main/schema/delta/22/receipts_index.sql
index bfc0b3bcaa..bfc0b3bcaa 100644
--- a/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql
+++ b/synapse/storage/databases/main/schema/delta/22/receipts_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql b/synapse/storage/databases/main/schema/delta/22/user_threepids_unique.sql
index 87edfa454c..87edfa454c 100644
--- a/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql
+++ b/synapse/storage/databases/main/schema/delta/22/user_threepids_unique.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql b/synapse/storage/databases/main/schema/delta/24/stats_reporting.sql
index acea7483bd..acea7483bd 100644
--- a/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql
+++ b/synapse/storage/databases/main/schema/delta/24/stats_reporting.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/25/fts.py b/synapse/storage/databases/main/schema/delta/25/fts.py
index ee675e71ff..ee675e71ff 100644
--- a/synapse/storage/data_stores/main/schema/delta/25/fts.py
+++ b/synapse/storage/databases/main/schema/delta/25/fts.py
diff --git a/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql b/synapse/storage/databases/main/schema/delta/25/guest_access.sql
index 1ea389b471..1ea389b471 100644
--- a/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql
+++ b/synapse/storage/databases/main/schema/delta/25/guest_access.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql b/synapse/storage/databases/main/schema/delta/25/history_visibility.sql
index f468fc1897..f468fc1897 100644
--- a/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql
+++ b/synapse/storage/databases/main/schema/delta/25/history_visibility.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/25/tags.sql b/synapse/storage/databases/main/schema/delta/25/tags.sql
index 7a32ce68e4..7a32ce68e4 100644
--- a/synapse/storage/data_stores/main/schema/delta/25/tags.sql
+++ b/synapse/storage/databases/main/schema/delta/25/tags.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/26/account_data.sql b/synapse/storage/databases/main/schema/delta/26/account_data.sql
index e395de2b5e..e395de2b5e 100644
--- a/synapse/storage/data_stores/main/schema/delta/26/account_data.sql
+++ b/synapse/storage/databases/main/schema/delta/26/account_data.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/27/account_data.sql b/synapse/storage/databases/main/schema/delta/27/account_data.sql
index bf0558b5b3..bf0558b5b3 100644
--- a/synapse/storage/data_stores/main/schema/delta/27/account_data.sql
+++ b/synapse/storage/databases/main/schema/delta/27/account_data.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql b/synapse/storage/databases/main/schema/delta/27/forgotten_memberships.sql
index e2094f37fe..e2094f37fe 100644
--- a/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql
+++ b/synapse/storage/databases/main/schema/delta/27/forgotten_memberships.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/27/ts.py b/synapse/storage/databases/main/schema/delta/27/ts.py
index b7972cfa8e..b7972cfa8e 100644
--- a/synapse/storage/data_stores/main/schema/delta/27/ts.py
+++ b/synapse/storage/databases/main/schema/delta/27/ts.py
diff --git a/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql b/synapse/storage/databases/main/schema/delta/28/event_push_actions.sql
index 4d519849df..4d519849df 100644
--- a/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql
+++ b/synapse/storage/databases/main/schema/delta/28/event_push_actions.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql b/synapse/storage/databases/main/schema/delta/28/events_room_stream.sql
index 36609475f1..36609475f1 100644
--- a/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql
+++ b/synapse/storage/databases/main/schema/delta/28/events_room_stream.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql b/synapse/storage/databases/main/schema/delta/28/public_roms_index.sql
index 6c1fd68c5b..6c1fd68c5b 100644
--- a/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql
+++ b/synapse/storage/databases/main/schema/delta/28/public_roms_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql b/synapse/storage/databases/main/schema/delta/28/receipts_user_id_index.sql
index cb84c69baa..cb84c69baa 100644
--- a/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql
+++ b/synapse/storage/databases/main/schema/delta/28/receipts_user_id_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql b/synapse/storage/databases/main/schema/delta/28/upgrade_times.sql
index 3e4a9ab455..3e4a9ab455 100644
--- a/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql
+++ b/synapse/storage/databases/main/schema/delta/28/upgrade_times.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql b/synapse/storage/databases/main/schema/delta/28/users_is_guest.sql
index 21d2b420bf..21d2b420bf 100644
--- a/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql
+++ b/synapse/storage/databases/main/schema/delta/28/users_is_guest.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql b/synapse/storage/databases/main/schema/delta/29/push_actions.sql
index 84b21cf813..84b21cf813 100644
--- a/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql
+++ b/synapse/storage/databases/main/schema/delta/29/push_actions.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql b/synapse/storage/databases/main/schema/delta/30/alias_creator.sql
index c9d0dde638..c9d0dde638 100644
--- a/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql
+++ b/synapse/storage/databases/main/schema/delta/30/alias_creator.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/30/as_users.py b/synapse/storage/databases/main/schema/delta/30/as_users.py
index b42c02710a..b42c02710a 100644
--- a/synapse/storage/data_stores/main/schema/delta/30/as_users.py
+++ b/synapse/storage/databases/main/schema/delta/30/as_users.py
diff --git a/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql b/synapse/storage/databases/main/schema/delta/30/deleted_pushers.sql
index 712c454aa1..712c454aa1 100644
--- a/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql
+++ b/synapse/storage/databases/main/schema/delta/30/deleted_pushers.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql b/synapse/storage/databases/main/schema/delta/30/presence_stream.sql
index 606bbb037d..606bbb037d 100644
--- a/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql
+++ b/synapse/storage/databases/main/schema/delta/30/presence_stream.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql b/synapse/storage/databases/main/schema/delta/30/public_rooms.sql
index f09db4faa6..f09db4faa6 100644
--- a/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql
+++ b/synapse/storage/databases/main/schema/delta/30/public_rooms.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql b/synapse/storage/databases/main/schema/delta/30/push_rule_stream.sql
index 735aa8d5f6..735aa8d5f6 100644
--- a/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql
+++ b/synapse/storage/databases/main/schema/delta/30/push_rule_stream.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/databases/main/schema/delta/30/threepid_guest_access_tokens.sql
index 0dd2f1360c..0dd2f1360c 100644
--- a/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql
+++ b/synapse/storage/databases/main/schema/delta/30/threepid_guest_access_tokens.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/31/invites.sql b/synapse/storage/databases/main/schema/delta/31/invites.sql
index 2c57846d5a..2c57846d5a 100644
--- a/synapse/storage/data_stores/main/schema/delta/31/invites.sql
+++ b/synapse/storage/databases/main/schema/delta/31/invites.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql b/synapse/storage/databases/main/schema/delta/31/local_media_repository_url_cache.sql
index 9efb4280eb..9efb4280eb 100644
--- a/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql
+++ b/synapse/storage/databases/main/schema/delta/31/local_media_repository_url_cache.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/31/pushers.py b/synapse/storage/databases/main/schema/delta/31/pushers.py
index 9bb504aad5..9bb504aad5 100644
--- a/synapse/storage/data_stores/main/schema/delta/31/pushers.py
+++ b/synapse/storage/databases/main/schema/delta/31/pushers.py
diff --git a/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql b/synapse/storage/databases/main/schema/delta/31/pushers_index.sql
index a82add88fd..a82add88fd 100644
--- a/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql
+++ b/synapse/storage/databases/main/schema/delta/31/pushers_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/31/search_update.py b/synapse/storage/databases/main/schema/delta/31/search_update.py
index 63b757ade6..63b757ade6 100644
--- a/synapse/storage/data_stores/main/schema/delta/31/search_update.py
+++ b/synapse/storage/databases/main/schema/delta/31/search_update.py
diff --git a/synapse/storage/data_stores/main/schema/delta/32/events.sql b/synapse/storage/databases/main/schema/delta/32/events.sql
index 1dd0f9e170..1dd0f9e170 100644
--- a/synapse/storage/data_stores/main/schema/delta/32/events.sql
+++ b/synapse/storage/databases/main/schema/delta/32/events.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/32/openid.sql b/synapse/storage/databases/main/schema/delta/32/openid.sql
index 36f37b11c8..36f37b11c8 100644
--- a/synapse/storage/data_stores/main/schema/delta/32/openid.sql
+++ b/synapse/storage/databases/main/schema/delta/32/openid.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql b/synapse/storage/databases/main/schema/delta/32/pusher_throttle.sql
index d86d30c13c..d86d30c13c 100644
--- a/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql
+++ b/synapse/storage/databases/main/schema/delta/32/pusher_throttle.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql b/synapse/storage/databases/main/schema/delta/32/remove_indices.sql
index 2de50d408c..2de50d408c 100644
--- a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
+++ b/synapse/storage/databases/main/schema/delta/32/remove_indices.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/32/reports.sql b/synapse/storage/databases/main/schema/delta/32/reports.sql
index d13609776f..d13609776f 100644
--- a/synapse/storage/data_stores/main/schema/delta/32/reports.sql
+++ b/synapse/storage/databases/main/schema/delta/32/reports.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/databases/main/schema/delta/33/access_tokens_device_index.sql
index 61ad3fe3e8..61ad3fe3e8 100644
--- a/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql
+++ b/synapse/storage/databases/main/schema/delta/33/access_tokens_device_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/33/devices.sql b/synapse/storage/databases/main/schema/delta/33/devices.sql
index eca7268d82..eca7268d82 100644
--- a/synapse/storage/data_stores/main/schema/delta/33/devices.sql
+++ b/synapse/storage/databases/main/schema/delta/33/devices.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql b/synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys.sql
index aa4a3b9f2f..aa4a3b9f2f 100644
--- a/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql
+++ b/synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql b/synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
index 6671573398..6671573398 100644
--- a/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
+++ b/synapse/storage/databases/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/33/event_fields.py b/synapse/storage/databases/main/schema/delta/33/event_fields.py
index a3e81eeac7..a3e81eeac7 100644
--- a/synapse/storage/data_stores/main/schema/delta/33/event_fields.py
+++ b/synapse/storage/databases/main/schema/delta/33/event_fields.py
diff --git a/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
index a26057dfb6..a26057dfb6 100644
--- a/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
diff --git a/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql b/synapse/storage/databases/main/schema/delta/33/user_ips_index.sql
index 473f75a78e..473f75a78e 100644
--- a/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql
+++ b/synapse/storage/databases/main/schema/delta/33/user_ips_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql b/synapse/storage/databases/main/schema/delta/34/appservice_stream.sql
index 69e16eda0f..69e16eda0f 100644
--- a/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql
+++ b/synapse/storage/databases/main/schema/delta/34/appservice_stream.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py b/synapse/storage/databases/main/schema/delta/34/cache_stream.py
index cf09e43e2b..cf09e43e2b 100644
--- a/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py
+++ b/synapse/storage/databases/main/schema/delta/34/cache_stream.py
diff --git a/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql b/synapse/storage/databases/main/schema/delta/34/device_inbox.sql
index e68844c74a..e68844c74a 100644
--- a/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql
+++ b/synapse/storage/databases/main/schema/delta/34/device_inbox.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql b/synapse/storage/databases/main/schema/delta/34/push_display_name_rename.sql
index 0d9fe1a99a..0d9fe1a99a 100644
--- a/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql
+++ b/synapse/storage/databases/main/schema/delta/34/push_display_name_rename.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py b/synapse/storage/databases/main/schema/delta/34/received_txn_purge.py
index 67d505e68b..67d505e68b 100644
--- a/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py
+++ b/synapse/storage/databases/main/schema/delta/34/received_txn_purge.py
diff --git a/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql b/synapse/storage/databases/main/schema/delta/35/contains_url.sql
index 6cd123027b..6cd123027b 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql
+++ b/synapse/storage/databases/main/schema/delta/35/contains_url.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql b/synapse/storage/databases/main/schema/delta/35/device_outbox.sql
index 17e6c43105..17e6c43105 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql
+++ b/synapse/storage/databases/main/schema/delta/35/device_outbox.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql b/synapse/storage/databases/main/schema/delta/35/device_stream_id.sql
index 7ab7d942e2..7ab7d942e2 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql
+++ b/synapse/storage/databases/main/schema/delta/35/device_stream_id.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql b/synapse/storage/databases/main/schema/delta/35/event_push_actions_index.sql
index 2e836d8e9c..2e836d8e9c 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql
+++ b/synapse/storage/databases/main/schema/delta/35/event_push_actions_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql b/synapse/storage/databases/main/schema/delta/35/public_room_list_change_stream.sql
index dd2bf2e28a..dd2bf2e28a 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql
+++ b/synapse/storage/databases/main/schema/delta/35/public_room_list_change_stream.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql b/synapse/storage/databases/main/schema/delta/35/stream_order_to_extrem.sql
index 2b945d8a57..2b945d8a57 100644
--- a/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql
+++ b/synapse/storage/databases/main/schema/delta/35/stream_order_to_extrem.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql b/synapse/storage/databases/main/schema/delta/36/readd_public_rooms.sql
index 90d8fd18f9..90d8fd18f9 100644
--- a/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql
+++ b/synapse/storage/databases/main/schema/delta/36/readd_public_rooms.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py b/synapse/storage/databases/main/schema/delta/37/remove_auth_idx.py
index a377884169..a377884169 100644
--- a/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py
+++ b/synapse/storage/databases/main/schema/delta/37/remove_auth_idx.py
diff --git a/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql b/synapse/storage/databases/main/schema/delta/37/user_threepids.sql
index cf7a90dd10..cf7a90dd10 100644
--- a/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql
+++ b/synapse/storage/databases/main/schema/delta/37/user_threepids.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/databases/main/schema/delta/38/postgres_fts_gist.sql
index 515e6b8e84..515e6b8e84 100644
--- a/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql
+++ b/synapse/storage/databases/main/schema/delta/38/postgres_fts_gist.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql b/synapse/storage/databases/main/schema/delta/39/appservice_room_list.sql
index 74bdc49073..74bdc49073 100644
--- a/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql
+++ b/synapse/storage/databases/main/schema/delta/39/appservice_room_list.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql b/synapse/storage/databases/main/schema/delta/39/device_federation_stream_idx.sql
index 00be801e90..00be801e90 100644
--- a/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/39/device_federation_stream_idx.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql b/synapse/storage/databases/main/schema/delta/39/event_push_index.sql
index de2ad93e5c..de2ad93e5c 100644
--- a/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql
+++ b/synapse/storage/databases/main/schema/delta/39/event_push_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql b/synapse/storage/databases/main/schema/delta/39/federation_out_position.sql
index 5af814290b..5af814290b 100644
--- a/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql
+++ b/synapse/storage/databases/main/schema/delta/39/federation_out_position.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql b/synapse/storage/databases/main/schema/delta/39/membership_profile.sql
index 1bf911c8ab..1bf911c8ab 100644
--- a/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql
+++ b/synapse/storage/databases/main/schema/delta/39/membership_profile.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql b/synapse/storage/databases/main/schema/delta/40/current_state_idx.sql
index 7ffa189f39..7ffa189f39 100644
--- a/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/40/current_state_idx.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql b/synapse/storage/databases/main/schema/delta/40/device_inbox.sql
index b9fe1f0480..b9fe1f0480 100644
--- a/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql
+++ b/synapse/storage/databases/main/schema/delta/40/device_inbox.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql b/synapse/storage/databases/main/schema/delta/40/device_list_streams.sql
index dd6dcb65f1..dd6dcb65f1 100644
--- a/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql
+++ b/synapse/storage/databases/main/schema/delta/40/device_list_streams.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql b/synapse/storage/databases/main/schema/delta/40/event_push_summary.sql
index 3918f0b794..3918f0b794 100644
--- a/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql
+++ b/synapse/storage/databases/main/schema/delta/40/event_push_summary.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/40/pushers.sql b/synapse/storage/databases/main/schema/delta/40/pushers.sql
index 054a223f14..054a223f14 100644
--- a/synapse/storage/data_stores/main/schema/delta/40/pushers.sql
+++ b/synapse/storage/databases/main/schema/delta/40/pushers.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql b/synapse/storage/databases/main/schema/delta/41/device_list_stream_idx.sql
index b7bee8b692..b7bee8b692 100644
--- a/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/41/device_list_stream_idx.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql b/synapse/storage/databases/main/schema/delta/41/device_outbound_index.sql
index 62f0b9892b..62f0b9892b 100644
--- a/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql
+++ b/synapse/storage/databases/main/schema/delta/41/device_outbound_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql b/synapse/storage/databases/main/schema/delta/41/event_search_event_id_idx.sql
index 5d9cfecf36..5d9cfecf36 100644
--- a/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/41/event_search_event_id_idx.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql b/synapse/storage/databases/main/schema/delta/41/ratelimit.sql
index a194bf0238..a194bf0238 100644
--- a/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql
+++ b/synapse/storage/databases/main/schema/delta/41/ratelimit.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql b/synapse/storage/databases/main/schema/delta/42/current_state_delta.sql
index d28851aff8..d28851aff8 100644
--- a/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql
+++ b/synapse/storage/databases/main/schema/delta/42/current_state_delta.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql b/synapse/storage/databases/main/schema/delta/42/device_list_last_id.sql
index 9ab8c14fa3..9ab8c14fa3 100644
--- a/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql
+++ b/synapse/storage/databases/main/schema/delta/42/device_list_last_id.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql b/synapse/storage/databases/main/schema/delta/42/event_auth_state_only.sql
index b8821ac759..b8821ac759 100644
--- a/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql
+++ b/synapse/storage/databases/main/schema/delta/42/event_auth_state_only.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/42/user_dir.py b/synapse/storage/databases/main/schema/delta/42/user_dir.py
index 506f326f4d..506f326f4d 100644
--- a/synapse/storage/data_stores/main/schema/delta/42/user_dir.py
+++ b/synapse/storage/databases/main/schema/delta/42/user_dir.py
diff --git a/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql b/synapse/storage/databases/main/schema/delta/43/blocked_rooms.sql
index 0e3cd143ff..0e3cd143ff 100644
--- a/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql
+++ b/synapse/storage/databases/main/schema/delta/43/blocked_rooms.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql b/synapse/storage/databases/main/schema/delta/43/quarantine_media.sql
index 630907ec4f..630907ec4f 100644
--- a/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql
+++ b/synapse/storage/databases/main/schema/delta/43/quarantine_media.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql b/synapse/storage/databases/main/schema/delta/43/url_cache.sql
index 45ebe020da..45ebe020da 100644
--- a/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql
+++ b/synapse/storage/databases/main/schema/delta/43/url_cache.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/43/user_share.sql b/synapse/storage/databases/main/schema/delta/43/user_share.sql
index ee7062abe4..ee7062abe4 100644
--- a/synapse/storage/data_stores/main/schema/delta/43/user_share.sql
+++ b/synapse/storage/databases/main/schema/delta/43/user_share.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql b/synapse/storage/databases/main/schema/delta/44/expire_url_cache.sql
index b12f9b2ebf..b12f9b2ebf 100644
--- a/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql
+++ b/synapse/storage/databases/main/schema/delta/44/expire_url_cache.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/45/group_server.sql b/synapse/storage/databases/main/schema/delta/45/group_server.sql
index b2333848a0..b2333848a0 100644
--- a/synapse/storage/data_stores/main/schema/delta/45/group_server.sql
+++ b/synapse/storage/databases/main/schema/delta/45/group_server.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql b/synapse/storage/databases/main/schema/delta/45/profile_cache.sql
index e5ddc84df0..e5ddc84df0 100644
--- a/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql
+++ b/synapse/storage/databases/main/schema/delta/45/profile_cache.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql b/synapse/storage/databases/main/schema/delta/46/drop_refresh_tokens.sql
index 68c48a89a9..68c48a89a9 100644
--- a/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql
+++ b/synapse/storage/databases/main/schema/delta/46/drop_refresh_tokens.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/databases/main/schema/delta/46/drop_unique_deleted_pushers.sql
index bb307889c1..bb307889c1 100644
--- a/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql
+++ b/synapse/storage/databases/main/schema/delta/46/drop_unique_deleted_pushers.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/46/group_server.sql b/synapse/storage/databases/main/schema/delta/46/group_server.sql
index 097679bc9a..097679bc9a 100644
--- a/synapse/storage/data_stores/main/schema/delta/46/group_server.sql
+++ b/synapse/storage/databases/main/schema/delta/46/group_server.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/databases/main/schema/delta/46/local_media_repository_url_idx.sql
index bbfc7f5d1a..bbfc7f5d1a 100644
--- a/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/46/local_media_repository_url_idx.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/databases/main/schema/delta/46/user_dir_null_room_ids.sql
index cb0d5a2576..cb0d5a2576 100644
--- a/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql
+++ b/synapse/storage/databases/main/schema/delta/46/user_dir_null_room_ids.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql b/synapse/storage/databases/main/schema/delta/46/user_dir_typos.sql
index d9505f8da1..d9505f8da1 100644
--- a/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql
+++ b/synapse/storage/databases/main/schema/delta/46/user_dir_typos.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql b/synapse/storage/databases/main/schema/delta/47/last_access_media.sql
index f505fb22b5..f505fb22b5 100644
--- a/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql
+++ b/synapse/storage/databases/main/schema/delta/47/last_access_media.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql b/synapse/storage/databases/main/schema/delta/47/postgres_fts_gin.sql
index 31d7a817eb..31d7a817eb 100644
--- a/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql
+++ b/synapse/storage/databases/main/schema/delta/47/postgres_fts_gin.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql b/synapse/storage/databases/main/schema/delta/47/push_actions_staging.sql
index edccf4a96f..edccf4a96f 100644
--- a/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql
+++ b/synapse/storage/databases/main/schema/delta/47/push_actions_staging.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql b/synapse/storage/databases/main/schema/delta/48/add_user_consent.sql
index 5237491506..5237491506 100644
--- a/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql
+++ b/synapse/storage/databases/main/schema/delta/48/add_user_consent.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/databases/main/schema/delta/48/add_user_ips_last_seen_index.sql
index 9248b0b24a..9248b0b24a 100644
--- a/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql
+++ b/synapse/storage/databases/main/schema/delta/48/add_user_ips_last_seen_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql b/synapse/storage/databases/main/schema/delta/48/deactivated_users.sql
index e9013a6969..e9013a6969 100644
--- a/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql
+++ b/synapse/storage/databases/main/schema/delta/48/deactivated_users.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py b/synapse/storage/databases/main/schema/delta/48/group_unique_indexes.py
index 49f5f2c003..49f5f2c003 100644
--- a/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py
+++ b/synapse/storage/databases/main/schema/delta/48/group_unique_indexes.py
diff --git a/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql b/synapse/storage/databases/main/schema/delta/48/groups_joinable.sql
index ce26eaf0c9..ce26eaf0c9 100644
--- a/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql
+++ b/synapse/storage/databases/main/schema/delta/48/groups_joinable.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql b/synapse/storage/databases/main/schema/delta/49/add_user_consent_server_notice_sent.sql
index 14dcf18d73..14dcf18d73 100644
--- a/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql
+++ b/synapse/storage/databases/main/schema/delta/49/add_user_consent_server_notice_sent.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql b/synapse/storage/databases/main/schema/delta/49/add_user_daily_visits.sql
index 3dd478196f..3dd478196f 100644
--- a/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql
+++ b/synapse/storage/databases/main/schema/delta/49/add_user_daily_visits.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql b/synapse/storage/databases/main/schema/delta/49/add_user_ips_last_seen_only_index.sql
index 3a4ed59b5b..3a4ed59b5b 100644
--- a/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql
+++ b/synapse/storage/databases/main/schema/delta/49/add_user_ips_last_seen_only_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql b/synapse/storage/databases/main/schema/delta/50/add_creation_ts_users_index.sql
index c93ae47532..c93ae47532 100644
--- a/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql
+++ b/synapse/storage/databases/main/schema/delta/50/add_creation_ts_users_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql b/synapse/storage/databases/main/schema/delta/50/erasure_store.sql
index 5d8641a9ab..5d8641a9ab 100644
--- a/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql
+++ b/synapse/storage/databases/main/schema/delta/50/erasure_store.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py b/synapse/storage/databases/main/schema/delta/50/make_event_content_nullable.py
index b1684a8441..b1684a8441 100644
--- a/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py
+++ b/synapse/storage/databases/main/schema/delta/50/make_event_content_nullable.py
diff --git a/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql b/synapse/storage/databases/main/schema/delta/51/e2e_room_keys.sql
index c0e66a697d..c0e66a697d 100644
--- a/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql
+++ b/synapse/storage/databases/main/schema/delta/51/e2e_room_keys.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql b/synapse/storage/databases/main/schema/delta/51/monthly_active_users.sql
index c9d537d5a3..c9d537d5a3 100644
--- a/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql
+++ b/synapse/storage/databases/main/schema/delta/51/monthly_active_users.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql b/synapse/storage/databases/main/schema/delta/52/add_event_to_state_group_index.sql
index 91e03d13e1..91e03d13e1 100644
--- a/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql
+++ b/synapse/storage/databases/main/schema/delta/52/add_event_to_state_group_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql b/synapse/storage/databases/main/schema/delta/52/device_list_streams_unique_idx.sql
index bfa49e6f92..bfa49e6f92 100644
--- a/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/52/device_list_streams_unique_idx.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql b/synapse/storage/databases/main/schema/delta/52/e2e_room_keys.sql
index db687cccae..db687cccae 100644
--- a/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql
+++ b/synapse/storage/databases/main/schema/delta/52/e2e_room_keys.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql b/synapse/storage/databases/main/schema/delta/53/add_user_type_to_users.sql
index 88ec2f83e5..88ec2f83e5 100644
--- a/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql
+++ b/synapse/storage/databases/main/schema/delta/53/add_user_type_to_users.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql b/synapse/storage/databases/main/schema/delta/53/drop_sent_transactions.sql
index e372f5a44a..e372f5a44a 100644
--- a/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql
+++ b/synapse/storage/databases/main/schema/delta/53/drop_sent_transactions.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql b/synapse/storage/databases/main/schema/delta/53/event_format_version.sql
index 1d977c2834..1d977c2834 100644
--- a/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql
+++ b/synapse/storage/databases/main/schema/delta/53/event_format_version.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql b/synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql
index ffcc896b58..ffcc896b58 100644
--- a/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql
+++ b/synapse/storage/databases/main/schema/delta/53/user_dir_populate.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql b/synapse/storage/databases/main/schema/delta/53/user_ips_index.sql
index b812c5794f..b812c5794f 100644
--- a/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql
+++ b/synapse/storage/databases/main/schema/delta/53/user_ips_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_share.sql b/synapse/storage/databases/main/schema/delta/53/user_share.sql
index 5831b1a6f8..5831b1a6f8 100644
--- a/synapse/storage/data_stores/main/schema/delta/53/user_share.sql
+++ b/synapse/storage/databases/main/schema/delta/53/user_share.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql b/synapse/storage/databases/main/schema/delta/53/user_threepid_id.sql
index 80c2c573b6..80c2c573b6 100644
--- a/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql
+++ b/synapse/storage/databases/main/schema/delta/53/user_threepid_id.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql b/synapse/storage/databases/main/schema/delta/53/users_in_public_rooms.sql
index f7827ca6d2..f7827ca6d2 100644
--- a/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql
+++ b/synapse/storage/databases/main/schema/delta/53/users_in_public_rooms.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql b/synapse/storage/databases/main/schema/delta/54/account_validity_with_renewal.sql
index 0adb2ad55e..0adb2ad55e 100644
--- a/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql
+++ b/synapse/storage/databases/main/schema/delta/54/account_validity_with_renewal.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/databases/main/schema/delta/54/add_validity_to_server_keys.sql
index c01aa9d2d9..c01aa9d2d9 100644
--- a/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql
+++ b/synapse/storage/databases/main/schema/delta/54/add_validity_to_server_keys.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/databases/main/schema/delta/54/delete_forward_extremities.sql
index b062ec840c..b062ec840c 100644
--- a/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql
+++ b/synapse/storage/databases/main/schema/delta/54/delete_forward_extremities.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql b/synapse/storage/databases/main/schema/delta/54/drop_legacy_tables.sql
index dbbe682697..dbbe682697 100644
--- a/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql
+++ b/synapse/storage/databases/main/schema/delta/54/drop_legacy_tables.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql b/synapse/storage/databases/main/schema/delta/54/drop_presence_list.sql
index e6ee70c623..e6ee70c623 100644
--- a/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql
+++ b/synapse/storage/databases/main/schema/delta/54/drop_presence_list.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/54/relations.sql b/synapse/storage/databases/main/schema/delta/54/relations.sql
index 134862b870..134862b870 100644
--- a/synapse/storage/data_stores/main/schema/delta/54/relations.sql
+++ b/synapse/storage/databases/main/schema/delta/54/relations.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/54/stats.sql b/synapse/storage/databases/main/schema/delta/54/stats.sql
index 652e58308e..652e58308e 100644
--- a/synapse/storage/data_stores/main/schema/delta/54/stats.sql
+++ b/synapse/storage/databases/main/schema/delta/54/stats.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/54/stats2.sql b/synapse/storage/databases/main/schema/delta/54/stats2.sql
index 3b2d48447f..3b2d48447f 100644
--- a/synapse/storage/data_stores/main/schema/delta/54/stats2.sql
+++ b/synapse/storage/databases/main/schema/delta/54/stats2.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql b/synapse/storage/databases/main/schema/delta/55/access_token_expiry.sql
index 4590604bfd..4590604bfd 100644
--- a/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql
+++ b/synapse/storage/databases/main/schema/delta/55/access_token_expiry.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql b/synapse/storage/databases/main/schema/delta/55/track_threepid_validations.sql
index a8eced2e0a..a8eced2e0a 100644
--- a/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql
+++ b/synapse/storage/databases/main/schema/delta/55/track_threepid_validations.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql b/synapse/storage/databases/main/schema/delta/55/users_alter_deactivated.sql
index dabdde489b..dabdde489b 100644
--- a/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql
+++ b/synapse/storage/databases/main/schema/delta/55/users_alter_deactivated.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql b/synapse/storage/databases/main/schema/delta/56/add_spans_to_device_lists.sql
index 41807eb1e7..41807eb1e7 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql
+++ b/synapse/storage/databases/main/schema/delta/56/add_spans_to_device_lists.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql b/synapse/storage/databases/main/schema/delta/56/current_state_events_membership.sql
index 473018676f..473018676f 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql
+++ b/synapse/storage/databases/main/schema/delta/56/current_state_events_membership.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql b/synapse/storage/databases/main/schema/delta/56/current_state_events_membership_mk2.sql
index 3133d42d4a..3133d42d4a 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql
+++ b/synapse/storage/databases/main/schema/delta/56/current_state_events_membership_mk2.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql b/synapse/storage/databases/main/schema/delta/56/delete_keys_from_deleted_backups.sql
index 1d2ddb1b1a..1d2ddb1b1a 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql
+++ b/synapse/storage/databases/main/schema/delta/56/delete_keys_from_deleted_backups.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql b/synapse/storage/databases/main/schema/delta/56/destinations_failure_ts.sql
index f00889290b..f00889290b 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql
+++ b/synapse/storage/databases/main/schema/delta/56/destinations_failure_ts.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres b/synapse/storage/databases/main/schema/delta/56/destinations_retry_interval_type.sql.postgres
index b9bbb18a91..b9bbb18a91 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/56/destinations_retry_interval_type.sql.postgres
diff --git a/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql b/synapse/storage/databases/main/schema/delta/56/device_stream_id_insert.sql
index c2f557fde9..c2f557fde9 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql
+++ b/synapse/storage/databases/main/schema/delta/56/device_stream_id_insert.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql b/synapse/storage/databases/main/schema/delta/56/devices_last_seen.sql
index dfa902d0ba..dfa902d0ba 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql
+++ b/synapse/storage/databases/main/schema/delta/56/devices_last_seen.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql b/synapse/storage/databases/main/schema/delta/56/drop_unused_event_tables.sql
index 9f09922c67..9f09922c67 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql
+++ b/synapse/storage/databases/main/schema/delta/56/drop_unused_event_tables.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql b/synapse/storage/databases/main/schema/delta/56/event_expiry.sql
index 81a36a8b1d..81a36a8b1d 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql
+++ b/synapse/storage/databases/main/schema/delta/56/event_expiry.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/databases/main/schema/delta/56/event_labels.sql
index 5e29c1da19..5e29c1da19 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql
+++ b/synapse/storage/databases/main/schema/delta/56/event_labels.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql b/synapse/storage/databases/main/schema/delta/56/event_labels_background_update.sql
index 5f5e0499ae..5f5e0499ae 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql
+++ b/synapse/storage/databases/main/schema/delta/56/event_labels_background_update.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql b/synapse/storage/databases/main/schema/delta/56/fix_room_keys_index.sql
index 014cb3b538..014cb3b538 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql
+++ b/synapse/storage/databases/main/schema/delta/56/fix_room_keys_index.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql b/synapse/storage/databases/main/schema/delta/56/hidden_devices.sql
index 67f8b20297..67f8b20297 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql
+++ b/synapse/storage/databases/main/schema/delta/56/hidden_devices.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite b/synapse/storage/databases/main/schema/delta/56/hidden_devices_fix.sql.sqlite
index e8b1fd35d8..e8b1fd35d8 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite
+++ b/synapse/storage/databases/main/schema/delta/56/hidden_devices_fix.sql.sqlite
diff --git a/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql b/synapse/storage/databases/main/schema/delta/56/nuke_empty_communities_from_db.sql
index 4f24c1405d..4f24c1405d 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql
+++ b/synapse/storage/databases/main/schema/delta/56/nuke_empty_communities_from_db.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql b/synapse/storage/databases/main/schema/delta/56/public_room_list_idx.sql
index 7be31ffebb..7be31ffebb 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/56/public_room_list_idx.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql b/synapse/storage/databases/main/schema/delta/56/redaction_censor.sql
index ea95db0ed7..ea95db0ed7 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
+++ b/synapse/storage/databases/main/schema/delta/56/redaction_censor.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql b/synapse/storage/databases/main/schema/delta/56/redaction_censor2.sql
index 49ce35d794..49ce35d794 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
+++ b/synapse/storage/databases/main/schema/delta/56/redaction_censor2.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres b/synapse/storage/databases/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres
index 67471f3ef5..67471f3ef5 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql b/synapse/storage/databases/main/schema/delta/56/redaction_censor4.sql
index b7550f6f4e..b7550f6f4e 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql
+++ b/synapse/storage/databases/main/schema/delta/56/redaction_censor4.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql b/synapse/storage/databases/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql
index aeb17813d3..aeb17813d3 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql
+++ b/synapse/storage/databases/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql b/synapse/storage/databases/main/schema/delta/56/room_key_etag.sql
index 7d70dd071e..7d70dd071e 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql
+++ b/synapse/storage/databases/main/schema/delta/56/room_key_etag.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql b/synapse/storage/databases/main/schema/delta/56/room_membership_idx.sql
index 92ab1f5e65..92ab1f5e65 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/56/room_membership_idx.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql b/synapse/storage/databases/main/schema/delta/56/room_retention.sql
index ee6cdf7a14..ee6cdf7a14 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql
+++ b/synapse/storage/databases/main/schema/delta/56/room_retention.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql b/synapse/storage/databases/main/schema/delta/56/signing_keys.sql
index 5c5fffcafb..5c5fffcafb 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql
+++ b/synapse/storage/databases/main/schema/delta/56/signing_keys.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql b/synapse/storage/databases/main/schema/delta/56/signing_keys_nonunique_signatures.sql
index 0aa90ebf0c..0aa90ebf0c 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql
+++ b/synapse/storage/databases/main/schema/delta/56/signing_keys_nonunique_signatures.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql b/synapse/storage/databases/main/schema/delta/56/stats_separated.sql
index bbdde121e8..bbdde121e8 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql
+++ b/synapse/storage/databases/main/schema/delta/56/stats_separated.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
index 1de8b54961..1de8b54961 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py
+++ b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
diff --git a/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql b/synapse/storage/databases/main/schema/delta/56/user_external_ids.sql
index 91390c4527..91390c4527 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql
+++ b/synapse/storage/databases/main/schema/delta/56/user_external_ids.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql b/synapse/storage/databases/main/schema/delta/56/users_in_public_rooms_idx.sql
index 149f8be8b6..149f8be8b6 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/56/users_in_public_rooms_idx.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql b/synapse/storage/databases/main/schema/delta/57/delete_old_current_state_events.sql
index aec06c8261..aec06c8261 100644
--- a/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql
+++ b/synapse/storage/databases/main/schema/delta/57/delete_old_current_state_events.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql b/synapse/storage/databases/main/schema/delta/57/device_list_remote_cache_stale.sql
index c3b6de2099..c3b6de2099 100644
--- a/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql
+++ b/synapse/storage/databases/main/schema/delta/57/device_list_remote_cache_stale.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
index 63b5acdcf7..63b5acdcf7 100644
--- a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py
+++ b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
diff --git a/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql b/synapse/storage/databases/main/schema/delta/57/remove_sent_outbound_pokes.sql
index 133d80af35..133d80af35 100644
--- a/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql
+++ b/synapse/storage/databases/main/schema/delta/57/remove_sent_outbound_pokes.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql b/synapse/storage/databases/main/schema/delta/57/rooms_version_column.sql
index 352a66f5b0..352a66f5b0 100644
--- a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql
+++ b/synapse/storage/databases/main/schema/delta/57/rooms_version_column.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.postgres
index c601cff6de..c601cff6de 100644
--- a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.postgres
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.sqlite
index 335c6f2074..335c6f2074 100644
--- a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite
+++ b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_2.sql.sqlite
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.postgres
index 92aaadde0d..92aaadde0d 100644
--- a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.postgres
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.sqlite
index e19dab97cb..e19dab97cb 100644
--- a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite
+++ b/synapse/storage/databases/main/schema/delta/57/rooms_version_column_3.sql.sqlite
diff --git a/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql b/synapse/storage/databases/main/schema/delta/58/02remove_dup_outbound_pokes.sql
index fdc39e9ba5..fdc39e9ba5 100644
--- a/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql
+++ b/synapse/storage/databases/main/schema/delta/58/02remove_dup_outbound_pokes.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql b/synapse/storage/databases/main/schema/delta/58/03persist_ui_auth.sql
index dcb593fc2d..dcb593fc2d 100644
--- a/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql
+++ b/synapse/storage/databases/main/schema/delta/58/03persist_ui_auth.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres b/synapse/storage/databases/main/schema/delta/58/05cache_instance.sql.postgres
index aa46eb0e10..aa46eb0e10 100644
--- a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/58/05cache_instance.sql.postgres
diff --git a/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py b/synapse/storage/databases/main/schema/delta/58/06dlols_unique_idx.py
index d353f2bcb3..d353f2bcb3 100644
--- a/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py
+++ b/synapse/storage/databases/main/schema/delta/58/06dlols_unique_idx.py
diff --git a/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres b/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres
index 597f2ffd3d..597f2ffd3d 100644
--- a/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres
diff --git a/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite b/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite
index 69db89ac0e..69db89ac0e 100644
--- a/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite
+++ b/synapse/storage/databases/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite
diff --git a/synapse/storage/data_stores/main/schema/delta/58/10drop_local_rejections_stream.sql b/synapse/storage/databases/main/schema/delta/58/10drop_local_rejections_stream.sql
index eb57203e46..eb57203e46 100644
--- a/synapse/storage/data_stores/main/schema/delta/58/10drop_local_rejections_stream.sql
+++ b/synapse/storage/databases/main/schema/delta/58/10drop_local_rejections_stream.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/10federation_pos_instance_name.sql
index 1cc2633aad..1cc2633aad 100644
--- a/synapse/storage/data_stores/main/schema/delta/58/10federation_pos_instance_name.sql
+++ b/synapse/storage/databases/main/schema/delta/58/10federation_pos_instance_name.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py b/synapse/storage/databases/main/schema/delta/58/11user_id_seq.py
index 2011f6bceb..4310ec12ce 100644
--- a/synapse/storage/data_stores/main/schema/delta/58/11user_id_seq.py
+++ b/synapse/storage/databases/main/schema/delta/58/11user_id_seq.py
@@ -16,7 +16,7 @@
Adds a postgres SEQUENCE for generating guest user IDs.
"""
-from synapse.storage.data_stores.main.registration import (
+from synapse.storage.databases.main.registration import (
find_max_generated_user_id_localpart,
)
from synapse.storage.engines import PostgresEngine
diff --git a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
new file mode 100644
index 0000000000..cade5dcca8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
@@ -0,0 +1,32 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Recalculate the stats for all rooms after the fix to joined_members erroneously
+-- incrementing on per-room profile changes.
+
+-- Note that the populate_stats_process_rooms background update is already set to
+-- run if you're upgrading from Synapse <1.0.0.
+
+-- Additionally, if you've upgraded to v1.18.0 (which doesn't include this fix),
+-- this bg job runs, and then update to v1.19.0, you'd end up with only half of
+-- your rooms having room stats recalculated after this fix was in place.
+
+-- So we've switched the old `populate_stats_process_rooms` background job to a
+-- no-op, and then kick off a bg job with a new name, but with the same
+-- functionality as the old one. This effectively restarts the background job
+-- from the beginning, without running it twice in a row, supporting both
+-- upgrade usecases.
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('populate_stats_process_rooms_2', '{}');
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql b/synapse/storage/databases/main/schema/full_schemas/16/application_services.sql
index 883fcd10b2..883fcd10b2 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/application_services.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql b/synapse/storage/databases/main/schema/full_schemas/16/event_edges.sql
index 10ce2aa7a0..10ce2aa7a0 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/event_edges.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql b/synapse/storage/databases/main/schema/full_schemas/16/event_signatures.sql
index 95826da431..95826da431 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/event_signatures.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql b/synapse/storage/databases/main/schema/full_schemas/16/im.sql
index a1a2aa8e5b..a1a2aa8e5b 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/im.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql b/synapse/storage/databases/main/schema/full_schemas/16/keys.sql
index 11cdffdbb3..11cdffdbb3 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/keys.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql b/synapse/storage/databases/main/schema/full_schemas/16/media_repository.sql
index 8f3759bb2a..8f3759bb2a 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/media_repository.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql b/synapse/storage/databases/main/schema/full_schemas/16/presence.sql
index 01d2d8f833..01d2d8f833 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/presence.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql b/synapse/storage/databases/main/schema/full_schemas/16/profiles.sql
index c04f4747d9..c04f4747d9 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/profiles.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql b/synapse/storage/databases/main/schema/full_schemas/16/push.sql
index e44465cf45..e44465cf45 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/push.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql b/synapse/storage/databases/main/schema/full_schemas/16/redactions.sql
index 318f0d9aa5..318f0d9aa5 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/redactions.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql b/synapse/storage/databases/main/schema/full_schemas/16/room_aliases.sql
index d47da3b12f..d47da3b12f 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/room_aliases.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql b/synapse/storage/databases/main/schema/full_schemas/16/state.sql
index 96391a8f0e..96391a8f0e 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/state.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql b/synapse/storage/databases/main/schema/full_schemas/16/transactions.sql
index 17e67bedac..17e67bedac 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/transactions.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql b/synapse/storage/databases/main/schema/full_schemas/16/users.sql
index f013aa8b18..f013aa8b18 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/16/users.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
index 889a9a0ce4..889a9a0ce4 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
+++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
index a0411ede7e..a0411ede7e 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql b/synapse/storage/databases/main/schema/full_schemas/54/stream_positions.sql
index 91d21b2921..91d21b2921 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
+++ b/synapse/storage/databases/main/schema/full_schemas/54/stream_positions.sql
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/databases/main/schema/full_schemas/README.md
index c00f287190..c00f287190 100644
--- a/synapse/storage/data_stores/main/schema/full_schemas/README.md
+++ b/synapse/storage/databases/main/schema/full_schemas/README.md
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/databases/main/search.py
index a79533dfad..dcbdeab36e 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -16,13 +16,12 @@
import logging
import re
from collections import namedtuple
-
-from twisted.internet import defer
+from typing import List, Optional
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
logger = logging.getLogger(__name__)
@@ -88,16 +87,16 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs)
if not hs.config.enable_search:
return
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order
)
@@ -106,16 +105,15 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
# a GIN index. However, it's possible that some people might still have
# the background update queued, so we register a handler to clear the
# background update.
- self.db.updates.register_noop_background_update(
+ self.db_pool.updates.register_noop_background_update(
self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
)
- @defer.inlineCallbacks
- def _background_reindex_search(self, progress, batch_size):
+ async def _background_reindex_search(self, progress, batch_size):
# we work through the events table from highest stream id to lowest
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
@@ -140,7 +138,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
# store_search_entries_txn with a generator function, but that
# would mean having two cursors open on the database at once.
# Instead we just build a list of results.
- rows = self.db.cursor_to_dict(txn)
+ rows = self.db_pool.cursor_to_dict(txn)
if not rows:
return 0
@@ -200,23 +198,24 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
"rows_inserted": rows_inserted + len(event_search_rows),
}
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, self.EVENT_SEARCH_UPDATE_NAME, progress
)
return len(event_search_rows)
- result = yield self.db.runInteraction(
+ result = await self.db_pool.runInteraction(
self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
)
if not result:
- yield self.db.updates._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
+ await self.db_pool.updates._end_background_update(
+ self.EVENT_SEARCH_UPDATE_NAME
+ )
return result
- @defer.inlineCallbacks
- def _background_reindex_gin_search(self, progress, batch_size):
+ async def _background_reindex_gin_search(self, progress, batch_size):
"""This handles old synapses which used GIST indexes, if any;
converting them back to be GIN as per the actual schema.
"""
@@ -253,15 +252,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
conn.set_session(autocommit=False)
if isinstance(self.database_engine, PostgresEngine):
- yield self.db.runWithConnection(create_index)
+ await self.db_pool.runWithConnection(create_index)
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME
)
return 1
- @defer.inlineCallbacks
- def _background_reindex_search_order(self, progress, batch_size):
+ async def _background_reindex_search_order(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@@ -286,14 +284,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
)
conn.set_session(autocommit=False)
- yield self.db.runWithConnection(create_index)
+ await self.db_pool.runWithConnection(create_index)
pg = dict(progress)
pg["have_added_indexes"] = True
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
- self.db.updates._background_update_progress_txn,
+ self.db_pool.updates._background_update_progress_txn,
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
pg,
)
@@ -323,18 +321,18 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
"have_added_indexes": True,
}
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress
)
return len(rows), True
- num_rows, finished = yield self.db.runInteraction(
+ num_rows, finished = await self.db_pool.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
)
if not finished:
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.EVENT_SEARCH_ORDER_UPDATE_NAME
)
@@ -342,11 +340,10 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
class SearchStore(SearchBackgroundUpdateStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SearchStore, self).__init__(database, db_conn, hs)
- @defer.inlineCallbacks
- def search_msgs(self, room_ids, search_term, keys):
+ async def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
Args:
@@ -423,15 +420,15 @@ class SearchStore(SearchBackgroundUpdateStore):
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
- results = yield self.db.execute(
- "search_msgs", self.db.cursor_to_dict, sql, *args
+ results = await self.db_pool.execute(
+ "search_msgs", self.db_pool.cursor_to_dict, sql, *args
)
results = list(filter(lambda row: row["room_id"] in room_ids, results))
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)
@@ -440,12 +437,12 @@ class SearchStore(SearchBackgroundUpdateStore):
highlights = None
if isinstance(self.database_engine, PostgresEngine):
- highlights = yield self._find_highlights_in_postgres(search_query, events)
+ highlights = await self._find_highlights_in_postgres(search_query, events)
count_sql += " GROUP BY room_id"
- count_results = yield self.db.execute(
- "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args
+ count_results = await self.db_pool.execute(
+ "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
)
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
@@ -460,19 +457,25 @@ class SearchStore(SearchBackgroundUpdateStore):
"count": count,
}
- @defer.inlineCallbacks
- def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
+ async def search_rooms(
+ self,
+ room_ids: List[str],
+ search_term: str,
+ keys: List[str],
+ limit,
+ pagination_token: Optional[str] = None,
+ ) -> List[dict]:
"""Performs a full text search over events with given keys.
Args:
- room_id (list): The room_ids to search in
- search_term (str): Search term to search for
- keys (list): List of keys to search in, currently supports
- "content.body", "content.name", "content.topic"
- pagination_token (str): A pagination token previously returned
+ room_ids: The room_ids to search in
+ search_term: Search term to search for
+ keys: List of keys to search in, currently supports "content.body",
+ "content.name", "content.topic"
+ pagination_token: A pagination token previously returned
Returns:
- list of dicts
+ Each match as a dictionary.
"""
clauses = []
@@ -575,15 +578,15 @@ class SearchStore(SearchBackgroundUpdateStore):
args.append(limit)
- results = yield self.db.execute(
- "search_rooms", self.db.cursor_to_dict, sql, *args
+ results = await self.db_pool.execute(
+ "search_rooms", self.db_pool.cursor_to_dict, sql, *args
)
results = list(filter(lambda row: row["room_id"] in room_ids, results))
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
- events = yield self.get_events_as_list(
+ events = await self.get_events_as_list(
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)
@@ -592,12 +595,12 @@ class SearchStore(SearchBackgroundUpdateStore):
highlights = None
if isinstance(self.database_engine, PostgresEngine):
- highlights = yield self._find_highlights_in_postgres(search_query, events)
+ highlights = await self._find_highlights_in_postgres(search_query, events)
count_sql += " GROUP BY room_id"
- count_results = yield self.db.execute(
- "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args
+ count_results = await self.db_pool.execute(
+ "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
)
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
@@ -682,7 +685,7 @@ class SearchStore(SearchBackgroundUpdateStore):
return highlight_words
- return self.db.runInteraction("_find_highlights", f)
+ return self.db_pool.runInteraction("_find_highlights", f)
def _to_postgres_options(options_dict):
diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/databases/main/signatures.py
index 36244d9f5d..be191dd870 100644
--- a/synapse/storage/data_stores/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -15,8 +15,6 @@
from unpaddedbase64 import encode_base64
-from twisted.internet import defer
-
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
@@ -38,11 +36,10 @@ class SignatureWorkerStore(SQLBaseStore):
for event_id in event_ids
}
- return self.db.runInteraction("get_event_reference_hashes", f)
+ return self.db_pool.runInteraction("get_event_reference_hashes", f)
- @defer.inlineCallbacks
- def add_event_hashes(self, event_ids):
- hashes = yield self.get_event_reference_hashes(event_ids)
+ async def add_event_hashes(self, event_ids):
+ hashes = await self.get_event_reference_hashes(event_ids)
hashes = {
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
for e_id, h in hashes.items()
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/databases/main/state.py
index a360699408..96e0378e50 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -23,9 +23,9 @@ from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
@@ -54,7 +54,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers.
"""
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(StateGroupWorkerStore, self).__init__(database, db_conn, hs)
async def get_room_version(self, room_id: str) -> RoomVersion:
@@ -93,7 +93,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# We really should have an entry in the rooms table for every room we
# care about, but let's be a bit paranoid (at least while the background
# update is happening) to avoid breaking existing rooms.
- version = await self.db.simple_select_one_onecol(
+ version = await self.db_pool.simple_select_one_onecol(
table="rooms",
keyvalues={"room_id": room_id},
retcol="room_version",
@@ -184,7 +184,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_current_state_ids", _get_current_state_ids_txn
)
@@ -231,7 +231,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
@@ -261,7 +261,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
@cached(max_entries=50000)
def _get_state_group_for_event(self, event_id):
- return self.db.simple_select_one_onecol(
+ return self.db_pool.simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={"event_id": event_id},
retcol="state_group",
@@ -278,7 +278,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
- rows = yield self.db.simple_select_many_batch(
+ rows = yield self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
iterable=event_ids,
@@ -301,7 +301,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
The subset of state groups that are referenced.
"""
- rows = await self.db.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="state_group",
iterable=state_groups,
@@ -319,25 +319,25 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.server_name = hs.hostname
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
index_name="current_state_events_member_index",
table="current_state_events",
columns=["state_key"],
where_clause="type='m.room.member'",
)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME,
index_name="event_to_state_groups_sg_index",
table="event_to_state_groups",
columns=["state_group"],
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms,
)
@@ -429,7 +429,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
# potentially stale, since there may have been a period where the
# server didn't share a room with the remote user and therefore may
# have missed any device updates.
- rows = self.db.simple_select_many_txn(
+ rows = self.db_pool.simple_select_many_txn(
txn,
table="current_state_events",
column="room_id",
@@ -441,7 +441,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
potentially_left_users = {row["state_key"] for row in rows}
# Now lets actually delete the rooms from the DB.
- self.db.simple_delete_many_txn(
+ self.db_pool.simple_delete_many_txn(
txn,
table="current_state_events",
column="room_id",
@@ -449,7 +449,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
keyvalues={},
)
- self.db.simple_delete_many_txn(
+ self.db_pool.simple_delete_many_txn(
txn,
table="event_forward_extremities",
column="room_id",
@@ -457,7 +457,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
keyvalues={},
)
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn,
self.DELETE_CURRENT_STATE_UPDATE_NAME,
{"last_room_id": room_ids[-1]},
@@ -465,12 +465,12 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
return False, potentially_left_users
- finished, potentially_left_users = await self.db.runInteraction(
+ finished, potentially_left_users = await self.db_pool.runInteraction(
"_background_remove_left_rooms", _background_remove_left_rooms_txn
)
if finished:
- await self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.DELETE_CURRENT_STATE_UPDATE_NAME
)
@@ -505,5 +505,5 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
* `state_groups_state`: Maps state group to state events.
"""
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(StateStore, self).__init__(database, db_conn, hs)
diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 725e12507f..0d963c98ff 100644
--- a/synapse/storage/data_stores/main/state_deltas.py
+++ b/synapse/storage/databases/main/state_deltas.py
@@ -100,14 +100,14 @@ class StateDeltasStore(SQLBaseStore):
ORDER BY stream_id ASC
"""
txn.execute(sql, (prev_stream_id, clipped_stream_id))
- return clipped_stream_id, self.db.cursor_to_dict(txn)
+ return clipped_stream_id, self.db_pool.cursor_to_dict(txn)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn
)
def _get_max_stream_id_in_current_state_deltas_txn(self, txn):
- return self.db.simple_select_one_onecol_txn(
+ return self.db_pool.simple_select_one_onecol_txn(
txn,
table="current_state_delta_stream",
keyvalues={},
@@ -115,7 +115,7 @@ class StateDeltasStore(SQLBaseStore):
)
def get_max_stream_id_in_current_state_deltas(self):
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_max_stream_id_in_current_state_deltas",
self._get_max_stream_id_in_current_state_deltas_txn,
)
diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/databases/main/stats.py
index 922400a7c3..802c9019b9 100644
--- a/synapse/storage/data_stores/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -21,8 +21,8 @@ from typing import Tuple
from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventTypes, Membership
-from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine
from synapse.util.caches.descriptors import cached
@@ -59,7 +59,7 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
class StatsStore(StateDeltasStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(StatsStore, self).__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -69,17 +69,20 @@ class StatsStore(StateDeltasStore):
self.stats_delta_processing_lock = DeferredLock()
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"populate_stats_process_rooms", self._populate_stats_process_rooms
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
+ "populate_stats_process_rooms_2", self._populate_stats_process_rooms_2
+ )
+ self.db_pool.updates.register_background_update_handler(
"populate_stats_process_users", self._populate_stats_process_users
)
# we no longer need to perform clean-up, but we will give ourselves
# the potential to reintroduce it in the future – so documentation
# will still encourage the use of this no-op handler.
- self.db.updates.register_noop_background_update("populate_stats_cleanup")
- self.db.updates.register_noop_background_update("populate_stats_prepare")
+ self.db_pool.updates.register_noop_background_update("populate_stats_cleanup")
+ self.db_pool.updates.register_noop_background_update("populate_stats_prepare")
def quantise_stats_time(self, ts):
"""
@@ -102,7 +105,9 @@ class StatsStore(StateDeltasStore):
This is a background update which regenerates statistics for users.
"""
if not self.stats_enabled:
- await self.db.updates._end_background_update("populate_stats_process_users")
+ await self.db_pool.updates._end_background_update(
+ "populate_stats_process_users"
+ )
return 1
last_user_id = progress.get("last_user_id", "")
@@ -117,22 +122,24 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, (last_user_id, batch_size))
return [r for r, in txn]
- users_to_work_on = await self.db.runInteraction(
+ users_to_work_on = await self.db_pool.runInteraction(
"_populate_stats_process_users", _get_next_batch
)
# No more rooms -- complete the transaction.
if not users_to_work_on:
- await self.db.updates._end_background_update("populate_stats_process_users")
+ await self.db_pool.updates._end_background_update(
+ "populate_stats_process_users"
+ )
return 1
for user_id in users_to_work_on:
await self._calculate_and_set_initial_state_for_user(user_id)
progress["last_user_id"] = user_id
- await self.db.runInteraction(
+ await self.db_pool.runInteraction(
"populate_stats_process_users",
- self.db.updates._background_update_progress_txn,
+ self.db_pool.updates._background_update_progress_txn,
"populate_stats_process_users",
progress,
)
@@ -141,10 +148,31 @@ class StatsStore(StateDeltasStore):
async def _populate_stats_process_rooms(self, progress, batch_size):
"""
+ This was a background update which regenerated statistics for rooms.
+
+ It has been replaced by StatsStore._populate_stats_process_rooms_2. This background
+ job has been scheduled to run as part of Synapse v1.0.0, and again now. To ensure
+ someone upgrading from <v1.0.0, this background task has been turned into a no-op
+ so that the potentially expensive task is not run twice.
+
+ Further context: https://github.com/matrix-org/synapse/pull/7977
+ """
+ await self.db_pool.updates._end_background_update(
+ "populate_stats_process_rooms"
+ )
+ return 1
+
+ async def _populate_stats_process_rooms_2(self, progress, batch_size):
+ """
This is a background update which regenerates statistics for rooms.
+
+ It replaces StatsStore._populate_stats_process_rooms. See its docstring for the
+ reasoning.
"""
if not self.stats_enabled:
- await self.db.updates._end_background_update("populate_stats_process_rooms")
+ await self.db_pool.updates._end_background_update(
+ "populate_stats_process_rooms_2"
+ )
return 1
last_room_id = progress.get("last_room_id", "")
@@ -159,23 +187,25 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, (last_room_id, batch_size))
return [r for r, in txn]
- rooms_to_work_on = await self.db.runInteraction(
- "populate_stats_rooms_get_batch", _get_next_batch
+ rooms_to_work_on = await self.db_pool.runInteraction(
+ "populate_stats_rooms_2_get_batch", _get_next_batch
)
# No more rooms -- complete the transaction.
if not rooms_to_work_on:
- await self.db.updates._end_background_update("populate_stats_process_rooms")
+ await self.db_pool.updates._end_background_update(
+ "populate_stats_process_rooms_2"
+ )
return 1
for room_id in rooms_to_work_on:
await self._calculate_and_set_initial_state_for_room(room_id)
progress["last_room_id"] = room_id
- await self.db.runInteraction(
- "_populate_stats_process_rooms",
- self.db.updates._background_update_progress_txn,
- "populate_stats_process_rooms",
+ await self.db_pool.runInteraction(
+ "_populate_stats_process_rooms_2",
+ self.db_pool.updates._background_update_progress_txn,
+ "populate_stats_process_rooms_2",
progress,
)
@@ -185,7 +215,7 @@ class StatsStore(StateDeltasStore):
"""
Returns the stats processor positions.
"""
- return self.db.simple_select_one_onecol(
+ return self.db_pool.simple_select_one_onecol(
table="stats_incremental_position",
keyvalues={},
retcol="stream_id",
@@ -214,7 +244,7 @@ class StatsStore(StateDeltasStore):
if field and "\0" in field:
fields[col] = None
- return self.db.simple_upsert(
+ return self.db_pool.simple_upsert(
table="room_stats_state",
keyvalues={"room_id": room_id},
values=fields,
@@ -235,7 +265,7 @@ class StatsStore(StateDeltasStore):
Deferred[list[dict]], where the dict has the keys of
ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts".
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_statistics_for_subject",
self._get_statistics_for_subject_txn,
stats_type,
@@ -256,7 +286,7 @@ class StatsStore(StateDeltasStore):
ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type]
)
- slice_list = self.db.simple_select_list_paginate_txn(
+ slice_list = self.db_pool.simple_select_list_paginate_txn(
txn,
table + "_historical",
"end_ts",
@@ -282,7 +312,7 @@ class StatsStore(StateDeltasStore):
"""
table, id_col = TYPE_TO_TABLE[stats_type]
- return self.db.simple_select_one_onecol(
+ return self.db_pool.simple_select_one_onecol(
"%s_current" % (table,),
keyvalues={id_col: id},
retcol="completed_delta_stream_id",
@@ -318,14 +348,14 @@ class StatsStore(StateDeltasStore):
complete_with_stream_id=stream_id,
)
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
table="stats_incremental_position",
keyvalues={},
updatevalues={"stream_id": stream_id},
)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"bulk_update_stats_delta", _bulk_update_stats_delta_txn
)
@@ -356,7 +386,7 @@ class StatsStore(StateDeltasStore):
Does not work with per-slice fields.
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"update_stats_delta",
self._update_stats_delta_txn,
ts,
@@ -491,17 +521,17 @@ class StatsStore(StateDeltasStore):
else:
self.database_engine.lock_table(txn, table)
retcols = list(chain(absolutes.keys(), additive_relatives.keys()))
- current_row = self.db.simple_select_one_txn(
+ current_row = self.db_pool.simple_select_one_txn(
txn, table, keyvalues, retcols, allow_none=True
)
if current_row is None:
merged_dict = {**keyvalues, **absolutes, **additive_relatives}
- self.db.simple_insert_txn(txn, table, merged_dict)
+ self.db_pool.simple_insert_txn(txn, table, merged_dict)
else:
for (key, val) in additive_relatives.items():
current_row[key] += val
current_row.update(absolutes)
- self.db.simple_update_one_txn(txn, table, keyvalues, current_row)
+ self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row)
def _upsert_copy_from_table_with_additive_relatives_txn(
self,
@@ -588,11 +618,11 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, qargs)
else:
self.database_engine.lock_table(txn, into_table)
- src_row = self.db.simple_select_one_txn(
+ src_row = self.db_pool.simple_select_one_txn(
txn, src_table, keyvalues, copy_columns
)
all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues}
- dest_current_row = self.db.simple_select_one_txn(
+ dest_current_row = self.db_pool.simple_select_one_txn(
txn,
into_table,
keyvalues=all_dest_keyvalues,
@@ -608,11 +638,13 @@ class StatsStore(StateDeltasStore):
**src_row,
**additive_relatives,
}
- self.db.simple_insert_txn(txn, into_table, merged_dict)
+ self.db_pool.simple_insert_txn(txn, into_table, merged_dict)
else:
for (key, val) in additive_relatives.items():
src_row[key] = dest_current_row[key] + val
- self.db.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row)
+ self.db_pool.simple_update_txn(
+ txn, into_table, all_dest_keyvalues, src_row
+ )
def get_changes_room_total_events_and_bytes(self, min_pos, max_pos):
"""Fetches the counts of events in the given range of stream IDs.
@@ -626,7 +658,7 @@ class StatsStore(StateDeltasStore):
changes.
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"stats_incremental_total_events_and_bytes",
self.get_changes_room_total_events_and_bytes_txn,
min_pos,
@@ -709,7 +741,7 @@ class StatsStore(StateDeltasStore):
def _fetch_current_state_stats(txn):
pos = self.get_room_max_stream_ordering()
- rows = self.db.simple_select_many_txn(
+ rows = self.db_pool.simple_select_many_txn(
txn,
table="current_state_events",
column="type",
@@ -765,7 +797,7 @@ class StatsStore(StateDeltasStore):
current_state_events_count,
users_in_room,
pos,
- ) = await self.db.runInteraction(
+ ) = await self.db_pool.runInteraction(
"get_initial_state_for_room", _fetch_current_state_stats
)
@@ -839,7 +871,7 @@ class StatsStore(StateDeltasStore):
(count,) = txn.fetchone()
return count, pos
- joined_rooms, pos = await self.db.runInteraction(
+ joined_rooms, pos = await self.db_pool.runInteraction(
"calculate_and_set_initial_state_for_user",
_calculate_and_set_initial_state_for_user_txn,
)
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/databases/main/stream.py
index 10d39b3699..aaf225894e 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -39,13 +39,14 @@ what sort order was used:
import abc
import logging
from collections import namedtuple
+from typing import Optional
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.database import Database, make_in_list_sql_clause
+from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -250,7 +251,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
__metaclass__ = abc.ABCMeta
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(StreamWorkerStore, self).__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
@@ -264,7 +265,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._need_to_reset_federation_stream_positions = self._send_federation
events_max = self.get_room_max_stream_ordering()
- event_cache_prefill, min_event_val = self.db.get_cache_dict(
+ event_cache_prefill, min_event_val = self.db_pool.get_cache_dict(
db_conn,
"events",
entity_column="room_id",
@@ -409,7 +410,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows
- rows = yield self.db.runInteraction("get_room_events_stream_for_room", f)
+ rows = yield self.db_pool.runInteraction("get_room_events_stream_for_room", f)
ret = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
@@ -459,7 +460,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows
- rows = yield self.db.runInteraction("get_membership_changes_for_user", f)
+ rows = yield self.db_pool.runInteraction("get_membership_changes_for_user", f)
ret = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
@@ -518,7 +519,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token = RoomStreamToken.parse(end_token)
- rows, token = yield self.db.runInteraction(
+ rows, token = yield self.db_pool.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
@@ -555,21 +556,20 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.execute(sql, (room_id, stream_ordering))
return txn.fetchone()
- return self.db.runInteraction("get_room_event_before_stream_ordering", _f)
+ return self.db_pool.runInteraction("get_room_event_before_stream_ordering", _f)
- @defer.inlineCallbacks
- def get_room_events_max_id(self, room_id=None):
+ async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str:
"""Returns the current token for rooms stream.
By default, it returns the current global stream token. Specifying a
`room_id` causes it to return the current room specific topological
token.
"""
- token = yield self.get_room_max_stream_ordering()
+ token = self.get_room_max_stream_ordering()
if room_id is None:
return "s%d" % (token,)
else:
- topo = yield self.db.runInteraction(
+ topo = await self.db_pool.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn, room_id
)
return "t%d-%d" % (topo, token)
@@ -583,7 +583,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A deferred "s%d" stream token.
"""
- return self.db.simple_select_one_onecol(
+ return self.db_pool.simple_select_one_onecol(
table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
).addCallback(lambda row: "s%d" % (row,))
@@ -596,7 +596,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A deferred "t%d-%d" topological token.
"""
- return self.db.simple_select_one(
+ return self.db_pool.simple_select_one(
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
@@ -620,7 +620,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"SELECT coalesce(max(topological_ordering), 0) FROM events"
" WHERE room_id = ? AND stream_ordering < ?"
)
- return self.db.execute(
+ return self.db_pool.execute(
"get_max_topological_token", None, sql, room_id, stream_key
).addCallback(lambda r: r[0][0] if r else 0)
@@ -674,7 +674,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
dict
"""
- results = yield self.db.runInteraction(
+ results = yield self.db_pool.runInteraction(
"get_events_around",
self._get_events_around_txn,
room_id,
@@ -716,7 +716,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
dict
"""
- results = self.db.simple_select_one_txn(
+ results = self.db_pool.simple_select_one_txn(
txn,
"events",
keyvalues={"event_id": event_id, "room_id": room_id},
@@ -795,7 +795,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, [row[1] for row in rows]
- upper_bound, event_ids = yield self.db.runInteraction(
+ upper_bound, event_ids = yield self.db_pool.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn
)
@@ -805,12 +805,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_federation_out_pos(self, typ: str) -> int:
if self._need_to_reset_federation_stream_positions:
- await self.db.runInteraction(
+ await self.db_pool.runInteraction(
"_reset_federation_positions_txn", self._reset_federation_positions_txn
)
self._need_to_reset_federation_stream_positions = False
- return await self.db.simple_select_one_onecol(
+ return await self.db_pool.simple_select_one_onecol(
table="federation_stream_position",
retcol="stream_id",
keyvalues={"type": typ, "instance_name": self._instance_name},
@@ -819,12 +819,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def update_federation_out_pos(self, typ, stream_id):
if self._need_to_reset_federation_stream_positions:
- await self.db.runInteraction(
+ await self.db_pool.runInteraction(
"_reset_federation_positions_txn", self._reset_federation_positions_txn
)
self._need_to_reset_federation_stream_positions = False
- return await self.db.simple_update_one(
+ return await self.db_pool.simple_update_one(
table="federation_stream_position",
keyvalues={"type": typ, "instance_name": self._instance_name},
updatevalues={"stream_id": stream_id},
@@ -854,7 +854,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
elif self._instance_name not in configured_instances:
return
- instances_in_table = self.db.simple_select_onecol_txn(
+ instances_in_table = self.db_pool.simple_select_onecol_txn(
txn,
table="federation_stream_position",
keyvalues={},
@@ -885,7 +885,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.execute(sql % (clause,), args)
for typ, stream_id in min_positions.items():
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="federation_stream_position",
keyvalues={"type": typ, "instance_name": self._instance_name},
@@ -1036,7 +1036,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if to_key:
to_key = RoomStreamToken.parse(to_key)
- rows, token = yield self.db.runInteraction(
+ rows, token = yield self.db_pool.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/databases/main/tags.py
index bd7227773a..e4e0a0c433 100644
--- a/synapse/storage/data_stores/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -15,14 +15,13 @@
# limitations under the License.
import logging
-from typing import List, Tuple
+from typing import Dict, List, Tuple
from canonicaljson import json
-from twisted.internet import defer
-
from synapse.storage._base import db_to_json
-from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
+from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -30,30 +29,26 @@ logger = logging.getLogger(__name__)
class TagsWorkerStore(AccountDataWorkerStore):
@cached()
- def get_tags_for_user(self, user_id):
+ async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]:
"""Get all the tags for a user.
Args:
- user_id(str): The user to get the tags for.
+ user_id: The user to get the tags for.
Returns:
- A deferred dict mapping from room_id strings to dicts mapping from
- tag strings to tag content.
+ A mapping from room_id strings to dicts mapping from tag strings to
+ tag content.
"""
- deferred = self.db.simple_select_list(
+ rows = await self.db_pool.simple_select_list(
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)
- @deferred.addCallback
- def tags_by_room(rows):
- tags_by_room = {}
- for row in rows:
- room_tags = tags_by_room.setdefault(row["room_id"], {})
- room_tags[row["tag"]] = db_to_json(row["content"])
- return tags_by_room
-
- return deferred
+ tags_by_room = {}
+ for row in rows:
+ room_tags = tags_by_room.setdefault(row["room_id"], {})
+ room_tags[row["tag"]] = db_to_json(row["content"])
+ return tags_by_room
async def get_all_updated_tags(
self, instance_name: str, last_id: int, current_id: int, limit: int
@@ -92,7 +87,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
- tag_ids = await self.db.runInteraction(
+ tag_ids = await self.db_pool.runInteraction(
"get_all_updated_tags", get_all_updated_tags_txn
)
@@ -112,7 +107,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
batch_size = 50
results = []
for i in range(0, len(tag_ids), batch_size):
- tags = await self.db.runInteraction(
+ tags = await self.db_pool.runInteraction(
"get_all_updated_tag_content",
get_tag_content,
tag_ids[i : i + batch_size],
@@ -127,17 +122,19 @@ class TagsWorkerStore(AccountDataWorkerStore):
return results, upto_token, limited
- @defer.inlineCallbacks
- def get_updated_tags(self, user_id, stream_id):
+ async def get_updated_tags(
+ self, user_id: str, stream_id: int
+ ) -> Dict[str, List[str]]:
"""Get all the tags for the rooms where the tags have changed since the
given version
Args:
user_id(str): The user to get the tags for.
stream_id(int): The earliest update to get for the user.
+
Returns:
- A deferred dict mapping from room_id strings to lists of tag
- strings for all the rooms that changed since the stream_id token.
+ A mapping from room_id strings to lists of tag strings for all the
+ rooms that changed since the stream_id token.
"""
def get_updated_tags_txn(txn):
@@ -155,52 +152,58 @@ class TagsWorkerStore(AccountDataWorkerStore):
if not changed:
return {}
- room_ids = yield self.db.runInteraction(
+ room_ids = await self.db_pool.runInteraction(
"get_updated_tags", get_updated_tags_txn
)
results = {}
if room_ids:
- tags_by_room = yield self.get_tags_for_user(user_id)
+ tags_by_room = await self.get_tags_for_user(user_id)
for room_id in room_ids:
results[room_id] = tags_by_room.get(room_id, {})
return results
- def get_tags_for_room(self, user_id, room_id):
+ async def get_tags_for_room(
+ self, user_id: str, room_id: str
+ ) -> Dict[str, JsonDict]:
"""Get all the tags for the given room
+
Args:
- user_id(str): The user to get tags for
- room_id(str): The room to get tags for
+ user_id: The user to get tags for
+ room_id: The room to get tags for
+
Returns:
- A deferred list of string tags.
+ A mapping of tags to tag content.
"""
- return self.db.simple_select_list(
+ rows = await self.db_pool.simple_select_list(
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"),
desc="get_tags_for_room",
- ).addCallback(
- lambda rows: {row["tag"]: db_to_json(row["content"]) for row in rows}
)
+ return {row["tag"]: db_to_json(row["content"]) for row in rows}
class TagsStore(TagsWorkerStore):
- @defer.inlineCallbacks
- def add_tag_to_room(self, user_id, room_id, tag, content):
+ async def add_tag_to_room(
+ self, user_id: str, room_id: str, tag: str, content: JsonDict
+ ) -> int:
"""Add a tag to a room for a user.
+
Args:
- user_id(str): The user to add a tag for.
- room_id(str): The room to add a tag for.
- tag(str): The tag name to add.
- content(dict): A json object to associate with the tag.
+ user_id: The user to add a tag for.
+ room_id: The room to add a tag for.
+ tag: The tag name to add.
+ content: A json object to associate with the tag.
+
Returns:
- A deferred that completes once the tag has been added.
+ The next account data ID.
"""
content_json = json.dumps(content)
def add_tag_txn(txn, next_id):
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag},
@@ -209,18 +212,17 @@ class TagsStore(TagsWorkerStore):
self._update_revision_txn(txn, user_id, room_id, next_id)
with self._account_data_id_gen.get_next() as next_id:
- yield self.db.runInteraction("add_tag", add_tag_txn, next_id)
+ await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
- result = self._account_data_id_gen.get_current_token()
- return result
+ return self._account_data_id_gen.get_current_token()
- @defer.inlineCallbacks
- def remove_tag_from_room(self, user_id, room_id, tag):
+ async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
"""Remove a tag from a room for a user.
+
Returns:
- A deferred that completes once the tag has been removed
+ The next account data ID.
"""
def remove_tag_txn(txn, next_id):
@@ -232,21 +234,22 @@ class TagsStore(TagsWorkerStore):
self._update_revision_txn(txn, user_id, room_id, next_id)
with self._account_data_id_gen.get_next() as next_id:
- yield self.db.runInteraction("remove_tag", remove_tag_txn, next_id)
+ await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
- result = self._account_data_id_gen.get_current_token()
- return result
+ return self._account_data_id_gen.get_current_token()
- def _update_revision_txn(self, txn, user_id, room_id, next_id):
+ def _update_revision_txn(
+ self, txn, user_id: str, room_id: str, next_id: int
+ ) -> None:
"""Update the latest revision of the tags for the given user and room.
Args:
txn: The database cursor
- user_id(str): The ID of the user.
- room_id(str): The ID of the room.
- next_id(int): The the revision to advance to.
+ user_id: The ID of the user.
+ room_id: The ID of the room.
+ next_id: The the revision to advance to.
"""
txn.call_after(
diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/databases/main/transactions.py
index a9bf457939..52668dbdf9 100644
--- a/synapse/storage/data_stores/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -18,11 +18,9 @@ from collections import namedtuple
from canonicaljson import encode_canonical_json
-from twisted.internet import defer
-
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
from synapse.util.caches.expiringcache import ExpiringCache
db_binary_type = memoryview
@@ -46,7 +44,7 @@ class TransactionStore(SQLBaseStore):
"""A collection of queries for handling PDUs.
"""
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(TransactionStore, self).__init__(database, db_conn, hs)
self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
@@ -71,7 +69,7 @@ class TransactionStore(SQLBaseStore):
this transaction or a 2-tuple of (int, dict)
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_received_txn_response",
self._get_received_txn_response,
transaction_id,
@@ -79,7 +77,7 @@ class TransactionStore(SQLBaseStore):
)
def _get_received_txn_response(self, txn, transaction_id, origin):
- result = self.db.simple_select_one_txn(
+ result = self.db_pool.simple_select_one_txn(
txn,
table="received_transactions",
keyvalues={"transaction_id": transaction_id, "origin": origin},
@@ -113,7 +111,7 @@ class TransactionStore(SQLBaseStore):
response_json (str)
"""
- return self.db.simple_insert(
+ return self.db_pool.simple_insert(
table="received_transactions",
values={
"transaction_id": transaction_id,
@@ -126,8 +124,7 @@ class TransactionStore(SQLBaseStore):
desc="set_received_txn_response",
)
- @defer.inlineCallbacks
- def get_destination_retry_timings(self, destination):
+ async def get_destination_retry_timings(self, destination):
"""Gets the current retry timings (if any) for a given destination.
Args:
@@ -142,7 +139,7 @@ class TransactionStore(SQLBaseStore):
if result is not SENTINEL:
return result
- result = yield self.db.runInteraction(
+ result = await self.db_pool.runInteraction(
"get_destination_retry_timings",
self._get_destination_retry_timings,
destination,
@@ -154,7 +151,7 @@ class TransactionStore(SQLBaseStore):
return result
def _get_destination_retry_timings(self, txn, destination):
- result = self.db.simple_select_one_txn(
+ result = self.db_pool.simple_select_one_txn(
txn,
table="destinations",
keyvalues={"destination": destination},
@@ -181,7 +178,7 @@ class TransactionStore(SQLBaseStore):
"""
self._destination_retry_cache.pop(destination, None)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings,
destination,
@@ -221,7 +218,7 @@ class TransactionStore(SQLBaseStore):
# We need to be careful here as the data may have changed from under us
# due to a worker setting the timings.
- prev_row = self.db.simple_select_one_txn(
+ prev_row = self.db_pool.simple_select_one_txn(
txn,
table="destinations",
keyvalues={"destination": destination},
@@ -230,7 +227,7 @@ class TransactionStore(SQLBaseStore):
)
if not prev_row:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="destinations",
values={
@@ -241,7 +238,7 @@ class TransactionStore(SQLBaseStore):
},
)
elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
"destinations",
keyvalues={"destination": destination},
@@ -264,6 +261,6 @@ class TransactionStore(SQLBaseStore):
def _cleanup_transactions_txn(txn):
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"_cleanup_transactions", _cleanup_transactions_txn
)
diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 5f1b919748..37276f73f8 100644
--- a/synapse/storage/data_stores/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -81,7 +81,7 @@ class UIAuthWorkerStore(SQLBaseStore):
session_id = stringutils.random_string(24)
try:
- await self.db.simple_insert(
+ await self.db_pool.simple_insert(
table="ui_auth_sessions",
values={
"session_id": session_id,
@@ -97,7 +97,7 @@ class UIAuthWorkerStore(SQLBaseStore):
return UIAuthSessionData(
session_id, clientdict, uri, method, description
)
- except self.db.engine.module.IntegrityError:
+ except self.db_pool.engine.module.IntegrityError:
attempts += 1
raise StoreError(500, "Couldn't generate a session ID.")
@@ -111,7 +111,7 @@ class UIAuthWorkerStore(SQLBaseStore):
Raises:
StoreError if the session is not found.
"""
- result = await self.db.simple_select_one(
+ result = await self.db_pool.simple_select_one(
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("clientdict", "uri", "method", "description"),
@@ -140,13 +140,13 @@ class UIAuthWorkerStore(SQLBaseStore):
# Note that we need to allow for the same stage to complete multiple
# times here so that registration is idempotent.
try:
- await self.db.simple_upsert(
+ await self.db_pool.simple_upsert(
table="ui_auth_sessions_credentials",
keyvalues={"session_id": session_id, "stage_type": stage_type},
values={"result": json.dumps(result)},
desc="mark_ui_auth_stage_complete",
)
- except self.db.engine.module.IntegrityError:
+ except self.db_pool.engine.module.IntegrityError:
raise StoreError(400, "Unknown session ID: %s" % (session_id,))
async def get_completed_ui_auth_stages(
@@ -162,7 +162,7 @@ class UIAuthWorkerStore(SQLBaseStore):
that auth-type.
"""
results = {}
- for row in await self.db.simple_select_list(
+ for row in await self.db_pool.simple_select_list(
table="ui_auth_sessions_credentials",
keyvalues={"session_id": session_id},
retcols=("stage_type", "result"),
@@ -186,7 +186,7 @@ class UIAuthWorkerStore(SQLBaseStore):
# The clientdict gets stored as JSON.
clientdict_json = json.dumps(clientdict)
- await self.db.simple_update_one(
+ await self.db_pool.simple_update_one(
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
updatevalues={"clientdict": clientdict_json},
@@ -206,7 +206,7 @@ class UIAuthWorkerStore(SQLBaseStore):
Raises:
StoreError if the session cannot be found.
"""
- await self.db.runInteraction(
+ await self.db_pool.runInteraction(
"set_ui_auth_session_data",
self._set_ui_auth_session_data_txn,
session_id,
@@ -216,7 +216,7 @@ class UIAuthWorkerStore(SQLBaseStore):
def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
# Get the current value.
- result = self.db.simple_select_one_txn(
+ result = self.db_pool.simple_select_one_txn(
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
@@ -227,7 +227,7 @@ class UIAuthWorkerStore(SQLBaseStore):
serverdict = db_to_json(result["serverdict"])
serverdict[key] = value
- self.db.simple_update_one_txn(
+ self.db_pool.simple_update_one_txn(
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
@@ -247,7 +247,7 @@ class UIAuthWorkerStore(SQLBaseStore):
Raises:
StoreError if the session cannot be found.
"""
- result = await self.db.simple_select_one(
+ result = await self.db_pool.simple_select_one(
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
@@ -269,7 +269,7 @@ class UIAuthStore(UIAuthWorkerStore):
This is an epoch time in milliseconds.
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"delete_old_ui_auth_sessions",
self._delete_old_ui_auth_sessions_txn,
expiration_time,
@@ -282,7 +282,7 @@ class UIAuthStore(UIAuthWorkerStore):
session_ids = [r[0] for r in txn.fetchall()]
# Delete the corresponding completed credentials.
- self.db.simple_delete_many_txn(
+ self.db_pool.simple_delete_many_txn(
txn,
table="ui_auth_sessions_credentials",
column="session_id",
@@ -291,7 +291,7 @@ class UIAuthStore(UIAuthWorkerStore):
)
# Finally, delete the sessions.
- self.db.simple_delete_many_txn(
+ self.db_pool.simple_delete_many_txn(
txn,
table="ui_auth_sessions",
column="session_id",
diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 942e51fd3a..af21fe457a 100644
--- a/synapse/storage/data_stores/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -16,12 +16,10 @@
import logging
import re
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, JoinRules
-from synapse.storage.data_stores.main.state import StateFilter
-from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.state import StateFilter
+from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id, get_localpart_from_id
from synapse.util.caches.descriptors import cached
@@ -38,29 +36,28 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.server_name = hs.hostname
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"populate_user_directory_createtables",
self._populate_user_directory_createtables,
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"populate_user_directory_process_rooms",
self._populate_user_directory_process_rooms,
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"populate_user_directory_process_users",
self._populate_user_directory_process_users,
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
"populate_user_directory_cleanup", self._populate_user_directory_cleanup
)
- @defer.inlineCallbacks
- def _populate_user_directory_createtables(self, progress, batch_size):
+ async def _populate_user_directory_createtables(self, progress, batch_size):
# Get all the rooms that we want to process.
def _make_staging_area(txn):
@@ -85,7 +82,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"""
txn.execute(sql)
rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()]
- self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
+ self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
del rooms
# If search all users is on, get all the users we want to add.
@@ -100,43 +97,45 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.execute("SELECT name FROM users")
users = [{"user_id": x[0]} for x in txn.fetchall()]
- self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
+ self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
- new_pos = yield self.get_max_stream_id_in_current_state_deltas()
- yield self.db.runInteraction(
+ new_pos = await self.get_max_stream_id_in_current_state_deltas()
+ await self.db_pool.runInteraction(
"populate_user_directory_temp_build", _make_staging_area
)
- yield self.db.simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
+ await self.db_pool.simple_insert(
+ TEMP_TABLE + "_position", {"position": new_pos}
+ )
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
"populate_user_directory_createtables"
)
return 1
- @defer.inlineCallbacks
- def _populate_user_directory_cleanup(self, progress, batch_size):
+ async def _populate_user_directory_cleanup(self, progress, batch_size):
"""
Update the user directory stream position, then clean up the old tables.
"""
- position = yield self.db.simple_select_one_onecol(
+ position = await self.db_pool.simple_select_one_onecol(
TEMP_TABLE + "_position", None, "position"
)
- yield self.update_user_directory_stream_pos(position)
+ await self.update_user_directory_stream_pos(position)
def _delete_staging_area(txn):
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"populate_user_directory_cleanup", _delete_staging_area
)
- yield self.db.updates._end_background_update("populate_user_directory_cleanup")
+ await self.db_pool.updates._end_background_update(
+ "populate_user_directory_cleanup"
+ )
return 1
- @defer.inlineCallbacks
- def _populate_user_directory_process_rooms(self, progress, batch_size):
+ async def _populate_user_directory_process_rooms(self, progress, batch_size):
"""
Args:
progress (dict)
@@ -147,7 +146,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# If we don't have progress filed, delete everything.
if not progress:
- yield self.delete_all_from_user_dir()
+ await self.delete_all_from_user_dir()
def _get_next_batch(txn):
# Only fetch 250 rooms, so we don't fetch too many at once, even
@@ -172,13 +171,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return rooms_to_work_on
- rooms_to_work_on = yield self.db.runInteraction(
+ rooms_to_work_on = await self.db_pool.runInteraction(
"populate_user_directory_temp_read", _get_next_batch
)
# No more rooms -- complete the transaction.
if not rooms_to_work_on:
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
"populate_user_directory_process_rooms"
)
return 1
@@ -191,21 +190,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
processed_event_count = 0
for room_id, event_count in rooms_to_work_on:
- is_in_room = yield self.is_host_joined(room_id, self.server_name)
+ is_in_room = await self.is_host_joined(room_id, self.server_name)
if is_in_room:
- is_public = yield self.is_room_world_readable_or_publicly_joinable(
+ is_public = await self.is_room_world_readable_or_publicly_joinable(
room_id
)
- users_with_profile = yield defer.ensureDeferred(
- state.get_current_users_in_room(room_id)
- )
+ users_with_profile = await state.get_current_users_in_room(room_id)
user_ids = set(users_with_profile)
# Update each user in the user directory.
for user_id, profile in users_with_profile.items():
- yield self.update_profile_in_user_dir(
+ await self.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
@@ -219,7 +216,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
to_insert.add(user_id)
if to_insert:
- yield self.add_users_in_public_rooms(room_id, to_insert)
+ await self.add_users_in_public_rooms(room_id, to_insert)
to_insert.clear()
else:
for user_id in user_ids:
@@ -239,22 +236,24 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# If it gets too big, stop and write to the database
# to prevent storing too much in RAM.
if len(to_insert) >= self.SHARE_PRIVATE_WORKING_SET:
- yield self.add_users_who_share_private_room(
+ await self.add_users_who_share_private_room(
room_id, to_insert
)
to_insert.clear()
if to_insert:
- yield self.add_users_who_share_private_room(room_id, to_insert)
+ await self.add_users_who_share_private_room(room_id, to_insert)
to_insert.clear()
# We've finished a room. Delete it from the table.
- yield self.db.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
+ await self.db_pool.simple_delete_one(
+ TEMP_TABLE + "_rooms", {"room_id": room_id}
+ )
# Update the remaining counter.
progress["remaining"] -= 1
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"populate_user_directory",
- self.db.updates._background_update_progress_txn,
+ self.db_pool.updates._background_update_progress_txn,
"populate_user_directory_process_rooms",
progress,
)
@@ -267,13 +266,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return processed_event_count
- @defer.inlineCallbacks
- def _populate_user_directory_process_users(self, progress, batch_size):
+ async def _populate_user_directory_process_users(self, progress, batch_size):
"""
If search_all_users is enabled, add all of the users to the user directory.
"""
if not self.hs.config.user_directory_search_all_users:
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
"populate_user_directory_process_users"
)
return 1
@@ -299,13 +297,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return users_to_work_on
- users_to_work_on = yield self.db.runInteraction(
+ users_to_work_on = await self.db_pool.runInteraction(
"populate_user_directory_temp_read", _get_next_batch
)
# No more users -- complete the transaction.
if not users_to_work_on:
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
"populate_user_directory_process_users"
)
return 1
@@ -316,26 +314,27 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
for user_id in users_to_work_on:
- profile = yield self.get_profileinfo(get_localpart_from_id(user_id))
- yield self.update_profile_in_user_dir(
+ profile = await self.get_profileinfo(get_localpart_from_id(user_id))
+ await self.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url
)
# We've finished processing a user. Delete it from the table.
- yield self.db.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id})
+ await self.db_pool.simple_delete_one(
+ TEMP_TABLE + "_users", {"user_id": user_id}
+ )
# Update the remaining counter.
progress["remaining"] -= 1
- yield self.db.runInteraction(
+ await self.db_pool.runInteraction(
"populate_user_directory",
- self.db.updates._background_update_progress_txn,
+ self.db_pool.updates._background_update_progress_txn,
"populate_user_directory_process_users",
progress,
)
return len(users_to_work_on)
- @defer.inlineCallbacks
- def is_room_world_readable_or_publicly_joinable(self, room_id):
+ async def is_room_world_readable_or_publicly_joinable(self, room_id):
"""Check if the room is either world_readable or publically joinable
"""
@@ -345,20 +344,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
(EventTypes.RoomHistoryVisibility, ""),
)
- current_state_ids = yield self.get_filtered_current_state_ids(
+ current_state_ids = await self.get_filtered_current_state_ids(
room_id, StateFilter.from_types(types_to_filter)
)
join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
if join_rules_id:
- join_rule_ev = yield self.get_event(join_rules_id, allow_none=True)
+ join_rule_ev = await self.get_event(join_rules_id, allow_none=True)
if join_rule_ev:
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
return True
hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
if hist_vis_id:
- hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True)
+ hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
if hist_vis_ev:
if hist_vis_ev.content.get("history_visibility") == "world_readable":
return True
@@ -371,7 +370,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"""
def _update_profile_in_user_dir_txn(txn):
- new_entry = self.db.simple_upsert_txn(
+ new_entry = self.db_pool.simple_upsert_txn(
txn,
table="user_directory",
keyvalues={"user_id": user_id},
@@ -445,7 +444,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
elif isinstance(self.database_engine, Sqlite3Engine):
value = "%s %s" % (user_id, display_name) if display_name else user_id
- self.db.simple_upsert_txn(
+ self.db_pool.simple_upsert_txn(
txn,
table="user_directory_search",
keyvalues={"user_id": user_id},
@@ -458,7 +457,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
)
@@ -472,7 +471,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
"""
def _add_users_who_share_room_txn(txn):
- self.db.simple_upsert_many_txn(
+ self.db_pool.simple_upsert_many_txn(
txn,
table="users_who_share_private_rooms",
key_names=["user_id", "other_user_id", "room_id"],
@@ -484,7 +483,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
value_values=None,
)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"add_users_who_share_room", _add_users_who_share_room_txn
)
@@ -499,7 +498,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
def _add_users_in_public_rooms_txn(txn):
- self.db.simple_upsert_many_txn(
+ self.db_pool.simple_upsert_many_txn(
txn,
table="users_in_public_rooms",
key_names=["user_id", "room_id"],
@@ -508,7 +507,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
value_values=None,
)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
)
@@ -523,13 +522,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.execute("DELETE FROM users_who_share_private_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
@cached()
def get_user_in_directory(self, user_id):
- return self.db.simple_select_one(
+ return self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
@@ -538,7 +537,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
def update_user_directory_stream_pos(self, stream_id):
- return self.db.simple_update_one(
+ return self.db_pool.simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
updatevalues={"stream_id": stream_id},
@@ -552,47 +551,48 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(UserDirectoryStore, self).__init__(database, db_conn, hs)
def remove_from_user_dir(self, user_id):
def _remove_from_user_dir_txn(txn):
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, table="user_directory_search", keyvalues={"user_id": user_id}
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"other_user_id": user_id},
)
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- return self.db.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
+ return self.db_pool.runInteraction(
+ "remove_from_user_dir", _remove_from_user_dir_txn
+ )
- @defer.inlineCallbacks
- def get_users_in_dir_due_to_room(self, room_id):
+ async def get_users_in_dir_due_to_room(self, room_id):
"""Get all user_ids that are in the room directory because they're
in the given room_id
"""
- user_ids_share_pub = yield self.db.simple_select_onecol(
+ user_ids_share_pub = await self.db_pool.simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_dir_due_to_room",
)
- user_ids_share_priv = yield self.db.simple_select_onecol(
+ user_ids_share_priv = await self.db_pool.simple_select_onecol(
table="users_who_share_private_rooms",
keyvalues={"room_id": room_id},
retcol="other_user_id",
@@ -615,28 +615,27 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
"""
def _remove_user_who_share_room_txn(txn):
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id, "room_id": room_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"other_user_id": user_id, "room_id": room_id},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="users_in_public_rooms",
keyvalues={"user_id": user_id, "room_id": room_id},
)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"remove_user_who_share_room", _remove_user_who_share_room_txn
)
- @defer.inlineCallbacks
- def get_user_dir_rooms_user_is_in(self, user_id):
+ async def get_user_dir_rooms_user_is_in(self, user_id):
"""
Returns the rooms that a user is in.
@@ -646,14 +645,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
Returns:
list: user_id
"""
- rows = yield self.db.simple_select_onecol(
+ rows = await self.db_pool.simple_select_onecol(
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id},
retcol="room_id",
desc="get_rooms_user_is_in",
)
- pub_rows = yield self.db.simple_select_onecol(
+ pub_rows = await self.db_pool.simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"user_id": user_id},
retcol="room_id",
@@ -664,42 +663,15 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
- @defer.inlineCallbacks
- def get_rooms_in_common_for_users(self, user_id, other_user_id):
- """Given two user_ids find out the list of rooms they share.
- """
- sql = """
- SELECT room_id FROM (
- SELECT c.room_id FROM current_state_events AS c
- INNER JOIN room_memberships AS m USING (event_id)
- WHERE type = 'm.room.member'
- AND m.membership = 'join'
- AND state_key = ?
- ) AS f1 INNER JOIN (
- SELECT c.room_id FROM current_state_events AS c
- INNER JOIN room_memberships AS m USING (event_id)
- WHERE type = 'm.room.member'
- AND m.membership = 'join'
- AND state_key = ?
- ) f2 USING (room_id)
- """
-
- rows = yield self.db.execute(
- "get_rooms_in_common_for_users", None, sql, user_id, other_user_id
- )
-
- return [room_id for room_id, in rows]
-
def get_user_directory_stream_pos(self):
- return self.db.simple_select_one_onecol(
+ return self.db_pool.simple_select_one_onecol(
table="user_directory_stream_pos",
keyvalues={},
retcol="stream_id",
desc="get_user_directory_stream_pos",
)
- @defer.inlineCallbacks
- def search_user_dir(self, user_id, search_term, limit):
+ async def search_user_dir(self, user_id, search_term, limit):
"""Searches for users in directory
Returns:
@@ -796,8 +768,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
- results = yield self.db.execute(
- "search_user_dir", self.db.cursor_to_dict, sql, *args
+ results = await self.db_pool.execute(
+ "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
)
limited = len(results) > limit
diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index d3038ff06d..ab6cb2c1f6 100644
--- a/synapse/storage/data_stores/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -31,7 +31,7 @@ class UserErasureWorkerStore(SQLBaseStore):
Returns:
Deferred[bool]: True if the user has requested erasure
"""
- return self.db.simple_select_onecol(
+ return self.db_pool.simple_select_onecol(
table="erased_users",
keyvalues={"user_id": user_id},
retcol="1",
@@ -56,7 +56,7 @@ class UserErasureWorkerStore(SQLBaseStore):
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))
- rows = yield self.db.simple_select_many_batch(
+ rows = yield self.db_pool.simple_select_many_batch(
table="erased_users",
column="user_id",
iterable=user_ids,
@@ -88,7 +88,7 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
- return self.db.runInteraction("mark_user_erased", f)
+ return self.db_pool.runInteraction("mark_user_erased", f)
def mark_user_not_erased(self, user_id: str) -> None:
"""Indicate that user_id is no longer erased.
@@ -110,4 +110,4 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
- return self.db.runInteraction("mark_user_not_erased", f)
+ return self.db_pool.runInteraction("mark_user_not_erased", f)
diff --git a/synapse/storage/data_stores/state/__init__.py b/synapse/storage/databases/state/__init__.py
index 86e09f6229..c90d022899 100644
--- a/synapse/storage/data_stores/state/__init__.py
+++ b/synapse/storage/databases/state/__init__.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.state.store import StateGroupDataStore # noqa: F401
+from synapse.storage.databases.state.store import StateGroupDataStore # noqa: F401
diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index be1fe97d79..139085b672 100644
--- a/synapse/storage/data_stores/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -15,10 +15,8 @@
import logging
-from twisted.internet import defer
-
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.state import StateFilter
@@ -62,7 +60,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
count = 0
while next_group:
- next_group = self.db.simple_select_one_onecol_txn(
+ next_group = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
@@ -165,7 +163,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
):
break
- next_group = self.db.simple_select_one_onecol_txn(
+ next_group = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
@@ -182,24 +180,23 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,
)
- self.db.updates.register_background_update_handler(
+ self.db_pool.updates.register_background_update_handler(
self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
)
- self.db.updates.register_background_index_update(
+ self.db_pool.updates.register_background_index_update(
self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME,
index_name="state_groups_room_id_idx",
table="state_groups",
columns=["room_id"],
)
- @defer.inlineCallbacks
- def _background_deduplicate_state(self, progress, batch_size):
+ async def _background_deduplicate_state(self, progress, batch_size):
"""This background update will slowly deduplicate state by reencoding
them as deltas.
"""
@@ -212,7 +209,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
if max_group is None:
- rows = yield self.db.execute(
+ rows = await self.db_pool.execute(
"_background_deduplicate_state",
None,
"SELECT coalesce(max(id), 0) FROM state_groups",
@@ -282,13 +279,13 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
if prev_state.get(key, None) != value
}
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": state_group},
)
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="state_group_edges",
values={
@@ -297,13 +294,13 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
},
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn,
table="state_groups_state",
keyvalues={"state_group": state_group},
)
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@@ -324,25 +321,24 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
"max_group": max_group,
}
- self.db.updates._background_update_progress_txn(
+ self.db_pool.updates._background_update_progress_txn(
txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
)
return False, batch_size
- finished, result = yield self.db.runInteraction(
+ finished, result = await self.db_pool.runInteraction(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
)
if finished:
- yield self.db.updates._end_background_update(
+ await self.db_pool.updates._end_background_update(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
)
return result * BATCH_SIZE_SCALE_FACTOR
- @defer.inlineCallbacks
- def _background_index_state(self, progress, batch_size):
+ async def _background_index_state(self, progress, batch_size):
def reindex_txn(conn):
conn.rollback()
if isinstance(self.database_engine, PostgresEngine):
@@ -365,8 +361,10 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
)
txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
- yield self.db.runWithConnection(reindex_txn)
+ await self.db_pool.runWithConnection(reindex_txn)
- yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
+ await self.db_pool.updates._end_background_update(
+ self.STATE_GROUP_INDEX_UPDATE_NAME
+ )
return 1
diff --git a/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql b/synapse/storage/databases/state/schema/delta/23/drop_state_index.sql
index ae09fa0065..ae09fa0065 100644
--- a/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql
+++ b/synapse/storage/databases/state/schema/delta/23/drop_state_index.sql
diff --git a/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql b/synapse/storage/databases/state/schema/delta/30/state_stream.sql
index e85699e82e..e85699e82e 100644
--- a/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql
+++ b/synapse/storage/databases/state/schema/delta/30/state_stream.sql
diff --git a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql b/synapse/storage/databases/state/schema/delta/32/remove_state_indices.sql
index 1450313bfa..1450313bfa 100644
--- a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql
+++ b/synapse/storage/databases/state/schema/delta/32/remove_state_indices.sql
diff --git a/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql b/synapse/storage/databases/state/schema/delta/35/add_state_index.sql
index 33980d02f0..33980d02f0 100644
--- a/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql
+++ b/synapse/storage/databases/state/schema/delta/35/add_state_index.sql
diff --git a/synapse/storage/data_stores/state/schema/delta/35/state.sql b/synapse/storage/databases/state/schema/delta/35/state.sql
index 0f1fa68a89..0f1fa68a89 100644
--- a/synapse/storage/data_stores/state/schema/delta/35/state.sql
+++ b/synapse/storage/databases/state/schema/delta/35/state.sql
diff --git a/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql b/synapse/storage/databases/state/schema/delta/35/state_dedupe.sql
index 97e5067ef4..97e5067ef4 100644
--- a/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql
+++ b/synapse/storage/databases/state/schema/delta/35/state_dedupe.sql
diff --git a/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py b/synapse/storage/databases/state/schema/delta/47/state_group_seq.py
index 9fd1ccf6f7..9fd1ccf6f7 100644
--- a/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py
+++ b/synapse/storage/databases/state/schema/delta/47/state_group_seq.py
diff --git a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql b/synapse/storage/databases/state/schema/delta/56/state_group_room_idx.sql
index 7916ef18b2..7916ef18b2 100644
--- a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql
+++ b/synapse/storage/databases/state/schema/delta/56/state_group_room_idx.sql
diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql b/synapse/storage/databases/state/schema/full_schemas/54/full.sql
index 35f97d6b3d..35f97d6b3d 100644
--- a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql
+++ b/synapse/storage/databases/state/schema/full_schemas/54/full.sql
diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres b/synapse/storage/databases/state/schema/full_schemas/54/sequence.sql.postgres
index fcd926c9fb..fcd926c9fb 100644
--- a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres
+++ b/synapse/storage/databases/state/schema/full_schemas/54/sequence.sql.postgres
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/databases/state/store.py
index 7dada7f75f..7f104ad936 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -21,8 +21,8 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.state import StateFilter
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
@@ -53,7 +53,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"""A data store for fetching/storing state groups.
"""
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(StateGroupDataStore, self).__init__(database, db_conn, hs)
# Originally the state store used a single DictionaryCache to cache the
@@ -112,7 +112,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"""
def _get_state_group_delta_txn(txn):
- prev_group = self.db.simple_select_one_onecol_txn(
+ prev_group = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": state_group},
@@ -123,7 +123,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
if not prev_group:
return _GetStateGroupDelta(None, None)
- delta_ids = self.db.simple_select_list_txn(
+ delta_ids = self.db_pool.simple_select_list_txn(
txn,
table="state_groups_state",
keyvalues={"state_group": state_group},
@@ -135,7 +135,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
)
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"get_state_group_delta", _get_state_group_delta_txn
)
@@ -156,7 +156,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
for chunk in chunks:
- res = await self.db.runInteraction(
+ res = await self.db_pool.runInteraction(
"_get_state_groups_from_groups",
self._get_state_groups_from_groups_txn,
chunk,
@@ -393,7 +393,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_group = self._state_group_seq_gen.get_next_id_txn(txn)
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="state_groups",
values={"id": state_group, "room_id": room_id, "event_id": event_id},
@@ -402,7 +402,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if prev_group:
- is_in_db = self.db.simple_select_one_onecol_txn(
+ is_in_db = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": prev_group},
@@ -417,13 +417,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
- self.db.simple_insert_txn(
+ self.db_pool.simple_insert_txn(
txn,
table="state_group_edges",
values={"state_group": state_group, "prev_state_group": prev_group},
)
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@@ -438,7 +438,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
],
)
else:
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@@ -484,7 +484,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_group
- return self.db.runInteraction("store_state_group", _store_state_group_txn)
+ return self.db_pool.runInteraction("store_state_group", _store_state_group_txn)
def purge_unreferenced_state_groups(
self, room_id: str, state_groups_to_delete
@@ -499,7 +499,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
to delete.
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"purge_unreferenced_state_groups",
self._purge_unreferenced_state_groups,
room_id,
@@ -511,7 +511,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"[purge] found %i state groups to delete", len(state_groups_to_delete)
)
- rows = self.db.simple_select_many_txn(
+ rows = self.db_pool.simple_select_many_txn(
txn,
table="state_group_edges",
column="prev_state_group",
@@ -538,15 +538,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
curr_state = curr_state[sg]
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, table="state_groups_state", keyvalues={"state_group": sg}
)
- self.db.simple_delete_txn(
+ self.db_pool.simple_delete_txn(
txn, table="state_group_edges", keyvalues={"state_group": sg}
)
- self.db.simple_insert_many_txn(
+ self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@@ -583,7 +583,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
A mapping from state group to previous state group.
"""
- rows = await self.db.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="state_group_edges",
column="prev_state_group",
iterable=state_groups,
@@ -602,7 +602,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_groups_to_delete (list[int]): State groups to delete
"""
- return self.db.runInteraction(
+ return self.db_pool.runInteraction(
"purge_room_state",
self._purge_room_state_txn,
room_id,
@@ -613,7 +613,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# first we have to delete the state groups states
logger.info("[purge] removing %s from state_groups_state", room_id)
- self.db.simple_delete_many_txn(
+ self.db_pool.simple_delete_many_txn(
txn,
table="state_groups_state",
column="state_group",
@@ -624,7 +624,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# ... and the state group edges
logger.info("[purge] removing %s from state_group_edges", room_id)
- self.db.simple_delete_many_txn(
+ self.db_pool.simple_delete_many_txn(
txn,
table="state_group_edges",
column="state_group",
@@ -635,7 +635,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# ... and the state groups
logger.info("[purge] removing %s from state_groups", room_id)
- self.db.simple_delete_many_txn(
+ self.db_pool.simple_delete_many_txn(
txn,
table="state_groups",
column="id",
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 4a164834d9..f15b95e633 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -29,8 +29,8 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.data_stores import DataStores
-from synapse.storage.data_stores.main.events import DeltaState
+from synapse.storage.databases import Databases
+from synapse.storage.databases.main.events import DeltaState
from synapse.types import StateMap
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@@ -179,7 +179,7 @@ class EventsPersistenceStorage(object):
current state and forward extremity changes.
"""
- def __init__(self, hs, stores: DataStores):
+ def __init__(self, hs, stores: Databases):
# We ultimately want to split out the state store from the main store,
# so we use separate variables here even though they point to the same
# store for now.
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 9cc3b51fe6..1c5f305132 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -47,8 +47,8 @@ class UpgradeDatabaseException(PrepareDatabaseException):
pass
-def prepare_database(db_conn, database_engine, config, data_stores=["main", "state"]):
- """Prepares a database for usage. Will either create all necessary tables
+def prepare_database(db_conn, database_engine, config, databases=["main", "state"]):
+ """Prepares a physical database for usage. Will either create all necessary tables
or upgrade from an older schema version.
If `config` is None then prepare_database will assert that no upgrade is
@@ -60,8 +60,8 @@ def prepare_database(db_conn, database_engine, config, data_stores=["main", "sta
config (synapse.config.homeserver.HomeServerConfig|None):
application config, or None if we are connecting to an existing
database which we expect to be configured already
- data_stores (list[str]): The name of the data stores that will be used
- with this database. Defaults to all data stores.
+ databases (list[str]): The name of the databases that will be used
+ with this physical database. Defaults to all databases.
"""
try:
@@ -87,10 +87,10 @@ def prepare_database(db_conn, database_engine, config, data_stores=["main", "sta
upgraded,
database_engine,
config,
- data_stores=data_stores,
+ databases=databases,
)
else:
- _setup_new_database(cur, database_engine, data_stores=data_stores)
+ _setup_new_database(cur, database_engine, databases=databases)
# check if any of our configured dynamic modules want a database
if config is not None:
@@ -103,9 +103,9 @@ def prepare_database(db_conn, database_engine, config, data_stores=["main", "sta
raise
-def _setup_new_database(cur, database_engine, data_stores):
- """Sets up the database by finding a base set of "full schemas" and then
- applying any necessary deltas, including schemas from the given data
+def _setup_new_database(cur, database_engine, databases):
+ """Sets up the physical database by finding a base set of "full schemas" and
+ then applying any necessary deltas, including schemas from the given data
stores.
The "full_schemas" directory has subdirectories named after versions. This
@@ -138,8 +138,8 @@ def _setup_new_database(cur, database_engine, data_stores):
Args:
cur (Cursor): a database cursor
database_engine (DatabaseEngine)
- data_stores (list[str]): The names of the data stores to instantiate
- on the given database.
+ databases (list[str]): The names of the databases to instantiate
+ on the given physical database.
"""
# We're about to set up a brand new database so we check that its
@@ -176,13 +176,13 @@ def _setup_new_database(cur, database_engine, data_stores):
directories.extend(
os.path.join(
dir_path,
- "data_stores",
- data_store,
+ "databases",
+ database,
"schema",
"full_schemas",
str(max_current_ver),
)
- for data_store in data_stores
+ for database in databases
)
directory_entries = []
@@ -219,7 +219,7 @@ def _setup_new_database(cur, database_engine, data_stores):
upgraded=False,
database_engine=database_engine,
config=None,
- data_stores=data_stores,
+ databases=databases,
is_empty=True,
)
@@ -231,10 +231,10 @@ def _upgrade_existing_database(
upgraded,
database_engine,
config,
- data_stores,
+ databases,
is_empty=False,
):
- """Upgrades an existing database.
+ """Upgrades an existing physical database.
Delta files can either be SQL stored in *.sql files, or python modules
in *.py.
@@ -285,8 +285,8 @@ def _upgrade_existing_database(
config (synapse.config.homeserver.HomeServerConfig|None):
None if we are initialising a blank database, otherwise the application
config
- data_stores (list[str]): The names of the data stores to instantiate
- on the given database.
+ databases (list[str]): The names of the databases to instantiate
+ on the given physical database.
is_empty (bool): Is this a blank database? I.e. do we need to run the
upgrade portions of the delta scripts.
"""
@@ -303,8 +303,8 @@ def _upgrade_existing_database(
# some of the deltas assume that config.server_name is set correctly, so now
# is a good time to run the sanity check.
- if not is_empty and "main" in data_stores:
- from synapse.storage.data_stores.main import check_database_before_upgrade
+ if not is_empty and "main" in databases:
+ from synapse.storage.databases.main import check_database_before_upgrade
check_database_before_upgrade(cur, database_engine, config)
@@ -330,11 +330,9 @@ def _upgrade_existing_database(
# First we find the directories to search in
delta_dir = os.path.join(dir_path, "schema", "delta", str(v))
directories = [delta_dir]
- for data_store in data_stores:
+ for database in databases:
directories.append(
- os.path.join(
- dir_path, "data_stores", data_store, "schema", "delta", str(v)
- )
+ os.path.join(dir_path, "databases", database, "schema", "delta", str(v))
)
# Used to check if we have any duplicate file names
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 787cebfbec..e2ddd01290 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -20,7 +20,7 @@ from typing import Dict, Set, Tuple
from typing_extensions import Deque
-from synapse.storage.database import Database, LoggingTransaction
+from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.util.sequence import PostgresSequenceGenerator
@@ -239,7 +239,7 @@ class MultiWriterIdGenerator:
def __init__(
self,
db_conn,
- db: Database,
+ db: DatabasePool,
instance_name: str,
table: str,
instance_column: str,
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 5d3eddcfdc..393e34b9fb 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -15,8 +15,6 @@
from typing import Any, Dict
-from twisted.internet import defer
-
from synapse.handlers.account_data import AccountDataEventSource
from synapse.handlers.presence import PresenceEventSource
from synapse.handlers.receipts import ReceiptEventSource
@@ -40,19 +38,18 @@ class EventSources(object):
} # type: Dict[str, Any]
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def get_current_token(self):
+ def get_current_token(self) -> StreamToken:
push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
groups_key = self.store.get_group_stream_token()
token = StreamToken(
- room_key=(yield self.sources["room"].get_current_key()),
- presence_key=(yield self.sources["presence"].get_current_key()),
- typing_key=(yield self.sources["typing"].get_current_key()),
- receipt_key=(yield self.sources["receipt"].get_current_key()),
- account_data_key=(yield self.sources["account_data"].get_current_key()),
+ room_key=self.sources["room"].get_current_key(),
+ presence_key=self.sources["presence"].get_current_key(),
+ typing_key=self.sources["typing"].get_current_key(),
+ receipt_key=self.sources["receipt"].get_current_key(),
+ account_data_key=self.sources["account_data"].get_current_key(),
push_rules_key=push_rules_key,
to_device_key=to_device_key,
device_list_key=device_list_key,
@@ -60,8 +57,7 @@ class EventSources(object):
)
return token
- @defer.inlineCallbacks
- def get_current_token_for_pagination(self):
+ def get_current_token_for_pagination(self) -> StreamToken:
"""Get the current token for a given room to be used to paginate
events.
@@ -69,10 +65,10 @@ class EventSources(object):
than `room`, since they are not used during pagination.
Returns:
- Deferred[StreamToken]
+ The current token for pagination.
"""
token = StreamToken(
- room_key=(yield self.sources["room"].get_current_key()),
+ room_key=self.sources["room"].get_current_key(),
presence_key=0,
typing_key=0,
receipt_key=0,
diff --git a/synapse/types.py b/synapse/types.py
index 238b938064..9e580f4295 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -13,11 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import re
import string
import sys
from collections import namedtuple
-from typing import Any, Dict, Tuple, TypeVar
+from typing import Any, Dict, Tuple, Type, TypeVar
import attr
from signedjson.key import decode_verify_key_bytes
@@ -33,7 +34,7 @@ else:
T_co = TypeVar("T_co", covariant=True)
- class Collection(Iterable[T_co], Container[T_co], Sized):
+ class Collection(Iterable[T_co], Container[T_co], Sized): # type: ignore
__slots__ = ()
@@ -141,6 +142,9 @@ def get_localpart_from_id(string):
return string[1:idx]
+DS = TypeVar("DS", bound="DomainSpecificString")
+
+
class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "domain"))):
"""Common base class among ID/name strings that have a local part and a
domain name, prefixed with a sigil.
@@ -151,6 +155,10 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
'domain' : The domain part of the name
"""
+ __metaclass__ = abc.ABCMeta
+
+ SIGIL = abc.abstractproperty() # type: str # type: ignore
+
# Deny iteration because it will bite you if you try to create a singleton
# set by:
# users = set(user)
@@ -166,7 +174,7 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
return self
@classmethod
- def from_string(cls, s: str):
+ def from_string(cls: Type[DS], s: str) -> DS:
"""Parse the string given by 's' into a structure object."""
if len(s) < 1 or s[0:1] != cls.SIGIL:
raise SynapseError(
@@ -190,12 +198,12 @@ class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "dom
# names on one HS
return cls(localpart=parts[0], domain=domain)
- def to_string(self):
+ def to_string(self) -> str:
"""Return a string encoding the fields of the structure object."""
return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)
@classmethod
- def is_valid(cls, s):
+ def is_valid(cls: Type[DS], s: str) -> bool:
try:
cls.from_string(s)
return True
@@ -235,8 +243,9 @@ class GroupID(DomainSpecificString):
SIGIL = "+"
@classmethod
- def from_string(cls, s):
- group_id = super(GroupID, cls).from_string(s)
+ def from_string(cls: Type[DS], s: str) -> DS:
+ group_id = super().from_string(s) # type: DS # type: ignore
+
if not group_id.localpart:
raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index c63256d3bd..b3f76428b6 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -17,6 +17,7 @@ import logging
import re
import attr
+from canonicaljson import json
from twisted.internet import defer, task
@@ -24,6 +25,9 @@ from synapse.logging import context
logger = logging.getLogger(__name__)
+# Create a custom encoder to reduce the whitespace produced by JSON encoding.
+json_encoder = json.JSONEncoder(separators=(",", ":"))
+
def unwrapFirstError(failure):
# defer.gatherResults and DeferredLists wrap failures.
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 9b09c08b89..c2d72a82cf 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -192,7 +192,7 @@ class Cache(object):
callbacks = [callback] if callback else []
self.check_thread()
observable = ObservableDeferred(value, consumeErrors=True)
- observer = defer.maybeDeferred(observable.observe)
+ observer = observable.observe()
entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py
new file mode 100644
index 0000000000..23393cf49b
--- /dev/null
+++ b/synapse/util/daemonize.py
@@ -0,0 +1,137 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2012, 2013, 2014 Ilya Otyutskiy <ilya.otyutskiy@icloud.com>
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import atexit
+import fcntl
+import logging
+import os
+import signal
+import sys
+
+
+def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -> None:
+ """daemonize the current process
+
+ This calls fork(), and has the main process exit. When it returns we will be
+ running in the child process.
+ """
+
+ # If pidfile already exists, we should read pid from there; to overwrite it, if
+ # locking will fail, because locking attempt somehow purges the file contents.
+ if os.path.isfile(pid_file):
+ with open(pid_file, "r") as pid_fh:
+ old_pid = pid_fh.read()
+
+ # Create a lockfile so that only one instance of this daemon is running at any time.
+ try:
+ lock_fh = open(pid_file, "w")
+ except IOError:
+ print("Unable to create the pidfile.")
+ sys.exit(1)
+
+ try:
+ # Try to get an exclusive lock on the file. This will fail if another process
+ # has the file locked.
+ fcntl.flock(lock_fh, fcntl.LOCK_EX | fcntl.LOCK_NB)
+ except IOError:
+ print("Unable to lock on the pidfile.")
+ # We need to overwrite the pidfile if we got here.
+ #
+ # XXX better to avoid overwriting it, surely. this looks racey as the pid file
+ # could be created between us trying to read it and us trying to lock it.
+ with open(pid_file, "w") as pid_fh:
+ pid_fh.write(old_pid)
+ sys.exit(1)
+
+ # Fork, creating a new process for the child.
+ process_id = os.fork()
+
+ if process_id != 0:
+ # parent process: exit.
+
+ # we use os._exit to avoid running the atexit handlers. In particular, that
+ # means we don't flush the logs. This is important because if we are using
+ # a MemoryHandler, we could have logs buffered which are now buffered in both
+ # the main and the child process, so if we let the main process flush the logs,
+ # we'll get two copies.
+ os._exit(0)
+
+ # This is the child process. Continue.
+
+ # Stop listening for signals that the parent process receives.
+ # This is done by getting a new process id.
+ # setpgrp() is an alternative to setsid().
+ # setsid puts the process in a new parent group and detaches its controlling
+ # terminal.
+
+ os.setsid()
+
+ # point stdin, stdout, stderr at /dev/null
+ devnull = "/dev/null"
+ if hasattr(os, "devnull"):
+ # Python has set os.devnull on this system, use it instead as it might be
+ # different than /dev/null.
+ devnull = os.devnull
+
+ devnull_fd = os.open(devnull, os.O_RDWR)
+ os.dup2(devnull_fd, 0)
+ os.dup2(devnull_fd, 1)
+ os.dup2(devnull_fd, 2)
+ os.close(devnull_fd)
+
+ # now that we have redirected stderr to /dev/null, any uncaught exceptions will
+ # get sent to /dev/null, so make sure we log them.
+ #
+ # (we don't normally expect reactor.run to raise any exceptions, but this will
+ # also catch any other uncaught exceptions before we get that far.)
+
+ def excepthook(type_, value, traceback):
+ logger.critical("Unhanded exception", exc_info=(type_, value, traceback))
+
+ sys.excepthook = excepthook
+
+ # Set umask to default to safe file permissions when running as a root daemon. 027
+ # is an octal number which we are typing as 0o27 for Python3 compatibility.
+ os.umask(0o27)
+
+ # Change to a known directory. If this isn't done, starting a daemon in a
+ # subdirectory that needs to be deleted results in "directory busy" errors.
+ os.chdir(chdir)
+
+ try:
+ lock_fh.write("%s" % (os.getpid()))
+ lock_fh.flush()
+ except IOError:
+ logger.error("Unable to write pid to the pidfile.")
+ print("Unable to write pid to the pidfile.")
+ sys.exit(1)
+
+ # write a log line on SIGTERM.
+ def sigterm(signum, frame):
+ logger.warning("Caught signal %s. Stopping daemon." % signum)
+ sys.exit(0)
+
+ signal.signal(signal.SIGTERM, sigterm)
+
+ # Cleanup pid file at exit.
+ def exit():
+ logger.warning("Stopping daemon.")
+ os.remove(pid_file)
+ sys.exit(0)
+
+ atexit.register(exit)
+
+ logger.warning("Starting daemon.")
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index eab78dd256..0e445e01d7 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -63,5 +63,8 @@ def _handle_frozendict(obj):
)
-# A JSONEncoder which is capable of encoding frozendicts without barfing
-frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict)
+# A JSONEncoder which is capable of encoding frozendicts without barfing.
+# Additionally reduce the whitespace produced by JSON encoding.
+frozendict_json_encoder = json.JSONEncoder(
+ default=_handle_frozendict, separators=(",", ":"),
+)
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index ec61e14423..13775b43f9 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -13,14 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
from functools import wraps
+from typing import Any, Callable, Optional, TypeVar, cast
from prometheus_client import Counter
-from twisted.internet import defer
-
from synapse.logging.context import LoggingContext, current_context
from synapse.metrics import InFlightGauge
@@ -60,29 +58,37 @@ in_flight = InFlightGauge(
sub_metrics=["real_time_max", "real_time_sum"],
)
+T = TypeVar("T", bound=Callable[..., Any])
-def measure_func(name=None):
- def wrapper(func):
- block_name = func.__name__ if name is None else name
- if inspect.iscoroutinefunction(func):
+def measure_func(name: Optional[str] = None) -> Callable[[T], T]:
+ """
+ Used to decorate an async function with a `Measure` context manager.
+
+ Usage:
+
+ @measure_func()
+ async def foo(...):
+ ...
- @wraps(func)
- async def measured_func(self, *args, **kwargs):
- with Measure(self.clock, block_name):
- r = await func(self, *args, **kwargs)
- return r
+ Which is analogous to:
- else:
+ async def foo(...):
+ with Measure(...):
+ ...
+
+ """
+
+ def wrapper(func: T) -> T:
+ block_name = func.__name__ if name is None else name
- @wraps(func)
- @defer.inlineCallbacks
- def measured_func(self, *args, **kwargs):
- with Measure(self.clock, block_name):
- r = yield func(self, *args, **kwargs)
- return r
+ @wraps(func)
+ async def measured_func(self, *args, **kwargs):
+ with Measure(self.clock, block_name):
+ r = await func(self, *args, **kwargs)
+ return r
- return measured_func
+ return cast(T, measured_func)
return wrapper
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 8794317caa..919988d3bc 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -15,8 +15,6 @@
import logging
import random
-from twisted.internet import defer
-
import synapse.logging.context
from synapse.api.errors import CodeMessageException
@@ -54,8 +52,7 @@ class NotRetryingDestination(Exception):
self.destination = destination
-@defer.inlineCallbacks
-def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
+async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
"""For a given destination check if we have previously failed to
send a request there and are waiting before retrying the destination.
If we are not ready to retry the destination, this will raise a
@@ -73,9 +70,9 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
Example usage:
try:
- limiter = yield get_retry_limiter(destination, clock, store)
+ limiter = await get_retry_limiter(destination, clock, store)
with limiter:
- response = yield do_request()
+ response = await do_request()
except NotRetryingDestination:
# We aren't ready to retry that destination.
raise
@@ -83,7 +80,7 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
failure_ts = None
retry_last_ts, retry_interval = (0, 0)
- retry_timings = yield store.get_destination_retry_timings(destination)
+ retry_timings = await store.get_destination_retry_timings(destination)
if retry_timings:
failure_ts = retry_timings["failure_ts"]
@@ -222,10 +219,9 @@ class RetryDestinationLimiter(object):
if self.failure_ts is None:
self.failure_ts = retry_last_ts
- @defer.inlineCallbacks
- def store_retry_timings():
+ async def store_retry_timings():
try:
- yield self.store.set_destination_retry_timings(
+ await self.store.set_destination_retry_timings(
self.destination,
self.failure_ts,
retry_last_ts,
diff --git a/synmark/__init__.py b/synmark/__init__.py
index afe4fad8cb..53698bd5ab 100644
--- a/synmark/__init__.py
+++ b/synmark/__init__.py
@@ -47,9 +47,9 @@ async def make_homeserver(reactor, config=None):
stor = hs.get_datastore()
# Run the database background updates.
- if hasattr(stor.db.updates, "do_next_background_update"):
- while not await stor.db.updates.has_completed_background_updates():
- await stor.db.updates.do_next_background_update(1)
+ if hasattr(stor.db_pool.updates, "do_next_background_update"):
+ while not await stor.db_pool.updates.has_completed_background_updates():
+ await stor.db_pool.updates.do_next_background_update(1)
def cleanup():
for i in cleanup_tasks:
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 0bfb86bf1f..5d45689c8c 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -62,12 +62,15 @@ class AuthTestCase(unittest.TestCase):
# this is overridden for the appservice tests
self.store.get_app_service_by_token = Mock(return_value=None)
+ self.store.insert_client_ip = Mock(return_value=defer.succeed(None))
self.store.is_support_user = Mock(return_value=defer.succeed(False))
@defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self):
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
- self.store.get_user_by_access_token = Mock(return_value=user_info)
+ self.store.get_user_by_access_token = Mock(
+ return_value=defer.succeed(user_info)
+ )
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
@@ -76,23 +79,25 @@ class AuthTestCase(unittest.TestCase):
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self):
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self):
user_info = {"name": self.test_user, "token_id": "ditto"}
- self.store.get_user_by_access_token = Mock(return_value=user_info)
+ self.store.get_user_by_access_token = Mock(
+ return_value=defer.succeed(user_info)
+ )
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, MissingClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@@ -103,7 +108,7 @@ class AuthTestCase(unittest.TestCase):
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
@@ -123,7 +128,7 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "192.168.10.10"
@@ -142,25 +147,25 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "131.111.8.42"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
@@ -168,11 +173,11 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_missing_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, MissingClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@@ -185,7 +190,11 @@ class AuthTestCase(unittest.TestCase):
)
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ # This just needs to return a truth-y value.
+ self.store.get_user_by_id = Mock(
+ return_value=defer.succeed({"is_guest": False})
+ )
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
@@ -204,20 +213,22 @@ class AuthTestCase(unittest.TestCase):
)
app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
self.failureResultOf(d, AuthError)
@defer.inlineCallbacks
def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = Mock(
- return_value={"name": "@baldrick:matrix.org", "device_id": "device"}
+ return_value=defer.succeed(
+ {"name": "@baldrick:matrix.org", "device_id": "device"}
+ )
)
user_id = "@baldrick:matrix.org"
@@ -241,8 +252,8 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
- self.store.get_user_by_id = Mock(return_value={"is_guest": True})
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True}))
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
@@ -282,16 +293,20 @@ class AuthTestCase(unittest.TestCase):
def get_user(tok):
if token != tok:
- return None
- return {
- "name": USER_ID,
- "is_guest": False,
- "token_id": 1234,
- "device_id": "DEVICE",
- }
+ return defer.succeed(None)
+ return defer.succeed(
+ {
+ "name": USER_ID,
+ "is_guest": False,
+ "token_id": 1234,
+ "device_id": "DEVICE",
+ }
+ )
self.store.get_user_by_access_token = get_user
- self.store.get_user_by_id = Mock(return_value={"is_guest": False})
+ self.store.get_user_by_id = Mock(
+ return_value=defer.succeed({"is_guest": False})
+ )
# check the token works
request = Mock(args={})
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 4e67503cf0..1fab1d6b69 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -375,8 +375,10 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(sender="@foo:bar", type="m.profile")
events = [event]
- user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
+ user_filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart, filter_id=filter_id
+ )
)
results = user_filter.filter_presence(events=events)
@@ -396,8 +398,10 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
- user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart + "2", filter_id=filter_id
+ user_filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart + "2", filter_id=filter_id
+ )
)
results = user_filter.filter_presence(events=events)
@@ -412,8 +416,10 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
events = [event]
- user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
+ user_filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart, filter_id=filter_id
+ )
)
results = user_filter.filter_room_state(events=events)
@@ -430,8 +436,10 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
- user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
+ user_filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart, filter_id=filter_id
+ )
)
results = user_filter.filter_room_state(events)
@@ -465,8 +473,10 @@ class FilteringTestCase(unittest.TestCase):
self.assertEquals(
user_filter_json,
(
- yield self.datastore.get_user_filter(
- user_localpart=user_localpart, filter_id=0
+ yield defer.ensureDeferred(
+ self.datastore.get_user_filter(
+ user_localpart=user_localpart, filter_id=0
+ )
)
),
)
@@ -479,8 +489,10 @@ class FilteringTestCase(unittest.TestCase):
user_localpart=user_localpart, user_filter=user_filter_json
)
- filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
+ filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart, filter_id=filter_id
+ )
)
self.assertEquals(filter.get_filter_json(), user_filter_json)
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index e0ad8e8a77..0d4b05304b 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -40,6 +40,7 @@ from synapse.logging.context import (
from synapse.storage.keys import FetchKeyResult
from tests import unittest
+from tests.test_utils import make_awaitable
class MockPerspectiveServer(object):
@@ -201,7 +202,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
with a null `ts_valid_until_ms`
"""
mock_fetcher = keyring.KeyFetcher()
- mock_fetcher.get_keys = Mock(return_value=defer.succeed({}))
+ mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
kr = keyring.Keyring(
self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
@@ -244,17 +245,15 @@ class KeyringTestCase(unittest.HomeserverTestCase):
"""Two requests for the same key should be deduped."""
key1 = signedjson.key.generate_signing_key(1)
- def get_keys(keys_to_fetch):
+ async def get_keys(keys_to_fetch):
# there should only be one request object (with the max validity)
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
- return defer.succeed(
- {
- "server1": {
- get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
- }
+ return {
+ "server1": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
}
- )
+ }
mock_fetcher = keyring.KeyFetcher()
mock_fetcher.get_keys = Mock(side_effect=get_keys)
@@ -281,25 +280,19 @@ class KeyringTestCase(unittest.HomeserverTestCase):
"""If the first fetcher cannot provide a recent enough key, we fall back"""
key1 = signedjson.key.generate_signing_key(1)
- def get_keys1(keys_to_fetch):
+ async def get_keys1(keys_to_fetch):
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
- return defer.succeed(
- {
- "server1": {
- get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
- }
- }
- )
+ return {
+ "server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
+ }
- def get_keys2(keys_to_fetch):
+ async def get_keys2(keys_to_fetch):
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
- return defer.succeed(
- {
- "server1": {
- get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
- }
+ return {
+ "server1": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
}
- )
+ }
mock_fetcher1 = keyring.KeyFetcher()
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 628f7d8db0..2a0b7c1b56 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -120,7 +120,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_as_api.query_alias.return_value = make_awaitable(True)
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_association_from_room_alias.return_value = defer.succeed(
+ self.mock_store.get_association_from_room_alias.return_value = make_awaitable(
Mock(room_id=room_id, servers=servers)
)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 6d45c4b233..e364b1bd62 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -22,6 +22,7 @@ from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
from synapse.types import RoomAlias, UserID, create_requester
+from tests.test_utils import make_awaitable
from tests.unittest import override_config
from .. import unittest
@@ -187,7 +188,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- self.store.is_real_user = Mock(return_value=defer.succeed(False))
+ self.store.is_real_user = Mock(return_value=make_awaitable(False))
user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@@ -199,8 +200,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self):
room_alias_str = "#room:test"
- self.store.count_real_users = Mock(return_value=defer.succeed(1))
- self.store.is_real_user = Mock(return_value=defer.succeed(True))
+ self.store.count_real_users = Mock(return_value=make_awaitable(1))
+ self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
directory_handler = self.hs.get_handlers().directory_handler
@@ -214,8 +215,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
- self.store.count_real_users = Mock(return_value=defer.succeed(2))
- self.store.is_real_user = Mock(return_value=defer.succeed(True))
+ self.store.count_real_users = Mock(return_value=make_awaitable(2))
+ self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index d9d312f0fb..0e666492f6 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -15,7 +15,7 @@
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
-from synapse.storage.data_stores.main import stats
+from synapse.storage.databases.main import stats
from tests import unittest
@@ -42,36 +42,36 @@ class StatsRoomTests(unittest.HomeserverTestCase):
Add the background updates we need to run.
"""
# Ugh, have to reset this flag
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "populate_stats_prepare", "progress_json": "{}"},
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
- "update_name": "populate_stats_process_rooms",
+ "update_name": "populate_stats_process_rooms_2",
"progress_json": "{}",
"depends_on": "populate_stats_prepare",
},
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_users",
"progress_json": "{}",
- "depends_on": "populate_stats_process_rooms",
+ "depends_on": "populate_stats_process_rooms_2",
},
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
@@ -82,7 +82,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
def get_all_room_state(self):
- return self.store.db.simple_select_list(
+ return self.store.db_pool.simple_select_list(
"room_stats_state", None, retcols=("name", "topic", "canonical_alias")
)
@@ -96,7 +96,7 @@ class StatsRoomTests(unittest.HomeserverTestCase):
end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000)
return self.get_success(
- self.store.db.simple_select_one(
+ self.store.db_pool.simple_select_one(
table + "_historical",
{id_col: stat_id, end_ts: end_ts},
cols,
@@ -109,10 +109,10 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self._add_background_updates()
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
def test_initial_room(self):
@@ -146,10 +146,10 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self._add_background_updates()
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
r = self.get_success(self.get_all_room_state())
@@ -186,9 +186,9 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# the position that the deltas should begin at, once they take over.
self.hs.config.stats_enabled = True
self.handler.stats_enabled = True
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
self.get_success(
- self.store.db.simple_update_one(
+ self.store.db_pool.simple_update_one(
table="stats_incremental_position",
keyvalues={},
updatevalues={"stream_id": 0},
@@ -196,17 +196,17 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "populate_stats_prepare", "progress_json": "{}"},
)
)
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Now, before the table is actually ingested, add some more events.
@@ -217,28 +217,31 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# Now do the initial ingestion.
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
- {"update_name": "populate_stats_process_rooms", "progress_json": "{}"},
+ {
+ "update_name": "populate_stats_process_rooms_2",
+ "progress_json": "{}",
+ },
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
"progress_json": "{}",
- "depends_on": "populate_stats_process_rooms",
+ "depends_on": "populate_stats_process_rooms_2",
},
)
)
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
self.reactor.advance(86401)
@@ -346,6 +349,37 @@ class StatsRoomTests(unittest.HomeserverTestCase):
self.assertEqual(r1stats_post["total_events"] - r1stats_ante["total_events"], 1)
+ def test_updating_profile_information_does_not_increase_joined_members_count(self):
+ """
+ Check that the joined_members count does not increase when a user changes their
+ profile information (which is done by sending another join membership event into
+ the room.
+ """
+ self._perform_background_initial_update()
+
+ # Create a user and room
+ u1 = self.register_user("u1", "pass")
+ u1token = self.login("u1", "pass")
+ r1 = self.helper.create_room_as(u1, tok=u1token)
+
+ # Get the current room stats
+ r1stats_ante = self._get_current_stats("room", r1)
+
+ # Send a profile update into the room
+ new_profile = {"displayname": "bob"}
+ self.helper.change_membership(
+ r1, u1, u1, "join", extra_data=new_profile, tok=u1token
+ )
+
+ # Get the new room stats
+ r1stats_post = self._get_current_stats("room", r1)
+
+ # Ensure that the user count did not changed
+ self.assertEqual(r1stats_post["joined_members"], r1stats_ante["joined_members"])
+ self.assertEqual(
+ r1stats_post["local_users_in_room"], r1stats_ante["local_users_in_room"]
+ )
+
def test_send_state_event_nonoverwriting(self):
"""
When we send a non-overwriting state event, it increments total_events AND current_state_events
@@ -669,15 +703,15 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# preparation stage of the initial background update
# Ugh, have to reset this flag
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
self.get_success(
- self.store.db.simple_delete(
+ self.store.db_pool.simple_delete(
"room_stats_current", {"1": 1}, "test_delete_stats"
)
)
self.get_success(
- self.store.db.simple_delete(
+ self.store.db_pool.simple_delete(
"user_stats_current", {"1": 1}, "test_delete_stats"
)
)
@@ -689,29 +723,29 @@ class StatsRoomTests(unittest.HomeserverTestCase):
# now do the background updates
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
- "update_name": "populate_stats_process_rooms",
+ "update_name": "populate_stats_process_rooms_2",
"progress_json": "{}",
"depends_on": "populate_stats_prepare",
},
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_process_users",
"progress_json": "{}",
- "depends_on": "populate_stats_process_rooms",
+ "depends_on": "populate_stats_process_rooms_2",
},
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_stats_cleanup",
@@ -722,10 +756,10 @@ class StatsRoomTests(unittest.HomeserverTestCase):
)
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
r1stats_complete = self._get_current_stats("room", r1)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 5878f74175..64afd581bc 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -24,6 +24,7 @@ from synapse.api.errors import AuthError
from synapse.types import UserID
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.unittest import override_config
from tests.utils import register_federation_servlets
@@ -115,7 +116,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
retry_timings_res
)
- self.datastore.get_device_updates_by_remote.return_value = defer.succeed(
+ self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable(
(0, [])
)
@@ -126,10 +127,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = []
- def check_user_in_room(room_id, user_id):
+ async def check_user_in_room(room_id, user_id):
if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room")
- return defer.succeed(None)
+ return None
hs.get_auth().check_user_in_room = check_user_in_room
@@ -151,7 +152,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_current_state_deltas.return_value = (0, None)
self.datastore.get_to_device_stream_token = lambda: 0
- self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed(
+ self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
([], 0)
)
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 23fcc372dd..31ed89a5cd 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -339,7 +339,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def get_users_in_public_rooms(self):
r = self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
"users_in_public_rooms", None, ("user_id", "room_id")
)
)
@@ -350,7 +350,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def get_users_who_share_private_rooms(self):
return self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
"users_who_share_private_rooms",
None,
["user_id", "other_user_id", "room_id"],
@@ -362,10 +362,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
Add the background updates we need to run.
"""
# Ugh, have to reset this flag
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_createtables",
@@ -374,7 +374,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_process_rooms",
@@ -384,7 +384,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_process_users",
@@ -394,7 +394,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{
"update_name": "populate_user_directory_cleanup",
@@ -437,10 +437,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self._add_background_updates()
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
shares_private = self.get_users_who_share_private_rooms()
@@ -476,10 +476,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self._add_background_updates()
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
shares_private = self.get_users_who_share_private_rooms()
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 06575ba0a6..ae60874ec3 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -65,7 +65,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Since we use sqlite in memory databases we need to make sure the
# databases objects are the same.
- self.worker_hs.get_datastore().db = hs.get_datastore().db
+ self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool
self.test_handler = self._build_replication_data_handler()
self.worker_hs.replication_data_handler = self.test_handler
@@ -198,7 +198,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.streamer = self.hs.get_replication_streamer()
store = self.hs.get_datastore()
- self.database = store.db
+ self.database_pool = store.db_pool
self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -254,7 +254,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
)
store = worker_hs.get_datastore()
- store.db._db_pool = self.database._db_pool
+ store.db_pool._db_pool = self.database_pool._db_pool
repl_handler = ReplicationCommandHandler(worker_hs)
client = ClientReplicationStreamProtocol(
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index cec1cf928f..408c568a27 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -566,7 +566,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
"state_groups_state",
):
count = self.get_success(
- self.store.db.simple_select_one_onecol(
+ self.store.db_pool.simple_select_one_onecol(
table=table,
keyvalues={"room_id": room_id},
retcol="COUNT(*)",
@@ -667,7 +667,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase):
"state_groups_state",
):
count = self.get_success(
- self.store.db.simple_select_one_onecol(
+ self.store.db_pool.simple_select_one_onecol(
table=table,
keyvalues={"room_id": room_id},
retcol="COUNT(*)",
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index f16eef15f7..17d0aae2e9 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -20,6 +20,8 @@ import urllib.parse
from mock import Mock
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.api.errors import HttpResponseException, ResourceLimitError
@@ -335,7 +337,9 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
store = self.hs.get_datastore()
# Set monthly active users to the limit
- store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value)
+ store.get_monthly_active_count = Mock(
+ return_value=defer.succeed(self.hs.config.max_mau_value)
+ )
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
self.get_failure(
@@ -588,7 +592,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- return_value=self.hs.config.max_mau_value
+ return_value=defer.succeed(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@@ -628,7 +632,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- return_value=self.hs.config.max_mau_value
+ return_value=defer.succeed(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index e54ffea150..0b191d13c6 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -144,7 +144,9 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Get the create event to, later, check that we can still access it.
message_handler = self.hs.get_message_handler()
create_event = self.get_success(
- message_handler.get_room_data(self.user_id, room_id, EventTypes.Create)
+ message_handler.get_room_data(
+ self.user_id, room_id, EventTypes.Create, state_key="", is_guest=False
+ )
)
# Send a first event to the room. This is the event we'll want to be purged at the
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 8df58b4a63..ace0a3c08d 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -70,8 +70,8 @@ class MockHandlerProfileTestCase(unittest.TestCase):
profile_handler=self.mock_handler,
)
- def _get_user_by_req(request=None, allow_guest=False):
- return defer.succeed(synapse.types.create_requester(myid))
+ async def _get_user_by_req(request=None, allow_guest=False):
+ return synapse.types.create_requester(myid)
hs.get_auth().get_user_by_req = _get_user_by_req
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 5ccda8b2bd..ef6b775ed2 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -23,8 +23,6 @@ from urllib import parse as urlparse
from mock import Mock
-from twisted.internet import defer
-
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
@@ -51,8 +49,8 @@ class RoomBase(unittest.HomeserverTestCase):
self.hs.get_federation_handler = Mock(return_value=Mock())
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
+ async def _insert_client_ip(*args, **kwargs):
+ return None
self.hs.get_datastore().insert_client_ip = _insert_client_ip
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 18260bb90e..94d2bf2eb1 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -46,7 +46,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_handlers().federation_handler = Mock()
- def get_user_by_access_token(token=None, allow_guest=False):
+ async def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
@@ -55,8 +55,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_auth().get_user_by_access_token = get_user_by_access_token
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
+ async def _insert_client_ip(*args, **kwargs):
+ return None
hs.get_datastore().insert_client_ip = _insert_client_ip
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 7f8252330a..8933b560d2 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -88,7 +88,28 @@ class RestHelper(object):
expect_code=expect_code,
)
- def change_membership(self, room, src, targ, membership, tok=None, expect_code=200):
+ def change_membership(
+ self,
+ room: str,
+ src: str,
+ targ: str,
+ membership: str,
+ extra_data: dict = {},
+ tok: Optional[str] = None,
+ expect_code: int = 200,
+ ) -> None:
+ """
+ Send a membership state event into a room.
+
+ Args:
+ room: The ID of the room to send to
+ src: The mxid of the event sender
+ targ: The mxid of the event's target. The state key
+ membership: The type of membership event
+ extra_data: Extra information to include in the content of the event
+ tok: The user access token to use
+ expect_code: The expected HTTP response code
+ """
temp_id = self.auth_user_id
self.auth_user_id = src
@@ -97,6 +118,7 @@ class RestHelper(object):
path = path + "?access_token=%s" % tok
data = {"membership": membership}
+ data.update(extra_data)
request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8")
@@ -143,26 +165,6 @@ class RestHelper(object):
return channel.json_body
- def redact(self, room_id, event_id, txn_id=None, tok=None, expect_code=200):
- if txn_id is None:
- txn_id = "m%s" % (str(time.time()))
-
- path = "/_matrix/client/r0/rooms/%s/redact/%s/%s" % (room_id, event_id, txn_id)
- if tok:
- path = path + "?access_token=%s" % tok
-
- request, channel = make_request(
- self.hs.get_reactor(), "PUT", path, json.dumps({}).encode("utf8")
- )
- render(request, self.resource, self.hs.get_reactor())
-
- assert int(channel.result["code"]) == expect_code, (
- "Expected: %d, got: %d, resp: %r"
- % (expect_code, int(channel.result["code"]), channel.result["body"])
- )
-
- return channel.json_body
-
def _read_write_state(
self,
room_id: str,
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 7deaf5b24a..53a43038f0 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -116,8 +116,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
+ @override_config({"enable_registration": False})
def test_POST_disabled_registration(self):
- self.hs.config.enable_registration = False
request_data = json.dumps({"username": "kermit", "password": "monkey"})
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index a31e44c97e..fa3a3ec1bd 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -16,9 +16,9 @@
import json
import synapse.rest.admin
-from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
+from synapse.api.constants import EventContentFields, EventTypes
from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import read_marker, sync
+from synapse.rest.client.v2_alpha import sync
from tests import unittest
from tests.server import TimedOutException
@@ -324,156 +324,3 @@ class SyncTypingTests(unittest.HomeserverTestCase):
"GET", sync_url % (access_token, next_batch)
)
self.assertRaises(TimedOutException, self.render, request)
-
-
-class UnreadMessagesTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- read_marker.register_servlets,
- room.register_servlets,
- sync.register_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.url = "/sync?since=%s"
- self.next_batch = "s0"
-
- # Register the first user (used to check the unread counts).
- self.user_id = self.register_user("kermit", "monkey")
- self.tok = self.login("kermit", "monkey")
-
- # Create the room we'll check unread counts for.
- self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
-
- # Register the second user (used to send events to the room).
- self.user2 = self.register_user("kermit2", "monkey")
- self.tok2 = self.login("kermit2", "monkey")
-
- # Change the power levels of the room so that the second user can send state
- # events.
- self.helper.send_state(
- self.room_id,
- EventTypes.PowerLevels,
- {
- "users": {self.user_id: 100, self.user2: 100},
- "users_default": 0,
- "events": {
- "m.room.name": 50,
- "m.room.power_levels": 100,
- "m.room.history_visibility": 100,
- "m.room.canonical_alias": 50,
- "m.room.avatar": 50,
- "m.room.tombstone": 100,
- "m.room.server_acl": 100,
- "m.room.encryption": 100,
- },
- "events_default": 0,
- "state_default": 50,
- "ban": 50,
- "kick": 50,
- "redact": 50,
- "invite": 0,
- },
- tok=self.tok,
- )
-
- def test_unread_counts(self):
- """Tests that /sync returns the right value for the unread count (MSC2654)."""
-
- # Check that our own messages don't increase the unread count.
- self.helper.send(self.room_id, "hello", tok=self.tok)
- self._check_unread_count(0)
-
- # Join the new user and check that this doesn't increase the unread count.
- self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
- self._check_unread_count(0)
-
- # Check that the new user sending a message increases our unread count.
- res = self.helper.send(self.room_id, "hello", tok=self.tok2)
- self._check_unread_count(1)
-
- # Send a read receipt to tell the server we've read the latest event.
- body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
- request, channel = self.make_request(
- "POST",
- "/rooms/%s/read_markers" % self.room_id,
- body,
- access_token=self.tok,
- )
- self.render(request)
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Check that the unread counter is back to 0.
- self._check_unread_count(0)
-
- # Check that room name changes increase the unread counter.
- self.helper.send_state(
- self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
- )
- self._check_unread_count(1)
-
- # Check that room topic changes increase the unread counter.
- self.helper.send_state(
- self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
- )
- self._check_unread_count(2)
-
- # Check that encrypted messages increase the unread counter.
- self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2)
- self._check_unread_count(3)
-
- # Check that custom events with a body increase the unread counter.
- self.helper.send_event(
- self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
- )
- self._check_unread_count(4)
-
- # Check that edits don't increase the unread counter.
- self.helper.send_event(
- room_id=self.room_id,
- type=EventTypes.Message,
- content={
- "body": "hello",
- "msgtype": "m.text",
- "m.relates_to": {"rel_type": RelationTypes.REPLACE},
- },
- tok=self.tok2,
- )
- self._check_unread_count(4)
-
- # Check that notices don't increase the unread counter.
- self.helper.send_event(
- room_id=self.room_id,
- type=EventTypes.Message,
- content={"body": "hello", "msgtype": "m.notice"},
- tok=self.tok2,
- )
- self._check_unread_count(4)
-
- # Check that tombstone events changes increase the unread counter.
- self.helper.send_state(
- self.room_id,
- EventTypes.Tombstone,
- {"replacement_room": "!someroom:test"},
- tok=self.tok2,
- )
- self._check_unread_count(5)
-
- def _check_unread_count(self, expected_count: True):
- """Syncs and compares the unread count with the expected value."""
-
- request, channel = self.make_request(
- "GET", self.url % self.next_batch, access_token=self.tok,
- )
- self.render(request)
-
- self.assertEqual(channel.code, 200, channel.json_body)
-
- room_entry = channel.json_body["rooms"]["join"][self.room_id]
- self.assertEqual(
- room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
- )
-
- # Store the next batch for the next request.
- self.next_batch = channel.json_body["next_batch"]
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
new file mode 100644
index 0000000000..2d021f6565
--- /dev/null
+++ b/tests/rest/test_health.py
@@ -0,0 +1,34 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.rest.health import HealthResource
+
+from tests import unittest
+
+
+class HealthCheckTests(unittest.HomeserverTestCase):
+ def setUp(self):
+ super().setUp()
+
+ # replace the JsonResource with a HealthResource.
+ self.resource = HealthResource()
+
+ def test_health(self):
+ request, channel = self.make_request("GET", "/health", shorthand=False)
+ self.render(request)
+
+ self.assertEqual(request.code, 200)
+ self.assertEqual(channel.result["body"], b"OK")
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 99908edba3..2858d13558 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -27,6 +27,7 @@ from synapse.server_notices.resource_limits_server_notices import (
)
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.unittest import override_config
from tests.utils import default_config
@@ -79,7 +80,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return_value=defer.succeed("!something:localhost")
)
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
- self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({}))
+ self._rlsn._store.get_tags_for_room = Mock(
+ side_effect=lambda user_id, room_id: make_awaitable({})
+ )
@override_config({"hs_disabled": True})
def test_maybe_send_server_notice_disabled_hs(self):
@@ -258,7 +261,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.user_id = "@user_id:test"
def test_server_notice_only_sent_once(self):
- self.store.get_monthly_active_count = Mock(return_value=1000)
+ self.store.get_monthly_active_count = Mock(return_value=defer.succeed(1000))
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(1000)
@@ -275,7 +278,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.server_notices_manager.get_or_create_notice_room_for_user(self.user_id)
)
- token = self.get_success(self.event_source.get_current_token())
+ token = self.event_source.get_current_token()
events, _ = self.get_success(
self.store.get_recent_events_for_room(
room_id, limit=100, end_token=token.room_key
diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py
index 5a50e4fdd4..319e2c2325 100644
--- a/tests/storage/test__base.py
+++ b/tests/storage/test__base.py
@@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
self.table_name = "table_" + hs.get_secrets().token_hex(6)
self.get_success(
- self.storage.db.runInteraction(
+ self.storage.db_pool.runInteraction(
"create",
lambda x, *a: x.execute(*a),
"CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)"
@@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.storage.db.runInteraction(
+ self.storage.db_pool.runInteraction(
"index",
lambda x, *a: x.execute(*a),
"CREATE UNIQUE INDEX %sindex ON %s(id, username)"
@@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
value_values = [["hello"], ["there"]]
self.get_success(
- self.storage.db.runInteraction(
+ self.storage.db_pool.runInteraction(
"test",
- self.storage.db.simple_upsert_many_txn,
+ self.storage.db_pool.simple_upsert_many_txn,
self.table_name,
key_names,
key_values,
@@ -367,7 +367,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
# Check results are what we expect
res = self.get_success(
- self.storage.db.simple_select_list(
+ self.storage.db_pool.simple_select_list(
self.table_name, None, ["id, username, value"]
)
)
@@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase):
value_values = [["bleb"]]
self.get_success(
- self.storage.db.runInteraction(
+ self.storage.db_pool.runInteraction(
"test",
- self.storage.db.simple_upsert_many_txn,
+ self.storage.db_pool.simple_upsert_many_txn,
self.table_name,
key_names,
key_values,
@@ -394,7 +394,7 @@ class UpsertManyTests(unittest.HomeserverTestCase):
# Check results are what we expect
res = self.get_success(
- self.storage.db.simple_select_list(
+ self.storage.db_pool.simple_select_list(
self.table_name, None, ["id, username, value"]
)
)
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index ef296e7dab..98b74890d5 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -24,11 +24,11 @@ from twisted.internet import defer
from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.config._base import ConfigError
-from synapse.storage.data_stores.main.appservice import (
+from synapse.storage.database import DatabasePool, make_conn
+from synapse.storage.databases.main.appservice import (
ApplicationServiceStore,
ApplicationServiceTransactionStore,
)
-from synapse.storage.database import Database, make_conn
from tests import unittest
from tests.utils import setup_test_homeserver
@@ -178,14 +178,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_appservice_state_none(self):
service = Mock(id="999")
- state = yield self.store.get_appservice_state(service)
+ state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
self.assertEquals(None, state)
@defer.inlineCallbacks
def test_get_appservice_state_up(self):
yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
service = Mock(id=self.as_list[0]["id"])
- state = yield self.store.get_appservice_state(service)
+ state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
self.assertEquals(ApplicationServiceState.UP, state)
@defer.inlineCallbacks
@@ -194,13 +194,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN)
yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
service = Mock(id=self.as_list[1]["id"])
- state = yield self.store.get_appservice_state(service)
+ state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
self.assertEquals(ApplicationServiceState.DOWN, state)
@defer.inlineCallbacks
def test_get_appservices_by_state_none(self):
- services = yield self.store.get_appservices_by_state(
- ApplicationServiceState.DOWN
+ services = yield defer.ensureDeferred(
+ self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
)
self.assertEquals(0, len(services))
@@ -339,7 +339,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
def test_get_oldest_unsent_txn_none(self):
service = Mock(id=self.as_list[0]["id"])
- txn = yield self.store.get_oldest_unsent_txn(service)
+ txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
self.assertEquals(None, txn)
@defer.inlineCallbacks
@@ -349,14 +349,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
# we aren't testing store._base stuff here, so mock this out
- self.store.get_events_as_list = Mock(return_value=events)
+ self.store.get_events_as_list = Mock(return_value=defer.succeed(events))
yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
yield self._insert_txn(service.id, 10, events)
yield self._insert_txn(service.id, 11, other_events)
yield self._insert_txn(service.id, 12, other_events)
- txn = yield self.store.get_oldest_unsent_txn(service)
+ txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
self.assertEquals(service, txn.service)
self.assertEquals(10, txn.id)
self.assertEquals(events, txn.events)
@@ -366,8 +366,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
- services = yield self.store.get_appservices_by_state(
- ApplicationServiceState.DOWN
+ services = yield defer.ensureDeferred(
+ self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
)
self.assertEquals(1, len(services))
self.assertEquals(self.as_list[0]["id"], services[0].id)
@@ -379,8 +379,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP)
- services = yield self.store.get_appservices_by_state(
- ApplicationServiceState.DOWN
+ services = yield defer.ensureDeferred(
+ self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
)
self.assertEquals(2, len(services))
self.assertEquals(
@@ -391,7 +391,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
# required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(TestTransactionStore, self).__init__(database, db_conn, hs)
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index 940b166129..2efbc97c2e 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -9,7 +9,9 @@ from tests import unittest
class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
- self.updates = self.hs.get_datastore().db.updates # type: BackgroundUpdater
+ self.updates = (
+ self.hs.get_datastore().db_pool.updates
+ ) # type: BackgroundUpdater
# the base test class should have run the real bg updates for us
self.assertTrue(
self.get_success(self.updates.has_completed_background_updates())
@@ -29,7 +31,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
store = self.hs.get_datastore()
self.get_success(
- store.db.simple_insert(
+ store.db_pool.simple_insert(
"background_updates",
values={"update_name": "test_update", "progress_json": '{"my_key": 1}'},
)
@@ -40,7 +42,7 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase):
def update(progress, count):
yield self.clock.sleep((count * duration_ms) / 1000)
progress = {"my_key": progress["my_key"] + 1}
- yield store.db.runInteraction(
+ yield store.db_pool.runInteraction(
"update_progress",
self.updates._background_update_progress_txn,
"test_update",
diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py
index b589506c60..efcaeef1e7 100644
--- a/tests/storage/test_base.py
+++ b/tests/storage/test_base.py
@@ -21,7 +21,7 @@ from mock import Mock
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
from synapse.storage.engines import create_engine
from tests import unittest
@@ -57,7 +57,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
fake_engine = Mock(wraps=engine)
fake_engine.can_native_upsert = False
- db = Database(Mock(), Mock(config=sqlite_config), fake_engine)
+ db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine)
db._db_pool = self.db_pool
self.datastore = SQLBaseStore(db, None, hs)
@@ -66,7 +66,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_1col(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db.simple_insert(
+ yield self.datastore.db_pool.simple_insert(
table="tablename", values={"columname": "Value"}
)
@@ -78,7 +78,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_insert_3cols(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db.simple_insert(
+ yield self.datastore.db_pool.simple_insert(
table="tablename",
# Use OrderedDict() so we can assert on the SQL generated
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
@@ -93,7 +93,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
- value = yield self.datastore.db.simple_select_one_onecol(
+ value = yield self.datastore.db_pool.simple_select_one_onecol(
table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
)
@@ -107,7 +107,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3)
- ret = yield self.datastore.db.simple_select_one(
+ ret = yield self.datastore.db_pool.simple_select_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
retcols=["colA", "colB", "colC"],
@@ -123,7 +123,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None
- ret = yield self.datastore.db.simple_select_one(
+ ret = yield self.datastore.db_pool.simple_select_one(
table="tablename",
keyvalues={"keycol": "Not here"},
retcols=["colA"],
@@ -138,7 +138,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = (("colA", None, None, None, None, None, None),)
- ret = yield self.datastore.db.simple_select_list(
+ ret = yield self.datastore.db_pool.simple_select_list(
table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
)
@@ -151,7 +151,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_1col(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db.simple_update_one(
+ yield self.datastore.db_pool.simple_update_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
updatevalues={"columnname": "New Value"},
@@ -166,7 +166,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_4cols(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db.simple_update_one(
+ yield self.datastore.db_pool.simple_update_one(
table="tablename",
keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
@@ -181,7 +181,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_delete_one(self):
self.mock_txn.rowcount = 1
- yield self.datastore.db.simple_delete_one(
+ yield self.datastore.db_pool.simple_delete_one(
table="tablename", keyvalues={"keycol": "Go away"}
)
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 43425c969a..3fab5a5248 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -47,12 +47,12 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
"""
# Make sure we don't clash with in progress updates.
self.assertTrue(
- self.store.db.updates._all_done, "Background updates are still ongoing"
+ self.store.db_pool.updates._all_done, "Background updates are still ongoing"
)
schema_path = os.path.join(
prepare_database.dir_path,
- "data_stores",
+ "databases",
"main",
"schema",
"delta",
@@ -64,19 +64,19 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
prepare_database.executescript(txn, schema_path)
self.get_success(
- self.store.db.runInteraction(
+ self.store.db_pool.runInteraction(
"test_delete_forward_extremities", run_delta_file
)
)
# Ugh, have to reset this flag
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
def test_soft_failed_extremities_handled_correctly(self):
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 3b483bc7f0..224ea6fd79 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -86,7 +86,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.pump(0)
result = self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -117,7 +117,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.pump(0)
result = self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -204,10 +204,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def test_devices_last_seen_bg_update(self):
# First make sure we have completed all updates.
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
user_id = "@user:id"
@@ -225,7 +225,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# But clear the associated entry in devices table
self.get_success(
- self.store.db.simple_update(
+ self.store.db_pool.simple_update(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
updatevalues={"last_seen": None, "ip": None, "user_agent": None},
@@ -252,7 +252,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# Register the background update to run again.
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
table="background_updates",
values={
"update_name": "devices_last_seen",
@@ -263,14 +263,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
# ... and tell the DataStore that it hasn't finished all updates yet
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
# Now let's actually drive the updates to completion
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# We should now get the correct result again
@@ -293,10 +293,10 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def test_old_user_ips_pruned(self):
# First make sure we have completed all updates.
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
user_id = "@user:id"
@@ -315,7 +315,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# We should see that in the DB
result = self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
@@ -341,7 +341,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# We should get no results.
result = self.get_success(
- self.store.db.simple_select_list(
+ self.store.db_pool.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"],
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index c2539b353a..87ed8f8cd1 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -34,7 +34,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_store_new_device(self):
- yield self.store.store_device("user_id", "device_id", "display_name")
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id", "device_id", "display_name")
+ )
res = yield self.store.get_device("user_id", "device_id")
self.assertDictContainsSubset(
@@ -48,11 +50,17 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_get_devices_by_user(self):
- yield self.store.store_device("user_id", "device1", "display_name 1")
- yield self.store.store_device("user_id", "device2", "display_name 2")
- yield self.store.store_device("user_id2", "device3", "display_name 3")
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id", "device1", "display_name 1")
+ )
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id", "device2", "display_name 2")
+ )
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id2", "device3", "display_name 3")
+ )
- res = yield self.store.get_devices_by_user("user_id")
+ res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id"))
self.assertEqual(2, len(res.keys()))
self.assertDictContainsSubset(
{
@@ -76,13 +84,13 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
device_ids = ["device_id1", "device_id2"]
# Add two device updates with a single stream_id
- yield self.store.add_device_change_to_streams(
- "user_id", device_ids, ["somehost"]
+ yield defer.ensureDeferred(
+ self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
)
# Get all device updates ever meant for this remote
- now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
- "somehost", -1, limit=100
+ now_stream_id, device_updates = yield defer.ensureDeferred(
+ self.store.get_device_updates_by_remote("somehost", -1, limit=100)
)
# Check original device_ids are contained within these updates
@@ -99,19 +107,23 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_update_device(self):
- yield self.store.store_device("user_id", "device_id", "display_name 1")
+ yield defer.ensureDeferred(
+ self.store.store_device("user_id", "device_id", "display_name 1")
+ )
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
- yield self.store.update_device("user_id", "device_id")
+ yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 1", res["display_name"])
# do the update
- yield self.store.update_device(
- "user_id", "device_id", new_display_name="display_name 2"
+ yield defer.ensureDeferred(
+ self.store.update_device(
+ "user_id", "device_id", new_display_name="display_name 2"
+ )
)
# check it worked
@@ -121,7 +133,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_update_unknown_device(self):
with self.assertRaises(synapse.api.errors.StoreError) as cm:
- yield self.store.update_device(
- "user_id", "unknown_device_id", new_display_name="display_name 2"
+ yield defer.ensureDeferred(
+ self.store.update_device(
+ "user_id", "unknown_device_id", new_display_name="display_name 2"
+ )
)
self.assertEqual(404, cm.exception.code)
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index 4e128e1047..daac947cb2 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -34,8 +34,10 @@ class DirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_room_to_alias(self):
- yield self.store.create_room_alias_association(
- room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ yield defer.ensureDeferred(
+ self.store.create_room_alias_association(
+ room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ )
)
self.assertEquals(
@@ -45,24 +47,36 @@ class DirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_alias_to_room(self):
- yield self.store.create_room_alias_association(
- room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ yield defer.ensureDeferred(
+ self.store.create_room_alias_association(
+ room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ )
)
self.assertObjectHasAttributes(
{"room_id": self.room.to_string(), "servers": ["test"]},
- (yield self.store.get_association_from_room_alias(self.alias)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_association_from_room_alias(self.alias)
+ )
+ ),
)
@defer.inlineCallbacks
def test_delete_alias(self):
- yield self.store.create_room_alias_association(
- room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ yield defer.ensureDeferred(
+ self.store.create_room_alias_association(
+ room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ )
)
- room_id = yield self.store.delete_room_alias(self.alias)
+ room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias))
self.assertEqual(self.room.to_string(), room_id)
self.assertIsNone(
- (yield self.store.get_association_from_room_alias(self.alias))
+ (
+ yield defer.ensureDeferred(
+ self.store.get_association_from_room_alias(self.alias)
+ )
+ )
)
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 398d546280..d57cdffd8b 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -30,11 +30,13 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
now = 1470174257070
json = {"key": "value"}
- yield self.store.store_device("user", "device", None)
+ yield defer.ensureDeferred(self.store.store_device("user", "device", None))
yield self.store.set_e2e_device_keys("user", "device", now, json)
- res = yield self.store.get_e2e_device_keys((("user", "device"),))
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_device_keys((("user", "device"),))
+ )
self.assertIn("user", res)
self.assertIn("device", res["user"])
dev = res["user"]["device"]
@@ -45,7 +47,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
now = 1470174257070
json = {"key": "value"}
- yield self.store.store_device("user", "device", None)
+ yield defer.ensureDeferred(self.store.store_device("user", "device", None))
changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
self.assertTrue(changed)
@@ -61,9 +63,13 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
json = {"key": "value"}
yield self.store.set_e2e_device_keys("user", "device", now, json)
- yield self.store.store_device("user", "device", "display_name")
+ yield defer.ensureDeferred(
+ self.store.store_device("user", "device", "display_name")
+ )
- res = yield self.store.get_e2e_device_keys((("user", "device"),))
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_device_keys((("user", "device"),))
+ )
self.assertIn("user", res)
self.assertIn("device", res["user"])
dev = res["user"]["device"]
@@ -75,18 +81,18 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
def test_multiple_devices(self):
now = 1470174257070
- yield self.store.store_device("user1", "device1", None)
- yield self.store.store_device("user1", "device2", None)
- yield self.store.store_device("user2", "device1", None)
- yield self.store.store_device("user2", "device2", None)
+ yield defer.ensureDeferred(self.store.store_device("user1", "device1", None))
+ yield defer.ensureDeferred(self.store.store_device("user1", "device2", None))
+ yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
+ yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
- res = yield self.store.get_e2e_device_keys(
- (("user1", "device1"), ("user2", "device2"))
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2")))
)
self.assertIn("user1", res)
self.assertIn("device1", res["user1"])
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 3aeec0dc0f..d4c3b867e3 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -56,7 +56,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
for i in range(0, 20):
- self.get_success(self.store.db.runInteraction("insert", insert_event, i))
+ self.get_success(
+ self.store.db_pool.runInteraction("insert", insert_event, i)
+ )
# this should get the last ten
r = self.get_success(self.store.get_prev_events_for_room(room_id))
@@ -81,13 +83,13 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
for i in range(0, 20):
self.get_success(
- self.store.db.runInteraction("insert", insert_event, i, room1)
+ self.store.db_pool.runInteraction("insert", insert_event, i, room1)
)
self.get_success(
- self.store.db.runInteraction("insert", insert_event, i, room2)
+ self.store.db_pool.runInteraction("insert", insert_event, i, room2)
)
self.get_success(
- self.store.db.runInteraction("insert", insert_event, i, room3)
+ self.store.db_pool.runInteraction("insert", insert_event, i, room3)
)
# Test simple case
@@ -164,7 +166,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
depth = depth_map[event_id]
- self.store.db.simple_insert_txn(
+ self.store.db_pool.simple_insert_txn(
txn,
table="events",
values={
@@ -179,7 +181,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
},
)
- self.store.db.simple_insert_many_txn(
+ self.store.db_pool.simple_insert_many_txn(
txn,
table="event_auth",
values=[
@@ -192,7 +194,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
for event_id in auth_graph:
next_stream_ordering += 1
self.get_success(
- self.store.db.runInteraction(
+ self.store.db_pool.runInteraction(
"insert", insert_event, event_id, next_stream_ordering
)
)
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 2b1580feeb..857db071d4 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -60,7 +60,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count):
- counts = yield self.store.db.runInteraction(
+ counts = yield self.store.db_pool.runInteraction(
"", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
)
self.assertEquals(
@@ -81,7 +81,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.event_id, {user_id: action}
)
)
- yield self.store.db.runInteraction(
+ yield self.store.db_pool.runInteraction(
"",
self.persist_events_store._set_push_actions_for_event_and_users_txn,
[(event, None)],
@@ -89,12 +89,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
def _rotate(stream):
- return self.store.db.runInteraction(
+ return self.store.db_pool.runInteraction(
"", self.store._rotate_notifs_before_txn, stream
)
def _mark_read(stream, depth):
- return self.store.db.runInteraction(
+ return self.store.db_pool.runInteraction(
"",
self.store._remove_old_push_actions_before_txn,
room_id,
@@ -123,7 +123,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield _inject_actions(6, PlAIN_NOTIF)
yield _rotate(7)
- yield self.store.db.simple_delete(
+ yield self.store.db_pool.simple_delete(
table="event_push_actions", keyvalues={"1": 1}, desc=""
)
@@ -142,7 +142,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts):
- return self.store.db.simple_insert(
+ return self.store.db_pool.simple_insert(
"events",
{
"stream_ordering": so,
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 55e9ecf264..e845410dae 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -14,7 +14,7 @@
# limitations under the License.
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from tests.unittest import HomeserverTestCase
@@ -27,9 +27,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- self.db = self.store.db # type: Database
+ self.db_pool = self.store.db_pool # type: DatabasePool
- self.get_success(self.db.runInteraction("_setup_db", self._setup_db))
+ self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn):
txn.execute("CREATE SEQUENCE foobar_seq")
@@ -47,7 +47,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
def _create(conn):
return MultiWriterIdGenerator(
conn,
- self.db,
+ self.db_pool,
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
@@ -55,7 +55,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
sequence_name="foobar_seq",
)
- return self.get_success(self.db.runWithConnection(_create))
+ return self.get_success(self.db_pool.runWithConnection(_create))
def _insert_rows(self, instance_name: str, number: int):
def _insert(txn):
@@ -65,7 +65,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
(instance_name,),
)
- self.get_success(self.db.runInteraction("test_single_instance", _insert))
+ self.get_success(self.db_pool.runInteraction("test_single_instance", _insert))
def test_empty(self):
"""Test an ID generator against an empty database gives sensible
@@ -178,7 +178,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token("master"), 7)
- self.get_success(self.db.runInteraction("test", _get_next_txn))
+ self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token("master"), 8)
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 9c04e92577..9870c74883 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.constants import UserTypes
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.unittest import default_config, override_config
FORTY_DAYS = 40 * 24 * 60 * 60
@@ -78,7 +79,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
# XXX why are we doing this here? this function is only run at startup
# so it is odd to re-run it here.
self.get_success(
- self.store.db.runInteraction(
+ self.store.db_pool.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
)
@@ -204,7 +205,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.user_add_threepid(user, "email", email, now, now)
)
- d = self.store.db.runInteraction(
+ d = self.store.db_pool.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
self.get_success(d)
@@ -230,7 +231,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
)
self.get_success(d)
- self.store.upsert_monthly_active_user = Mock()
+ self.store.upsert_monthly_active_user = Mock(
+ side_effect=lambda user_id: make_awaitable(None)
+ )
d = self.store.populate_monthly_active_users(user_id)
self.get_success(d)
@@ -238,7 +241,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called()
def test_populate_monthly_users_should_update(self):
- self.store.upsert_monthly_active_user = Mock()
+ self.store.upsert_monthly_active_user = Mock(
+ side_effect=lambda user_id: make_awaitable(None)
+ )
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
@@ -251,7 +256,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_called_once()
def test_populate_monthly_users_should_not_update(self):
- self.store.upsert_monthly_active_user = Mock()
+ self.store.upsert_monthly_active_user = Mock(
+ side_effect=lambda user_id: make_awaitable(None)
+ )
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
self.store.user_last_seen_monthly_active = Mock(
@@ -280,7 +287,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
]
self.hs.config.mau_limits_reserved_threepids = threepids
- d = self.store.db.runInteraction(
+ d = self.store.db_pool.runInteraction(
"initialise", self.store._initialise_reserved_users, threepids
)
self.get_success(d)
@@ -293,8 +300,12 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.register_user(user_id=user2, password_hash=None))
now = int(self.hs.get_clock().time_msec())
- self.store.user_add_threepid(user1, "email", user1_email, now, now)
- self.store.user_add_threepid(user2, "email", user2_email, now, now)
+ self.get_success(
+ self.store.user_add_threepid(user1, "email", user1_email, now, now)
+ )
+ self.get_success(
+ self.store.user_add_threepid(user2, "email", user2_email, now, now)
+ )
users = self.get_success(self.store.get_registered_reserved_users())
self.assertEqual(len(users), len(threepids))
@@ -333,7 +344,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
def test_no_users_when_not_tracking(self):
- self.store.upsert_monthly_active_user = Mock()
+ self.store.upsert_monthly_active_user = Mock(
+ side_effect=lambda user_id: make_awaitable(None)
+ )
self.get_success(self.store.populate_monthly_active_users("@user:sever"))
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 0f0e1cd09b..1ea35d60c1 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -251,6 +251,10 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def room_id(self):
return self._base_builder.room_id
+ @property
+ def type(self):
+ return self._base_builder.type
+
event_1, context_1 = self.get_success(
self.event_creation_handler.create_new_client_event(
EventIdManglingBuilder(
@@ -343,7 +347,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
)
event_json = self.get_success(
- self.store.db.simple_select_one_onecol(
+ self.store.db_pool.simple_select_one_onecol(
table="event_json",
keyvalues={"event_id": msg_event.event_id},
retcol="json",
@@ -361,7 +365,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.reactor.advance(60 * 60 * 2)
event_json = self.get_success(
- self.store.db.simple_select_one_onecol(
+ self.store.db_pool.simple_select_one_onecol(
table="event_json",
keyvalues={"event_id": msg_event.event_id},
retcol="json",
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py
index 71a40a0a49..840db66072 100644
--- a/tests/storage/test_registration.py
+++ b/tests/storage/test_registration.py
@@ -58,8 +58,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_add_tokens(self):
yield self.store.register_user(self.user_id, self.pwhash)
- yield self.store.add_access_token_to_user(
- self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.store.add_access_token_to_user(
+ self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
+ )
)
result = yield self.store.get_user_by_access_token(self.tokens[1])
@@ -74,11 +76,15 @@ class RegistrationStoreTestCase(unittest.TestCase):
def test_user_delete_access_tokens(self):
# add some tokens
yield self.store.register_user(self.user_id, self.pwhash)
- yield self.store.add_access_token_to_user(
- self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.store.add_access_token_to_user(
+ self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
+ )
)
- yield self.store.add_access_token_to_user(
- self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.store.add_access_token_to_user(
+ self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
+ )
)
# now delete some
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index f282921538..17c9da4838 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -179,10 +179,10 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
def test_can_rerun_update(self):
# First make sure we have completed all updates.
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Now let's create a room, which will insert a membership
@@ -192,7 +192,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
# Register the background update to run again.
self.get_success(
- self.store.db.simple_insert(
+ self.store.db_pool.simple_insert(
table="background_updates",
values={
"update_name": "current_state_events_membership",
@@ -203,12 +203,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
)
# ... and tell the DataStore that it hasn't finished all updates yet
- self.store.db.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
# Now let's actually drive the updates to completion
while not self.get_success(
- self.store.db.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- self.store.db.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 6a545d2eb0..ecfafe68a9 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -40,7 +40,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
def test_search_user_dir(self):
# normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her.
- r = yield self.store.search_user_dir(ALICE, "bob", 10)
+ r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"])
self.assertEqual(1, len(r["results"]))
self.assertDictEqual(
@@ -51,7 +51,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
def test_search_user_dir_all_users(self):
self.hs.config.user_directory_search_all_users = True
try:
- r = yield self.store.search_user_dir(ALICE, "bob", 10)
+ r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"])
self.assertEqual(2, len(r["results"]))
self.assertDictEqual(
diff --git a/tests/test_federation.py b/tests/test_federation.py
index c2f12c2741..f2fa42bfb9 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -1,3 +1,18 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from mock import Mock
from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
@@ -10,6 +25,7 @@ from synapse.util.retryutils import NotRetryingDestination
from tests import unittest
from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
+from tests.test_utils import make_awaitable
class MessageAcceptTests(unittest.HomeserverTestCase):
@@ -173,7 +189,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register a mock on the store so that the incoming update doesn't fail because
# we don't share a room with the user.
store = self.homeserver.get_datastore()
- store.get_rooms_for_user = Mock(return_value=succeed(["!someroom:test"]))
+ store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"]))
# Manually inject a fake device list update. We need this update to include at
# least one prev_id so that the user's device list will need to be retried.
diff --git a/tests/test_server.py b/tests/test_server.py
index 073b2362cc..d628070e48 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -157,6 +157,29 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(channel.json_body["error"], "Unrecognized request")
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+ def test_head_request(self):
+ """
+ JsonResource.handler_for_request gives correctly decoded URL args to
+ the callback, while Twisted will give the raw bytes of URL query
+ arguments.
+ """
+
+ def _callback(request, **kwargs):
+ return 200, {"result": True}
+
+ res = JsonResource(self.homeserver)
+ res.register_paths(
+ "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet",
+ )
+
+ # The path was registered as GET, but this is a HEAD request.
+ request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertNotIn("body", channel.result)
+ self.assertEqual(channel.headers.getRawHeaders(b"Content-Length"), [b"15"])
+
class OptionsResourceTests(unittest.TestCase):
def setUp(self):
@@ -255,7 +278,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.reactor = ThreadedMemoryReactorClock()
def test_good_response(self):
- def callback(request):
+ async def callback(request):
request.write(b"response")
request.finish()
@@ -275,7 +298,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
with the right location.
"""
- def callback(request, **kwargs):
+ async def callback(request, **kwargs):
raise RedirectException(b"/look/an/eagle", 301)
res = WrapHtmlRequestHandlerTests.TestResource()
@@ -295,7 +318,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
returned too
"""
- def callback(request, **kwargs):
+ async def callback(request, **kwargs):
e = RedirectException(b"/no/over/there", 304)
e.cookies.append(b"session=yespls")
raise e
@@ -312,3 +335,19 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.assertEqual(location_headers, [b"/no/over/there"])
cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
self.assertEqual(cookies_headers, [b"session=yespls"])
+
+ def test_head_request(self):
+ """A head request should work by being turned into a GET request."""
+
+ async def callback(request):
+ request.write(b"response")
+ request.finish()
+
+ res = WrapHtmlRequestHandlerTests.TestResource()
+ res.callback = callback
+
+ request, channel = make_request(self.reactor, b"HEAD", b"/path")
+ render(request, res, self.reactor)
+
+ self.assertEqual(channel.result["code"], b"200")
+ self.assertNotIn("body", channel.result)
diff --git a/tests/unittest.py b/tests/unittest.py
index 68d2586efd..d0bba3ddef 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -241,20 +241,16 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "user_id"):
if self.hijack_auth:
- def get_user_by_access_token(token=None, allow_guest=False):
- return succeed(
- {
- "user": UserID.from_string(self.helper.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
- )
-
- def get_user_by_req(request, allow_guest=False, rights="access"):
- return succeed(
- create_requester(
- UserID.from_string(self.helper.auth_user_id), 1, False, None
- )
+ async def get_user_by_access_token(token=None, allow_guest=False):
+ return {
+ "user": UserID.from_string(self.helper.auth_user_id),
+ "token_id": 1,
+ "is_guest": False,
+ }
+
+ async def get_user_by_req(request, allow_guest=False, rights="access"):
+ return create_requester(
+ UserID.from_string(self.helper.auth_user_id), 1, False, None
)
self.hs.get_auth().get_user_by_req = get_user_by_req
@@ -422,8 +418,8 @@ class HomeserverTestCase(TestCase):
async def run_bg_updates():
with LoggingContext("run_bg_updates", request="run_bg_updates-1"):
- while not await stor.db.updates.has_completed_background_updates():
- await stor.db.updates.do_next_background_update(1)
+ while not await stor.db_pool.updates.has_completed_background_updates():
+ await stor.db_pool.updates.do_next_background_update(1)
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
@@ -571,7 +567,7 @@ class HomeserverTestCase(TestCase):
Add the given event as an extremity to the room.
"""
self.get_success(
- self.hs.get_datastore().db.simple_insert(
+ self.hs.get_datastore().db_pool.simple_insert(
table="event_forward_extremities",
values={"room_id": room_id, "event_id": event_id},
desc="test_add_extremity",
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 9e348694ad..bc42ffce88 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -26,9 +26,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
def test_new_destination(self):
"""A happy-path case with a new destination and a successful operation"""
store = self.hs.get_datastore()
- d = get_retry_limiter("test_dest", self.clock, store)
- self.pump()
- limiter = self.successResultOf(d)
+ limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
# advance the clock a bit before making the request
self.pump(1)
@@ -36,18 +34,14 @@ class RetryLimiterTestCase(HomeserverTestCase):
with limiter:
pass
- d = store.get_destination_retry_timings("test_dest")
- self.pump()
- new_timings = self.successResultOf(d)
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertIsNone(new_timings)
def test_limiter(self):
"""General test case which walks through the process of a failing request"""
store = self.hs.get_datastore()
- d = get_retry_limiter("test_dest", self.clock, store)
- self.pump()
- limiter = self.successResultOf(d)
+ limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump(1)
try:
@@ -58,29 +52,22 @@ class RetryLimiterTestCase(HomeserverTestCase):
except AssertionError:
pass
- # wait for the update to land
- self.pump()
-
- d = store.get_destination_retry_timings("test_dest")
- self.pump()
- new_timings = self.successResultOf(d)
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertEqual(new_timings["failure_ts"], failure_ts)
self.assertEqual(new_timings["retry_last_ts"], failure_ts)
self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL)
# now if we try again we should get a failure
- d = get_retry_limiter("test_dest", self.clock, store)
- self.pump()
- self.failureResultOf(d, NotRetryingDestination)
+ self.get_failure(
+ get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination
+ )
#
# advance the clock and try again
#
self.pump(MIN_RETRY_INTERVAL)
- d = get_retry_limiter("test_dest", self.clock, store)
- self.pump()
- limiter = self.successResultOf(d)
+ limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump(1)
try:
@@ -91,12 +78,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
except AssertionError:
pass
- # wait for the update to land
- self.pump()
-
- d = store.get_destination_retry_timings("test_dest")
- self.pump()
- new_timings = self.successResultOf(d)
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertEqual(new_timings["failure_ts"], failure_ts)
self.assertEqual(new_timings["retry_last_ts"], retry_ts)
self.assertGreaterEqual(
@@ -110,9 +92,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
# one more go, with success
#
self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
- d = get_retry_limiter("test_dest", self.clock, store)
- self.pump()
- limiter = self.successResultOf(d)
+ limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump(1)
with limiter:
@@ -121,7 +101,5 @@ class RetryLimiterTestCase(HomeserverTestCase):
# wait for the update to land
self.pump()
- d = store.get_destination_retry_timings("test_dest")
- self.pump()
- new_timings = self.successResultOf(d)
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertIsNone(new_timings)
diff --git a/tests/utils.py b/tests/utils.py
index b33b6860d4..a61cbdef44 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -154,6 +154,10 @@ def default_config(name, parse=False):
"account": {"per_second": 10000, "burst_count": 10000},
"failed_attempts": {"per_second": 10000, "burst_count": 10000},
},
+ "rc_joins": {
+ "local": {"per_second": 10000, "burst_count": 10000},
+ "remote": {"per_second": 10000, "burst_count": 10000},
+ },
"saml2_enabled": False,
"public_baseurl": None,
"default_identity_server": None,
diff --git a/tox.ini b/tox.ini
index a394f6eadc..e5413eb110 100644
--- a/tox.ini
+++ b/tox.ini
@@ -179,6 +179,7 @@ commands = mypy \
synapse/appservice \
synapse/config \
synapse/event_auth.py \
+ synapse/events/builder.py \
synapse/events/spamcheck.py \
synapse/federation \
synapse/handlers/auth.py \
@@ -186,6 +187,7 @@ commands = mypy \
synapse/handlers/directory.py \
synapse/handlers/federation.py \
synapse/handlers/identity.py \
+ synapse/handlers/message.py \
synapse/handlers/oidc_handler.py \
synapse/handlers/presence.py \
synapse/handlers/room_member.py \
@@ -198,18 +200,23 @@ commands = mypy \
synapse/logging/ \
synapse/metrics \
synapse/module_api \
+ synapse/notifier.py \
synapse/push/pusherpool.py \
synapse/push/push_rule_evaluator.py \
synapse/replication \
synapse/rest \
+ synapse/server.py \
+ synapse/server_notices \
synapse/spam_checker_api \
- synapse/storage/data_stores/main/ui_auth.py \
+ synapse/storage/databases/main/ui_auth.py \
synapse/storage/database.py \
synapse/storage/engines \
synapse/storage/state.py \
synapse/storage/util \
synapse/streams \
+ synapse/types.py \
synapse/util/caches/stream_change_cache.py \
+ synapse/util/metrics.py \
tests/replication \
tests/test_utils \
tests/rest/client/v2_alpha/test_auth.py \
|