diff options
author | Neil Johnson <neil@matrix.org> | 2019-12-10 11:25:28 +0000 |
---|---|---|
committer | Neil Johnson <neil@matrix.org> | 2019-12-10 11:25:28 +0000 |
commit | 0a522121a0801e5474397b4f01730bf1896bc497 (patch) | |
tree | 80b491e8965c0a99127e33789ac28c39ae660ad4 | |
parent | Merge branch 'release-v1.6.1' of github.com:matrix-org/synapse into matrix-or... (diff) | |
parent | Fix erroneous reference for new room directory defaults. (diff) | |
download | synapse-0a522121a0801e5474397b4f01730bf1896bc497.tar.xz |
Merge branch 'release-v1.7.0' of github.com:matrix-org/synapse into matrix-org-hotfixes
214 files changed, 7455 insertions, 4720 deletions
diff --git a/.buildkite/merge_base_branch.sh b/.buildkite/merge_base_branch.sh index eb7219a56d..361440fd1a 100755 --- a/.buildkite/merge_base_branch.sh +++ b/.buildkite/merge_base_branch.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -set -ex +set -e if [[ "$BUILDKITE_BRANCH" =~ ^(develop|master|dinsic|shhs|release-.*)$ ]]; then echo "Not merging forward, as this is a release branch" @@ -18,6 +18,8 @@ else GITBASE=$BUILDKITE_PULL_REQUEST_BASE_BRANCH fi +echo "--- merge_base_branch $GITBASE" + # Show what we are before git --no-pager show -s diff --git a/.buildkite/postgres-config.yaml b/.buildkite/postgres-config.yaml index a35fec394d..2acbe66f4c 100644 --- a/.buildkite/postgres-config.yaml +++ b/.buildkite/postgres-config.yaml @@ -1,7 +1,7 @@ # Configuration file used for testing the 'synapse_port_db' script. # Tells the script to connect to the postgresql database that will be available in the # CI's Docker setup at the point where this file is considered. -server_name: "test" +server_name: "localhost:8800" signing_key_path: "/src/.buildkite/test.signing.key" diff --git a/.buildkite/sqlite-config.yaml b/.buildkite/sqlite-config.yaml index 635b921764..6d9bf80d84 100644 --- a/.buildkite/sqlite-config.yaml +++ b/.buildkite/sqlite-config.yaml @@ -1,7 +1,7 @@ # Configuration file used for testing the 'synapse_port_db' script. # Tells the 'update_database' script to connect to the test SQLite database to upgrade its # schema and run background updates on it. -server_name: "test" +server_name: "localhost:8800" signing_key_path: "/src/.buildkite/test.signing.key" diff --git a/.buildkite/worker-blacklist b/.buildkite/worker-blacklist index cda5c84e94..7950d19db3 100644 --- a/.buildkite/worker-blacklist +++ b/.buildkite/worker-blacklist @@ -28,3 +28,39 @@ User sees updates to presence from other users in the incremental sync. Gapped incremental syncs include all state changes Old members are included in gappy incr LL sync if they start speaking + +# new failures as of https://github.com/matrix-org/sytest/pull/732 +Device list doesn't change if remote server is down +Remote servers cannot set power levels in rooms without existing powerlevels +Remote servers should reject attempts by non-creators to set the power levels + +# new failures as of https://github.com/matrix-org/sytest/pull/753 +GET /rooms/:room_id/messages returns a message +GET /rooms/:room_id/messages lazy loads members correctly +Read receipts are sent as events +Only original members of the room can see messages from erased users +Device deletion propagates over federation +If user leaves room, remote user changes device and rejoins we see update in /sync and /keys/changes +Changing user-signing key notifies local users +Newly updated tags appear in an incremental v2 /sync +Server correctly handles incoming m.device_list_update +Local device key changes get to remote servers with correct prev_id +AS-ghosted users can use rooms via AS +Ghost user must register before joining room +Test that a message is pushed +Invites are pushed +Rooms with aliases are correctly named in pushed +Rooms with names are correctly named in pushed +Rooms with canonical alias are correctly named in pushed +Rooms with many users are correctly pushed +Don't get pushed for rooms you've muted +Rejected events are not pushed +Test that rejected pushers are removed. +Events come down the correct room + +# https://buildkite.com/matrix-dot-org/sytest/builds/326#cca62404-a88a-4fcb-ad41-175fd3377603 +Presence changes to UNAVAILABLE are reported to remote room members +If remote user leaves room, changes device and rejoins we see update in sync +uploading self-signing key notifies over federation +Inbound federation can receive redacted events +Outbound federation can request missing events diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 8939fda67d..11fb05ca96 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,8 +1,8 @@ ### Pull Request Checklist -<!-- Please read CONTRIBUTING.rst before submitting your pull request --> +<!-- Please read CONTRIBUTING.md before submitting your pull request --> * [ ] Pull request is based on the develop branch -* [ ] Pull request includes a [changelog file](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#changelog) -* [ ] Pull request includes a [sign off](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#sign-off) -* [ ] Code style is correct (run the [linters](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#code-style)) +* [ ] Pull request includes a [changelog file](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.md#changelog) +* [ ] Pull request includes a [sign off](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.md#sign-off) +* [ ] Code style is correct (run the [linters](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.md#code-style)) diff --git a/CHANGES.md b/CHANGES.md index a9afd36d2c..c30ea4718d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,75 @@ +Synapse 1.7.0rc1 (2019-12-09) +============================= + +Features +-------- + +- Implement per-room message retention policies. ([\#5815](https://github.com/matrix-org/synapse/issues/5815), [\#6436](https://github.com/matrix-org/synapse/issues/6436)) +- Add etag and count fields to key backup endpoints to help clients guess if there are new keys. ([\#5858](https://github.com/matrix-org/synapse/issues/5858)) +- Add `/admin/v2/users` endpoint with pagination. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#5925](https://github.com/matrix-org/synapse/issues/5925)) +- Require User-Interactive Authentication for `/account/3pid/add`, meaning the user's password will be required to add a third-party ID to their account. ([\#6119](https://github.com/matrix-org/synapse/issues/6119)) +- Implement the `/_matrix/federation/unstable/net.atleastfornow/state/<context>` API as drafted in MSC2314. ([\#6176](https://github.com/matrix-org/synapse/issues/6176)) +- Configure privacy-preserving settings by default for the room directory. ([\#6355](https://github.com/matrix-org/synapse/issues/6355)) +- Add ephemeral messages support by partially implementing [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228). ([\#6409](https://github.com/matrix-org/synapse/issues/6409)) +- Add support for [MSC 2367](https://github.com/matrix-org/matrix-doc/pull/2367), which allows specifying a reason on all membership events. ([\#6434](https://github.com/matrix-org/synapse/issues/6434)) + + +Bugfixes +-------- + +- Transfer non-standard power levels on room upgrade. ([\#6237](https://github.com/matrix-org/synapse/issues/6237)) +- Fix error from the Pillow library when uploading RGBA images. ([\#6241](https://github.com/matrix-org/synapse/issues/6241)) +- Correctly apply the event filter to the `state`, `events_before` and `events_after` fields in the response to `/context` requests. ([\#6329](https://github.com/matrix-org/synapse/issues/6329)) +- Fix caching devices for remote users when using workers, so that we don't attempt to refetch (and potentially fail) each time a user requests devices. ([\#6332](https://github.com/matrix-org/synapse/issues/6332)) +- Prevent account data syncs getting lost across TCP replication. ([\#6333](https://github.com/matrix-org/synapse/issues/6333)) +- Fix bug: TypeError in `register_user()` while using LDAP auth module. ([\#6406](https://github.com/matrix-org/synapse/issues/6406)) +- Fix an intermittent exception when handling read-receipts. ([\#6408](https://github.com/matrix-org/synapse/issues/6408)) +- Fix broken guest registration when there are existing blocks of numeric user IDs. ([\#6420](https://github.com/matrix-org/synapse/issues/6420)) +- Fix startup error when http proxy is defined. ([\#6421](https://github.com/matrix-org/synapse/issues/6421)) +- Fix error when using synapse_port_db on a vanilla synapse db. ([\#6449](https://github.com/matrix-org/synapse/issues/6449)) +- Fix uploading multiple cross signing signatures for the same user. ([\#6451](https://github.com/matrix-org/synapse/issues/6451)) +- Fix bug which lead to exceptions being thrown in a loop when a cross-signed device is deleted. ([\#6462](https://github.com/matrix-org/synapse/issues/6462)) +- Fix `synapse_port_db` not exiting with a 0 code if something went wrong during the port process. ([\#6470](https://github.com/matrix-org/synapse/issues/6470)) +- Improve sanity-checking when receiving events over federation. ([\#6472](https://github.com/matrix-org/synapse/issues/6472)) +- Fix inaccurate per-block Prometheus metrics. ([\#6491](https://github.com/matrix-org/synapse/issues/6491)) +- Fix small performance regression for sending invites. ([\#6493](https://github.com/matrix-org/synapse/issues/6493)) +- Back out cross-signing code added in Synapse 1.5.0, which caused a performance regression. ([\#6494](https://github.com/matrix-org/synapse/issues/6494)) + + +Improved Documentation +---------------------- + +- Update documentation and variables in user contributed systemd reference file. ([\#6369](https://github.com/matrix-org/synapse/issues/6369), [\#6490](https://github.com/matrix-org/synapse/issues/6490)) +- Fix link in the user directory documentation. ([\#6388](https://github.com/matrix-org/synapse/issues/6388)) +- Add build instructions to the docker readme. ([\#6390](https://github.com/matrix-org/synapse/issues/6390)) +- Switch Ubuntu package install recommendation to use python3 packages in INSTALL.md. ([\#6443](https://github.com/matrix-org/synapse/issues/6443)) +- Write some docs for the quarantine_media api. ([\#6458](https://github.com/matrix-org/synapse/issues/6458)) +- Convert CONTRIBUTING.rst to markdown (among other small fixes). ([\#6461](https://github.com/matrix-org/synapse/issues/6461)) + + +Deprecations and Removals +------------------------- + +- Remove admin/v1/users_paginate endpoint. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#5925](https://github.com/matrix-org/synapse/issues/5925)) +- Remove fallback for federation with old servers which lack the /federation/v1/state_ids API. ([\#6488](https://github.com/matrix-org/synapse/issues/6488)) + + +Internal Changes +---------------- + +- Add benchmarks for structured logging and improve output performance. ([\#6266](https://github.com/matrix-org/synapse/issues/6266)) +- Improve the performance of outputting structured logging. ([\#6322](https://github.com/matrix-org/synapse/issues/6322)) +- Refactor some code in the event authentication path for clarity. ([\#6343](https://github.com/matrix-org/synapse/issues/6343), [\#6468](https://github.com/matrix-org/synapse/issues/6468), [\#6480](https://github.com/matrix-org/synapse/issues/6480)) +- Clean up some unnecessary quotation marks around the codebase. ([\#6362](https://github.com/matrix-org/synapse/issues/6362)) +- Complain on startup instead of 500'ing during runtime when `public_baseurl` isn't set when necessary. ([\#6379](https://github.com/matrix-org/synapse/issues/6379)) +- Add a test scenario to make sure room history purges don't break `/messages` in the future. ([\#6392](https://github.com/matrix-org/synapse/issues/6392)) +- Clarifications for the email configuration settings. ([\#6423](https://github.com/matrix-org/synapse/issues/6423)) +- Add more tests to the blacklist when running in worker mode. ([\#6429](https://github.com/matrix-org/synapse/issues/6429)) +- Refactor data store layer to support multiple databases in the future. ([\#6454](https://github.com/matrix-org/synapse/issues/6454), [\#6464](https://github.com/matrix-org/synapse/issues/6464), [\#6469](https://github.com/matrix-org/synapse/issues/6469), [\#6487](https://github.com/matrix-org/synapse/issues/6487)) +- Port synapse.rest.client.v1 to async/await. ([\#6482](https://github.com/matrix-org/synapse/issues/6482)) +- Port synapse.rest.client.v2_alpha to async/await. ([\#6483](https://github.com/matrix-org/synapse/issues/6483)) +- Port SyncHandler to async/await. ([\#6484](https://github.com/matrix-org/synapse/issues/6484)) + Synapse 1.6.1 (2019-11-28) ========================== diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000..c0091346f3 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,210 @@ +# Contributing code to Matrix + +Everyone is welcome to contribute code to Matrix +(https://github.com/matrix-org), provided that they are willing to license +their contributions under the same license as the project itself. We follow a +simple 'inbound=outbound' model for contributions: the act of submitting an +'inbound' contribution means that the contributor agrees to license the code +under the same terms as the project's overall 'outbound' license - in our +case, this is almost always Apache Software License v2 (see [LICENSE](LICENSE)). + +## How to contribute + +The preferred and easiest way to contribute changes to Matrix is to fork the +relevant project on github, and then [create a pull request]( +https://help.github.com/articles/using-pull-requests/) to ask us to pull +your changes into our repo. + +**The single biggest thing you need to know is: please base your changes on +the develop branch - *not* master.** + +We use the master branch to track the most recent release, so that folks who +blindly clone the repo and automatically check out master get something that +works. Develop is the unstable branch where all the development actually +happens: the workflow is that contributors should fork the develop branch to +make a 'feature' branch for a particular contribution, and then make a pull +request to merge this back into the matrix.org 'official' develop branch. We +use github's pull request workflow to review the contribution, and either ask +you to make any refinements needed or merge it and make them ourselves. The +changes will then land on master when we next do a release. + +We use [Buildkite](https://buildkite.com/matrix-dot-org/synapse) for continuous +integration. If your change breaks the build, this will be shown in GitHub, so +please keep an eye on the pull request for feedback. + +To run unit tests in a local development environment, you can use: + +- ``tox -e py35`` (requires tox to be installed by ``pip install tox``) + for SQLite-backed Synapse on Python 3.5. +- ``tox -e py36`` for SQLite-backed Synapse on Python 3.6. +- ``tox -e py36-postgres`` for PostgreSQL-backed Synapse on Python 3.6 + (requires a running local PostgreSQL with access to create databases). +- ``./test_postgresql.sh`` for PostgreSQL-backed Synapse on Python 3.5 + (requires Docker). Entirely self-contained, recommended if you don't want to + set up PostgreSQL yourself. + +Docker images are available for running the integration tests (SyTest) locally, +see the [documentation in the SyTest repo]( +https://github.com/matrix-org/sytest/blob/develop/docker/README.md) for more +information. + +## Code style + +All Matrix projects have a well-defined code-style - and sometimes we've even +got as far as documenting it... For instance, synapse's code style doc lives +[here](docs/code_style.md). + +To facilitate meeting these criteria you can run `scripts-dev/lint.sh` +locally. Since this runs the tools listed in the above document, you'll need +python 3.6 and to install each tool: + +``` +# Install the dependencies +pip install -U black flake8 isort + +# Run the linter script +./scripts-dev/lint.sh +``` + +**Note that the script does not just test/check, but also reformats code, so you +may wish to ensure any new code is committed first**. By default this script +checks all files and can take some time; if you alter only certain files, you +might wish to specify paths as arguments to reduce the run-time: + +``` +./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder +``` + +Before pushing new changes, ensure they don't produce linting errors. Commit any +files that were corrected. + +Please ensure your changes match the cosmetic style of the existing project, +and **never** mix cosmetic and functional changes in the same commit, as it +makes it horribly hard to review otherwise. + + +## Changelog + +All changes, even minor ones, need a corresponding changelog / newsfragment +entry. These are managed by [Towncrier](https://github.com/hawkowl/towncrier). + +To create a changelog entry, make a new file in the `changelog.d` directory named +in the format of `PRnumber.type`. The type can be one of the following: + +* `feature` +* `bugfix` +* `docker` (for updates to the Docker image) +* `doc` (for updates to the documentation) +* `removal` (also used for deprecations) +* `misc` (for internal-only changes) + +The content of the file is your changelog entry, which should be a short +description of your change in the same style as the rest of our [changelog]( +https://github.com/matrix-org/synapse/blob/master/CHANGES.md). The file can +contain Markdown formatting, and should end with a full stop ('.') for +consistency. + +Adding credits to the changelog is encouraged, we value your +contributions and would like to have you shouted out in the release notes! + +For example, a fix in PR #1234 would have its changelog entry in +`changelog.d/1234.bugfix`, and contain content like "The security levels of +Florbs are now validated when received over federation. Contributed by Jane +Matrix.". + +## Debian changelog + +Changes which affect the debian packaging files (in `debian`) are an +exception. + +In this case, you will need to add an entry to the debian changelog for the +next release. For this, run the following command: + +``` +dch +``` + +This will make up a new version number (if there isn't already an unreleased +version in flight), and open an editor where you can add a new changelog entry. +(Our release process will ensure that the version number and maintainer name is +corrected for the release.) + +If your change affects both the debian packaging *and* files outside the debian +directory, you will need both a regular newsfragment *and* an entry in the +debian changelog. (Though typically such changes should be submitted as two +separate pull requests.) + +## Sign off + +In order to have a concrete record that your contribution is intentional +and you agree to license it under the same terms as the project's license, we've adopted the +same lightweight approach that the Linux Kernel +[submitting patches process]( +https://www.kernel.org/doc/html/latest/process/submitting-patches.html#sign-your-work-the-developer-s-certificate-of-origin>), +[Docker](https://github.com/docker/docker/blob/master/CONTRIBUTING.md), and many other +projects use: the DCO (Developer Certificate of Origin: +http://developercertificate.org/). This is a simple declaration that you wrote +the contribution or otherwise have the right to contribute it to Matrix: + +``` +Developer Certificate of Origin +Version 1.1 + +Copyright (C) 2004, 2006 The Linux Foundation and its contributors. +660 York Street, Suite 102, +San Francisco, CA 94110 USA + +Everyone is permitted to copy and distribute verbatim copies of this +license document, but changing it is not allowed. + +Developer's Certificate of Origin 1.1 + +By making a contribution to this project, I certify that: + +(a) The contribution was created in whole or in part by me and I + have the right to submit it under the open source license + indicated in the file; or + +(b) The contribution is based upon previous work that, to the best + of my knowledge, is covered under an appropriate open source + license and I have the right under that license to submit that + work with modifications, whether created in whole or in part + by me, under the same open source license (unless I am + permitted to submit under a different license), as indicated + in the file; or + +(c) The contribution was provided directly to me by some other + person who certified (a), (b) or (c) and I have not modified + it. + +(d) I understand and agree that this project and the contribution + are public and that a record of the contribution (including all + personal information I submit with it, including my sign-off) is + maintained indefinitely and may be redistributed consistent with + this project or the open source license(s) involved. +``` + +If you agree to this for your contribution, then all that's needed is to +include the line in your commit or pull request comment: + +``` +Signed-off-by: Your Name <your@email.example.org> +``` + +We accept contributions under a legally identifiable name, such as +your name on government documentation or common-law names (names +claimed by legitimate usage or repute). Unfortunately, we cannot +accept anonymous contributions at this time. + +Git allows you to add this signoff automatically when using the `-s` +flag to `git commit`, which uses the name and email set in your +`user.name` and `user.email` git configs. + +## Conclusion + +That's it! Matrix is a very open and collaborative project as you might expect +given our obsession with open communication. If we're going to successfully +matrix together all the fragmented communication technologies out there we are +reliant on contributions and collaboration from the community to do so. So +please get involved - and we hope you have as much fun hacking on Matrix as we +do! diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst deleted file mode 100644 index df81f6e54f..0000000000 --- a/CONTRIBUTING.rst +++ /dev/null @@ -1,206 +0,0 @@ -Contributing code to Matrix -=========================== - -Everyone is welcome to contribute code to Matrix -(https://github.com/matrix-org), provided that they are willing to license -their contributions under the same license as the project itself. We follow a -simple 'inbound=outbound' model for contributions: the act of submitting an -'inbound' contribution means that the contributor agrees to license the code -under the same terms as the project's overall 'outbound' license - in our -case, this is almost always Apache Software License v2 (see LICENSE). - -How to contribute -~~~~~~~~~~~~~~~~~ - -The preferred and easiest way to contribute changes to Matrix is to fork the -relevant project on github, and then create a pull request to ask us to pull -your changes into our repo -(https://help.github.com/articles/using-pull-requests/) - -**The single biggest thing you need to know is: please base your changes on -the develop branch - /not/ master.** - -We use the master branch to track the most recent release, so that folks who -blindly clone the repo and automatically check out master get something that -works. Develop is the unstable branch where all the development actually -happens: the workflow is that contributors should fork the develop branch to -make a 'feature' branch for a particular contribution, and then make a pull -request to merge this back into the matrix.org 'official' develop branch. We -use github's pull request workflow to review the contribution, and either ask -you to make any refinements needed or merge it and make them ourselves. The -changes will then land on master when we next do a release. - -We use `Buildkite <https://buildkite.com/matrix-dot-org/synapse>`_ for -continuous integration. Buildkite builds need to be authorised by a -maintainer. If your change breaks the build, this will be shown in GitHub, so -please keep an eye on the pull request for feedback. - -To run unit tests in a local development environment, you can use: - -- ``tox -e py35`` (requires tox to be installed by ``pip install tox``) - for SQLite-backed Synapse on Python 3.5. -- ``tox -e py36`` for SQLite-backed Synapse on Python 3.6. -- ``tox -e py36-postgres`` for PostgreSQL-backed Synapse on Python 3.6 - (requires a running local PostgreSQL with access to create databases). -- ``./test_postgresql.sh`` for PostgreSQL-backed Synapse on Python 3.5 - (requires Docker). Entirely self-contained, recommended if you don't want to - set up PostgreSQL yourself. - -Docker images are available for running the integration tests (SyTest) locally, -see the `documentation in the SyTest repo -<https://github.com/matrix-org/sytest/blob/develop/docker/README.md>`_ for more -information. - -Code style -~~~~~~~~~~ - -All Matrix projects have a well-defined code-style - and sometimes we've even -got as far as documenting it... For instance, synapse's code style doc lives -at https://github.com/matrix-org/synapse/tree/master/docs/code_style.md. - -To facilitate meeting these criteria you can run ``scripts-dev/lint.sh`` -locally. Since this runs the tools listed in the above document, you'll need -python 3.6 and to install each tool. **Note that the script does not just -test/check, but also reformats code, so you may wish to ensure any new code is -committed first**. By default this script checks all files and can take some -time; if you alter only certain files, you might wish to specify paths as -arguments to reduce the run-time. - -Please ensure your changes match the cosmetic style of the existing project, -and **never** mix cosmetic and functional changes in the same commit, as it -makes it horribly hard to review otherwise. - -Before doing a commit, ensure the changes you've made don't produce -linting errors. You can do this by running the linters as follows. Ensure to -commit any files that were corrected. - -:: - # Install the dependencies - pip install -U black flake8 isort - - # Run the linter script - ./scripts-dev/lint.sh - -Changelog -~~~~~~~~~ - -All changes, even minor ones, need a corresponding changelog / newsfragment -entry. These are managed by Towncrier -(https://github.com/hawkowl/towncrier). - -To create a changelog entry, make a new file in the ``changelog.d`` file named -in the format of ``PRnumber.type``. The type can be one of the following: - -* ``feature``. -* ``bugfix``. -* ``docker`` (for updates to the Docker image). -* ``doc`` (for updates to the documentation). -* ``removal`` (also used for deprecations). -* ``misc`` (for internal-only changes). - -The content of the file is your changelog entry, which should be a short -description of your change in the same style as the rest of our `changelog -<https://github.com/matrix-org/synapse/blob/master/CHANGES.md>`_. The file can -contain Markdown formatting, and should end with a full stop ('.') for -consistency. - -Adding credits to the changelog is encouraged, we value your -contributions and would like to have you shouted out in the release notes! - -For example, a fix in PR #1234 would have its changelog entry in -``changelog.d/1234.bugfix``, and contain content like "The security levels of -Florbs are now validated when recieved over federation. Contributed by Jane -Matrix.". - -Debian changelog ----------------- - -Changes which affect the debian packaging files (in ``debian``) are an -exception. - -In this case, you will need to add an entry to the debian changelog for the -next release. For this, run the following command:: - - dch - -This will make up a new version number (if there isn't already an unreleased -version in flight), and open an editor where you can add a new changelog entry. -(Our release process will ensure that the version number and maintainer name is -corrected for the release.) - -If your change affects both the debian packaging *and* files outside the debian -directory, you will need both a regular newsfragment *and* an entry in the -debian changelog. (Though typically such changes should be submitted as two -separate pull requests.) - -Sign off -~~~~~~~~ - -In order to have a concrete record that your contribution is intentional -and you agree to license it under the same terms as the project's license, we've adopted the -same lightweight approach that the Linux Kernel -`submitting patches process <https://www.kernel.org/doc/html/latest/process/submitting-patches.html#sign-your-work-the-developer-s-certificate-of-origin>`_, Docker -(https://github.com/docker/docker/blob/master/CONTRIBUTING.md), and many other -projects use: the DCO (Developer Certificate of Origin: -http://developercertificate.org/). This is a simple declaration that you wrote -the contribution or otherwise have the right to contribute it to Matrix:: - - Developer Certificate of Origin - Version 1.1 - - Copyright (C) 2004, 2006 The Linux Foundation and its contributors. - 660 York Street, Suite 102, - San Francisco, CA 94110 USA - - Everyone is permitted to copy and distribute verbatim copies of this - license document, but changing it is not allowed. - - Developer's Certificate of Origin 1.1 - - By making a contribution to this project, I certify that: - - (a) The contribution was created in whole or in part by me and I - have the right to submit it under the open source license - indicated in the file; or - - (b) The contribution is based upon previous work that, to the best - of my knowledge, is covered under an appropriate open source - license and I have the right under that license to submit that - work with modifications, whether created in whole or in part - by me, under the same open source license (unless I am - permitted to submit under a different license), as indicated - in the file; or - - (c) The contribution was provided directly to me by some other - person who certified (a), (b) or (c) and I have not modified - it. - - (d) I understand and agree that this project and the contribution - are public and that a record of the contribution (including all - personal information I submit with it, including my sign-off) is - maintained indefinitely and may be redistributed consistent with - this project or the open source license(s) involved. - -If you agree to this for your contribution, then all that's needed is to -include the line in your commit or pull request comment:: - - Signed-off-by: Your Name <your@email.example.org> - -We accept contributions under a legally identifiable name, such as -your name on government documentation or common-law names (names -claimed by legitimate usage or repute). Unfortunately, we cannot -accept anonymous contributions at this time. - -Git allows you to add this signoff automatically when using the ``-s`` -flag to ``git commit``, which uses the name and email set in your -``user.name`` and ``user.email`` git configs. - -Conclusion -~~~~~~~~~~ - -That's it! Matrix is a very open and collaborative project as you might expect -given our obsession with open communication. If we're going to successfully -matrix together all the fragmented communication technologies out there we are -reliant on contributions and collaboration from the community to do so. So -please get involved - and we hope you have as much fun hacking on Matrix as we -do! diff --git a/INSTALL.md b/INSTALL.md index 9b7360f0ef..9da2e3c734 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -109,8 +109,8 @@ Installing prerequisites on Ubuntu or Debian: ``` sudo apt-get install build-essential python3-dev libffi-dev \ - python-pip python-setuptools sqlite3 \ - libssl-dev python-virtualenv libjpeg-dev libxslt1-dev + python3-pip python3-setuptools sqlite3 \ + libssl-dev python3-virtualenv libjpeg-dev libxslt1-dev ``` #### ArchLinux diff --git a/UPGRADE.rst b/UPGRADE.rst index 5ebf16a73e..d9020f2663 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -75,6 +75,23 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb +Upgrading to v1.7.0 +=================== + +In an attempt to configure Synapse in a privacy preserving way, the default +behaviours of ``allow_public_rooms_without_auth`` and +``allow_public_rooms_over_federation`` have been inverted. This means that by +default, only authenticated users querying the Client/Server API will be able +to query the room directory, and relatedly that the server will not share +room directory information with other servers over federation. + +If your installation does not explicitly set these settings one way or the other +and you want either setting to be ``true`` then it will necessary to update +your homeserver configuration file accordingly. + +For more details on the surrounding context see our `explainer +<https://matrix.org/blog/2019/11/09/avoiding-unwelcome-visitors-on-private-matrix-servers>`_. + Upgrading to v1.5.0 =================== diff --git a/changelog.d/6322.misc b/changelog.d/6322.misc deleted file mode 100644 index 70ef36ca80..0000000000 --- a/changelog.d/6322.misc +++ /dev/null @@ -1 +0,0 @@ -Improve the performance of outputting structured logging. diff --git a/changelog.d/6332.bugfix b/changelog.d/6332.bugfix deleted file mode 100644 index 67d5170ba0..0000000000 --- a/changelog.d/6332.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix caching devices for remote users when using workers, so that we don't attempt to refetch (and potentially fail) each time a user requests devices. diff --git a/changelog.d/6333.bugfix b/changelog.d/6333.bugfix deleted file mode 100644 index a25d6ef3cb..0000000000 --- a/changelog.d/6333.bugfix +++ /dev/null @@ -1 +0,0 @@ -Prevent account data syncs getting lost across TCP replication. \ No newline at end of file diff --git a/changelog.d/6343.misc b/changelog.d/6343.misc deleted file mode 100644 index d9a44389b9..0000000000 --- a/changelog.d/6343.misc +++ /dev/null @@ -1 +0,0 @@ -Refactor some code in the event authentication path for clarity. diff --git a/changelog.d/6362.misc b/changelog.d/6362.misc deleted file mode 100644 index b79a5bea99..0000000000 --- a/changelog.d/6362.misc +++ /dev/null @@ -1 +0,0 @@ -Clean up some unnecessary quotation marks around the codebase. \ No newline at end of file diff --git a/changelog.d/6379.misc b/changelog.d/6379.misc deleted file mode 100644 index 725c2e7d87..0000000000 --- a/changelog.d/6379.misc +++ /dev/null @@ -1 +0,0 @@ -Complain on startup instead of 500'ing during runtime when `public_baseurl` isn't set when necessary. \ No newline at end of file diff --git a/changelog.d/6388.doc b/changelog.d/6388.doc deleted file mode 100644 index c777cb6b8f..0000000000 --- a/changelog.d/6388.doc +++ /dev/null @@ -1 +0,0 @@ -Fix link in the user directory documentation. diff --git a/changelog.d/6390.doc b/changelog.d/6390.doc deleted file mode 100644 index 093411bec1..0000000000 --- a/changelog.d/6390.doc +++ /dev/null @@ -1 +0,0 @@ -Add build instructions to the docker readme. \ No newline at end of file diff --git a/changelog.d/6392.misc b/changelog.d/6392.misc deleted file mode 100644 index a00257944f..0000000000 --- a/changelog.d/6392.misc +++ /dev/null @@ -1 +0,0 @@ -Add a test scenario to make sure room history purges don't break `/messages` in the future. diff --git a/changelog.d/6408.bugfix b/changelog.d/6408.bugfix deleted file mode 100644 index c9babe599b..0000000000 --- a/changelog.d/6408.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix an intermittent exception when handling read-receipts. diff --git a/changelog.d/6420.bugfix b/changelog.d/6420.bugfix deleted file mode 100644 index aef47cccaa..0000000000 --- a/changelog.d/6420.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix broken guest registration when there are existing blocks of numeric user IDs. diff --git a/changelog.d/6421.bugfix b/changelog.d/6421.bugfix deleted file mode 100644 index 7969f7f71d..0000000000 --- a/changelog.d/6421.bugfix +++ /dev/null @@ -1 +0,0 @@ -Fix startup error when http proxy is defined. diff --git a/contrib/systemd/README.md b/contrib/systemd/README.md new file mode 100644 index 0000000000..5d42b3464f --- /dev/null +++ b/contrib/systemd/README.md @@ -0,0 +1,17 @@ +# Setup Synapse with Systemd +This is a setup for managing synapse with a user contributed systemd unit +file. It provides a `matrix-synapse` systemd unit file that should be tailored +to accommodate your installation in accordance with the installation +instructions provided in [installation instructions](../../INSTALL.md). + +## Setup +1. Under the service section, ensure the `User` variable matches which user +you installed synapse under and wish to run it as. +2. Under the service section, ensure the `WorkingDirectory` variable matches +where you have installed synapse. +3. Under the service section, ensure the `ExecStart` variable matches the +appropriate locations of your installation. +4. Copy the `matrix-synapse.service` to `/etc/systemd/system/` +5. Start Synapse: `sudo systemctl start matrix-synapse` +6. Verify Synapse is running: `sudo systemctl status matrix-synapse` +7. *optional* Enable Synapse to start at system boot: `sudo systemctl enable matrix-synapse` diff --git a/contrib/systemd/matrix-synapse.service b/contrib/systemd/matrix-synapse.service index 38d369ea3d..813717b032 100644 --- a/contrib/systemd/matrix-synapse.service +++ b/contrib/systemd/matrix-synapse.service @@ -4,8 +4,11 @@ # systemctl enable matrix-synapse # systemctl start matrix-synapse # +# This assumes that Synapse has been installed by a user named +# synapse. +# # This assumes that Synapse has been installed in a virtualenv in -# /opt/synapse/env. +# the user's home directory: `/home/synapse/synapse/env`. # # **NOTE:** This is an example service file that may change in the future. If you # wish to use this please copy rather than symlink it. @@ -22,8 +25,8 @@ Restart=on-abort User=synapse Group=nogroup -WorkingDirectory=/opt/synapse -ExecStart=/opt/synapse/env/bin/python -m synapse.app.homeserver --config-path=/opt/synapse/homeserver.yaml +WorkingDirectory=/home/synapse/synapse +ExecStart=/home/synapse/synapse/env/bin/python -m synapse.app.homeserver --config-path=/home/synapse/synapse/homeserver.yaml SyslogIdentifier=matrix-synapse # adjust the cache factor if necessary diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index 5e9f8e5d84..8b3666d5f5 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -21,3 +21,20 @@ It returns a JSON body like the following: ] } ``` + +# Quarantine media in a room + +This API 'quarantines' all the media in a room. + +The API is: + +``` +POST /_synapse/admin/v1/quarantine_media/<room_id> + +{} +``` + +Quarantining media means that it is marked as inaccessible by users. It applies +to any local media, and any locally-cached copies of remote media. + +The media file itself (and any thumbnails) is not deleted from the server. diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index d0871f9438..b451dc5014 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -1,3 +1,48 @@ +List Accounts +============= + +This API returns all local user accounts. + +The api is:: + + GET /_synapse/admin/v2/users?from=0&limit=10&guests=false + +including an ``access_token`` of a server admin. +The parameters ``from`` and ``limit`` are required only for pagination. +By default, a ``limit`` of 100 is used. +The parameter ``user_id`` can be used to select only users with user ids that +contain this value. +The parameter ``guests=false`` can be used to exclude guest users, +default is to include guest users. +The parameter ``deactivated=true`` can be used to include deactivated users, +default is to exclude deactivated users. +If the endpoint does not return a ``next_token`` then there are no more users left. +It returns a JSON body like the following: + +.. code:: json + + { + "users": [ + { + "name": "<user_id1>", + "password_hash": "<password_hash1>", + "is_guest": 0, + "admin": 0, + "user_type": null, + "deactivated": 0 + }, { + "name": "<user_id2>", + "password_hash": "<password_hash2>", + "is_guest": 0, + "admin": 1, + "user_type": null, + "deactivated": 0 + } + ], + "next_token": "100" + } + + Query Account ============= diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 896159394c..10664ae8f7 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -54,15 +54,16 @@ pid_file: DATADIR/homeserver.pid # #require_auth_for_profile_requests: true -# If set to 'false', requires authentication to access the server's public rooms -# directory through the client API. Defaults to 'true'. +# If set to 'true', removes the need for authentication to access the server's +# public rooms directory through the client API, meaning that anyone can +# query the room directory. Defaults to 'false'. # -#allow_public_rooms_without_auth: false +#allow_public_rooms_without_auth: true -# If set to 'false', forbids any other homeserver to fetch the server's public -# rooms directory via federation. Defaults to 'true'. +# If set to 'true', allows any other homeserver to fetch the server's public +# rooms directory via federation. Defaults to 'false'. # -#allow_public_rooms_over_federation: false +#allow_public_rooms_over_federation: true # The default room version for newly created rooms. # @@ -328,6 +329,69 @@ listeners: # #user_ips_max_age: 14d +# Message retention policy at the server level. +# +# Room admins and mods can define a retention period for their rooms using the +# 'm.room.retention' state event, and server admins can cap this period by setting +# the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options. +# +# If this feature is enabled, Synapse will regularly look for and purge events +# which are older than the room's maximum retention period. Synapse will also +# filter events received over federation so that events that should have been +# purged are ignored and not stored again. +# +retention: + # The message retention policies feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # Default retention policy. If set, Synapse will apply it to rooms that lack the + # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't + # matter much because Synapse doesn't take it into account yet. + # + #default_policy: + # min_lifetime: 1d + # max_lifetime: 1y + + # Retention policy limits. If set, a user won't be able to send a + # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime' + # that's not within this range. This is especially useful in closed federations, + # in which server admins can make sure every federating server applies the same + # rules. + # + #allowed_lifetime_min: 1d + #allowed_lifetime_max: 1y + + # Server admins can define the settings of the background jobs purging the + # events which lifetime has expired under the 'purge_jobs' section. + # + # If no configuration is provided, a single job will be set up to delete expired + # events in every room daily. + # + # Each job's configuration defines which range of message lifetimes the job + # takes care of. For example, if 'shortest_max_lifetime' is '2d' and + # 'longest_max_lifetime' is '3d', the job will handle purging expired events in + # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and + # lower than or equal to 3 days. Both the minimum and the maximum value of a + # range are optional, e.g. a job with no 'shortest_max_lifetime' and a + # 'longest_max_lifetime' of '3d' will handle every room with a retention policy + # which 'max_lifetime' is lower than or equal to three days. + # + # The rationale for this per-job configuration is that some rooms might have a + # retention policy with a low 'max_lifetime', where history needs to be purged + # of outdated messages on a very frequent basis (e.g. every 5min), but not want + # that purge to be performed by a job that's iterating over every room it knows, + # which would be quite heavy on the server. + # + #purge_jobs: + # - shortest_max_lifetime: 1d + # longest_max_lifetime: 3d + # interval: 5m: + # - shortest_max_lifetime: 3d + # longest_max_lifetime: 1y + # interval: 24h + ## TLS ## @@ -1270,8 +1334,23 @@ password_config: # smtp_user: "exampleusername" # smtp_pass: "examplepassword" # require_transport_security: false +# +# # notif_from defines the "From" address to use when sending emails. +# # It must be set if email sending is enabled. +# # +# # The placeholder '%(app)s' will be replaced by the application name, +# # which is normally 'app_name' (below), but may be overridden by the +# # Matrix client application. +# # +# # Note that the placeholder must be written '%(app)s', including the +# # trailing 's'. +# # # notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>" -# app_name: Matrix +# +# # app_name defines the default value for '%(app)s' in notif_from. It +# # defaults to 'Matrix'. +# # +# #app_name: my_branded_matrix_server # # # Enable email notifications by default # # diff --git a/scripts-dev/hash_history.py b/scripts-dev/hash_history.py index d20f6db176..bf3862a386 100644 --- a/scripts-dev/hash_history.py +++ b/scripts-dev/hash_history.py @@ -27,7 +27,7 @@ class Store(object): "_store_pdu_reference_hash_txn" ] _store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"] - _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"] + simple_insert_txn = SQLBaseStore.__dict__["simple_insert_txn"] store = Store() diff --git a/scripts-dev/update_database b/scripts-dev/update_database index 27a1ad1e7e..1776d202c5 100755 --- a/scripts-dev/update_database +++ b/scripts-dev/update_database @@ -58,10 +58,10 @@ if __name__ == "__main__": " on it." ) ) - parser.add_argument("-v", action='store_true') + parser.add_argument("-v", action="store_true") parser.add_argument( "--database-config", - type=argparse.FileType('r'), + type=argparse.FileType("r"), required=True, help="A database config file for either a SQLite3 database or a PostgreSQL one.", ) @@ -101,10 +101,7 @@ if __name__ == "__main__": # Instantiate and initialise the homeserver object. hs = MockHomeserver( - config, - database_engine, - db_conn, - db_config=config.database_config, + config, database_engine, db_conn, db_config=config.database_config, ) # setup instantiates the store within the homeserver object. hs.setup() @@ -112,13 +109,13 @@ if __name__ == "__main__": @defer.inlineCallbacks def run_background_updates(): - yield store.run_background_updates(sleep=False) + yield store.db.updates.run_background_updates(sleep=False) # Stop the reactor to exit the script once every background update is run. reactor.stop() # Apply all background updates on the database. - reactor.callWhenRunning(lambda: run_as_background_process( - "background_updates", run_background_updates - )) + reactor.callWhenRunning( + lambda: run_as_background_process("background_updates", run_background_updates) + ) reactor.run() diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 0d3321682c..e393a9b2f7 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -47,6 +47,7 @@ from synapse.storage.data_stores.main.media_repository import ( from synapse.storage.data_stores.main.registration import ( RegistrationBackgroundUpdateStore, ) +from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore from synapse.storage.data_stores.main.search import SearchBackgroundUpdateStore from synapse.storage.data_stores.main.state import StateBackgroundUpdateStore @@ -54,6 +55,7 @@ from synapse.storage.data_stores.main.stats import StatsStore from synapse.storage.data_stores.main.user_directory import ( UserDirectoryBackgroundUpdateStore, ) +from synapse.storage.database import Database from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database from synapse.util import Clock @@ -131,54 +133,22 @@ class Store( EventsBackgroundUpdatesStore, MediaRepositoryBackgroundUpdateStore, RegistrationBackgroundUpdateStore, + RoomBackgroundUpdateStore, RoomMemberBackgroundUpdateStore, SearchBackgroundUpdateStore, StateBackgroundUpdateStore, UserDirectoryBackgroundUpdateStore, StatsStore, ): - def __init__(self, db_conn, hs): - super().__init__(db_conn, hs) - self.db_pool = hs.get_db_pool() - - @defer.inlineCallbacks - def runInteraction(self, desc, func, *args, **kwargs): - def r(conn): - try: - i = 0 - N = 5 - while True: - try: - txn = conn.cursor() - return func( - LoggingTransaction(txn, desc, self.database_engine, [], []), - *args, - **kwargs - ) - except self.database_engine.module.DatabaseError as e: - if self.database_engine.is_deadlock(e): - logger.warning("[TXN DEADLOCK] {%s} %d/%d", desc, i, N) - if i < N: - i += 1 - conn.rollback() - continue - raise - except Exception as e: - logger.debug("[TXN FAIL] {%s} %s", desc, e) - raise - - with PreserveLoggingContext(): - return (yield self.db_pool.runWithConnection(r)) - def execute(self, f, *args, **kwargs): - return self.runInteraction(f.__name__, f, *args, **kwargs) + return self.db.runInteraction(f.__name__, f, *args, **kwargs) def execute_sql(self, sql, *args): def r(txn): txn.execute(sql, args) return txn.fetchall() - return self.runInteraction("execute_sql", r) + return self.db.runInteraction("execute_sql", r) def insert_many_txn(self, txn, table, headers, rows): sql = "INSERT INTO %s (%s) VALUES (%s)" % ( @@ -221,7 +191,7 @@ class Porter(object): def setup_table(self, table): if table in APPEND_ONLY_TABLES: # It's safe to just carry on inserting. - row = yield self.postgres_store._simple_select_one( + row = yield self.postgres_store.db.simple_select_one( table="port_from_sqlite3", keyvalues={"table_name": table}, retcols=("forward_rowid", "backward_rowid"), @@ -231,12 +201,14 @@ class Porter(object): total_to_port = None if row is None: if table == "sent_transactions": - forward_chunk, already_ported, total_to_port = ( - yield self._setup_sent_transactions() - ) + ( + forward_chunk, + already_ported, + total_to_port, + ) = yield self._setup_sent_transactions() backward_chunk = 0 else: - yield self.postgres_store._simple_insert( + yield self.postgres_store.db.simple_insert( table="port_from_sqlite3", values={ "table_name": table, @@ -266,7 +238,7 @@ class Porter(object): yield self.postgres_store.execute(delete_all) - yield self.postgres_store._simple_insert( + yield self.postgres_store.db.simple_insert( table="port_from_sqlite3", values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, ) @@ -320,7 +292,7 @@ class Porter(object): if table == "user_directory_stream_pos": # We need to make sure there is a single row, `(X, null), as that is # what synapse expects to be there. - yield self.postgres_store._simple_insert( + yield self.postgres_store.db.simple_insert( table=table, values={"stream_id": None} ) self.progress.update(table, table_size) # Mark table as done @@ -361,7 +333,9 @@ class Porter(object): return headers, forward_rows, backward_rows - headers, frows, brows = yield self.sqlite_store.runInteraction("select", r) + headers, frows, brows = yield self.sqlite_store.db.runInteraction( + "select", r + ) if frows or brows: if frows: @@ -375,7 +349,7 @@ class Porter(object): def insert(txn): self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) - self.postgres_store._simple_update_one_txn( + self.postgres_store.db.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": table}, @@ -414,7 +388,7 @@ class Porter(object): return headers, rows - headers, rows = yield self.sqlite_store.runInteraction("select", r) + headers, rows = yield self.sqlite_store.db.runInteraction("select", r) if rows: forward_chunk = rows[-1][0] + 1 @@ -431,8 +405,8 @@ class Porter(object): rows_dict = [] for row in rows: d = dict(zip(headers, row)) - if "\0" in d['value']: - logger.warning('dropping search row %s', d) + if "\0" in d["value"]: + logger.warning("dropping search row %s", d) else: rows_dict.append(d) @@ -452,7 +426,7 @@ class Porter(object): ], ) - self.postgres_store._simple_update_one_txn( + self.postgres_store.db.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": "event_search"}, @@ -502,17 +476,14 @@ class Porter(object): self.progress.set_state("Preparing %s" % config["name"]) conn = self.setup_db(config, engine) - db_pool = adbapi.ConnectionPool( - config["name"], **config["args"] - ) + db_pool = adbapi.ConnectionPool(config["name"], **config["args"]) hs = MockHomeserver(self.hs_config, engine, conn, db_pool) - store = Store(conn, hs) + store = Store(Database(hs), conn, hs) - yield store.runInteraction( - "%s_engine.check_database" % config["name"], - engine.check_database, + yield store.db.runInteraction( + "%s_engine.check_database" % config["name"], engine.check_database, ) return store @@ -520,7 +491,9 @@ class Porter(object): @defer.inlineCallbacks def run_background_updates_on_postgres(self): # Manually apply all background updates on the PostgreSQL database. - postgres_ready = yield self.postgres_store.has_completed_background_updates() + postgres_ready = ( + yield self.postgres_store.db.updates.has_completed_background_updates() + ) if not postgres_ready: # Only say that we're running background updates when there are background @@ -528,9 +501,9 @@ class Porter(object): self.progress.set_state("Running background updates on PostgreSQL") while not postgres_ready: - yield self.postgres_store.do_next_background_update(100) + yield self.postgres_store.db.updates.do_next_background_update(100) postgres_ready = yield ( - self.postgres_store.has_completed_background_updates() + self.postgres_store.db.updates.has_completed_background_updates() ) @defer.inlineCallbacks @@ -539,7 +512,9 @@ class Porter(object): self.sqlite_store = yield self.build_db_store(self.sqlite_config) # Check if all background updates are done, abort if not. - updates_complete = yield self.sqlite_store.has_completed_background_updates() + updates_complete = ( + yield self.sqlite_store.db.updates.has_completed_background_updates() + ) if not updates_complete: sys.stderr.write( "Pending background updates exist in the SQLite3 database." @@ -580,22 +555,22 @@ class Porter(object): ) try: - yield self.postgres_store.runInteraction("alter_table", alter_table) + yield self.postgres_store.db.runInteraction("alter_table", alter_table) except Exception: # On Error Resume Next pass - yield self.postgres_store.runInteraction( + yield self.postgres_store.db.runInteraction( "create_port_table", create_port_table ) # Step 2. Get tables. self.progress.set_state("Fetching tables") - sqlite_tables = yield self.sqlite_store._simple_select_onecol( + sqlite_tables = yield self.sqlite_store.db.simple_select_onecol( table="sqlite_master", keyvalues={"type": "table"}, retcol="name" ) - postgres_tables = yield self.postgres_store._simple_select_onecol( + postgres_tables = yield self.postgres_store.db.simple_select_onecol( table="information_schema.tables", keyvalues={}, retcol="distinct table_name", @@ -685,11 +660,11 @@ class Porter(object): rows = txn.fetchall() headers = [column[0] for column in txn.description] - ts_ind = headers.index('ts') + ts_ind = headers.index("ts") return headers, [r for r in rows if r[ts_ind] < yesterday] - headers, rows = yield self.sqlite_store.runInteraction("select", r) + headers, rows = yield self.sqlite_store.db.runInteraction("select", r) rows = self._convert_rows("sent_transactions", headers, rows) @@ -722,7 +697,7 @@ class Porter(object): next_chunk = yield self.sqlite_store.execute(get_start_id) next_chunk = max(max_inserted_rowid + 1, next_chunk) - yield self.postgres_store._simple_insert( + yield self.postgres_store.db.simple_insert( table="port_from_sqlite3", values={ "table_name": "sent_transactions", @@ -735,7 +710,7 @@ class Porter(object): txn.execute( "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,) ) - size, = txn.fetchone() + (size,) = txn.fetchone() return int(size) remaining_count = yield self.sqlite_store.execute(get_sent_table_size) @@ -782,10 +757,13 @@ class Porter(object): def _setup_state_group_id_seq(self): def r(txn): txn.execute("SELECT MAX(id) FROM state_groups") - next_id = txn.fetchone()[0] + 1 + curr_id = txn.fetchone()[0] + if not curr_id: + return + next_id = curr_id + 1 txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) - return self.postgres_store.runInteraction("setup_state_group_id_seq", r) + return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r) ############################################## @@ -866,7 +844,7 @@ class CursesProgress(Progress): duration = int(now) - int(self.start_time) minutes, seconds = divmod(duration, 60) - duration_str = '%02dm %02ds' % (minutes, seconds) + duration_str = "%02dm %02ds" % (minutes, seconds) if self.finished: status = "Time spent: %s (Done!)" % (duration_str,) @@ -876,7 +854,7 @@ class CursesProgress(Progress): left = float(self.total_remaining) / self.total_processed est_remaining = (int(now) - self.start_time) * left - est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60) + est_remaining_str = "%02dm %02ds remaining" % divmod(est_remaining, 60) else: est_remaining_str = "Unknown" status = "Time spent: %s (est. remaining: %s)" % ( @@ -962,7 +940,7 @@ if __name__ == "__main__": description="A script to port an existing synapse SQLite database to" " a new PostgreSQL database." ) - parser.add_argument("-v", action='store_true') + parser.add_argument("-v", action="store_true") parser.add_argument( "--sqlite-database", required=True, @@ -971,12 +949,12 @@ if __name__ == "__main__": ) parser.add_argument( "--postgres-config", - type=argparse.FileType('r'), + type=argparse.FileType("r"), required=True, help="The database config file for the PostgreSQL database", ) parser.add_argument( - "--curses", action='store_true', help="display a curses based progress UI" + "--curses", action="store_true", help="display a curses based progress UI" ) parser.add_argument( @@ -1052,3 +1030,4 @@ if __name__ == "__main__": if end_error_exec_info: exc_type, exc_value, exc_traceback = end_error_exec_info traceback.print_exception(exc_type, exc_value, exc_traceback) + sys.exit(5) diff --git a/synapse/__init__.py b/synapse/__init__.py index f99de2f3f3..c67a51a8d5 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -36,7 +36,7 @@ try: except ImportError: pass -__version__ = "1.6.1" +__version__ = "1.7.0rc1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 49c4b85054..0ade47e624 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -94,6 +95,8 @@ class EventTypes(object): ServerACL = "m.room.server_acl" Pinned = "m.room.pinned_events" + Retention = "m.room.retention" + class RejectedReason(object): AUTH_ERROR = "auth_error" @@ -145,3 +148,7 @@ class EventContentFields(object): # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 LABELS = "org.matrix.labels" + + # Timestamp to delete the event after + # cf https://github.com/matrix-org/matrix-doc/pull/2228 + SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after" diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index bec13f08d8..6eab1f13f0 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 2ac7d5c064..9c96816096 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -269,7 +269,7 @@ def start(hs, listeners=None): # It is now safe to start your Synapse. hs.start_listening(listeners) - hs.get_datastore().start_profiling() + hs.get_datastore().db.start_profiling() setup_sentry(hs) setup_sdnotify(hs) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 448e45e00f..f24920a7d6 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -40,6 +40,7 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.streams._base import ReceiptsStream from synapse.server import HomeServer +from synapse.storage.database import Database from synapse.storage.engines import create_engine from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer @@ -59,8 +60,8 @@ class FederationSenderSlaveStore( SlavedDeviceStore, SlavedPresenceStore, ): - def __init__(self, db_conn, hs): - super(FederationSenderSlaveStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(FederationSenderSlaveStore, self).__init__(database, db_conn, hs) # We pull out the current federation stream position now so that we # always have a known value for the federation position in memory so diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 883b3fb70b..df65d0a989 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -68,9 +68,9 @@ from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer -from synapse.storage import DataStore, are_all_users_on_domain +from synapse.storage import DataStore from synapse.storage.engines import IncorrectDatabaseSetup, create_engine -from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database +from synapse.storage.prepare_database import UpgradeDatabaseException from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole @@ -294,22 +294,6 @@ class SynapseHomeServer(HomeServer): else: logger.warning("Unrecognized listener type: %s", listener["type"]) - def run_startup_checks(self, db_conn, database_engine): - all_users_native = are_all_users_on_domain( - db_conn.cursor(), database_engine, self.hostname - ) - if not all_users_native: - quit_with_error( - "Found users in database not native to %s!\n" - "You cannot changed a synapse server_name after it's been configured" - % (self.hostname,) - ) - - try: - database_engine.check_database(db_conn.cursor()) - except IncorrectDatabaseSetup as e: - quit_with_error(str(e)) - # Gauges to expose monthly active user control metrics current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU") @@ -357,16 +341,12 @@ def setup(config_options): synapse.config.logger.setup_logging(hs, config, use_worker_options=False) - logger.info("Preparing database: %s...", config.database_config["name"]) + logger.info("Setting up server") try: - with hs.get_db_conn(run_new_connection=False) as db_conn: - prepare_database(db_conn, database_engine, config=config) - database_engine.on_new_connection(db_conn) - - hs.run_startup_checks(db_conn, database_engine) - - db_conn.commit() + hs.setup() + except IncorrectDatabaseSetup as e: + quit_with_error(str(e)) except UpgradeDatabaseException: sys.stderr.write( "\nFailed to upgrade database.\n" @@ -375,9 +355,6 @@ def setup(config_options): ) sys.exit(1) - logger.info("Database prepared in %s.", config.database_config["name"]) - - hs.setup() hs.setup_master() @defer.inlineCallbacks @@ -436,7 +413,7 @@ def setup(config_options): _base.start(hs, config.listeners) hs.get_pusherpool().start() - hs.get_datastore().start_doing_background_updates() + hs.get_datastore().db.updates.start_doing_background_updates() except Exception: # Print the exception and bail out. print("Error during startup:", file=sys.stderr) @@ -542,8 +519,8 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): # Database version # - stats["database_engine"] = hs.get_datastore().database_engine_name - stats["database_server_version"] = hs.get_datastore().get_server_version() + stats["database_engine"] = hs.database_engine.module.__name__ + stats["database_server_version"] = hs.database_engine.server_version logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) try: yield hs.get_proxied_http_client().put_json( diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index b14da09f47..288ee64b42 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -151,7 +151,7 @@ class SynchrotronPresence(object): def set_state(self, user, state, ignore_status_msg=False): # TODO Hows this supposed to work? - pass + return defer.succeed(None) get_states = __func__(PresenceHandler.get_states) get_state = __func__(PresenceHandler.get_state) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 6cb100319f..c01fb34a9b 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -43,6 +43,7 @@ from synapse.replication.tcp.streams.events import ( from synapse.rest.client.v2_alpha import user_directory from synapse.server import HomeServer from synapse.storage.data_stores.main.user_directory import UserDirectoryStore +from synapse.storage.database import Database from synapse.storage.engines import create_engine from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.httpresourcetree import create_resource_tree @@ -60,11 +61,11 @@ class UserDirectorySlaveStore( UserDirectoryStore, BaseSlavedStore, ): - def __init__(self, db_conn, hs): - super(UserDirectorySlaveStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(UserDirectorySlaveStore, self).__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( + curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index ac1724045f..18f42a87f9 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -307,8 +307,23 @@ class EmailConfig(Config): # smtp_user: "exampleusername" # smtp_pass: "examplepassword" # require_transport_security: false + # + # # notif_from defines the "From" address to use when sending emails. + # # It must be set if email sending is enabled. + # # + # # The placeholder '%(app)s' will be replaced by the application name, + # # which is normally 'app_name' (below), but may be overridden by the + # # Matrix client application. + # # + # # Note that the placeholder must be written '%(app)s', including the + # # trailing 's'. + # # # notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>" - # app_name: Matrix + # + # # app_name defines the default value for '%(app)s' in notif_from. It + # # defaults to 'Matrix'. + # # + # #app_name: my_branded_matrix_server # # # Enable email notifications by default # # diff --git a/synapse/config/server.py b/synapse/config/server.py index 11336d7549..a4bef00936 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -19,7 +19,7 @@ import logging import os.path import re from textwrap import indent -from typing import List +from typing import Dict, List, Optional import attr import yaml @@ -118,15 +118,16 @@ class ServerConfig(Config): self.allow_public_rooms_without_auth = False self.allow_public_rooms_over_federation = False else: - # If set to 'False', requires authentication to access the server's public - # rooms directory through the client API. Defaults to 'True'. + # If set to 'true', removes the need for authentication to access the server's + # public rooms directory through the client API, meaning that anyone can + # query the room directory. Defaults to 'false'. self.allow_public_rooms_without_auth = config.get( - "allow_public_rooms_without_auth", True + "allow_public_rooms_without_auth", False ) - # If set to 'False', forbids any other homeserver to fetch the server's public - # rooms directory via federation. Defaults to 'True'. + # If set to 'true', allows any other homeserver to fetch the server's public + # rooms directory via federation. Defaults to 'false'. self.allow_public_rooms_over_federation = config.get( - "allow_public_rooms_over_federation", True + "allow_public_rooms_over_federation", False ) default_room_version = config.get("default_room_version", DEFAULT_ROOM_VERSION) @@ -246,6 +247,124 @@ class ServerConfig(Config): # events with profile information that differ from the target's global profile. self.allow_per_room_profiles = config.get("allow_per_room_profiles", True) + retention_config = config.get("retention") + if retention_config is None: + retention_config = {} + + self.retention_enabled = retention_config.get("enabled", False) + + retention_default_policy = retention_config.get("default_policy") + + if retention_default_policy is not None: + self.retention_default_min_lifetime = retention_default_policy.get( + "min_lifetime" + ) + if self.retention_default_min_lifetime is not None: + self.retention_default_min_lifetime = self.parse_duration( + self.retention_default_min_lifetime + ) + + self.retention_default_max_lifetime = retention_default_policy.get( + "max_lifetime" + ) + if self.retention_default_max_lifetime is not None: + self.retention_default_max_lifetime = self.parse_duration( + self.retention_default_max_lifetime + ) + + if ( + self.retention_default_min_lifetime is not None + and self.retention_default_max_lifetime is not None + and ( + self.retention_default_min_lifetime + > self.retention_default_max_lifetime + ) + ): + raise ConfigError( + "The default retention policy's 'min_lifetime' can not be greater" + " than its 'max_lifetime'" + ) + else: + self.retention_default_min_lifetime = None + self.retention_default_max_lifetime = None + + self.retention_allowed_lifetime_min = retention_config.get( + "allowed_lifetime_min" + ) + if self.retention_allowed_lifetime_min is not None: + self.retention_allowed_lifetime_min = self.parse_duration( + self.retention_allowed_lifetime_min + ) + + self.retention_allowed_lifetime_max = retention_config.get( + "allowed_lifetime_max" + ) + if self.retention_allowed_lifetime_max is not None: + self.retention_allowed_lifetime_max = self.parse_duration( + self.retention_allowed_lifetime_max + ) + + if ( + self.retention_allowed_lifetime_min is not None + and self.retention_allowed_lifetime_max is not None + and self.retention_allowed_lifetime_min + > self.retention_allowed_lifetime_max + ): + raise ConfigError( + "Invalid retention policy limits: 'allowed_lifetime_min' can not be" + " greater than 'allowed_lifetime_max'" + ) + + self.retention_purge_jobs = [] # type: List[Dict[str, Optional[int]]] + for purge_job_config in retention_config.get("purge_jobs", []): + interval_config = purge_job_config.get("interval") + + if interval_config is None: + raise ConfigError( + "A retention policy's purge jobs configuration must have the" + " 'interval' key set." + ) + + interval = self.parse_duration(interval_config) + + shortest_max_lifetime = purge_job_config.get("shortest_max_lifetime") + + if shortest_max_lifetime is not None: + shortest_max_lifetime = self.parse_duration(shortest_max_lifetime) + + longest_max_lifetime = purge_job_config.get("longest_max_lifetime") + + if longest_max_lifetime is not None: + longest_max_lifetime = self.parse_duration(longest_max_lifetime) + + if ( + shortest_max_lifetime is not None + and longest_max_lifetime is not None + and shortest_max_lifetime > longest_max_lifetime + ): + raise ConfigError( + "A retention policy's purge jobs configuration's" + " 'shortest_max_lifetime' value can not be greater than its" + " 'longest_max_lifetime' value." + ) + + self.retention_purge_jobs.append( + { + "interval": interval, + "shortest_max_lifetime": shortest_max_lifetime, + "longest_max_lifetime": longest_max_lifetime, + } + ) + + if not self.retention_purge_jobs: + self.retention_purge_jobs = [ + { + "interval": self.parse_duration("1d"), + "shortest_max_lifetime": None, + "longest_max_lifetime": None, + } + ] + self.listeners = [] # type: List[dict] for listener in config.get("listeners", []): if not isinstance(listener.get("port", None), int): @@ -372,6 +491,8 @@ class ServerConfig(Config): "cleanup_extremities_with_dummy_events", True ) + self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False) + def has_tls_listener(self) -> bool: return any(l["tls"] for l in self.listeners) @@ -500,15 +621,16 @@ class ServerConfig(Config): # #require_auth_for_profile_requests: true - # If set to 'false', requires authentication to access the server's public rooms - # directory through the client API. Defaults to 'true'. + # If set to 'true', removes the need for authentication to access the server's + # public rooms directory through the client API, meaning that anyone can + # query the room directory. Defaults to 'false'. # - #allow_public_rooms_without_auth: false + #allow_public_rooms_without_auth: true - # If set to 'false', forbids any other homeserver to fetch the server's public - # rooms directory via federation. Defaults to 'true'. + # If set to 'true', allows any other homeserver to fetch the server's public + # rooms directory via federation. Defaults to 'false'. # - #allow_public_rooms_over_federation: false + #allow_public_rooms_over_federation: true # The default room version for newly created rooms. # @@ -761,6 +883,69 @@ class ServerConfig(Config): # Defaults to `28d`. Set to `null` to disable clearing out of old rows. # #user_ips_max_age: 14d + + # Message retention policy at the server level. + # + # Room admins and mods can define a retention period for their rooms using the + # 'm.room.retention' state event, and server admins can cap this period by setting + # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options. + # + # If this feature is enabled, Synapse will regularly look for and purge events + # which are older than the room's maximum retention period. Synapse will also + # filter events received over federation so that events that should have been + # purged are ignored and not stored again. + # + retention: + # The message retention policies feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # Default retention policy. If set, Synapse will apply it to rooms that lack the + # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't + # matter much because Synapse doesn't take it into account yet. + # + #default_policy: + # min_lifetime: 1d + # max_lifetime: 1y + + # Retention policy limits. If set, a user won't be able to send a + # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime' + # that's not within this range. This is especially useful in closed federations, + # in which server admins can make sure every federating server applies the same + # rules. + # + #allowed_lifetime_min: 1d + #allowed_lifetime_max: 1y + + # Server admins can define the settings of the background jobs purging the + # events which lifetime has expired under the 'purge_jobs' section. + # + # If no configuration is provided, a single job will be set up to delete expired + # events in every room daily. + # + # Each job's configuration defines which range of message lifetimes the job + # takes care of. For example, if 'shortest_max_lifetime' is '2d' and + # 'longest_max_lifetime' is '3d', the job will handle purging expired events in + # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and + # lower than or equal to 3 days. Both the minimum and the maximum value of a + # range are optional, e.g. a job with no 'shortest_max_lifetime' and a + # 'longest_max_lifetime' of '3d' will handle every room with a retention policy + # which 'max_lifetime' is lower than or equal to three days. + # + # The rationale for this per-job configuration is that some rooms might have a + # retention policy with a low 'max_lifetime', where history needs to be purged + # of outdated messages on a very frequent basis (e.g. every 5min), but not want + # that purge to be performed by a job that's iterating over every room it knows, + # which would be quite heavy on the server. + # + #purge_jobs: + # - shortest_max_lifetime: 1d + # longest_max_lifetime: 3d + # interval: 5m: + # - shortest_max_lifetime: 3d + # longest_max_lifetime: 1y + # interval: 24h """ % locals() ) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 272426e105..9b90c9ce04 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six import string_types +from six import integer_types, string_types from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError @@ -22,11 +22,12 @@ from synapse.types import EventID, RoomID, UserID class EventValidator(object): - def validate_new(self, event): + def validate_new(self, event, config): """Validates the event has roughly the right format Args: - event (FrozenEvent) + event (FrozenEvent): The event to validate. + config (Config): The homeserver's configuration. """ self.validate_builder(event) @@ -67,6 +68,99 @@ class EventValidator(object): Codes.INVALID_PARAM, ) + if event.type == EventTypes.Retention: + self._validate_retention(event, config) + + def _validate_retention(self, event, config): + """Checks that an event that defines the retention policy for a room respects the + boundaries imposed by the server's administrator. + + Args: + event (FrozenEvent): The event to validate. + config (Config): The homeserver's configuration. + """ + min_lifetime = event.content.get("min_lifetime") + max_lifetime = event.content.get("max_lifetime") + + if min_lifetime is not None: + if not isinstance(min_lifetime, integer_types): + raise SynapseError( + code=400, + msg="'min_lifetime' must be an integer", + errcode=Codes.BAD_JSON, + ) + + if ( + config.retention_allowed_lifetime_min is not None + and min_lifetime < config.retention_allowed_lifetime_min + ): + raise SynapseError( + code=400, + msg=( + "'min_lifetime' can't be lower than the minimum allowed" + " value enforced by the server's administrator" + ), + errcode=Codes.BAD_JSON, + ) + + if ( + config.retention_allowed_lifetime_max is not None + and min_lifetime > config.retention_allowed_lifetime_max + ): + raise SynapseError( + code=400, + msg=( + "'min_lifetime' can't be greater than the maximum allowed" + " value enforced by the server's administrator" + ), + errcode=Codes.BAD_JSON, + ) + + if max_lifetime is not None: + if not isinstance(max_lifetime, integer_types): + raise SynapseError( + code=400, + msg="'max_lifetime' must be an integer", + errcode=Codes.BAD_JSON, + ) + + if ( + config.retention_allowed_lifetime_min is not None + and max_lifetime < config.retention_allowed_lifetime_min + ): + raise SynapseError( + code=400, + msg=( + "'max_lifetime' can't be lower than the minimum allowed value" + " enforced by the server's administrator" + ), + errcode=Codes.BAD_JSON, + ) + + if ( + config.retention_allowed_lifetime_max is not None + and max_lifetime > config.retention_allowed_lifetime_max + ): + raise SynapseError( + code=400, + msg=( + "'max_lifetime' can't be greater than the maximum allowed" + " value enforced by the server's administrator" + ), + errcode=Codes.BAD_JSON, + ) + + if ( + min_lifetime is not None + and max_lifetime is not None + and min_lifetime > max_lifetime + ): + raise SynapseError( + code=400, + msg="'min_lifetime' can't be greater than 'max_lifetime", + errcode=Codes.BAD_JSON, + ) + def validate_builder(self, event): """Validates that the builder/event has roughly the right format. Only checks values that we expect a proto event to have, rather than all the diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 27f6aff004..709449c9e3 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -324,87 +324,32 @@ class FederationClient(FederationBase): A list of events in the state, and a list of events in the auth chain for the given event. """ - try: - # First we try and ask for just the IDs, as thats far quicker if - # we have most of the state and auth_chain already. - # However, this may 404 if the other side has an old synapse. - result = yield self.transport_layer.get_room_state_ids( - destination, room_id, event_id=event_id - ) - - state_event_ids = result["pdu_ids"] - auth_event_ids = result.get("auth_chain_ids", []) - - fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest( - destination, room_id, set(state_event_ids + auth_event_ids) - ) - - if failed_to_fetch: - logger.warning( - "Failed to fetch missing state/auth events for %s: %s", - room_id, - failed_to_fetch, - ) - - event_map = {ev.event_id: ev for ev in fetched_events} - - pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] - auth_chain = [ - event_map[e_id] for e_id in auth_event_ids if e_id in event_map - ] - - auth_chain.sort(key=lambda e: e.depth) - - return pdus, auth_chain - except HttpResponseException as e: - if e.code == 400 or e.code == 404: - logger.info("Failed to use get_room_state_ids API, falling back") - else: - raise e - - result = yield self.transport_layer.get_room_state( + result = yield self.transport_layer.get_room_state_ids( destination, room_id, event_id=event_id ) - room_version = yield self.store.get_room_version(room_id) - format_ver = room_version_to_event_format(room_version) - - pdus = [ - event_from_pdu_json(p, format_ver, outlier=True) for p in result["pdus"] - ] + state_event_ids = result["pdu_ids"] + auth_event_ids = result.get("auth_chain_ids", []) - auth_chain = [ - event_from_pdu_json(p, format_ver, outlier=True) - for p in result.get("auth_chain", []) - ] - - seen_events = yield self.store.get_events( - [ev.event_id for ev in itertools.chain(pdus, auth_chain)] + fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest( + destination, room_id, set(state_event_ids + auth_event_ids) ) - signed_pdus = yield self._check_sigs_and_hash_and_fetch( - destination, - [p for p in pdus if p.event_id not in seen_events], - outlier=True, - room_version=room_version, - ) - signed_pdus.extend( - seen_events[p.event_id] for p in pdus if p.event_id in seen_events - ) + if failed_to_fetch: + logger.warning( + "Failed to fetch missing state/auth events for %s: %s", + room_id, + failed_to_fetch, + ) - signed_auth = yield self._check_sigs_and_hash_and_fetch( - destination, - [p for p in auth_chain if p.event_id not in seen_events], - outlier=True, - room_version=room_version, - ) - signed_auth.extend( - seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events - ) + event_map = {ev.event_id: ev for ev in fetched_events} - signed_auth.sort(key=lambda e: e.depth) + pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] + auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] + + auth_chain.sort(key=lambda e: e.depth) - return signed_pdus, signed_auth + return pdus, auth_chain @defer.inlineCallbacks def get_events_from_store_or_dest(self, destination, room_id, event_ids): diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index d942d77a72..84d4eca041 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd +# Copyright 2019 Matrix.org Federation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -73,6 +74,7 @@ class FederationServer(FederationBase): self.auth = hs.get_auth() self.handler = hs.get_handlers().federation_handler + self.state = hs.get_state_handler() self._server_linearizer = Linearizer("fed_server") self._transaction_linearizer = Linearizer("fed_txn_handler") @@ -264,9 +266,6 @@ class FederationServer(FederationBase): await self.registry.on_edu(edu_type, origin, content) async def on_context_state_request(self, origin, room_id, event_id): - if not event_id: - raise NotImplementedError("Specify an event") - origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) @@ -280,13 +279,18 @@ class FederationServer(FederationBase): # - but that's non-trivial to get right, and anyway somewhat defeats # the point of the linearizer. with (await self._server_linearizer.queue((origin, room_id))): - resp = await self._state_resp_cache.wrap( - (room_id, event_id), - self._on_context_state_request_compute, - room_id, - event_id, + resp = dict( + await self._state_resp_cache.wrap( + (room_id, event_id), + self._on_context_state_request_compute, + room_id, + event_id, + ) ) + room_version = await self.store.get_room_version(room_id) + resp["room_version"] = room_version + return 200, resp async def on_state_ids_request(self, origin, room_id, event_id): @@ -306,7 +310,11 @@ class FederationServer(FederationBase): return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} async def _on_context_state_request_compute(self, room_id, event_id): - pdus = await self.handler.get_state_for_pdu(room_id, event_id) + if event_id: + pdus = await self.handler.get_state_for_pdu(room_id, event_id) + else: + pdus = (await self.state.get_current_state(room_id)).values() + auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus]) return { diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 8f9d6ac067..8082c29121 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -39,30 +39,6 @@ class TransportLayerClient(object): self.client = hs.get_http_client() @log_function - def get_room_state(self, destination, room_id, event_id): - """ Requests all state for a given room from the given server at the - given event. - - Args: - destination (str): The host name of the remote homeserver we want - to get the state from. - context (str): The name of the context we want the state of - event_id (str): The event we want the context at. - - Returns: - Deferred: Results in a dict received from the remote homeserver. - """ - logger.debug("get_room_state dest=%s, room=%s", destination, room_id) - - path = _create_v1_path("/state/%s", room_id) - return self.client.get_json( - destination, - path=path, - args={"event_id": event_id}, - try_trailing_slash_on_400=True, - ) - - @log_function def get_room_state_ids(self, destination, room_id, event_id): """ Requests all state for a given room from the given server at the given event. Returns the state's event_id's diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 09baa9c57d..fefc789c85 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -421,7 +421,7 @@ class FederationEventServlet(BaseFederationServlet): return await self.handler.on_pdu_request(origin, event_id) -class FederationStateServlet(BaseFederationServlet): +class FederationStateV1Servlet(BaseFederationServlet): PATH = "/state/(?P<context>[^/]*)/?" # This is when someone asks for all data for a given context. @@ -429,7 +429,7 @@ class FederationStateServlet(BaseFederationServlet): return await self.handler.on_context_state_request( origin, context, - parse_string_from_args(query, "event_id", None, required=True), + parse_string_from_args(query, "event_id", None, required=False), ) @@ -1360,7 +1360,7 @@ class RoomComplexityServlet(BaseFederationServlet): FEDERATION_SERVLET_CLASSES = ( FederationSendServlet, FederationEventServlet, - FederationStateServlet, + FederationStateV1Servlet, FederationStateIdsServlet, FederationBackfillServlet, FederationQueryServlet, diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 6407d56f8e..14449b9a1e 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -56,7 +56,7 @@ class AdminHandler(BaseHandler): @defer.inlineCallbacks def get_users(self): - """Function to reterive a list of users in users table. + """Function to retrieve a list of users in users table. Args: Returns: @@ -67,19 +67,22 @@ class AdminHandler(BaseHandler): return ret @defer.inlineCallbacks - def get_users_paginate(self, order, start, limit): - """Function to reterive a paginated list of users from - users list. This will return a json object, which contains - list of users and the total number of users in users table. + def get_users_paginate(self, start, limit, name, guests, deactivated): + """Function to retrieve a paginated list of users from + users list. This will return a json list of users. Args: - order (str): column name to order the select by this column start (int): start number to begin the query from - limit (int): number of rows to reterive + limit (int): number of rows to retrieve + name (string): filter for user names + guests (bool): whether to in include guest users + deactivated (bool): whether to include deactivated users Returns: - defer.Deferred: resolves to json object {list[dict[str, Any]], count} + defer.Deferred: resolves to json list[dict[str, Any]] """ - ret = yield self.store.get_users_paginate(order, start, limit) + ret = yield self.store.get_users_paginate( + start, limit, name, guests, deactivated + ) return ret diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 28c12753c1..57a10daefd 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -264,7 +264,6 @@ class E2eKeysHandler(object): return ret - @defer.inlineCallbacks def get_cross_signing_keys_from_cache(self, query, from_user_id): """Get cross-signing keys for users from the database @@ -284,35 +283,14 @@ class E2eKeysHandler(object): self_signing_keys = {} user_signing_keys = {} - for user_id in query: - # XXX: consider changing the store functions to allow querying - # multiple users simultaneously. - key = yield self.store.get_e2e_cross_signing_key( - user_id, "master", from_user_id - ) - if key: - master_keys[user_id] = key - - key = yield self.store.get_e2e_cross_signing_key( - user_id, "self_signing", from_user_id - ) - if key: - self_signing_keys[user_id] = key - - # users can see other users' master and self-signing keys, but can - # only see their own user-signing keys - if from_user_id == user_id: - key = yield self.store.get_e2e_cross_signing_key( - user_id, "user_signing", from_user_id - ) - if key: - user_signing_keys[user_id] = key - - return { - "master_keys": master_keys, - "self_signing_keys": self_signing_keys, - "user_signing_keys": user_signing_keys, - } + # Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486 + return defer.succeed( + { + "master_keys": master_keys, + "self_signing_keys": self_signing_keys, + "user_signing_keys": user_signing_keys, + } + ) @trace @defer.inlineCallbacks diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 0cea445f0d..f1b4424a02 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017, 2018 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -103,14 +104,35 @@ class E2eRoomKeysHandler(object): rooms session_id(string): session ID to delete keys for, for None to delete keys for all sessions + Raises: + NotFoundError: if the backup version does not exist Returns: - A deferred of the deletion transaction + A dict containing the count and etag for the backup version """ # lock for consistency with uploading with (yield self._upload_linearizer.queue(user_id)): + # make sure the backup version exists + try: + version_info = yield self.store.get_e2e_room_keys_version_info( + user_id, version + ) + except StoreError as e: + if e.code == 404: + raise NotFoundError("Unknown backup version") + else: + raise + yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) + version_etag = version_info["etag"] + 1 + yield self.store.update_e2e_room_keys_version( + user_id, version, None, version_etag + ) + + count = yield self.store.count_e2e_room_keys(user_id, version) + return {"etag": str(version_etag), "count": count} + @trace @defer.inlineCallbacks def upload_room_keys(self, user_id, version, room_keys): @@ -138,6 +160,9 @@ class E2eRoomKeysHandler(object): } } + Returns: + A dict containing the count and etag for the backup version + Raises: NotFoundError: if there are no versions defined RoomKeysVersionError: if the uploaded version is not the current version @@ -171,59 +196,62 @@ class E2eRoomKeysHandler(object): else: raise - # go through the room_keys. - # XXX: this should/could be done concurrently, given we're in a lock. + # Fetch any existing room keys for the sessions that have been + # submitted. Then compare them with the submitted keys. If the + # key is new, insert it; if the key should be updated, then update + # it; otherwise, drop it. + existing_keys = yield self.store.get_e2e_room_keys_multi( + user_id, version, room_keys["rooms"] + ) + to_insert = [] # batch the inserts together + changed = False # if anything has changed, we need to update the etag for room_id, room in iteritems(room_keys["rooms"]): - for session_id, session in iteritems(room["sessions"]): - yield self._upload_room_key( - user_id, version, room_id, session_id, session + for session_id, room_key in iteritems(room["sessions"]): + log_kv( + { + "message": "Trying to upload room key", + "room_id": room_id, + "session_id": session_id, + "user_id": user_id, + } ) - - @defer.inlineCallbacks - def _upload_room_key(self, user_id, version, room_id, session_id, room_key): - """Upload a given room_key for a given room and session into a given - version of the backup. Merges the key with any which might already exist. - - Args: - user_id(str): the user whose backup we're setting - version(str): the version ID of the backup we're updating - room_id(str): the ID of the room whose keys we're setting - session_id(str): the session whose room_key we're setting - room_key(dict): the room_key being set - """ - log_kv( - { - "message": "Trying to upload room key", - "room_id": room_id, - "session_id": session_id, - "user_id": user_id, - } - ) - # get the room_key for this particular row - current_room_key = None - try: - current_room_key = yield self.store.get_e2e_room_key( - user_id, version, room_id, session_id - ) - except StoreError as e: - if e.code == 404: - log_kv( - { - "message": "Room key not found.", - "room_id": room_id, - "user_id": user_id, - } + current_room_key = existing_keys.get(room_id, {}).get(session_id) + if current_room_key: + if self._should_replace_room_key(current_room_key, room_key): + log_kv({"message": "Replacing room key."}) + # updates are done one at a time in the DB, so send + # updates right away rather than batching them up, + # like we do with the inserts + yield self.store.update_e2e_room_key( + user_id, version, room_id, session_id, room_key + ) + changed = True + else: + log_kv({"message": "Not replacing room_key."}) + else: + log_kv( + { + "message": "Room key not found.", + "room_id": room_id, + "user_id": user_id, + } + ) + log_kv({"message": "Replacing room key."}) + to_insert.append((room_id, session_id, room_key)) + changed = True + + if len(to_insert): + yield self.store.add_e2e_room_keys(user_id, version, to_insert) + + version_etag = version_info["etag"] + if changed: + version_etag = version_etag + 1 + yield self.store.update_e2e_room_keys_version( + user_id, version, None, version_etag ) - else: - raise - if self._should_replace_room_key(current_room_key, room_key): - log_kv({"message": "Replacing room key."}) - yield self.store.set_e2e_room_key( - user_id, version, room_id, session_id, room_key - ) - else: - log_kv({"message": "Not replacing room_key."}) + count = yield self.store.count_e2e_room_keys(user_id, version) + return {"etag": str(version_etag), "count": count} @staticmethod def _should_replace_room_key(current_room_key, room_key): @@ -314,6 +342,8 @@ class E2eRoomKeysHandler(object): raise NotFoundError("Unknown backup version") else: raise + + res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"]) return res @trace diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 45fe13c62f..ec18a42a68 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -16,8 +16,6 @@ import logging import random -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase @@ -50,9 +48,8 @@ class EventStreamHandler(BaseHandler): self._server_notices_sender = hs.get_server_notices_sender() self._event_serializer = hs.get_event_client_serializer() - @defer.inlineCallbacks @log_function - def get_stream( + async def get_stream( self, auth_user_id, pagin_config, @@ -69,17 +66,17 @@ class EventStreamHandler(BaseHandler): """ if room_id: - blocked = yield self.store.is_room_blocked(room_id) + blocked = await self.store.is_room_blocked(room_id) if blocked: raise SynapseError(403, "This room has been blocked on this server") # send any outstanding server notices to the user. - yield self._server_notices_sender.on_user_syncing(auth_user_id) + await self._server_notices_sender.on_user_syncing(auth_user_id) auth_user = UserID.from_string(auth_user_id) presence_handler = self.hs.get_presence_handler() - context = yield presence_handler.user_syncing( + context = await presence_handler.user_syncing( auth_user_id, affect_presence=affect_presence ) with context: @@ -91,7 +88,7 @@ class EventStreamHandler(BaseHandler): # thundering herds on restart. timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1)) - events, tokens = yield self.notifier.get_events_for( + events, tokens = await self.notifier.get_events_for( auth_user, pagin_config, timeout, @@ -112,14 +109,14 @@ class EventStreamHandler(BaseHandler): # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = yield self.state.get_current_users_in_room( + users = await self.state.get_current_users_in_room( event.room_id ) - states = yield presence_handler.get_states(users, as_event=True) + states = await presence_handler.get_states(users, as_event=True) to_add.extend(states) else: - ev = yield presence_handler.get_state( + ev = await presence_handler.get_state( UserID.from_string(event.state_key), as_event=True ) to_add.append(ev) @@ -128,7 +125,7 @@ class EventStreamHandler(BaseHandler): time_now = self.clock.time_msec() - chunks = yield self._event_serializer.serialize_events( + chunks = await self._event_serializer.serialize_events( events, time_now, as_client_event=as_client_event, @@ -151,8 +148,7 @@ class EventHandler(BaseHandler): super(EventHandler, self).__init__(hs) self.storage = hs.get_storage() - @defer.inlineCallbacks - def get_event(self, user, room_id, event_id): + async def get_event(self, user, room_id, event_id): """Retrieve a single specified event. Args: @@ -167,15 +163,15 @@ class EventHandler(BaseHandler): AuthError if the user does not have the rights to inspect this event. """ - event = yield self.store.get_event(event_id, check_room_id=room_id) + event = await self.store.get_event(event_id, check_room_id=room_id) if not event: return None - users = yield self.store.get_users_in_room(event.room_id) + users = await self.store.get_users_in_room(event.room_id) is_peeking = user.to_string() not in users - filtered = yield filter_events_for_client( + filtered = await filter_events_for_client( self.storage, user.to_string(), [event], is_peeking=is_peeking ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 97d045db10..bc26921768 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -19,11 +19,13 @@ import itertools import logging +from typing import Dict, Iterable, Optional, Sequence, Tuple import six from six import iteritems, itervalues from six.moves import http_client, zip +import attr from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 @@ -45,6 +47,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.crypto.event_signing import compute_event_signature from synapse.event_auth import auth_types_for_event +from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.logging.context import ( @@ -72,6 +75,23 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) +@attr.s +class _NewEventInfo: + """Holds information about a received event, ready for passing to _handle_new_events + + Attributes: + event: the received event + + state: the state at that event + + auth_events: the auth_event map for that event + """ + + event = attr.ib(type=EventBase) + state = attr.ib(type=Optional[Sequence[EventBase]], default=None) + auth_events = attr.ib(type=Optional[Dict[Tuple[str, str], EventBase]], default=None) + + def shortstr(iterable, maxitems=5): """If iterable has maxitems or fewer, return the stringification of a list containing those items. @@ -121,6 +141,7 @@ class FederationHandler(BaseHandler): self.pusher_pool = hs.get_pusherpool() self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() + self._message_handler = hs.get_message_handler() self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config self.http_client = hs.get_simple_http_client() @@ -141,6 +162,8 @@ class FederationHandler(BaseHandler): self.third_party_event_rules = hs.get_third_party_event_rules() + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): """ Process a PDU received via a federation /send/ transaction, or @@ -594,14 +617,14 @@ class FederationHandler(BaseHandler): for e in auth_chain if e.event_id in auth_ids or e.type == EventTypes.Create } - event_infos.append({"event": e, "auth_events": auth}) + event_infos.append(_NewEventInfo(event=e, auth_events=auth)) seen_ids.add(e.event_id) logger.info( "[%s %s] persisting newly-received auth/state events %s", room_id, event_id, - [e["event"].event_id for e in event_infos], + [e.event.event_id for e in event_infos], ) yield self._handle_new_events(origin, event_infos) @@ -792,9 +815,9 @@ class FederationHandler(BaseHandler): a.internal_metadata.outlier = True ev_infos.append( - { - "event": a, - "auth_events": { + _NewEventInfo( + event=a, + auth_events={ ( auth_events[a_id].type, auth_events[a_id].state_key, @@ -802,7 +825,7 @@ class FederationHandler(BaseHandler): for a_id in a.auth_event_ids() if a_id in auth_events }, - } + ) ) # Step 1b: persist the events in the chunk we fetched state for (i.e. @@ -814,10 +837,10 @@ class FederationHandler(BaseHandler): assert not ev.internal_metadata.is_outlier() ev_infos.append( - { - "event": ev, - "state": events_to_state[e_id], - "auth_events": { + _NewEventInfo( + event=ev, + state=events_to_state[e_id], + auth_events={ ( auth_events[a_id].type, auth_events[a_id].state_key, @@ -825,7 +848,7 @@ class FederationHandler(BaseHandler): for a_id in ev.auth_event_ids() if a_id in auth_events }, - } + ) ) yield self._handle_new_events(dest, ev_infos, backfilled=True) @@ -1428,9 +1451,9 @@ class FederationHandler(BaseHandler): return event @defer.inlineCallbacks - def do_remotely_reject_invite(self, target_hosts, room_id, user_id): + def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content): origin, event, event_format_version = yield self._make_and_verify_event( - target_hosts, room_id, user_id, "leave" + target_hosts, room_id, user_id, "leave", content=content, ) # Mark as outlier as we don't have any state for this event; we're not # even in the room. @@ -1710,7 +1733,12 @@ class FederationHandler(BaseHandler): return context @defer.inlineCallbacks - def _handle_new_events(self, origin, event_infos, backfilled=False): + def _handle_new_events( + self, + origin: str, + event_infos: Iterable[_NewEventInfo], + backfilled: bool = False, + ): """Creates the appropriate contexts and persists events. The events should not depend on one another, e.g. this should be used to persist a bunch of outliers, but not a chunk of individual events that depend @@ -1720,14 +1748,14 @@ class FederationHandler(BaseHandler): """ @defer.inlineCallbacks - def prep(ev_info): - event = ev_info["event"] + def prep(ev_info: _NewEventInfo): + event = ev_info.event with nested_logging_context(suffix=event.event_id): res = yield self._prep_event( origin, event, - state=ev_info.get("state"), - auth_events=ev_info.get("auth_events"), + state=ev_info.state, + auth_events=ev_info.auth_events, backfilled=backfilled, ) return res @@ -1741,7 +1769,7 @@ class FederationHandler(BaseHandler): yield self.persist_events_and_notify( [ - (ev_info["event"], context) + (ev_info.event, context) for ev_info, context in zip(event_infos, contexts) ], backfilled=backfilled, @@ -1843,7 +1871,14 @@ class FederationHandler(BaseHandler): yield self.persist_events_and_notify([(event, new_event_context)]) @defer.inlineCallbacks - def _prep_event(self, origin, event, state, auth_events, backfilled): + def _prep_event( + self, + origin: str, + event: EventBase, + state: Optional[Iterable[EventBase]], + auth_events: Optional[Dict[Tuple[str, str], EventBase]], + backfilled: bool, + ): """ Args: @@ -1851,7 +1886,7 @@ class FederationHandler(BaseHandler): event: state: auth_events: - backfilled (bool) + backfilled: Returns: Deferred, which resolves to synapse.events.snapshot.EventContext @@ -1887,15 +1922,16 @@ class FederationHandler(BaseHandler): return context @defer.inlineCallbacks - def _check_for_soft_fail(self, event, state, backfilled): + def _check_for_soft_fail( + self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool + ): """Checks if we should soft fail the event, if so marks the event as such. Args: - event (FrozenEvent) - state (dict|None): The state at the event if we don't have all the - event's prev events - backfilled (bool): Whether the event is from backfill + event + state: The state at the event if we don't have all the event's prev events + backfilled: Whether the event is from backfill Returns: Deferred @@ -2119,14 +2155,9 @@ class FederationHandler(BaseHandler): # # we start by checking if they are in the store, and then try calling /event_auth/. if missing_auth: - # TODO: can we use store.have_seen_events here instead? - have_events = yield self.store.get_seen_events_with_rejections(missing_auth) - logger.debug("Found events %s in the store", have_events) - missing_auth.difference_update(have_events.keys()) - else: - have_events = {} - - have_events.update({e.event_id: "" for e in auth_events.values()}) + have_events = yield self.store.have_seen_events(missing_auth) + logger.debug("Events %s are in the store", have_events) + missing_auth.difference_update(have_events) if missing_auth: # If we don't have all the auth events, we need to get them. @@ -2172,9 +2203,6 @@ class FederationHandler(BaseHandler): except AuthError: pass - have_events = yield self.store.get_seen_events_with_rejections( - event.auth_event_ids() - ) except Exception: logger.exception("Failed to get auth chain") @@ -2200,43 +2228,53 @@ class FederationHandler(BaseHandler): different_auth, ) - # now we state-resolve between our own idea of the auth events, and the remote's - # idea of them. + # XXX: currently this checks for redactions but I'm not convinced that is + # necessary? + different_events = yield self.store.get_events_as_list(different_auth) - room_version = yield self.store.get_room_version(event.room_id) - different_event_ids = [ - d for d in different_auth if d in have_events and not have_events[d] - ] + for d in different_events: + if d.room_id != event.room_id: + logger.warning( + "Event %s refers to auth_event %s which is in a different room", + event.event_id, + d.event_id, + ) - if different_event_ids: - # XXX: currently this checks for redactions but I'm not convinced that is - # necessary? - different_events = yield self.store.get_events_as_list(different_event_ids) + # don't attempt to resolve the claimed auth events against our own + # in this case: just use our own auth events. + # + # XXX: should we reject the event in this case? It feels like we should, + # but then shouldn't we also do so if we've failed to fetch any of the + # auth events? + return context - local_view = dict(auth_events) - remote_view = dict(auth_events) - remote_view.update({(d.type, d.state_key): d for d in different_events}) + # now we state-resolve between our own idea of the auth events, and the remote's + # idea of them. - new_state = yield self.state_handler.resolve_events( - room_version, - [list(local_view.values()), list(remote_view.values())], - event, - ) + local_state = auth_events.values() + remote_auth_events = dict(auth_events) + remote_auth_events.update({(d.type, d.state_key): d for d in different_events}) + remote_state = remote_auth_events.values() - logger.info( - "After state res: updating auth_events with new state %s", - { - (d.type, d.state_key): d.event_id - for d in new_state.values() - if auth_events.get((d.type, d.state_key)) != d - }, - ) + room_version = yield self.store.get_room_version(event.room_id) + new_state = yield self.state_handler.resolve_events( + room_version, (local_state, remote_state), event + ) + + logger.info( + "After state res: updating auth_events with new state %s", + { + (d.type, d.state_key): d.event_id + for d in new_state.values() + if auth_events.get((d.type, d.state_key)) != d + }, + ) - auth_events.update(new_state) + auth_events.update(new_state) - context = yield self._update_context_for_auth_events( - event, context, auth_events - ) + context = yield self._update_context_for_auth_events( + event, context, auth_events + ) return context @@ -2466,7 +2504,7 @@ class FederationHandler(BaseHandler): room_version, event_dict, event, context ) - EventValidator().validate_new(event) + EventValidator().validate_new(event, self.config) # We need to tell the transaction queue to send this out, even # though the sender isn't a local user. @@ -2581,7 +2619,7 @@ class FederationHandler(BaseHandler): event, context = yield self.event_creation_handler.create_new_client_event( builder=builder ) - EventValidator().validate_new(event) + EventValidator().validate_new(event, self.config) return (event, context) @defer.inlineCallbacks @@ -2715,6 +2753,11 @@ class FederationHandler(BaseHandler): event_and_contexts, backfilled=backfilled ) + if self._ephemeral_messages_enabled: + for (event, context) in event_and_contexts: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) + if not backfilled: # Never notify for backfilled events for event, _ in event_and_contexts: yield self._notify_persisted_event(event, max_stream_id) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c05b9e83be..762c600424 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Optional from six import iteritems, itervalues, string_types @@ -22,9 +23,16 @@ from canonicaljson import encode_canonical_json, json from twisted.internet import defer from twisted.internet.defer import succeed +from twisted.internet.interfaces import IDelayedCall from synapse import event_auth -from synapse.api.constants import EventTypes, Membership, RelationTypes, UserTypes +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RelationTypes, + UserTypes, +) from synapse.api.errors import ( AuthError, Codes, @@ -62,6 +70,17 @@ class MessageHandler(object): self.storage = hs.get_storage() self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() + self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._is_worker_app = bool(hs.config.worker_app) + + # The scheduled call to self._expire_event. None if no call is currently + # scheduled. + self._scheduled_expiry = None # type: Optional[IDelayedCall] + + if not hs.config.worker_app: + run_as_background_process( + "_schedule_next_expiry", self._schedule_next_expiry + ) @defer.inlineCallbacks def get_room_data( @@ -138,7 +157,7 @@ class MessageHandler(object): raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = yield filter_events_for_client( - self.storage, user_id, last_events + self.storage, user_id, last_events, apply_retention_policies=False ) event = last_events[0] @@ -225,6 +244,100 @@ class MessageHandler(object): for user_id, profile in iteritems(users_with_profile) } + def maybe_schedule_expiry(self, event): + """Schedule the expiry of an event if there's not already one scheduled, + or if the one running is for an event that will expire after the provided + timestamp. + + This function needs to invalidate the event cache, which is only possible on + the master process, and therefore needs to be run on there. + + Args: + event (EventBase): The event to schedule the expiry of. + """ + assert not self._is_worker_app + + expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) + if not isinstance(expiry_ts, int) or event.is_state(): + return + + # _schedule_expiry_for_event won't actually schedule anything if there's already + # a task scheduled for a timestamp that's sooner than the provided one. + self._schedule_expiry_for_event(event.event_id, expiry_ts) + + @defer.inlineCallbacks + def _schedule_next_expiry(self): + """Retrieve the ID and the expiry timestamp of the next event to be expired, + and schedule an expiry task for it. + + If there's no event left to expire, set _expiry_scheduled to None so that a + future call to save_expiry_ts can schedule a new expiry task. + """ + # Try to get the expiry timestamp of the next event to expire. + res = yield self.store.get_next_event_to_expire() + if res: + event_id, expiry_ts = res + self._schedule_expiry_for_event(event_id, expiry_ts) + + def _schedule_expiry_for_event(self, event_id, expiry_ts): + """Schedule an expiry task for the provided event if there's not already one + scheduled at a timestamp that's sooner than the provided one. + + Args: + event_id (str): The ID of the event to expire. + expiry_ts (int): The timestamp at which to expire the event. + """ + if self._scheduled_expiry: + # If the provided timestamp refers to a time before the scheduled time of the + # next expiry task, cancel that task and reschedule it for this timestamp. + next_scheduled_expiry_ts = self._scheduled_expiry.getTime() * 1000 + if expiry_ts < next_scheduled_expiry_ts: + self._scheduled_expiry.cancel() + else: + return + + # Figure out how many seconds we need to wait before expiring the event. + now_ms = self.clock.time_msec() + delay = (expiry_ts - now_ms) / 1000 + + # callLater doesn't support negative delays, so trim the delay to 0 if we're + # in that case. + if delay < 0: + delay = 0 + + logger.info("Scheduling expiry for event %s in %.3fs", event_id, delay) + + self._scheduled_expiry = self.clock.call_later( + delay, + run_as_background_process, + "_expire_event", + self._expire_event, + event_id, + ) + + @defer.inlineCallbacks + def _expire_event(self, event_id): + """Retrieve and expire an event that needs to be expired from the database. + + If the event doesn't exist in the database, log it and delete the expiry date + from the database (so that we don't try to expire it again). + """ + assert self._ephemeral_events_enabled + + self._scheduled_expiry = None + + logger.info("Expiring event %s", event_id) + + try: + # Expire the event if we know about it. This function also deletes the expiry + # date from the database in the same database transaction. + yield self.store.expire_event(event_id) + except Exception as e: + logger.error("Could not expire event %s: %r", event_id, e) + + # Schedule the expiry of the next event to expire. + yield self._schedule_next_expiry() + # The duration (in ms) after which rooms should be removed # `_rooms_to_exclude_from_dummy_event_insertion` (with the effect that we will try @@ -250,6 +363,8 @@ class EventCreationHandler(object): self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases + self.room_invite_state_types = self.hs.config.room_invite_state_types + self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs) # This is only used to get at ratelimit function, and maybe_kick_guest_users @@ -295,6 +410,10 @@ class EventCreationHandler(object): 5 * 60 * 1000, ) + self._message_handler = hs.get_message_handler() + + self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def create_event( self, @@ -417,7 +536,7 @@ class EventCreationHandler(object): 403, "You must be in the room to create an alias for it" ) - self.validator.validate_new(event) + self.validator.validate_new(event, self.config) return (event, context) @@ -634,7 +753,7 @@ class EventCreationHandler(object): if requester: context.app_service = requester.app_service - self.validator.validate_new(event) + self.validator.validate_new(event, self.config) # If this event is an annotation then we check that that the sender # can't annotate the same way twice (e.g. stops users from liking an @@ -799,7 +918,7 @@ class EventCreationHandler(object): state_to_include_ids = [ e_id for k, e_id in iteritems(current_state_ids) - if k[0] in self.hs.config.room_invite_state_types + if k[0] in self.room_invite_state_types or k == (EventTypes.Member, event.sender) ] @@ -877,6 +996,10 @@ class EventCreationHandler(object): event, context=context ) + if self._ephemeral_events_enabled: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) + yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _notify(): diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 260a4351ca..8514ddc600 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -15,12 +15,15 @@ # limitations under the License. import logging +from six import iteritems + from twisted.internet import defer from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.logging.context import run_in_background +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter from synapse.types import RoomStreamToken from synapse.util.async_helpers import ReadWriteLock @@ -80,6 +83,109 @@ class PaginationHandler(object): self._purges_by_id = {} self._event_serializer = hs.get_event_client_serializer() + self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime + + if hs.config.retention_enabled: + # Run the purge jobs described in the configuration file. + for job in hs.config.retention_purge_jobs: + self.clock.looping_call( + run_as_background_process, + job["interval"], + "purge_history_for_rooms_in_range", + self.purge_history_for_rooms_in_range, + job["shortest_max_lifetime"], + job["longest_max_lifetime"], + ) + + @defer.inlineCallbacks + def purge_history_for_rooms_in_range(self, min_ms, max_ms): + """Purge outdated events from rooms within the given retention range. + + If a default retention policy is defined in the server's configuration and its + 'max_lifetime' is within this range, also targets rooms which don't have a + retention policy. + + Args: + min_ms (int|None): Duration in milliseconds that define the lower limit of + the range to handle (exclusive). If None, it means that the range has no + lower limit. + max_ms (int|None): Duration in milliseconds that define the upper limit of + the range to handle (inclusive). If None, it means that the range has no + upper limit. + """ + # We want the storage layer to to include rooms with no retention policy in its + # return value only if a default retention policy is defined in the server's + # configuration and that policy's 'max_lifetime' is either lower (or equal) than + # max_ms or higher than min_ms (or both). + if self._retention_default_max_lifetime is not None: + include_null = True + + if min_ms is not None and min_ms >= self._retention_default_max_lifetime: + # The default max_lifetime is lower than (or equal to) min_ms. + include_null = False + + if max_ms is not None and max_ms < self._retention_default_max_lifetime: + # The default max_lifetime is higher than max_ms. + include_null = False + else: + include_null = False + + rooms = yield self.store.get_rooms_for_retention_period_in_range( + min_ms, max_ms, include_null + ) + + for room_id, retention_policy in iteritems(rooms): + if room_id in self._purges_in_progress_by_room: + logger.warning( + "[purge] not purging room %s as there's an ongoing purge running" + " for this room", + room_id, + ) + continue + + max_lifetime = retention_policy["max_lifetime"] + + if max_lifetime is None: + # If max_lifetime is None, it means that include_null equals True, + # therefore we can safely assume that there is a default policy defined + # in the server's configuration. + max_lifetime = self._retention_default_max_lifetime + + # Figure out what token we should start purging at. + ts = self.clock.time_msec() - max_lifetime + + stream_ordering = yield self.store.find_first_stream_ordering_after_ts(ts) + + r = yield self.store.get_room_event_after_stream_ordering( + room_id, stream_ordering, + ) + if not r: + logger.warning( + "[purge] purging events not possible: No event found " + "(ts %i => stream_ordering %i)", + ts, + stream_ordering, + ) + continue + + (stream, topo, _event_id) = r + token = "t%d-%d" % (topo, stream) + + purge_id = random_string(16) + + self._purges_by_id[purge_id] = PurgeStatus() + + logger.info( + "Starting purging events in room %s (purge_id %s)" % (room_id, purge_id) + ) + + # We want to purge everything, including local events, and to run the purge in + # the background so that it's not blocking any other operation apart from + # other purges in the same room. + run_as_background_process( + "_purge_history", self._purge_history, purge_id, room_id, token, True, + ) + def start_purge_history(self, room_id, token, delete_local_events=False): """Start off a history purge on a room. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 95806af41e..8a7d965feb 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -266,7 +266,7 @@ class RegistrationHandler(BaseHandler): } # Bind email to new account - yield self._register_email_threepid(user_id, threepid_dict, None, False) + yield self._register_email_threepid(user_id, threepid_dict, None) return user_id diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index e92b2eafd5..22768e97ff 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -198,21 +199,21 @@ class RoomCreationHandler(BaseHandler): # finally, shut down the PLs in the old room, and update them in the new # room. yield self._update_upgraded_room_pls( - requester, old_room_id, new_room_id, old_room_state + requester, old_room_id, new_room_id, old_room_state, ) return new_room_id @defer.inlineCallbacks def _update_upgraded_room_pls( - self, requester, old_room_id, new_room_id, old_room_state + self, requester, old_room_id, new_room_id, old_room_state, ): """Send updated power levels in both rooms after an upgrade Args: requester (synapse.types.Requester): the user requesting the upgrade - old_room_id (unicode): the id of the room to be replaced - new_room_id (unicode): the id of the replacement room + old_room_id (str): the id of the room to be replaced + new_room_id (str): the id of the replacement room old_room_state (dict[tuple[str, str], str]): the state map for the old room Returns: @@ -298,7 +299,7 @@ class RoomCreationHandler(BaseHandler): tombstone_event_id (unicode|str): the ID of the tombstone event in the old room. Returns: - Deferred[None] + Deferred """ user_id = requester.user.to_string() @@ -333,6 +334,7 @@ class RoomCreationHandler(BaseHandler): (EventTypes.Encryption, ""), (EventTypes.ServerACL, ""), (EventTypes.RelatedGroups, ""), + (EventTypes.PowerLevels, ""), ) old_room_state_ids = yield self.store.get_filtered_current_state_ids( @@ -346,6 +348,31 @@ class RoomCreationHandler(BaseHandler): if old_event: initial_state[k] = old_event.content + # Resolve the minimum power level required to send any state event + # We will give the upgrading user this power level temporarily (if necessary) such that + # they are able to copy all of the state events over, then revert them back to their + # original power level afterwards in _update_upgraded_room_pls + + # Copy over user power levels now as this will not be possible with >100PL users once + # the room has been created + + power_levels = initial_state[(EventTypes.PowerLevels, "")] + + # Calculate the minimum power level needed to clone the room + event_power_levels = power_levels.get("events", {}) + state_default = power_levels.get("state_default", 0) + ban = power_levels.get("ban") + needed_power_level = max(state_default, ban, max(event_power_levels.values())) + + # Raise the requester's power level in the new room if necessary + current_power_level = power_levels["users"][requester.user.to_string()] + if current_power_level < needed_power_level: + # Assign this power level to the requester + power_levels["users"][requester.user.to_string()] = needed_power_level + + # Set the power levels to the modified state + initial_state[(EventTypes.PowerLevels, "")] = power_levels + yield self._send_events_for_new_room( requester, new_room_id, @@ -874,6 +901,10 @@ class RoomContextHandler(object): room_id, event_id, before_limit, after_limit, event_filter ) + if event_filter: + results["events_before"] = event_filter.filter(results["events_before"]) + results["events_after"] = event_filter.filter(results["events_after"]) + results["events_before"] = yield filter_evts(results["events_before"]) results["events_after"] = yield filter_evts(results["events_after"]) results["event"] = event @@ -902,7 +933,12 @@ class RoomContextHandler(object): state = yield self.state_store.get_state_for_events( [last_event_id], state_filter=state_filter ) - results["state"] = list(state[last_event_id].values()) + + state_events = list(state[last_event_id].values()) + if event_filter: + state_events = event_filter.filter(state_events) + + results["state"] = state_events # We use a dummy token here as we only care about the room portion of # the token, which we replace. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index dd096a8608..b0abc322b5 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -95,7 +95,9 @@ class RoomMemberHandler(object): raise NotImplementedError() @abc.abstractmethod - def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): + def _remote_reject_invite( + self, requester, remote_room_hosts, room_id, target, content + ): """Attempt to reject an invite for a room this server is not in. If we fail to do so we locally mark the invite as rejected. @@ -105,6 +107,7 @@ class RoomMemberHandler(object): reject invite room_id (str) target (UserID): The user rejecting the invite + content (dict): The content for the rejection event Returns: Deferred[dict]: A dictionary to be returned to the client, may @@ -491,7 +494,7 @@ class RoomMemberHandler(object): # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] res = yield self._remote_reject_invite( - requester, remote_room_hosts, room_id, target + requester, remote_room_hosts, room_id, target, content, ) return res @@ -991,13 +994,15 @@ class RoomMemberMasterHandler(RoomMemberHandler): ) @defer.inlineCallbacks - def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): + def _remote_reject_invite( + self, requester, remote_room_hosts, room_id, target, content + ): """Implements RoomMemberHandler._remote_reject_invite """ fed_handler = self.federation_handler try: ret = yield fed_handler.do_remotely_reject_invite( - remote_room_hosts, room_id, target.to_string() + remote_room_hosts, room_id, target.to_string(), content=content, ) return ret except Exception as e: diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 75e96ae1a2..69be86893b 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -55,7 +55,9 @@ class RoomMemberWorkerHandler(RoomMemberHandler): return ret - def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): + def _remote_reject_invite( + self, requester, remote_room_hosts, room_id, target, content + ): """Implements RoomMemberHandler._remote_reject_invite """ return self._remote_reject_client( @@ -63,6 +65,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): remote_room_hosts=remote_room_hosts, room_id=room_id, user_id=target.to_string(), + content=content, ) def _user_joined_room(self, target, room_id): diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 49c025e991..ca5eb04735 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -22,8 +22,6 @@ from six import iteritems, itervalues from prometheus_client import Counter -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.logging.context import LoggingContext from synapse.push.clientformat import format_push_rules_for_user @@ -244,8 +242,7 @@ class SyncHandler(object): expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, ) - @defer.inlineCallbacks - def wait_for_sync_for_user( + async def wait_for_sync_for_user( self, sync_config, since_token=None, timeout=0, full_state=False ): """Get the sync for a client if we have new data for it now. Otherwise @@ -258,9 +255,9 @@ class SyncHandler(object): # not been exceeded (if not part of the group by this point, almost certain # auth_blocking will occur) user_id = sync_config.user.to_string() - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) - res = yield self.response_cache.wrap( + res = await self.response_cache.wrap( sync_config.request_key, self._wait_for_sync_for_user, sync_config, @@ -270,8 +267,9 @@ class SyncHandler(object): ) return res - @defer.inlineCallbacks - def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state): + async def _wait_for_sync_for_user( + self, sync_config, since_token, timeout, full_state + ): if since_token is None: sync_type = "initial_sync" elif full_state: @@ -286,7 +284,7 @@ class SyncHandler(object): if timeout == 0 or since_token is None or full_state: # we are going to return immediately, so don't bother calling # notifier.wait_for_events. - result = yield self.current_sync_for_user( + result = await self.current_sync_for_user( sync_config, since_token, full_state=full_state ) else: @@ -294,7 +292,7 @@ class SyncHandler(object): def current_sync_callback(before_token, after_token): return self.current_sync_for_user(sync_config, since_token) - result = yield self.notifier.wait_for_events( + result = await self.notifier.wait_for_events( sync_config.user.to_string(), timeout, current_sync_callback, @@ -317,15 +315,13 @@ class SyncHandler(object): """ return self.generate_sync_result(sync_config, since_token, full_state) - @defer.inlineCallbacks - def push_rules_for_user(self, user): + async def push_rules_for_user(self, user): user_id = user.to_string() - rules = yield self.store.get_push_rules_for_user(user_id) + rules = await self.store.get_push_rules_for_user(user_id) rules = format_push_rules_for_user(user, rules) return rules - @defer.inlineCallbacks - def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None): + async def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None): """Get the ephemeral events for each room the user is in Args: sync_result_builder(SyncResultBuilder) @@ -346,7 +342,7 @@ class SyncHandler(object): room_ids = sync_result_builder.joined_room_ids typing_source = self.event_sources.sources["typing"] - typing, typing_key = yield typing_source.get_new_events( + typing, typing_key = await typing_source.get_new_events( user=sync_config.user, from_key=typing_key, limit=sync_config.filter_collection.ephemeral_limit(), @@ -368,7 +364,7 @@ class SyncHandler(object): receipt_key = since_token.receipt_key if since_token else "0" receipt_source = self.event_sources.sources["receipt"] - receipts, receipt_key = yield receipt_source.get_new_events( + receipts, receipt_key = await receipt_source.get_new_events( user=sync_config.user, from_key=receipt_key, limit=sync_config.filter_collection.ephemeral_limit(), @@ -385,8 +381,7 @@ class SyncHandler(object): return now_token, ephemeral_by_room - @defer.inlineCallbacks - def _load_filtered_recents( + async def _load_filtered_recents( self, room_id, sync_config, @@ -418,10 +413,10 @@ class SyncHandler(object): # ensure that we always include current state in the timeline current_state_ids = frozenset() if any(e.is_state() for e in recents): - current_state_ids = yield self.state.get_current_state_ids(room_id) + current_state_ids = await self.state.get_current_state_ids(room_id) current_state_ids = frozenset(itervalues(current_state_ids)) - recents = yield filter_events_for_client( + recents = await filter_events_for_client( self.storage, sync_config.user.to_string(), recents, @@ -452,14 +447,14 @@ class SyncHandler(object): # Otherwise, we want to return the last N events in the room # in toplogical ordering. if since_key: - events, end_key = yield self.store.get_room_events_stream_for_room( + events, end_key = await self.store.get_room_events_stream_for_room( room_id, limit=load_limit + 1, from_key=since_key, to_key=end_key, ) else: - events, end_key = yield self.store.get_recent_events_for_room( + events, end_key = await self.store.get_recent_events_for_room( room_id, limit=load_limit + 1, end_token=end_key ) loaded_recents = sync_config.filter_collection.filter_room_timeline( @@ -471,10 +466,10 @@ class SyncHandler(object): # ensure that we always include current state in the timeline current_state_ids = frozenset() if any(e.is_state() for e in loaded_recents): - current_state_ids = yield self.state.get_current_state_ids(room_id) + current_state_ids = await self.state.get_current_state_ids(room_id) current_state_ids = frozenset(itervalues(current_state_ids)) - loaded_recents = yield filter_events_for_client( + loaded_recents = await filter_events_for_client( self.storage, sync_config.user.to_string(), loaded_recents, @@ -501,8 +496,7 @@ class SyncHandler(object): limited=limited or newly_joined_room, ) - @defer.inlineCallbacks - def get_state_after_event(self, event, state_filter=StateFilter.all()): + async def get_state_after_event(self, event, state_filter=StateFilter.all()): """ Get the room state after the given event @@ -514,7 +508,7 @@ class SyncHandler(object): Returns: A Deferred map from ((type, state_key)->Event) """ - state_ids = yield self.state_store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( event.event_id, state_filter=state_filter ) if event.is_state(): @@ -522,8 +516,9 @@ class SyncHandler(object): state_ids[(event.type, event.state_key)] = event.event_id return state_ids - @defer.inlineCallbacks - def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()): + async def get_state_at( + self, room_id, stream_position, state_filter=StateFilter.all() + ): """ Get the room state at a particular stream position Args: @@ -539,13 +534,13 @@ class SyncHandler(object): # get_recent_events_for_room operates by topo ordering. This therefore # does not reliably give you the state at the given stream position. # (https://github.com/matrix-org/synapse/issues/3305) - last_events, _ = yield self.store.get_recent_events_for_room( + last_events, _ = await self.store.get_recent_events_for_room( room_id, end_token=stream_position.room_key, limit=1 ) if last_events: last_event = last_events[-1] - state = yield self.get_state_after_event( + state = await self.get_state_after_event( last_event, state_filter=state_filter ) @@ -554,8 +549,7 @@ class SyncHandler(object): state = {} return state - @defer.inlineCallbacks - def compute_summary(self, room_id, sync_config, batch, state, now_token): + async def compute_summary(self, room_id, sync_config, batch, state, now_token): """ Works out a room summary block for this room, summarising the number of joined members in the room, and providing the 'hero' members if the room has no name so clients can consistently name rooms. Also adds @@ -577,7 +571,7 @@ class SyncHandler(object): # FIXME: we could/should get this from room_stats when matthew/stats lands # FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305 - last_events, _ = yield self.store.get_recent_event_ids_for_room( + last_events, _ = await self.store.get_recent_event_ids_for_room( room_id, end_token=now_token.room_key, limit=1 ) @@ -585,7 +579,7 @@ class SyncHandler(object): return None last_event = last_events[-1] - state_ids = yield self.state_store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( last_event.event_id, state_filter=StateFilter.from_types( [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] @@ -593,7 +587,7 @@ class SyncHandler(object): ) # this is heavily cached, thus: fast. - details = yield self.store.get_room_summary(room_id) + details = await self.store.get_room_summary(room_id) name_id = state_ids.get((EventTypes.Name, "")) canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, "")) @@ -611,12 +605,12 @@ class SyncHandler(object): # calculating heroes. Empty strings are falsey, so we check # for the "name" value and default to an empty string. if name_id: - name = yield self.store.get_event(name_id, allow_none=True) + name = await self.store.get_event(name_id, allow_none=True) if name and name.content.get("name"): return summary if canonical_alias_id: - canonical_alias = yield self.store.get_event( + canonical_alias = await self.store.get_event( canonical_alias_id, allow_none=True ) if canonical_alias and canonical_alias.content.get("alias"): @@ -681,7 +675,7 @@ class SyncHandler(object): ) ] - missing_hero_state = yield self.store.get_events(missing_hero_event_ids) + missing_hero_state = await self.store.get_events(missing_hero_event_ids) missing_hero_state = missing_hero_state.values() for s in missing_hero_state: @@ -700,8 +694,7 @@ class SyncHandler(object): logger.debug("found LruCache for %r", cache_key) return cache - @defer.inlineCallbacks - def compute_state_delta( + async def compute_state_delta( self, room_id, batch, sync_config, since_token, now_token, full_state ): """ Works out the difference in state between the start of the timeline @@ -762,16 +755,16 @@ class SyncHandler(object): if full_state: if batch: - current_state_ids = yield self.state_store.get_state_ids_for_event( + current_state_ids = await self.state_store.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) - state_ids = yield self.state_store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) else: - current_state_ids = yield self.get_state_at( + current_state_ids = await self.get_state_at( room_id, stream_position=now_token, state_filter=state_filter ) @@ -786,13 +779,13 @@ class SyncHandler(object): ) elif batch.limited: if batch: - state_at_timeline_start = yield self.state_store.get_state_ids_for_event( + state_at_timeline_start = await self.state_store.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) else: # We can get here if the user has ignored the senders of all # the recent events. - state_at_timeline_start = yield self.get_state_at( + state_at_timeline_start = await self.get_state_at( room_id, stream_position=now_token, state_filter=state_filter ) @@ -810,19 +803,19 @@ class SyncHandler(object): # about them). state_filter = StateFilter.all() - state_at_previous_sync = yield self.get_state_at( + state_at_previous_sync = await self.get_state_at( room_id, stream_position=since_token, state_filter=state_filter ) if batch: - current_state_ids = yield self.state_store.get_state_ids_for_event( + current_state_ids = await self.state_store.get_state_ids_for_event( batch.events[-1].event_id, state_filter=state_filter ) else: # Its not clear how we get here, but empirically we do # (#5407). Logging has been added elsewhere to try and # figure out where this state comes from. - current_state_ids = yield self.get_state_at( + current_state_ids = await self.get_state_at( room_id, stream_position=now_token, state_filter=state_filter ) @@ -846,7 +839,7 @@ class SyncHandler(object): # So we fish out all the member events corresponding to the # timeline here, and then dedupe any redundant ones below. - state_ids = yield self.state_store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( batch.events[0].event_id, # we only want members! state_filter=StateFilter.from_types( @@ -886,7 +879,7 @@ class SyncHandler(object): state = {} if state_ids: - state = yield self.store.get_events(list(state_ids.values())) + state = await self.store.get_events(list(state_ids.values())) return { (e.type, e.state_key): e @@ -895,10 +888,9 @@ class SyncHandler(object): ) } - @defer.inlineCallbacks - def unread_notifs_for_room_id(self, room_id, sync_config): + async def unread_notifs_for_room_id(self, room_id, sync_config): with Measure(self.clock, "unread_notifs_for_room_id"): - last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user( + last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( user_id=sync_config.user.to_string(), room_id=room_id, receipt_type="m.read", @@ -906,7 +898,7 @@ class SyncHandler(object): notifs = [] if last_unread_event_id: - notifs = yield self.store.get_unread_event_push_actions_by_room_for_user( + notifs = await self.store.get_unread_event_push_actions_by_room_for_user( room_id, sync_config.user.to_string(), last_unread_event_id ) return notifs @@ -915,8 +907,9 @@ class SyncHandler(object): # count is whatever it was last time. return None - @defer.inlineCallbacks - def generate_sync_result(self, sync_config, since_token=None, full_state=False): + async def generate_sync_result( + self, sync_config, since_token=None, full_state=False + ): """Generates a sync result. Args: @@ -931,7 +924,7 @@ class SyncHandler(object): # this is due to some of the underlying streams not supporting the ability # to query up to a given point. # Always use the `now_token` in `SyncResultBuilder` - now_token = yield self.event_sources.get_current_token() + now_token = await self.event_sources.get_current_token() logger.info( "Calculating sync response for %r between %s and %s", @@ -947,10 +940,9 @@ class SyncHandler(object): # See https://github.com/matrix-org/matrix-doc/issues/1144 raise NotImplementedError() else: - joined_room_ids = yield self.get_rooms_for_user_at( + joined_room_ids = await self.get_rooms_for_user_at( user_id, now_token.room_stream_id ) - sync_result_builder = SyncResultBuilder( sync_config, full_state, @@ -959,11 +951,11 @@ class SyncHandler(object): joined_room_ids=joined_room_ids, ) - account_data_by_room = yield self._generate_sync_entry_for_account_data( + account_data_by_room = await self._generate_sync_entry_for_account_data( sync_result_builder ) - res = yield self._generate_sync_entry_for_rooms( + res = await self._generate_sync_entry_for_rooms( sync_result_builder, account_data_by_room ) newly_joined_rooms, newly_joined_or_invited_users, _, _ = res @@ -973,13 +965,13 @@ class SyncHandler(object): since_token is None and sync_config.filter_collection.blocks_all_presence() ) if self.hs_config.use_presence and not block_all_presence_data: - yield self._generate_sync_entry_for_presence( + await self._generate_sync_entry_for_presence( sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users ) - yield self._generate_sync_entry_for_to_device(sync_result_builder) + await self._generate_sync_entry_for_to_device(sync_result_builder) - device_lists = yield self._generate_sync_entry_for_device_list( + device_lists = await self._generate_sync_entry_for_device_list( sync_result_builder, newly_joined_rooms=newly_joined_rooms, newly_joined_or_invited_users=newly_joined_or_invited_users, @@ -990,11 +982,11 @@ class SyncHandler(object): device_id = sync_config.device_id one_time_key_counts = {} if device_id: - one_time_key_counts = yield self.store.count_e2e_one_time_keys( + one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id ) - yield self._generate_sync_entry_for_groups(sync_result_builder) + await self._generate_sync_entry_for_groups(sync_result_builder) # debug for https://github.com/matrix-org/synapse/issues/4422 for joined_room in sync_result_builder.joined: @@ -1018,18 +1010,17 @@ class SyncHandler(object): ) @measure_func("_generate_sync_entry_for_groups") - @defer.inlineCallbacks - def _generate_sync_entry_for_groups(self, sync_result_builder): + async def _generate_sync_entry_for_groups(self, sync_result_builder): user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token now_token = sync_result_builder.now_token if since_token and since_token.groups_key: - results = yield self.store.get_groups_changes_for_user( + results = await self.store.get_groups_changes_for_user( user_id, since_token.groups_key, now_token.groups_key ) else: - results = yield self.store.get_all_groups_for_user( + results = await self.store.get_all_groups_for_user( user_id, now_token.groups_key ) @@ -1062,8 +1053,7 @@ class SyncHandler(object): ) @measure_func("_generate_sync_entry_for_device_list") - @defer.inlineCallbacks - def _generate_sync_entry_for_device_list( + async def _generate_sync_entry_for_device_list( self, sync_result_builder, newly_joined_rooms, @@ -1111,32 +1101,32 @@ class SyncHandler(object): # room with by looking at all users that have left a room plus users # that were in a room we've left. - users_who_share_room = yield self.store.get_users_who_share_room_with_user( + users_who_share_room = await self.store.get_users_who_share_room_with_user( user_id ) # Step 1a, check for changes in devices of users we share a room with - users_that_have_changed = yield self.store.get_users_whose_devices_changed( + users_that_have_changed = await self.store.get_users_whose_devices_changed( since_token.device_list_key, users_who_share_room ) # Step 1b, check for newly joined rooms for room_id in newly_joined_rooms: - joined_users = yield self.state.get_current_users_in_room(room_id) + joined_users = await self.state.get_current_users_in_room(room_id) newly_joined_or_invited_users.update(joined_users) # TODO: Check that these users are actually new, i.e. either they # weren't in the previous sync *or* they left and rejoined. users_that_have_changed.update(newly_joined_or_invited_users) - user_signatures_changed = yield self.store.get_users_whose_signatures_changed( + user_signatures_changed = await self.store.get_users_whose_signatures_changed( user_id, since_token.device_list_key ) users_that_have_changed.update(user_signatures_changed) # Now find users that we no longer track for room_id in newly_left_rooms: - left_users = yield self.state.get_current_users_in_room(room_id) + left_users = await self.state.get_current_users_in_room(room_id) newly_left_users.update(left_users) # Remove any users that we still share a room with. @@ -1146,8 +1136,7 @@ class SyncHandler(object): else: return DeviceLists(changed=[], left=[]) - @defer.inlineCallbacks - def _generate_sync_entry_for_to_device(self, sync_result_builder): + async def _generate_sync_entry_for_to_device(self, sync_result_builder): """Generates the portion of the sync response. Populates `sync_result_builder` with the result. @@ -1168,14 +1157,14 @@ class SyncHandler(object): # We only delete messages when a new message comes in, but that's # fine so long as we delete them at some point. - deleted = yield self.store.delete_messages_for_device( + deleted = await self.store.delete_messages_for_device( user_id, device_id, since_stream_id ) logger.debug( "Deleted %d to-device messages up to %d", deleted, since_stream_id ) - messages, stream_id = yield self.store.get_new_messages_for_device( + messages, stream_id = await self.store.get_new_messages_for_device( user_id, device_id, since_stream_id, now_token.to_device_key ) @@ -1193,8 +1182,7 @@ class SyncHandler(object): else: sync_result_builder.to_device = [] - @defer.inlineCallbacks - def _generate_sync_entry_for_account_data(self, sync_result_builder): + async def _generate_sync_entry_for_account_data(self, sync_result_builder): """Generates the account data portion of the sync response. Populates `sync_result_builder` with the result. @@ -1212,25 +1200,25 @@ class SyncHandler(object): ( account_data, account_data_by_room, - ) = yield self.store.get_updated_account_data_for_user( + ) = await self.store.get_updated_account_data_for_user( user_id, since_token.account_data_key ) - push_rules_changed = yield self.store.have_push_rules_changed_for_user( + push_rules_changed = await self.store.have_push_rules_changed_for_user( user_id, int(since_token.push_rules_key) ) if push_rules_changed: - account_data["m.push_rules"] = yield self.push_rules_for_user( + account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) else: ( account_data, account_data_by_room, - ) = yield self.store.get_account_data_for_user(sync_config.user.to_string()) + ) = await self.store.get_account_data_for_user(sync_config.user.to_string()) - account_data["m.push_rules"] = yield self.push_rules_for_user( + account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) @@ -1245,8 +1233,7 @@ class SyncHandler(object): return account_data_by_room - @defer.inlineCallbacks - def _generate_sync_entry_for_presence( + async def _generate_sync_entry_for_presence( self, sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users ): """Generates the presence portion of the sync response. Populates the @@ -1274,7 +1261,7 @@ class SyncHandler(object): presence_key = None include_offline = False - presence, presence_key = yield presence_source.get_new_events( + presence, presence_key = await presence_source.get_new_events( user=user, from_key=presence_key, is_guest=sync_config.is_guest, @@ -1286,12 +1273,12 @@ class SyncHandler(object): extra_users_ids = set(newly_joined_or_invited_users) for room_id in newly_joined_rooms: - users = yield self.state.get_current_users_in_room(room_id) + users = await self.state.get_current_users_in_room(room_id) extra_users_ids.update(users) extra_users_ids.discard(user.to_string()) if extra_users_ids: - states = yield self.presence_handler.get_states(extra_users_ids) + states = await self.presence_handler.get_states(extra_users_ids) presence.extend(states) # Deduplicate the presence entries so that there's at most one per user @@ -1301,8 +1288,9 @@ class SyncHandler(object): sync_result_builder.presence = presence - @defer.inlineCallbacks - def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_room): + async def _generate_sync_entry_for_rooms( + self, sync_result_builder, account_data_by_room + ): """Generates the rooms portion of the sync response. Populates the `sync_result_builder` with the result. @@ -1324,7 +1312,7 @@ class SyncHandler(object): if block_all_room_ephemeral: ephemeral_by_room = {} else: - now_token, ephemeral_by_room = yield self.ephemeral_by_room( + now_token, ephemeral_by_room = await self.ephemeral_by_room( sync_result_builder, now_token=sync_result_builder.now_token, since_token=sync_result_builder.since_token, @@ -1336,16 +1324,16 @@ class SyncHandler(object): since_token = sync_result_builder.since_token if not sync_result_builder.full_state: if since_token and not ephemeral_by_room and not account_data_by_room: - have_changed = yield self._have_rooms_changed(sync_result_builder) + have_changed = await self._have_rooms_changed(sync_result_builder) if not have_changed: - tags_by_room = yield self.store.get_updated_tags( + tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key ) if not tags_by_room: logger.debug("no-oping sync") return [], [], [], [] - ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( + ignored_account_data = await self.store.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id=user_id ) @@ -1355,18 +1343,18 @@ class SyncHandler(object): ignored_users = frozenset() if since_token: - res = yield self._get_rooms_changed(sync_result_builder, ignored_users) + res = await self._get_rooms_changed(sync_result_builder, ignored_users) room_entries, invited, newly_joined_rooms, newly_left_rooms = res - tags_by_room = yield self.store.get_updated_tags( + tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key ) else: - res = yield self._get_all_rooms(sync_result_builder, ignored_users) + res = await self._get_all_rooms(sync_result_builder, ignored_users) room_entries, invited, newly_joined_rooms = res newly_left_rooms = [] - tags_by_room = yield self.store.get_tags_for_user(user_id) + tags_by_room = await self.store.get_tags_for_user(user_id) def handle_room_entries(room_entry): return self._generate_room_entry( @@ -1379,7 +1367,7 @@ class SyncHandler(object): always_include=sync_result_builder.full_state, ) - yield concurrently_execute(handle_room_entries, room_entries, 10) + await concurrently_execute(handle_room_entries, room_entries, 10) sync_result_builder.invited.extend(invited) @@ -1413,8 +1401,7 @@ class SyncHandler(object): newly_left_users, ) - @defer.inlineCallbacks - def _have_rooms_changed(self, sync_result_builder): + async def _have_rooms_changed(self, sync_result_builder): """Returns whether there may be any new events that should be sent down the sync. Returns True if there are. """ @@ -1425,7 +1412,7 @@ class SyncHandler(object): assert since_token # Get a list of membership change events that have happened. - rooms_changed = yield self.store.get_membership_changes_for_user( + rooms_changed = await self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key ) @@ -1438,8 +1425,7 @@ class SyncHandler(object): return True return False - @defer.inlineCallbacks - def _get_rooms_changed(self, sync_result_builder, ignored_users): + async def _get_rooms_changed(self, sync_result_builder, ignored_users): """Gets the the changes that have happened since the last sync. Args: @@ -1464,7 +1450,7 @@ class SyncHandler(object): assert since_token # Get a list of membership change events that have happened. - rooms_changed = yield self.store.get_membership_changes_for_user( + rooms_changed = await self.store.get_membership_changes_for_user( user_id, since_token.room_key, now_token.room_key ) @@ -1502,11 +1488,11 @@ class SyncHandler(object): continue if room_id in sync_result_builder.joined_room_ids or has_join: - old_state_ids = yield self.get_state_at(room_id, since_token) + old_state_ids = await self.get_state_at(room_id, since_token) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) old_mem_ev = None if old_mem_ev_id: - old_mem_ev = yield self.store.get_event( + old_mem_ev = await self.store.get_event( old_mem_ev_id, allow_none=True ) @@ -1539,13 +1525,13 @@ class SyncHandler(object): newly_left_rooms.append(room_id) else: if not old_state_ids: - old_state_ids = yield self.get_state_at(room_id, since_token) + old_state_ids = await self.get_state_at(room_id, since_token) old_mem_ev_id = old_state_ids.get( (EventTypes.Member, user_id), None ) old_mem_ev = None if old_mem_ev_id: - old_mem_ev = yield self.store.get_event( + old_mem_ev = await self.store.get_event( old_mem_ev_id, allow_none=True ) if old_mem_ev and old_mem_ev.membership == Membership.JOIN: @@ -1569,7 +1555,7 @@ class SyncHandler(object): if leave_events: leave_event = leave_events[-1] - leave_stream_token = yield self.store.get_stream_token_for_event( + leave_stream_token = await self.store.get_stream_token_for_event( leave_event.event_id ) leave_token = since_token.copy_and_replace( @@ -1606,7 +1592,7 @@ class SyncHandler(object): timeline_limit = sync_config.filter_collection.timeline_limit() # Get all events for rooms we're currently joined to. - room_to_events = yield self.store.get_room_events_stream_for_rooms( + room_to_events = await self.store.get_room_events_stream_for_rooms( room_ids=sync_result_builder.joined_room_ids, from_key=since_token.room_key, to_key=now_token.room_key, @@ -1655,8 +1641,7 @@ class SyncHandler(object): return room_entries, invited, newly_joined_rooms, newly_left_rooms - @defer.inlineCallbacks - def _get_all_rooms(self, sync_result_builder, ignored_users): + async def _get_all_rooms(self, sync_result_builder, ignored_users): """Returns entries for all rooms for the user. Args: @@ -1680,7 +1665,7 @@ class SyncHandler(object): Membership.BAN, ) - room_list = yield self.store.get_rooms_for_user_where_membership_is( + room_list = await self.store.get_rooms_for_user_where_membership_is( user_id=user_id, membership_list=membership_list ) @@ -1703,7 +1688,7 @@ class SyncHandler(object): elif event.membership == Membership.INVITE: if event.sender in ignored_users: continue - invite = yield self.store.get_event(event.event_id) + invite = await self.store.get_event(event.event_id) invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite)) elif event.membership in (Membership.LEAVE, Membership.BAN): # Always send down rooms we were banned or kicked from. @@ -1729,8 +1714,7 @@ class SyncHandler(object): return room_entries, invited, [] - @defer.inlineCallbacks - def _generate_room_entry( + async def _generate_room_entry( self, sync_result_builder, ignored_users, @@ -1772,7 +1756,7 @@ class SyncHandler(object): since_token = room_builder.since_token upto_token = room_builder.upto_token - batch = yield self._load_filtered_recents( + batch = await self._load_filtered_recents( room_id, sync_config, now_token=upto_token, @@ -1799,7 +1783,7 @@ class SyncHandler(object): # tag was added by synapse e.g. for server notice rooms. if full_state: user_id = sync_result_builder.sync_config.user.to_string() - tags = yield self.store.get_tags_for_room(user_id, room_id) + tags = await self.store.get_tags_for_room(user_id, room_id) # If there aren't any tags, don't send the empty tags list down # sync @@ -1824,7 +1808,7 @@ class SyncHandler(object): ): return - state = yield self.compute_state_delta( + state = await self.compute_state_delta( room_id, batch, sync_config, since_token, now_token, full_state=full_state ) @@ -1847,7 +1831,7 @@ class SyncHandler(object): ) or since_token is None ): - summary = yield self.compute_summary( + summary = await self.compute_summary( room_id, sync_config, batch, state, now_token ) @@ -1864,7 +1848,7 @@ class SyncHandler(object): ) if room_sync or always_include: - notifs = yield self.unread_notifs_for_room_id(room_id, sync_config) + notifs = await self.unread_notifs_for_room_id(room_id, sync_config) if notifs is not None: unread_notifications["notification_count"] = notifs["notify_count"] @@ -1890,8 +1874,7 @@ class SyncHandler(object): else: raise Exception("Unrecognized rtype: %r", room_builder.rtype) - @defer.inlineCallbacks - def get_rooms_for_user_at(self, user_id, stream_ordering): + async def get_rooms_for_user_at(self, user_id, stream_ordering): """Get set of joined rooms for a user at the given stream ordering. The stream ordering *must* be recent, otherwise this may throw an @@ -1906,7 +1889,7 @@ class SyncHandler(object): Deferred[frozenset[str]]: Set of room_ids the user is in at given stream_ordering. """ - joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(user_id) + joined_rooms = await self.store.get_rooms_for_user_with_stream_ordering(user_id) joined_room_ids = set() @@ -1924,10 +1907,10 @@ class SyncHandler(object): logger.info("User joined room after current token: %s", room_id) - extrems = yield self.store.get_forward_extremeties_for_room( + extrems = await self.store.get_forward_extremeties_for_room( room_id, stream_ordering ) - users_in_room = yield self.state.get_current_users_in_room(room_id, extrems) + users_in_room = await self.state.get_current_users_in_room(room_id, extrems) if user_id in users_in_room: joined_room_ids.add(room_id) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 856337b7e2..6f78454322 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -313,7 +313,7 @@ class TypingNotificationEventSource(object): events.append(self._make_event_for(room_id)) - return events, handler._latest_room_serial + return defer.succeed((events, handler._latest_room_serial)) def get_current_key(self): return self.get_typing_handler()._latest_room_serial diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py index 05fc64f409..03934956f4 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py @@ -256,6 +256,7 @@ class TerseJSONToTCPLogObserver(object): # transport is the same, just trigger a resumeProducing. if self._producer and r.transport is self._producer.transport: self._producer.resumeProducing() + self._connection_waiter = None return # If the producer is still producing, stop it. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 735b882363..305b9b0178 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -175,4 +175,4 @@ class ModuleApi(object): Returns: Deferred[object]: result of func """ - return self._store.runInteraction(desc, func, *args, **kwargs) + return self._store.db.runInteraction(desc, func, *args, **kwargs) diff --git a/synapse/notifier.py b/synapse/notifier.py index af161a81d7..5f5f765bea 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -304,8 +304,7 @@ class Notifier(object): without waking up any of the normal user event streams""" self.notify_replication() - @defer.inlineCallbacks - def wait_for_events( + async def wait_for_events( self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START ): """Wait until the callback returns a non empty response or the @@ -313,9 +312,9 @@ class Notifier(object): """ user_stream = self.user_to_user_stream.get(user_id) if user_stream is None: - current_token = yield self.event_sources.get_current_token() + current_token = await self.event_sources.get_current_token() if room_ids is None: - room_ids = yield self.store.get_rooms_for_user(user_id) + room_ids = await self.store.get_rooms_for_user(user_id) user_stream = _NotifierUserStream( user_id=user_id, rooms=room_ids, @@ -344,11 +343,11 @@ class Notifier(object): self.hs.get_reactor(), ) with PreserveLoggingContext(): - yield listener.deferred + await listener.deferred current_token = user_stream.current_token - result = yield callback(prev_token, current_token) + result = await callback(prev_token, current_token) if result: break @@ -364,12 +363,11 @@ class Notifier(object): # This happened if there was no timeout or if the timeout had # already expired. current_token = user_stream.current_token - result = yield callback(prev_token, current_token) + result = await callback(prev_token, current_token) return result - @defer.inlineCallbacks - def get_events_for( + async def get_events_for( self, user, pagination_config, @@ -391,15 +389,14 @@ class Notifier(object): """ from_token = pagination_config.from_token if not from_token: - from_token = yield self.event_sources.get_current_token() + from_token = await self.event_sources.get_current_token() limit = pagination_config.limit - room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id) + room_ids, is_joined = await self._get_room_ids(user, explicit_room_id) is_peeking = not is_joined - @defer.inlineCallbacks - def check_for_updates(before_token, after_token): + async def check_for_updates(before_token, after_token): if not after_token.is_after(before_token): return EventStreamResult([], (from_token, from_token)) @@ -415,7 +412,7 @@ class Notifier(object): if only_keys and name not in only_keys: continue - new_events, new_key = yield source.get_new_events( + new_events, new_key = await source.get_new_events( user=user, from_key=getattr(from_token, keyname), limit=limit, @@ -425,7 +422,7 @@ class Notifier(object): ) if name == "room": - new_events = yield filter_events_for_client( + new_events = await filter_events_for_client( self.storage, user.to_string(), new_events, @@ -461,7 +458,7 @@ class Notifier(object): user_id_for_stream, ) - result = yield self.wait_for_events( + result = await self.wait_for_events( user_id_for_stream, timeout, check_for_updates, diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 1ba7bcd4d8..7881780760 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -386,15 +386,7 @@ class RulesForRoom(object): """ sequence = self.sequence - rows = yield self.store._simple_select_many_batch( - table="room_memberships", - column="event_id", - iterable=member_event_ids.values(), - retcols=("user_id", "membership", "event_id"), - keyvalues={}, - batch_size=500, - desc="_get_rules_for_member_event_ids", - ) + rows = yield self.store.get_membership_from_event_ids(member_event_ids.values()) members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows} diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index cc1f249740..3577611fd7 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -93,6 +93,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): { "requester": ..., "remote_room_hosts": [...], + "content": { ... } } """ @@ -107,7 +108,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - def _serialize_payload(requester, room_id, user_id, remote_room_hosts): + def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content): """ Args: requester(Requester) @@ -118,12 +119,14 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): return { "requester": requester.serialize(), "remote_room_hosts": remote_room_hosts, + "content": content, } async def _handle_request(self, request, room_id, user_id): content = parse_json_object_from_request(request) remote_room_hosts = content["remote_room_hosts"] + event_content = content["content"] requester = Requester.deserialize(self.store, content["requester"]) @@ -134,7 +137,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): try: event = await self.federation_handler.do_remotely_reject_invite( - remote_room_hosts, room_id, user_id + remote_room_hosts, room_id, user_id, event_content, ) ret = event.get_pdu_json() except Exception as e: diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 456bc005a0..b91a528245 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -18,7 +18,9 @@ from typing import Dict import six -from synapse.storage._base import _CURRENT_STATE_CACHE_NAME, SQLBaseStore +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from ._slaved_id_tracker import SlavedIdTracker @@ -34,8 +36,8 @@ def __func__(inp): class BaseSlavedStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(BaseSlavedStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(BaseSlavedStore, self).__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = SlavedIdTracker( db_conn, "cache_invalidation_stream", "stream_id" @@ -62,7 +64,7 @@ class BaseSlavedStore(SQLBaseStore): if stream_name == "caches": self._cache_id_gen.advance(token) for row in rows: - if row.cache_func == _CURRENT_STATE_CACHE_NAME: + if row.cache_func == CURRENT_STATE_CACHE_NAME: room_id = row.keys[0] members_changed = set(row.keys[1:]) self._invalidate_state_caches(room_id, members_changed) diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index bc2f6a12ae..ebe94909cb 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -18,15 +18,16 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore from synapse.storage.data_stores.main.tags import TagsWorkerStore +from synapse.storage.database import Database class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._account_data_id_gen = SlavedIdTracker( db_conn, "account_data_max_stream_id", "stream_id" ) - super(SlavedAccountDataStore, self).__init__(db_conn, hs) + super(SlavedAccountDataStore, self).__init__(database, db_conn, hs) def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index b4f58cea19..fbf996e33a 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -14,6 +14,7 @@ # limitations under the License. from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY +from synapse.storage.database import Database from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache @@ -21,8 +22,8 @@ from ._base import BaseSlavedStore class SlavedClientIpStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedClientIpStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedClientIpStore, self).__init__(database, db_conn, hs) self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 9fb6c5c6ff..0c237c6e0f 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -16,13 +16,14 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedDeviceInboxStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( db_conn, "device_max_stream_id", "stream_id" ) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index de50748c30..dc625e0d7a 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -18,12 +18,13 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.storage.data_stores.main.devices import DeviceWorkerStore from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedDeviceStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedDeviceStore, self).__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index d0a0eaf75b..29f35b9915 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -31,6 +31,7 @@ from synapse.storage.data_stores.main.signatures import SignatureWorkerStore from synapse.storage.data_stores.main.state import StateGroupWorkerStore from synapse.storage.data_stores.main.stream import StreamWorkerStore from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker @@ -59,13 +60,13 @@ class SlavedEventStore( RelationsWorkerStore, BaseSlavedStore, ): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering") self._backfill_id_gen = SlavedIdTracker( db_conn, "events", "stream_ordering", step=-1 ) - super(SlavedEventStore, self).__init__(db_conn, hs) + super(SlavedEventStore, self).__init__(database, db_conn, hs) # Cached functions can't be accessed through a class instance so we need # to reach inside the __dict__ to extract them. diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py index 5c84ebd125..bcb0688954 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py @@ -14,13 +14,14 @@ # limitations under the License. from synapse.storage.data_stores.main.filtering import FilteringStore +from synapse.storage.database import Database from ._base import BaseSlavedStore class SlavedFilteringStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedFilteringStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedFilteringStore, self).__init__(database, db_conn, hs) # Filters are immutable so this cache doesn't need to be expired get_user_filter = FilteringStore.__dict__["get_user_filter"] diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 28a46edd28..69a4ae42f9 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -14,6 +14,7 @@ # limitations under the License. from synapse.storage import DataStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore, __func__ @@ -21,8 +22,8 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedGroupServerStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedGroupServerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedGroupServerStore, self).__init__(database, db_conn, hs) self.hs = hs diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index 747ced0c84..f552e7c972 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -15,6 +15,7 @@ from synapse.storage import DataStore from synapse.storage.data_stores.main.presence import PresenceStore +from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore, __func__ @@ -22,8 +23,8 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedPresenceStore(BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedPresenceStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedPresenceStore, self).__init__(database, db_conn, hs) self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") self._presence_on_startup = self._get_active_presence(db_conn) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 3655f05e54..eebd5a1fb6 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -15,17 +15,18 @@ # limitations under the License. from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore +from synapse.storage.database import Database from ._slaved_id_tracker import SlavedIdTracker from .events import SlavedEventStore class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id" ) - super(SlavedPushRuleStore, self).__init__(db_conn, hs) + super(SlavedPushRuleStore, self).__init__(database, db_conn, hs) def get_push_rules_stream_token(self): return ( diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index b4331d0799..f22c2d44a3 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -15,14 +15,15 @@ # limitations under the License. from synapse.storage.data_stores.main.pusher import PusherWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(SlavedPusherStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SlavedPusherStore, self).__init__(database, db_conn, hs) self._pushers_id_gen = SlavedIdTracker( db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 43d823c601..d40dc6e1f5 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -15,6 +15,7 @@ # limitations under the License. from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker @@ -29,14 +30,14 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = SlavedIdTracker( db_conn, "receipts_linearized", "stream_id" ) - super(SlavedReceiptsStore, self).__init__(db_conn, hs) + super(SlavedReceiptsStore, self).__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index d9ad386b28..3a20f45316 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -14,14 +14,15 @@ # limitations under the License. from synapse.storage.data_stores.main.room import RoomWorkerStore +from synapse.storage.database import Database from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class RoomStore(RoomWorkerStore, BaseSlavedStore): - def __init__(self, db_conn, hs): - super(RoomStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomStore, self).__init__(database, db_conn, hs) self._public_room_id_gen = SlavedIdTracker( db_conn, "public_room_list_stream", "stream_id" ) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 68a59a3424..c122c449f4 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -34,12 +34,12 @@ from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.users import ( AccountValidityRenewServlet, DeactivateAccountRestServlet, - GetUsersPaginatedRestServlet, ResetPasswordRestServlet, SearchUsersRestServlet, UserAdminServlet, UserRegisterServlet, UsersRestServlet, + UsersRestServletV2, WhoisRestServlet, ) from synapse.util.versionstring import get_version_string @@ -191,6 +191,7 @@ def register_servlets(hs, http_server): SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) + UsersRestServletV2(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): @@ -201,7 +202,6 @@ def register_servlets_for_client_rest_resource(hs, http_server): PurgeHistoryRestServlet(hs).register(http_server) UsersRestServlet(hs).register(http_server) ResetPasswordRestServlet(hs).register(http_server) - GetUsersPaginatedRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server) ShutdownRoomRestServlet(hs).register(http_server) UserRegisterServlet(hs).register(http_server) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 58a83f93af..1937879dbe 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -25,6 +25,7 @@ from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( RestServlet, assert_params_in_dict, + parse_boolean, parse_integer, parse_json_object_from_request, parse_string, @@ -59,71 +60,45 @@ class UsersRestServlet(RestServlet): return 200, ret -class GetUsersPaginatedRestServlet(RestServlet): - """Get request to get specific number of users from Synapse. +class UsersRestServletV2(RestServlet): + PATTERNS = (re.compile("^/_synapse/admin/v2/users$"),) + + """Get request to list all local users. This needs user to have administrator access in Synapse. - Example: - http://localhost:8008/_synapse/admin/v1/users_paginate/ - @admin:user?access_token=admin_access_token&start=0&limit=10 - Returns: - 200 OK with json object {list[dict[str, Any]], count} or empty object. - """ - PATTERNS = historical_admin_path_patterns( - "/users_paginate/(?P<target_user_id>[^/]*)" - ) + GET /_synapse/admin/v2/users?from=0&limit=10&guests=false + + returns: + 200 OK with list of users if success otherwise an error. + + The parameters `from` and `limit` are required only for pagination. + By default, a `limit` of 100 is used. + The parameter `user_id` can be used to filter by user id. + The parameter `guests` can be used to exclude guest users. + The parameter `deactivated` can be used to include deactivated users. + """ def __init__(self, hs): - self.store = hs.get_datastore() self.hs = hs self.auth = hs.get_auth() - self.handlers = hs.get_handlers() + self.admin_handler = hs.get_handlers().admin_handler - async def on_GET(self, request, target_user_id): - """Get request to get specific number of users from Synapse. - This needs user to have administrator access in Synapse. - """ + async def on_GET(self, request): await assert_requester_is_admin(self.auth, request) - target_user = UserID.from_string(target_user_id) - - if not self.hs.is_mine(target_user): - raise SynapseError(400, "Can only users a local user") - - order = "name" # order by name in user table - start = parse_integer(request, "start", required=True) - limit = parse_integer(request, "limit", required=True) - - logger.info("limit: %s, start: %s", limit, start) - - ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit) - return 200, ret + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + user_id = parse_string(request, "user_id", default=None) + guests = parse_boolean(request, "guests", default=True) + deactivated = parse_boolean(request, "deactivated", default=False) - async def on_POST(self, request, target_user_id): - """Post request to get specific number of users from Synapse.. - This needs user to have administrator access in Synapse. - Example: - http://localhost:8008/_synapse/admin/v1/users_paginate/ - @admin:user?access_token=admin_access_token - JsonBodyToSend: - { - "start": "0", - "limit": "10 - } - Returns: - 200 OK with json object {list[dict[str, Any]], count} or empty object. - """ - await assert_requester_is_admin(self.auth, request) - UserID.from_string(target_user_id) - - order = "name" # order by name in user table - params = parse_json_object_from_request(request) - assert_params_in_dict(params, ["limit", "start"]) - limit = params["limit"] - start = params["start"] - logger.info("limit: %s, start: %s", limit, start) + users = await self.admin_handler.get_users_paginate( + start, limit, user_id, guests, deactivated + ) + ret = {"users": users} + if len(users) >= limit: + ret["next_token"] = str(start + len(users)) - ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit) return 200, ret diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 4ea3666874..5934b1fe8b 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import ( AuthError, Codes, @@ -47,17 +45,15 @@ class ClientDirectoryServer(RestServlet): self.handlers = hs.get_handlers() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_alias): + async def on_GET(self, request, room_alias): room_alias = RoomAlias.from_string(room_alias) dir_handler = self.handlers.directory_handler - res = yield dir_handler.get_association(room_alias) + res = await dir_handler.get_association(room_alias) return 200, res - @defer.inlineCallbacks - def on_PUT(self, request, room_alias): + async def on_PUT(self, request, room_alias): room_alias = RoomAlias.from_string(room_alias) content = parse_json_object_from_request(request) @@ -77,26 +73,25 @@ class ClientDirectoryServer(RestServlet): # TODO(erikj): Check types. - room = yield self.store.get_room(room_id) + room = await self.store.get_room(room_id) if room is None: raise SynapseError(400, "Room does not exist") - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) - yield self.handlers.directory_handler.create_association( + await self.handlers.directory_handler.create_association( requester, room_alias, room_id, servers ) return 200, {} - @defer.inlineCallbacks - def on_DELETE(self, request, room_alias): + async def on_DELETE(self, request, room_alias): dir_handler = self.handlers.directory_handler try: - service = yield self.auth.get_appservice_by_req(request) + service = await self.auth.get_appservice_by_req(request) room_alias = RoomAlias.from_string(room_alias) - yield dir_handler.delete_appservice_association(service, room_alias) + await dir_handler.delete_appservice_association(service, room_alias) logger.info( "Application service at %s deleted alias %s", service.url, @@ -107,12 +102,12 @@ class ClientDirectoryServer(RestServlet): # fallback to default user behaviour if they aren't an AS pass - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) user = requester.user room_alias = RoomAlias.from_string(room_alias) - yield dir_handler.delete_association(requester, room_alias) + await dir_handler.delete_association(requester, room_alias) logger.info( "User %s deleted alias %s", user.to_string(), room_alias.to_string() @@ -130,32 +125,29 @@ class ClientDirectoryListServer(RestServlet): self.handlers = hs.get_handlers() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - room = yield self.store.get_room(room_id) + async def on_GET(self, request, room_id): + room = await self.store.get_room(room_id) if room is None: raise NotFoundError("Unknown room") return 200, {"visibility": "public" if room["is_public"] else "private"} - @defer.inlineCallbacks - def on_PUT(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, room_id): + requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) visibility = content.get("visibility", "public") - yield self.handlers.directory_handler.edit_published_room_list( + await self.handlers.directory_handler.edit_published_room_list( requester, room_id, visibility ) return 200, {} - @defer.inlineCallbacks - def on_DELETE(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, room_id): + requester = await self.auth.get_user_by_req(request) - yield self.handlers.directory_handler.edit_published_room_list( + await self.handlers.directory_handler.edit_published_room_list( requester, room_id, "private" ) @@ -181,15 +173,14 @@ class ClientAppserviceDirectoryListServer(RestServlet): def on_DELETE(self, request, network_id, room_id): return self._edit(request, network_id, room_id, "private") - @defer.inlineCallbacks - def _edit(self, request, network_id, room_id, visibility): - requester = yield self.auth.get_user_by_req(request) + async def _edit(self, request, network_id, room_id, visibility): + requester = await self.auth.get_user_by_req(request) if not requester.app_service: raise AuthError( 403, "Only appservices can edit the appservice published room list" ) - yield self.handlers.directory_handler.edit_published_appservice_room_list( + await self.handlers.directory_handler.edit_published_appservice_room_list( requester.app_service.id, network_id, room_id, visibility ) diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 6651b4cf07..4beb617733 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -16,8 +16,6 @@ """This module contains REST servlets to do with event streaming, /events.""" import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet from synapse.rest.client.v2_alpha._base import client_patterns @@ -36,9 +34,8 @@ class EventStreamRestServlet(RestServlet): self.event_stream_handler = hs.get_event_stream_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) is_guest = requester.is_guest room_id = None if is_guest: @@ -57,7 +54,7 @@ class EventStreamRestServlet(RestServlet): as_client_event = b"raw" not in request.args - chunk = yield self.event_stream_handler.get_stream( + chunk = await self.event_stream_handler.get_stream( requester.user.to_string(), pagin_config, timeout=timeout, @@ -83,14 +80,13 @@ class EventRestServlet(RestServlet): self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() - @defer.inlineCallbacks - def on_GET(self, request, event_id): - requester = yield self.auth.get_user_by_req(request) - event = yield self.event_handler.get_event(requester.user, None, event_id) + async def on_GET(self, request, event_id): + requester = await self.auth.get_user_by_req(request) + event = await self.event_handler.get_event(requester.user, None, event_id) time_now = self.clock.time_msec() if event: - event = yield self._event_serializer.serialize_event(event, time_now) + event = await self._event_serializer.serialize_event(event, time_now) return 200, event else: return 404, "Event not found." diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 2da3cd7511..910b3b4eeb 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer from synapse.http.servlet import RestServlet, parse_boolean from synapse.rest.client.v2_alpha._base import client_patterns @@ -29,13 +28,12 @@ class InitialSyncRestServlet(RestServlet): self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) as_client_event = b"raw" not in request.args pagination_config = PaginationConfig.from_request(request) include_archived = parse_boolean(request, "archived", default=False) - content = yield self.initial_sync_handler.snapshot_all_rooms( + content = await self.initial_sync_handler.snapshot_all_rooms( user_id=requester.user.to_string(), pagin_config=pagination_config, as_client_event=as_client_event, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 19eb15003d..ff9c978fe7 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -18,7 +18,6 @@ import xml.etree.ElementTree as ET from six.moves import urllib -from twisted.internet import defer from twisted.web.client import PartialDownloadError from synapse.api.errors import Codes, LoginError, SynapseError @@ -130,8 +129,7 @@ class LoginRestServlet(RestServlet): def on_OPTIONS(self, request): return 200, {} - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): self._address_ratelimiter.ratelimit( request.getClientIP(), time_now_s=self.hs.clock.time(), @@ -145,11 +143,11 @@ class LoginRestServlet(RestServlet): if self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE ): - result = yield self.do_jwt_login(login_submission) + result = await self.do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: - result = yield self.do_token_login(login_submission) + result = await self.do_token_login(login_submission) else: - result = yield self._do_other_login(login_submission) + result = await self._do_other_login(login_submission) except KeyError: raise SynapseError(400, "Missing JSON keys.") @@ -158,8 +156,7 @@ class LoginRestServlet(RestServlet): result["well_known"] = well_known_data return 200, result - @defer.inlineCallbacks - def _do_other_login(self, login_submission): + async def _do_other_login(self, login_submission): """Handle non-token/saml/jwt logins Args: @@ -219,20 +216,20 @@ class LoginRestServlet(RestServlet): ( canonical_user_id, callback_3pid, - ) = yield self.auth_handler.check_password_provider_3pid( + ) = await self.auth_handler.check_password_provider_3pid( medium, address, login_submission["password"] ) if canonical_user_id: # Authentication through password provider and 3pid succeeded - result = yield self._complete_login( + result = await self._complete_login( canonical_user_id, login_submission, callback_3pid ) return result # No password providers were able to handle this 3pid # Check local store - user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + user_id = await self.hs.get_datastore().get_user_id_by_threepid( medium, address ) if not user_id: @@ -280,7 +277,7 @@ class LoginRestServlet(RestServlet): ) try: - canonical_user_id, callback = yield self.auth_handler.validate_login( + canonical_user_id, callback = await self.auth_handler.validate_login( identifier["user"], login_submission ) except LoginError: @@ -297,13 +294,12 @@ class LoginRestServlet(RestServlet): ) raise - result = yield self._complete_login( + result = await self._complete_login( canonical_user_id, login_submission, callback ) return result - @defer.inlineCallbacks - def _complete_login( + async def _complete_login( self, user_id, login_submission, callback=None, create_non_existant_users=False ): """Called when we've successfully authed the user and now need to @@ -337,15 +333,15 @@ class LoginRestServlet(RestServlet): ) if create_non_existant_users: - user_id = yield self.auth_handler.check_user_exists(user_id) + user_id = await self.auth_handler.check_user_exists(user_id) if not user_id: - user_id = yield self.registration_handler.register_user( + user_id = await self.registration_handler.register_user( localpart=UserID.from_string(user_id).localpart ) device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") - device_id, access_token = yield self.registration_handler.register_device( + device_id, access_token = await self.registration_handler.register_device( user_id, device_id, initial_display_name ) @@ -357,23 +353,21 @@ class LoginRestServlet(RestServlet): } if callback is not None: - yield callback(result) + await callback(result) return result - @defer.inlineCallbacks - def do_token_login(self, login_submission): + async def do_token_login(self, login_submission): token = login_submission["token"] auth_handler = self.auth_handler - user_id = yield auth_handler.validate_short_term_login_token_and_get_user_id( + user_id = await auth_handler.validate_short_term_login_token_and_get_user_id( token ) - result = yield self._complete_login(user_id, login_submission) + result = await self._complete_login(user_id, login_submission) return result - @defer.inlineCallbacks - def do_jwt_login(self, login_submission): + async def do_jwt_login(self, login_submission): token = login_submission.get("token", None) if token is None: raise LoginError( @@ -397,7 +391,7 @@ class LoginRestServlet(RestServlet): raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user_id = UserID(user, self.hs.hostname).to_string() - result = yield self._complete_login( + result = await self._complete_login( user_id, login_submission, create_non_existant_users=True ) return result @@ -460,8 +454,7 @@ class CasTicketServlet(RestServlet): self._sso_auth_handler = SSOAuthHandler(hs) self._http_client = hs.get_proxied_http_client() - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): client_redirect_url = parse_string(request, "redirectUrl", required=True) uri = self.cas_server_url + "/proxyValidate" args = { @@ -469,12 +462,12 @@ class CasTicketServlet(RestServlet): "service": self.cas_service_url, } try: - body = yield self._http_client.get_raw(uri, args) + body = await self._http_client.get_raw(uri, args) except PartialDownloadError as pde: # Twisted raises this error if the connection is closed, # even if that's being used old-http style to signal end-of-data body = pde.response - result = yield self.handle_cas_response(request, body, client_redirect_url) + result = await self.handle_cas_response(request, body, client_redirect_url) return result def handle_cas_response(self, request, cas_response_body, client_redirect_url): @@ -555,8 +548,7 @@ class SSOAuthHandler(object): self._registration_handler = hs.get_registration_handler() self._macaroon_gen = hs.get_macaroon_generator() - @defer.inlineCallbacks - def on_successful_auth( + async def on_successful_auth( self, username, request, client_redirect_url, user_display_name=None ): """Called once the user has successfully authenticated with the SSO. @@ -582,9 +574,9 @@ class SSOAuthHandler(object): """ localpart = map_username_to_mxid_localpart(username) user_id = UserID(localpart, self._hostname).to_string() - registered_user_id = yield self._auth_handler.check_user_exists(user_id) + registered_user_id = await self._auth_handler.check_user_exists(user_id) if not registered_user_id: - registered_user_id = yield self._registration_handler.register_user( + registered_user_id = await self._registration_handler.register_user( localpart=localpart, default_display_name=user_display_name ) diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index 4785a34d75..1cf3caf832 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import RestServlet from synapse.rest.client.v2_alpha._base import client_patterns @@ -35,17 +33,16 @@ class LogoutRestServlet(RestServlet): def on_OPTIONS(self, request): return 200, {} - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) if requester.device_id is None: # the acccess token wasn't associated with a device. # Just delete the access token access_token = self.auth.get_access_token_from_request(request) - yield self._auth_handler.delete_access_token(access_token) + await self._auth_handler.delete_access_token(access_token) else: - yield self._device_handler.delete_device( + await self._device_handler.delete_device( requester.user.to_string(), requester.device_id ) @@ -64,17 +61,16 @@ class LogoutAllRestServlet(RestServlet): def on_OPTIONS(self, request): return 200, {} - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() # first delete all of the user's devices - yield self._device_handler.delete_all_devices_for_user(user_id) + await self._device_handler.delete_all_devices_for_user(user_id) # .. and then delete any access tokens which weren't associated with # devices. - yield self._auth_handler.delete_access_tokens_for_user(user_id) + await self._auth_handler.delete_access_tokens_for_user(user_id) return 200, {} diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 0153525cef..eec16f8ad8 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -19,8 +19,6 @@ import logging from six import string_types -from twisted.internet import defer - from synapse.api.errors import AuthError, SynapseError from synapse.handlers.presence import format_user_presence_state from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -40,27 +38,25 @@ class PresenceStatusRestServlet(RestServlet): self.clock = hs.get_clock() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, user_id): + requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if requester.user != user: - allowed = yield self.presence_handler.is_visible( + allowed = await self.presence_handler.is_visible( observed_user=user, observer_user=requester.user ) if not allowed: raise AuthError(403, "You are not allowed to see their presence.") - state = yield self.presence_handler.get_state(target_user=user) + state = await self.presence_handler.get_state(target_user=user) state = format_user_presence_state(state, self.clock.time_msec()) return 200, state - @defer.inlineCallbacks - def on_PUT(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if requester.user != user: @@ -86,7 +82,7 @@ class PresenceStatusRestServlet(RestServlet): raise SynapseError(400, "Unable to parse state") if self.hs.config.use_presence: - yield self.presence_handler.set_state(user, state) + await self.presence_handler.set_state(user, state) return 200, {} diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index bbce2e2b71..1eac8a44c5 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -14,7 +14,6 @@ # limitations under the License. """ This module contains REST servlets to do with profile: /profile/<paths> """ -from twisted.internet import defer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.rest.client.v2_alpha._base import client_patterns @@ -30,19 +29,18 @@ class ProfileDisplaynameRestServlet(RestServlet): self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, user_id): + async def on_GET(self, request, user_id): requester_user = None if self.hs.config.require_auth_for_profile_requests: - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) requester_user = requester.user user = UserID.from_string(user_id) - yield self.profile_handler.check_profile_query_allowed(user, requester_user) + await self.profile_handler.check_profile_query_allowed(user, requester_user) - displayname = yield self.profile_handler.get_displayname(user) + displayname = await self.profile_handler.get_displayname(user) ret = {} if displayname is not None: @@ -50,11 +48,10 @@ class ProfileDisplaynameRestServlet(RestServlet): return 200, ret - @defer.inlineCallbacks - def on_PUT(self, request, user_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) user = UserID.from_string(user_id) - is_admin = yield self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester.user) content = parse_json_object_from_request(request) @@ -63,7 +60,7 @@ class ProfileDisplaynameRestServlet(RestServlet): except Exception: return 400, "Unable to parse name" - yield self.profile_handler.set_displayname(user, requester, new_name, is_admin) + await self.profile_handler.set_displayname(user, requester, new_name, is_admin) return 200, {} @@ -80,19 +77,18 @@ class ProfileAvatarURLRestServlet(RestServlet): self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, user_id): + async def on_GET(self, request, user_id): requester_user = None if self.hs.config.require_auth_for_profile_requests: - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) requester_user = requester.user user = UserID.from_string(user_id) - yield self.profile_handler.check_profile_query_allowed(user, requester_user) + await self.profile_handler.check_profile_query_allowed(user, requester_user) - avatar_url = yield self.profile_handler.get_avatar_url(user) + avatar_url = await self.profile_handler.get_avatar_url(user) ret = {} if avatar_url is not None: @@ -100,11 +96,10 @@ class ProfileAvatarURLRestServlet(RestServlet): return 200, ret - @defer.inlineCallbacks - def on_PUT(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id): + requester = await self.auth.get_user_by_req(request) user = UserID.from_string(user_id) - is_admin = yield self.auth.is_server_admin(requester.user) + is_admin = await self.auth.is_server_admin(requester.user) content = parse_json_object_from_request(request) try: @@ -112,7 +107,7 @@ class ProfileAvatarURLRestServlet(RestServlet): except Exception: return 400, "Unable to parse name" - yield self.profile_handler.set_avatar_url(user, requester, new_name, is_admin) + await self.profile_handler.set_avatar_url(user, requester, new_name, is_admin) return 200, {} @@ -129,20 +124,19 @@ class ProfileRestServlet(RestServlet): self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, user_id): + async def on_GET(self, request, user_id): requester_user = None if self.hs.config.require_auth_for_profile_requests: - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) requester_user = requester.user user = UserID.from_string(user_id) - yield self.profile_handler.check_profile_query_allowed(user, requester_user) + await self.profile_handler.check_profile_query_allowed(user, requester_user) - displayname = yield self.profile_handler.get_displayname(user) - avatar_url = yield self.profile_handler.get_avatar_url(user) + displayname = await self.profile_handler.get_displayname(user) + avatar_url = await self.profile_handler.get_avatar_url(user) ret = {} if displayname is not None: diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 9f8c3d09e3..4f74600239 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer from synapse.api.errors import ( NotFoundError, @@ -46,8 +45,7 @@ class PushRuleRestServlet(RestServlet): self.notifier = hs.get_notifier() self._is_worker = hs.config.worker_app is not None - @defer.inlineCallbacks - def on_PUT(self, request, path): + async def on_PUT(self, request, path): if self._is_worker: raise Exception("Cannot handle PUT /push_rules on worker") @@ -57,7 +55,7 @@ class PushRuleRestServlet(RestServlet): except InvalidRuleException as e: raise SynapseError(400, str(e)) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) if "/" in spec["rule_id"] or "\\" in spec["rule_id"]: raise SynapseError(400, "rule_id may not contain slashes") @@ -67,7 +65,7 @@ class PushRuleRestServlet(RestServlet): user_id = requester.user.to_string() if "attr" in spec: - yield self.set_rule_attr(user_id, spec, content) + await self.set_rule_attr(user_id, spec, content) self.notify_user(user_id) return 200, {} @@ -91,7 +89,7 @@ class PushRuleRestServlet(RestServlet): after = _namespaced_rule_id(spec, after) try: - yield self.store.add_push_rule( + await self.store.add_push_rule( user_id=user_id, rule_id=_namespaced_rule_id_from_spec(spec), priority_class=priority_class, @@ -108,20 +106,19 @@ class PushRuleRestServlet(RestServlet): return 200, {} - @defer.inlineCallbacks - def on_DELETE(self, request, path): + async def on_DELETE(self, request, path): if self._is_worker: raise Exception("Cannot handle DELETE /push_rules on worker") spec = _rule_spec_from_path([x for x in path.split("/")]) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() namespaced_rule_id = _namespaced_rule_id_from_spec(spec) try: - yield self.store.delete_push_rule(user_id, namespaced_rule_id) + await self.store.delete_push_rule(user_id, namespaced_rule_id) self.notify_user(user_id) return 200, {} except StoreError as e: @@ -130,15 +127,14 @@ class PushRuleRestServlet(RestServlet): else: raise - @defer.inlineCallbacks - def on_GET(self, request, path): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, path): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() # we build up the full structure and then decide which bits of it # to send which means doing unnecessary work sometimes but is # is probably not going to make a whole lot of difference - rules = yield self.store.get_push_rules_for_user(user_id) + rules = await self.store.get_push_rules_for_user(user_id) rules = format_push_rules_for_user(requester.user, rules) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 41660682d9..0791866f55 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import Codes, StoreError, SynapseError from synapse.http.server import finish_request from synapse.http.servlet import ( @@ -39,12 +37,11 @@ class PushersRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) user = requester.user - pushers = yield self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) + pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) allowed_keys = [ "app_display_name", @@ -78,9 +75,8 @@ class PushersSetRestServlet(RestServlet): self.notifier = hs.get_notifier() self.pusher_pool = self.hs.get_pusherpool() - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) user = requester.user content = parse_json_object_from_request(request) @@ -91,7 +87,7 @@ class PushersSetRestServlet(RestServlet): and "kind" in content and content["kind"] is None ): - yield self.pusher_pool.remove_pusher( + await self.pusher_pool.remove_pusher( content["app_id"], content["pushkey"], user_id=user.to_string() ) return 200, {} @@ -117,14 +113,14 @@ class PushersSetRestServlet(RestServlet): append = content["append"] if not append: - yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( + await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( app_id=content["app_id"], pushkey=content["pushkey"], not_user_id=user.to_string(), ) try: - yield self.pusher_pool.add_pusher( + await self.pusher_pool.add_pusher( user_id=user.to_string(), access_token=requester.access_token_id, kind=content["kind"], @@ -164,16 +160,15 @@ class PushersRemoveRestServlet(RestServlet): self.auth = hs.get_auth() self.pusher_pool = self.hs.get_pusherpool() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, rights="delete_pusher") + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, rights="delete_pusher") user = requester.user app_id = parse_string(request, "app_id", required=True) pushkey = parse_string(request, "pushkey", required=True) try: - yield self.pusher_pool.remove_pusher( + await self.pusher_pool.remove_pusher( app_id=app_id, pushkey=pushkey, user_id=user.to_string() ) except StoreError as se: diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 86bbcc0eea..711d4ad304 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -714,7 +714,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): target = UserID.from_string(content["user_id"]) event_content = None - if "reason" in content and membership_action in ["kick", "ban"]: + if "reason" in content: event_content = {"reason": content["reason"]} await self.room_member_handler.update_membership( diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 2afdbb89e5..747d46eac2 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -17,8 +17,6 @@ import base64 import hashlib import hmac -from twisted.internet import defer - from synapse.http.servlet import RestServlet from synapse.rest.client.v2_alpha._base import client_patterns @@ -31,9 +29,8 @@ class VoipRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req( + async def on_GET(self, request): + requester = await self.auth.get_user_by_req( request, self.hs.config.turn_allow_guests ) diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 8250ae0ae1..2a3f4dd58f 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -78,7 +78,7 @@ def interactive_auth_handler(orig): """ def wrapped(*args, **kwargs): - res = defer.maybeDeferred(orig, *args, **kwargs) + res = defer.ensureDeferred(orig(*args, **kwargs)) res.addErrback(_catch_incomplete_interactive_auth) return res diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index f26eae794c..fc240f5cf8 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -18,8 +18,6 @@ import logging from six.moves import http_client -from twisted.internet import defer - from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError, ThreepidValidationError from synapse.config.emailconfig import ThreepidBehaviour @@ -67,8 +65,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): template_text=template_text, ) - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -95,7 +92,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email ) @@ -106,7 +103,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): assert self.hs.config.account_threepid_delegate_email # Have the configured identity server handle the request - ret = yield self.identity_handler.requestEmailToken( + ret = await self.identity_handler.requestEmailToken( self.hs.config.account_threepid_delegate_email, email, client_secret, @@ -115,7 +112,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): ) else: # Send password reset emails from Synapse - sid = yield self.identity_handler.send_threepid_validation( + sid = await self.identity_handler.send_threepid_validation( email, client_secret, send_attempt, @@ -153,8 +150,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): [self.config.email_password_reset_template_failure_html], ) - @defer.inlineCallbacks - def on_GET(self, request, medium): + async def on_GET(self, request, medium): # We currently only handle threepid token submissions for email if medium != "email": raise SynapseError( @@ -176,7 +172,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): # Attempt to validate a 3PID session try: # Mark the session as valid - next_link = yield self.store.validate_threepid_session( + next_link = await self.store.validate_threepid_session( sid, client_secret, token, self.clock.time_msec() ) @@ -218,8 +214,7 @@ class PasswordRestServlet(RestServlet): self._set_password_handler = hs.get_set_password_handler() @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) # there are two possibilities here. Either the user does not have an @@ -233,14 +228,14 @@ class PasswordRestServlet(RestServlet): # In the second case, we require a password to confirm their identity. if self.auth.has_access_token(request): - requester = yield self.auth.get_user_by_req(request) - params = yield self.auth_handler.validate_user_via_ui_auth( + requester = await self.auth.get_user_by_req(request) + params = await self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request) ) user_id = requester.user.to_string() else: requester = None - result, params, _ = yield self.auth_handler.check_auth( + result, params, _ = await self.auth_handler.check_auth( [[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request) ) @@ -254,7 +249,7 @@ class PasswordRestServlet(RestServlet): # (See add_threepid in synapse/handlers/auth.py) threepid["address"] = threepid["address"].lower() # if using email, we must know about the email they're authing with! - threepid_user_id = yield self.datastore.get_user_id_by_threepid( + threepid_user_id = await self.datastore.get_user_id_by_threepid( threepid["medium"], threepid["address"] ) if not threepid_user_id: @@ -267,7 +262,7 @@ class PasswordRestServlet(RestServlet): assert_params_in_dict(params, ["new_password"]) new_password = params["new_password"] - yield self._set_password_handler.set_password(user_id, new_password, requester) + await self._set_password_handler.set_password(user_id, new_password, requester) return 200, {} @@ -286,8 +281,7 @@ class DeactivateAccountRestServlet(RestServlet): self._deactivate_account_handler = hs.get_deactivate_account_handler() @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) erase = body.get("erase", False) if not isinstance(erase, bool): @@ -297,19 +291,19 @@ class DeactivateAccountRestServlet(RestServlet): Codes.BAD_JSON, ) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) # allow ASes to dectivate their own users if requester.app_service: - yield self._deactivate_account_handler.deactivate_account( + await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase ) return 200, {} - yield self.auth_handler.validate_user_via_ui_auth( + await self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request) ) - result = yield self._deactivate_account_handler.deactivate_account( + result = await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase, id_server=body.get("id_server") ) if result: @@ -346,8 +340,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): template_text=template_text, ) - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -371,7 +364,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.store.get_user_id_by_threepid( + existing_user_id = await self.store.get_user_id_by_threepid( "email", body["email"] ) @@ -382,7 +375,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): assert self.hs.config.account_threepid_delegate_email # Have the configured identity server handle the request - ret = yield self.identity_handler.requestEmailToken( + ret = await self.identity_handler.requestEmailToken( self.hs.config.account_threepid_delegate_email, email, client_secret, @@ -391,7 +384,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): ) else: # Send threepid validation emails from Synapse - sid = yield self.identity_handler.send_threepid_validation( + sid = await self.identity_handler.send_threepid_validation( email, client_secret, send_attempt, @@ -414,8 +407,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): self.store = self.hs.get_datastore() self.identity_handler = hs.get_handlers().identity_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict( body, ["client_secret", "country", "phone_number", "send_attempt"] @@ -435,7 +427,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.store.get_user_id_by_threepid("msisdn", msisdn) + existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) if existing_user_id is not None: raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) @@ -450,7 +442,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): "Adding phone numbers to user account is not supported by this homeserver", ) - ret = yield self.identity_handler.requestMsisdnToken( + ret = await self.identity_handler.requestMsisdnToken( self.hs.config.account_threepid_delegate_msisdn, country, phone_number, @@ -484,8 +476,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): [self.config.email_add_threepid_template_failure_html], ) - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -508,7 +499,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): # Attempt to validate a 3PID session try: # Mark the session as valid - next_link = yield self.store.validate_threepid_session( + next_link = await self.store.validate_threepid_session( sid, client_secret, token, self.clock.time_msec() ) @@ -558,8 +549,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): self.store = hs.get_datastore() self.identity_handler = hs.get_handlers().identity_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if not self.config.account_threepid_delegate_msisdn: raise SynapseError( 400, @@ -571,7 +561,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): assert_params_in_dict(body, ["client_secret", "sid", "token"]) # Proxy submit_token request to msisdn threepid delegate - response = yield self.identity_handler.proxy_msisdn_submit_token( + response = await self.identity_handler.proxy_msisdn_submit_token( self.config.account_threepid_delegate_msisdn, body["client_secret"], body["sid"], @@ -591,17 +581,15 @@ class ThreepidRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.datastore = self.hs.get_datastore() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) - threepids = yield self.datastore.user_get_threepids(requester.user.to_string()) + threepids = await self.datastore.user_get_threepids(requester.user.to_string()) return 200, {"threepids": threepids} - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -615,11 +603,11 @@ class ThreepidRestServlet(RestServlet): client_secret = threepid_creds["client_secret"] sid = threepid_creds["sid"] - validation_session = yield self.identity_handler.validate_threepid_session( + validation_session = await self.identity_handler.validate_threepid_session( client_secret, sid ) if validation_session: - yield self.auth_handler.add_threepid( + await self.auth_handler.add_threepid( user_id, validation_session["medium"], validation_session["address"], @@ -642,9 +630,9 @@ class ThreepidAddRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + @interactive_auth_handler + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -652,11 +640,15 @@ class ThreepidAddRestServlet(RestServlet): client_secret = body["client_secret"] sid = body["sid"] - validation_session = yield self.identity_handler.validate_threepid_session( + await self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request) + ) + + validation_session = await self.identity_handler.validate_threepid_session( client_secret, sid ) if validation_session: - yield self.auth_handler.add_threepid( + await self.auth_handler.add_threepid( user_id, validation_session["medium"], validation_session["address"], @@ -678,8 +670,7 @@ class ThreepidBindRestServlet(RestServlet): self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict(body, ["id_server", "sid", "client_secret"]) @@ -688,10 +679,10 @@ class ThreepidBindRestServlet(RestServlet): client_secret = body["client_secret"] id_access_token = body.get("id_access_token") # optional - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() - yield self.identity_handler.bind_threepid( + await self.identity_handler.bind_threepid( client_secret, sid, user_id, id_server, id_access_token ) @@ -708,12 +699,11 @@ class ThreepidUnbindRestServlet(RestServlet): self.auth = hs.get_auth() self.datastore = self.hs.get_datastore() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): """Unbind the given 3pid from a specific identity server, or identity servers that are known to have this 3pid bound """ - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) body = parse_json_object_from_request(request) assert_params_in_dict(body, ["medium", "address"]) @@ -723,7 +713,7 @@ class ThreepidUnbindRestServlet(RestServlet): # Attempt to unbind the threepid from an identity server. If id_server is None, try to # unbind from all identity servers this threepid has been added to in the past - result = yield self.identity_handler.try_unbind_threepid( + result = await self.identity_handler.try_unbind_threepid( requester.user.to_string(), {"address": address, "medium": medium, "id_server": id_server}, ) @@ -738,16 +728,15 @@ class ThreepidDeleteRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict(body, ["medium", "address"]) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() try: - ret = yield self.auth_handler.delete_threepid( + ret = await self.auth_handler.delete_threepid( user_id, body["medium"], body["address"], body.get("id_server") ) except Exception: @@ -772,9 +761,8 @@ class WhoamiRestServlet(RestServlet): super(WhoamiRestServlet, self).__init__() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) return 200, {"user_id": requester.user.to_string()} diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index f0db204ffa..64eb7fec3b 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -41,15 +39,14 @@ class AccountDataServlet(RestServlet): self.store = hs.get_datastore() self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def on_PUT(self, request, user_id, account_data_type): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id, account_data_type): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") body = parse_json_object_from_request(request) - max_id = yield self.store.add_account_data_for_user( + max_id = await self.store.add_account_data_for_user( user_id, account_data_type, body ) @@ -57,13 +54,12 @@ class AccountDataServlet(RestServlet): return 200, {} - @defer.inlineCallbacks - def on_GET(self, request, user_id, account_data_type): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, user_id, account_data_type): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") - event = yield self.store.get_global_account_data_by_type_for_user( + event = await self.store.get_global_account_data_by_type_for_user( account_data_type, user_id ) @@ -91,9 +87,8 @@ class RoomAccountDataServlet(RestServlet): self.store = hs.get_datastore() self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def on_PUT(self, request, user_id, room_id, account_data_type): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id, room_id, account_data_type): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") @@ -106,7 +101,7 @@ class RoomAccountDataServlet(RestServlet): " Use /rooms/!roomId:server.name/read_markers", ) - max_id = yield self.store.add_account_data_to_room( + max_id = await self.store.add_account_data_to_room( user_id, room_id, account_data_type, body ) @@ -114,13 +109,12 @@ class RoomAccountDataServlet(RestServlet): return 200, {} - @defer.inlineCallbacks - def on_GET(self, request, user_id, room_id, account_data_type): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, user_id, room_id, account_data_type): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") - event = yield self.store.get_account_data_for_room_and_type( + event = await self.store.get_account_data_for_room_and_type( user_id, room_id, account_data_type ) diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 33f6a23028..2f10fa64e2 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError, SynapseError from synapse.http.server import finish_request from synapse.http.servlet import RestServlet @@ -45,13 +43,12 @@ class AccountValidityRenewServlet(RestServlet): self.success_html = hs.config.account_validity.account_renewed_html_content self.failure_html = hs.config.account_validity.invalid_token_html_content - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): if b"token" not in request.args: raise SynapseError(400, "Missing renewal token") renewal_token = request.args[b"token"][0] - token_valid = yield self.account_activity_handler.renew_account( + token_valid = await self.account_activity_handler.renew_account( renewal_token.decode("utf8") ) @@ -67,7 +64,6 @@ class AccountValidityRenewServlet(RestServlet): request.setHeader(b"Content-Length", b"%d" % (len(response),)) request.write(response.encode("utf8")) finish_request(request) - defer.returnValue(None) class AccountValiditySendMailServlet(RestServlet): @@ -85,18 +81,17 @@ class AccountValiditySendMailServlet(RestServlet): self.auth = hs.get_auth() self.account_validity = self.hs.config.account_validity - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if not self.account_validity.renew_by_email_enabled: raise AuthError( 403, "Account renewal via email is disabled on this server." ) - requester = yield self.auth.get_user_by_req(request, allow_expired=True) + requester = await self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() - yield self.account_activity_handler.send_renewal_email_to_user(user_id) + await self.account_activity_handler.send_renewal_email_to_user(user_id) - defer.returnValue((200, {})) + return 200, {} def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index f21aff39e5..7a256b6ecb 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import LoginType from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_API_PREFIX @@ -171,8 +169,7 @@ class AuthRestServlet(RestServlet): else: raise SynapseError(404, "Unknown auth stage type") - @defer.inlineCallbacks - def on_POST(self, request, stagetype): + async def on_POST(self, request, stagetype): session = parse_string(request, "session") if not session: @@ -186,7 +183,7 @@ class AuthRestServlet(RestServlet): authdict = {"response": response, "session": session} - success = yield self.auth_handler.add_oob_auth( + success = await self.auth_handler.add_oob_auth( LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request) ) @@ -215,7 +212,7 @@ class AuthRestServlet(RestServlet): session = request.args["session"][0] authdict = {"session": session} - success = yield self.auth_handler.add_oob_auth( + success = await self.auth_handler.add_oob_auth( LoginType.TERMS, authdict, self.hs.get_ip_from_request(request) ) diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py index acd58af193..fe9d019c44 100644 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ b/synapse/rest/client/v2_alpha/capabilities.py @@ -14,8 +14,6 @@ # limitations under the License. import logging -from twisted.internet import defer - from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import RestServlet @@ -40,10 +38,9 @@ class CapabilitiesRestServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) - user = yield self.store.get_user_by_id(requester.user.to_string()) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + user = await self.store.get_user_by_id(requester.user.to_string()) change_password = bool(user["password_hash"]) response = { diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 26d0235208..94ff73f384 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api import errors from synapse.http.servlet import ( RestServlet, @@ -42,10 +40,9 @@ class DevicesRestServlet(RestServlet): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) - devices = yield self.device_handler.get_devices_by_user( + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + devices = await self.device_handler.get_devices_by_user( requester.user.to_string() ) return 200, {"devices": devices} @@ -67,9 +64,8 @@ class DeleteDevicesRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) try: body = parse_json_object_from_request(request) @@ -84,11 +80,11 @@ class DeleteDevicesRestServlet(RestServlet): assert_params_in_dict(body, ["devices"]) - yield self.auth_handler.validate_user_via_ui_auth( + await self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request) ) - yield self.device_handler.delete_devices( + await self.device_handler.delete_devices( requester.user.to_string(), body["devices"] ) return 200, {} @@ -108,18 +104,16 @@ class DeviceRestServlet(RestServlet): self.device_handler = hs.get_device_handler() self.auth_handler = hs.get_auth_handler() - @defer.inlineCallbacks - def on_GET(self, request, device_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) - device = yield self.device_handler.get_device( + async def on_GET(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) + device = await self.device_handler.get_device( requester.user.to_string(), device_id ) return 200, device @interactive_auth_handler - @defer.inlineCallbacks - def on_DELETE(self, request, device_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, device_id): + requester = await self.auth.get_user_by_req(request) try: body = parse_json_object_from_request(request) @@ -132,19 +126,18 @@ class DeviceRestServlet(RestServlet): else: raise - yield self.auth_handler.validate_user_via_ui_auth( + await self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request) ) - yield self.device_handler.delete_device(requester.user.to_string(), device_id) + await self.device_handler.delete_device(requester.user.to_string(), device_id) return 200, {} - @defer.inlineCallbacks - def on_PUT(self, request, device_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_PUT(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) body = parse_json_object_from_request(request) - yield self.device_handler.update_device( + await self.device_handler.update_device( requester.user.to_string(), device_id, body ) return 200, {} diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 17a8bc7366..b28da017cd 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID @@ -35,10 +33,9 @@ class GetFilterRestServlet(RestServlet): self.auth = hs.get_auth() self.filtering = hs.get_filtering() - @defer.inlineCallbacks - def on_GET(self, request, user_id, filter_id): + async def on_GET(self, request, user_id, filter_id): target_user = UserID.from_string(user_id) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) if target_user != requester.user: raise AuthError(403, "Cannot get filters for other users") @@ -52,7 +49,7 @@ class GetFilterRestServlet(RestServlet): raise SynapseError(400, "Invalid filter_id") try: - filter_collection = yield self.filtering.get_user_filter( + filter_collection = await self.filtering.get_user_filter( user_localpart=target_user.localpart, filter_id=filter_id ) except StoreError as e: @@ -72,11 +69,10 @@ class CreateFilterRestServlet(RestServlet): self.auth = hs.get_auth() self.filtering = hs.get_filtering() - @defer.inlineCallbacks - def on_POST(self, request, user_id): + async def on_POST(self, request, user_id): target_user = UserID.from_string(user_id) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) if target_user != requester.user: raise AuthError(403, "Cannot create filters for other users") @@ -87,7 +83,7 @@ class CreateFilterRestServlet(RestServlet): content = parse_json_object_from_request(request) set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit) - filter_id = yield self.filtering.add_user_filter( + filter_id = await self.filtering.add_user_filter( user_localpart=target_user.localpart, user_filter=content ) diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 999a0fa80c..d84a6d7e11 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import GroupID @@ -38,24 +36,22 @@ class GroupServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - group_description = yield self.groups_handler.get_group_profile( + group_description = await self.groups_handler.get_group_profile( group_id, requester_user_id ) return 200, group_description - @defer.inlineCallbacks - def on_POST(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - yield self.groups_handler.update_group_profile( + await self.groups_handler.update_group_profile( group_id, requester_user_id, content ) @@ -74,12 +70,11 @@ class GroupSummaryServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - get_group_summary = yield self.groups_handler.get_group_summary( + get_group_summary = await self.groups_handler.get_group_summary( group_id, requester_user_id ) @@ -106,13 +101,12 @@ class GroupSummaryRoomsCatServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, category_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, category_id, room_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - resp = yield self.groups_handler.update_group_summary_room( + resp = await self.groups_handler.update_group_summary_room( group_id, requester_user_id, room_id=room_id, @@ -122,12 +116,11 @@ class GroupSummaryRoomsCatServlet(RestServlet): return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, category_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, category_id, room_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - resp = yield self.groups_handler.delete_group_summary_room( + resp = await self.groups_handler.delete_group_summary_room( group_id, requester_user_id, room_id=room_id, category_id=category_id ) @@ -148,35 +141,32 @@ class GroupCategoryServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id, category_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id, category_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - category = yield self.groups_handler.get_group_category( + category = await self.groups_handler.get_group_category( group_id, requester_user_id, category_id=category_id ) return 200, category - @defer.inlineCallbacks - def on_PUT(self, request, group_id, category_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, category_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - resp = yield self.groups_handler.update_group_category( + resp = await self.groups_handler.update_group_category( group_id, requester_user_id, category_id=category_id, content=content ) return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, category_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, category_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - resp = yield self.groups_handler.delete_group_category( + resp = await self.groups_handler.delete_group_category( group_id, requester_user_id, category_id=category_id ) @@ -195,12 +185,11 @@ class GroupCategoriesServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - category = yield self.groups_handler.get_group_categories( + category = await self.groups_handler.get_group_categories( group_id, requester_user_id ) @@ -219,35 +208,32 @@ class GroupRoleServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id, role_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id, role_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - category = yield self.groups_handler.get_group_role( + category = await self.groups_handler.get_group_role( group_id, requester_user_id, role_id=role_id ) return 200, category - @defer.inlineCallbacks - def on_PUT(self, request, group_id, role_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, role_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - resp = yield self.groups_handler.update_group_role( + resp = await self.groups_handler.update_group_role( group_id, requester_user_id, role_id=role_id, content=content ) return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, role_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, role_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - resp = yield self.groups_handler.delete_group_role( + resp = await self.groups_handler.delete_group_role( group_id, requester_user_id, role_id=role_id ) @@ -266,12 +252,11 @@ class GroupRolesServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - category = yield self.groups_handler.get_group_roles( + category = await self.groups_handler.get_group_roles( group_id, requester_user_id ) @@ -298,13 +283,12 @@ class GroupSummaryUsersRoleServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, role_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, role_id, user_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - resp = yield self.groups_handler.update_group_summary_user( + resp = await self.groups_handler.update_group_summary_user( group_id, requester_user_id, user_id=user_id, @@ -314,12 +298,11 @@ class GroupSummaryUsersRoleServlet(RestServlet): return 200, resp - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, role_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, role_id, user_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - resp = yield self.groups_handler.delete_group_summary_user( + resp = await self.groups_handler.delete_group_summary_user( group_id, requester_user_id, user_id=user_id, role_id=role_id ) @@ -338,12 +321,11 @@ class GroupRoomServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_rooms_in_group( + result = await self.groups_handler.get_rooms_in_group( group_id, requester_user_id ) @@ -362,12 +344,11 @@ class GroupUsersServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_users_in_group( + result = await self.groups_handler.get_users_in_group( group_id, requester_user_id ) @@ -386,12 +367,11 @@ class GroupInvitedUsersServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_invited_users_in_group( + result = await self.groups_handler.get_invited_users_in_group( group_id, requester_user_id ) @@ -409,14 +389,13 @@ class GroupSettingJoinPolicyServlet(RestServlet): self.auth = hs.get_auth() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.set_group_join_policy( + result = await self.groups_handler.set_group_join_policy( group_id, requester_user_id, content ) @@ -436,9 +415,8 @@ class GroupCreateServlet(RestServlet): self.groups_handler = hs.get_groups_local_handler() self.server_name = hs.hostname - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() # TODO: Create group on remote server @@ -446,7 +424,7 @@ class GroupCreateServlet(RestServlet): localpart = content.pop("localpart") group_id = GroupID(localpart, self.server_name).to_string() - result = yield self.groups_handler.create_group( + result = await self.groups_handler.create_group( group_id, requester_user_id, content ) @@ -467,24 +445,22 @@ class GroupAdminRoomsServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, room_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.add_room_to_group( + result = await self.groups_handler.add_room_to_group( group_id, requester_user_id, room_id, content ) return 200, result - @defer.inlineCallbacks - def on_DELETE(self, request, group_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, group_id, room_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.remove_room_from_group( + result = await self.groups_handler.remove_room_from_group( group_id, requester_user_id, room_id ) @@ -506,13 +482,12 @@ class GroupAdminRoomsConfigServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, room_id, config_key): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, room_id, config_key): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.update_room_in_group( + result = await self.groups_handler.update_room_in_group( group_id, requester_user_id, room_id, config_key, content ) @@ -535,14 +510,13 @@ class GroupAdminUsersInviteServlet(RestServlet): self.store = hs.get_datastore() self.is_mine_id = hs.is_mine_id - @defer.inlineCallbacks - def on_PUT(self, request, group_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, user_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) config = content.get("config", {}) - result = yield self.groups_handler.invite( + result = await self.groups_handler.invite( group_id, user_id, requester_user_id, config ) @@ -563,13 +537,12 @@ class GroupAdminUsersKickServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id, user_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.remove_user_from_group( + result = await self.groups_handler.remove_user_from_group( group_id, user_id, requester_user_id, content ) @@ -588,13 +561,12 @@ class GroupSelfLeaveServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.remove_user_from_group( + result = await self.groups_handler.remove_user_from_group( group_id, requester_user_id, requester_user_id, content ) @@ -613,13 +585,12 @@ class GroupSelfJoinServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.join_group( + result = await self.groups_handler.join_group( group_id, requester_user_id, content ) @@ -638,13 +609,12 @@ class GroupSelfAcceptInviteServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) - result = yield self.groups_handler.accept_invite( + result = await self.groups_handler.accept_invite( group_id, requester_user_id, content ) @@ -663,14 +633,13 @@ class GroupSelfUpdatePublicityServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_PUT(self, request, group_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, group_id): + requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) publicise = content["publicise"] - yield self.store.update_group_publicity(group_id, requester_user_id, publicise) + await self.store.update_group_publicity(group_id, requester_user_id, publicise) return 200, {} @@ -688,11 +657,10 @@ class PublicisedGroupsForUserServlet(RestServlet): self.store = hs.get_datastore() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request, user_id): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, user_id): + await self.auth.get_user_by_req(request, allow_guest=True) - result = yield self.groups_handler.get_publicised_groups_for_user(user_id) + result = await self.groups_handler.get_publicised_groups_for_user(user_id) return 200, result @@ -710,14 +678,13 @@ class PublicisedGroupsForUsersServlet(RestServlet): self.store = hs.get_datastore() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_POST(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) user_ids = content["user_ids"] - result = yield self.groups_handler.bulk_get_publicised_groups(user_ids) + result = await self.groups_handler.bulk_get_publicised_groups(user_ids) return 200, result @@ -734,12 +701,11 @@ class GroupsForUserServlet(RestServlet): self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_joined_groups(requester_user_id) + result = await self.groups_handler.get_joined_groups(requester_user_id) return 200, result diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 341567ae21..f7ed4daf90 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import ( RestServlet, @@ -71,9 +69,8 @@ class KeyUploadServlet(RestServlet): self.e2e_keys_handler = hs.get_e2e_keys_handler() @trace(opname="upload_keys") - @defer.inlineCallbacks - def on_POST(self, request, device_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, device_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -103,7 +100,7 @@ class KeyUploadServlet(RestServlet): 400, "To upload keys, you must pass device_id when authenticating" ) - result = yield self.e2e_keys_handler.upload_keys_for_user( + result = await self.e2e_keys_handler.upload_keys_for_user( user_id, device_id, body ) return 200, result @@ -154,13 +151,12 @@ class KeyQueryServlet(RestServlet): self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.query_devices(body, timeout, user_id) + result = await self.e2e_keys_handler.query_devices(body, timeout, user_id) return 200, result @@ -185,9 +181,8 @@ class KeyChangesServlet(RestServlet): self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) from_token_string = parse_string(request, "from") set_tag("from", from_token_string) @@ -200,7 +195,7 @@ class KeyChangesServlet(RestServlet): user_id = requester.user.to_string() - results = yield self.device_handler.get_user_ids_changed(user_id, from_token) + results = await self.device_handler.get_user_ids_changed(user_id, from_token) return 200, results @@ -231,12 +226,11 @@ class OneTimeKeyServlet(RestServlet): self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - @defer.inlineCallbacks - def on_POST(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.claim_one_time_keys(body, timeout) + result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout) return 200, result @@ -263,17 +257,16 @@ class SigningKeyUploadServlet(RestServlet): self.auth_handler = hs.get_auth_handler() @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) - yield self.auth_handler.validate_user_via_ui_auth( + await self.auth_handler.validate_user_via_ui_auth( requester, body, self.hs.get_ip_from_request(request) ) - result = yield self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) + result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) return 200, result @@ -315,13 +308,12 @@ class SignaturesUploadServlet(RestServlet): self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) user_id = requester.user.to_string() body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.upload_signatures_for_device_keys( + result = await self.e2e_keys_handler.upload_signatures_for_device_keys( user_id, body ) return 200, result diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index 10c1ad5b07..aa911d75ee 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.http.servlet import RestServlet, parse_integer, parse_string @@ -35,9 +33,8 @@ class NotificationsServlet(RestServlet): self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() from_token = parse_string(request, "from", required=False) @@ -46,16 +43,16 @@ class NotificationsServlet(RestServlet): limit = min(limit, 500) - push_actions = yield self.store.get_push_actions_for_user( + push_actions = await self.store.get_push_actions_for_user( user_id, from_token, limit, only_highlight=(only == "highlight") ) - receipts_by_room = yield self.store.get_receipts_for_user_with_orderings( + receipts_by_room = await self.store.get_receipts_for_user_with_orderings( user_id, "m.read" ) notif_event_ids = [pa["event_id"] for pa in push_actions] - notif_events = yield self.store.get_events(notif_event_ids) + notif_events = await self.store.get_events(notif_event_ids) returned_push_actions = [] @@ -68,7 +65,7 @@ class NotificationsServlet(RestServlet): "actions": pa["actions"], "ts": pa["received_ts"], "event": ( - yield self._event_serializer.serialize_event( + await self._event_serializer.serialize_event( notif_events[pa["event_id"]], self.clock.time_msec(), event_format=format_event_for_client_v2_without_room_id, diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py index b4925c0f59..6ae9a5a8e9 100644 --- a/synapse/rest/client/v2_alpha/openid.py +++ b/synapse/rest/client/v2_alpha/openid.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.util.stringutils import random_string @@ -68,9 +66,8 @@ class IdTokenServlet(RestServlet): self.clock = hs.get_clock() self.server_name = hs.config.server_name - @defer.inlineCallbacks - def on_POST(self, request, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, user_id): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot request tokens for other users.") @@ -81,7 +78,7 @@ class IdTokenServlet(RestServlet): token = random_string(24) ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS - yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) + await self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) return ( 200, diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 91db923814..66de16a1fa 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -20,8 +20,6 @@ from typing import List, Union from six import string_types -from twisted.internet import defer - import synapse import synapse.types from synapse.api.constants import LoginType @@ -102,8 +100,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): template_text=template_text, ) - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.hs.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -129,7 +126,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", body["email"] ) @@ -140,7 +137,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): assert self.hs.config.account_threepid_delegate_email # Have the configured identity server handle the request - ret = yield self.identity_handler.requestEmailToken( + ret = await self.identity_handler.requestEmailToken( self.hs.config.account_threepid_delegate_email, email, client_secret, @@ -149,7 +146,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): ) else: # Send registration emails from Synapse - sid = yield self.identity_handler.send_threepid_validation( + sid = await self.identity_handler.send_threepid_validation( email, client_secret, send_attempt, @@ -175,8 +172,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): self.hs = hs self.identity_handler = hs.get_handlers().identity_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict( @@ -197,7 +193,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "msisdn", msisdn ) @@ -215,7 +211,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): 400, "Registration by phone number is not supported on this homeserver" ) - ret = yield self.identity_handler.requestMsisdnToken( + ret = await self.identity_handler.requestMsisdnToken( self.hs.config.account_threepid_delegate_msisdn, country, phone_number, @@ -258,8 +254,7 @@ class RegistrationSubmitTokenServlet(RestServlet): [self.config.email_registration_template_failure_html], ) - @defer.inlineCallbacks - def on_GET(self, request, medium): + async def on_GET(self, request, medium): if medium != "email": raise SynapseError( 400, "This medium is currently not supported for registration" @@ -280,7 +275,7 @@ class RegistrationSubmitTokenServlet(RestServlet): # Attempt to validate a 3PID session try: # Mark the session as valid - next_link = yield self.store.validate_threepid_session( + next_link = await self.store.validate_threepid_session( sid, client_secret, token, self.clock.time_msec() ) @@ -338,8 +333,7 @@ class UsernameAvailabilityRestServlet(RestServlet): ), ) - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): if not self.hs.config.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN @@ -347,11 +341,11 @@ class UsernameAvailabilityRestServlet(RestServlet): ip = self.hs.get_ip_from_request(request) with self.ratelimiter.ratelimit(ip) as wait_deferred: - yield wait_deferred + await wait_deferred username = parse_string(request, "username", required=True) - yield self.registration_handler.check_username(username) + await self.registration_handler.check_username(username) return 200, {"available": True} @@ -382,8 +376,7 @@ class RegisterRestServlet(RestServlet): ) @interactive_auth_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): body = parse_json_object_from_request(request) client_addr = request.getClientIP() @@ -408,7 +401,7 @@ class RegisterRestServlet(RestServlet): kind = request.args[b"kind"][0] if kind == b"guest": - ret = yield self._do_guest_registration(body, address=client_addr) + ret = await self._do_guest_registration(body, address=client_addr) return ret elif kind != b"user": raise UnrecognizedRequestError( @@ -435,7 +428,7 @@ class RegisterRestServlet(RestServlet): appservice = None if self.auth.has_access_token(request): - appservice = yield self.auth.get_appservice_by_req(request) + appservice = await self.auth.get_appservice_by_req(request) # fork off as soon as possible for ASes which have completely # different registration flows to normal users @@ -455,7 +448,7 @@ class RegisterRestServlet(RestServlet): access_token = self.auth.get_access_token_from_request(request) if isinstance(desired_username, string_types): - result = yield self._do_appservice_registration( + result = await self._do_appservice_registration( desired_username, access_token, body ) return 200, result # we throw for non 200 responses @@ -495,13 +488,13 @@ class RegisterRestServlet(RestServlet): ) if desired_username is not None: - yield self.registration_handler.check_username( + await self.registration_handler.check_username( desired_username, guest_access_token=guest_access_token, assigned_user_id=registered_user_id, ) - auth_result, params, session_id = yield self.auth_handler.check_auth( + auth_result, params, session_id = await self.auth_handler.check_auth( self._registration_flows, body, self.hs.get_ip_from_request(request) ) @@ -557,7 +550,7 @@ class RegisterRestServlet(RestServlet): medium = auth_result[login_type]["medium"] address = auth_result[login_type]["address"] - existing_user_id = yield self.store.get_user_id_by_threepid( + existing_user_id = await self.store.get_user_id_by_threepid( medium, address ) @@ -568,7 +561,7 @@ class RegisterRestServlet(RestServlet): Codes.THREEPID_IN_USE, ) - registered_user_id = yield self.registration_handler.register_user( + registered_user_id = await self.registration_handler.register_user( localpart=desired_username, password=new_password, guest_access_token=guest_access_token, @@ -581,7 +574,7 @@ class RegisterRestServlet(RestServlet): if is_threepid_reserved( self.hs.config.mau_limits_reserved_threepids, threepid ): - yield self.store.upsert_monthly_active_user(registered_user_id) + await self.store.upsert_monthly_active_user(registered_user_id) # remember that we've now registered that user account, and with # what user ID (since the user may not have specified) @@ -591,12 +584,12 @@ class RegisterRestServlet(RestServlet): registered = True - return_dict = yield self._create_registration_details( + return_dict = await self._create_registration_details( registered_user_id, params ) if registered: - yield self.registration_handler.post_registration_actions( + await self.registration_handler.post_registration_actions( user_id=registered_user_id, auth_result=auth_result, access_token=return_dict.get("access_token"), @@ -607,15 +600,13 @@ class RegisterRestServlet(RestServlet): def on_OPTIONS(self, _): return 200, {} - @defer.inlineCallbacks - def _do_appservice_registration(self, username, as_token, body): - user_id = yield self.registration_handler.appservice_register( + async def _do_appservice_registration(self, username, as_token, body): + user_id = await self.registration_handler.appservice_register( username, as_token ) - return (yield self._create_registration_details(user_id, body)) + return await self._create_registration_details(user_id, body) - @defer.inlineCallbacks - def _create_registration_details(self, user_id, params): + async def _create_registration_details(self, user_id, params): """Complete registration of newly-registered user Allocates device_id if one was not given; also creates access_token. @@ -631,18 +622,17 @@ class RegisterRestServlet(RestServlet): if not params.get("inhibit_login", False): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") - device_id, access_token = yield self.registration_handler.register_device( + device_id, access_token = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=False ) result.update({"access_token": access_token, "device_id": device_id}) return result - @defer.inlineCallbacks - def _do_guest_registration(self, params, address=None): + async def _do_guest_registration(self, params, address=None): if not self.hs.config.allow_guest_access: raise SynapseError(403, "Guest access is disabled") - user_id = yield self.registration_handler.register_user( + user_id = await self.registration_handler.register_user( make_guest=True, address=address ) @@ -650,7 +640,7 @@ class RegisterRestServlet(RestServlet): # we have nowhere to store it. device_id = synapse.api.auth.GUEST_DEVICE_ID initial_display_name = params.get("initial_device_display_name") - device_id, access_token = yield self.registration_handler.register_device( + device_id, access_token = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=True ) diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 040b37c504..9be9a34b91 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -21,8 +21,6 @@ any time to reflect changes in the MSC. import logging -from twisted.internet import defer - from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.http.servlet import ( @@ -86,11 +84,10 @@ class RelationSendServlet(RestServlet): request, self.on_PUT_or_POST, request, *args, **kwargs ) - @defer.inlineCallbacks - def on_PUT_or_POST( + async def on_PUT_or_POST( self, request, room_id, parent_id, relation_type, event_type, txn_id=None ): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = await self.auth.get_user_by_req(request, allow_guest=True) if event_type == EventTypes.Member: # Add relations to a membership is meaningless, so we just deny it @@ -114,7 +111,7 @@ class RelationSendServlet(RestServlet): "sender": requester.user.to_string(), } - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict=event_dict, txn_id=txn_id ) @@ -140,17 +137,18 @@ class RelationPaginationServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() - @defer.inlineCallbacks - def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET( + self, request, room_id, parent_id, relation_type=None, event_type=None + ): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - yield self.auth.check_in_room_or_world_readable( + await self.auth.check_in_room_or_world_readable( room_id, requester.user.to_string() ) # This gets the original event and checks that a) the event exists and # b) the user is allowed to view it. - event = yield self.event_handler.get_event(requester.user, room_id, parent_id) + event = await self.event_handler.get_event(requester.user, room_id, parent_id) limit = parse_integer(request, "limit", default=5) from_token = parse_string(request, "from") @@ -167,7 +165,7 @@ class RelationPaginationServlet(RestServlet): if to_token: to_token = RelationPaginationToken.from_string(to_token) - pagination_chunk = yield self.store.get_relations_for_event( + pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, relation_type=relation_type, event_type=event_type, @@ -176,7 +174,7 @@ class RelationPaginationServlet(RestServlet): to_token=to_token, ) - events = yield self.store.get_events_as_list( + events = await self.store.get_events_as_list( [c["event_id"] for c in pagination_chunk.chunk] ) @@ -184,13 +182,13 @@ class RelationPaginationServlet(RestServlet): # We set bundle_aggregations to False when retrieving the original # event because we want the content before relations were applied to # it. - original_event = yield self._event_serializer.serialize_event( + original_event = await self._event_serializer.serialize_event( event, now, bundle_aggregations=False ) # Similarly, we don't allow relations to be applied to relations, so we # return the original relations without any aggregations on top of them # here. - events = yield self._event_serializer.serialize_events( + events = await self._event_serializer.serialize_events( events, now, bundle_aggregations=False ) @@ -232,17 +230,18 @@ class RelationAggregationPaginationServlet(RestServlet): self.store = hs.get_datastore() self.event_handler = hs.get_event_handler() - @defer.inlineCallbacks - def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET( + self, request, room_id, parent_id, relation_type=None, event_type=None + ): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - yield self.auth.check_in_room_or_world_readable( + await self.auth.check_in_room_or_world_readable( room_id, requester.user.to_string() ) # This checks that a) the event exists and b) the user is allowed to # view it. - event = yield self.event_handler.get_event(requester.user, room_id, parent_id) + event = await self.event_handler.get_event(requester.user, room_id, parent_id) if relation_type not in (RelationTypes.ANNOTATION, None): raise SynapseError(400, "Relation type must be 'annotation'") @@ -262,7 +261,7 @@ class RelationAggregationPaginationServlet(RestServlet): if to_token: to_token = AggregationPaginationToken.from_string(to_token) - pagination_chunk = yield self.store.get_aggregation_groups_for_event( + pagination_chunk = await self.store.get_aggregation_groups_for_event( event_id=parent_id, event_type=event_type, limit=limit, @@ -311,17 +310,16 @@ class RelationAggregationGroupPaginationServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() - @defer.inlineCallbacks - def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - yield self.auth.check_in_room_or_world_readable( + await self.auth.check_in_room_or_world_readable( room_id, requester.user.to_string() ) # This checks that a) the event exists and b) the user is allowed to # view it. - yield self.event_handler.get_event(requester.user, room_id, parent_id) + await self.event_handler.get_event(requester.user, room_id, parent_id) if relation_type != RelationTypes.ANNOTATION: raise SynapseError(400, "Relation type must be 'annotation'") @@ -336,7 +334,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): if to_token: to_token = RelationPaginationToken.from_string(to_token) - result = yield self.store.get_relations_for_event( + result = await self.store.get_relations_for_event( event_id=parent_id, relation_type=relation_type, event_type=event_type, @@ -346,12 +344,12 @@ class RelationAggregationGroupPaginationServlet(RestServlet): to_token=to_token, ) - events = yield self.store.get_events_as_list( + events = await self.store.get_events_as_list( [c["event_id"] for c in result.chunk] ) now = self.clock.time_msec() - events = yield self._event_serializer.serialize_events(events, now) + events = await self._event_serializer.serialize_events(events, now) return_value = result.to_dict() return_value["chunk"] = events diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index e7449864cd..f067b5edac 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -18,8 +18,6 @@ import logging from six import string_types from six.moves import http_client -from twisted.internet import defer - from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( RestServlet, @@ -42,9 +40,8 @@ class ReportEventRestServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_POST(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() body = parse_json_object_from_request(request) @@ -63,7 +60,7 @@ class ReportEventRestServlet(RestServlet): Codes.BAD_JSON, ) - yield self.store.add_event_report( + await self.store.add_event_report( room_id=room_id, event_id=event_id, user_id=user_id, diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index d596786430..38952a1d27 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, @@ -43,8 +41,7 @@ class RoomKeysServlet(RestServlet): self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - @defer.inlineCallbacks - def on_PUT(self, request, room_id, session_id): + async def on_PUT(self, request, room_id, session_id): """ Uploads one or more encrypted E2E room keys for backup purposes. room_id: the ID of the room the keys are for (optional) @@ -123,7 +120,7 @@ class RoomKeysServlet(RestServlet): } } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() body = parse_json_object_from_request(request) version = parse_string(request, "version") @@ -134,11 +131,10 @@ class RoomKeysServlet(RestServlet): if room_id: body = {"rooms": {room_id: body}} - yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) - return 200, {} + ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) + return 200, ret - @defer.inlineCallbacks - def on_GET(self, request, room_id, session_id): + async def on_GET(self, request, room_id, session_id): """ Retrieves one or more encrypted E2E room keys for backup purposes. Symmetric with the PUT version of the API. @@ -190,11 +186,11 @@ class RoomKeysServlet(RestServlet): } } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() version = parse_string(request, "version") - room_keys = yield self.e2e_room_keys_handler.get_room_keys( + room_keys = await self.e2e_room_keys_handler.get_room_keys( user_id, version, room_id, session_id ) @@ -220,8 +216,7 @@ class RoomKeysServlet(RestServlet): return 200, room_keys - @defer.inlineCallbacks - def on_DELETE(self, request, room_id, session_id): + async def on_DELETE(self, request, room_id, session_id): """ Deletes one or more encrypted E2E room keys for a user for backup purposes. @@ -235,14 +230,14 @@ class RoomKeysServlet(RestServlet): the version must already have been created via the /change_secret API. """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() version = parse_string(request, "version") - yield self.e2e_room_keys_handler.delete_room_keys( + ret = await self.e2e_room_keys_handler.delete_room_keys( user_id, version, room_id, session_id ) - return 200, {} + return 200, ret class RoomKeysNewVersionServlet(RestServlet): @@ -257,8 +252,7 @@ class RoomKeysNewVersionServlet(RestServlet): self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): """ Create a new backup version for this user's room_keys with the given info. The version is allocated by the server and returned to the user @@ -288,11 +282,11 @@ class RoomKeysNewVersionServlet(RestServlet): "version": 12345 } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() info = parse_json_object_from_request(request) - new_version = yield self.e2e_room_keys_handler.create_version(user_id, info) + new_version = await self.e2e_room_keys_handler.create_version(user_id, info) return 200, {"version": new_version} # we deliberately don't have a PUT /version, as these things really should @@ -311,8 +305,7 @@ class RoomKeysVersionServlet(RestServlet): self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - @defer.inlineCallbacks - def on_GET(self, request, version): + async def on_GET(self, request, version): """ Retrieve the version information about a given version of the user's room_keys backup. If the version part is missing, returns info about the @@ -330,18 +323,17 @@ class RoomKeysVersionServlet(RestServlet): "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() try: - info = yield self.e2e_room_keys_handler.get_version_info(user_id, version) + info = await self.e2e_room_keys_handler.get_version_info(user_id, version) except SynapseError as e: if e.code == 404: raise SynapseError(404, "No backup found", Codes.NOT_FOUND) return 200, info - @defer.inlineCallbacks - def on_DELETE(self, request, version): + async def on_DELETE(self, request, version): """ Delete the information about a given version of the user's room_keys backup. If the version part is missing, deletes the most @@ -354,14 +346,13 @@ class RoomKeysVersionServlet(RestServlet): if version is None: raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND) - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - yield self.e2e_room_keys_handler.delete_version(user_id, version) + await self.e2e_room_keys_handler.delete_version(user_id, version) return 200, {} - @defer.inlineCallbacks - def on_PUT(self, request, version): + async def on_PUT(self, request, version): """ Update the information about a given version of the user's room_keys backup. @@ -382,7 +373,7 @@ class RoomKeysVersionServlet(RestServlet): Content-Type: application/json {} """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() info = parse_json_object_from_request(request) @@ -391,7 +382,7 @@ class RoomKeysVersionServlet(RestServlet): 400, "No version specified to update", Codes.MISSING_PARAM ) - yield self.e2e_room_keys_handler.update_version(user_id, version, info) + await self.e2e_room_keys_handler.update_version(user_id, version, info) return 200, {} diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index d2c3316eb7..ca97330797 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import ( @@ -59,9 +57,8 @@ class RoomUpgradeRestServlet(RestServlet): self._room_creation_handler = hs.get_room_creation_handler() self._auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request, room_id): - requester = yield self._auth.get_user_by_req(request) + async def on_POST(self, request, room_id): + requester = await self._auth.get_user_by_req(request) content = parse_json_object_from_request(request) assert_params_in_dict(content, ("new_version",)) @@ -74,7 +71,7 @@ class RoomUpgradeRestServlet(RestServlet): Codes.UNSUPPORTED_ROOM_VERSION, ) - new_room_id = yield self._room_creation_handler.upgrade_room( + new_room_id = await self._room_creation_handler.upgrade_room( requester, room_id, new_version ) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index d90e52ed1a..501b52fb6c 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.http import servlet from synapse.http.servlet import parse_json_object_from_request from synapse.logging.opentracing import set_tag, trace @@ -51,15 +49,14 @@ class SendToDeviceRestServlet(servlet.RestServlet): request, self._put, request, message_type, txn_id ) - @defer.inlineCallbacks - def _put(self, request, message_type, txn_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def _put(self, request, message_type, txn_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) sender_user_id = requester.user.to_string() - yield self.device_message_handler.send_device_message( + await self.device_message_handler.send_device_message( sender_user_id, message_type, content["messages"] ) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index ccd8b17b23..d8292ce29f 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -18,8 +18,6 @@ import logging from canonicaljson import json -from twisted.internet import defer - from synapse.api.constants import PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection @@ -87,8 +85,7 @@ class SyncRestServlet(RestServlet): self._server_notices_sender = hs.get_server_notices_sender() self._event_serializer = hs.get_event_client_serializer() - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): if b"from" in request.args: # /events used to use 'from', but /sync uses 'since'. # Lets be helpful and whine if we see a 'from'. @@ -96,7 +93,7 @@ class SyncRestServlet(RestServlet): 400, "'from' is not a valid query parameter. Did you mean 'since'?" ) - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = await self.auth.get_user_by_req(request, allow_guest=True) user = requester.user device_id = requester.device_id @@ -138,7 +135,7 @@ class SyncRestServlet(RestServlet): filter_collection = FilterCollection(filter_object) else: try: - filter_collection = yield self.filtering.get_user_filter( + filter_collection = await self.filtering.get_user_filter( user.localpart, filter_id ) except StoreError as err: @@ -161,20 +158,20 @@ class SyncRestServlet(RestServlet): since_token = None # send any outstanding server notices to the user. - yield self._server_notices_sender.on_user_syncing(user.to_string()) + await self._server_notices_sender.on_user_syncing(user.to_string()) affect_presence = set_presence != PresenceState.OFFLINE if affect_presence: - yield self.presence_handler.set_state( + await self.presence_handler.set_state( user, {"presence": set_presence}, True ) - context = yield self.presence_handler.user_syncing( + context = await self.presence_handler.user_syncing( user.to_string(), affect_presence=affect_presence ) with context: - sync_result = yield self.sync_handler.wait_for_sync_for_user( + sync_result = await self.sync_handler.wait_for_sync_for_user( sync_config, since_token=since_token, timeout=timeout, @@ -182,14 +179,13 @@ class SyncRestServlet(RestServlet): ) time_now = self.clock.time_msec() - response_content = yield self.encode_response( + response_content = await self.encode_response( time_now, sync_result, requester.access_token_id, filter_collection ) return 200, response_content - @defer.inlineCallbacks - def encode_response(self, time_now, sync_result, access_token_id, filter): + async def encode_response(self, time_now, sync_result, access_token_id, filter): if filter.event_format == "client": event_formatter = format_event_for_client_v2_without_room_id elif filter.event_format == "federation": @@ -197,7 +193,7 @@ class SyncRestServlet(RestServlet): else: raise Exception("Unknown event format %s" % (filter.event_format,)) - joined = yield self.encode_joined( + joined = await self.encode_joined( sync_result.joined, time_now, access_token_id, @@ -205,11 +201,11 @@ class SyncRestServlet(RestServlet): event_formatter, ) - invited = yield self.encode_invited( + invited = await self.encode_invited( sync_result.invited, time_now, access_token_id, event_formatter ) - archived = yield self.encode_archived( + archived = await self.encode_archived( sync_result.archived, time_now, access_token_id, @@ -250,8 +246,9 @@ class SyncRestServlet(RestServlet): ] } - @defer.inlineCallbacks - def encode_joined(self, rooms, time_now, token_id, event_fields, event_formatter): + async def encode_joined( + self, rooms, time_now, token_id, event_fields, event_formatter + ): """ Encode the joined rooms in a sync result @@ -272,7 +269,7 @@ class SyncRestServlet(RestServlet): """ joined = {} for room in rooms: - joined[room.room_id] = yield self.encode_room( + joined[room.room_id] = await self.encode_room( room, time_now, token_id, @@ -283,8 +280,7 @@ class SyncRestServlet(RestServlet): return joined - @defer.inlineCallbacks - def encode_invited(self, rooms, time_now, token_id, event_formatter): + async def encode_invited(self, rooms, time_now, token_id, event_formatter): """ Encode the invited rooms in a sync result @@ -304,7 +300,7 @@ class SyncRestServlet(RestServlet): """ invited = {} for room in rooms: - invite = yield self._event_serializer.serialize_event( + invite = await self._event_serializer.serialize_event( room.invite, time_now, token_id=token_id, @@ -319,8 +315,9 @@ class SyncRestServlet(RestServlet): return invited - @defer.inlineCallbacks - def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatter): + async def encode_archived( + self, rooms, time_now, token_id, event_fields, event_formatter + ): """ Encode the archived rooms in a sync result @@ -341,7 +338,7 @@ class SyncRestServlet(RestServlet): """ joined = {} for room in rooms: - joined[room.room_id] = yield self.encode_room( + joined[room.room_id] = await self.encode_room( room, time_now, token_id, @@ -352,8 +349,7 @@ class SyncRestServlet(RestServlet): return joined - @defer.inlineCallbacks - def encode_room( + async def encode_room( self, room, time_now, token_id, joined, only_fields, event_formatter ): """ @@ -401,8 +397,8 @@ class SyncRestServlet(RestServlet): event.room_id, ) - serialized_state = yield serialize(state_events) - serialized_timeline = yield serialize(timeline_events) + serialized_state = await serialize(state_events) + serialized_timeline = await serialize(timeline_events) account_data = room.account_data diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index 3b555669a0..a3f12e8a77 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -37,13 +35,12 @@ class TagListServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastore() - @defer.inlineCallbacks - def on_GET(self, request, user_id, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, user_id, room_id): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get tags for other users.") - tags = yield self.store.get_tags_for_room(user_id, room_id) + tags = await self.store.get_tags_for_room(user_id, room_id) return 200, {"tags": tags} @@ -64,27 +61,25 @@ class TagServlet(RestServlet): self.store = hs.get_datastore() self.notifier = hs.get_notifier() - @defer.inlineCallbacks - def on_PUT(self, request, user_id, room_id, tag): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, user_id, room_id, tag): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") body = parse_json_object_from_request(request) - max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) + max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body) self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) return 200, {} - @defer.inlineCallbacks - def on_DELETE(self, request, user_id, room_id, tag): - requester = yield self.auth.get_user_by_req(request) + async def on_DELETE(self, request, user_id, room_id, tag): + requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") - max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) + max_id = await self.store.remove_tag_from_room(user_id, room_id, tag) self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 2e8d672471..23709960ad 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -16,8 +16,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import ThirdPartyEntityKind from synapse.http.servlet import RestServlet @@ -35,11 +33,10 @@ class ThirdPartyProtocolsServlet(RestServlet): self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_GET(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) - protocols = yield self.appservice_handler.get_3pe_protocols() + protocols = await self.appservice_handler.get_3pe_protocols() return 200, protocols @@ -52,11 +49,10 @@ class ThirdPartyProtocolServlet(RestServlet): self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) - protocols = yield self.appservice_handler.get_3pe_protocols( + protocols = await self.appservice_handler.get_3pe_protocols( only_protocol=protocol ) if protocol in protocols: @@ -74,14 +70,13 @@ class ThirdPartyUserServlet(RestServlet): self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) fields = request.args fields.pop(b"access_token", None) - results = yield self.appservice_handler.query_3pe( + results = await self.appservice_handler.query_3pe( ThirdPartyEntityKind.USER, protocol, fields ) @@ -97,14 +92,13 @@ class ThirdPartyLocationServlet(RestServlet): self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() - @defer.inlineCallbacks - def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, protocol): + await self.auth.get_user_by_req(request, allow_guest=True) fields = request.args fields.pop(b"access_token", None) - results = yield self.appservice_handler.query_3pe( + results = await self.appservice_handler.query_3pe( ThirdPartyEntityKind.LOCATION, protocol, fields ) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 2da0f55811..83f3b6b70a 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet @@ -32,8 +30,7 @@ class TokenRefreshRestServlet(RestServlet): def __init__(self, hs): super(TokenRefreshRestServlet, self).__init__() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): raise AuthError(403, "tokenrefresh is no longer supported.") diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index 2863affbab..bef91a2d3e 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -38,8 +36,7 @@ class UserDirectorySearchRestServlet(RestServlet): self.auth = hs.get_auth() self.user_directory_handler = hs.get_user_directory_handler() - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): """Searches for users in directory Returns: @@ -56,7 +53,7 @@ class UserDirectorySearchRestServlet(RestServlet): ] } """ - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() if not self.hs.config.user_directory_search_enabled: @@ -72,7 +69,7 @@ class UserDirectorySearchRestServlet(RestServlet): except Exception: raise SynapseError(400, "`search_term` is required field") - results = yield self.user_directory_handler.search_users( + results = await self.user_directory_handler.search_users( user_id, search_term, limit ) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index bb30ce3f34..2a477ad22e 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index fb0d02aa83..6b978be876 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -402,7 +402,7 @@ class PreviewUrlResource(DirectServeResource): logger.info("Running url preview cache expiry") - if not (yield self.store.has_completed_background_updates()): + if not (yield self.store.db.updates.has_completed_background_updates()): logger.info("Still running DB updates; skipping expiry") return diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 8cf415e29d..c234ea7421 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -129,5 +129,8 @@ class Thumbnailer(object): def _encode_image(self, output_image, output_type): output_bytes_io = BytesIO() - output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80) + fmt = self.FORMATS[output_type] + if fmt == "JPEG": + output_image = output_image.convert("RGB") + output_image.save(output_bytes_io, fmt, quality=80) return output_bytes_io diff --git a/synapse/server.py b/synapse/server.py index be9af7f986..2db3dab221 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -238,8 +238,7 @@ class HomeServer(object): def setup(self): logger.info("Setting up.") with self.get_db_conn() as conn: - datastore = self.DATASTORE_CLASS(conn, self) - self.datastores = DataStores(datastore, conn, self) + self.datastores = DataStores(self.DATASTORE_CLASS, conn, self) conn.commit() self.start_time = int(self.get_clock().time()) logger.info("Finished setting up.") diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 0460fe8cc9..ec89f645d4 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -17,10 +17,10 @@ """ The storage layer is split up into multiple parts to allow Synapse to run against different configurations of databases (e.g. single or multiple -databases). The `data_stores` are classes that talk directly to a single -database and have associated schemas, background updates, etc. On top of those -there are (or will be) classes that provide high level interfaces that combine -calls to multiple `data_stores`. +databases). The `Database` class represents a single physical database. The +`data_stores` are classes that talk directly to a `Database` instance and have +associated schemas, background updates, etc. On top of those there are classes +that provide high level interfaces that combine calls to multiple `data_stores`. There are also schemas that get applied to every database, regardless of the data stores associated with them (e.g. the schema version tables), which are @@ -49,15 +49,3 @@ class Storage(object): self.persistence = EventsPersistenceStorage(hs, stores) self.purge_events = PurgeEventsStorage(hs, stores) self.state = StateGroupStorage(hs, stores) - - -def are_all_users_on_domain(txn, database_engine, domain): - sql = database_engine.convert_param_style( - "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?" - ) - pat = "%:" + domain - txn.execute(sql, (pat,)) - num_not_matching = txn.fetchall()[0][0] - if num_not_matching == 0: - return True - return False diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 459901ac60..b7637b5dc0 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -14,1433 +14,36 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging import random -import sys -import threading -import time -from typing import Iterable, Tuple -from six import PY2, iteritems, iterkeys, itervalues -from six.moves import builtins, intern, range +from six import PY2 +from six.moves import builtins from canonicaljson import json -from prometheus_client import Histogram -from twisted.internet import defer - -from synapse.api.errors import StoreError -from synapse.logging.context import LoggingContext, make_deferred_yieldable -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.storage.database import LoggingTransaction # noqa: F401 +from synapse.storage.database import make_in_list_sql_clause # noqa: F401 +from synapse.storage.database import Database from synapse.types import get_domain_from_id -from synapse.util import batch_iter -from synapse.util.caches.descriptors import Cache -from synapse.util.stringutils import exception_to_unicode - -# import a function which will return a monotonic time, in seconds -try: - # on python 3, use time.monotonic, since time.clock can go backwards - from time import monotonic as monotonic_time -except ImportError: - # ... but python 2 doesn't have it - from time import clock as monotonic_time logger = logging.getLogger(__name__) -try: - MAX_TXN_ID = sys.maxint - 1 -except AttributeError: - # python 3 does not have a maximum int value - MAX_TXN_ID = 2 ** 63 - 1 - -sql_logger = logging.getLogger("synapse.storage.SQL") -transaction_logger = logging.getLogger("synapse.storage.txn") -perf_logger = logging.getLogger("synapse.storage.TIME") - -sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec") - -sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"]) -sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"]) - - -# Unique indexes which have been added in background updates. Maps from table name -# to the name of the background update which added the unique index to that table. -# -# This is used by the upsert logic to figure out which tables are safe to do a proper -# UPSERT on: until the relevant background update has completed, we -# have to emulate an upsert by locking the table. -# -UNIQUE_INDEX_BACKGROUND_UPDATES = { - "user_ips": "user_ips_device_unique_index", - "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx", - "device_lists_remote_cache": "device_lists_remote_cache_unique_idx", - "event_search": "event_search_event_id_idx", -} - -# This is a special cache name we use to batch multiple invalidations of caches -# based on the current state when notifying workers over replication. -_CURRENT_STATE_CACHE_NAME = "cs_cache_fake" - -class LoggingTransaction(object): - """An object that almost-transparently proxies for the 'txn' object - passed to the constructor. Adds logging and metrics to the .execute() - method. +class SQLBaseStore(object): + """Base class for data stores that holds helper functions. - Args: - txn: The database transcation object to wrap. - name (str): The name of this transactions for logging. - database_engine (Sqlite3Engine|PostgresEngine) - after_callbacks(list|None): A list that callbacks will be appended to - that have been added by `call_after` which should be run on - successful completion of the transaction. None indicates that no - callbacks should be allowed to be scheduled to run. - exception_callbacks(list|None): A list that callbacks will be appended - to that have been added by `call_on_exception` which should be run - if transaction ends with an error. None indicates that no callbacks - should be allowed to be scheduled to run. + Note that multiple instances of this class will exist as there will be one + per data store (and not one per physical database). """ - __slots__ = [ - "txn", - "name", - "database_engine", - "after_callbacks", - "exception_callbacks", - ] - - def __init__( - self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None - ): - object.__setattr__(self, "txn", txn) - object.__setattr__(self, "name", name) - object.__setattr__(self, "database_engine", database_engine) - object.__setattr__(self, "after_callbacks", after_callbacks) - object.__setattr__(self, "exception_callbacks", exception_callbacks) - - def call_after(self, callback, *args, **kwargs): - """Call the given callback on the main twisted thread after the - transaction has finished. Used to invalidate the caches on the - correct thread. - """ - self.after_callbacks.append((callback, args, kwargs)) - - def call_on_exception(self, callback, *args, **kwargs): - self.exception_callbacks.append((callback, args, kwargs)) - - def __getattr__(self, name): - return getattr(self.txn, name) - - def __setattr__(self, name, value): - setattr(self.txn, name, value) - - def __iter__(self): - return self.txn.__iter__() - - def execute_batch(self, sql, args): - if isinstance(self.database_engine, PostgresEngine): - from psycopg2.extras import execute_batch - - self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) - else: - for val in args: - self.execute(sql, val) - - def execute(self, sql, *args): - self._do_execute(self.txn.execute, sql, *args) - - def executemany(self, sql, *args): - self._do_execute(self.txn.executemany, sql, *args) - - def _make_sql_one_line(self, sql): - "Strip newlines out of SQL so that the loggers in the DB are on one line" - return " ".join(l.strip() for l in sql.splitlines() if l.strip()) - - def _do_execute(self, func, sql, *args): - sql = self._make_sql_one_line(sql) - - # TODO(paul): Maybe use 'info' and 'debug' for values? - sql_logger.debug("[SQL] {%s} %s", self.name, sql) - - sql = self.database_engine.convert_param_style(sql) - if args: - try: - sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) - except Exception: - # Don't let logging failures stop SQL from working - pass - - start = time.time() - - try: - return func(sql, *args) - except Exception as e: - logger.debug("[SQL FAIL] {%s} %s", self.name, e) - raise - finally: - secs = time.time() - start - sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) - sql_query_timer.labels(sql.split()[0]).observe(secs) - - -class PerformanceCounters(object): - def __init__(self): - self.current_counters = {} - self.previous_counters = {} - - def update(self, key, duration_secs): - count, cum_time = self.current_counters.get(key, (0, 0)) - count += 1 - cum_time += duration_secs - self.current_counters[key] = (count, cum_time) - - def interval(self, interval_duration_secs, limit=3): - counters = [] - for name, (count, cum_time) in iteritems(self.current_counters): - prev_count, prev_time = self.previous_counters.get(name, (0, 0)) - counters.append( - ( - (cum_time - prev_time) / interval_duration_secs, - count - prev_count, - name, - ) - ) - - self.previous_counters = dict(self.current_counters) - - counters.sort(reverse=True) - - top_n_counters = ", ".join( - "%s(%d): %.3f%%" % (name, count, 100 * ratio) - for ratio, count, name in counters[:limit] - ) - - return top_n_counters - - -class SQLBaseStore(object): - _TXN_ID = 0 - - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() - self._db_pool = hs.get_db_pool() - - self._previous_txn_total_time = 0 - self._current_txn_total_time = 0 - self._previous_loop_ts = 0 - - # TODO(paul): These can eventually be removed once the metrics code - # is running in mainline, and we have some nice monitoring frontends - # to watch it - self._txn_perf_counters = PerformanceCounters() - - self._get_event_cache = Cache( - "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size - ) - - self._event_fetch_lock = threading.Condition() - self._event_fetch_list = [] - self._event_fetch_ongoing = 0 - - self._pending_ds = [] - self.database_engine = hs.database_engine - - # A set of tables that are not safe to use native upserts in. - self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) - - self._account_validity = self.hs.config.account_validity - - # We add the user_directory_search table to the blacklist on SQLite - # because the existing search table does not have an index, making it - # unsafe to use native upserts. - if isinstance(self.database_engine, Sqlite3Engine): - self._unsafe_to_upsert_tables.add("user_directory_search") - - if self.database_engine.can_native_upsert: - # Check ASAP (and then later, every 1s) to see if we have finished - # background updates of tables that aren't safe to update. - self._clock.call_later( - 0.0, - run_as_background_process, - "upsert_safety_check", - self._check_safe_to_upsert, - ) - + self.db = database self.rand = random.SystemRandom() - if self._account_validity.enabled: - self._clock.call_later( - 0.0, - run_as_background_process, - "account_validity_set_expiration_dates", - self._set_expiration_date_when_missing, - ) - - @defer.inlineCallbacks - def _check_safe_to_upsert(self): - """ - Is it safe to use native UPSERT? - - If there are background updates, we will need to wait, as they may be - the addition of indexes that set the UNIQUE constraint that we require. - - If the background updates have not completed, wait 15 sec and check again. - """ - updates = yield self._simple_select_list( - "background_updates", - keyvalues=None, - retcols=["update_name"], - desc="check_background_updates", - ) - updates = [x["update_name"] for x in updates] - - for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items(): - if update_name not in updates: - logger.debug("Now safe to upsert in %s", table) - self._unsafe_to_upsert_tables.discard(table) - - # If there's any updates still running, reschedule to run. - if updates: - self._clock.call_later( - 15.0, - run_as_background_process, - "upsert_safety_check", - self._check_safe_to_upsert, - ) - - @defer.inlineCallbacks - def _set_expiration_date_when_missing(self): - """ - Retrieves the list of registered users that don't have an expiration date, and - adds an expiration date for each of them. - """ - - def select_users_with_no_expiration_date_txn(txn): - """Retrieves the list of registered users with no expiration date from the - database, filtering out deactivated users. - """ - sql = ( - "SELECT users.name FROM users" - " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" - " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" - ) - txn.execute(sql, []) - - res = self.cursor_to_dict(txn) - if res: - for user in res: - self.set_expiration_date_for_user_txn( - txn, user["name"], use_delta=True - ) - - yield self.runInteraction( - "get_users_with_no_expiration_date", - select_users_with_no_expiration_date_txn, - ) - - def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): - """Sets an expiration date to the account with the given user ID. - - Args: - user_id (str): User ID to set an expiration date for. - use_delta (bool): If set to False, the expiration date for the user will be - now + validity period. If set to True, this expiration date will be a - random value in the [now + period - d ; now + period] range, d being a - delta equal to 10% of the validity period. - """ - now_ms = self._clock.time_msec() - expiration_ts = now_ms + self._account_validity.period - - if use_delta: - expiration_ts = self.rand.randrange( - expiration_ts - self._account_validity.startup_job_max_delta, - expiration_ts, - ) - - self._simple_upsert_txn( - txn, - "account_validity", - keyvalues={"user_id": user_id}, - values={"expiration_ts_ms": expiration_ts, "email_sent": False}, - ) - - def start_profiling(self): - self._previous_loop_ts = monotonic_time() - - def loop(): - curr = self._current_txn_total_time - prev = self._previous_txn_total_time - self._previous_txn_total_time = curr - - time_now = monotonic_time() - time_then = self._previous_loop_ts - self._previous_loop_ts = time_now - - duration = time_now - time_then - ratio = (curr - prev) / duration - - top_three_counters = self._txn_perf_counters.interval(duration, limit=3) - - perf_logger.info( - "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters - ) - - self._clock.looping_call(loop, 10000) - - def _new_transaction( - self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs - ): - start = monotonic_time() - txn_id = self._TXN_ID - - # We don't really need these to be unique, so lets stop it from - # growing really large. - self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID) - - name = "%s-%x" % (desc, txn_id) - - transaction_logger.debug("[TXN START] {%s}", name) - - try: - i = 0 - N = 5 - while True: - cursor = LoggingTransaction( - conn.cursor(), - name, - self.database_engine, - after_callbacks, - exception_callbacks, - ) - try: - r = func(cursor, *args, **kwargs) - conn.commit() - return r - except self.database_engine.module.OperationalError as e: - # This can happen if the database disappears mid - # transaction. - logger.warning( - "[TXN OPERROR] {%s} %s %d/%d", - name, - exception_to_unicode(e), - i, - N, - ) - if i < N: - i += 1 - try: - conn.rollback() - except self.database_engine.module.Error as e1: - logger.warning( - "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1) - ) - continue - raise - except self.database_engine.module.DatabaseError as e: - if self.database_engine.is_deadlock(e): - logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N) - if i < N: - i += 1 - try: - conn.rollback() - except self.database_engine.module.Error as e1: - logger.warning( - "[TXN EROLL] {%s} %s", - name, - exception_to_unicode(e1), - ) - continue - raise - finally: - # we're either about to retry with a new cursor, or we're about to - # release the connection. Once we release the connection, it could - # get used for another query, which might do a conn.rollback(). - # - # In the latter case, even though that probably wouldn't affect the - # results of this transaction, python's sqlite will reset all - # statements on the connection [1], which will make our cursor - # invalid [2]. - # - # In any case, continuing to read rows after commit()ing seems - # dubious from the PoV of ACID transactional semantics - # (sqlite explicitly says that once you commit, you may see rows - # from subsequent updates.) - # - # In psycopg2, cursors are essentially a client-side fabrication - - # all the data is transferred to the client side when the statement - # finishes executing - so in theory we could go on streaming results - # from the cursor, but attempting to do so would make us - # incompatible with sqlite, so let's make sure we're not doing that - # by closing the cursor. - # - # (*named* cursors in psycopg2 are different and are proper server- - # side things, but (a) we don't use them and (b) they are implicitly - # closed by ending the transaction anyway.) - # - # In short, if we haven't finished with the cursor yet, that's a - # problem waiting to bite us. - # - # TL;DR: we're done with the cursor, so we can close it. - # - # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465 - # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236 - cursor.close() - except Exception as e: - logger.debug("[TXN FAIL] {%s} %s", name, e) - raise - finally: - end = monotonic_time() - duration = end - start - - LoggingContext.current_context().add_database_transaction(duration) - - transaction_logger.debug("[TXN END] {%s} %f sec", name, duration) - - self._current_txn_total_time += duration - self._txn_perf_counters.update(desc, duration) - sql_txn_timer.labels(desc).observe(duration) - - @defer.inlineCallbacks - def runInteraction(self, desc, func, *args, **kwargs): - """Starts a transaction on the database and runs a given function - - Arguments: - desc (str): description of the transaction, for logging and metrics - func (func): callback function, which will be called with a - database transaction (twisted.enterprise.adbapi.Transaction) as - its first argument, followed by `args` and `kwargs`. - - args (list): positional args to pass to `func` - kwargs (dict): named args to pass to `func` - - Returns: - Deferred: The result of func - """ - after_callbacks = [] - exception_callbacks = [] - - if LoggingContext.current_context() == LoggingContext.sentinel: - logger.warning("Starting db txn '%s' from sentinel context", desc) - - try: - result = yield self.runWithConnection( - self._new_transaction, - desc, - after_callbacks, - exception_callbacks, - func, - *args, - **kwargs - ) - - for after_callback, after_args, after_kwargs in after_callbacks: - after_callback(*after_args, **after_kwargs) - except: # noqa: E722, as we reraise the exception this is fine. - for after_callback, after_args, after_kwargs in exception_callbacks: - after_callback(*after_args, **after_kwargs) - raise - - return result - - @defer.inlineCallbacks - def runWithConnection(self, func, *args, **kwargs): - """Wraps the .runWithConnection() method on the underlying db_pool. - - Arguments: - func (func): callback function, which will be called with a - database connection (twisted.enterprise.adbapi.Connection) as - its first argument, followed by `args` and `kwargs`. - args (list): positional args to pass to `func` - kwargs (dict): named args to pass to `func` - - Returns: - Deferred: The result of func - """ - parent_context = LoggingContext.current_context() - if parent_context == LoggingContext.sentinel: - logger.warning( - "Starting db connection from sentinel context: metrics will be lost" - ) - parent_context = None - - start_time = monotonic_time() - - def inner_func(conn, *args, **kwargs): - with LoggingContext("runWithConnection", parent_context) as context: - sched_duration_sec = monotonic_time() - start_time - sql_scheduling_timer.observe(sched_duration_sec) - context.add_database_scheduled(sched_duration_sec) - - if self.database_engine.is_connection_closed(conn): - logger.debug("Reconnecting closed database connection") - conn.reconnect() - - return func(conn, *args, **kwargs) - - result = yield make_deferred_yieldable( - self._db_pool.runWithConnection(inner_func, *args, **kwargs) - ) - - return result - - @staticmethod - def cursor_to_dict(cursor): - """Converts a SQL cursor into an list of dicts. - - Args: - cursor : The DBAPI cursor which has executed a query. - Returns: - A list of dicts where the key is the column header. - """ - col_headers = list(intern(str(column[0])) for column in cursor.description) - results = list(dict(zip(col_headers, row)) for row in cursor) - return results - - def _execute(self, desc, decoder, query, *args): - """Runs a single query for a result set. - - Args: - decoder - The function which can resolve the cursor results to - something meaningful. - query - The query string to execute - *args - Query args. - Returns: - The result of decoder(results) - """ - - def interaction(txn): - txn.execute(query, args) - if decoder: - return decoder(txn) - else: - return txn.fetchall() - - return self.runInteraction(desc, interaction) - - # "Simple" SQL API methods that operate on a single table with no JOINs, - # no complex WHERE clauses, just a dict of values for columns. - - @defer.inlineCallbacks - def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"): - """Executes an INSERT query on the named table. - - Args: - table : string giving the table name - values : dict of new column names and values for them - or_ignore : bool stating whether an exception should be raised - when a conflicting row already exists. If True, False will be - returned by the function instead - desc : string giving a description of the transaction - - Returns: - bool: Whether the row was inserted or not. Only useful when - `or_ignore` is True - """ - try: - yield self.runInteraction(desc, self._simple_insert_txn, table, values) - except self.database_engine.module.IntegrityError: - # We have to do or_ignore flag at this layer, since we can't reuse - # a cursor after we receive an error from the db. - if not or_ignore: - raise - return False - return True - - @staticmethod - def _simple_insert_txn(txn, table, values): - keys, vals = zip(*values.items()) - - sql = "INSERT INTO %s (%s) VALUES(%s)" % ( - table, - ", ".join(k for k in keys), - ", ".join("?" for _ in keys), - ) - - txn.execute(sql, vals) - - def _simple_insert_many(self, table, values, desc): - return self.runInteraction(desc, self._simple_insert_many_txn, table, values) - - @staticmethod - def _simple_insert_many_txn(txn, table, values): - if not values: - return - - # This is a *slight* abomination to get a list of tuples of key names - # and a list of tuples of value names. - # - # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}] - # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)] - # - # The sort is to ensure that we don't rely on dictionary iteration - # order. - keys, vals = zip( - *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i] - ) - - for k in keys: - if k != keys[0]: - raise RuntimeError("All items must have the same keys") - - sql = "INSERT INTO %s (%s) VALUES(%s)" % ( - table, - ", ".join(k for k in keys[0]), - ", ".join("?" for _ in keys[0]), - ) - - txn.executemany(sql, vals) - - @defer.inlineCallbacks - def _simple_upsert( - self, - table, - keyvalues, - values, - insertion_values={}, - desc="_simple_upsert", - lock=True, - ): - """ - - `lock` should generally be set to True (the default), but can be set - to False if either of the following are true: - - * there is a UNIQUE INDEX on the key columns. In this case a conflict - will cause an IntegrityError in which case this function will retry - the update. - - * we somehow know that we are the only thread which will be updating - this table. - - Args: - table (str): The table to upsert into - keyvalues (dict): The unique key columns and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. - Returns: - Deferred(None or bool): Native upserts always return None. Emulated - upserts return True if a new entry was created, False if an existing - one was updated. - """ - attempts = 0 - while True: - try: - result = yield self.runInteraction( - desc, - self._simple_upsert_txn, - table, - keyvalues, - values, - insertion_values, - lock=lock, - ) - return result - except self.database_engine.module.IntegrityError as e: - attempts += 1 - if attempts >= 5: - # don't retry forever, because things other than races - # can cause IntegrityErrors - raise - - # presumably we raced with another transaction: let's retry. - logger.warning( - "IntegrityError when upserting into %s; retrying: %s", table, e - ) - - def _simple_upsert_txn( - self, txn, table, keyvalues, values, insertion_values={}, lock=True - ): - """ - Pick the UPSERT method which works best on the platform. Either the - native one (Pg9.5+, recent SQLites), or fall back to an emulated method. - - Args: - txn: The transaction to use. - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. - Returns: - None or bool: Native upserts always return None. Emulated - upserts return True if a new entry was created, False if an existing - one was updated. - """ - if ( - self.database_engine.can_native_upsert - and table not in self._unsafe_to_upsert_tables - ): - return self._simple_upsert_txn_native_upsert( - txn, table, keyvalues, values, insertion_values=insertion_values - ) - else: - return self._simple_upsert_txn_emulated( - txn, - table, - keyvalues, - values, - insertion_values=insertion_values, - lock=lock, - ) - - def _simple_upsert_txn_emulated( - self, txn, table, keyvalues, values, insertion_values={}, lock=True - ): - """ - Args: - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - lock (bool): True to lock the table when doing the upsert. - Returns: - bool: Return True if a new entry was created, False if an existing - one was updated. - """ - # We need to lock the table :(, unless we're *really* careful - if lock: - self.database_engine.lock_table(txn, table) - - def _getwhere(key): - # If the value we're passing in is None (aka NULL), we need to use - # IS, not =, as NULL = NULL equals NULL (False). - if keyvalues[key] is None: - return "%s IS ?" % (key,) - else: - return "%s = ?" % (key,) - - if not values: - # If `values` is empty, then all of the values we care about are in - # the unique key, so there is nothing to UPDATE. We can just do a - # SELECT instead to see if it exists. - sql = "SELECT 1 FROM %s WHERE %s" % ( - table, - " AND ".join(_getwhere(k) for k in keyvalues), - ) - sqlargs = list(keyvalues.values()) - txn.execute(sql, sqlargs) - if txn.fetchall(): - # We have an existing record. - return False - else: - # First try to update. - sql = "UPDATE %s SET %s WHERE %s" % ( - table, - ", ".join("%s = ?" % (k,) for k in values), - " AND ".join(_getwhere(k) for k in keyvalues), - ) - sqlargs = list(values.values()) + list(keyvalues.values()) - - txn.execute(sql, sqlargs) - if txn.rowcount > 0: - # successfully updated at least one row. - return False - - # We didn't find any existing rows, so insert a new one - allvalues = {} - allvalues.update(keyvalues) - allvalues.update(values) - allvalues.update(insertion_values) - - sql = "INSERT INTO %s (%s) VALUES (%s)" % ( - table, - ", ".join(k for k in allvalues), - ", ".join("?" for _ in allvalues), - ) - txn.execute(sql, list(allvalues.values())) - # successfully inserted - return True - - def _simple_upsert_txn_native_upsert( - self, txn, table, keyvalues, values, insertion_values={} - ): - """ - Use the native UPSERT functionality in recent PostgreSQL versions. - - Args: - table (str): The table to upsert into - keyvalues (dict): The unique key tables and their new values - values (dict): The nonunique columns and their new values - insertion_values (dict): additional key/values to use only when - inserting - Returns: - None - """ - allvalues = {} - allvalues.update(keyvalues) - allvalues.update(insertion_values) - - if not values: - latter = "NOTHING" - else: - allvalues.update(values) - latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) - - sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % ( - table, - ", ".join(k for k in allvalues), - ", ".join("?" for _ in allvalues), - ", ".join(k for k in keyvalues), - latter, - ) - txn.execute(sql, list(allvalues.values())) - - def _simple_upsert_many_txn( - self, txn, table, key_names, key_values, value_names, value_values - ): - """ - Upsert, many times. - - Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None - """ - if ( - self.database_engine.can_native_upsert - and table not in self._unsafe_to_upsert_tables - ): - return self._simple_upsert_many_txn_native_upsert( - txn, table, key_names, key_values, value_names, value_values - ) - else: - return self._simple_upsert_many_txn_emulated( - txn, table, key_names, key_values, value_names, value_values - ) - - def _simple_upsert_many_txn_emulated( - self, txn, table, key_names, key_values, value_names, value_values - ): - """ - Upsert, many times, but without native UPSERT support or batching. - - Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None - """ - # No value columns, therefore make a blank list so that the following - # zip() works correctly. - if not value_names: - value_values = [() for x in range(len(key_values))] - - for keyv, valv in zip(key_values, value_values): - _keys = {x: y for x, y in zip(key_names, keyv)} - _vals = {x: y for x, y in zip(value_names, valv)} - - self._simple_upsert_txn_emulated(txn, table, _keys, _vals) - - def _simple_upsert_many_txn_native_upsert( - self, txn, table, key_names, key_values, value_names, value_values - ): - """ - Upsert, many times, using batching where possible. - - Args: - table (str): The table to upsert into - key_names (list[str]): The key column names. - key_values (list[list]): A list of each row's key column values. - value_names (list[str]): The value column names. If empty, no - values will be used, even if value_values is provided. - value_values (list[list]): A list of each row's value column values. - Returns: - None - """ - allnames = [] - allnames.extend(key_names) - allnames.extend(value_names) - - if not value_names: - # No value columns, therefore make a blank list so that the - # following zip() works correctly. - latter = "NOTHING" - value_values = [() for x in range(len(key_values))] - else: - latter = "UPDATE SET " + ", ".join( - k + "=EXCLUDED." + k for k in value_names - ) - - sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % ( - table, - ", ".join(k for k in allnames), - ", ".join("?" for _ in allnames), - ", ".join(key_names), - latter, - ) - - args = [] - - for x, y in zip(key_values, value_values): - args.append(tuple(x) + tuple(y)) - - return txn.execute_batch(sql, args) - - def _simple_select_one( - self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one" - ): - """Executes a SELECT query on the named table, which is expected to - return a single row, returning multiple columns from it. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - retcols : list of strings giving the names of the columns to return - - allow_none : If true, return None instead of failing if the SELECT - statement returns no rows - """ - return self.runInteraction( - desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none - ) - - def _simple_select_one_onecol( - self, - table, - keyvalues, - retcol, - allow_none=False, - desc="_simple_select_one_onecol", - ): - """Executes a SELECT query on the named table, which is expected to - return a single row, returning a single column from it. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - retcol : string giving the name of the column to return - """ - return self.runInteraction( - desc, - self._simple_select_one_onecol_txn, - table, - keyvalues, - retcol, - allow_none=allow_none, - ) - - @classmethod - def _simple_select_one_onecol_txn( - cls, txn, table, keyvalues, retcol, allow_none=False - ): - ret = cls._simple_select_onecol_txn( - txn, table=table, keyvalues=keyvalues, retcol=retcol - ) - - if ret: - return ret[0] - else: - if allow_none: - return None - else: - raise StoreError(404, "No row found") - - @staticmethod - def _simple_select_onecol_txn(txn, table, keyvalues, retcol): - sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} - - if keyvalues: - sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) - txn.execute(sql, list(keyvalues.values())) - else: - txn.execute(sql) - - return [r[0] for r in txn] - - def _simple_select_onecol( - self, table, keyvalues, retcol, desc="_simple_select_onecol" - ): - """Executes a SELECT query on the named table, which returns a list - comprising of the values of the named column from the selected rows. - - Args: - table (str): table name - keyvalues (dict|None): column names and values to select the rows with - retcol (str): column whos value we wish to retrieve. - - Returns: - Deferred: Results in a list - """ - return self.runInteraction( - desc, self._simple_select_onecol_txn, table, keyvalues, retcol - ) - - def _simple_select_list( - self, table, keyvalues, retcols, desc="_simple_select_list" - ): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Args: - table (str): the table name - keyvalues (dict[str, Any] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - retcols (iterable[str]): the names of the columns to return - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self.runInteraction( - desc, self._simple_select_list_txn, table, keyvalues, retcols - ) - - @classmethod - def _simple_select_list_txn(cls, txn, table, keyvalues, retcols): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Args: - txn : Transaction object - table (str): the table name - keyvalues (dict[str, T] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - retcols (iterable[str]): the names of the columns to return - """ - if keyvalues: - sql = "SELECT %s FROM %s WHERE %s" % ( - ", ".join(retcols), - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) - txn.execute(sql, list(keyvalues.values())) - else: - sql = "SELECT %s FROM %s" % (", ".join(retcols), table) - txn.execute(sql) - - return cls.cursor_to_dict(txn) - - @defer.inlineCallbacks - def _simple_select_many_batch( - self, - table, - column, - iterable, - retcols, - keyvalues={}, - desc="_simple_select_many_batch", - batch_size=100, - ): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Filters rows by if value of `column` is in `iterable`. - - Args: - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with - retcols : list of strings giving the names of the columns to return - """ - results = [] - - if not iterable: - return results - - # iterables can not be sliced, so convert it to a list first - it_list = list(iterable) - - chunks = [ - it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) - ] - for chunk in chunks: - rows = yield self.runInteraction( - desc, - self._simple_select_many_txn, - table, - column, - chunk, - keyvalues, - retcols, - ) - - results.extend(rows) - - return results - - @classmethod - def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Filters rows by if value of `column` is in `iterable`. - - Args: - txn : Transaction object - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with - retcols : list of strings giving the names of the columns to return - """ - if not iterable: - return [] - - clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) - clauses = [clause] - - for key, value in iteritems(keyvalues): - clauses.append("%s = ?" % (key,)) - values.append(value) - - sql = "SELECT %s FROM %s WHERE %s" % ( - ", ".join(retcols), - table, - " AND ".join(clauses), - ) - - txn.execute(sql, values) - return cls.cursor_to_dict(txn) - - def _simple_update(self, table, keyvalues, updatevalues, desc): - return self.runInteraction( - desc, self._simple_update_txn, table, keyvalues, updatevalues - ) - - @staticmethod - def _simple_update_txn(txn, table, keyvalues, updatevalues): - if keyvalues: - where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) - else: - where = "" - - update_sql = "UPDATE %s SET %s %s" % ( - table, - ", ".join("%s = ?" % (k,) for k in updatevalues), - where, - ) - - txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values())) - - return txn.rowcount - - def _simple_update_one( - self, table, keyvalues, updatevalues, desc="_simple_update_one" - ): - """Executes an UPDATE query on the named table, setting new values for - columns in a row matching the key values. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - updatevalues : dict giving column names and values to update - retcols : optional list of column names to return - - If present, retcols gives a list of column names on which to perform - a SELECT statement *before* performing the UPDATE statement. The values - of these will be returned in a dict. - - These are performed within the same transaction, allowing an atomic - get-and-set. This can be used to implement compare-and-set by putting - the update column in the 'keyvalues' dict as well. - """ - return self.runInteraction( - desc, self._simple_update_one_txn, table, keyvalues, updatevalues - ) - - @classmethod - def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): - rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues) - - if rowcount == 0: - raise StoreError(404, "No row found (%s)" % (table,)) - if rowcount > 1: - raise StoreError(500, "More than one row matched (%s)" % (table,)) - - @staticmethod - def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): - select_sql = "SELECT %s FROM %s WHERE %s" % ( - ", ".join(retcols), - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) - - txn.execute(select_sql, list(keyvalues.values())) - row = txn.fetchone() - - if not row: - if allow_none: - return None - raise StoreError(404, "No row found (%s)" % (table,)) - if txn.rowcount > 1: - raise StoreError(500, "More than one row matched (%s)" % (table,)) - - return dict(zip(retcols, row)) - - def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"): - """Executes a DELETE query on the named table, expecting to delete a - single row. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - """ - return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues) - - @staticmethod - def _simple_delete_one_txn(txn, table, keyvalues): - """Executes a DELETE query on the named table, expecting to delete a - single row. - - Args: - table : string giving the table name - keyvalues : dict of column names and values to select the row with - """ - sql = "DELETE FROM %s WHERE %s" % ( - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) - - txn.execute(sql, list(keyvalues.values())) - if txn.rowcount == 0: - raise StoreError(404, "No row found (%s)" % (table,)) - if txn.rowcount > 1: - raise StoreError(500, "More than one row matched (%s)" % (table,)) - - def _simple_delete(self, table, keyvalues, desc): - return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues) - - @staticmethod - def _simple_delete_txn(txn, table, keyvalues): - sql = "DELETE FROM %s WHERE %s" % ( - table, - " AND ".join("%s = ?" % (k,) for k in keyvalues), - ) - - txn.execute(sql, list(keyvalues.values())) - return txn.rowcount - - def _simple_delete_many(self, table, column, iterable, keyvalues, desc): - return self.runInteraction( - desc, self._simple_delete_many_txn, table, column, iterable, keyvalues - ) - - @staticmethod - def _simple_delete_many_txn(txn, table, column, iterable, keyvalues): - """Executes a DELETE query on the named table. - - Filters rows by if value of `column` is in `iterable`. - - Args: - txn : Transaction object - table : string giving the table name - column : column name to test for inclusion against `iterable` - iterable : list - keyvalues : dict of column names and values to select the rows with - - Returns: - int: Number rows deleted - """ - if not iterable: - return 0 - - sql = "DELETE FROM %s" % table - - clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) - clauses = [clause] - - for key, value in iteritems(keyvalues): - clauses.append("%s = ?" % (key,)) - values.append(value) - - if clauses: - sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) - txn.execute(sql, values) - - return txn.rowcount - - def _get_cache_dict( - self, db_conn, table, entity_column, stream_column, max_value, limit=100000 - ): - # Fetch a mapping of room_id -> max stream position for "recent" rooms. - # It doesn't really matter how many we get, the StreamChangeCache will - # do the right thing to ensure it respects the max size of cache. - sql = ( - "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" - " WHERE %(stream)s > ? - %(limit)s" - " GROUP BY %(entity)s" - ) % { - "table": table, - "entity": entity_column, - "stream": stream_column, - "limit": limit, - } - - sql = self.database_engine.convert_param_style(sql) - - txn = db_conn.cursor() - txn.execute(sql, (int(max_value),)) - - cache = {row[0]: int(row[1]) for row in txn} - - txn.close() - - if cache: - min_val = min(itervalues(cache)) - else: - min_val = max_value - - return cache, min_val - - def _invalidate_cache_and_stream(self, txn, cache_func, keys): - """Invalidates the cache and adds it to the cache stream so slaves - will know to invalidate their caches. - - This should only be used to invalidate caches where slaves won't - otherwise know from other replication streams that the cache should - be invalidated. - """ - txn.call_after(cache_func.invalidate, keys) - self._send_invalidation_to_replication(txn, cache_func.__name__, keys) - - def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): - """Special case invalidation of caches based on current state. - - We special case this so that we can batch the cache invalidations into a - single replication poke. - - Args: - txn - room_id (str): Room where state changed - members_changed (iterable[str]): The user_ids of members that have changed - """ - txn.call_after(self._invalidate_state_caches, room_id, members_changed) - - if members_changed: - # We need to be careful that the size of the `members_changed` list - # isn't so large that it causes problems sending over replication, so we - # send them in chunks. - # Max line length is 16K, and max user ID length is 255, so 50 should - # be safe. - for chunk in batch_iter(members_changed, 50): - keys = itertools.chain([room_id], chunk) - self._send_invalidation_to_replication( - txn, _CURRENT_STATE_CACHE_NAME, keys - ) - else: - # if no members changed, we still need to invalidate the other caches. - self._send_invalidation_to_replication( - txn, _CURRENT_STATE_CACHE_NAME, [room_id] - ) - def _invalidate_state_caches(self, room_id, members_changed): """Invalidates caches that are based on the current state, but does not stream invalidations down replication. @@ -1474,226 +77,6 @@ class SQLBaseStore(object): # which is fine. pass - def _send_invalidation_to_replication(self, txn, cache_name, keys): - """Notifies replication that given cache has been invalidated. - - Note that this does *not* invalidate the cache locally. - - Args: - txn - cache_name (str) - keys (iterable[str]) - """ - - if isinstance(self.database_engine, PostgresEngine): - # get_next() returns a context manager which is designed to wrap - # the transaction. However, we want to only get an ID when we want - # to use it, here, so we need to call __enter__ manually, and have - # __exit__ called after the transaction finishes. - ctx = self._cache_id_gen.get_next() - stream_id = ctx.__enter__() - txn.call_on_exception(ctx.__exit__, None, None, None) - txn.call_after(ctx.__exit__, None, None, None) - txn.call_after(self.hs.get_notifier().on_new_replication_data) - - self._simple_insert_txn( - txn, - table="cache_invalidation_stream", - values={ - "stream_id": stream_id, - "cache_func": cache_name, - "keys": list(keys), - "invalidation_ts": self.clock.time_msec(), - }, - ) - - def get_all_updated_caches(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_updated_caches_txn(txn): - # We purposefully don't bound by the current token, as we want to - # send across cache invalidations as quickly as possible. Cache - # invalidations are idempotent, so duplicates are fine. - sql = ( - "SELECT stream_id, cache_func, keys, invalidation_ts" - " FROM cache_invalidation_stream" - " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, limit)) - return txn.fetchall() - - return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn) - - def get_cache_stream_token(self): - if self._cache_id_gen: - return self._cache_id_gen.get_current_token() - else: - return 0 - - def _simple_select_list_paginate( - self, - table, - keyvalues, - orderby, - start, - limit, - retcols, - order_direction="ASC", - desc="_simple_select_list_paginate", - ): - """ - Executes a SELECT query on the named table with start and limit, - of row numbers, which may return zero or number of rows from start to limit, - returning the result as a list of dicts. - - Args: - table (str): the table name - keyvalues (dict[str, T] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - orderby (str): Column to order the results by. - start (int): Index to begin the query at. - limit (int): Number of results to return. - retcols (iterable[str]): the names of the columns to return - order_direction (str): Whether the results should be ordered "ASC" or "DESC". - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - return self.runInteraction( - desc, - self._simple_select_list_paginate_txn, - table, - keyvalues, - orderby, - start, - limit, - retcols, - order_direction=order_direction, - ) - - @classmethod - def _simple_select_list_paginate_txn( - cls, - txn, - table, - keyvalues, - orderby, - start, - limit, - retcols, - order_direction="ASC", - ): - """ - Executes a SELECT query on the named table with start and limit, - of row numbers, which may return zero or number of rows from start to limit, - returning the result as a list of dicts. - - Args: - txn : Transaction object - table (str): the table name - keyvalues (dict[str, T] | None): - column names and values to select the rows with, or None to not - apply a WHERE clause. - orderby (str): Column to order the results by. - start (int): Index to begin the query at. - limit (int): Number of results to return. - retcols (iterable[str]): the names of the columns to return - order_direction (str): Whether the results should be ordered "ASC" or "DESC". - Returns: - defer.Deferred: resolves to list[dict[str, Any]] - """ - if order_direction not in ["ASC", "DESC"]: - raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") - - if keyvalues: - where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues) - else: - where_clause = "" - - sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % ( - ", ".join(retcols), - table, - where_clause, - orderby, - order_direction, - ) - txn.execute(sql, list(keyvalues.values()) + [limit, start]) - - return cls.cursor_to_dict(txn) - - def get_user_count_txn(self, txn): - """Get a total number of registered users in the users list. - - Args: - txn : Transaction object - Returns: - int : number of users - """ - sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;" - txn.execute(sql_count) - return txn.fetchone()[0] - - def _simple_search_list( - self, table, term, col, retcols, desc="_simple_search_list" - ): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Args: - table (str): the table name - term (str | None): - term for searching the table matched to a column. - col (str): column to query term should be matched to - retcols (iterable[str]): the names of the columns to return - Returns: - defer.Deferred: resolves to list[dict[str, Any]] or None - """ - - return self.runInteraction( - desc, self._simple_search_list_txn, table, term, col, retcols - ) - - @classmethod - def _simple_search_list_txn(cls, txn, table, term, col, retcols): - """Executes a SELECT query on the named table, which may return zero or - more rows, returning the result as a list of dicts. - - Args: - txn : Transaction object - table (str): the table name - term (str | None): - term for searching the table matched to a column. - col (str): column to query term should be matched to - retcols (iterable[str]): the names of the columns to return - Returns: - defer.Deferred: resolves to list[dict[str, Any]] or None - """ - if term: - sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col) - termvalues = ["%%" + term + "%%"] - txn.execute(sql, termvalues) - else: - return 0 - - return cls.cursor_to_dict(txn) - - @property - def database_engine_name(self): - return self.database_engine.module.__name__ - - def get_server_version(self): - """Returns a string describing the server version number""" - return self.database_engine.server_version - - -class _RollbackButIsFineException(Exception): - """ This exception is used to rollback a transaction without implying - something went wrong. - """ - - pass - def db_to_json(db_content): """ @@ -1722,30 +105,3 @@ def db_to_json(db_content): except Exception: logging.warning("Tried to decode '%r' as JSON and failed", db_content) raise - - -def make_in_list_sql_clause( - database_engine, column: str, iterable: Iterable -) -> Tuple[str, Iterable]: - """Returns an SQL clause that checks the given column is in the iterable. - - On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres - it expands to `column = ANY(?)`. While both DBs support the `IN` form, - using the `ANY` form on postgres means that it views queries with - different length iterables as the same, helping the query stats. - - Args: - database_engine - column: Name of the column - iterable: The values to check the column against. - - Returns: - A tuple of SQL query and the args - """ - - if database_engine.supports_using_any_list: - # This should hopefully be faster, but also makes postgres query - # stats easier to understand. - return "%s = ANY(?)" % (column,), [list(iterable)] - else: - return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 37d469ffd7..4f97fd5ab6 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -22,7 +22,6 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from . import engines -from ._base import SQLBaseStore logger = logging.getLogger(__name__) @@ -74,7 +73,7 @@ class BackgroundUpdatePerformance(object): return float(self.total_item_count) / float(self.total_duration_ms) -class BackgroundUpdateStore(SQLBaseStore): +class BackgroundUpdater(object): """ Background updates are updates to the database that run in the background. Each update processes a batch of data at once. We attempt to limit the impact of each update by monitoring how long each batch takes to @@ -86,8 +85,10 @@ class BackgroundUpdateStore(SQLBaseStore): BACKGROUND_UPDATE_INTERVAL_MS = 1000 BACKGROUND_UPDATE_DURATION_MS = 100 - def __init__(self, db_conn, hs): - super(BackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, hs, database): + self._clock = hs.get_clock() + self.db = database + self._background_update_performance = {} self._background_update_queue = [] self._background_update_handlers = {} @@ -101,9 +102,7 @@ class BackgroundUpdateStore(SQLBaseStore): logger.info("Starting background schema updates") while True: if sleep: - yield self.hs.get_clock().sleep( - self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0 - ) + yield self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0) try: result = yield self.do_next_background_update( @@ -139,7 +138,7 @@ class BackgroundUpdateStore(SQLBaseStore): # otherwise, check if there are updates to be run. This is important, # as we may be running on a worker which doesn't perform the bg updates # itself, but still wants to wait for them to happen. - updates = yield self._simple_select_onecol( + updates = yield self.db.simple_select_onecol( "background_updates", keyvalues=None, retcol="1", @@ -161,7 +160,7 @@ class BackgroundUpdateStore(SQLBaseStore): if update_name in self._background_update_queue: return False - update_exists = await self._simple_select_one_onecol( + update_exists = await self.db.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="1", @@ -184,7 +183,7 @@ class BackgroundUpdateStore(SQLBaseStore): no more work to do. """ if not self._background_update_queue: - updates = yield self._simple_select_list( + updates = yield self.db.simple_select_list( "background_updates", keyvalues=None, retcols=("update_name", "depends_on"), @@ -226,7 +225,7 @@ class BackgroundUpdateStore(SQLBaseStore): else: batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE - progress_json = yield self._simple_select_one_onecol( + progress_json = yield self.db.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="progress_json", @@ -380,7 +379,7 @@ class BackgroundUpdateStore(SQLBaseStore): logger.debug("[SQL] %s", sql) c.execute(sql) - if isinstance(self.database_engine, engines.PostgresEngine): + if isinstance(self.db.engine, engines.PostgresEngine): runner = create_index_psql elif psql_only: runner = None @@ -391,7 +390,7 @@ class BackgroundUpdateStore(SQLBaseStore): def updater(progress, batch_size): if runner is not None: logger.info("Adding index %s to %s", index_name, table) - yield self.runWithConnection(runner) + yield self.db.runWithConnection(runner) yield self._end_background_update(update_name) return 1 @@ -413,7 +412,7 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_queue = [] progress_json = json.dumps(progress) - return self._simple_insert( + return self.db.simple_insert( "background_updates", {"update_name": update_name, "progress_json": progress_json}, ) @@ -429,7 +428,7 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_queue = [ name for name in self._background_update_queue if name != update_name ] - return self._simple_delete_one( + return self.db.simple_delete_one( "background_updates", keyvalues={"update_name": update_name} ) @@ -444,7 +443,7 @@ class BackgroundUpdateStore(SQLBaseStore): progress_json = json.dumps(progress) - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, "background_updates", keyvalues={"update_name": update_name}, diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index cb184a98cc..cafedd5c0d 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.storage.database import Database +from synapse.storage.prepare_database import prepare_database + class DataStores(object): """The various data stores. @@ -20,7 +23,14 @@ class DataStores(object): These are low level interfaces to physical databases. """ - def __init__(self, main_store, db_conn, hs): - # Note we pass in the main store here as workers use a different main + def __init__(self, main_store_class, db_conn, hs): + # Note we pass in the main store class here as workers use a different main # store. - self.main = main_store + database = Database(hs) + + # Check that db is correctly configured. + database.engine.check_database(db_conn.cursor()) + + prepare_database(db_conn, database.engine, config=hs.config) + + self.main = main_store_class(database, db_conn, hs) diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index 10c940df1e..c577c0df5f 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -19,9 +19,8 @@ import calendar import logging import time -from twisted.internet import defer - from synapse.api.constants import PresenceState +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import ( ChainedIdGenerator, @@ -32,6 +31,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache from .account_data import AccountDataStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore +from .cache import CacheInvalidationStore from .client_ips import ClientIpStore from .deviceinbox import DeviceInboxStore from .devices import DeviceStore @@ -110,11 +110,22 @@ class DataStore( MonthlyActiveUsersStore, StatsStore, RelationsStore, + CacheInvalidationStore, ): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() - self.database_engine = hs.database_engine + self.database_engine = database.engine + + all_users_native = are_all_users_on_domain( + db_conn.cursor(), database.engine, hs.hostname + ) + if not all_users_native: + raise Exception( + "Found users in database not native to %s!\n" + "You cannot changed a synapse server_name after it's been configured" + % (hs.hostname,) + ) self._stream_id_gen = StreamIdGenerator( db_conn, @@ -169,9 +180,11 @@ class DataStore( else: self._cache_id_gen = None + super(DataStore, self).__init__(database, db_conn, hs) + self._presence_on_startup = self._get_active_presence(db_conn) - presence_cache_prefill, min_presence_val = self._get_cache_dict( + presence_cache_prefill, min_presence_val = self.db.get_cache_dict( db_conn, "presence_stream", entity_column="user_id", @@ -185,7 +198,7 @@ class DataStore( ) max_device_inbox_id = self._device_inbox_id_gen.get_current_token() - device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( + device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict( db_conn, "device_inbox", entity_column="user_id", @@ -200,7 +213,7 @@ class DataStore( ) # The federation outbox and the local device inbox uses the same # stream_id generator. - device_outbox_prefill, min_device_outbox_id = self._get_cache_dict( + device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict( db_conn, "device_federation_outbox", entity_column="destination", @@ -226,7 +239,7 @@ class DataStore( ) events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( + curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", @@ -240,7 +253,7 @@ class DataStore( prefilled_cache=curr_state_delta_prefill, ) - _group_updates_prefill, min_group_updates_id = self._get_cache_dict( + _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict( db_conn, "local_group_updates", entity_column="user_id", @@ -260,8 +273,6 @@ class DataStore( # Used in _generate_user_daily_visits to keep track of progress self._last_user_visit_update = self._get_start_of_day() - super(DataStore, self).__init__(db_conn, hs) - def take_presence_startup_info(self): active_on_startup = self._presence_on_startup self._presence_on_startup = None @@ -281,7 +292,7 @@ class DataStore( txn = db_conn.cursor() txn.execute(sql, (PresenceState.OFFLINE,)) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) txn.close() for row in rows: @@ -294,7 +305,7 @@ class DataStore( Counts the number of users who used this homeserver in the last 24 hours. """ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return self.runInteraction("count_daily_users", self._count_users, yesterday) + return self.db.runInteraction("count_daily_users", self._count_users, yesterday) def count_monthly_users(self): """ @@ -304,7 +315,7 @@ class DataStore( amongst other things, includes a 3 day grace period before a user counts. """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - return self.runInteraction( + return self.db.runInteraction( "count_monthly_users", self._count_users, thirty_days_ago ) @@ -404,7 +415,7 @@ class DataStore( return results - return self.runInteraction("count_r30_users", _count_r30_users) + return self.db.runInteraction("count_r30_users", _count_r30_users) def _get_start_of_day(self): """ @@ -469,50 +480,73 @@ class DataStore( # frequently self._last_user_visit_update = now - return self.runInteraction( + return self.db.runInteraction( "generate_user_daily_visits", _generate_user_daily_visits ) def get_users(self): - """Function to reterive a list of users in users table. + """Function to retrieve a list of users in users table. Args: Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - return self._simple_select_list( + return self.db.simple_select_list( table="users", keyvalues={}, - retcols=["name", "password_hash", "is_guest", "admin", "user_type"], + retcols=[ + "name", + "password_hash", + "is_guest", + "admin", + "user_type", + "deactivated", + ], desc="get_users", ) - @defer.inlineCallbacks - def get_users_paginate(self, order, start, limit): - """Function to reterive a paginated list of users from - users list. This will return a json object, which contains - list of users and the total number of users in users table. + def get_users_paginate( + self, start, limit, name=None, guests=True, deactivated=False + ): + """Function to retrieve a paginated list of users from + users list. This will return a json list of users. Args: - order (str): column name to order the select by this column start (int): start number to begin the query from - limit (int): number of rows to reterive + limit (int): number of rows to retrieve + name (string): filter for user names + guests (bool): whether to in include guest users + deactivated (bool): whether to include deactivated users Returns: - defer.Deferred: resolves to json object {list[dict[str, Any]], count} + defer.Deferred: resolves to list[dict[str, Any]] """ - users = yield self.runInteraction( - "get_users_paginate", - self._simple_select_list_paginate_txn, + name_filter = {} + if name: + name_filter["name"] = "%" + name + "%" + + attr_filter = {} + if not guests: + attr_filter["is_guest"] = False + if not deactivated: + attr_filter["deactivated"] = False + + return self.db.simple_select_list_paginate( + desc="get_users_paginate", table="users", - keyvalues={"is_guest": False}, - orderby=order, + orderby="name", start=start, limit=limit, - retcols=["name", "password_hash", "is_guest", "admin", "user_type"], + filters=name_filter, + keyvalues=attr_filter, + retcols=[ + "name", + "password_hash", + "is_guest", + "admin", + "user_type", + "deactivated", + ], ) - count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn) - retval = {"users": users, "total": count} - return retval def search_users(self, term): """Function to search users list for one or more users with @@ -524,10 +558,22 @@ class DataStore( Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - return self._simple_search_list( + return self.db.simple_search_list( table="users", term=term, col="name", retcols=["name", "password_hash", "is_guest", "admin", "user_type"], desc="search_users", ) + + +def are_all_users_on_domain(txn, database_engine, domain): + sql = database_engine.convert_param_style( + "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?" + ) + pat = "%:" + domain + txn.execute(sql, (pat,)) + num_not_matching = txn.fetchall()[0][0] + if num_not_matching == 0: + return True + return False diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index 22093484ed..46b494b334 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -22,6 +22,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -38,13 +39,13 @@ class AccountDataWorkerStore(SQLBaseStore): # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max ) - super(AccountDataWorkerStore, self).__init__(db_conn, hs) + super(AccountDataWorkerStore, self).__init__(database, db_conn, hs) @abc.abstractmethod def get_max_account_data_stream_id(self): @@ -67,7 +68,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_user_txn(txn): - rows = self._simple_select_list_txn( + rows = self.db.simple_select_list_txn( txn, "account_data", {"user_id": user_id}, @@ -78,7 +79,7 @@ class AccountDataWorkerStore(SQLBaseStore): row["account_data_type"]: json.loads(row["content"]) for row in rows } - rows = self._simple_select_list_txn( + rows = self.db.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id}, @@ -92,7 +93,7 @@ class AccountDataWorkerStore(SQLBaseStore): return global_account_data, by_room - return self.runInteraction( + return self.db.runInteraction( "get_account_data_for_user", get_account_data_for_user_txn ) @@ -102,7 +103,7 @@ class AccountDataWorkerStore(SQLBaseStore): Returns: Deferred: A dict """ - result = yield self._simple_select_one_onecol( + result = yield self.db.simple_select_one_onecol( table="account_data", keyvalues={"user_id": user_id, "account_data_type": data_type}, retcol="content", @@ -127,7 +128,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_room_txn(txn): - rows = self._simple_select_list_txn( + rows = self.db.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, @@ -138,7 +139,7 @@ class AccountDataWorkerStore(SQLBaseStore): row["account_data_type"]: json.loads(row["content"]) for row in rows } - return self.runInteraction( + return self.db.runInteraction( "get_account_data_for_room", get_account_data_for_room_txn ) @@ -156,7 +157,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_room_and_type_txn(txn): - content_json = self._simple_select_one_onecol_txn( + content_json = self.db.simple_select_one_onecol_txn( txn, table="room_account_data", keyvalues={ @@ -170,7 +171,7 @@ class AccountDataWorkerStore(SQLBaseStore): return json.loads(content_json) if content_json else None - return self.runInteraction( + return self.db.runInteraction( "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn ) @@ -207,7 +208,7 @@ class AccountDataWorkerStore(SQLBaseStore): room_results = txn.fetchall() return global_results, room_results - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_account_data_txn", get_updated_account_data_txn ) @@ -250,9 +251,9 @@ class AccountDataWorkerStore(SQLBaseStore): user_id, int(stream_id) ) if not changed: - return {}, {} + return defer.succeed(({}, {})) - return self.runInteraction( + return self.db.runInteraction( "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) @@ -270,12 +271,12 @@ class AccountDataWorkerStore(SQLBaseStore): class AccountDataStore(AccountDataWorkerStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self._account_data_id_gen = StreamIdGenerator( db_conn, "account_data_max_stream_id", "stream_id" ) - super(AccountDataStore, self).__init__(db_conn, hs) + super(AccountDataStore, self).__init__(database, db_conn, hs) def get_max_account_data_stream_id(self): """Get the current max stream id for the private user data stream @@ -300,9 +301,9 @@ class AccountDataStore(AccountDataWorkerStore): with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint - # on (user_id, room_id, account_data_type) so _simple_upsert will + # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. - yield self._simple_upsert( + yield self.db.simple_upsert( desc="add_room_account_data", table="room_account_data", keyvalues={ @@ -346,9 +347,9 @@ class AccountDataStore(AccountDataWorkerStore): with self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on - # (user_id, account_data_type) so _simple_upsert will retry if + # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. - yield self._simple_upsert( + yield self.db.simple_upsert( desc="add_user_account_data", table="account_data", keyvalues={"user_id": user_id, "account_data_type": account_data_type}, @@ -388,4 +389,4 @@ class AccountDataStore(AccountDataWorkerStore): ) txn.execute(update_max_id_sql, (next_id, next_id)) - return self.runInteraction("update_account_data_max_stream_id", _update) + return self.db.runInteraction("update_account_data_max_stream_id", _update) diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index 81babf2029..b2f39649fd 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -24,6 +24,7 @@ from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database logger = logging.getLogger(__name__) @@ -48,13 +49,13 @@ def _make_exclusive_regex(services_cache): class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.services_cache = load_appservices( hs.hostname, hs.config.app_service_config_files ) self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) - super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs) + super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs) def get_app_services(self): return self.services_cache @@ -133,7 +134,7 @@ class ApplicationServiceTransactionWorkerStore( A Deferred which resolves to a list of ApplicationServices, which may be empty. """ - results = yield self._simple_select_list( + results = yield self.db.simple_select_list( "application_services_state", dict(state=state), ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore @@ -155,7 +156,7 @@ class ApplicationServiceTransactionWorkerStore( Returns: A Deferred which resolves to ApplicationServiceState. """ - result = yield self._simple_select_one( + result = yield self.db.simple_select_one( "application_services_state", dict(as_id=service.id), ["state"], @@ -175,7 +176,7 @@ class ApplicationServiceTransactionWorkerStore( Returns: A Deferred which resolves when the state was set successfully. """ - return self._simple_upsert( + return self.db.simple_upsert( "application_services_state", dict(as_id=service.id), dict(state=state) ) @@ -216,7 +217,7 @@ class ApplicationServiceTransactionWorkerStore( ) return AppServiceTransaction(service=service, id=new_txn_id, events=events) - return self.runInteraction("create_appservice_txn", _create_appservice_txn) + return self.db.runInteraction("create_appservice_txn", _create_appservice_txn) def complete_appservice_txn(self, txn_id, service): """Completes an application service transaction. @@ -249,7 +250,7 @@ class ApplicationServiceTransactionWorkerStore( ) # Set current txn_id for AS to 'txn_id' - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, "application_services_state", dict(as_id=service.id), @@ -257,11 +258,13 @@ class ApplicationServiceTransactionWorkerStore( ) # Delete txn - self._simple_delete_txn( + self.db.simple_delete_txn( txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id) ) - return self.runInteraction("complete_appservice_txn", _complete_appservice_txn) + return self.db.runInteraction( + "complete_appservice_txn", _complete_appservice_txn + ) @defer.inlineCallbacks def get_oldest_unsent_txn(self, service): @@ -283,7 +286,7 @@ class ApplicationServiceTransactionWorkerStore( " ORDER BY txn_id ASC LIMIT 1", (service.id,), ) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return None @@ -291,7 +294,7 @@ class ApplicationServiceTransactionWorkerStore( return entry - entry = yield self.runInteraction( + entry = yield self.db.runInteraction( "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn ) @@ -321,7 +324,7 @@ class ApplicationServiceTransactionWorkerStore( "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) ) - return self.runInteraction( + return self.db.runInteraction( "set_appservice_last_pos", set_appservice_last_pos_txn ) @@ -350,7 +353,7 @@ class ApplicationServiceTransactionWorkerStore( return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.runInteraction( + upper_bound, event_ids = yield self.db.runInteraction( "get_new_events_for_appservice", get_new_events_for_appservice_txn ) diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py new file mode 100644 index 0000000000..54ed8574c4 --- /dev/null +++ b/synapse/storage/data_stores/main/cache.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.storage.engines import PostgresEngine +from synapse.util import batch_iter + +logger = logging.getLogger(__name__) + + +# This is a special cache name we use to batch multiple invalidations of caches +# based on the current state when notifying workers over replication. +CURRENT_STATE_CACHE_NAME = "cs_cache_fake" + + +class CacheInvalidationStore(SQLBaseStore): + def _invalidate_cache_and_stream(self, txn, cache_func, keys): + """Invalidates the cache and adds it to the cache stream so slaves + will know to invalidate their caches. + + This should only be used to invalidate caches where slaves won't + otherwise know from other replication streams that the cache should + be invalidated. + """ + txn.call_after(cache_func.invalidate, keys) + self._send_invalidation_to_replication(txn, cache_func.__name__, keys) + + def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): + """Special case invalidation of caches based on current state. + + We special case this so that we can batch the cache invalidations into a + single replication poke. + + Args: + txn + room_id (str): Room where state changed + members_changed (iterable[str]): The user_ids of members that have changed + """ + txn.call_after(self._invalidate_state_caches, room_id, members_changed) + + if members_changed: + # We need to be careful that the size of the `members_changed` list + # isn't so large that it causes problems sending over replication, so we + # send them in chunks. + # Max line length is 16K, and max user ID length is 255, so 50 should + # be safe. + for chunk in batch_iter(members_changed, 50): + keys = itertools.chain([room_id], chunk) + self._send_invalidation_to_replication( + txn, CURRENT_STATE_CACHE_NAME, keys + ) + else: + # if no members changed, we still need to invalidate the other caches. + self._send_invalidation_to_replication( + txn, CURRENT_STATE_CACHE_NAME, [room_id] + ) + + def _send_invalidation_to_replication(self, txn, cache_name, keys): + """Notifies replication that given cache has been invalidated. + + Note that this does *not* invalidate the cache locally. + + Args: + txn + cache_name (str) + keys (iterable[str]) + """ + + if isinstance(self.database_engine, PostgresEngine): + # get_next() returns a context manager which is designed to wrap + # the transaction. However, we want to only get an ID when we want + # to use it, here, so we need to call __enter__ manually, and have + # __exit__ called after the transaction finishes. + ctx = self._cache_id_gen.get_next() + stream_id = ctx.__enter__() + txn.call_on_exception(ctx.__exit__, None, None, None) + txn.call_after(ctx.__exit__, None, None, None) + txn.call_after(self.hs.get_notifier().on_new_replication_data) + + self.db.simple_insert_txn( + txn, + table="cache_invalidation_stream", + values={ + "stream_id": stream_id, + "cache_func": cache_name, + "keys": list(keys), + "invalidation_ts": self.clock.time_msec(), + }, + ) + + def get_all_updated_caches(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_caches_txn(txn): + # We purposefully don't bound by the current token, as we want to + # send across cache invalidations as quickly as possible. Cache + # invalidations are idempotent, so duplicates are fine. + sql = ( + "SELECT stream_id, cache_func, keys, invalidation_ts" + " FROM cache_invalidation_stream" + " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, limit)) + return txn.fetchall() + + return self.db.runInteraction( + "get_all_updated_caches", get_all_updated_caches_txn + ) + + def get_cache_stream_token(self): + if self._cache_id_gen: + return self._cache_id_gen.get_current_token() + else: + return 0 diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 62b8e06fb4..b3f1806c72 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -20,9 +20,10 @@ from six import iteritems from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage import background_updates -from synapse.storage._base import Cache +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.util.caches import CACHE_SIZE_FACTOR +from synapse.util.caches.descriptors import Cache logger = logging.getLogger(__name__) @@ -32,41 +33,41 @@ logger = logging.getLogger(__name__) LAST_SEEN_GRANULARITY = 10 * 60 * 1000 -class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs) +class ClientIpBackgroundUpdateStore(SQLBaseStore): + def __init__(self, database: Database, db_conn, hs): + super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_device_index", index_name="user_ips_device_id", table="user_ips", columns=["user_id", "device_id", "last_seen"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_last_seen_index", index_name="user_ips_last_seen", table="user_ips", columns=["user_id", "last_seen"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_last_seen_only_index", index_name="user_ips_last_seen_only", table="user_ips", columns=["last_seen"], ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_ips_analyze", self._analyze_user_ip ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_ips_remove_dupes", self._remove_user_ip_dupes ) # Register a unique index - self.register_background_index_update( + self.db.updates.register_background_index_update( "user_ips_device_unique_index", index_name="user_ips_user_token_ip_unique_index", table="user_ips", @@ -75,12 +76,12 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): ) # Drop the old non-unique index - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_ips_drop_nonunique_index", self._remove_user_ip_nonunique ) # Update the last seen info in devices. - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "devices_last_seen", self._devices_last_seen_update ) @@ -91,8 +92,8 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): txn.execute("DROP INDEX IF EXISTS user_ips_user_ip") txn.close() - yield self.runWithConnection(f) - yield self._end_background_update("user_ips_drop_nonunique_index") + yield self.db.runWithConnection(f) + yield self.db.updates._end_background_update("user_ips_drop_nonunique_index") return 1 @defer.inlineCallbacks @@ -106,9 +107,9 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): def user_ips_analyze(txn): txn.execute("ANALYZE user_ips") - yield self.runInteraction("user_ips_analyze", user_ips_analyze) + yield self.db.runInteraction("user_ips_analyze", user_ips_analyze) - yield self._end_background_update("user_ips_analyze") + yield self.db.updates._end_background_update("user_ips_analyze") return 1 @@ -140,7 +141,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): return None # Get a last seen that has roughly `batch_size` since `begin_last_seen` - end_last_seen = yield self.runInteraction( + end_last_seen = yield self.db.runInteraction( "user_ips_dups_get_last_seen", get_last_seen ) @@ -271,14 +272,14 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): (user_id, access_token, ip, device_id, user_agent, last_seen), ) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "user_ips_remove_dupes", {"last_seen": end_last_seen} ) - yield self.runInteraction("user_ips_dups_remove", remove) + yield self.db.runInteraction("user_ips_dups_remove", remove) if last: - yield self._end_background_update("user_ips_remove_dupes") + yield self.db.updates._end_background_update("user_ips_remove_dupes") return batch_size @@ -344,7 +345,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): txn.execute_batch(sql, rows) _, _, _, user_id, device_id = rows[-1] - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "devices_last_seen", {"last_user_id": user_id, "last_device_id": device_id}, @@ -352,24 +353,24 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore): return len(rows) - updated = yield self.runInteraction( + updated = yield self.db.runInteraction( "_devices_last_seen_update", _devices_last_seen_update_txn ) if not updated: - yield self._end_background_update("devices_last_seen") + yield self.db.updates._end_background_update("devices_last_seen") return updated class ClientIpStore(ClientIpBackgroundUpdateStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR ) - super(ClientIpStore, self).__init__(db_conn, hs) + super(ClientIpStore, self).__init__(database, db_conn, hs) self.user_ips_max_age = hs.config.user_ips_max_age @@ -417,12 +418,12 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): to_update = self._batch_row_update self._batch_row_update = {} - return self.runInteraction( + return self.db.runInteraction( "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) def _update_client_ips_batch_txn(self, txn, to_update): - if "user_ips" in self._unsafe_to_upsert_tables or ( + if "user_ips" in self.db._unsafe_to_upsert_tables or ( not self.database_engine.can_native_upsert ): self.database_engine.lock_table(txn, "user_ips") @@ -431,7 +432,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry try: - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="user_ips", keyvalues={ @@ -450,7 +451,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Technically an access token might not be associated with # a device so we need to check. if device_id: - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -483,7 +484,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): if device_id is not None: keyvalues["device_id"] = device_id - res = yield self._simple_select_list( + res = yield self.db.simple_select_list( table="devices", keyvalues=keyvalues, retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), @@ -516,7 +517,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): user_agent, _, last_seen = self._batch_row_update[key] results[(access_token, ip)] = (user_agent, last_seen) - rows = yield self._simple_select_list( + rows = yield self.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "last_seen"], @@ -546,7 +547,9 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Nothing to do return - if not await self.has_completed_background_update("devices_last_seen"): + if not await self.db.updates.has_completed_background_update( + "devices_last_seen" + ): # Only start pruning if we have finished populating the devices # last seen info. return @@ -577,4 +580,4 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): def _prune_old_user_ips_txn(txn): txn.execute(sql, (timestamp,)) - await self.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn) + await self.db.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn) diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index a23744f11c..85cfa16850 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -21,7 +21,7 @@ from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache logger = logging.getLogger(__name__) @@ -69,7 +69,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): stream_pos = current_stream_id return messages, stream_pos - return self.runInteraction( + return self.db.runInteraction( "get_new_messages_for_device", get_new_messages_for_device_txn ) @@ -109,7 +109,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount - count = yield self.runInteraction( + count = yield self.db.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn ) @@ -178,7 +178,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): stream_pos = current_stream_id return messages, stream_pos - return self.runInteraction( + return self.db.runInteraction( "get_new_device_msgs_for_remote", get_new_messages_for_remote_destination_txn, ) @@ -203,25 +203,25 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) txn.execute(sql, (destination, up_to_stream_id)) - return self.runInteraction( + return self.db.runInteraction( "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn ) -class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore): +class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, db_conn, hs): - super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_inbox_stream_index", index_name="device_inbox_stream_id_user_id", table="device_inbox", columns=["stream_id", "user_id"], ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox ) @@ -232,9 +232,9 @@ class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore): txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") txn.close() - yield self.runWithConnection(reindex_txn) + yield self.db.runWithConnection(reindex_txn) - yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID) + yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) return 1 @@ -242,8 +242,8 @@ class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore): class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, db_conn, hs): - super(DeviceInboxStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceInboxStore, self).__init__(database, db_conn, hs) # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. @@ -294,7 +294,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() - yield self.runInteraction( + yield self.db.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id ) for user_id in local_messages_by_user_then_device.keys(): @@ -314,7 +314,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our # acknowledgement from the first time we received the message. - already_inserted = self._simple_select_one_txn( + already_inserted = self.db.simple_select_one_txn( txn, table="device_federation_inbox", keyvalues={"origin": origin, "message_id": message_id}, @@ -326,7 +326,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Add an entry for this message_id so that we know we've processed # it. - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="device_federation_inbox", values={ @@ -344,7 +344,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) with self._device_inbox_id_gen.get_next() as stream_id: now_ms = self.clock.time_msec() - yield self.runInteraction( + yield self.db.runInteraction( "add_messages_from_remote_to_device_inbox", add_messages_txn, now_ms, @@ -465,6 +465,6 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) return rows - return self.runInteraction( + return self.db.runInteraction( "get_all_new_device_messages", get_all_new_device_messages_txn ) diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 71f62036c0..9a828231c4 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -30,16 +30,16 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import ( - Cache, - SQLBaseStore, - db_to_json, - make_in_list_sql_clause, -) -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.types import get_verify_key_from_cross_signing_key from synapse.util import batch_iter -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.caches.descriptors import ( + Cache, + cached, + cachedInlineCallbacks, + cachedList, +) logger = logging.getLogger(__name__) @@ -61,7 +61,7 @@ class DeviceWorkerStore(SQLBaseStore): Raises: StoreError: if the device is not found """ - return self._simple_select_one( + return self.db.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -80,7 +80,7 @@ class DeviceWorkerStore(SQLBaseStore): containing "device_id", "user_id" and "display_name" for each device. """ - devices = yield self._simple_select_list( + devices = yield self.db.simple_select_list( table="devices", keyvalues={"user_id": user_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -122,7 +122,7 @@ class DeviceWorkerStore(SQLBaseStore): # consider the device update to be too large, and simply skip the # stream_id; the rationale being that such a large device list update # is likely an error. - updates = yield self.runInteraction( + updates = yield self.db.runInteraction( "get_device_updates_by_remote", self._get_device_updates_by_remote_txn, destination, @@ -283,7 +283,7 @@ class DeviceWorkerStore(SQLBaseStore): """ devices = ( - yield self.runInteraction( + yield self.db.runInteraction( "_get_e2e_device_keys_txn", self._get_e2e_device_keys_txn, query_map.keys(), @@ -340,12 +340,12 @@ class DeviceWorkerStore(SQLBaseStore): rows = txn.fetchall() return rows[0][0] - return self.runInteraction("get_last_device_update_for_remote_user", f) + return self.db.runInteraction("get_last_device_update_for_remote_user", f) def mark_as_sent_devices_by_remote(self, destination, stream_id): """Mark that updates have successfully been sent to the destination. """ - return self.runInteraction( + return self.db.runInteraction( "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, destination, @@ -399,7 +399,7 @@ class DeviceWorkerStore(SQLBaseStore): """ with self._device_list_id_gen.get_next() as stream_id: - yield self.runInteraction( + yield self.db.runInteraction( "add_user_sig_change_to_streams", self._add_user_signature_change_txn, from_user_id, @@ -414,7 +414,7 @@ class DeviceWorkerStore(SQLBaseStore): from_user_id, stream_id, ) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, "user_signature_stream", values={ @@ -466,7 +466,7 @@ class DeviceWorkerStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2, tree=True) def _get_cached_user_device(self, user_id, device_id): - content = yield self._simple_select_one_onecol( + content = yield self.db.simple_select_one_onecol( table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="content", @@ -476,7 +476,7 @@ class DeviceWorkerStore(SQLBaseStore): @cachedInlineCallbacks() def _get_cached_devices_for_user(self, user_id): - devices = yield self._simple_select_list( + devices = yield self.db.simple_select_list( table="device_lists_remote_cache", keyvalues={"user_id": user_id}, retcols=("device_id", "content"), @@ -492,7 +492,7 @@ class DeviceWorkerStore(SQLBaseStore): Returns: (stream_id, devices) """ - return self.runInteraction( + return self.db.runInteraction( "get_devices_with_keys_by_user", self._get_devices_with_keys_by_user_txn, user_id, @@ -565,7 +565,7 @@ class DeviceWorkerStore(SQLBaseStore): return changes - return self.runInteraction( + return self.db.runInteraction( "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn ) @@ -584,7 +584,7 @@ class DeviceWorkerStore(SQLBaseStore): SELECT DISTINCT user_ids FROM user_signature_stream WHERE from_user_id = ? AND stream_id > ? """ - rows = yield self._execute( + rows = yield self.db.execute( "get_users_whose_signatures_changed", None, sql, user_id, from_key ) return set(user for row in rows for user in json.loads(row[0])) @@ -605,7 +605,7 @@ class DeviceWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? GROUP BY user_id, destination """ - return self._execute( + return self.db.execute( "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key ) @@ -614,7 +614,7 @@ class DeviceWorkerStore(SQLBaseStore): """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, retcol="stream_id", @@ -628,7 +628,7 @@ class DeviceWorkerStore(SQLBaseStore): inlineCallbacks=True, ) def get_device_list_last_stream_id_for_remotes(self, user_ids): - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", iterable=user_ids, @@ -642,11 +642,11 @@ class DeviceWorkerStore(SQLBaseStore): return results -class DeviceBackgroundUpdateStore(BackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs) +class DeviceBackgroundUpdateStore(SQLBaseStore): + def __init__(self, database: Database, db_conn, hs): + super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_lists_stream_idx", index_name="device_lists_stream_user_id", table="device_lists_stream", @@ -654,7 +654,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore): ) # create a unique index on device_lists_remote_cache - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_lists_remote_cache_unique_idx", index_name="device_lists_remote_cache_unique_id", table="device_lists_remote_cache", @@ -663,7 +663,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore): ) # And one on device_lists_remote_extremeties - self.register_background_index_update( + self.db.updates.register_background_index_update( "device_lists_remote_extremeties_unique_idx", index_name="device_lists_remote_extremeties_unique_idx", table="device_lists_remote_extremeties", @@ -672,7 +672,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore): ) # once they complete, we can remove the old non-unique indexes. - self.register_background_update_handler( + self.db.updates.register_background_update_handler( DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES, self._drop_device_list_streams_non_unique_indexes, ) @@ -685,14 +685,16 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore): txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") txn.close() - yield self.runWithConnection(f) - yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES) + yield self.db.runWithConnection(f) + yield self.db.updates._end_background_update( + DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES + ) return 1 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(DeviceStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(DeviceStore, self).__init__(database, db_conn, hs) # Map of (user_id, device_id) -> bool. If there is an entry that implies # the device exists. @@ -722,7 +724,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return False try: - inserted = yield self._simple_insert( + inserted = yield self.db.simple_insert( "devices", values={ "user_id": user_id, @@ -736,7 +738,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not inserted: # if the device already exists, check if it's a real device, or # if the device ID is reserved by something else - hidden = yield self._simple_select_one_onecol( + hidden = yield self.db.simple_select_one_onecol( "devices", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="hidden", @@ -771,7 +773,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Returns: defer.Deferred """ - yield self._simple_delete_one( + yield self.db.simple_delete_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, desc="delete_device", @@ -789,7 +791,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Returns: defer.Deferred """ - yield self._simple_delete_many( + yield self.db.simple_delete_many( table="devices", column="device_id", iterable=device_ids, @@ -818,7 +820,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): updates["display_name"] = new_display_name if not updates: return defer.succeed(None) - return self._simple_update_one( + return self.db.simple_update_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, updatevalues=updates, @@ -829,7 +831,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def mark_remote_user_device_list_as_unsubscribed(self, user_id): """Mark that we no longer track device lists for remote user. """ - yield self._simple_delete( + yield self.db.simple_delete( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, desc="mark_remote_user_device_list_as_unsubscribed", @@ -853,7 +855,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Returns: Deferred[None] """ - return self.runInteraction( + return self.db.runInteraction( "update_remote_device_list_cache_entry", self._update_remote_device_list_cache_entry_txn, user_id, @@ -866,7 +868,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self, txn, user_id, device_id, content, stream_id ): if content.get("deleted"): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -874,7 +876,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id)) else: - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -890,7 +892,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -914,7 +916,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Returns: Deferred[None] """ - return self.runInteraction( + return self.db.runInteraction( "update_remote_device_list_cache", self._update_remote_device_list_cache_txn, user_id, @@ -923,11 +925,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="device_lists_remote_cache", values=[ @@ -946,7 +948,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -962,7 +964,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): (if any) should be poked. """ with self._device_list_id_gen.get_next() as stream_id: - yield self.runInteraction( + yield self.db.runInteraction( "add_device_change_to_streams", self._add_device_change_txn, user_id, @@ -995,7 +997,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): [(user_id, device_id, stream_id) for device_id in device_ids], ) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="device_lists_stream", values=[ @@ -1006,7 +1008,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): context = get_active_span_text_map() - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", values=[ @@ -1069,7 +1071,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return run_as_background_process( "prune_old_outbound_device_pokes", - self.runInteraction, + self.db.runInteraction, "_prune_old_outbound_device_pokes", _prune_txn, ) diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/data_stores/main/directory.py index 297966d9f4..c9e7de7d12 100644 --- a/synapse/storage/data_stores/main/directory.py +++ b/synapse/storage/data_stores/main/directory.py @@ -36,7 +36,7 @@ class DirectoryWorkerStore(SQLBaseStore): Deferred: results in namedtuple with keys "room_id" and "servers" or None if no association can be found """ - room_id = yield self._simple_select_one_onecol( + room_id = yield self.db.simple_select_one_onecol( "room_aliases", {"room_alias": room_alias.to_string()}, "room_id", @@ -47,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore): if not room_id: return None - servers = yield self._simple_select_onecol( + servers = yield self.db.simple_select_onecol( "room_alias_servers", {"room_alias": room_alias.to_string()}, "server", @@ -60,7 +60,7 @@ class DirectoryWorkerStore(SQLBaseStore): return RoomAliasMapping(room_id, room_alias.to_string(), servers) def get_room_alias_creator(self, room_alias): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="room_aliases", keyvalues={"room_alias": room_alias}, retcol="creator", @@ -69,7 +69,7 @@ class DirectoryWorkerStore(SQLBaseStore): @cached(max_entries=5000) def get_aliases_for_room(self, room_id): - return self._simple_select_onecol( + return self.db.simple_select_onecol( "room_aliases", {"room_id": room_id}, "room_alias", @@ -93,7 +93,7 @@ class DirectoryStore(DirectoryWorkerStore): """ def alias_txn(txn): - self._simple_insert_txn( + self.db.simple_insert_txn( txn, "room_aliases", { @@ -103,7 +103,7 @@ class DirectoryStore(DirectoryWorkerStore): }, ) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="room_alias_servers", values=[ @@ -117,7 +117,9 @@ class DirectoryStore(DirectoryWorkerStore): ) try: - ret = yield self.runInteraction("create_room_alias_association", alias_txn) + ret = yield self.db.runInteraction( + "create_room_alias_association", alias_txn + ) except self.database_engine.module.IntegrityError: raise SynapseError( 409, "Room alias %s already exists" % room_alias.to_string() @@ -126,7 +128,7 @@ class DirectoryStore(DirectoryWorkerStore): @defer.inlineCallbacks def delete_room_alias(self, room_alias): - room_id = yield self.runInteraction( + room_id = yield self.db.runInteraction( "delete_room_alias", self._delete_room_alias_txn, room_alias ) @@ -168,6 +170,6 @@ class DirectoryStore(DirectoryWorkerStore): txn, self.get_aliases_for_room, (new_room_id,) ) - return self.runInteraction( + return self.db.runInteraction( "_update_aliases_for_room_txn", _update_aliases_for_room_txn ) diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index 1cbbae5b63..84594cf0a9 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,49 +25,8 @@ from synapse.storage._base import SQLBaseStore class EndToEndRoomKeyStore(SQLBaseStore): @defer.inlineCallbacks - def get_e2e_room_key(self, user_id, version, room_id, session_id): - """Get the encrypted E2E room key for a given session from a given - backup version of room_keys. We only store the 'best' room key for a given - session at a given time, as determined by the handler. - - Args: - user_id(str): the user whose backup we're querying - version(str): the version ID of the backup for the set of keys we're querying - room_id(str): the ID of the room whose keys we're querying. - This is a bit redundant as it's implied by the session_id, but - we include for consistency with the rest of the API. - session_id(str): the session whose room_key we're querying. - - Returns: - A deferred dict giving the session_data and message metadata for - this room key. - """ - - row = yield self._simple_select_one( - table="e2e_room_keys", - keyvalues={ - "user_id": user_id, - "version": version, - "room_id": room_id, - "session_id": session_id, - }, - retcols=( - "first_message_index", - "forwarded_count", - "is_verified", - "session_data", - ), - desc="get_e2e_room_key", - ) - - row["session_data"] = json.loads(row["session_data"]) - - return row - - @defer.inlineCallbacks - def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key): - """Replaces or inserts the encrypted E2E room key for a given session in - a given backup + def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key): + """Replaces the encrypted E2E room key for a given session in a given backup Args: user_id(str): the user whose backup we're setting @@ -78,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): StoreError """ - yield self._simple_upsert( + yield self.db.simple_update_one( table="e2e_room_keys", keyvalues={ "user_id": user_id, @@ -86,21 +46,51 @@ class EndToEndRoomKeyStore(SQLBaseStore): "room_id": room_id, "session_id": session_id, }, - values={ + updatevalues={ "first_message_index": room_key["first_message_index"], "forwarded_count": room_key["forwarded_count"], "is_verified": room_key["is_verified"], "session_data": json.dumps(room_key["session_data"]), }, - lock=False, + desc="update_e2e_room_key", ) - log_kv( - { - "message": "Set room key", - "room_id": room_id, - "session_id": session_id, - "room_key": room_key, - } + + @defer.inlineCallbacks + def add_e2e_room_keys(self, user_id, version, room_keys): + """Bulk add room keys to a given backup. + + Args: + user_id (str): the user whose backup we're adding to + version (str): the version ID of the backup for the set of keys we're adding to + room_keys (iterable[(str, str, dict)]): the keys to add, in the form + (roomID, sessionID, keyData) + """ + + values = [] + for (room_id, session_id, room_key) in room_keys: + values.append( + { + "user_id": user_id, + "version": version, + "room_id": room_id, + "session_id": session_id, + "first_message_index": room_key["first_message_index"], + "forwarded_count": room_key["forwarded_count"], + "is_verified": room_key["is_verified"], + "session_data": json.dumps(room_key["session_data"]), + } + ) + log_kv( + { + "message": "Set room key", + "room_id": room_id, + "session_id": session_id, + "room_key": room_key, + } + ) + + yield self.db.simple_insert_many( + table="e2e_room_keys", values=values, desc="add_e2e_room_keys" ) @trace @@ -110,11 +100,11 @@ class EndToEndRoomKeyStore(SQLBaseStore): room, or a given session. Args: - user_id(str): the user whose backup we're querying - version(str): the version ID of the backup for the set of keys we're querying - room_id(str): Optional. the ID of the room whose keys we're querying, if any. + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup for the set of keys we're querying + room_id (str): Optional. the ID of the room whose keys we're querying, if any. If not specified, we return the keys for all the rooms in the backup. - session_id(str): Optional. the session whose room_key we're querying, if any. + session_id (str): Optional. the session whose room_key we're querying, if any. If specified, we also require the room_id to be specified. If not specified, we return all the keys in this version of the backup (or for the specified room) @@ -135,7 +125,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - rows = yield self._simple_select_list( + rows = yield self.db.simple_select_list( table="e2e_room_keys", keyvalues=keyvalues, retcols=( @@ -162,6 +152,95 @@ class EndToEndRoomKeyStore(SQLBaseStore): return sessions + def get_e2e_room_keys_multi(self, user_id, version, room_keys): + """Get multiple room keys at a time. The difference between this function and + get_e2e_room_keys is that this function can be used to retrieve + multiple specific keys at a time, whereas get_e2e_room_keys is used for + getting all the keys in a backup version, all the keys for a room, or a + specific key. + + Args: + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup we're querying about + room_keys (dict[str, dict[str, iterable[str]]]): a map from + room ID -> {"session": [session ids]} indicating the session IDs + that we want to query + + Returns: + Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key + """ + + return self.db.runInteraction( + "get_e2e_room_keys_multi", + self._get_e2e_room_keys_multi_txn, + user_id, + version, + room_keys, + ) + + @staticmethod + def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys): + if not room_keys: + return {} + + where_clauses = [] + params = [user_id, version] + for room_id, room in room_keys.items(): + sessions = list(room["sessions"]) + if not sessions: + continue + params.append(room_id) + params.extend(sessions) + where_clauses.append( + "(room_id = ? AND session_id IN (%s))" + % (",".join(["?" for _ in sessions]),) + ) + + # check if we're actually querying something + if not where_clauses: + return {} + + sql = """ + SELECT room_id, session_id, first_message_index, forwarded_count, + is_verified, session_data + FROM e2e_room_keys + WHERE user_id = ? AND version = ? AND (%s) + """ % ( + " OR ".join(where_clauses) + ) + + txn.execute(sql, params) + + ret = {} + + for row in txn: + room_id = row[0] + session_id = row[1] + ret.setdefault(room_id, {}) + ret[room_id][session_id] = { + "first_message_index": row[2], + "forwarded_count": row[3], + "is_verified": row[4], + "session_data": json.loads(row[5]), + } + + return ret + + def count_e2e_room_keys(self, user_id, version): + """Get the number of keys in a backup version. + + Args: + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup we're querying about + """ + + return self.db.simple_select_one_onecol( + table="e2e_room_keys", + keyvalues={"user_id": user_id, "version": version}, + retcol="COUNT(*)", + desc="count_e2e_room_keys", + ) + @trace @defer.inlineCallbacks def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): @@ -188,7 +267,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - yield self._simple_delete( + yield self.db.simple_delete( table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" ) @@ -219,6 +298,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): version(str) algorithm(str) auth_data(object): opaque dict supplied by the client + etag(int): tag of the keys in the backup """ def _get_e2e_room_keys_version_info_txn(txn): @@ -232,17 +312,19 @@ class EndToEndRoomKeyStore(SQLBaseStore): # it isn't there. raise StoreError(404, "No row found") - result = self._simple_select_one_txn( + result = self.db.simple_select_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, - retcols=("version", "algorithm", "auth_data"), + retcols=("version", "algorithm", "auth_data", "etag"), ) result["auth_data"] = json.loads(result["auth_data"]) result["version"] = str(result["version"]) + if result["etag"] is None: + result["etag"] = 0 return result - return self.runInteraction( + return self.db.runInteraction( "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn ) @@ -270,7 +352,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): new_version = str(int(current_version) + 1) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="e2e_room_keys_versions", values={ @@ -283,26 +365,38 @@ class EndToEndRoomKeyStore(SQLBaseStore): return new_version - return self.runInteraction( + return self.db.runInteraction( "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn ) @trace - def update_e2e_room_keys_version(self, user_id, version, info): + def update_e2e_room_keys_version( + self, user_id, version, info=None, version_etag=None + ): """Update a given backup version Args: user_id(str): the user whose backup version we're updating version(str): the version ID of the backup version we're updating - info(dict): the new backup version info to store + info (dict): the new backup version info to store. If None, then + the backup version info is not updated + version_etag (Optional[int]): etag of the keys in the backup. If + None, then the etag is not updated """ + updatevalues = {} - return self._simple_update( - table="e2e_room_keys_versions", - keyvalues={"user_id": user_id, "version": version}, - updatevalues={"auth_data": json.dumps(info["auth_data"])}, - desc="update_e2e_room_keys_version", - ) + if info is not None and "auth_data" in info: + updatevalues["auth_data"] = json.dumps(info["auth_data"]) + if version_etag is not None: + updatevalues["etag"] = version_etag + + if updatevalues: + return self.db.simple_update( + table="e2e_room_keys_versions", + keyvalues={"user_id": user_id, "version": version}, + updatevalues=updatevalues, + desc="update_e2e_room_keys_version", + ) @trace def delete_e2e_room_keys_version(self, user_id, version=None): @@ -326,19 +420,19 @@ class EndToEndRoomKeyStore(SQLBaseStore): else: this_version = version - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="e2e_room_keys", keyvalues={"user_id": user_id, "version": this_version}, ) - return self._simple_update_one_txn( + return self.db.simple_update_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version}, updatevalues={"deleted": 1}, ) - return self.runInteraction( + return self.db.runInteraction( "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn ) diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index d8ad59ad93..38cd0ca9b8 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -48,7 +48,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): if not query_list: return {} - results = yield self.runInteraction( + results = yield self.db.runInteraction( "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list, @@ -125,7 +125,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): ) txn.execute(sql, query_params) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) result = {} for row in rows: @@ -143,15 +143,30 @@ class EndToEndKeyWorkerStore(SQLBaseStore): ) txn.execute(signature_sql, signature_query_params) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) + # add each cross-signing signature to the correct device in the result dict. for row in rows: + signing_user_id = row["user_id"] + signing_key_id = row["key_id"] target_user_id = row["target_user_id"] target_device_id = row["target_device_id"] - if target_user_id in result and target_device_id in result[target_user_id]: - result[target_user_id][target_device_id].setdefault( - "signatures", {} - ).setdefault(row["user_id"], {})[row["key_id"]] = row["signature"] + signature = row["signature"] + + target_user_result = result.get(target_user_id) + if not target_user_result: + continue + + target_device_result = target_user_result.get(target_device_id) + if not target_device_result: + # note that target_device_result will be None for deleted devices. + continue + + target_device_signatures = target_device_result.setdefault("signatures", {}) + signing_user_signatures = target_device_signatures.setdefault( + signing_user_id, {} + ) + signing_user_signatures[signing_key_id] = signature log_kv(result) return result @@ -171,7 +186,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): key_id) to json string for key """ - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="e2e_one_time_keys_json", column="key_id", iterable=key_ids, @@ -204,7 +219,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): # a unique constraint. If there is a race of two calls to # `add_e2e_one_time_keys` then they'll conflict and we will only # insert one set. - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="e2e_one_time_keys_json", values=[ @@ -223,7 +238,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - yield self.runInteraction( + yield self.db.runInteraction( "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys ) @@ -246,7 +261,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore): result[algorithm] = key_count return result - return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys) + return self.db.runInteraction( + "count_e2e_one_time_keys", _count_e2e_one_time_keys + ) def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None): """Returns a user's cross-signing key. @@ -307,7 +324,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): Returns: dict of the key data or None if not found """ - return self.runInteraction( + return self.db.runInteraction( "get_e2e_cross_signing_key", self._get_e2e_cross_signing_key_txn, user_id, @@ -335,7 +352,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? GROUP BY user_id """ - return self._execute( + return self.db.execute( "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key ) @@ -352,7 +369,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): set_tag("time_now", time_now) set_tag("device_keys", device_keys) - old_key_json = self._simple_select_one_onecol_txn( + old_key_json = self.db.simple_select_one_onecol_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -368,7 +385,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): log_kv({"Message": "Device key already stored."}) return False - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -377,7 +394,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): log_kv({"message": "Device keys stored."}) return True - return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn) + return self.db.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn) def claim_e2e_one_time_keys(self, query_list): """Take a list of one time keys out of the database""" @@ -416,7 +433,9 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ) return result - return self.runInteraction("claim_e2e_one_time_keys", _claim_e2e_one_time_keys) + return self.db.runInteraction( + "claim_e2e_one_time_keys", _claim_e2e_one_time_keys + ) def delete_e2e_keys_by_device(self, user_id, device_id): def delete_e2e_keys_by_device_txn(txn): @@ -427,12 +446,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "user_id": user_id, } ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="e2e_one_time_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -441,7 +460,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): txn, self.count_e2e_one_time_keys, (user_id, device_id) ) - return self.runInteraction( + return self.db.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) @@ -477,7 +496,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): # The "keys" property must only have one entry, which will be the public # key, so we just grab the first value in there pubkey = next(iter(key["keys"].values())) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, "devices", values={ @@ -490,7 +509,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): # and finally, store the key itself with self._cross_signing_id_gen.get_next() as stream_id: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, "e2e_cross_signing_keys", values={ @@ -509,7 +528,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): key_type (str): the type of cross-signing key to set key (dict): the key data """ - return self.runInteraction( + return self.db.runInteraction( "add_e2e_cross_signing_key", self._set_e2e_cross_signing_key_txn, user_id, @@ -524,7 +543,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): user_id (str): the user who made the signatures signatures (iterable[SignatureListItem]): signatures to add """ - return self._simple_insert_many( + return self.db.simple_insert_many( "e2e_cross_signing_signatures", [ { diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 90bef0cd2c..1f517e8fad 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -28,6 +28,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.signatures import SignatureWorkerStore +from synapse.storage.database import Database from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -58,7 +59,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns: list of event_ids """ - return self.runInteraction( + return self.db.runInteraction( "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given ) @@ -90,12 +91,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return list(results) def get_oldest_events_in_room(self, room_id): - return self.runInteraction( + return self.db.runInteraction( "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id ) def get_oldest_events_with_depth_in_room(self, room_id): - return self.runInteraction( + return self.db.runInteraction( "get_oldest_events_with_depth_in_room", self.get_oldest_events_with_depth_in_room_txn, room_id, @@ -126,7 +127,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns Deferred[int] """ - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="events", column="event_id", iterable=event_ids, @@ -140,7 +141,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return max(row["depth"] for row in rows) def _get_oldest_events_in_room_txn(self, txn, room_id): - return self._simple_select_onecol_txn( + return self.db.simple_select_onecol_txn( txn, table="event_backward_extremities", keyvalues={"room_id": room_id}, @@ -188,7 +189,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas where *hashes* is a map from algorithm to hash. """ - return self.runInteraction( + return self.db.runInteraction( "get_latest_event_ids_and_hashes_in_room", self._get_latest_event_ids_and_hashes_in_room, room_id, @@ -229,13 +230,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas txn.execute(sql, query_args) return [room_id for room_id, in txn] - return self.runInteraction( + return self.db.runInteraction( "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn ) @cached(max_entries=5000, iterable=True) def get_latest_event_ids_in_room(self, room_id): - return self._simple_select_onecol( + return self.db.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, retcol="event_id", @@ -266,12 +267,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas def get_min_depth(self, room_id): """ For hte given room, get the minimum depth we have seen for it. """ - return self.runInteraction( + return self.db.runInteraction( "get_min_depth", self._get_min_depth_interaction, room_id ) def _get_min_depth_interaction(self, txn, room_id): - min_depth = self._simple_select_one_onecol_txn( + min_depth = self.db.simple_select_one_onecol_txn( txn, table="room_depth", keyvalues={"room_id": room_id}, @@ -337,7 +338,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas txn.execute(sql, (stream_ordering, room_id)) return [event_id for event_id, in txn] - return self.runInteraction( + return self.db.runInteraction( "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) @@ -352,7 +353,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas limit (int) """ return ( - self.runInteraction( + self.db.runInteraction( "get_backfill_events", self._get_backfill_events, room_id, @@ -383,7 +384,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas queue = PriorityQueue() for event_id in event_list: - depth = self._simple_select_one_onecol_txn( + depth = self.db.simple_select_one_onecol_txn( txn, table="events", keyvalues={"event_id": event_id, "room_id": room_id}, @@ -415,7 +416,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas @defer.inlineCallbacks def get_missing_events(self, room_id, earliest_events, latest_events, limit): - ids = yield self.runInteraction( + ids = yield self.db.runInteraction( "get_missing_events", self._get_missing_events, room_id, @@ -468,7 +469,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns: Deferred[list[str]] """ - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="event_edges", column="prev_event_id", iterable=event_ids, @@ -491,10 +492,10 @@ class EventFederationStore(EventFederationWorkerStore): EVENT_AUTH_STATE_ONLY = "event_auth_state_only" - def __init__(self, db_conn, hs): - super(EventFederationStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventFederationStore, self).__init__(database, db_conn, hs) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth ) @@ -508,7 +509,7 @@ class EventFederationStore(EventFederationWorkerStore): if min_depth and depth >= min_depth: return - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="room_depth", keyvalues={"room_id": room_id}, @@ -520,7 +521,7 @@ class EventFederationStore(EventFederationWorkerStore): For the given event, update the event edges table and forward and backward extremities tables. """ - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_edges", values=[ @@ -604,13 +605,13 @@ class EventFederationStore(EventFederationWorkerStore): return run_as_background_process( "delete_old_forward_extrem_cache", - self.runInteraction, + self.db.runInteraction, "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn, ) def clean_room_for_join(self, room_id): - return self.runInteraction( + return self.db.runInteraction( "clean_room_for_join", self._clean_room_for_join_txn, room_id ) @@ -654,17 +655,17 @@ class EventFederationStore(EventFederationWorkerStore): "max_stream_id_exclusive": min_stream_id, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_AUTH_STATE_ONLY, new_progress ) return min_stream_id >= target_min_stream_id - result = yield self.runInteraction( + result = yield self.db.runInteraction( self.EVENT_AUTH_STATE_ONLY, delete_event_auth ) if not result: - yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY) + yield self.db.updates._end_background_update(self.EVENT_AUTH_STATE_ONLY) return batch_size diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 04ce21ac66..9988a6d3fc 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -24,6 +24,7 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore +from synapse.storage.database import Database from synapse.util.caches.descriptors import cachedInlineCallbacks logger = logging.getLogger(__name__) @@ -68,8 +69,8 @@ def _deserialize_action(actions, is_highlight): class EventPushActionsWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(EventPushActionsWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs) # These get correctly set by _find_stream_orderings_for_times_txn self.stream_ordering_month_ago = None @@ -93,7 +94,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): def get_unread_event_push_actions_by_room_for_user( self, room_id, user_id, last_read_event_id ): - ret = yield self.runInteraction( + ret = yield self.db.runInteraction( "get_unread_event_push_actions_by_room", self._get_unread_counts_by_receipt_txn, room_id, @@ -177,7 +178,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (min_stream_ordering, max_stream_ordering)) return [r[0] for r in txn] - ret = yield self.runInteraction("get_push_action_users_in_range", f) + ret = yield self.db.runInteraction("get_push_action_users_in_range", f) return ret @defer.inlineCallbacks @@ -229,7 +230,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - after_read_receipt = yield self.runInteraction( + after_read_receipt = yield self.db.runInteraction( "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt ) @@ -257,7 +258,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - no_read_receipt = yield self.runInteraction( + no_read_receipt = yield self.db.runInteraction( "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt ) @@ -329,7 +330,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - after_read_receipt = yield self.runInteraction( + after_read_receipt = yield self.db.runInteraction( "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt ) @@ -357,7 +358,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - no_read_receipt = yield self.runInteraction( + no_read_receipt = yield self.db.runInteraction( "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt ) @@ -407,7 +408,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (user_id, min_stream_ordering)) return bool(txn.fetchone()) - return self.runInteraction( + return self.db.runInteraction( "get_if_maybe_push_in_range_for_user", _get_if_maybe_push_in_range_for_user_txn, ) @@ -441,7 +442,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) def _add_push_actions_to_staging_txn(txn): - # We don't use _simple_insert_many here to avoid the overhead + # We don't use simple_insert_many here to avoid the overhead # of generating lists of dicts. sql = """ @@ -458,7 +459,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ), ) - return self.runInteraction( + return self.db.runInteraction( "add_push_actions_to_staging", _add_push_actions_to_staging_txn ) @@ -472,7 +473,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): """ try: - res = yield self._simple_delete( + res = yield self.db.simple_delete( table="event_push_actions_staging", keyvalues={"event_id": event_id}, desc="remove_push_actions_from_staging", @@ -489,7 +490,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): def _find_stream_orderings_for_times(self): return run_as_background_process( "event_push_action_stream_orderings", - self.runInteraction, + self.db.runInteraction, "_find_stream_orderings_for_times", self._find_stream_orderings_for_times_txn, ) @@ -525,7 +526,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): Deferred[int]: stream ordering of the first event received on/after the timestamp """ - return self.runInteraction( + return self.db.runInteraction( "_find_first_stream_ordering_after_ts_txn", self._find_first_stream_ordering_after_ts_txn, ts, @@ -611,17 +612,17 @@ class EventPushActionsWorkerStore(SQLBaseStore): class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" - def __init__(self, db_conn, hs): - super(EventPushActionsStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventPushActionsStore, self).__init__(database, db_conn, hs) - self.register_background_index_update( + self.db.updates.register_background_index_update( self.EPA_HIGHLIGHT_INDEX, index_name="event_push_actions_u_highlight", table="event_push_actions", columns=["user_id", "stream_ordering"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_push_actions_highlights_index", index_name="event_push_actions_highlights_index", table="event_push_actions", @@ -677,7 +678,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) for event, _ in events_and_contexts: - user_ids = self._simple_select_onecol_txn( + user_ids = self.db.simple_select_onecol_txn( txn, table="event_push_actions_staging", keyvalues={"event_id": event.event_id}, @@ -727,9 +728,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore): " LIMIT ?" % (before_clause,) ) txn.execute(sql, args) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - push_actions = yield self.runInteraction("get_push_actions_for_user", f) + push_actions = yield self.db.runInteraction("get_push_actions_for_user", f) for pa in push_actions: pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) return push_actions @@ -748,7 +749,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): txn.execute(sql, (stream_ordering,)) return txn.fetchone() - result = yield self.runInteraction("get_time_of_last_push_action_before", f) + result = yield self.db.runInteraction("get_time_of_last_push_action_before", f) return result[0] if result else None @defer.inlineCallbacks @@ -757,7 +758,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore): txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") return txn.fetchone() - result = yield self.runInteraction("get_latest_push_action_stream_ordering", f) + result = yield self.db.runInteraction( + "get_latest_push_action_stream_ordering", f + ) return result[0] or 0 def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): @@ -830,7 +833,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): while True: logger.info("Rotating notifications") - caught_up = yield self.runInteraction( + caught_up = yield self.db.runInteraction( "_rotate_notifs", self._rotate_notifs_txn ) if caught_up: @@ -844,7 +847,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): the archiving process has caught up or not. """ - old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", keyvalues={}, @@ -880,7 +883,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): return caught_up def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): - old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", keyvalues={}, @@ -912,7 +915,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): # If the `old.user_id` above is NULL then we know there isn't already an # entry in the table, so we simply insert it. Otherwise we update the # existing table. - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_push_summary", values=[ diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 627c0b67f1..da1529f6ea 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -38,10 +38,10 @@ from synapse.logging.utils import log_function from synapse.metrics import BucketCollector from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.event_federation import EventFederationStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.state import StateGroupWorkerStore +from synapse.storage.database import Database from synapse.types import RoomStreamToken, get_domain_from_id from synapse.util import batch_iter from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -94,13 +94,10 @@ def _retry_on_integrity_error(func): # inherits from EventFederationStore so that we can call _update_backward_extremities # and _handle_mult_prev_events (though arguably those could both be moved in here) class EventsStore( - StateGroupWorkerStore, - EventFederationStore, - EventsWorkerStore, - BackgroundUpdateStore, + StateGroupWorkerStore, EventFederationStore, EventsWorkerStore, ): - def __init__(self, db_conn, hs): - super(EventsStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventsStore, self).__init__(database, db_conn, hs) # Collect metrics on the number of forward extremities that exist. # Counter of number of extremities to count @@ -130,6 +127,8 @@ class EventsStore( if self.hs.config.redaction_retention_period is not None: hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def _read_forward_extremities(self): def fetch(txn): @@ -141,7 +140,7 @@ class EventsStore( ) return txn.fetchall() - res = yield self.runInteraction("read_forward_extremities", fetch) + res = yield self.db.runInteraction("read_forward_extremities", fetch) self._current_forward_extremities_amount = c_counter(list(x[0] for x in res)) @_retry_on_integrity_error @@ -206,7 +205,7 @@ class EventsStore( for (event, context), stream in zip(events_and_contexts, stream_orderings): event.internal_metadata.stream_ordering = stream - yield self.runInteraction( + yield self.db.runInteraction( "persist_events", self._persist_events_txn, events_and_contexts=events_and_contexts, @@ -279,7 +278,7 @@ class EventsStore( results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed")) for chunk in batch_iter(event_ids, 100): - yield self.runInteraction( + yield self.db.runInteraction( "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk ) @@ -343,7 +342,7 @@ class EventsStore( existing_prevs.add(prev_event_id) for chunk in batch_iter(event_ids, 100): - yield self.runInteraction( + yield self.db.runInteraction( "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk ) @@ -430,7 +429,7 @@ class EventsStore( # event's auth chain, but its easier for now just to store them (and # it doesn't take much storage compared to storing the entire event # anyway). - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_auth", values=[ @@ -578,12 +577,12 @@ class EventsStore( self, txn, new_forward_extremities, max_stream_order ): for room_id, new_extrem in iteritems(new_forward_extremities): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="event_forward_extremities", keyvalues={"room_id": room_id} ) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_forward_extremities", values=[ @@ -596,7 +595,7 @@ class EventsStore( # new stream_ordering to new forward extremeties in the room. # This allows us to later efficiently look up the forward extremeties # for a room before a given stream_ordering - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="stream_ordering_to_exterm", values=[ @@ -720,7 +719,7 @@ class EventsStore( # change in outlier status to our workers. stream_order = event.internal_metadata.stream_ordering state_group_id = context.state_group - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="ex_outlier_stream", values={ @@ -792,7 +791,7 @@ class EventsStore( d.pop("redacted_because", None) return d - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_json", values=[ @@ -809,7 +808,7 @@ class EventsStore( ], ) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="events", values=[ @@ -839,7 +838,7 @@ class EventsStore( # If we're persisting an unredacted event we go and ensure # that we mark any redactions that reference this event as # requiring censoring. - self._simple_update_txn( + self.db.simple_update_txn( txn, table="redactions", keyvalues={"redacts": event.event_id}, @@ -927,6 +926,9 @@ class EventsStore( elif event.type == EventTypes.Redaction: # Insert into the redactions table. self._store_redaction(txn, event) + elif event.type == EventTypes.Retention: + # Update the room_retention table. + self._store_retention_policy_for_room_txn(txn, event) self._handle_event_relations(txn, event) @@ -937,6 +939,12 @@ class EventsStore( txn, event.event_id, labels, event.room_id, event.depth ) + if self._ephemeral_messages_enabled: + # If there's an expiry timestamp on the event, store it. + expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) + if isinstance(expiry_ts, int) and not event.is_state(): + self._insert_event_expiry_txn(txn, event.event_id, expiry_ts) + # Insert into the room_memberships table. self._store_room_members_txn( txn, @@ -972,7 +980,7 @@ class EventsStore( state_values.append(vals) - self._simple_insert_many_txn(txn, table="state_events", values=state_values) + self.db.simple_insert_many_txn(txn, table="state_events", values=state_values) # Prefill the event cache self._add_to_cache(txn, events_and_contexts) @@ -1003,7 +1011,7 @@ class EventsStore( ) txn.execute(sql + clause, args) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) for row in rows: event = ev_map[row["event_id"]] if not row["rejects"] and not row["redacts"]: @@ -1021,7 +1029,7 @@ class EventsStore( # invalidate the cache for the redacted event txn.call_after(self._invalidate_get_event_cache, event.redacts) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="redactions", values={ @@ -1066,7 +1074,7 @@ class EventsStore( LIMIT ? """ - rows = yield self._execute( + rows = yield self.db.execute( "_censor_redactions_fetch", None, sql, before_ts, 100 ) @@ -1098,21 +1106,32 @@ class EventsStore( def _update_censor_txn(txn): for redaction_id, event_id, pruned_json in updates: if pruned_json: - self._simple_update_one_txn( - txn, - table="event_json", - keyvalues={"event_id": event_id}, - updatevalues={"json": pruned_json}, - ) + self._censor_event_txn(txn, event_id, pruned_json) - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="redactions", keyvalues={"event_id": redaction_id}, updatevalues={"have_censored": True}, ) - yield self.runInteraction("_update_censor_txn", _update_censor_txn) + yield self.db.runInteraction("_update_censor_txn", _update_censor_txn) + + def _censor_event_txn(self, txn, event_id, pruned_json): + """Censor an event by replacing its JSON in the event_json table with the + provided pruned JSON. + + Args: + txn (LoggingTransaction): The database transaction. + event_id (str): The ID of the event to censor. + pruned_json (str): The pruned JSON + """ + self.db.simple_update_one_txn( + txn, + table="event_json", + keyvalues={"event_id": event_id}, + updatevalues={"json": pruned_json}, + ) @defer.inlineCallbacks def count_daily_messages(self): @@ -1133,7 +1152,7 @@ class EventsStore( (count,) = txn.fetchone() return count - ret = yield self.runInteraction("count_messages", _count_messages) + ret = yield self.db.runInteraction("count_messages", _count_messages) return ret @defer.inlineCallbacks @@ -1154,7 +1173,7 @@ class EventsStore( (count,) = txn.fetchone() return count - ret = yield self.runInteraction("count_daily_sent_messages", _count_messages) + ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages) return ret @defer.inlineCallbacks @@ -1169,7 +1188,7 @@ class EventsStore( (count,) = txn.fetchone() return count - ret = yield self.runInteraction("count_daily_active_rooms", _count) + ret = yield self.db.runInteraction("count_daily_active_rooms", _count) return ret def get_current_backfill_token(self): @@ -1221,7 +1240,7 @@ class EventsStore( return new_event_updates - return self.runInteraction( + return self.db.runInteraction( "get_all_new_forward_event_rows", get_all_new_forward_event_rows ) @@ -1266,7 +1285,7 @@ class EventsStore( return new_event_updates - return self.runInteraction( + return self.db.runInteraction( "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows ) @@ -1359,7 +1378,7 @@ class EventsStore( backward_ex_outliers, ) - return self.runInteraction("get_all_new_events", get_all_new_events_txn) + return self.db.runInteraction("get_all_new_events", get_all_new_events_txn) def purge_history(self, room_id, token, delete_local_events): """Deletes room history before a certain point @@ -1379,7 +1398,7 @@ class EventsStore( deleted events. """ - return self.runInteraction( + return self.db.runInteraction( "purge_history", self._purge_history_txn, room_id, @@ -1627,7 +1646,7 @@ class EventsStore( Deferred[List[int]]: The list of state groups to delete. """ - return self.runInteraction("purge_room", self._purge_room_txn, room_id) + return self.db.runInteraction("purge_room", self._purge_room_txn, room_id) def _purge_room_txn(self, txn, room_id): # First we fetch all the state groups that should be deleted, before @@ -1746,7 +1765,7 @@ class EventsStore( to delete. """ - return self.runInteraction( + return self.db.runInteraction( "purge_unreferenced_state_groups", self._purge_unreferenced_state_groups, room_id, @@ -1758,7 +1777,7 @@ class EventsStore( "[purge] found %i state groups to delete", len(state_groups_to_delete) ) - rows = self._simple_select_many_txn( + rows = self.db.simple_select_many_txn( txn, table="state_group_edges", column="prev_state_group", @@ -1785,15 +1804,15 @@ class EventsStore( curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) curr_state = curr_state[sg] - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="state_groups_state", keyvalues={"state_group": sg} ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="state_group_edges", keyvalues={"state_group": sg} ) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -1830,7 +1849,7 @@ class EventsStore( state group. """ - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="state_group_edges", column="prev_state_group", iterable=state_groups, @@ -1849,7 +1868,7 @@ class EventsStore( state_groups_to_delete (list[int]): State groups to delete """ - return self.runInteraction( + return self.db.runInteraction( "purge_room_state", self._purge_room_state_txn, room_id, @@ -1860,7 +1879,7 @@ class EventsStore( # first we have to delete the state groups states logger.info("[purge] removing %s from state_groups_state", room_id) - self._simple_delete_many_txn( + self.db.simple_delete_many_txn( txn, table="state_groups_state", column="state_group", @@ -1871,7 +1890,7 @@ class EventsStore( # ... and the state group edges logger.info("[purge] removing %s from state_group_edges", room_id) - self._simple_delete_many_txn( + self.db.simple_delete_many_txn( txn, table="state_group_edges", column="state_group", @@ -1882,7 +1901,7 @@ class EventsStore( # ... and the state groups logger.info("[purge] removing %s from state_groups", room_id) - self._simple_delete_many_txn( + self.db.simple_delete_many_txn( txn, table="state_groups", column="id", @@ -1899,7 +1918,7 @@ class EventsStore( @cachedInlineCallbacks(max_entries=5000) def _get_event_ordering(self, event_id): - res = yield self._simple_select_one( + res = yield self.db.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, @@ -1922,7 +1941,7 @@ class EventsStore( txn.execute(sql, (from_token, to_token, limit)) return txn.fetchall() - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_current_state_deltas", get_all_updated_current_state_deltas_txn, ) @@ -1940,7 +1959,7 @@ class EventsStore( room_id (str): The ID of the room the event was sent to. topological_ordering (int): The position of the event in the room's topology. """ - return self._simple_insert_many_txn( + return self.db.simple_insert_many_txn( txn=txn, table="event_labels", values=[ @@ -1954,6 +1973,101 @@ class EventsStore( ], ) + def _insert_event_expiry_txn(self, txn, event_id, expiry_ts): + """Save the expiry timestamp associated with a given event ID. + + Args: + txn (LoggingTransaction): The database transaction to use. + event_id (str): The event ID the expiry timestamp is associated with. + expiry_ts (int): The timestamp at which to expire (delete) the event. + """ + return self.db.simple_insert_txn( + txn=txn, + table="event_expiry", + values={"event_id": event_id, "expiry_ts": expiry_ts}, + ) + + @defer.inlineCallbacks + def expire_event(self, event_id): + """Retrieve and expire an event that has expired, and delete its associated + expiry timestamp. If the event can't be retrieved, delete its associated + timestamp so we don't try to expire it again in the future. + + Args: + event_id (str): The ID of the event to delete. + """ + # Try to retrieve the event's content from the database or the event cache. + event = yield self.get_event(event_id) + + def delete_expired_event_txn(txn): + # Delete the expiry timestamp associated with this event from the database. + self._delete_event_expiry_txn(txn, event_id) + + if not event: + # If we can't find the event, log a warning and delete the expiry date + # from the database so that we don't try to expire it again in the + # future. + logger.warning( + "Can't expire event %s because we don't have it.", event_id + ) + return + + # Prune the event's dict then convert it to JSON. + pruned_json = encode_json(prune_event_dict(event.get_dict())) + + # Update the event_json table to replace the event's JSON with the pruned + # JSON. + self._censor_event_txn(txn, event.event_id, pruned_json) + + # We need to invalidate the event cache entry for this event because we + # changed its content in the database. We can't call + # self._invalidate_cache_and_stream because self.get_event_cache isn't of the + # right type. + txn.call_after(self._get_event_cache.invalidate, (event.event_id,)) + # Send that invalidation to replication so that other workers also invalidate + # the event cache. + self._send_invalidation_to_replication( + txn, "_get_event_cache", (event.event_id,) + ) + + yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn) + + def _delete_event_expiry_txn(self, txn, event_id): + """Delete the expiry timestamp associated with an event ID without deleting the + actual event. + + Args: + txn (LoggingTransaction): The transaction to use to perform the deletion. + event_id (str): The event ID to delete the associated expiry timestamp of. + """ + return self.db.simple_delete_txn( + txn=txn, table="event_expiry", keyvalues={"event_id": event_id} + ) + + def get_next_event_to_expire(self): + """Retrieve the entry with the lowest expiry timestamp in the event_expiry + table, or None if there's no more event to expire. + + Returns: Deferred[Optional[Tuple[str, int]]] + A tuple containing the event ID as its first element and an expiry timestamp + as its second one, if there's at least one row in the event_expiry table. + None otherwise. + """ + + def get_next_event_to_expire_txn(txn): + txn.execute( + """ + SELECT event_id, expiry_ts FROM event_expiry + ORDER BY expiry_ts ASC LIMIT 1 + """ + ) + + return txn.fetchone() + + return self.db.runInteraction( + desc="get_next_event_to_expire", func=get_next_event_to_expire_txn + ) + AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index aa87f9abc5..efee17b929 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -22,30 +22,30 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.constants import EventContentFields -from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database logger = logging.getLogger(__name__) -class EventsBackgroundUpdatesStore(BackgroundUpdateStore): +class EventsBackgroundUpdatesStore(SQLBaseStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" - def __init__(self, db_conn, hs): - super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, self._background_reindex_fields_sender, ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_contains_url_index", index_name="event_contains_url_index", table="events", @@ -56,7 +56,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): # an event_id index on event_search is useful for the purge_history # api. Plus it means we get to enforce some integrity with a UNIQUE # clause - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_search_event_id_idx", index_name="event_search_event_id_idx", table="event_search", @@ -65,16 +65,16 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): psql_only=True, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "redactions_received_ts", self._redactions_received_ts ) # This index gets deleted in `event_fix_redactions_bytes` update - self.register_background_index_update( + self.db.updates.register_background_index_update( "event_fix_redactions_bytes_create_index", index_name="redactions_censored_redacts", table="redactions", @@ -82,11 +82,11 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): where_clause="have_censored", ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "event_fix_redactions_bytes", self._event_fix_redactions_bytes ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "event_store_labels", self._event_store_labels ) @@ -145,18 +145,20 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): "rows_inserted": rows_inserted + len(rows), } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress ) return len(rows) - result = yield self.runInteraction( + result = yield self.db.runInteraction( self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn ) if not result: - yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME) + yield self.db.updates._end_background_update( + self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME + ) return result @@ -189,7 +191,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] for chunk in chunks: - ev_rows = self._simple_select_many_txn( + ev_rows = self.db.simple_select_many_txn( txn, table="event_json", column="event_id", @@ -222,18 +224,20 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): "rows_inserted": rows_inserted + len(rows_to_update), } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress ) return len(rows_to_update) - result = yield self.runInteraction( + result = yield self.db.runInteraction( self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn ) if not result: - yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) + yield self.db.updates._end_background_update( + self.EVENT_ORIGIN_SERVER_TS_NAME + ) return result @@ -366,7 +370,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): to_delete.intersection_update(original_set) - deleted = self._simple_delete_many_txn( + deleted = self.db.simple_delete_many_txn( txn=txn, table="event_forward_extremities", column="event_id", @@ -382,7 +386,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): if deleted: # We now need to invalidate the caches of these rooms - rows = self._simple_select_many_txn( + rows = self.db.simple_select_many_txn( txn, table="events", column="event_id", @@ -396,7 +400,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): self.get_latest_event_ids_in_room.invalidate, (room_id,) ) - self._simple_delete_many_txn( + self.db.simple_delete_many_txn( txn=txn, table="_extremities_to_check", column="event_id", @@ -406,17 +410,19 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): return len(original_set) - num_handled = yield self.runInteraction( + num_handled = yield self.db.runInteraction( "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn ) if not num_handled: - yield self._end_background_update(self.DELETE_SOFT_FAILED_EXTREMITIES) + yield self.db.updates._end_background_update( + self.DELETE_SOFT_FAILED_EXTREMITIES + ) def _drop_table_txn(txn): txn.execute("DROP TABLE _extremities_to_check") - yield self.runInteraction( + yield self.db.runInteraction( "_cleanup_extremities_bg_update_drop_table", _drop_table_txn ) @@ -464,18 +470,18 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id)) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "redactions_received_ts", {"last_event_id": upper_event_id} ) return len(rows) - count = yield self.runInteraction( + count = yield self.db.runInteraction( "_redactions_received_ts", _redactions_received_ts_txn ) if not count: - yield self._end_background_update("redactions_received_ts") + yield self.db.updates._end_background_update("redactions_received_ts") return count @@ -501,11 +507,11 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): txn.execute("DROP INDEX redactions_censored_redacts") - yield self.runInteraction( + yield self.db.runInteraction( "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn ) - yield self._end_background_update("event_fix_redactions_bytes") + yield self.db.updates._end_background_update("event_fix_redactions_bytes") return 1 @@ -533,7 +539,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): try: event_json = json.loads(event_json_raw) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn=txn, table="event_labels", values=[ @@ -559,17 +565,17 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): nbrows += 1 last_row_event_id = event_id - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "event_store_labels", {"last_event_id": last_row_event_id} ) return nbrows - num_rows = yield self.runInteraction( + num_rows = yield self.db.runInteraction( desc="event_store_labels", func=_event_store_labels_txn ) if not num_rows: - yield self._end_background_update("event_store_labels") + yield self.db.updates._end_background_update("event_store_labels") return num_rows diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 4c4b76bd93..9ee117ce0f 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -17,6 +17,7 @@ from __future__ import division import itertools import logging +import threading from collections import namedtuple from canonicaljson import json @@ -32,8 +33,10 @@ from synapse.events.utils import prune_event from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.types import get_domain_from_id from synapse.util import batch_iter +from synapse.util.caches.descriptors import Cache from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -53,6 +56,17 @@ _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) class EventsWorkerStore(SQLBaseStore): + def __init__(self, database: Database, db_conn, hs): + super(EventsWorkerStore, self).__init__(database, db_conn, hs) + + self._get_event_cache = Cache( + "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size + ) + + self._event_fetch_lock = threading.Condition() + self._event_fetch_list = [] + self._event_fetch_ongoing = 0 + def get_received_ts(self, event_id): """Get received_ts (when it was persisted) for the event. @@ -65,7 +79,7 @@ class EventsWorkerStore(SQLBaseStore): Deferred[int|None]: Timestamp in milliseconds, or None for events that were persisted before received_ts was implemented. """ - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="received_ts", @@ -104,7 +118,7 @@ class EventsWorkerStore(SQLBaseStore): return ts - return self.runInteraction( + return self.db.runInteraction( "get_approximate_received_ts", _get_approximate_received_ts_txn ) @@ -439,7 +453,7 @@ class EventsWorkerStore(SQLBaseStore): event_id for events, _ in event_list for event_id in events ) - row_dict = self._new_transaction( + row_dict = self.db.new_transaction( conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch ) @@ -571,7 +585,7 @@ class EventsWorkerStore(SQLBaseStore): if should_start: run_as_background_process( - "fetch_events", self.runWithConnection, self._do_fetch + "fetch_events", self.db.runWithConnection, self._do_fetch ) logger.debug("Loading %d events: %s", len(events), events) @@ -732,7 +746,7 @@ class EventsWorkerStore(SQLBaseStore): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="events", retcols=("event_id",), column="event_id", @@ -767,42 +781,10 @@ class EventsWorkerStore(SQLBaseStore): # break the input up into chunks of 100 input_iterator = iter(event_ids) for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): - yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk) - return results - - def get_seen_events_with_rejections(self, event_ids): - """Given a list of event ids, check if we rejected them. - - Args: - event_ids (list[str]) - - Returns: - Deferred[dict[str, str|None): - Has an entry for each event id we already have seen. Maps to - the rejected reason string if we rejected the event, else maps - to None. - """ - if not event_ids: - return defer.succeed({}) - - def f(txn): - sql = ( - "SELECT e.event_id, reason FROM events as e " - "LEFT JOIN rejections as r ON e.event_id = r.event_id " - "WHERE e.event_id = ?" + yield self.db.runInteraction( + "have_seen_events", have_seen_events_txn, chunk ) - - res = {} - for event_id in event_ids: - txn.execute(sql, (event_id,)) - row = txn.fetchone() - if row: - _, rejected = row - res[event_id] = rejected - - return res - - return self.runInteraction("get_seen_events_with_rejections", f) + return results def _get_total_state_event_counts_txn(self, txn, room_id): """ @@ -828,7 +810,7 @@ class EventsWorkerStore(SQLBaseStore): Returns: Deferred[int] """ - return self.runInteraction( + return self.db.runInteraction( "get_total_state_event_counts", self._get_total_state_event_counts_txn, room_id, @@ -853,7 +835,7 @@ class EventsWorkerStore(SQLBaseStore): Returns: Deferred[int] """ - return self.runInteraction( + return self.db.runInteraction( "get_current_state_event_counts", self._get_current_state_event_counts_txn, room_id, diff --git a/synapse/storage/data_stores/main/filtering.py b/synapse/storage/data_stores/main/filtering.py index f05ace299a..342d6622a4 100644 --- a/synapse/storage/data_stores/main/filtering.py +++ b/synapse/storage/data_stores/main/filtering.py @@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore): except ValueError: raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) - def_json = yield self._simple_select_one_onecol( + def_json = yield self.db.simple_select_one_onecol( table="user_filters", keyvalues={"user_id": user_localpart, "filter_id": filter_id}, retcol="filter_json", @@ -71,4 +71,4 @@ class FilteringStore(SQLBaseStore): return filter_id - return self.runInteraction("add_user_filter", _do_txn) + return self.db.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py index 5ded539af8..6acd45e9f3 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -35,7 +35,7 @@ class GroupServerStore(SQLBaseStore): * "invite" * "open" """ - return self._simple_update_one( + return self.db.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues={"join_policy": join_policy}, @@ -43,7 +43,7 @@ class GroupServerStore(SQLBaseStore): ) def get_group(self, group_id): - return self._simple_select_one( + return self.db.simple_select_one( table="groups", keyvalues={"group_id": group_id}, retcols=( @@ -65,7 +65,7 @@ class GroupServerStore(SQLBaseStore): if not include_private: keyvalues["is_public"] = True - return self._simple_select_list( + return self.db.simple_select_list( table="group_users", keyvalues=keyvalues, retcols=("user_id", "is_public", "is_admin"), @@ -75,7 +75,7 @@ class GroupServerStore(SQLBaseStore): def get_invited_users_in_group(self, group_id): # TODO: Pagination - return self._simple_select_onecol( + return self.db.simple_select_onecol( table="group_invites", keyvalues={"group_id": group_id}, retcol="user_id", @@ -89,7 +89,7 @@ class GroupServerStore(SQLBaseStore): if not include_private: keyvalues["is_public"] = True - return self._simple_select_list( + return self.db.simple_select_list( table="group_rooms", keyvalues=keyvalues, retcols=("room_id", "is_public"), @@ -153,10 +153,12 @@ class GroupServerStore(SQLBaseStore): return rooms, categories - return self.runInteraction("get_rooms_for_summary", _get_rooms_for_summary_txn) + return self.db.runInteraction( + "get_rooms_for_summary", _get_rooms_for_summary_txn + ) def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): - return self.runInteraction( + return self.db.runInteraction( "add_room_to_summary", self._add_room_to_summary_txn, group_id, @@ -180,7 +182,7 @@ class GroupServerStore(SQLBaseStore): an order of 1 will put the room first. Otherwise, the room gets added to the end. """ - room_in_group = self._simple_select_one_onecol_txn( + room_in_group = self.db.simple_select_one_onecol_txn( txn, table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, @@ -193,7 +195,7 @@ class GroupServerStore(SQLBaseStore): if category_id is None: category_id = _DEFAULT_CATEGORY_ID else: - cat_exists = self._simple_select_one_onecol_txn( + cat_exists = self.db.simple_select_one_onecol_txn( txn, table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -204,7 +206,7 @@ class GroupServerStore(SQLBaseStore): raise SynapseError(400, "Category doesn't exist") # TODO: Check category is part of summary already - cat_exists = self._simple_select_one_onecol_txn( + cat_exists = self.db.simple_select_one_onecol_txn( txn, table="group_summary_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -224,7 +226,7 @@ class GroupServerStore(SQLBaseStore): (group_id, category_id, group_id, category_id), ) - existing = self._simple_select_one_txn( + existing = self.db.simple_select_one_txn( txn, table="group_summary_rooms", keyvalues={ @@ -257,7 +259,7 @@ class GroupServerStore(SQLBaseStore): to_update["room_order"] = order if is_public is not None: to_update["is_public"] = is_public - self._simple_update_txn( + self.db.simple_update_txn( txn, table="group_summary_rooms", keyvalues={ @@ -271,7 +273,7 @@ class GroupServerStore(SQLBaseStore): if is_public is None: is_public = True - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_summary_rooms", values={ @@ -287,7 +289,7 @@ class GroupServerStore(SQLBaseStore): if category_id is None: category_id = _DEFAULT_CATEGORY_ID - return self._simple_delete( + return self.db.simple_delete( table="group_summary_rooms", keyvalues={ "group_id": group_id, @@ -299,7 +301,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_categories(self, group_id): - rows = yield self._simple_select_list( + rows = yield self.db.simple_select_list( table="group_room_categories", keyvalues={"group_id": group_id}, retcols=("category_id", "is_public", "profile"), @@ -316,7 +318,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_category(self, group_id, category_id): - category = yield self._simple_select_one( + category = yield self.db.simple_select_one( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, retcols=("is_public", "profile"), @@ -343,7 +345,7 @@ class GroupServerStore(SQLBaseStore): else: update_values["is_public"] = is_public - return self._simple_upsert( + return self.db.simple_upsert( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, values=update_values, @@ -352,7 +354,7 @@ class GroupServerStore(SQLBaseStore): ) def remove_group_category(self, group_id, category_id): - return self._simple_delete( + return self.db.simple_delete( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, desc="remove_group_category", @@ -360,7 +362,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_roles(self, group_id): - rows = yield self._simple_select_list( + rows = yield self.db.simple_select_list( table="group_roles", keyvalues={"group_id": group_id}, retcols=("role_id", "is_public", "profile"), @@ -377,7 +379,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_role(self, group_id, role_id): - role = yield self._simple_select_one( + role = yield self.db.simple_select_one( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, retcols=("is_public", "profile"), @@ -404,7 +406,7 @@ class GroupServerStore(SQLBaseStore): else: update_values["is_public"] = is_public - return self._simple_upsert( + return self.db.simple_upsert( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, values=update_values, @@ -413,14 +415,14 @@ class GroupServerStore(SQLBaseStore): ) def remove_group_role(self, group_id, role_id): - return self._simple_delete( + return self.db.simple_delete( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, desc="remove_group_role", ) def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): - return self.runInteraction( + return self.db.runInteraction( "add_user_to_summary", self._add_user_to_summary_txn, group_id, @@ -444,7 +446,7 @@ class GroupServerStore(SQLBaseStore): an order of 1 will put the user first. Otherwise, the user gets added to the end. """ - user_in_group = self._simple_select_one_onecol_txn( + user_in_group = self.db.simple_select_one_onecol_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -457,7 +459,7 @@ class GroupServerStore(SQLBaseStore): if role_id is None: role_id = _DEFAULT_ROLE_ID else: - role_exists = self._simple_select_one_onecol_txn( + role_exists = self.db.simple_select_one_onecol_txn( txn, table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -468,7 +470,7 @@ class GroupServerStore(SQLBaseStore): raise SynapseError(400, "Role doesn't exist") # TODO: Check role is part of the summary already - role_exists = self._simple_select_one_onecol_txn( + role_exists = self.db.simple_select_one_onecol_txn( txn, table="group_summary_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -488,7 +490,7 @@ class GroupServerStore(SQLBaseStore): (group_id, role_id, group_id, role_id), ) - existing = self._simple_select_one_txn( + existing = self.db.simple_select_one_txn( txn, table="group_summary_users", keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id}, @@ -517,7 +519,7 @@ class GroupServerStore(SQLBaseStore): to_update["user_order"] = order if is_public is not None: to_update["is_public"] = is_public - self._simple_update_txn( + self.db.simple_update_txn( txn, table="group_summary_users", keyvalues={ @@ -531,7 +533,7 @@ class GroupServerStore(SQLBaseStore): if is_public is None: is_public = True - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_summary_users", values={ @@ -547,7 +549,7 @@ class GroupServerStore(SQLBaseStore): if role_id is None: role_id = _DEFAULT_ROLE_ID - return self._simple_delete( + return self.db.simple_delete( table="group_summary_users", keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, desc="remove_user_from_summary", @@ -561,7 +563,7 @@ class GroupServerStore(SQLBaseStore): Deferred[list[str]]: A twisted.Deferred containing a list of group ids containing this room """ - return self._simple_select_onecol( + return self.db.simple_select_onecol( table="group_rooms", keyvalues={"room_id": room_id}, retcol="group_id", @@ -625,12 +627,12 @@ class GroupServerStore(SQLBaseStore): return users, roles - return self.runInteraction( + return self.db.runInteraction( "get_users_for_summary_by_role", _get_users_for_summary_txn ) def is_user_in_group(self, user_id, group_id): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", @@ -639,7 +641,7 @@ class GroupServerStore(SQLBaseStore): ).addCallback(lambda r: bool(r)) def is_user_admin_in_group(self, group_id, user_id): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="is_admin", @@ -650,7 +652,7 @@ class GroupServerStore(SQLBaseStore): def add_group_invite(self, group_id, user_id): """Record that the group server has invited a user """ - return self._simple_insert( + return self.db.simple_insert( table="group_invites", values={"group_id": group_id, "user_id": user_id}, desc="add_group_invite", @@ -659,7 +661,7 @@ class GroupServerStore(SQLBaseStore): def is_user_invited_to_local_group(self, group_id, user_id): """Has the group server invited a user? """ - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", @@ -682,7 +684,7 @@ class GroupServerStore(SQLBaseStore): """ def _get_users_membership_in_group_txn(txn): - row = self._simple_select_one_txn( + row = self.db.simple_select_one_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -697,7 +699,7 @@ class GroupServerStore(SQLBaseStore): "is_privileged": row["is_admin"], } - row = self._simple_select_one_onecol_txn( + row = self.db.simple_select_one_onecol_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -710,7 +712,7 @@ class GroupServerStore(SQLBaseStore): return {} - return self.runInteraction( + return self.db.runInteraction( "get_users_membership_info_in_group", _get_users_membership_in_group_txn ) @@ -738,7 +740,7 @@ class GroupServerStore(SQLBaseStore): """ def _add_user_to_group_txn(txn): - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_users", values={ @@ -749,14 +751,14 @@ class GroupServerStore(SQLBaseStore): }, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, ) if local_attestation: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_attestations_renewals", values={ @@ -766,7 +768,7 @@ class GroupServerStore(SQLBaseStore): }, ) if remote_attestation: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_attestations_remote", values={ @@ -777,49 +779,49 @@ class GroupServerStore(SQLBaseStore): }, ) - return self.runInteraction("add_user_to_group", _add_user_to_group_txn) + return self.db.runInteraction("add_user_to_group", _add_user_to_group_txn) def remove_user_from_group(self, group_id, user_id): def _remove_user_from_group_txn(txn): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_summary_users", keyvalues={"group_id": group_id, "user_id": user_id}, ) - return self.runInteraction( + return self.db.runInteraction( "remove_user_from_group", _remove_user_from_group_txn ) def add_room_to_group(self, group_id, room_id, is_public): - return self._simple_insert( + return self.db.simple_insert( table="group_rooms", values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, desc="add_room_to_group", ) def update_room_in_group_visibility(self, group_id, room_id, is_public): - return self._simple_update( + return self.db.simple_update( table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, updatevalues={"is_public": is_public}, @@ -828,26 +830,26 @@ class GroupServerStore(SQLBaseStore): def remove_room_from_group(self, group_id, room_id): def _remove_room_from_group_txn(txn): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_summary_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, ) - return self.runInteraction( + return self.db.runInteraction( "remove_room_from_group", _remove_room_from_group_txn ) def get_publicised_groups_for_user(self, user_id): """Get all groups a user is publicising """ - return self._simple_select_onecol( + return self.db.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, retcol="group_id", @@ -857,7 +859,7 @@ class GroupServerStore(SQLBaseStore): def update_group_publicity(self, group_id, user_id, publicise): """Update whether the user is publicising their membership of the group """ - return self._simple_update_one( + return self.db.simple_update_one( table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"is_publicised": publicise}, @@ -893,12 +895,12 @@ class GroupServerStore(SQLBaseStore): def _register_user_group_membership_txn(txn, next_id): # TODO: Upsert? - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="local_group_membership", values={ @@ -911,7 +913,7 @@ class GroupServerStore(SQLBaseStore): }, ) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="local_group_updates", values={ @@ -930,7 +932,7 @@ class GroupServerStore(SQLBaseStore): if membership == "join": if local_attestation: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_attestations_renewals", values={ @@ -940,7 +942,7 @@ class GroupServerStore(SQLBaseStore): }, ) if remote_attestation: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="group_attestations_remote", values={ @@ -951,12 +953,12 @@ class GroupServerStore(SQLBaseStore): }, ) else: - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -965,7 +967,7 @@ class GroupServerStore(SQLBaseStore): return next_id with self._group_updates_id_gen.get_next() as next_id: - res = yield self.runInteraction( + res = yield self.db.runInteraction( "register_user_group_membership", _register_user_group_membership_txn, next_id, @@ -976,7 +978,7 @@ class GroupServerStore(SQLBaseStore): def create_group( self, group_id, user_id, name, avatar_url, short_description, long_description ): - yield self._simple_insert( + yield self.db.simple_insert( table="groups", values={ "group_id": group_id, @@ -991,7 +993,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def update_group_profile(self, group_id, profile): - yield self._simple_update_one( + yield self.db.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues=profile, @@ -1008,16 +1010,16 @@ class GroupServerStore(SQLBaseStore): WHERE valid_until_ms <= ? """ txn.execute(sql, (valid_until_ms,)) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - return self.runInteraction( + return self.db.runInteraction( "get_attestations_need_renewals", _get_attestations_need_renewals_txn ) def update_attestation_renewal(self, group_id, user_id, attestation): """Update an attestation that we have renewed """ - return self._simple_update_one( + return self.db.simple_update_one( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, @@ -1027,7 +1029,7 @@ class GroupServerStore(SQLBaseStore): def update_remote_attestion(self, group_id, user_id, attestation): """Update an attestation that a remote has renewed """ - return self._simple_update_one( + return self.db.simple_update_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={ @@ -1046,7 +1048,7 @@ class GroupServerStore(SQLBaseStore): group_id (str) user_id (str) """ - return self._simple_delete( + return self.db.simple_delete( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, desc="remove_attestation_renewal", @@ -1057,7 +1059,7 @@ class GroupServerStore(SQLBaseStore): """Get the attestation that proves the remote agrees that the user is in the group. """ - row = yield self._simple_select_one( + row = yield self.db.simple_select_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, retcols=("valid_until_ms", "attestation_json"), @@ -1072,7 +1074,7 @@ class GroupServerStore(SQLBaseStore): return None def get_joined_groups(self, user_id): - return self._simple_select_onecol( + return self.db.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join"}, retcol="group_id", @@ -1099,7 +1101,7 @@ class GroupServerStore(SQLBaseStore): for row in txn ] - return self.runInteraction( + return self.db.runInteraction( "get_all_groups_for_user", _get_all_groups_for_user_txn ) @@ -1109,7 +1111,7 @@ class GroupServerStore(SQLBaseStore): user_id, from_token ) if not has_changed: - return [] + return defer.succeed([]) def _get_groups_changes_for_user_txn(txn): sql = """ @@ -1129,7 +1131,7 @@ class GroupServerStore(SQLBaseStore): for group_id, membership, gtype, content_json in txn ] - return self.runInteraction( + return self.db.runInteraction( "get_groups_changes_for_user", _get_groups_changes_for_user_txn ) @@ -1139,7 +1141,7 @@ class GroupServerStore(SQLBaseStore): from_token ) if not has_changed: - return [] + return defer.succeed([]) def _get_all_groups_changes_txn(txn): sql = """ @@ -1154,7 +1156,7 @@ class GroupServerStore(SQLBaseStore): for stream_id, group_id, user_id, gtype, content_json in txn ] - return self.runInteraction( + return self.db.runInteraction( "get_all_groups_changes", _get_all_groups_changes_txn ) @@ -1188,8 +1190,8 @@ class GroupServerStore(SQLBaseStore): ] for table in tables: - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table=table, keyvalues={"group_id": group_id} ) - return self.runInteraction("delete_group", _delete_group_txn) + return self.db.runInteraction("delete_group", _delete_group_txn) diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py index ebc7db3ed6..6b12f5a75f 100644 --- a/synapse/storage/data_stores/main/keys.py +++ b/synapse/storage/data_stores/main/keys.py @@ -92,7 +92,7 @@ class KeyStore(SQLBaseStore): _get_keys(txn, batch) return keys - return self.runInteraction("get_server_verify_keys", _txn) + return self.db.runInteraction("get_server_verify_keys", _txn) def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): """Stores NACL verification keys for remote servers. @@ -127,9 +127,9 @@ class KeyStore(SQLBaseStore): f((i,)) return res - return self.runInteraction( + return self.db.runInteraction( "store_server_verify_keys", - self._simple_upsert_many_txn, + self.db.simple_upsert_many_txn, table="server_signature_keys", key_names=("server_name", "key_id"), key_values=key_values, @@ -157,7 +157,7 @@ class KeyStore(SQLBaseStore): ts_valid_until_ms (int): The time when this json stops being valid. key_json (bytes): The encoded JSON. """ - return self._simple_upsert( + return self.db.simple_upsert( table="server_keys_json", keyvalues={ "server_name": server_name, @@ -196,7 +196,7 @@ class KeyStore(SQLBaseStore): keyvalues["key_id"] = key_id if from_server is not None: keyvalues["from_server"] = from_server - rows = self._simple_select_list_txn( + rows = self.db.simple_select_list_txn( txn, "server_keys_json", keyvalues=keyvalues, @@ -211,4 +211,4 @@ class KeyStore(SQLBaseStore): results[(server_name, key_id, from_server)] = rows return results - return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn) + return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn) diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index 0f2887bdce..80ca36dedf 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py @@ -12,14 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database -class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs) +class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): + def __init__(self, database: Database, db_conn, hs): + super(MediaRepositoryBackgroundUpdateStore, self).__init__( + database, db_conn, hs + ) - self.register_background_index_update( + self.db.updates.register_background_index_update( update_name="local_media_repository_url_idx", index_name="local_media_repository_url_idx", table="local_media_repository", @@ -31,15 +34,15 @@ class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore): class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): """Persistence for attachments and avatars""" - def __init__(self, db_conn, hs): - super(MediaRepositoryStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(MediaRepositoryStore, self).__init__(database, db_conn, hs) def get_local_media(self, media_id): """Get the metadata for a local piece of media Returns: None if the media_id doesn't exist. """ - return self._simple_select_one( + return self.db.simple_select_one( "local_media_repository", {"media_id": media_id}, ( @@ -64,7 +67,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): user_id, url_cache=None, ): - return self._simple_insert( + return self.db.simple_insert( "local_media_repository", { "media_id": media_id, @@ -124,12 +127,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) ) - return self.runInteraction("get_url_cache", get_url_cache_txn) + return self.db.runInteraction("get_url_cache", get_url_cache_txn) def store_url_cache( self, url, response_code, etag, expires_ts, og, media_id, download_ts ): - return self._simple_insert( + return self.db.simple_insert( "local_media_repository_url_cache", { "url": url, @@ -144,7 +147,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) def get_local_media_thumbnails(self, media_id): - return self._simple_select_list( + return self.db.simple_select_list( "local_media_repository_thumbnails", {"media_id": media_id}, ( @@ -166,7 +169,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self._simple_insert( + return self.db.simple_insert( "local_media_repository_thumbnails", { "media_id": media_id, @@ -180,7 +183,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) def get_cached_remote_media(self, origin, media_id): - return self._simple_select_one( + return self.db.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( @@ -205,7 +208,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): upload_name, filesystem_id, ): - return self._simple_insert( + return self.db.simple_insert( "remote_media_cache", { "media_origin": origin, @@ -250,10 +253,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) - return self.runInteraction("update_cached_last_access_time", update_cache_txn) + return self.db.runInteraction( + "update_cached_last_access_time", update_cache_txn + ) def get_remote_media_thumbnails(self, origin, media_id): - return self._simple_select_list( + return self.db.simple_select_list( "remote_media_cache_thumbnails", {"media_origin": origin, "media_id": media_id}, ( @@ -278,7 +283,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self._simple_insert( + return self.db.simple_insert( "remote_media_cache_thumbnails", { "media_origin": origin, @@ -300,24 +305,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " WHERE last_access_ts < ?" ) - return self._execute( - "get_remote_media_before", self.cursor_to_dict, sql, before_ts + return self.db.execute( + "get_remote_media_before", self.db.cursor_to_dict, sql, before_ts ) def delete_remote_media(self, media_origin, media_id): def delete_remote_media_txn(txn): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, "remote_media_cache", keyvalues={"media_origin": media_origin, "media_id": media_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, "remote_media_cache_thumbnails", keyvalues={"media_origin": media_origin, "media_id": media_id}, ) - return self.runInteraction("delete_remote_media", delete_remote_media_txn) + return self.db.runInteraction("delete_remote_media", delete_remote_media_txn) def get_expired_url_cache(self, now_ts): sql = ( @@ -331,7 +336,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.execute(sql, (now_ts,)) return [row[0] for row in txn] - return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn) + return self.db.runInteraction( + "get_expired_url_cache", _get_expired_url_cache_txn + ) def delete_url_cache(self, media_ids): if len(media_ids) == 0: @@ -342,7 +349,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def _delete_url_cache_txn(txn): txn.executemany(sql, [(media_id,) for media_id in media_ids]) - return self.runInteraction("delete_url_cache", _delete_url_cache_txn) + return self.db.runInteraction("delete_url_cache", _delete_url_cache_txn) def get_url_cache_media_before(self, before_ts): sql = ( @@ -356,7 +363,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.execute(sql, (before_ts,)) return [row[0] for row in txn] - return self.runInteraction( + return self.db.runInteraction( "get_url_cache_media_before", _get_url_cache_media_before_txn ) @@ -373,6 +380,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): txn.executemany(sql, [(media_id,) for media_id in media_ids]) - return self.runInteraction( + return self.db.runInteraction( "delete_url_cache_media", _delete_url_cache_media_txn ) diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index b41c3d317a..27158534cb 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -17,6 +17,7 @@ import logging from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -27,13 +28,13 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000 class MonthlyActiveUsersStore(SQLBaseStore): - def __init__(self, dbconn, hs): - super(MonthlyActiveUsersStore, self).__init__(None, hs) + def __init__(self, database: Database, db_conn, hs): + super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs) self._clock = hs.get_clock() self.hs = hs # Do not add more reserved users than the total allowable number - self._new_transaction( - dbconn, + self.db.new_transaction( + db_conn, "initialise_mau_threepids", [], [], @@ -146,7 +147,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): txn.execute(sql, query_args) reserved_users = yield self.get_registered_reserved_users() - yield self.runInteraction( + yield self.db.runInteraction( "reap_monthly_active_users", _reap_users, reserved_users ) # It seems poor to invalidate the whole cache, Postgres supports @@ -174,7 +175,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): (count,) = txn.fetchone() return count - return self.runInteraction("count_users", _count_users) + return self.db.runInteraction("count_users", _count_users) @defer.inlineCallbacks def get_registered_reserved_users(self): @@ -217,7 +218,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): if is_support: return - yield self.runInteraction( + yield self.db.runInteraction( "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id ) @@ -261,7 +262,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): # never be a big table and alternative approaches (batching multiple # upserts into a single txn) introduced a lot of extra complexity. # See https://github.com/matrix-org/synapse/issues/3854 for more - is_insert = self._simple_upsert_txn( + is_insert = self.db.simple_upsert_txn( txn, table="monthly_active_users", keyvalues={"user_id": user_id}, @@ -281,7 +282,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): """ - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="monthly_active_users", keyvalues={"user_id": user_id}, retcol="timestamp", diff --git a/synapse/storage/data_stores/main/openid.py b/synapse/storage/data_stores/main/openid.py index 79b40044d9..cc21437e92 100644 --- a/synapse/storage/data_stores/main/openid.py +++ b/synapse/storage/data_stores/main/openid.py @@ -3,7 +3,7 @@ from synapse.storage._base import SQLBaseStore class OpenIdStore(SQLBaseStore): def insert_open_id_token(self, token, ts_valid_until_ms, user_id): - return self._simple_insert( + return self.db.simple_insert( table="open_id_tokens", values={ "token": token, @@ -28,4 +28,6 @@ class OpenIdStore(SQLBaseStore): else: return rows[0][0] - return self.runInteraction("get_user_id_for_token", get_user_id_for_token_txn) + return self.db.runInteraction( + "get_user_id_for_token", get_user_id_for_token_txn + ) diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py index 523ed6575e..a2c83e0867 100644 --- a/synapse/storage/data_stores/main/presence.py +++ b/synapse/storage/data_stores/main/presence.py @@ -29,7 +29,7 @@ class PresenceStore(SQLBaseStore): ) with stream_ordering_manager as stream_orderings: - yield self.runInteraction( + yield self.db.runInteraction( "update_presence", self._update_presence_txn, stream_orderings, @@ -46,7 +46,7 @@ class PresenceStore(SQLBaseStore): txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) # Actually insert new rows - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="presence_stream", values=[ @@ -88,7 +88,7 @@ class PresenceStore(SQLBaseStore): txn.execute(sql, (last_id, current_id)) return txn.fetchall() - return self.runInteraction( + return self.db.runInteraction( "get_all_presence_updates", get_all_presence_updates_txn ) @@ -103,7 +103,7 @@ class PresenceStore(SQLBaseStore): inlineCallbacks=True, ) def get_presence_for_users(self, user_ids): - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="presence_stream", column="user_id", iterable=user_ids, @@ -129,7 +129,7 @@ class PresenceStore(SQLBaseStore): return self._presence_id_gen.get_current_token() def allow_presence_visible(self, observed_localpart, observer_userid): - return self._simple_insert( + return self.db.simple_insert( table="presence_allow_inbound", values={ "observed_user_id": observed_localpart, @@ -140,7 +140,7 @@ class PresenceStore(SQLBaseStore): ) def disallow_presence_visible(self, observed_localpart, observer_userid): - return self._simple_delete_one( + return self.db.simple_delete_one( table="presence_allow_inbound", keyvalues={ "observed_user_id": observed_localpart, diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py index e4e8a1c1d6..2b52cf9c1a 100644 --- a/synapse/storage/data_stores/main/profile.py +++ b/synapse/storage/data_stores/main/profile.py @@ -24,7 +24,7 @@ class ProfileWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_profileinfo(self, user_localpart): try: - profile = yield self._simple_select_one( + profile = yield self.db.simple_select_one( table="profiles", keyvalues={"user_id": user_localpart}, retcols=("displayname", "avatar_url"), @@ -42,7 +42,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_profile_displayname(self, user_localpart): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="displayname", @@ -50,7 +50,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_profile_avatar_url(self, user_localpart): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="avatar_url", @@ -58,7 +58,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_from_remote_profile_cache(self, user_id): - return self._simple_select_one( + return self.db.simple_select_one( table="remote_profile_cache", keyvalues={"user_id": user_id}, retcols=("displayname", "avatar_url"), @@ -67,12 +67,12 @@ class ProfileWorkerStore(SQLBaseStore): ) def create_profile(self, user_localpart): - return self._simple_insert( + return self.db.simple_insert( table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) def set_profile_displayname(self, user_localpart, new_displayname): - return self._simple_update_one( + return self.db.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"displayname": new_displayname}, @@ -80,7 +80,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def set_profile_avatar_url(self, user_localpart, new_avatar_url): - return self._simple_update_one( + return self.db.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"avatar_url": new_avatar_url}, @@ -95,7 +95,7 @@ class ProfileStore(ProfileWorkerStore): This should only be called when `is_subscribed_remote_profile_for_user` would return true for the user. """ - return self._simple_upsert( + return self.db.simple_upsert( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ @@ -107,7 +107,7 @@ class ProfileStore(ProfileWorkerStore): ) def update_remote_profile_cache(self, user_id, displayname, avatar_url): - return self._simple_update( + return self.db.simple_update( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ @@ -125,7 +125,7 @@ class ProfileStore(ProfileWorkerStore): """ subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) if not subscribed: - yield self._simple_delete( + yield self.db.simple_delete( table="remote_profile_cache", keyvalues={"user_id": user_id}, desc="delete_remote_profile_cache", @@ -144,9 +144,9 @@ class ProfileStore(ProfileWorkerStore): txn.execute(sql, (last_checked,)) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - return self.runInteraction( + return self.db.runInteraction( "get_remote_profile_cache_entries_that_expire", _get_remote_profile_cache_entries_that_expire_txn, ) @@ -155,7 +155,7 @@ class ProfileStore(ProfileWorkerStore): def is_subscribed_remote_profile_for_user(self, user_id): """Check whether we are interested in a remote user's profile. """ - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="group_users", keyvalues={"user_id": user_id}, retcol="user_id", @@ -166,7 +166,7 @@ class ProfileStore(ProfileWorkerStore): if res: return True - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="group_invites", keyvalues={"user_id": user_id}, retcol="user_id", diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index b520062d84..5ba13aa973 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -27,6 +27,7 @@ from synapse.storage.data_stores.main.appservice import ApplicationServiceWorker from synapse.storage.data_stores.main.pusher import PusherWorkerStore from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore +from synapse.storage.database import Database from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -72,10 +73,10 @@ class PushRulesWorkerStore( # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): - super(PushRulesWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(PushRulesWorkerStore, self).__init__(database, db_conn, hs) - push_rules_prefill, push_rules_id = self._get_cache_dict( + push_rules_prefill, push_rules_id = self.db.get_cache_dict( db_conn, "push_rules_stream", entity_column="user_id", @@ -100,7 +101,7 @@ class PushRulesWorkerStore( @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): - rows = yield self._simple_select_list( + rows = yield self.db.simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, retcols=( @@ -124,7 +125,7 @@ class PushRulesWorkerStore( @cachedInlineCallbacks(max_entries=5000) def get_push_rules_enabled_for_user(self, user_id): - results = yield self._simple_select_list( + results = yield self.db.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, retcols=("user_name", "rule_id", "enabled"), @@ -146,7 +147,7 @@ class PushRulesWorkerStore( (count,) = txn.fetchone() return bool(count) - return self.runInteraction( + return self.db.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn ) @@ -162,7 +163,7 @@ class PushRulesWorkerStore( results = {user_id: [] for user_id in user_ids} - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="push_rules", column="user_name", iterable=user_ids, @@ -320,7 +321,7 @@ class PushRulesWorkerStore( results = {user_id: {} for user_id in user_ids} - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="push_rules_enable", column="user_name", iterable=user_ids, @@ -350,7 +351,7 @@ class PushRuleStore(PushRulesWorkerStore): with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids if before or after: - yield self.runInteraction( + yield self.db.runInteraction( "_add_push_rule_relative_txn", self._add_push_rule_relative_txn, stream_id, @@ -364,7 +365,7 @@ class PushRuleStore(PushRulesWorkerStore): after, ) else: - yield self.runInteraction( + yield self.db.runInteraction( "_add_push_rule_highest_priority_txn", self._add_push_rule_highest_priority_txn, stream_id, @@ -395,7 +396,7 @@ class PushRuleStore(PushRulesWorkerStore): relative_to_rule = before or after - res = self._simple_select_one_txn( + res = self.db.simple_select_one_txn( txn, table="push_rules", keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, @@ -499,7 +500,7 @@ class PushRuleStore(PushRulesWorkerStore): actions_json, update_stream=True, ): - """Specialised version of _simple_upsert_txn that picks a push_rule_id + """Specialised version of simple_upsert_txn that picks a push_rule_id using the _push_rule_id_gen if it needs to insert the rule. It assumes that the "push_rules" table is locked""" @@ -518,7 +519,7 @@ class PushRuleStore(PushRulesWorkerStore): # We didn't update a row with the given rule_id so insert one push_rule_id = self._push_rule_id_gen.get_next() - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="push_rules", values={ @@ -561,7 +562,7 @@ class PushRuleStore(PushRulesWorkerStore): """ def delete_push_rule_txn(txn, stream_id, event_stream_ordering): - self._simple_delete_one_txn( + self.db.simple_delete_one_txn( txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} ) @@ -571,7 +572,7 @@ class PushRuleStore(PushRulesWorkerStore): with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids - yield self.runInteraction( + yield self.db.runInteraction( "delete_push_rule", delete_push_rule_txn, stream_id, @@ -582,7 +583,7 @@ class PushRuleStore(PushRulesWorkerStore): def set_push_rule_enabled(self, user_id, rule_id, enabled): with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids - yield self.runInteraction( + yield self.db.runInteraction( "_set_push_rule_enabled_txn", self._set_push_rule_enabled_txn, stream_id, @@ -596,7 +597,7 @@ class PushRuleStore(PushRulesWorkerStore): self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled ): new_id = self._push_rules_enable_id_gen.get_next() - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, "push_rules_enable", {"user_name": user_id, "rule_id": rule_id}, @@ -636,7 +637,7 @@ class PushRuleStore(PushRulesWorkerStore): update_stream=False, ) else: - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}, @@ -655,7 +656,7 @@ class PushRuleStore(PushRulesWorkerStore): with self._push_rules_stream_id_gen.get_next() as ids: stream_id, event_stream_ordering = ids - yield self.runInteraction( + yield self.db.runInteraction( "set_push_rule_actions", set_push_rule_actions_txn, stream_id, @@ -675,7 +676,7 @@ class PushRuleStore(PushRulesWorkerStore): if data is not None: values.update(data) - self._simple_insert_txn(txn, "push_rules_stream", values=values) + self.db.simple_insert_txn(txn, "push_rules_stream", values=values) txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) @@ -699,7 +700,7 @@ class PushRuleStore(PushRulesWorkerStore): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - return self.runInteraction( + return self.db.runInteraction( "get_all_push_rule_updates", get_all_push_rule_updates_txn ) diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py index d76861cdc0..f07309ef09 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -59,7 +59,7 @@ class PusherWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_has_pusher(self, user_id): - ret = yield self._simple_select_one_onecol( + ret = yield self.db.simple_select_one_onecol( "pushers", {"user_name": user_id}, "id", allow_none=True ) return ret is not None @@ -72,7 +72,7 @@ class PusherWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_pushers_by(self, keyvalues): - ret = yield self._simple_select_list( + ret = yield self.db.simple_select_list( "pushers", keyvalues, [ @@ -100,11 +100,11 @@ class PusherWorkerStore(SQLBaseStore): def get_all_pushers(self): def get_pushers(txn): txn.execute("SELECT * FROM pushers") - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) return self._decode_pushers_rows(rows) - rows = yield self.runInteraction("get_all_pushers", get_pushers) + rows = yield self.db.runInteraction("get_all_pushers", get_pushers) return rows def get_all_updated_pushers(self, last_id, current_id, limit): @@ -134,7 +134,7 @@ class PusherWorkerStore(SQLBaseStore): return updated, deleted - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_pushers", get_all_updated_pushers_txn ) @@ -177,7 +177,7 @@ class PusherWorkerStore(SQLBaseStore): return results - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn ) @@ -193,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore): inlineCallbacks=True, ) def get_if_users_have_pushers(self, user_ids): - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="pushers", column="user_name", iterable=user_ids, @@ -229,8 +229,8 @@ class PusherStore(PusherWorkerStore): ): with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on - # (app_id, pushkey, user_name) so _simple_upsert will retry - yield self._simple_upsert( + # (app_id, pushkey, user_name) so simple_upsert will retry + yield self.db.simple_upsert( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, values={ @@ -255,7 +255,7 @@ class PusherStore(PusherWorkerStore): if user_has_pusher is not True: # invalidate, since we the user might not have had a pusher before - yield self.runInteraction( + yield self.db.runInteraction( "add_pusher", self._invalidate_cache_and_stream, self.get_if_user_has_pusher, @@ -269,7 +269,7 @@ class PusherStore(PusherWorkerStore): txn, self.get_if_user_has_pusher, (user_id,) ) - self._simple_delete_one_txn( + self.db.simple_delete_one_txn( txn, "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, @@ -278,7 +278,7 @@ class PusherStore(PusherWorkerStore): # it's possible for us to end up with duplicate rows for # (app_id, pushkey, user_id) at different stream_ids, but that # doesn't really matter. - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="deleted_pushers", values={ @@ -290,13 +290,13 @@ class PusherStore(PusherWorkerStore): ) with self._pushers_id_gen.get_next() as stream_id: - yield self.runInteraction("delete_pusher", delete_pusher_txn, stream_id) + yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id) @defer.inlineCallbacks def update_pusher_last_stream_ordering( self, app_id, pushkey, user_id, last_stream_ordering ): - yield self._simple_update_one( + yield self.db.simple_update_one( "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, {"last_stream_ordering": last_stream_ordering}, @@ -319,7 +319,7 @@ class PusherStore(PusherWorkerStore): Returns: Deferred[bool]: True if the pusher still exists; False if it has been deleted. """ - updated = yield self._simple_update( + updated = yield self.db.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={ @@ -333,7 +333,7 @@ class PusherStore(PusherWorkerStore): @defer.inlineCallbacks def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): - yield self._simple_update( + yield self.db.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={"failing_since": failing_since}, @@ -342,7 +342,7 @@ class PusherStore(PusherWorkerStore): @defer.inlineCallbacks def get_throttle_params_by_room(self, pusher_id): - res = yield self._simple_select_list( + res = yield self.db.simple_select_list( "pusher_throttle", {"pusher": pusher_id}, ["room_id", "last_sent_ts", "throttle_ms"], @@ -361,8 +361,8 @@ class PusherStore(PusherWorkerStore): @defer.inlineCallbacks def set_throttle_params(self, pusher_id, room_id, params): # no need to lock because `pusher_throttle` has a primary key on - # (pusher, room_id) so _simple_upsert will retry - yield self._simple_upsert( + # (pusher, room_id) so simple_upsert will retry + yield self.db.simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, params, diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py index 8b17334ff4..96e54d145e 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -22,6 +22,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.storage.util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -38,8 +39,8 @@ class ReceiptsWorkerStore(SQLBaseStore): # the abstract methods being implemented. __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): - super(ReceiptsWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() @@ -61,7 +62,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): - return self._simple_select_list( + return self.db.simple_select_list( table="receipts_linearized", keyvalues={"room_id": room_id, "receipt_type": receipt_type}, retcols=("user_id", "event_id"), @@ -70,7 +71,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(num_args=3) def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, @@ -84,7 +85,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2) def get_receipts_for_user(self, user_id, receipt_type): - rows = yield self._simple_select_list( + rows = yield self.db.simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, retcols=("room_id", "event_id"), @@ -108,7 +109,7 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, (user_id,)) return txn.fetchall() - rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f) + rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f) return { row[0]: { "event_id": row[1], @@ -187,11 +188,11 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql, (room_id, to_key)) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) return rows - rows = yield self.runInteraction("get_linearized_receipts_for_room", f) + rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f) if not rows: return [] @@ -237,9 +238,11 @@ class ReceiptsWorkerStore(SQLBaseStore): txn.execute(sql + clause, [to_key] + list(args)) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f) + txn_results = yield self.db.runInteraction( + "_get_linearized_receipts_for_rooms", f + ) results = {} for row in txn_results: @@ -282,7 +285,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return list(r[0:5] + (json.loads(r[5]),) for r in txn) - return self.runInteraction( + return self.db.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn ) @@ -313,14 +316,14 @@ class ReceiptsWorkerStore(SQLBaseStore): class ReceiptsStore(ReceiptsWorkerStore): - def __init__(self, db_conn, hs): + def __init__(self, database: Database, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = StreamIdGenerator( db_conn, "receipts_linearized", "stream_id" ) - super(ReceiptsStore, self).__init__(db_conn, hs) + super(ReceiptsStore, self).__init__(database, db_conn, hs) def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() @@ -335,7 +338,7 @@ class ReceiptsStore(ReceiptsWorkerStore): otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) """ - res = self._simple_select_one_txn( + res = self.db.simple_select_one_txn( txn, table="events", retcols=["stream_ordering", "received_ts"], @@ -388,7 +391,7 @@ class ReceiptsStore(ReceiptsWorkerStore): (user_id, room_id, receipt_type), ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="receipts_linearized", keyvalues={ @@ -398,7 +401,7 @@ class ReceiptsStore(ReceiptsWorkerStore): }, ) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="receipts_linearized", values={ @@ -453,13 +456,13 @@ class ReceiptsStore(ReceiptsWorkerStore): else: raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) - linearized_event_id = yield self.runInteraction( + linearized_event_id = yield self.db.runInteraction( "insert_receipt_conv", graph_to_linear ) stream_id_manager = self._receipts_id_gen.get_next() with stream_id_manager as stream_id: - event_ts = yield self.runInteraction( + event_ts = yield self.db.runInteraction( "insert_linearized_receipt", self.insert_linearized_receipt_txn, room_id, @@ -488,7 +491,7 @@ class ReceiptsStore(ReceiptsWorkerStore): return stream_id, max_persisted_id def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): - return self.runInteraction( + return self.db.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, @@ -514,7 +517,7 @@ class ReceiptsStore(ReceiptsWorkerStore): self._get_linearized_receipts_for_room.invalidate_many, (room_id,) ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="receipts_graph", keyvalues={ @@ -523,7 +526,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "user_id": user_id, }, ) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="receipts_graph", values={ diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 98cf6427c3..5e8ecac0ea 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -26,8 +26,8 @@ from twisted.internet.defer import Deferred from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage import background_updates from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.types import UserID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -37,15 +37,15 @@ logger = logging.getLogger(__name__) class RegistrationWorkerStore(SQLBaseStore): - def __init__(self, db_conn, hs): - super(RegistrationWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RegistrationWorkerStore, self).__init__(database, db_conn, hs) self.config = hs.config self.clock = hs.get_clock() @cached() def get_user_by_id(self, user_id): - return self._simple_select_one( + return self.db.simple_select_one( table="users", keyvalues={"name": user_id}, retcols=[ @@ -94,7 +94,7 @@ class RegistrationWorkerStore(SQLBaseStore): including the keys `name`, `is_guest`, `device_id`, `token_id`, `valid_until_ms`. """ - return self.runInteraction( + return self.db.runInteraction( "get_user_by_access_token", self._query_for_auth, token ) @@ -109,7 +109,7 @@ class RegistrationWorkerStore(SQLBaseStore): otherwise int representation of the timestamp (as a number of milliseconds since epoch). """ - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="expiration_ts_ms", @@ -137,7 +137,7 @@ class RegistrationWorkerStore(SQLBaseStore): """ def set_account_validity_for_user_txn(txn): - self._simple_update_txn( + self.db.simple_update_txn( txn=txn, table="account_validity", keyvalues={"user_id": user_id}, @@ -151,7 +151,7 @@ class RegistrationWorkerStore(SQLBaseStore): txn, self.get_expiration_ts_for_user, (user_id,) ) - yield self.runInteraction( + yield self.db.runInteraction( "set_account_validity_for_user", set_account_validity_for_user_txn ) @@ -167,7 +167,7 @@ class RegistrationWorkerStore(SQLBaseStore): Raises: StoreError: The provided token is already set for another user. """ - yield self._simple_update_one( + yield self.db.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"renewal_token": renewal_token}, @@ -184,7 +184,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: defer.Deferred[str]: The ID of the user to which the token belongs. """ - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="account_validity", keyvalues={"renewal_token": renewal_token}, retcol="user_id", @@ -203,7 +203,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: defer.Deferred[str]: The renewal token associated with this user ID. """ - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="renewal_token", @@ -229,9 +229,9 @@ class RegistrationWorkerStore(SQLBaseStore): ) values = [False, now_ms, renew_at] txn.execute(sql, values) - return self.cursor_to_dict(txn) + return self.db.cursor_to_dict(txn) - res = yield self.runInteraction( + res = yield self.db.runInteraction( "get_users_expiring_soon", select_users_txn, self.clock.time_msec(), @@ -250,7 +250,7 @@ class RegistrationWorkerStore(SQLBaseStore): email_sent (bool): Flag which indicates whether a renewal email has been sent to this user. """ - yield self._simple_update_one( + yield self.db.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"email_sent": email_sent}, @@ -265,7 +265,7 @@ class RegistrationWorkerStore(SQLBaseStore): Args: user_id (str): ID of the user to remove from the account validity table. """ - yield self._simple_delete_one( + yield self.db.simple_delete_one( table="account_validity", keyvalues={"user_id": user_id}, desc="delete_account_validity_for_user", @@ -281,7 +281,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns (bool): true iff the user is a server admin, false otherwise. """ - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="users", keyvalues={"name": user.to_string()}, retcol="admin", @@ -299,7 +299,7 @@ class RegistrationWorkerStore(SQLBaseStore): admin (bool): true iff the user is to be a server admin, false otherwise. """ - return self._simple_update_one( + return self.db.simple_update_one( table="users", keyvalues={"name": user.to_string()}, updatevalues={"admin": 1 if admin else 0}, @@ -316,7 +316,7 @@ class RegistrationWorkerStore(SQLBaseStore): ) txn.execute(sql, (token,)) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if rows: return rows[0] @@ -332,7 +332,9 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[bool]: True if user 'user_type' is null or empty string """ - res = yield self.runInteraction("is_real_user", self.is_real_user_txn, user_id) + res = yield self.db.runInteraction( + "is_real_user", self.is_real_user_txn, user_id + ) return res @cachedInlineCallbacks() @@ -345,13 +347,13 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[bool]: True if user is of type UserTypes.SUPPORT """ - res = yield self.runInteraction( + res = yield self.db.runInteraction( "is_support_user", self.is_support_user_txn, user_id ) return res def is_real_user_txn(self, txn, user_id): - res = self._simple_select_one_onecol_txn( + res = self.db.simple_select_one_onecol_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -361,7 +363,7 @@ class RegistrationWorkerStore(SQLBaseStore): return res is None def is_support_user_txn(self, txn, user_id): - res = self._simple_select_one_onecol_txn( + res = self.db.simple_select_one_onecol_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -380,7 +382,7 @@ class RegistrationWorkerStore(SQLBaseStore): txn.execute(sql, (user_id,)) return dict(txn) - return self.runInteraction("get_users_by_id_case_insensitive", f) + return self.db.runInteraction("get_users_by_id_case_insensitive", f) async def get_user_by_external_id( self, auth_provider: str, external_id: str @@ -394,7 +396,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: str|None: the mxid of the user, or None if they are not known """ - return await self._simple_select_one_onecol( + return await self.db.simple_select_one_onecol( table="user_external_ids", keyvalues={"auth_provider": auth_provider, "external_id": external_id}, retcol="user_id", @@ -408,12 +410,12 @@ class RegistrationWorkerStore(SQLBaseStore): def _count_users(txn): txn.execute("SELECT COUNT(*) AS users FROM users") - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if rows: return rows[0]["users"] return 0 - ret = yield self.runInteraction("count_users", _count_users) + ret = yield self.db.runInteraction("count_users", _count_users) return ret def count_daily_user_type(self): @@ -445,7 +447,7 @@ class RegistrationWorkerStore(SQLBaseStore): results[row[0]] = row[1] return results - return self.runInteraction("count_daily_user_type", _count_daily_user_type) + return self.db.runInteraction("count_daily_user_type", _count_daily_user_type) @defer.inlineCallbacks def count_nonbridged_users(self): @@ -459,7 +461,7 @@ class RegistrationWorkerStore(SQLBaseStore): (count,) = txn.fetchone() return count - ret = yield self.runInteraction("count_users", _count_users) + ret = yield self.db.runInteraction("count_users", _count_users) return ret @defer.inlineCallbacks @@ -468,12 +470,12 @@ class RegistrationWorkerStore(SQLBaseStore): def _count_users(txn): txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null") - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if rows: return rows[0]["users"] return 0 - ret = yield self.runInteraction("count_real_users", _count_users) + ret = yield self.db.runInteraction("count_real_users", _count_users) return ret @defer.inlineCallbacks @@ -503,7 +505,7 @@ class RegistrationWorkerStore(SQLBaseStore): return ( ( - yield self.runInteraction( + yield self.db.runInteraction( "find_next_generated_user_id", _find_next_generated_user_id ) ) @@ -520,7 +522,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[str|None]: user id or None if no user id/threepid mapping exists """ - user_id = yield self.runInteraction( + user_id = yield self.db.runInteraction( "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address ) return user_id @@ -536,7 +538,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: str|None: user id or None if no user id/threepid mapping exists """ - ret = self._simple_select_one_txn( + ret = self.db.simple_select_one_txn( txn, "user_threepids", {"medium": medium, "address": address}, @@ -549,7 +551,7 @@ class RegistrationWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_add_threepid(self, user_id, medium, address, validated_at, added_at): - yield self._simple_upsert( + yield self.db.simple_upsert( "user_threepids", {"medium": medium, "address": address}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, @@ -557,7 +559,7 @@ class RegistrationWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_get_threepids(self, user_id): - ret = yield self._simple_select_list( + ret = yield self.db.simple_select_list( "user_threepids", {"user_id": user_id}, ["medium", "address", "validated_at", "added_at"], @@ -566,7 +568,7 @@ class RegistrationWorkerStore(SQLBaseStore): return ret def user_delete_threepid(self, user_id, medium, address): - return self._simple_delete( + return self.db.simple_delete( "user_threepids", keyvalues={"user_id": user_id, "medium": medium, "address": address}, desc="user_delete_threepid", @@ -579,7 +581,7 @@ class RegistrationWorkerStore(SQLBaseStore): user_id: The user id to delete all threepids of """ - return self._simple_delete( + return self.db.simple_delete( "user_threepids", keyvalues={"user_id": user_id}, desc="user_delete_threepids", @@ -601,7 +603,7 @@ class RegistrationWorkerStore(SQLBaseStore): """ # We need to use an upsert, in case they user had already bound the # threepid - return self._simple_upsert( + return self.db.simple_upsert( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -627,7 +629,7 @@ class RegistrationWorkerStore(SQLBaseStore): medium (str): The medium of the threepid (e.g "email") address (str): The address of the threepid (e.g "bob@example.com") """ - return self._simple_select_list( + return self.db.simple_select_list( table="user_threepid_id_server", keyvalues={"user_id": user_id}, retcols=["medium", "address"], @@ -648,7 +650,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred """ - return self._simple_delete( + return self.db.simple_delete( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -671,7 +673,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[list[str]]: Resolves to a list of identity servers """ - return self._simple_select_onecol( + return self.db.simple_select_onecol( table="user_threepid_id_server", keyvalues={"user_id": user_id, "medium": medium, "address": address}, retcol="id_server", @@ -689,7 +691,7 @@ class RegistrationWorkerStore(SQLBaseStore): defer.Deferred(bool): The requested value. """ - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="deactivated", @@ -756,13 +758,13 @@ class RegistrationWorkerStore(SQLBaseStore): sql += " LIMIT 1" txn.execute(sql, list(keyvalues.values())) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return None return rows[0] - return self.runInteraction( + return self.db.runInteraction( "get_threepid_validation_session", get_threepid_validation_session_txn ) @@ -776,39 +778,37 @@ class RegistrationWorkerStore(SQLBaseStore): """ def delete_threepid_session_txn(txn): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="threepid_validation_token", keyvalues={"session_id": session_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, ) - return self.runInteraction( + return self.db.runInteraction( "delete_threepid_session", delete_threepid_session_txn ) -class RegistrationBackgroundUpdateStore( - RegistrationWorkerStore, background_updates.BackgroundUpdateStore -): - def __init__(self, db_conn, hs): - super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs) +class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): + def __init__(self, database: Database, db_conn, hs): + super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.clock = hs.get_clock() self.config = hs.config - self.register_background_index_update( + self.db.updates.register_background_index_update( "access_tokens_device_index", index_name="access_tokens_device_id", table="access_tokens", columns=["user_id", "device_id"], ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "users_creation_ts", index_name="users_creation_ts", table="users", @@ -818,13 +818,13 @@ class RegistrationBackgroundUpdateStore( # we no longer use refresh tokens, but it's possible that some people # might have a background update queued to build this index. Just # clear the background update. - self.register_noop_background_update("refresh_tokens_device_index") + self.db.updates.register_noop_background_update("refresh_tokens_device_index") - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "user_threepids_grandfather", self._bg_user_threepids_grandfather ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "users_set_deactivated_flag", self._background_update_set_deactivated_flag ) @@ -857,7 +857,7 @@ class RegistrationBackgroundUpdateStore( (last_user, batch_size), ) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return True, 0 @@ -871,7 +871,7 @@ class RegistrationBackgroundUpdateStore( logger.info("Marked %d rows as deactivated", rows_processed_nb) - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]} ) @@ -880,12 +880,12 @@ class RegistrationBackgroundUpdateStore( else: return False, len(rows) - end, nb_processed = yield self.runInteraction( + end, nb_processed = yield self.db.runInteraction( "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn ) if end: - yield self._end_background_update("users_set_deactivated_flag") + yield self.db.updates._end_background_update("users_set_deactivated_flag") return nb_processed @@ -911,21 +911,29 @@ class RegistrationBackgroundUpdateStore( txn.executemany(sql, [(id_server,) for id_server in id_servers]) if id_servers: - yield self.runInteraction( + yield self.db.runInteraction( "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn ) - yield self._end_background_update("user_threepids_grandfather") + yield self.db.updates._end_background_update("user_threepids_grandfather") return 1 class RegistrationStore(RegistrationBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(RegistrationStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RegistrationStore, self).__init__(database, db_conn, hs) self._account_validity = hs.config.account_validity + if self._account_validity.enabled: + self._clock.call_later( + 0.0, + run_as_background_process, + "account_validity_set_expiration_dates", + self._set_expiration_date_when_missing, + ) + # Create a background job for culling expired 3PID validity tokens def start_cull(): # run as a background process to make sure that the database transactions @@ -953,7 +961,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ next_id = self._access_tokens_id_gen.get_next() - yield self._simple_insert( + yield self.db.simple_insert( "access_tokens", { "id": next_id, @@ -995,7 +1003,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): Raises: StoreError if the user_id could not be registered. """ - return self.runInteraction( + return self.db.runInteraction( "register_user", self._register_user, user_id, @@ -1029,7 +1037,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): # Ensure that the guest user actually exists # ``allow_none=False`` makes this raise an exception # if the row isn't in the database. - self._simple_select_one_txn( + self.db.simple_select_one_txn( txn, "users", keyvalues={"name": user_id, "is_guest": 1}, @@ -1037,7 +1045,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): allow_none=False, ) - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, "users", keyvalues={"name": user_id, "is_guest": 1}, @@ -1051,7 +1059,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): }, ) else: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, "users", values={ @@ -1106,7 +1114,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): external_id: id on that system user_id: complete mxid that it is mapped to """ - return self._simple_insert( + return self.db.simple_insert( table="user_external_ids", values={ "auth_provider": auth_provider, @@ -1124,12 +1132,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ def user_set_password_hash_txn(txn): - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, "users", {"name": user_id}, {"password_hash": password_hash} ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.runInteraction("user_set_password_hash", user_set_password_hash_txn) + return self.db.runInteraction( + "user_set_password_hash", user_set_password_hash_txn + ) def user_set_consent_version(self, user_id, consent_version): """Updates the user table to record privacy policy consent @@ -1144,7 +1154,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ def f(txn): - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="users", keyvalues={"name": user_id}, @@ -1152,7 +1162,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.runInteraction("user_set_consent_version", f) + return self.db.runInteraction("user_set_consent_version", f) def user_set_consent_server_notice_sent(self, user_id, consent_version): """Updates the user table to record that we have sent the user a server @@ -1168,7 +1178,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ def f(txn): - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="users", keyvalues={"name": user_id}, @@ -1176,7 +1186,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) - return self.runInteraction("user_set_consent_server_notice_sent", f) + return self.db.runInteraction("user_set_consent_server_notice_sent", f) def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): """ @@ -1222,11 +1232,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): return tokens_and_devices - return self.runInteraction("user_delete_access_tokens", f) + return self.db.runInteraction("user_delete_access_tokens", f) def delete_access_token(self, access_token): def f(txn): - self._simple_delete_one_txn( + self.db.simple_delete_one_txn( txn, table="access_tokens", keyvalues={"token": access_token} ) @@ -1234,11 +1244,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): txn, self.get_user_by_access_token, (access_token,) ) - return self.runInteraction("delete_access_token", f) + return self.db.runInteraction("delete_access_token", f) @cachedInlineCallbacks() def is_guest(self, user_id): - res = yield self._simple_select_one_onecol( + res = yield self.db.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="is_guest", @@ -1253,7 +1263,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): Adds a user to the table of users who need to be parted from all the rooms they're in """ - return self._simple_insert( + return self.db.simple_insert( "users_pending_deactivation", values={"user_id": user_id}, desc="add_user_pending_deactivation", @@ -1266,7 +1276,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ # XXX: This should be simple_delete_one but we failed to put a unique index on # the table, so somehow duplicate entries have ended up in it. - return self._simple_delete( + return self.db.simple_delete( "users_pending_deactivation", keyvalues={"user_id": user_id}, desc="del_user_pending_deactivation", @@ -1277,7 +1287,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): Gets one user from the table of users waiting to be parted from all the rooms they're in. """ - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( "users_pending_deactivation", keyvalues={}, retcol="user_id", @@ -1307,7 +1317,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): # Insert everything into a transaction in order to run atomically def validate_threepid_session_txn(txn): - row = self._simple_select_one_txn( + row = self.db.simple_select_one_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1325,7 +1335,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): 400, "This client_secret does not match the provided session_id" ) - row = self._simple_select_one_txn( + row = self.db.simple_select_one_txn( txn, table="threepid_validation_token", keyvalues={"session_id": session_id, "token": token}, @@ -1350,7 +1360,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) # Looks good. Validate the session - self._simple_update_txn( + self.db.simple_update_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1360,7 +1370,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): return next_link # Return next_link if it exists - return self.runInteraction( + return self.db.runInteraction( "validate_threepid_session_txn", validate_threepid_session_txn ) @@ -1393,7 +1403,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): if validated_at: insertion_values["validated_at"] = validated_at - return self._simple_upsert( + return self.db.simple_upsert( table="threepid_validation_session", keyvalues={"session_id": session_id}, values={"last_send_attempt": send_attempt}, @@ -1431,7 +1441,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def start_or_continue_validation_session_txn(txn): # Create or update a validation session - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1444,7 +1454,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) # Create a new validation token with this session ID - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="threepid_validation_token", values={ @@ -1455,7 +1465,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): }, ) - return self.runInteraction( + return self.db.runInteraction( "start_or_continue_validation_session", start_or_continue_validation_session_txn, ) @@ -1470,7 +1480,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ return txn.execute(sql, (ts,)) - return self.runInteraction( + return self.db.runInteraction( "cull_expired_threepid_validation_tokens", cull_expired_threepid_validation_tokens_txn, self.clock.time_msec(), @@ -1485,7 +1495,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): deactivated (bool): The value to set for `deactivated`. """ - yield self.runInteraction( + yield self.db.runInteraction( "set_user_deactivated_status", self.set_user_deactivated_status_txn, user_id, @@ -1493,7 +1503,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) def set_user_deactivated_status_txn(self, txn, user_id, deactivated): - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -1502,3 +1512,59 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): self._invalidate_cache_and_stream( txn, self.get_user_deactivated_status, (user_id,) ) + + @defer.inlineCallbacks + def _set_expiration_date_when_missing(self): + """ + Retrieves the list of registered users that don't have an expiration date, and + adds an expiration date for each of them. + """ + + def select_users_with_no_expiration_date_txn(txn): + """Retrieves the list of registered users with no expiration date from the + database, filtering out deactivated users. + """ + sql = ( + "SELECT users.name FROM users" + " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" + " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" + ) + txn.execute(sql, []) + + res = self.db.cursor_to_dict(txn) + if res: + for user in res: + self.set_expiration_date_for_user_txn( + txn, user["name"], use_delta=True + ) + + yield self.db.runInteraction( + "get_users_with_no_expiration_date", + select_users_with_no_expiration_date_txn, + ) + + def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): + """Sets an expiration date to the account with the given user ID. + + Args: + user_id (str): User ID to set an expiration date for. + use_delta (bool): If set to False, the expiration date for the user will be + now + validity period. If set to True, this expiration date will be a + random value in the [now + period - d ; now + period] range, d being a + delta equal to 10% of the validity period. + """ + now_ms = self._clock.time_msec() + expiration_ts = now_ms + self._account_validity.period + + if use_delta: + expiration_ts = self.rand.randrange( + expiration_ts - self._account_validity.startup_job_max_delta, + expiration_ts, + ) + + self.db.simple_upsert_txn( + txn, + "account_validity", + keyvalues={"user_id": user_id}, + values={"expiration_ts_ms": expiration_ts, "email_sent": False}, + ) diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/data_stores/main/rejections.py index 7d5de0ea2e..1c07c7a425 100644 --- a/synapse/storage/data_stores/main/rejections.py +++ b/synapse/storage/data_stores/main/rejections.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) class RejectionsStore(SQLBaseStore): def _store_rejections_txn(self, txn, event_id, reason): - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="rejections", values={ @@ -33,7 +33,7 @@ class RejectionsStore(SQLBaseStore): ) def get_rejection_reason(self, event_id): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="rejections", retcol="reason", keyvalues={"event_id": event_id}, diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py index 858f65582b..046c2b4845 100644 --- a/synapse/storage/data_stores/main/relations.py +++ b/synapse/storage/data_stores/main/relations.py @@ -129,7 +129,7 @@ class RelationsWorkerStore(SQLBaseStore): chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) - return self.runInteraction( + return self.db.runInteraction( "get_recent_references_for_event", _get_recent_references_for_event_txn ) @@ -223,7 +223,7 @@ class RelationsWorkerStore(SQLBaseStore): chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) - return self.runInteraction( + return self.db.runInteraction( "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) @@ -268,7 +268,7 @@ class RelationsWorkerStore(SQLBaseStore): if row: return row[0] - edit_id = yield self.runInteraction( + edit_id = yield self.db.runInteraction( "get_applicable_edit", _get_applicable_edit_txn ) @@ -318,7 +318,7 @@ class RelationsWorkerStore(SQLBaseStore): return bool(txn.fetchone()) - return self.runInteraction( + return self.db.runInteraction( "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) @@ -352,7 +352,7 @@ class RelationsStore(RelationsWorkerStore): aggregation_key = relation.get("key") - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="event_relations", values={ @@ -380,6 +380,6 @@ class RelationsStore(RelationsWorkerStore): redacted_event_id (str): The event that was redacted. """ - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index 67bb1b6f60..0148be20d3 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -19,13 +19,17 @@ import logging import re from typing import Optional, Tuple +from six import integer_types + from canonicaljson import json from twisted.internet import defer +from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.search import SearchStore +from synapse.storage.database import Database from synapse.types import ThirdPartyInstanceID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -50,7 +54,7 @@ class RoomWorkerStore(SQLBaseStore): Returns: A dict containing the room information, or None if the room is unknown. """ - return self._simple_select_one( + return self.db.simple_select_one( table="rooms", keyvalues={"room_id": room_id}, retcols=("room_id", "is_public", "creator"), @@ -59,7 +63,7 @@ class RoomWorkerStore(SQLBaseStore): ) def get_public_room_ids(self): - return self._simple_select_onecol( + return self.db.simple_select_onecol( table="rooms", keyvalues={"is_public": True}, retcol="room_id", @@ -116,7 +120,7 @@ class RoomWorkerStore(SQLBaseStore): txn.execute(sql, query_args) return txn.fetchone()[0] - return self.runInteraction("count_public_rooms", _count_public_rooms_txn) + return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn) @defer.inlineCallbacks def get_largest_public_rooms( @@ -249,21 +253,21 @@ class RoomWorkerStore(SQLBaseStore): def _get_largest_public_rooms_txn(txn): txn.execute(sql, query_args) - results = self.cursor_to_dict(txn) + results = self.db.cursor_to_dict(txn) if not forwards: results.reverse() return results - ret_val = yield self.runInteraction( + ret_val = yield self.db.runInteraction( "get_largest_public_rooms", _get_largest_public_rooms_txn ) defer.returnValue(ret_val) @cached(max_entries=10000) def is_room_blocked(self, room_id): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="blocked_rooms", keyvalues={"room_id": room_id}, retcol="1", @@ -284,7 +288,7 @@ class RoomWorkerStore(SQLBaseStore): of RatelimitOverride are None or 0 then ratelimitng has been disabled for that user entirely. """ - row = yield self._simple_select_one( + row = yield self.db.simple_select_one( table="ratelimit_override", keyvalues={"user_id": user_id}, retcols=("messages_per_second", "burst_count"), @@ -300,8 +304,148 @@ class RoomWorkerStore(SQLBaseStore): else: return None + @cachedInlineCallbacks() + def get_retention_policy_for_room(self, room_id): + """Get the retention policy for a given room. + + If no retention policy has been found for this room, returns a policy defined + by the configured default policy (which has None as both the 'min_lifetime' and + the 'max_lifetime' if no default policy has been defined in the server's + configuration). + + Args: + room_id (str): The ID of the room to get the retention policy of. + + Returns: + dict[int, int]: "min_lifetime" and "max_lifetime" for this room. + """ + + def get_retention_policy_for_room_txn(txn): + txn.execute( + """ + SELECT min_lifetime, max_lifetime FROM room_retention + INNER JOIN current_state_events USING (event_id, room_id) + WHERE room_id = ?; + """, + (room_id,), + ) + + return self.db.cursor_to_dict(txn) + + ret = yield self.db.runInteraction( + "get_retention_policy_for_room", get_retention_policy_for_room_txn, + ) + + # If we don't know this room ID, ret will be None, in this case return the default + # policy. + if not ret: + defer.returnValue( + { + "min_lifetime": self.config.retention_default_min_lifetime, + "max_lifetime": self.config.retention_default_max_lifetime, + } + ) + + row = ret[0] + + # If one of the room's policy's attributes isn't defined, use the matching + # attribute from the default policy. + # The default values will be None if no default policy has been defined, or if one + # of the attributes is missing from the default policy. + if row["min_lifetime"] is None: + row["min_lifetime"] = self.config.retention_default_min_lifetime + + if row["max_lifetime"] is None: + row["max_lifetime"] = self.config.retention_default_max_lifetime + + defer.returnValue(row) + + +class RoomBackgroundUpdateStore(SQLBaseStore): + def __init__(self, database: Database, db_conn, hs): + super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs) + + self.config = hs.config + + self.db.updates.register_background_update_handler( + "insert_room_retention", self._background_insert_retention, + ) + + @defer.inlineCallbacks + def _background_insert_retention(self, progress, batch_size): + """Retrieves a list of all rooms within a range and inserts an entry for each of + them into the room_retention table. + NULLs the property's columns if missing from the retention event in the room's + state (or NULLs all of them if there's no retention event in the room's state), + so that we fall back to the server's retention policy. + """ + + last_room = progress.get("room_id", "") + + def _background_insert_retention_txn(txn): + txn.execute( + """ + SELECT state.room_id, state.event_id, events.json + FROM current_state_events as state + LEFT JOIN event_json AS events ON (state.event_id = events.event_id) + WHERE state.room_id > ? AND state.type = '%s' + ORDER BY state.room_id ASC + LIMIT ?; + """ + % EventTypes.Retention, + (last_room, batch_size), + ) + + rows = self.db.cursor_to_dict(txn) + + if not rows: + return True + + for row in rows: + if not row["json"]: + retention_policy = {} + else: + ev = json.loads(row["json"]) + retention_policy = json.dumps(ev["content"]) + + self.db.simple_insert_txn( + txn=txn, + table="room_retention", + values={ + "room_id": row["room_id"], + "event_id": row["event_id"], + "min_lifetime": retention_policy.get("min_lifetime"), + "max_lifetime": retention_policy.get("max_lifetime"), + }, + ) + + logger.info("Inserted %d rows into room_retention", len(rows)) + + self.db.updates._background_update_progress_txn( + txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]} + ) + + if batch_size > len(rows): + return True + else: + return False + + end = yield self.db.runInteraction( + "insert_room_retention", _background_insert_retention_txn, + ) + + if end: + yield self.db.updates._end_background_update("insert_room_retention") + + defer.returnValue(batch_size) + + +class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): + def __init__(self, database: Database, db_conn, hs): + super(RoomStore, self).__init__(database, db_conn, hs) + + self.config = hs.config -class RoomStore(RoomWorkerStore, SearchStore): @defer.inlineCallbacks def store_room(self, room_id, room_creator_user_id, is_public): """Stores a room. @@ -317,7 +461,7 @@ class RoomStore(RoomWorkerStore, SearchStore): try: def store_room_txn(txn, next_id): - self._simple_insert_txn( + self.db.simple_insert_txn( txn, "rooms", { @@ -327,7 +471,7 @@ class RoomStore(RoomWorkerStore, SearchStore): }, ) if is_public: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -338,7 +482,7 @@ class RoomStore(RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.runInteraction("store_room_txn", store_room_txn, next_id) + yield self.db.runInteraction("store_room_txn", store_room_txn, next_id) except Exception as e: logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") @@ -346,14 +490,14 @@ class RoomStore(RoomWorkerStore, SearchStore): @defer.inlineCallbacks def set_room_is_public(self, room_id, is_public): def set_room_is_public_txn(txn, next_id): - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="rooms", keyvalues={"room_id": room_id}, updatevalues={"is_public": is_public}, ) - entries = self._simple_select_list_txn( + entries = self.db.simple_select_list_txn( txn, table="public_room_list_stream", keyvalues={ @@ -371,7 +515,7 @@ class RoomStore(RoomWorkerStore, SearchStore): add_to_stream = bool(entries[-1]["visibility"]) != is_public if add_to_stream: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -384,7 +528,7 @@ class RoomStore(RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.runInteraction( + yield self.db.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) self.hs.get_notifier().on_new_replication_data() @@ -411,7 +555,7 @@ class RoomStore(RoomWorkerStore, SearchStore): def set_room_is_public_appservice_txn(txn, next_id): if is_public: try: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="appservice_room_list", values={ @@ -424,7 +568,7 @@ class RoomStore(RoomWorkerStore, SearchStore): # We've already inserted, nothing to do. return else: - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="appservice_room_list", keyvalues={ @@ -434,7 +578,7 @@ class RoomStore(RoomWorkerStore, SearchStore): }, ) - entries = self._simple_select_list_txn( + entries = self.db.simple_select_list_txn( txn, table="public_room_list_stream", keyvalues={ @@ -452,7 +596,7 @@ class RoomStore(RoomWorkerStore, SearchStore): add_to_stream = bool(entries[-1]["visibility"]) != is_public if add_to_stream: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -465,7 +609,7 @@ class RoomStore(RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.runInteraction( + yield self.db.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, next_id, @@ -482,7 +626,7 @@ class RoomStore(RoomWorkerStore, SearchStore): row = txn.fetchone() return row[0] or 0 - return self.runInteraction("get_rooms", f) + return self.db.runInteraction("get_rooms", f) def _store_room_topic_txn(self, txn, event): if hasattr(event, "content") and "topic" in event.content: @@ -502,11 +646,40 @@ class RoomStore(RoomWorkerStore, SearchStore): txn, event, "content.body", event.content["body"] ) + def _store_retention_policy_for_room_txn(self, txn, event): + if hasattr(event, "content") and ( + "min_lifetime" in event.content or "max_lifetime" in event.content + ): + if ( + "min_lifetime" in event.content + and not isinstance(event.content.get("min_lifetime"), integer_types) + ) or ( + "max_lifetime" in event.content + and not isinstance(event.content.get("max_lifetime"), integer_types) + ): + # Ignore the event if one of the value isn't an integer. + return + + self.db.simple_insert_txn( + txn=txn, + table="room_retention", + values={ + "room_id": event.room_id, + "event_id": event.event_id, + "min_lifetime": event.content.get("min_lifetime"), + "max_lifetime": event.content.get("max_lifetime"), + }, + ) + + self._invalidate_cache_and_stream( + txn, self.get_retention_policy_for_room, (event.room_id,) + ) + def add_event_report( self, room_id, event_id, user_id, reason, content, received_ts ): next_id = self._event_reports_id_gen.get_next() - return self._simple_insert( + return self.db.simple_insert( table="event_reports", values={ "id": next_id, @@ -539,7 +712,9 @@ class RoomStore(RoomWorkerStore, SearchStore): if prev_id == current_id: return defer.succeed([]) - return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms) + return self.db.runInteraction( + "get_all_new_public_rooms", get_all_new_public_rooms + ) @defer.inlineCallbacks def block_room(self, room_id, user_id): @@ -552,14 +727,14 @@ class RoomStore(RoomWorkerStore, SearchStore): Returns: Deferred """ - yield self._simple_upsert( + yield self.db.simple_upsert( table="blocked_rooms", keyvalues={"room_id": room_id}, values={}, insertion_values={"user_id": user_id}, desc="block_room", ) - yield self.runInteraction( + yield self.db.runInteraction( "block_room_invalidation", self._invalidate_cache_and_stream, self.is_room_blocked, @@ -590,7 +765,9 @@ class RoomStore(RoomWorkerStore, SearchStore): return local_media_mxcs, remote_media_mxcs - return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn) + return self.db.runInteraction( + "get_media_ids_in_room", _get_media_mxcs_in_room_txn + ) def quarantine_media_ids_in_room(self, room_id, quarantined_by): """For a room loops through all events with media and quarantines @@ -629,7 +806,7 @@ class RoomStore(RoomWorkerStore, SearchStore): return total_media_quarantined - return self.runInteraction( + return self.db.runInteraction( "quarantine_media_in_room", _quarantine_media_in_room_txn ) @@ -683,3 +860,89 @@ class RoomStore(RoomWorkerStore, SearchStore): remote_media_mxcs.append((hostname, media_id)) return local_media_mxcs, remote_media_mxcs + + @defer.inlineCallbacks + def get_rooms_for_retention_period_in_range( + self, min_ms, max_ms, include_null=False + ): + """Retrieves all of the rooms within the given retention range. + + Optionally includes the rooms which don't have a retention policy. + + Args: + min_ms (int|None): Duration in milliseconds that define the lower limit of + the range to handle (exclusive). If None, doesn't set a lower limit. + max_ms (int|None): Duration in milliseconds that define the upper limit of + the range to handle (inclusive). If None, doesn't set an upper limit. + include_null (bool): Whether to include rooms which retention policy is NULL + in the returned set. + + Returns: + dict[str, dict]: The rooms within this range, along with their retention + policy. The key is "room_id", and maps to a dict describing the retention + policy associated with this room ID. The keys for this nested dict are + "min_lifetime" (int|None), and "max_lifetime" (int|None). + """ + + def get_rooms_for_retention_period_in_range_txn(txn): + range_conditions = [] + args = [] + + if min_ms is not None: + range_conditions.append("max_lifetime > ?") + args.append(min_ms) + + if max_ms is not None: + range_conditions.append("max_lifetime <= ?") + args.append(max_ms) + + # Do a first query which will retrieve the rooms that have a retention policy + # in their current state. + sql = """ + SELECT room_id, min_lifetime, max_lifetime FROM room_retention + INNER JOIN current_state_events USING (event_id, room_id) + """ + + if len(range_conditions): + sql += " WHERE (" + " AND ".join(range_conditions) + ")" + + if include_null: + sql += " OR max_lifetime IS NULL" + + txn.execute(sql, args) + + rows = self.db.cursor_to_dict(txn) + rooms_dict = {} + + for row in rows: + rooms_dict[row["room_id"]] = { + "min_lifetime": row["min_lifetime"], + "max_lifetime": row["max_lifetime"], + } + + if include_null: + # If required, do a second query that retrieves all of the rooms we know + # of so we can handle rooms with no retention policy. + sql = "SELECT DISTINCT room_id FROM current_state_events" + + txn.execute(sql) + + rows = self.db.cursor_to_dict(txn) + + # If a room isn't already in the dict (i.e. it doesn't have a retention + # policy in its state), add it with a null policy. + for row in rows: + if row["room_id"] not in rooms_dict: + rooms_dict[row["room_id"]] = { + "min_lifetime": None, + "max_lifetime": None, + } + + return rooms_dict + + rooms = yield self.db.runInteraction( + "get_rooms_for_retention_period_in_range", + get_rooms_for_retention_period_in_range_txn, + ) + + defer.returnValue(rooms) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 2af24a20b7..92e3b9c512 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +from typing import Iterable, List from six import iteritems, itervalues @@ -25,9 +26,13 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import LoggingTransaction, make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import ( + LoggingTransaction, + SQLBaseStore, + make_in_list_sql_clause, +) from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database from synapse.storage.engines import Sqlite3Engine from synapse.storage.roommember import ( GetRoomsForUserWithStreamOrdering, @@ -50,8 +55,8 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" class RoomMemberWorkerStore(EventsWorkerStore): - def __init__(self, db_conn, hs): - super(RoomMemberWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs) # Is the current_state_events.membership up to date? Or is the # background update still running? @@ -115,7 +120,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(query) return list(txn)[0][0] - count = yield self.runInteraction("get_known_servers", _transact) + count = yield self.db.runInteraction("get_known_servers", _transact) # We always know about ourselves, even if we have nothing in # room_memberships (for example, the server is new). @@ -127,7 +132,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): membership column is up to date """ - pending_update = self._simple_select_one_txn( + pending_update = self.db.simple_select_one_txn( txn, table="background_updates", keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, @@ -143,7 +148,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): 15.0, run_as_background_process, "_check_safe_current_state_events_membership_updated", - self.runInteraction, + self.db.runInteraction, "_check_safe_current_state_events_membership_updated", self._check_safe_current_state_events_membership_updated_txn, ) @@ -160,7 +165,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached(max_entries=100000, iterable=True) def get_users_in_room(self, room_id): - return self.runInteraction( + return self.db.runInteraction( "get_users_in_room", self.get_users_in_room_txn, room_id ) @@ -268,7 +273,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return res - return self.runInteraction("get_room_summary", _get_room_summary_txn) + return self.db.runInteraction("get_room_summary", _get_room_summary_txn) def _get_user_counts_in_room_txn(self, txn, room_id): """ @@ -338,7 +343,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): if not membership_list: return defer.succeed(None) - rooms = yield self.runInteraction( + rooms = yield self.db.runInteraction( "get_rooms_for_user_where_membership_is", self._get_rooms_for_user_where_membership_is_txn, user_id, @@ -391,7 +396,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) txn.execute(sql, (user_id, *args)) - results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)] + results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)] if do_invite: sql = ( @@ -411,7 +416,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): stream_ordering=r["stream_ordering"], membership=Membership.INVITE, ) - for r in self.cursor_to_dict(txn) + for r in self.db.cursor_to_dict(txn) ) return results @@ -602,7 +607,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): to `user_id` and ProfileInfo (or None if not join event). """ - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="room_memberships", column="event_id", iterable=event_ids, @@ -642,7 +647,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # the returned user actually has the correct domain. like_clause = "%:" + host - rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause) + rows = yield self.db.execute("is_host_joined", None, sql, room_id, like_clause) if not rows: return False @@ -682,7 +687,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # the returned user actually has the correct domain. like_clause = "%:" + host - rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause) + rows = yield self.db.execute("was_host_joined", None, sql, room_id, like_clause) if not rows: return False @@ -752,7 +757,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): rows = txn.fetchall() return rows[0][0] - count = yield self.runInteraction("did_forget_membership", f) + count = yield self.db.runInteraction("did_forget_membership", f) return count == 0 @cached() @@ -789,7 +794,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql, (user_id,)) return set(row[0] for row in txn if row[1] == 0) - return self.runInteraction( + return self.db.runInteraction( "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn ) @@ -804,7 +809,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): Deferred[set[str]]: Set of room IDs. """ - room_ids = yield self._simple_select_onecol( + room_ids = yield self.db.simple_select_onecol( table="room_memberships", keyvalues={"membership": Membership.JOIN, "user_id": user_id}, retcol="room_id", @@ -813,18 +818,34 @@ class RoomMemberWorkerStore(EventsWorkerStore): return set(room_ids) + def get_membership_from_event_ids( + self, member_event_ids: Iterable[str] + ) -> List[dict]: + """Get user_id and membership of a set of event IDs. + """ + + return self.db.simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=member_event_ids, + retcols=("user_id", "membership", "event_id"), + keyvalues={}, + batch_size=500, + desc="get_membership_from_event_ids", + ) + -class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_update_handler( +class RoomMemberBackgroundUpdateStore(SQLBaseStore): + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs) + self.db.updates.register_background_update_handler( _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, self._background_current_state_membership, ) - self.register_background_index_update( + self.db.updates.register_background_index_update( "room_membership_forgotten_idx", index_name="room_memberships_user_room_forgotten", table="room_memberships", @@ -857,7 +878,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return 0 @@ -892,18 +913,20 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): "max_stream_id_exclusive": min_stream_id, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress ) return len(rows) - result = yield self.runInteraction( + result = yield self.db.runInteraction( _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn ) if not result: - yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) + yield self.db.updates._end_background_update( + _MEMBERSHIP_PROFILE_UPDATE_NAME + ) return result @@ -942,7 +965,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): last_processed_room = next_room - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME, {"last_processed_room": last_processed_room}, @@ -954,26 +977,28 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): # string, which will compare before all room IDs correctly. last_processed_room = progress.get("last_processed_room", "") - row_count, finished = yield self.runInteraction( + row_count, finished = yield self.db.runInteraction( "_background_current_state_membership_update", _background_current_state_membership_txn, last_processed_room, ) if finished: - yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME) + yield self.db.updates._end_background_update( + _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME + ) return row_count class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(RoomMemberStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(RoomMemberStore, self).__init__(database, db_conn, hs) def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. """ - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="room_memberships", values=[ @@ -1011,7 +1036,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): is_mine = self.hs.is_mine_id(event.state_key) if is_new_state and is_mine: if event.membership == Membership.INVITE: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="local_invites", values={ @@ -1051,7 +1076,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): txn.execute(sql, (stream_ordering, True, room_id, user_id)) with self._stream_id_gen.get_next() as stream_ordering: - yield self.runInteraction("locally_reject_invite", f, stream_ordering) + yield self.db.runInteraction("locally_reject_invite", f, stream_ordering) def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" @@ -1074,7 +1099,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): txn, self.get_forgotten_rooms_for_user, (user_id,) ) - return self.runInteraction("forget_membership", f) + return self.db.runInteraction("forget_membership", f) class _JoinedHostsCache(object): diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql new file mode 100644 index 0000000000..81a36a8b1d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql @@ -0,0 +1,21 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS event_expiry ( + event_id TEXT PRIMARY KEY, + expiry_ts BIGINT NOT NULL +); + +CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts); diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql new file mode 100644 index 0000000000..7d70dd071e --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- store the current etag of backup version +ALTER TABLE e2e_room_keys_versions ADD COLUMN etag BIGINT; diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql b/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql new file mode 100644 index 0000000000..ee6cdf7a14 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql @@ -0,0 +1,33 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Tracks the retention policy of a room. +-- A NULL max_lifetime or min_lifetime means that the matching property is not defined in +-- the room's retention policy state event. +-- If a room doesn't have a retention policy state event in its state, both max_lifetime +-- and min_lifetime are NULL. +CREATE TABLE IF NOT EXISTS room_retention( + room_id TEXT, + event_id TEXT, + min_lifetime BIGINT, + max_lifetime BIGINT, + + PRIMARY KEY(room_id, event_id) +); + +CREATE INDEX room_retention_max_lifetime_idx on room_retention(max_lifetime); + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('insert_room_retention', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql index 27a96123e3..5c5fffcafb 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql @@ -40,7 +40,8 @@ CREATE TABLE IF NOT EXISTS e2e_cross_signing_signatures ( signature TEXT NOT NULL ); -CREATE UNIQUE INDEX e2e_cross_signing_signatures_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id); +-- replaced by the index created in signing_keys_nonunique_signatures.sql +-- CREATE UNIQUE INDEX e2e_cross_signing_signatures_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id); -- stream of user signature updates CREATE TABLE IF NOT EXISTS user_signature_stream ( diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql new file mode 100644 index 0000000000..0aa90ebf0c --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql @@ -0,0 +1,22 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* The cross-signing signatures index should not be a unique index, because a + * user may upload multiple signatures for the same target user. The previous + * index was unique, so delete it if it's there and create a new non-unique + * index. */ + +DROP INDEX IF EXISTS e2e_cross_signing_signatures_idx; CREATE INDEX IF NOT +EXISTS e2e_cross_signing_signatures2_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id); diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index 9ea2bc0b84..260eff81cc 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -24,8 +24,8 @@ from canonicaljson import json from twisted.internet import defer from synapse.api.errors import SynapseError -from synapse.storage._base import make_in_list_sql_clause -from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine logger = logging.getLogger(__name__) @@ -36,23 +36,23 @@ SearchEntry = namedtuple( ) -class SearchBackgroundUpdateStore(BackgroundUpdateStore): +class SearchBackgroundUpdateStore(SQLBaseStore): EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - def __init__(self, db_conn, hs): - super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs) if not hs.config.enable_search: return - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order ) @@ -61,9 +61,11 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): # a GIN index. However, it's possible that some people might still have # the background update queued, so we register a handler to clear the # background update. - self.register_noop_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME) + self.db.updates.register_noop_background_update( + self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME + ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search ) @@ -93,7 +95,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): # store_search_entries_txn with a generator function, but that # would mean having two cursors open on the database at once. # Instead we just build a list of results. - rows = self.cursor_to_dict(txn) + rows = self.db.cursor_to_dict(txn) if not rows: return 0 @@ -153,18 +155,18 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): "rows_inserted": rows_inserted + len(event_search_rows), } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_SEARCH_UPDATE_NAME, progress ) return len(event_search_rows) - result = yield self.runInteraction( + result = yield self.db.runInteraction( self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn ) if not result: - yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME) + yield self.db.updates._end_background_update(self.EVENT_SEARCH_UPDATE_NAME) return result @@ -206,9 +208,11 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): conn.set_session(autocommit=False) if isinstance(self.database_engine, PostgresEngine): - yield self.runWithConnection(create_index) + yield self.db.runWithConnection(create_index) - yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME) + yield self.db.updates._end_background_update( + self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME + ) return 1 @defer.inlineCallbacks @@ -237,14 +241,14 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): ) conn.set_session(autocommit=False) - yield self.runWithConnection(create_index) + yield self.db.runWithConnection(create_index) pg = dict(progress) pg["have_added_indexes"] = True - yield self.runInteraction( + yield self.db.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, pg, ) @@ -274,18 +278,20 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): "have_added_indexes": True, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress ) return len(rows), True - num_rows, finished = yield self.runInteraction( + num_rows, finished = yield self.db.runInteraction( self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn ) if not finished: - yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME) + yield self.db.updates._end_background_update( + self.EVENT_SEARCH_ORDER_UPDATE_NAME + ) return num_rows @@ -337,8 +343,8 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore): class SearchStore(SearchBackgroundUpdateStore): - def __init__(self, db_conn, hs): - super(SearchStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(SearchStore, self).__init__(database, db_conn, hs) def store_event_search_txn(self, txn, event, key, value): """Add event to the search table @@ -441,7 +447,9 @@ class SearchStore(SearchBackgroundUpdateStore): # entire table from the database. sql += " ORDER BY rank DESC LIMIT 500" - results = yield self._execute("search_msgs", self.cursor_to_dict, sql, *args) + results = yield self.db.execute( + "search_msgs", self.db.cursor_to_dict, sql, *args + ) results = list(filter(lambda row: row["room_id"] in room_ids, results)) @@ -455,8 +463,8 @@ class SearchStore(SearchBackgroundUpdateStore): count_sql += " GROUP BY room_id" - count_results = yield self._execute( - "search_rooms_count", self.cursor_to_dict, count_sql, *count_args + count_results = yield self.db.execute( + "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args ) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) @@ -586,7 +594,9 @@ class SearchStore(SearchBackgroundUpdateStore): args.append(limit) - results = yield self._execute("search_rooms", self.cursor_to_dict, sql, *args) + results = yield self.db.execute( + "search_rooms", self.db.cursor_to_dict, sql, *args + ) results = list(filter(lambda row: row["room_id"] in room_ids, results)) @@ -600,8 +610,8 @@ class SearchStore(SearchBackgroundUpdateStore): count_sql += " GROUP BY room_id" - count_results = yield self._execute( - "search_rooms_count", self.cursor_to_dict, count_sql, *count_args + count_results = yield self.db.execute( + "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args ) count = sum(row["count"] for row in count_results if row["room_id"] in room_ids) @@ -686,7 +696,7 @@ class SearchStore(SearchBackgroundUpdateStore): return highlight_words - return self.runInteraction("_find_highlights", f) + return self.db.runInteraction("_find_highlights", f) def _to_postgres_options(options_dict): diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/data_stores/main/signatures.py index 556191b76f..563216b63c 100644 --- a/synapse/storage/data_stores/main/signatures.py +++ b/synapse/storage/data_stores/main/signatures.py @@ -48,7 +48,7 @@ class SignatureWorkerStore(SQLBaseStore): for event_id in event_ids } - return self.runInteraction("get_event_reference_hashes", f) + return self.db.runInteraction("get_event_reference_hashes", f) @defer.inlineCallbacks def add_event_hashes(self, event_ids): @@ -98,4 +98,4 @@ class SignatureStore(SignatureWorkerStore): } ) - self._simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) + self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 6a90daea31..9ef7b48c74 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -27,8 +27,8 @@ from synapse.api.errors import NotFoundError from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.storage._base import SQLBaseStore -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter from synapse.util.caches import get_cache_factor_for, intern_string @@ -89,7 +89,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): count = 0 while next_group: - next_group = self._simple_select_one_onecol_txn( + next_group = self.db.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": next_group}, @@ -192,7 +192,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): ): break - next_group = self._simple_select_one_onecol_txn( + next_group = self.db.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": next_group}, @@ -214,8 +214,8 @@ class StateGroupWorkerStore( STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - def __init__(self, db_conn, hs): - super(StateGroupWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) # Originally the state store used a single DictionaryCache to cache the # event IDs for the state types in a given state group to avoid hammering @@ -348,7 +348,9 @@ class StateGroupWorkerStore( (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn } - return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn) + return self.db.runInteraction( + "get_current_state_ids", _get_current_state_ids_txn + ) # FIXME: how should this be cached? def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): @@ -392,7 +394,7 @@ class StateGroupWorkerStore( return results - return self.runInteraction( + return self.db.runInteraction( "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) @@ -431,7 +433,7 @@ class StateGroupWorkerStore( """ def _get_state_group_delta_txn(txn): - prev_group = self._simple_select_one_onecol_txn( + prev_group = self.db.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": state_group}, @@ -442,7 +444,7 @@ class StateGroupWorkerStore( if not prev_group: return _GetStateGroupDelta(None, None) - delta_ids = self._simple_select_list_txn( + delta_ids = self.db.simple_select_list_txn( txn, table="state_groups_state", keyvalues={"state_group": state_group}, @@ -454,7 +456,9 @@ class StateGroupWorkerStore( {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, ) - return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn) + return self.db.runInteraction( + "get_state_group_delta", _get_state_group_delta_txn + ) @defer.inlineCallbacks def get_state_groups_ids(self, _room_id, event_ids): @@ -540,7 +544,7 @@ class StateGroupWorkerStore( chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] for chunk in chunks: - res = yield self.runInteraction( + res = yield self.db.runInteraction( "_get_state_groups_from_groups", self._get_state_groups_from_groups_txn, chunk, @@ -644,7 +648,7 @@ class StateGroupWorkerStore( @cached(max_entries=50000) def _get_state_group_for_event(self, event_id): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="event_to_state_groups", keyvalues={"event_id": event_id}, retcol="state_group", @@ -661,7 +665,7 @@ class StateGroupWorkerStore( def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="event_to_state_groups", column="event_id", iterable=event_ids, @@ -902,7 +906,7 @@ class StateGroupWorkerStore( state_group = self.database_engine.get_next_state_group_id(txn) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="state_groups", values={"id": state_group, "room_id": room_id, "event_id": event_id}, @@ -911,7 +915,7 @@ class StateGroupWorkerStore( # We persist as a delta if we can, while also ensuring the chain # of deltas isn't tooo long, as otherwise read performance degrades. if prev_group: - is_in_db = self._simple_select_one_onecol_txn( + is_in_db = self.db.simple_select_one_onecol_txn( txn, table="state_groups", keyvalues={"id": prev_group}, @@ -926,13 +930,13 @@ class StateGroupWorkerStore( potential_hops = self._count_state_group_hops_txn(txn, prev_group) if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="state_group_edges", values={"state_group": state_group, "prev_state_group": prev_group}, ) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -947,7 +951,7 @@ class StateGroupWorkerStore( ], ) else: - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -993,7 +997,7 @@ class StateGroupWorkerStore( return state_group - return self.runInteraction("store_state_group", _store_state_group_txn) + return self.db.runInteraction("store_state_group", _store_state_group_txn) @defer.inlineCallbacks def get_referenced_state_groups(self, state_groups): @@ -1007,7 +1011,7 @@ class StateGroupWorkerStore( referenced. """ - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="event_to_state_groups", column="state_group", iterable=state_groups, @@ -1019,32 +1023,30 @@ class StateGroupWorkerStore( return set(row["state_group"] for row in rows) -class StateBackgroundUpdateStore( - StateGroupBackgroundUpdateStore, BackgroundUpdateStore -): +class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" - def __init__(self, db_conn, hs): - super(StateBackgroundUpdateStore, self).__init__(db_conn, hs) - self.register_background_update_handler( + def __init__(self, database: Database, db_conn, hs): + super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + self.db.updates.register_background_update_handler( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self._background_deduplicate_state, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state ) - self.register_background_index_update( + self.db.updates.register_background_index_update( self.CURRENT_STATE_INDEX_UPDATE_NAME, index_name="current_state_events_member_index", table="current_state_events", columns=["state_key"], where_clause="type='m.room.member'", ) - self.register_background_index_update( + self.db.updates.register_background_index_update( self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME, index_name="event_to_state_groups_sg_index", table="event_to_state_groups", @@ -1065,7 +1067,7 @@ class StateBackgroundUpdateStore( batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) if max_group is None: - rows = yield self._execute( + rows = yield self.db.execute( "_background_deduplicate_state", None, "SELECT coalesce(max(id), 0) FROM state_groups", @@ -1135,13 +1137,13 @@ class StateBackgroundUpdateStore( if prev_state.get(key, None) != value } - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="state_group_edges", keyvalues={"state_group": state_group}, ) - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="state_group_edges", values={ @@ -1150,13 +1152,13 @@ class StateBackgroundUpdateStore( }, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="state_groups_state", keyvalues={"state_group": state_group}, ) - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -1177,18 +1179,18 @@ class StateBackgroundUpdateStore( "max_group": max_group, } - self._background_update_progress_txn( + self.db.updates._background_update_progress_txn( txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress ) return False, batch_size - finished, result = yield self.runInteraction( + finished, result = yield self.db.runInteraction( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn ) if finished: - yield self._end_background_update( + yield self.db.updates._end_background_update( self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME ) @@ -1218,9 +1220,9 @@ class StateBackgroundUpdateStore( ) txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - yield self.runWithConnection(reindex_txn) + yield self.db.runWithConnection(reindex_txn) - yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) + yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) return 1 @@ -1244,8 +1246,8 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): * `state_groups_state`: Maps state group to state events. """ - def __init__(self, db_conn, hs): - super(StateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StateStore, self).__init__(database, db_conn, hs) def _store_event_state_mappings_txn( self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]] @@ -1263,7 +1265,7 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): state_groups[event.event_id] = context.state_group - self._simple_insert_many_txn( + self.db.simple_insert_many_txn( txn, table="event_to_state_groups", values=[ diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py index 28f33ec18f..12c982cb26 100644 --- a/synapse/storage/data_stores/main/state_deltas.py +++ b/synapse/storage/data_stores/main/state_deltas.py @@ -98,14 +98,14 @@ class StateDeltasStore(SQLBaseStore): ORDER BY stream_id ASC """ txn.execute(sql, (prev_stream_id, clipped_stream_id)) - return clipped_stream_id, self.cursor_to_dict(txn) + return clipped_stream_id, self.db.cursor_to_dict(txn) - return self.runInteraction( + return self.db.runInteraction( "get_current_state_deltas", get_current_state_deltas_txn ) def _get_max_stream_id_in_current_state_deltas_txn(self, txn): - return self._simple_select_one_onecol_txn( + return self.db.simple_select_one_onecol_txn( txn, table="current_state_delta_stream", keyvalues={}, @@ -113,7 +113,7 @@ class StateDeltasStore(SQLBaseStore): ) def get_max_stream_id_in_current_state_deltas(self): - return self.runInteraction( + return self.db.runInteraction( "get_max_stream_id_in_current_state_deltas", self._get_max_stream_id_in_current_state_deltas_txn, ) diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 45b3de7d56..7bc186e9a1 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -22,6 +22,7 @@ from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership from synapse.storage.data_stores.main.state_deltas import StateDeltasStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.util.caches.descriptors import cached @@ -58,8 +59,8 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")} class StatsStore(StateDeltasStore): - def __init__(self, db_conn, hs): - super(StatsStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StatsStore, self).__init__(database, db_conn, hs) self.server_name = hs.hostname self.clock = self.hs.get_clock() @@ -68,17 +69,17 @@ class StatsStore(StateDeltasStore): self.stats_delta_processing_lock = DeferredLock() - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_stats_process_rooms", self._populate_stats_process_rooms ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_stats_process_users", self._populate_stats_process_users ) # we no longer need to perform clean-up, but we will give ourselves # the potential to reintroduce it in the future – so documentation # will still encourage the use of this no-op handler. - self.register_noop_background_update("populate_stats_cleanup") - self.register_noop_background_update("populate_stats_prepare") + self.db.updates.register_noop_background_update("populate_stats_cleanup") + self.db.updates.register_noop_background_update("populate_stats_prepare") def quantise_stats_time(self, ts): """ @@ -102,7 +103,7 @@ class StatsStore(StateDeltasStore): This is a background update which regenerates statistics for users. """ if not self.stats_enabled: - yield self._end_background_update("populate_stats_process_users") + yield self.db.updates._end_background_update("populate_stats_process_users") return 1 last_user_id = progress.get("last_user_id", "") @@ -117,22 +118,22 @@ class StatsStore(StateDeltasStore): txn.execute(sql, (last_user_id, batch_size)) return [r for r, in txn] - users_to_work_on = yield self.runInteraction( + users_to_work_on = yield self.db.runInteraction( "_populate_stats_process_users", _get_next_batch ) # No more rooms -- complete the transaction. if not users_to_work_on: - yield self._end_background_update("populate_stats_process_users") + yield self.db.updates._end_background_update("populate_stats_process_users") return 1 for user_id in users_to_work_on: yield self._calculate_and_set_initial_state_for_user(user_id) progress["last_user_id"] = user_id - yield self.runInteraction( + yield self.db.runInteraction( "populate_stats_process_users", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_stats_process_users", progress, ) @@ -145,7 +146,7 @@ class StatsStore(StateDeltasStore): This is a background update which regenerates statistics for rooms. """ if not self.stats_enabled: - yield self._end_background_update("populate_stats_process_rooms") + yield self.db.updates._end_background_update("populate_stats_process_rooms") return 1 last_room_id = progress.get("last_room_id", "") @@ -160,22 +161,22 @@ class StatsStore(StateDeltasStore): txn.execute(sql, (last_room_id, batch_size)) return [r for r, in txn] - rooms_to_work_on = yield self.runInteraction( + rooms_to_work_on = yield self.db.runInteraction( "populate_stats_rooms_get_batch", _get_next_batch ) # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self._end_background_update("populate_stats_process_rooms") + yield self.db.updates._end_background_update("populate_stats_process_rooms") return 1 for room_id in rooms_to_work_on: yield self._calculate_and_set_initial_state_for_room(room_id) progress["last_room_id"] = room_id - yield self.runInteraction( + yield self.db.runInteraction( "_populate_stats_process_rooms", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_stats_process_rooms", progress, ) @@ -186,7 +187,7 @@ class StatsStore(StateDeltasStore): """ Returns the stats processor positions. """ - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="stats_incremental_position", keyvalues={}, retcol="stream_id", @@ -215,7 +216,7 @@ class StatsStore(StateDeltasStore): if field and "\0" in field: fields[col] = None - return self._simple_upsert( + return self.db.simple_upsert( table="room_stats_state", keyvalues={"room_id": room_id}, values=fields, @@ -236,7 +237,7 @@ class StatsStore(StateDeltasStore): Deferred[list[dict]], where the dict has the keys of ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts". """ - return self.runInteraction( + return self.db.runInteraction( "get_statistics_for_subject", self._get_statistics_for_subject_txn, stats_type, @@ -257,14 +258,14 @@ class StatsStore(StateDeltasStore): ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type] ) - slice_list = self._simple_select_list_paginate_txn( + slice_list = self.db.simple_select_list_paginate_txn( txn, table + "_historical", - {id_col: stats_id}, "end_ts", start, size, retcols=selected_columns + ["bucket_size", "end_ts"], + keyvalues={id_col: stats_id}, order_direction="DESC", ) @@ -282,7 +283,7 @@ class StatsStore(StateDeltasStore): "name", "topic", "canonical_alias", "avatar", "join_rules", "history_visibility" """ - return self._simple_select_one( + return self.db.simple_select_one( "room_stats_state", {"room_id": room_id}, retcols=( @@ -308,7 +309,7 @@ class StatsStore(StateDeltasStore): """ table, id_col = TYPE_TO_TABLE[stats_type] - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( "%s_current" % (table,), keyvalues={id_col: id}, retcol="completed_delta_stream_id", @@ -344,14 +345,14 @@ class StatsStore(StateDeltasStore): complete_with_stream_id=stream_id, ) - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, table="stats_incremental_position", keyvalues={}, updatevalues={"stream_id": stream_id}, ) - return self.runInteraction( + return self.db.runInteraction( "bulk_update_stats_delta", _bulk_update_stats_delta_txn ) @@ -382,7 +383,7 @@ class StatsStore(StateDeltasStore): Does not work with per-slice fields. """ - return self.runInteraction( + return self.db.runInteraction( "update_stats_delta", self._update_stats_delta_txn, ts, @@ -517,17 +518,17 @@ class StatsStore(StateDeltasStore): else: self.database_engine.lock_table(txn, table) retcols = list(chain(absolutes.keys(), additive_relatives.keys())) - current_row = self._simple_select_one_txn( + current_row = self.db.simple_select_one_txn( txn, table, keyvalues, retcols, allow_none=True ) if current_row is None: merged_dict = {**keyvalues, **absolutes, **additive_relatives} - self._simple_insert_txn(txn, table, merged_dict) + self.db.simple_insert_txn(txn, table, merged_dict) else: for (key, val) in additive_relatives.items(): current_row[key] += val current_row.update(absolutes) - self._simple_update_one_txn(txn, table, keyvalues, current_row) + self.db.simple_update_one_txn(txn, table, keyvalues, current_row) def _upsert_copy_from_table_with_additive_relatives_txn( self, @@ -614,11 +615,11 @@ class StatsStore(StateDeltasStore): txn.execute(sql, qargs) else: self.database_engine.lock_table(txn, into_table) - src_row = self._simple_select_one_txn( + src_row = self.db.simple_select_one_txn( txn, src_table, keyvalues, copy_columns ) all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues} - dest_current_row = self._simple_select_one_txn( + dest_current_row = self.db.simple_select_one_txn( txn, into_table, keyvalues=all_dest_keyvalues, @@ -634,11 +635,11 @@ class StatsStore(StateDeltasStore): **src_row, **additive_relatives, } - self._simple_insert_txn(txn, into_table, merged_dict) + self.db.simple_insert_txn(txn, into_table, merged_dict) else: for (key, val) in additive_relatives.items(): src_row[key] = dest_current_row[key] + val - self._simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) + self.db.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): """Fetches the counts of events in the given range of stream IDs. @@ -652,7 +653,7 @@ class StatsStore(StateDeltasStore): changes. """ - return self.runInteraction( + return self.db.runInteraction( "stats_incremental_total_events_and_bytes", self.get_changes_room_total_events_and_bytes_txn, min_pos, @@ -735,7 +736,7 @@ class StatsStore(StateDeltasStore): def _fetch_current_state_stats(txn): pos = self.get_room_max_stream_ordering() - rows = self._simple_select_many_txn( + rows = self.db.simple_select_many_txn( txn, table="current_state_events", column="type", @@ -791,7 +792,7 @@ class StatsStore(StateDeltasStore): current_state_events_count, users_in_room, pos, - ) = yield self.runInteraction( + ) = yield self.db.runInteraction( "get_initial_state_for_room", _fetch_current_state_stats ) @@ -866,7 +867,7 @@ class StatsStore(StateDeltasStore): (count,) = txn.fetchone() return count, pos - joined_rooms, pos = yield self.runInteraction( + joined_rooms, pos = yield self.db.runInteraction( "calculate_and_set_initial_state_for_user", _calculate_and_set_initial_state_for_user_txn, ) diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 9ae4a913a1..140da8dad6 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -44,6 +47,7 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -248,11 +252,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): __metaclass__ = abc.ABCMeta - def __init__(self, db_conn, hs): - super(StreamWorkerStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(StreamWorkerStore, self).__init__(database, db_conn, hs) events_max = self.get_room_max_stream_ordering() - event_cache_prefill, min_event_val = self._get_cache_dict( + event_cache_prefill, min_event_val = self.db.get_cache_dict( db_conn, "events", entity_column="room_id", @@ -397,7 +401,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] return rows - rows = yield self.runInteraction("get_room_events_stream_for_room", f) + rows = yield self.db.runInteraction("get_room_events_stream_for_room", f) ret = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True @@ -447,7 +451,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return rows - rows = yield self.runInteraction("get_membership_changes_for_user", f) + rows = yield self.db.runInteraction("get_membership_changes_for_user", f) ret = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True @@ -508,7 +512,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): end_token = RoomStreamToken.parse(end_token) - rows, token = yield self.runInteraction( + rows, token = yield self.db.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_txn, room_id, @@ -545,7 +549,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn.execute(sql, (room_id, stream_ordering)) return txn.fetchone() - return self.runInteraction("get_room_event_after_stream_ordering", _f) + return self.db.runInteraction("get_room_event_after_stream_ordering", _f) @defer.inlineCallbacks def get_room_events_max_id(self, room_id=None): @@ -559,7 +563,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if room_id is None: return "s%d" % (token,) else: - topo = yield self.runInteraction( + topo = yield self.db.runInteraction( "_get_max_topological_txn", self._get_max_topological_txn, room_id ) return "t%d-%d" % (topo, token) @@ -573,7 +577,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: A deferred "s%d" stream token. """ - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" ).addCallback(lambda row: "s%d" % (row,)) @@ -586,7 +590,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: A deferred "t%d-%d" topological token. """ - return self._simple_select_one( + return self.db.simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), @@ -610,7 +614,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "SELECT coalesce(max(topological_ordering), 0) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) - return self._execute( + return self.db.execute( "get_max_topological_token", None, sql, room_id, stream_key ).addCallback(lambda r: r[0][0] if r else 0) @@ -664,7 +668,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): dict """ - results = yield self.runInteraction( + results = yield self.db.runInteraction( "get_events_around", self._get_events_around_txn, room_id, @@ -706,7 +710,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): dict """ - results = self._simple_select_one_txn( + results = self.db.simple_select_one_txn( txn, "events", keyvalues={"event_id": event_id, "room_id": room_id}, @@ -785,7 +789,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, [row[1] for row in rows] - upper_bound, event_ids = yield self.runInteraction( + upper_bound, event_ids = yield self.db.runInteraction( "get_all_new_events_stream", get_all_new_events_stream_txn ) @@ -794,7 +798,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, events def get_federation_out_pos(self, typ): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="federation_stream_position", retcol="stream_id", keyvalues={"type": typ}, @@ -802,7 +806,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) def update_federation_out_pos(self, typ, stream_id): - return self._simple_update_one( + return self.db.simple_update_one( table="federation_stream_position", keyvalues={"type": typ}, updatevalues={"stream_id": stream_id}, @@ -953,7 +957,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if to_key: to_key = RoomStreamToken.parse(to_key) - rows, token = yield self.runInteraction( + rows, token = yield self.db.runInteraction( "paginate_room_events", self._paginate_room_events_txn, room_id, diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py index aa24339717..2aa1bafd48 100644 --- a/synapse/storage/data_stores/main/tags.py +++ b/synapse/storage/data_stores/main/tags.py @@ -41,7 +41,7 @@ class TagsWorkerStore(AccountDataWorkerStore): tag strings to tag content. """ - deferred = self._simple_select_list( + deferred = self.db.simple_select_list( "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] ) @@ -78,7 +78,7 @@ class TagsWorkerStore(AccountDataWorkerStore): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - tag_ids = yield self.runInteraction( + tag_ids = yield self.db.runInteraction( "get_all_updated_tags", get_all_updated_tags_txn ) @@ -98,7 +98,7 @@ class TagsWorkerStore(AccountDataWorkerStore): batch_size = 50 results = [] for i in range(0, len(tag_ids), batch_size): - tags = yield self.runInteraction( + tags = yield self.db.runInteraction( "get_all_updated_tag_content", get_tag_content, tag_ids[i : i + batch_size], @@ -135,7 +135,9 @@ class TagsWorkerStore(AccountDataWorkerStore): if not changed: return {} - room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn) + room_ids = yield self.db.runInteraction( + "get_updated_tags", get_updated_tags_txn + ) results = {} if room_ids: @@ -153,7 +155,7 @@ class TagsWorkerStore(AccountDataWorkerStore): Returns: A deferred list of string tags. """ - return self._simple_select_list( + return self.db.simple_select_list( table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id}, retcols=("tag", "content"), @@ -178,7 +180,7 @@ class TagsStore(TagsWorkerStore): content_json = json.dumps(content) def add_tag_txn(txn, next_id): - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag}, @@ -187,7 +189,7 @@ class TagsStore(TagsWorkerStore): self._update_revision_txn(txn, user_id, room_id, next_id) with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction("add_tag", add_tag_txn, next_id) + yield self.db.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) @@ -210,7 +212,7 @@ class TagsStore(TagsWorkerStore): self._update_revision_txn(txn, user_id, room_id, next_id) with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction("remove_tag", remove_tag_txn, next_id) + yield self.db.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py index 01b1be5e14..5b07c2fbc0 100644 --- a/synapse/storage/data_stores/main/transactions.py +++ b/synapse/storage/data_stores/main/transactions.py @@ -24,6 +24,7 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache # py2 sqlite has buffer hardcoded as only binary type, so we must use it, @@ -52,8 +53,8 @@ class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. """ - def __init__(self, db_conn, hs): - super(TransactionStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(TransactionStore, self).__init__(database, db_conn, hs) self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000) @@ -77,7 +78,7 @@ class TransactionStore(SQLBaseStore): this transaction or a 2-tuple of (int, dict) """ - return self.runInteraction( + return self.db.runInteraction( "get_received_txn_response", self._get_received_txn_response, transaction_id, @@ -85,7 +86,7 @@ class TransactionStore(SQLBaseStore): ) def _get_received_txn_response(self, txn, transaction_id, origin): - result = self._simple_select_one_txn( + result = self.db.simple_select_one_txn( txn, table="received_transactions", keyvalues={"transaction_id": transaction_id, "origin": origin}, @@ -119,7 +120,7 @@ class TransactionStore(SQLBaseStore): response_json (str) """ - return self._simple_insert( + return self.db.simple_insert( table="received_transactions", values={ "transaction_id": transaction_id, @@ -148,7 +149,7 @@ class TransactionStore(SQLBaseStore): if result is not SENTINEL: return result - result = yield self.runInteraction( + result = yield self.db.runInteraction( "get_destination_retry_timings", self._get_destination_retry_timings, destination, @@ -160,7 +161,7 @@ class TransactionStore(SQLBaseStore): return result def _get_destination_retry_timings(self, txn, destination): - result = self._simple_select_one_txn( + result = self.db.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, @@ -187,7 +188,7 @@ class TransactionStore(SQLBaseStore): """ self._destination_retry_cache.pop(destination, None) - return self.runInteraction( + return self.db.runInteraction( "set_destination_retry_timings", self._set_destination_retry_timings, destination, @@ -227,7 +228,7 @@ class TransactionStore(SQLBaseStore): # We need to be careful here as the data may have changed from under us # due to a worker setting the timings. - prev_row = self._simple_select_one_txn( + prev_row = self.db.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, @@ -236,7 +237,7 @@ class TransactionStore(SQLBaseStore): ) if not prev_row: - self._simple_insert_txn( + self.db.simple_insert_txn( txn, table="destinations", values={ @@ -247,7 +248,7 @@ class TransactionStore(SQLBaseStore): }, ) elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: - self._simple_update_one_txn( + self.db.simple_update_one_txn( txn, "destinations", keyvalues={"destination": destination}, @@ -270,4 +271,6 @@ class TransactionStore(SQLBaseStore): def _cleanup_transactions_txn(txn): txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,)) - return self.runInteraction("_cleanup_transactions", _cleanup_transactions_txn) + return self.db.runInteraction( + "_cleanup_transactions", _cleanup_transactions_txn + ) diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index 652abe0e6a..90c180ec6d 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -19,9 +19,9 @@ import re from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules -from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.state import StateFilter from synapse.storage.data_stores.main.state_deltas import StateDeltasStore +from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import get_domain_from_id, get_localpart_from_id from synapse.util.caches.descriptors import cached @@ -32,30 +32,30 @@ logger = logging.getLogger(__name__) TEMP_TABLE = "_temp_populate_user_directory" -class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore): +class UserDirectoryBackgroundUpdateStore(StateDeltasStore): # How many records do we calculate before sending it to # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, db_conn, hs): - super(UserDirectoryBackgroundUpdateStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs) self.server_name = hs.hostname - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_createtables", self._populate_user_directory_createtables, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_process_rooms", self._populate_user_directory_process_rooms, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_process_users", self._populate_user_directory_process_users, ) - self.register_background_update_handler( + self.db.updates.register_background_update_handler( "populate_user_directory_cleanup", self._populate_user_directory_cleanup ) @@ -85,7 +85,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ txn.execute(sql) rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()] - self._simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) + self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) del rooms # If search all users is on, get all the users we want to add. @@ -100,15 +100,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore txn.execute("SELECT name FROM users") users = [{"user_id": x[0]} for x in txn.fetchall()] - self._simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) + self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) new_pos = yield self.get_max_stream_id_in_current_state_deltas() - yield self.runInteraction( + yield self.db.runInteraction( "populate_user_directory_temp_build", _make_staging_area ) - yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) + yield self.db.simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) - yield self._end_background_update("populate_user_directory_createtables") + yield self.db.updates._end_background_update( + "populate_user_directory_createtables" + ) return 1 @defer.inlineCallbacks @@ -116,7 +118,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ Update the user directory stream position, then clean up the old tables. """ - position = yield self._simple_select_one_onecol( + position = yield self.db.simple_select_one_onecol( TEMP_TABLE + "_position", None, "position" ) yield self.update_user_directory_stream_pos(position) @@ -126,11 +128,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users") txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position") - yield self.runInteraction( + yield self.db.runInteraction( "populate_user_directory_cleanup", _delete_staging_area ) - yield self._end_background_update("populate_user_directory_cleanup") + yield self.db.updates._end_background_update("populate_user_directory_cleanup") return 1 @defer.inlineCallbacks @@ -170,13 +172,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore return rooms_to_work_on - rooms_to_work_on = yield self.runInteraction( + rooms_to_work_on = yield self.db.runInteraction( "populate_user_directory_temp_read", _get_next_batch ) # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self._end_background_update("populate_user_directory_process_rooms") + yield self.db.updates._end_background_update( + "populate_user_directory_process_rooms" + ) return 1 logger.info( @@ -243,12 +247,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore to_insert.clear() # We've finished a room. Delete it from the table. - yield self._simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) + yield self.db.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) # Update the remaining counter. progress["remaining"] -= 1 - yield self.runInteraction( + yield self.db.runInteraction( "populate_user_directory", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_user_directory_process_rooms", progress, ) @@ -267,7 +271,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore If search_all_users is enabled, add all of the users to the user directory. """ if not self.hs.config.user_directory_search_all_users: - yield self._end_background_update("populate_user_directory_process_users") + yield self.db.updates._end_background_update( + "populate_user_directory_process_users" + ) return 1 def _get_next_batch(txn): @@ -291,13 +297,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore return users_to_work_on - users_to_work_on = yield self.runInteraction( + users_to_work_on = yield self.db.runInteraction( "populate_user_directory_temp_read", _get_next_batch ) # No more users -- complete the transaction. if not users_to_work_on: - yield self._end_background_update("populate_user_directory_process_users") + yield self.db.updates._end_background_update( + "populate_user_directory_process_users" + ) return 1 logger.info( @@ -312,12 +320,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) # We've finished processing a user. Delete it from the table. - yield self._simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) + yield self.db.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) # Update the remaining counter. progress["remaining"] -= 1 - yield self.runInteraction( + yield self.db.runInteraction( "populate_user_directory", - self._background_update_progress_txn, + self.db.updates._background_update_progress_txn, "populate_user_directory_process_users", progress, ) @@ -361,7 +369,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ def _update_profile_in_user_dir_txn(txn): - new_entry = self._simple_upsert_txn( + new_entry = self.db.simple_upsert_txn( txn, table="user_directory", keyvalues={"user_id": user_id}, @@ -435,7 +443,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) elif isinstance(self.database_engine, Sqlite3Engine): value = "%s %s" % (user_id, display_name) if display_name else user_id - self._simple_upsert_txn( + self.db.simple_upsert_txn( txn, table="user_directory_search", keyvalues={"user_id": user_id}, @@ -448,7 +456,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - return self.runInteraction( + return self.db.runInteraction( "update_profile_in_user_dir", _update_profile_in_user_dir_txn ) @@ -462,7 +470,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ def _add_users_who_share_room_txn(txn): - self._simple_upsert_many_txn( + self.db.simple_upsert_many_txn( txn, table="users_who_share_private_rooms", key_names=["user_id", "other_user_id", "room_id"], @@ -474,7 +482,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore value_values=None, ) - return self.runInteraction( + return self.db.runInteraction( "add_users_who_share_room", _add_users_who_share_room_txn ) @@ -489,7 +497,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore def _add_users_in_public_rooms_txn(txn): - self._simple_upsert_many_txn( + self.db.simple_upsert_many_txn( txn, table="users_in_public_rooms", key_names=["user_id", "room_id"], @@ -498,7 +506,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore value_values=None, ) - return self.runInteraction( + return self.db.runInteraction( "add_users_in_public_rooms", _add_users_in_public_rooms_txn ) @@ -513,13 +521,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore txn.execute("DELETE FROM users_who_share_private_rooms") txn.call_after(self.get_user_in_directory.invalidate_all) - return self.runInteraction( + return self.db.runInteraction( "delete_all_from_user_dir", _delete_all_from_user_dir_txn ) @cached() def get_user_in_directory(self, user_id): - return self._simple_select_one( + return self.db.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, retcols=("display_name", "avatar_url"), @@ -528,7 +536,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) def update_user_directory_stream_pos(self, stream_id): - return self._simple_update_one( + return self.db.simple_update_one( table="user_directory_stream_pos", keyvalues={}, updatevalues={"stream_id": stream_id}, @@ -542,47 +550,47 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # add_users_who_share_private_rooms? SHARE_PRIVATE_WORKING_SET = 500 - def __init__(self, db_conn, hs): - super(UserDirectoryStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(UserDirectoryStore, self).__init__(database, db_conn, hs) def remove_from_user_dir(self, user_id): def _remove_from_user_dir_txn(txn): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="user_directory", keyvalues={"user_id": user_id} ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="user_directory_search", keyvalues={"user_id": user_id} ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_in_public_rooms", keyvalues={"user_id": user_id} ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"other_user_id": user_id}, ) txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn) + return self.db.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn) @defer.inlineCallbacks def get_users_in_dir_due_to_room(self, room_id): """Get all user_ids that are in the room directory because they're in the given room_id """ - user_ids_share_pub = yield self._simple_select_onecol( + user_ids_share_pub = yield self.db.simple_select_onecol( table="users_in_public_rooms", keyvalues={"room_id": room_id}, retcol="user_id", desc="get_users_in_dir_due_to_room", ) - user_ids_share_priv = yield self._simple_select_onecol( + user_ids_share_priv = yield self.db.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"room_id": room_id}, retcol="other_user_id", @@ -605,23 +613,23 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): """ def _remove_user_who_share_room_txn(txn): - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"user_id": user_id, "room_id": room_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"other_user_id": user_id, "room_id": room_id}, ) - self._simple_delete_txn( + self.db.simple_delete_txn( txn, table="users_in_public_rooms", keyvalues={"user_id": user_id, "room_id": room_id}, ) - return self.runInteraction( + return self.db.runInteraction( "remove_user_who_share_room", _remove_user_who_share_room_txn ) @@ -636,14 +644,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): Returns: list: user_id """ - rows = yield self._simple_select_onecol( + rows = yield self.db.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, retcol="room_id", desc="get_rooms_user_is_in", ) - pub_rows = yield self._simple_select_onecol( + pub_rows = yield self.db.simple_select_onecol( table="users_in_public_rooms", keyvalues={"user_id": user_id}, retcol="room_id", @@ -674,14 +682,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): ) f2 USING (room_id) """ - rows = yield self._execute( + rows = yield self.db.execute( "get_rooms_in_common_for_users", None, sql, user_id, other_user_id ) return [room_id for room_id, in rows] def get_user_directory_stream_pos(self): - return self._simple_select_one_onecol( + return self.db.simple_select_one_onecol( table="user_directory_stream_pos", keyvalues={}, retcol="stream_id", @@ -786,8 +794,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # This should be unreachable. raise Exception("Unrecognized database engine") - results = yield self._execute( - "search_user_dir", self.cursor_to_dict, sql, *args + results = yield self.db.execute( + "search_user_dir", self.db.cursor_to_dict, sql, *args ) limited = len(results) > limit diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py index aa4f0da5f0..af8025bc17 100644 --- a/synapse/storage/data_stores/main/user_erasure_store.py +++ b/synapse/storage/data_stores/main/user_erasure_store.py @@ -31,7 +31,7 @@ class UserErasureWorkerStore(SQLBaseStore): Returns: Deferred[bool]: True if the user has requested erasure """ - return self._simple_select_onecol( + return self.db.simple_select_onecol( table="erased_users", keyvalues={"user_id": user_id}, retcol="1", @@ -56,7 +56,7 @@ class UserErasureWorkerStore(SQLBaseStore): # iterate it multiple times, and (b) avoiding duplicates. user_ids = tuple(set(user_ids)) - rows = yield self._simple_select_many_batch( + rows = yield self.db.simple_select_many_batch( table="erased_users", column="user_id", iterable=user_ids, @@ -88,4 +88,4 @@ class UserErasureStore(UserErasureWorkerStore): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.runInteraction("mark_user_erased", f) + return self.db.runInteraction("mark_user_erased", f) diff --git a/synapse/storage/database.py b/synapse/storage/database.py new file mode 100644 index 0000000000..ec19ae1d9d --- /dev/null +++ b/synapse/storage/database.py @@ -0,0 +1,1490 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import sys +import time +from typing import Iterable, Tuple + +from six import iteritems, iterkeys, itervalues +from six.moves import intern, range + +from prometheus_client import Histogram + +from twisted.internet import defer + +from synapse.api.errors import StoreError +from synapse.logging.context import LoggingContext, make_deferred_yieldable +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.background_updates import BackgroundUpdater +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.util.stringutils import exception_to_unicode + +# import a function which will return a monotonic time, in seconds +try: + # on python 3, use time.monotonic, since time.clock can go backwards + from time import monotonic as monotonic_time +except ImportError: + # ... but python 2 doesn't have it + from time import clock as monotonic_time + +logger = logging.getLogger(__name__) + +try: + MAX_TXN_ID = sys.maxint - 1 +except AttributeError: + # python 3 does not have a maximum int value + MAX_TXN_ID = 2 ** 63 - 1 + +sql_logger = logging.getLogger("synapse.storage.SQL") +transaction_logger = logging.getLogger("synapse.storage.txn") +perf_logger = logging.getLogger("synapse.storage.TIME") + +sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec") + +sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"]) +sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"]) + + +# Unique indexes which have been added in background updates. Maps from table name +# to the name of the background update which added the unique index to that table. +# +# This is used by the upsert logic to figure out which tables are safe to do a proper +# UPSERT on: until the relevant background update has completed, we +# have to emulate an upsert by locking the table. +# +UNIQUE_INDEX_BACKGROUND_UPDATES = { + "user_ips": "user_ips_device_unique_index", + "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx", + "device_lists_remote_cache": "device_lists_remote_cache_unique_idx", + "event_search": "event_search_event_id_idx", +} + + +class LoggingTransaction(object): + """An object that almost-transparently proxies for the 'txn' object + passed to the constructor. Adds logging and metrics to the .execute() + method. + + Args: + txn: The database transcation object to wrap. + name (str): The name of this transactions for logging. + database_engine (Sqlite3Engine|PostgresEngine) + after_callbacks(list|None): A list that callbacks will be appended to + that have been added by `call_after` which should be run on + successful completion of the transaction. None indicates that no + callbacks should be allowed to be scheduled to run. + exception_callbacks(list|None): A list that callbacks will be appended + to that have been added by `call_on_exception` which should be run + if transaction ends with an error. None indicates that no callbacks + should be allowed to be scheduled to run. + """ + + __slots__ = [ + "txn", + "name", + "database_engine", + "after_callbacks", + "exception_callbacks", + ] + + def __init__( + self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None + ): + object.__setattr__(self, "txn", txn) + object.__setattr__(self, "name", name) + object.__setattr__(self, "database_engine", database_engine) + object.__setattr__(self, "after_callbacks", after_callbacks) + object.__setattr__(self, "exception_callbacks", exception_callbacks) + + def call_after(self, callback, *args, **kwargs): + """Call the given callback on the main twisted thread after the + transaction has finished. Used to invalidate the caches on the + correct thread. + """ + self.after_callbacks.append((callback, args, kwargs)) + + def call_on_exception(self, callback, *args, **kwargs): + self.exception_callbacks.append((callback, args, kwargs)) + + def __getattr__(self, name): + return getattr(self.txn, name) + + def __setattr__(self, name, value): + setattr(self.txn, name, value) + + def __iter__(self): + return self.txn.__iter__() + + def execute_batch(self, sql, args): + if isinstance(self.database_engine, PostgresEngine): + from psycopg2.extras import execute_batch + + self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) + else: + for val in args: + self.execute(sql, val) + + def execute(self, sql, *args): + self._do_execute(self.txn.execute, sql, *args) + + def executemany(self, sql, *args): + self._do_execute(self.txn.executemany, sql, *args) + + def _make_sql_one_line(self, sql): + "Strip newlines out of SQL so that the loggers in the DB are on one line" + return " ".join(l.strip() for l in sql.splitlines() if l.strip()) + + def _do_execute(self, func, sql, *args): + sql = self._make_sql_one_line(sql) + + # TODO(paul): Maybe use 'info' and 'debug' for values? + sql_logger.debug("[SQL] {%s} %s", self.name, sql) + + sql = self.database_engine.convert_param_style(sql) + if args: + try: + sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) + except Exception: + # Don't let logging failures stop SQL from working + pass + + start = time.time() + + try: + return func(sql, *args) + except Exception as e: + logger.debug("[SQL FAIL] {%s} %s", self.name, e) + raise + finally: + secs = time.time() - start + sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) + sql_query_timer.labels(sql.split()[0]).observe(secs) + + +class PerformanceCounters(object): + def __init__(self): + self.current_counters = {} + self.previous_counters = {} + + def update(self, key, duration_secs): + count, cum_time = self.current_counters.get(key, (0, 0)) + count += 1 + cum_time += duration_secs + self.current_counters[key] = (count, cum_time) + + def interval(self, interval_duration_secs, limit=3): + counters = [] + for name, (count, cum_time) in iteritems(self.current_counters): + prev_count, prev_time = self.previous_counters.get(name, (0, 0)) + counters.append( + ( + (cum_time - prev_time) / interval_duration_secs, + count - prev_count, + name, + ) + ) + + self.previous_counters = dict(self.current_counters) + + counters.sort(reverse=True) + + top_n_counters = ", ".join( + "%s(%d): %.3f%%" % (name, count, 100 * ratio) + for ratio, count, name in counters[:limit] + ) + + return top_n_counters + + +class Database(object): + """Wraps a single physical database and connection pool. + + A single database may be used by multiple data stores. + """ + + _TXN_ID = 0 + + def __init__(self, hs): + self.hs = hs + self._clock = hs.get_clock() + self._db_pool = hs.get_db_pool() + + self.updates = BackgroundUpdater(hs, self) + + self._previous_txn_total_time = 0 + self._current_txn_total_time = 0 + self._previous_loop_ts = 0 + + # TODO(paul): These can eventually be removed once the metrics code + # is running in mainline, and we have some nice monitoring frontends + # to watch it + self._txn_perf_counters = PerformanceCounters() + + self.engine = hs.database_engine + + # A set of tables that are not safe to use native upserts in. + self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) + + # We add the user_directory_search table to the blacklist on SQLite + # because the existing search table does not have an index, making it + # unsafe to use native upserts. + if isinstance(self.engine, Sqlite3Engine): + self._unsafe_to_upsert_tables.add("user_directory_search") + + if self.engine.can_native_upsert: + # Check ASAP (and then later, every 1s) to see if we have finished + # background updates of tables that aren't safe to update. + self._clock.call_later( + 0.0, + run_as_background_process, + "upsert_safety_check", + self._check_safe_to_upsert, + ) + + @defer.inlineCallbacks + def _check_safe_to_upsert(self): + """ + Is it safe to use native UPSERT? + + If there are background updates, we will need to wait, as they may be + the addition of indexes that set the UNIQUE constraint that we require. + + If the background updates have not completed, wait 15 sec and check again. + """ + updates = yield self.simple_select_list( + "background_updates", + keyvalues=None, + retcols=["update_name"], + desc="check_background_updates", + ) + updates = [x["update_name"] for x in updates] + + for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items(): + if update_name not in updates: + logger.debug("Now safe to upsert in %s", table) + self._unsafe_to_upsert_tables.discard(table) + + # If there's any updates still running, reschedule to run. + if updates: + self._clock.call_later( + 15.0, + run_as_background_process, + "upsert_safety_check", + self._check_safe_to_upsert, + ) + + def start_profiling(self): + self._previous_loop_ts = monotonic_time() + + def loop(): + curr = self._current_txn_total_time + prev = self._previous_txn_total_time + self._previous_txn_total_time = curr + + time_now = monotonic_time() + time_then = self._previous_loop_ts + self._previous_loop_ts = time_now + + duration = time_now - time_then + ratio = (curr - prev) / duration + + top_three_counters = self._txn_perf_counters.interval(duration, limit=3) + + perf_logger.info( + "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters + ) + + self._clock.looping_call(loop, 10000) + + def new_transaction( + self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs + ): + start = monotonic_time() + txn_id = self._TXN_ID + + # We don't really need these to be unique, so lets stop it from + # growing really large. + self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID) + + name = "%s-%x" % (desc, txn_id) + + transaction_logger.debug("[TXN START] {%s}", name) + + try: + i = 0 + N = 5 + while True: + cursor = LoggingTransaction( + conn.cursor(), + name, + self.engine, + after_callbacks, + exception_callbacks, + ) + try: + r = func(cursor, *args, **kwargs) + conn.commit() + return r + except self.engine.module.OperationalError as e: + # This can happen if the database disappears mid + # transaction. + logger.warning( + "[TXN OPERROR] {%s} %s %d/%d", + name, + exception_to_unicode(e), + i, + N, + ) + if i < N: + i += 1 + try: + conn.rollback() + except self.engine.module.Error as e1: + logger.warning( + "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1) + ) + continue + raise + except self.engine.module.DatabaseError as e: + if self.engine.is_deadlock(e): + logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N) + if i < N: + i += 1 + try: + conn.rollback() + except self.engine.module.Error as e1: + logger.warning( + "[TXN EROLL] {%s} %s", + name, + exception_to_unicode(e1), + ) + continue + raise + finally: + # we're either about to retry with a new cursor, or we're about to + # release the connection. Once we release the connection, it could + # get used for another query, which might do a conn.rollback(). + # + # In the latter case, even though that probably wouldn't affect the + # results of this transaction, python's sqlite will reset all + # statements on the connection [1], which will make our cursor + # invalid [2]. + # + # In any case, continuing to read rows after commit()ing seems + # dubious from the PoV of ACID transactional semantics + # (sqlite explicitly says that once you commit, you may see rows + # from subsequent updates.) + # + # In psycopg2, cursors are essentially a client-side fabrication - + # all the data is transferred to the client side when the statement + # finishes executing - so in theory we could go on streaming results + # from the cursor, but attempting to do so would make us + # incompatible with sqlite, so let's make sure we're not doing that + # by closing the cursor. + # + # (*named* cursors in psycopg2 are different and are proper server- + # side things, but (a) we don't use them and (b) they are implicitly + # closed by ending the transaction anyway.) + # + # In short, if we haven't finished with the cursor yet, that's a + # problem waiting to bite us. + # + # TL;DR: we're done with the cursor, so we can close it. + # + # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465 + # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236 + cursor.close() + except Exception as e: + logger.debug("[TXN FAIL] {%s} %s", name, e) + raise + finally: + end = monotonic_time() + duration = end - start + + LoggingContext.current_context().add_database_transaction(duration) + + transaction_logger.debug("[TXN END] {%s} %f sec", name, duration) + + self._current_txn_total_time += duration + self._txn_perf_counters.update(desc, duration) + sql_txn_timer.labels(desc).observe(duration) + + @defer.inlineCallbacks + def runInteraction(self, desc, func, *args, **kwargs): + """Starts a transaction on the database and runs a given function + + Arguments: + desc (str): description of the transaction, for logging and metrics + func (func): callback function, which will be called with a + database transaction (twisted.enterprise.adbapi.Transaction) as + its first argument, followed by `args` and `kwargs`. + + args (list): positional args to pass to `func` + kwargs (dict): named args to pass to `func` + + Returns: + Deferred: The result of func + """ + after_callbacks = [] + exception_callbacks = [] + + if LoggingContext.current_context() == LoggingContext.sentinel: + logger.warning("Starting db txn '%s' from sentinel context", desc) + + try: + result = yield self.runWithConnection( + self.new_transaction, + desc, + after_callbacks, + exception_callbacks, + func, + *args, + **kwargs + ) + + for after_callback, after_args, after_kwargs in after_callbacks: + after_callback(*after_args, **after_kwargs) + except: # noqa: E722, as we reraise the exception this is fine. + for after_callback, after_args, after_kwargs in exception_callbacks: + after_callback(*after_args, **after_kwargs) + raise + + return result + + @defer.inlineCallbacks + def runWithConnection(self, func, *args, **kwargs): + """Wraps the .runWithConnection() method on the underlying db_pool. + + Arguments: + func (func): callback function, which will be called with a + database connection (twisted.enterprise.adbapi.Connection) as + its first argument, followed by `args` and `kwargs`. + args (list): positional args to pass to `func` + kwargs (dict): named args to pass to `func` + + Returns: + Deferred: The result of func + """ + parent_context = LoggingContext.current_context() + if parent_context == LoggingContext.sentinel: + logger.warning( + "Starting db connection from sentinel context: metrics will be lost" + ) + parent_context = None + + start_time = monotonic_time() + + def inner_func(conn, *args, **kwargs): + with LoggingContext("runWithConnection", parent_context) as context: + sched_duration_sec = monotonic_time() - start_time + sql_scheduling_timer.observe(sched_duration_sec) + context.add_database_scheduled(sched_duration_sec) + + if self.engine.is_connection_closed(conn): + logger.debug("Reconnecting closed database connection") + conn.reconnect() + + return func(conn, *args, **kwargs) + + result = yield make_deferred_yieldable( + self._db_pool.runWithConnection(inner_func, *args, **kwargs) + ) + + return result + + @staticmethod + def cursor_to_dict(cursor): + """Converts a SQL cursor into an list of dicts. + + Args: + cursor : The DBAPI cursor which has executed a query. + Returns: + A list of dicts where the key is the column header. + """ + col_headers = list(intern(str(column[0])) for column in cursor.description) + results = list(dict(zip(col_headers, row)) for row in cursor) + return results + + def execute(self, desc, decoder, query, *args): + """Runs a single query for a result set. + + Args: + decoder - The function which can resolve the cursor results to + something meaningful. + query - The query string to execute + *args - Query args. + Returns: + The result of decoder(results) + """ + + def interaction(txn): + txn.execute(query, args) + if decoder: + return decoder(txn) + else: + return txn.fetchall() + + return self.runInteraction(desc, interaction) + + # "Simple" SQL API methods that operate on a single table with no JOINs, + # no complex WHERE clauses, just a dict of values for columns. + + @defer.inlineCallbacks + def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): + """Executes an INSERT query on the named table. + + Args: + table : string giving the table name + values : dict of new column names and values for them + or_ignore : bool stating whether an exception should be raised + when a conflicting row already exists. If True, False will be + returned by the function instead + desc : string giving a description of the transaction + + Returns: + bool: Whether the row was inserted or not. Only useful when + `or_ignore` is True + """ + try: + yield self.runInteraction(desc, self.simple_insert_txn, table, values) + except self.engine.module.IntegrityError: + # We have to do or_ignore flag at this layer, since we can't reuse + # a cursor after we receive an error from the db. + if not or_ignore: + raise + return False + return True + + @staticmethod + def simple_insert_txn(txn, table, values): + keys, vals = zip(*values.items()) + + sql = "INSERT INTO %s (%s) VALUES(%s)" % ( + table, + ", ".join(k for k in keys), + ", ".join("?" for _ in keys), + ) + + txn.execute(sql, vals) + + def simple_insert_many(self, table, values, desc): + return self.runInteraction(desc, self.simple_insert_many_txn, table, values) + + @staticmethod + def simple_insert_many_txn(txn, table, values): + if not values: + return + + # This is a *slight* abomination to get a list of tuples of key names + # and a list of tuples of value names. + # + # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}] + # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)] + # + # The sort is to ensure that we don't rely on dictionary iteration + # order. + keys, vals = zip( + *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i] + ) + + for k in keys: + if k != keys[0]: + raise RuntimeError("All items must have the same keys") + + sql = "INSERT INTO %s (%s) VALUES(%s)" % ( + table, + ", ".join(k for k in keys[0]), + ", ".join("?" for _ in keys[0]), + ) + + txn.executemany(sql, vals) + + @defer.inlineCallbacks + def simple_upsert( + self, + table, + keyvalues, + values, + insertion_values={}, + desc="simple_upsert", + lock=True, + ): + """ + + `lock` should generally be set to True (the default), but can be set + to False if either of the following are true: + + * there is a UNIQUE INDEX on the key columns. In this case a conflict + will cause an IntegrityError in which case this function will retry + the update. + + * we somehow know that we are the only thread which will be updating + this table. + + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key columns and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + Deferred(None or bool): Native upserts always return None. Emulated + upserts return True if a new entry was created, False if an existing + one was updated. + """ + attempts = 0 + while True: + try: + result = yield self.runInteraction( + desc, + self.simple_upsert_txn, + table, + keyvalues, + values, + insertion_values, + lock=lock, + ) + return result + except self.engine.module.IntegrityError as e: + attempts += 1 + if attempts >= 5: + # don't retry forever, because things other than races + # can cause IntegrityErrors + raise + + # presumably we raced with another transaction: let's retry. + logger.warning( + "IntegrityError when upserting into %s; retrying: %s", table, e + ) + + def simple_upsert_txn( + self, txn, table, keyvalues, values, insertion_values={}, lock=True + ): + """ + Pick the UPSERT method which works best on the platform. Either the + native one (Pg9.5+, recent SQLites), or fall back to an emulated method. + + Args: + txn: The transaction to use. + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + None or bool: Native upserts always return None. Emulated + upserts return True if a new entry was created, False if an existing + one was updated. + """ + if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: + return self.simple_upsert_txn_native_upsert( + txn, table, keyvalues, values, insertion_values=insertion_values + ) + else: + return self.simple_upsert_txn_emulated( + txn, + table, + keyvalues, + values, + insertion_values=insertion_values, + lock=lock, + ) + + def simple_upsert_txn_emulated( + self, txn, table, keyvalues, values, insertion_values={}, lock=True + ): + """ + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. + Returns: + bool: Return True if a new entry was created, False if an existing + one was updated. + """ + # We need to lock the table :(, unless we're *really* careful + if lock: + self.engine.lock_table(txn, table) + + def _getwhere(key): + # If the value we're passing in is None (aka NULL), we need to use + # IS, not =, as NULL = NULL equals NULL (False). + if keyvalues[key] is None: + return "%s IS ?" % (key,) + else: + return "%s = ?" % (key,) + + if not values: + # If `values` is empty, then all of the values we care about are in + # the unique key, so there is nothing to UPDATE. We can just do a + # SELECT instead to see if it exists. + sql = "SELECT 1 FROM %s WHERE %s" % ( + table, + " AND ".join(_getwhere(k) for k in keyvalues), + ) + sqlargs = list(keyvalues.values()) + txn.execute(sql, sqlargs) + if txn.fetchall(): + # We have an existing record. + return False + else: + # First try to update. + sql = "UPDATE %s SET %s WHERE %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in values), + " AND ".join(_getwhere(k) for k in keyvalues), + ) + sqlargs = list(values.values()) + list(keyvalues.values()) + + txn.execute(sql, sqlargs) + if txn.rowcount > 0: + # successfully updated at least one row. + return False + + # We didn't find any existing rows, so insert a new one + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(values) + allvalues.update(insertion_values) + + sql = "INSERT INTO %s (%s) VALUES (%s)" % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues), + ) + txn.execute(sql, list(allvalues.values())) + # successfully inserted + return True + + def simple_upsert_txn_native_upsert( + self, txn, table, keyvalues, values, insertion_values={} + ): + """ + Use the native UPSERT functionality in recent PostgreSQL versions. + + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + insertion_values (dict): additional key/values to use only when + inserting + Returns: + None + """ + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(insertion_values) + + if not values: + latter = "NOTHING" + else: + allvalues.update(values) + latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) + + sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues), + ", ".join(k for k in keyvalues), + latter, + ) + txn.execute(sql, list(allvalues.values())) + + def simple_upsert_many_txn( + self, txn, table, key_names, key_values, value_names, value_values + ): + """ + Upsert, many times. + + Args: + table (str): The table to upsert into + key_names (list[str]): The key column names. + key_values (list[list]): A list of each row's key column values. + value_names (list[str]): The value column names. If empty, no + values will be used, even if value_values is provided. + value_values (list[list]): A list of each row's value column values. + Returns: + None + """ + if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: + return self.simple_upsert_many_txn_native_upsert( + txn, table, key_names, key_values, value_names, value_values + ) + else: + return self.simple_upsert_many_txn_emulated( + txn, table, key_names, key_values, value_names, value_values + ) + + def simple_upsert_many_txn_emulated( + self, txn, table, key_names, key_values, value_names, value_values + ): + """ + Upsert, many times, but without native UPSERT support or batching. + + Args: + table (str): The table to upsert into + key_names (list[str]): The key column names. + key_values (list[list]): A list of each row's key column values. + value_names (list[str]): The value column names. If empty, no + values will be used, even if value_values is provided. + value_values (list[list]): A list of each row's value column values. + Returns: + None + """ + # No value columns, therefore make a blank list so that the following + # zip() works correctly. + if not value_names: + value_values = [() for x in range(len(key_values))] + + for keyv, valv in zip(key_values, value_values): + _keys = {x: y for x, y in zip(key_names, keyv)} + _vals = {x: y for x, y in zip(value_names, valv)} + + self.simple_upsert_txn_emulated(txn, table, _keys, _vals) + + def simple_upsert_many_txn_native_upsert( + self, txn, table, key_names, key_values, value_names, value_values + ): + """ + Upsert, many times, using batching where possible. + + Args: + table (str): The table to upsert into + key_names (list[str]): The key column names. + key_values (list[list]): A list of each row's key column values. + value_names (list[str]): The value column names. If empty, no + values will be used, even if value_values is provided. + value_values (list[list]): A list of each row's value column values. + Returns: + None + """ + allnames = [] + allnames.extend(key_names) + allnames.extend(value_names) + + if not value_names: + # No value columns, therefore make a blank list so that the + # following zip() works correctly. + latter = "NOTHING" + value_values = [() for x in range(len(key_values))] + else: + latter = "UPDATE SET " + ", ".join( + k + "=EXCLUDED." + k for k in value_names + ) + + sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % ( + table, + ", ".join(k for k in allnames), + ", ".join("?" for _ in allnames), + ", ".join(key_names), + latter, + ) + + args = [] + + for x, y in zip(key_values, value_values): + args.append(tuple(x) + tuple(y)) + + return txn.execute_batch(sql, args) + + def simple_select_one( + self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one" + ): + """Executes a SELECT query on the named table, which is expected to + return a single row, returning multiple columns from it. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + retcols : list of strings giving the names of the columns to return + + allow_none : If true, return None instead of failing if the SELECT + statement returns no rows + """ + return self.runInteraction( + desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none + ) + + def simple_select_one_onecol( + self, + table, + keyvalues, + retcol, + allow_none=False, + desc="simple_select_one_onecol", + ): + """Executes a SELECT query on the named table, which is expected to + return a single row, returning a single column from it. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + retcol : string giving the name of the column to return + """ + return self.runInteraction( + desc, + self.simple_select_one_onecol_txn, + table, + keyvalues, + retcol, + allow_none=allow_none, + ) + + @classmethod + def simple_select_one_onecol_txn( + cls, txn, table, keyvalues, retcol, allow_none=False + ): + ret = cls.simple_select_onecol_txn( + txn, table=table, keyvalues=keyvalues, retcol=retcol + ) + + if ret: + return ret[0] + else: + if allow_none: + return None + else: + raise StoreError(404, "No row found") + + @staticmethod + def simple_select_onecol_txn(txn, table, keyvalues, retcol): + sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} + + if keyvalues: + sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) + txn.execute(sql, list(keyvalues.values())) + else: + txn.execute(sql) + + return [r[0] for r in txn] + + def simple_select_onecol( + self, table, keyvalues, retcol, desc="simple_select_onecol" + ): + """Executes a SELECT query on the named table, which returns a list + comprising of the values of the named column from the selected rows. + + Args: + table (str): table name + keyvalues (dict|None): column names and values to select the rows with + retcol (str): column whos value we wish to retrieve. + + Returns: + Deferred: Results in a list + """ + return self.runInteraction( + desc, self.simple_select_onecol_txn, table, keyvalues, retcol + ) + + def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + table (str): the table name + keyvalues (dict[str, Any] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self.runInteraction( + desc, self.simple_select_list_txn, table, keyvalues, retcols + ) + + @classmethod + def simple_select_list_txn(cls, txn, table, keyvalues, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + retcols (iterable[str]): the names of the columns to return + """ + if keyvalues: + sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + txn.execute(sql, list(keyvalues.values())) + else: + sql = "SELECT %s FROM %s" % (", ".join(retcols), table) + txn.execute(sql) + + return cls.cursor_to_dict(txn) + + @defer.inlineCallbacks + def simple_select_many_batch( + self, + table, + column, + iterable, + retcols, + keyvalues={}, + desc="simple_select_many_batch", + batch_size=100, + ): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Filters rows by if value of `column` is in `iterable`. + + Args: + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + retcols : list of strings giving the names of the columns to return + """ + results = [] + + if not iterable: + return results + + # iterables can not be sliced, so convert it to a list first + it_list = list(iterable) + + chunks = [ + it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size) + ] + for chunk in chunks: + rows = yield self.runInteraction( + desc, + self.simple_select_many_txn, + table, + column, + chunk, + keyvalues, + retcols, + ) + + results.extend(rows) + + return results + + @classmethod + def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Filters rows by if value of `column` is in `iterable`. + + Args: + txn : Transaction object + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + retcols : list of strings giving the names of the columns to return + """ + if not iterable: + return [] + + clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) + clauses = [clause] + + for key, value in iteritems(keyvalues): + clauses.append("%s = ?" % (key,)) + values.append(value) + + sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join(clauses), + ) + + txn.execute(sql, values) + return cls.cursor_to_dict(txn) + + def simple_update(self, table, keyvalues, updatevalues, desc): + return self.runInteraction( + desc, self.simple_update_txn, table, keyvalues, updatevalues + ) + + @staticmethod + def simple_update_txn(txn, table, keyvalues, updatevalues): + if keyvalues: + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) + else: + where = "" + + update_sql = "UPDATE %s SET %s %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in updatevalues), + where, + ) + + txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values())) + + return txn.rowcount + + def simple_update_one( + self, table, keyvalues, updatevalues, desc="simple_update_one" + ): + """Executes an UPDATE query on the named table, setting new values for + columns in a row matching the key values. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + updatevalues : dict giving column names and values to update + retcols : optional list of column names to return + + If present, retcols gives a list of column names on which to perform + a SELECT statement *before* performing the UPDATE statement. The values + of these will be returned in a dict. + + These are performed within the same transaction, allowing an atomic + get-and-set. This can be used to implement compare-and-set by putting + the update column in the 'keyvalues' dict as well. + """ + return self.runInteraction( + desc, self.simple_update_one_txn, table, keyvalues, updatevalues + ) + + @classmethod + def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): + rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues) + + if rowcount == 0: + raise StoreError(404, "No row found (%s)" % (table,)) + if rowcount > 1: + raise StoreError(500, "More than one row matched (%s)" % (table,)) + + @staticmethod + def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): + select_sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + + txn.execute(select_sql, list(keyvalues.values())) + row = txn.fetchone() + + if not row: + if allow_none: + return None + raise StoreError(404, "No row found (%s)" % (table,)) + if txn.rowcount > 1: + raise StoreError(500, "More than one row matched (%s)" % (table,)) + + return dict(zip(retcols, row)) + + def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"): + """Executes a DELETE query on the named table, expecting to delete a + single row. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ + return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) + + @staticmethod + def simple_delete_one_txn(txn, table, keyvalues): + """Executes a DELETE query on the named table, expecting to delete a + single row. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ + sql = "DELETE FROM %s WHERE %s" % ( + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + + txn.execute(sql, list(keyvalues.values())) + if txn.rowcount == 0: + raise StoreError(404, "No row found (%s)" % (table,)) + if txn.rowcount > 1: + raise StoreError(500, "More than one row matched (%s)" % (table,)) + + def simple_delete(self, table, keyvalues, desc): + return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) + + @staticmethod + def simple_delete_txn(txn, table, keyvalues): + sql = "DELETE FROM %s WHERE %s" % ( + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + ) + + txn.execute(sql, list(keyvalues.values())) + return txn.rowcount + + def simple_delete_many(self, table, column, iterable, keyvalues, desc): + return self.runInteraction( + desc, self.simple_delete_many_txn, table, column, iterable, keyvalues + ) + + @staticmethod + def simple_delete_many_txn(txn, table, column, iterable, keyvalues): + """Executes a DELETE query on the named table. + + Filters rows by if value of `column` is in `iterable`. + + Args: + txn : Transaction object + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + + Returns: + int: Number rows deleted + """ + if not iterable: + return 0 + + sql = "DELETE FROM %s" % table + + clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) + clauses = [clause] + + for key, value in iteritems(keyvalues): + clauses.append("%s = ?" % (key,)) + values.append(value) + + if clauses: + sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) + txn.execute(sql, values) + + return txn.rowcount + + def get_cache_dict( + self, db_conn, table, entity_column, stream_column, max_value, limit=100000 + ): + # Fetch a mapping of room_id -> max stream position for "recent" rooms. + # It doesn't really matter how many we get, the StreamChangeCache will + # do the right thing to ensure it respects the max size of cache. + sql = ( + "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" + " WHERE %(stream)s > ? - %(limit)s" + " GROUP BY %(entity)s" + ) % { + "table": table, + "entity": entity_column, + "stream": stream_column, + "limit": limit, + } + + sql = self.engine.convert_param_style(sql) + + txn = db_conn.cursor() + txn.execute(sql, (int(max_value),)) + + cache = {row[0]: int(row[1]) for row in txn} + + txn.close() + + if cache: + min_val = min(itervalues(cache)) + else: + min_val = max_value + + return cache, min_val + + def simple_select_list_paginate( + self, + table, + orderby, + start, + limit, + retcols, + filters=None, + keyvalues=None, + order_direction="ASC", + desc="simple_select_list_paginate", + ): + """ + Executes a SELECT query on the named table with start and limit, + of row numbers, which may return zero or number of rows from start to limit, + returning the result as a list of dicts. + + Args: + table (str): the table name + filters (dict[str, T] | None): + column names and values to filter the rows with, or None to not + apply a WHERE ? LIKE ? clause. + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + orderby (str): Column to order the results by. + start (int): Index to begin the query at. + limit (int): Number of results to return. + retcols (iterable[str]): the names of the columns to return + order_direction (str): Whether the results should be ordered "ASC" or "DESC". + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self.runInteraction( + desc, + self.simple_select_list_paginate_txn, + table, + orderby, + start, + limit, + retcols, + filters=filters, + keyvalues=keyvalues, + order_direction=order_direction, + ) + + @classmethod + def simple_select_list_paginate_txn( + cls, + txn, + table, + orderby, + start, + limit, + retcols, + filters=None, + keyvalues=None, + order_direction="ASC", + ): + """ + Executes a SELECT query on the named table with start and limit, + of row numbers, which may return zero or number of rows from start to limit, + returning the result as a list of dicts. + + Use `filters` to search attributes using SQL wildcards and/or `keyvalues` to + select attributes with exact matches. All constraints are joined together + using 'AND'. + + Args: + txn : Transaction object + table (str): the table name + orderby (str): Column to order the results by. + start (int): Index to begin the query at. + limit (int): Number of results to return. + retcols (iterable[str]): the names of the columns to return + filters (dict[str, T] | None): + column names and values to filter the rows with, or None to not + apply a WHERE ? LIKE ? clause. + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + order_direction (str): Whether the results should be ordered "ASC" or "DESC". + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + if order_direction not in ["ASC", "DESC"]: + raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") + + where_clause = "WHERE " if filters or keyvalues else "" + arg_list = [] + if filters: + where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters) + arg_list += list(filters.values()) + where_clause += " AND " if filters and keyvalues else "" + if keyvalues: + where_clause += " AND ".join("%s = ?" % (k,) for k in keyvalues) + arg_list += list(keyvalues.values()) + + sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % ( + ", ".join(retcols), + table, + where_clause, + orderby, + order_direction, + ) + txn.execute(sql, arg_list + [limit, start]) + + return cls.cursor_to_dict(txn) + + def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + table (str): the table name + term (str | None): + term for searching the table matched to a column. + col (str): column to query term should be matched to + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] or None + """ + + return self.runInteraction( + desc, self.simple_search_list_txn, table, term, col, retcols + ) + + @classmethod + def simple_search_list_txn(cls, txn, table, term, col, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + term (str | None): + term for searching the table matched to a column. + col (str): column to query term should be matched to + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] or None + """ + if term: + sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col) + termvalues = ["%%" + term + "%%"] + txn.execute(sql, termvalues) + else: + return 0 + + return cls.cursor_to_dict(txn) + + +def make_in_list_sql_clause( + database_engine, column: str, iterable: Iterable +) -> Tuple[str, Iterable]: + """Returns an SQL clause that checks the given column is in the iterable. + + On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres + it expands to `column = ANY(?)`. While both DBs support the `IN` form, + using the `ANY` form on postgres means that it views queries with + different length iterables as the same, helping the query stats. + + Args: + database_engine + column: Name of the column + iterable: The values to check the column against. + + Returns: + A tuple of SQL query and the args + """ + + if database_engine.supports_using_any_list: + # This should hopefully be faster, but also makes postgres query + # stats easier to understand. + return "%s = ANY(?)" % (column,), [list(iterable)] + else: + return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable) diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index 3286804322..7b18455469 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging from functools import wraps @@ -64,12 +65,22 @@ def measure_func(name=None): def wrapper(func): block_name = func.__name__ if name is None else name - @wraps(func) - @defer.inlineCallbacks - def measured_func(self, *args, **kwargs): - with Measure(self.clock, block_name): - r = yield func(self, *args, **kwargs) - return r + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def measured_func(self, *args, **kwargs): + with Measure(self.clock, block_name): + r = await func(self, *args, **kwargs) + return r + + else: + + @wraps(func) + @defer.inlineCallbacks + def measured_func(self, *args, **kwargs): + with Measure(self.clock, block_name): + r = yield func(self, *args, **kwargs) + return r return measured_func @@ -80,72 +91,48 @@ class Measure(object): __slots__ = [ "clock", "name", - "start_context", + "_logging_context", "start", - "created_context", - "start_usage", ] def __init__(self, clock, name): self.clock = clock self.name = name - self.start_context = None + self._logging_context = None self.start = None - self.created_context = False def __enter__(self): - self.start = self.clock.time() - self.start_context = LoggingContext.current_context() - if not self.start_context: - self.start_context = LoggingContext("Measure") - self.start_context.__enter__() - self.created_context = True - - self.start_usage = self.start_context.get_resource_usage() + if self._logging_context: + raise RuntimeError("Measure() objects cannot be re-used") + self.start = self.clock.time() + parent_context = LoggingContext.current_context() + self._logging_context = LoggingContext( + "Measure[%s]" % (self.name,), parent_context + ) + self._logging_context.__enter__() in_flight.register((self.name,), self._update_in_flight) def __exit__(self, exc_type, exc_val, exc_tb): - if isinstance(exc_type, Exception) or not self.start_context: - return - - in_flight.unregister((self.name,), self._update_in_flight) + if not self._logging_context: + raise RuntimeError("Measure() block exited without being entered") duration = self.clock.time() - self.start + usage = self._logging_context.get_resource_usage() - block_counter.labels(self.name).inc() - block_timer.labels(self.name).inc(duration) - - context = LoggingContext.current_context() - - if context != self.start_context: - logger.warning( - "Context has unexpectedly changed from '%s' to '%s'. (%r)", - self.start_context, - context, - self.name, - ) - return - - if not context: - logger.warning("Expected context. (%r)", self.name) - return + in_flight.unregister((self.name,), self._update_in_flight) + self._logging_context.__exit__(exc_type, exc_val, exc_tb) - current = context.get_resource_usage() - usage = current - self.start_usage try: + block_counter.labels(self.name).inc() + block_timer.labels(self.name).inc(duration) block_ru_utime.labels(self.name).inc(usage.ru_utime) block_ru_stime.labels(self.name).inc(usage.ru_stime) block_db_txn_count.labels(self.name).inc(usage.db_txn_count) block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec) block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec) except ValueError: - logger.warning( - "Failed to save metrics! OLD: %r, NEW: %r", self.start_usage, current - ) - - if self.created_context: - self.start_context.__exit__(exc_type, exc_val, exc_tb) + logger.warning("Failed to save metrics! Usage: %s", usage) def _update_in_flight(self, metrics): """Gets called when processing in flight metrics diff --git a/synapse/visibility.py b/synapse/visibility.py index 8c843febd8..dffe943b28 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -44,7 +44,12 @@ MEMBERSHIP_PRIORITY = ( @defer.inlineCallbacks def filter_events_for_client( - storage: Storage, user_id, events, is_peeking=False, always_include_ids=frozenset() + storage: Storage, + user_id, + events, + is_peeking=False, + always_include_ids=frozenset(), + apply_retention_policies=True, ): """ Check which events a user is allowed to see @@ -59,6 +64,10 @@ def filter_events_for_client( events always_include_ids (set(event_id)): set of event ids to specifically include (unless sender is ignored) + apply_retention_policies (bool): Whether to filter out events that's older than + allowed by the room's retention policy. Useful when this function is called + to e.g. check whether a user should be allowed to see the state at a given + event rather than to know if it should send an event to a user's client(s). Returns: Deferred[list[synapse.events.EventBase]] @@ -86,6 +95,15 @@ def filter_events_for_client( erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) + if apply_retention_policies: + room_ids = set(e.room_id for e in events) + retention_policies = {} + + for room_id in room_ids: + retention_policies[ + room_id + ] = yield storage.main.get_retention_policy_for_room(room_id) + def allowed(event): """ Args: @@ -103,6 +121,18 @@ def filter_events_for_client( if not event.is_state() and event.sender in ignore_list: return None + # Don't try to apply the room's retention policy if the event is a state event, as + # MSC1763 states that retention is only considered for non-state events. + if apply_retention_policies and not event.is_state(): + retention_policy = retention_policies[event.room_id] + max_lifetime = retention_policy.get("max_lifetime") + + if max_lifetime is not None: + oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime + + if event.origin_server_ts < oldest_allowed_ts: + return None + if event.event_id in always_include_ids: return event diff --git a/synmark/__init__.py b/synmark/__init__.py new file mode 100644 index 0000000000..afe4fad8cb --- /dev/null +++ b/synmark/__init__.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +from twisted.internet import epollreactor +from twisted.internet.main import installReactor + +from synapse.config.homeserver import HomeServerConfig +from synapse.util import Clock + +from tests.utils import default_config, setup_test_homeserver + + +async def make_homeserver(reactor, config=None): + """ + Make a Homeserver suitable for running benchmarks against. + + Args: + reactor: A Twisted reactor to run under. + config: A HomeServerConfig to use, or None. + """ + cleanup_tasks = [] + clock = Clock(reactor) + + if not config: + config = default_config("test") + + config_obj = HomeServerConfig() + config_obj.parse_config_dict(config, "", "") + + hs = await setup_test_homeserver( + cleanup_tasks.append, config=config_obj, reactor=reactor, clock=clock + ) + stor = hs.get_datastore() + + # Run the database background updates. + if hasattr(stor.db.updates, "do_next_background_update"): + while not await stor.db.updates.has_completed_background_updates(): + await stor.db.updates.do_next_background_update(1) + + def cleanup(): + for i in cleanup_tasks: + i() + + return hs, clock.sleep, cleanup + + +def make_reactor(): + """ + Instantiate and install a Twisted reactor suitable for testing (i.e. not the + default global one). + """ + reactor = epollreactor.EPollReactor() + + if "twisted.internet.reactor" in sys.modules: + del sys.modules["twisted.internet.reactor"] + installReactor(reactor) + + return reactor diff --git a/synmark/__main__.py b/synmark/__main__.py new file mode 100644 index 0000000000..ac59befbd4 --- /dev/null +++ b/synmark/__main__.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from contextlib import redirect_stderr +from io import StringIO + +import pyperf +from synmark import make_reactor +from synmark.suites import SUITES + +from twisted.internet.defer import ensureDeferred +from twisted.logger import globalLogBeginner, textFileLogObserver +from twisted.python.failure import Failure + +from tests.utils import setupdb + + +def make_test(main): + """ + Take a benchmark function and wrap it in a reactor start and stop. + """ + + def _main(loops): + + reactor = make_reactor() + + file_out = StringIO() + with redirect_stderr(file_out): + + d = ensureDeferred(main(reactor, loops)) + + def on_done(_): + if isinstance(_, Failure): + _.printTraceback() + print(file_out.getvalue()) + reactor.stop() + return _ + + d.addBoth(on_done) + reactor.run() + + return d.result + + return _main + + +if __name__ == "__main__": + + def add_cmdline_args(cmd, args): + if args.log: + cmd.extend(["--log"]) + + runner = pyperf.Runner( + processes=3, min_time=2, show_name=True, add_cmdline_args=add_cmdline_args + ) + runner.argparser.add_argument("--log", action="store_true") + runner.parse_args() + + orig_loops = runner.args.loops + runner.args.inherit_environ = ["SYNAPSE_POSTGRES"] + + if runner.args.worker: + if runner.args.log: + globalLogBeginner.beginLoggingTo( + [textFileLogObserver(sys.__stdout__)], redirectStandardIO=False + ) + setupdb() + + for suite, loops in SUITES: + if loops: + runner.args.loops = loops + else: + runner.args.loops = orig_loops + loops = "auto" + runner.bench_time_func( + suite.__name__ + "_" + str(loops), make_test(suite.main), + ) diff --git a/synmark/suites/__init__.py b/synmark/suites/__init__.py new file mode 100644 index 0000000000..cfa3b0ba38 --- /dev/null +++ b/synmark/suites/__init__.py @@ -0,0 +1,3 @@ +from . import logging + +SUITES = [(logging, 1000), (logging, 10000), (logging, None)] diff --git a/synmark/suites/logging.py b/synmark/suites/logging.py new file mode 100644 index 0000000000..d8e4c7d58f --- /dev/null +++ b/synmark/suites/logging.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from io import StringIO + +from mock import Mock + +from pyperf import perf_counter +from synmark import make_homeserver + +from twisted.internet.defer import Deferred +from twisted.internet.protocol import ServerFactory +from twisted.logger import LogBeginner, Logger, LogPublisher +from twisted.protocols.basic import LineOnlyReceiver + +from synapse.logging._structured import setup_structured_logging + + +class LineCounter(LineOnlyReceiver): + + delimiter = b"\n" + + def __init__(self, *args, **kwargs): + self.count = 0 + super().__init__(*args, **kwargs) + + def lineReceived(self, line): + self.count += 1 + + if self.count >= self.factory.wait_for and self.factory.on_done: + on_done = self.factory.on_done + self.factory.on_done = None + on_done.callback(True) + + +async def main(reactor, loops): + """ + Benchmark how long it takes to send `loops` messages. + """ + servers = [] + + def protocol(): + p = LineCounter() + servers.append(p) + return p + + logger_factory = ServerFactory.forProtocol(protocol) + logger_factory.wait_for = loops + logger_factory.on_done = Deferred() + port = reactor.listenTCP(0, logger_factory, interface="127.0.0.1") + + hs, wait, cleanup = await make_homeserver(reactor) + + errors = StringIO() + publisher = LogPublisher() + mock_sys = Mock() + beginner = LogBeginner( + publisher, errors, mock_sys, warnings, initialBufferSize=loops + ) + + log_config = { + "loggers": {"synapse": {"level": "DEBUG"}}, + "drains": { + "tersejson": { + "type": "network_json_terse", + "host": "127.0.0.1", + "port": port.getHost().port, + "maximum_buffer": 100, + } + }, + } + + logger = Logger(namespace="synapse.logging.test_terse_json", observer=publisher) + logging_system = setup_structured_logging( + hs, hs.config, log_config, logBeginner=beginner, redirect_stdlib_logging=False + ) + + # Wait for it to connect... + await logging_system._observers[0]._service.whenConnected() + + start = perf_counter() + + # Send a bunch of useful messages + for i in range(0, loops): + logger.info("test message %s" % (i,)) + + if ( + len(logging_system._observers[0]._buffer) + == logging_system._observers[0].maximum_buffer + ): + while ( + len(logging_system._observers[0]._buffer) + > logging_system._observers[0].maximum_buffer / 2 + ): + await wait(0.01) + + await logger_factory.on_done + + end = perf_counter() - start + + logging_system.stop() + port.stopListening() + cleanup() + + return end diff --git a/sytest-blacklist b/sytest-blacklist index 11785fd43f..79b2d4402a 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -1,6 +1,6 @@ # This file serves as a blacklist for SyTest tests that we expect will fail in # Synapse. -# +# # Each line of this file is scanned by sytest during a run and if the line # exactly matches the name of a test, it will be marked as "expected fail", # meaning the test will still run, but failure will not mark the entire test @@ -29,3 +29,10 @@ Enabling an unknown default rule fails with 404 # Blacklisted due to https://github.com/matrix-org/synapse/issues/1663 New federated private chats get full presence information (SYN-115) + +# Blacklisted due to https://github.com/matrix-org/matrix-doc/pull/2314 removing +# this requirement from the spec +Inbound federation of state requires event_id as a mandatory paramater + +# Blacklisted until https://github.com/matrix-org/synapse/pull/6486 lands +Can upload self-signing keys diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 2dc5052249..63d8633582 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 51714a2b06..24fa8dbb45 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -18,17 +18,14 @@ from mock import Mock from twisted.internet import defer from synapse.api.errors import Codes, SynapseError -from synapse.config.ratelimiting import FederationRateLimitConfig -from synapse.federation.transport import server from synapse.rest import admin from synapse.rest.client.v1 import login, room from synapse.types import UserID -from synapse.util.ratelimitutils import FederationRateLimiter from tests import unittest -class RoomComplexityTests(unittest.HomeserverTestCase): +class RoomComplexityTests(unittest.FederatingHomeserverTestCase): servlets = [ admin.register_servlets, @@ -41,25 +38,6 @@ class RoomComplexityTests(unittest.HomeserverTestCase): config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05} return config - def prepare(self, reactor, clock, homeserver): - class Authenticator(object): - def authenticate_request(self, request, content): - return defer.succeed("otherserver.nottld") - - ratelimiter = FederationRateLimiter( - clock, - FederationRateLimitConfig( - window_size=1, - sleep_limit=1, - sleep_msec=1, - reject_limit=1000, - concurrent_requests=1000, - ), - ) - server.register_servlets( - homeserver, self.resource, Authenticator(), ratelimiter - ) - def test_complexity_simple(self): u1 = self.register_user("u1", "pass") @@ -105,7 +83,7 @@ class RoomComplexityTests(unittest.HomeserverTestCase): d = handler._remote_join( None, - ["otherserver.example"], + ["other.example.com"], "roomid", UserID.from_string(u1), {"membership": "join"}, @@ -146,7 +124,7 @@ class RoomComplexityTests(unittest.HomeserverTestCase): d = handler._remote_join( None, - ["otherserver.example"], + ["other.example.com"], room_1, UserID.from_string(u1), {"membership": "join"}, diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index cce8d8c6de..d456267b87 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.types import ReadReceipt -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config class FederationSenderTestCases(HomeserverTestCase): @@ -29,6 +29,7 @@ class FederationSenderTestCases(HomeserverTestCase): federation_transport_client=Mock(spec=["send_transaction"]), ) + @override_config({"send_federation": True}) def test_send_receipts(self): mock_state_handler = self.hs.get_state_handler() mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] @@ -69,6 +70,7 @@ class FederationSenderTestCases(HomeserverTestCase): ], ) + @override_config({"send_federation": True}) def test_send_receipts_with_backoff(self): """Send two receipts in quick succession; the second should be flushed, but only after 20ms""" diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index b08be451aa..1ec8c40901 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2018 New Vector Ltd +# Copyright 2019 Matrix.org Federation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +17,8 @@ import logging from synapse.events import FrozenEvent from synapse.federation.federation_server import server_matches_acl_event +from synapse.rest import admin +from synapse.rest.client.v1 import login, room from tests import unittest @@ -41,6 +44,66 @@ class ServerACLsTestCase(unittest.TestCase): self.assertTrue(server_matches_acl_event("1:2:3:4", e)) +class StateQueryTests(unittest.FederatingHomeserverTestCase): + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def test_without_event_id(self): + """ + Querying v1/state/<room_id> without an event ID will return the current + known state. + """ + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + self.inject_room_member(room_1, "@user:other.example.com", "join") + + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/state/%s" % (room_1,) + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + self.assertEqual( + channel.json_body["room_version"], + self.hs.config.default_room_version.identifier, + ) + + members = set( + map( + lambda x: x["state_key"], + filter( + lambda x: x["type"] == "m.room.member", channel.json_body["pdus"] + ), + ) + ) + + self.assertEqual(members, set(["@user:other.example.com", u1])) + self.assertEqual(len(channel.json_body["pdus"]), 6) + + def test_needs_to_be_in_room(self): + """ + Querying v1/state/<room_id> requires the server + be in the room to provide data. + """ + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/state/%s" % (room_1,) + ) + self.render(request) + self.assertEquals(403, channel.code, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + + def _create_acl_event(content): return FrozenEvent( { diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py new file mode 100644 index 0000000000..27d83bb7d9 --- /dev/null +++ b/tests/federation/transport/test_server.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from synapse.config.ratelimiting import FederationRateLimitConfig +from synapse.federation.transport import server +from synapse.util.ratelimitutils import FederationRateLimiter + +from tests import unittest +from tests.unittest import override_config + + +class RoomDirectoryFederationTests(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + class Authenticator(object): + def authenticate_request(self, request, content): + return defer.succeed("otherserver.nottld") + + ratelimiter = FederationRateLimiter(clock, FederationRateLimitConfig()) + server.register_servlets( + homeserver, self.resource, Authenticator(), ratelimiter + ) + + @override_config({"allow_public_rooms_over_federation": False}) + def test_blocked_public_room_list_over_federation(self): + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/publicRooms" + ) + self.render(request) + self.assertEquals(403, channel.code) + + @override_config({"allow_public_rooms_over_federation": True}) + def test_open_public_room_list_over_federation(self): + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/publicRooms" + ) + self.render(request) + self.assertEquals(200, channel.code) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 854eb6c024..fdfa2cbbc4 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -183,6 +183,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) + test_replace_master_key.skip = ( + "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486" + ) + @defer.inlineCallbacks def test_reupload_signatures(self): """re-uploading a signature should not fail""" @@ -503,3 +507,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ], other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey], ) + + test_upload_signatures.skip = ( + "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486" + ) diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 0bb96674a2..70f172eb02 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # Copyright 2017 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -94,23 +95,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + version_etag = res["etag"] + del res["etag"] self.assertDictEqual( res, { "version": "1", "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", + "count": 0, }, ) # check we can retrieve it as a specific version res = yield self.handler.get_version_info(self.local_user, "1") + self.assertEqual(res["etag"], version_etag) + del res["etag"] self.assertDictEqual( res, { "version": "1", "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", + "count": 0, }, ) @@ -126,12 +133,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["etag"] self.assertDictEqual( res, { "version": "2", "algorithm": "m.megolm_backup.v1", "auth_data": "second_version_auth_data", + "count": 0, }, ) @@ -158,12 +167,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["etag"] self.assertDictEqual( res, { "algorithm": "m.megolm_backup.v1", "auth_data": "revised_first_version_auth_data", "version": version, + "count": 0, }, ) @@ -207,12 +218,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["etag"] # etag is opaque, so don't test its contents self.assertDictEqual( res, { "algorithm": "m.megolm_backup.v1", "auth_data": "revised_first_version_auth_data", "version": version, + "count": 0, }, ) @@ -409,6 +422,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): yield self.handler.upload_room_keys(self.local_user, version, room_keys) + # get the etag to compare to future versions + res = yield self.handler.get_version_info(self.local_user) + backup_etag = res["etag"] + self.assertEqual(res["count"], 1) + new_room_keys = copy.deepcopy(room_keys) new_room_key = new_room_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"] @@ -423,6 +441,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): "SSBBTSBBIEZJU0gK", ) + # the etag should be the same since the session did not change + res = yield self.handler.get_version_info(self.local_user) + self.assertEqual(res["etag"], backup_etag) + # test that marking the session as verified however /does/ replace it new_room_key["is_verified"] = True yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) @@ -432,6 +454,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) + # the etag should NOT be equal now, since the key changed + res = yield self.handler.get_version_info(self.local_user) + self.assertNotEqual(res["etag"], backup_etag) + backup_etag = res["etag"] + # test that a session with a higher forwarded_count doesn't replace one # with a lower forwarding count new_room_key["forwarded_count"] = 2 @@ -443,6 +470,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) + # the etag should be the same since the session did not change + res = yield self.handler.get_version_info(self.local_user) + self.assertEqual(res["etag"], backup_etag) + # TODO: check edge cases as well as the common variations here @defer.inlineCallbacks diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index e0075ccd32..d9d312f0fb 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -42,16 +42,16 @@ class StatsRoomTests(unittest.HomeserverTestCase): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms", @@ -61,7 +61,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -71,7 +71,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -82,7 +82,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) def get_all_room_state(self): - return self.store._simple_select_list( + return self.store.db.simple_select_list( "room_stats_state", None, retcols=("name", "topic", "canonical_alias") ) @@ -96,7 +96,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000) return self.get_success( - self.store._simple_select_one( + self.store.db.simple_select_one( table + "_historical", {id_col: stat_id, end_ts: end_ts}, cols, @@ -108,8 +108,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Do the initial population of the stats via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) def test_initial_room(self): """ @@ -141,8 +145,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) r = self.get_success(self.get_all_room_state()) @@ -178,9 +186,9 @@ class StatsRoomTests(unittest.HomeserverTestCase): # the position that the deltas should begin at, once they take over. self.hs.config.stats_enabled = True self.handler.stats_enabled = True - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( - self.store._simple_update_one( + self.store.db.simple_update_one( table="stats_incremental_position", keyvalues={}, updatevalues={"stream_id": 0}, @@ -188,14 +196,18 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Now, before the table is actually ingested, add some more events. self.helper.invite(room=room_1, src=u1, targ=u2, tok=u1_token) @@ -205,13 +217,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Now do the initial ingestion. self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", {"update_name": "populate_stats_process_rooms", "progress_json": "{}"}, ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -221,9 +233,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - self.store._all_done = False - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + self.store.db.updates._all_done = False + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) self.reactor.advance(86401) @@ -653,15 +669,15 @@ class StatsRoomTests(unittest.HomeserverTestCase): # preparation stage of the initial background update # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( - self.store._simple_delete( + self.store.db.simple_delete( "room_stats_current", {"1": 1}, "test_delete_stats" ) ) self.get_success( - self.store._simple_delete( + self.store.db.simple_delete( "user_stats_current", {"1": 1}, "test_delete_stats" ) ) @@ -673,9 +689,9 @@ class StatsRoomTests(unittest.HomeserverTestCase): # now do the background updates - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms", @@ -685,7 +701,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -695,7 +711,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -705,8 +721,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) r1stats_complete = self._get_current_stats("room", r1) u1stats_complete = self._get_current_stats("user", u1) diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 31f54bbd7d..758ee071a5 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -12,54 +12,53 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer from synapse.api.errors import Codes, ResourceLimitError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION -from synapse.handlers.sync import SyncConfig, SyncHandler +from synapse.handlers.sync import SyncConfig from synapse.types import UserID import tests.unittest import tests.utils -from tests.utils import setup_test_homeserver -class SyncTestCase(tests.unittest.TestCase): +class SyncTestCase(tests.unittest.HomeserverTestCase): """ Tests Sync Handler. """ - @defer.inlineCallbacks - def setUp(self): - self.hs = yield setup_test_homeserver(self.addCleanup) - self.sync_handler = SyncHandler(self.hs) + def prepare(self, reactor, clock, hs): + self.hs = hs + self.sync_handler = self.hs.get_sync_handler() self.store = self.hs.get_datastore() - @defer.inlineCallbacks def test_wait_for_sync_for_user_auth_blocking(self): user_id1 = "@user1:server" user_id2 = "@user2:server" sync_config = self._generate_sync_config(user_id1) + self.reactor.advance(100) # So we get not 0 time self.hs.config.limit_usage_by_mau = True self.hs.config.max_mau_value = 1 # Check that the happy case does not throw errors - yield self.store.upsert_monthly_active_user(user_id1) - yield self.sync_handler.wait_for_sync_for_user(sync_config) + self.get_success(self.store.upsert_monthly_active_user(user_id1)) + self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config)) # Test that global lock works self.hs.config.hs_disabled = True - with self.assertRaises(ResourceLimitError) as e: - yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + e = self.get_failure( + self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError + ) + self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.hs.config.hs_disabled = False sync_config = self._generate_sync_config(user_id2) - with self.assertRaises(ResourceLimitError) as e: - yield self.sync_handler.wait_for_sync_for_user(sync_config) - self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + e = self.get_failure( + self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError + ) + self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) def _generate_sync_config(self, user_id): return SyncConfig( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 5ec568f4e6..92b8726093 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -24,6 +24,7 @@ from synapse.api.errors import AuthError from synapse.types import UserID from tests import unittest +from tests.unittest import override_config from tests.utils import register_federation_servlets # Some local users to test with @@ -162,7 +163,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ @@ -174,6 +177,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) + @override_config({"send_federation": True}) def test_started_typing_remote_send(self): self.room_members = [U_APPLE, U_ONION] @@ -225,7 +229,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ @@ -237,6 +243,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) + @override_config({"send_federation": True}) def test_stopped_typing(self): self.room_members = [U_APPLE, U_BANANA, U_ONION] @@ -276,7 +283,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], @@ -297,7 +306,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ @@ -314,7 +325,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])]) self.assertEquals(self.event_source.get_current_key(), 2) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1) + ) self.assertEquals( events[0], [{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}], @@ -332,7 +345,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 3) - events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + events = self.get_success( + self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0) + ) self.assertEquals( events[0], [ diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index c5e91a8c41..26071059d2 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -158,7 +158,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_in_public_rooms(self): r = self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( "users_in_public_rooms", None, ("user_id", "room_id") ) ) @@ -169,7 +169,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_who_share_private_rooms(self): return self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( "users_who_share_private_rooms", None, ["user_id", "other_user_id", "room_id"], @@ -181,10 +181,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): Add the background updates we need to run. """ # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_createtables", @@ -193,7 +193,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_rooms", @@ -203,7 +203,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_users", @@ -213,7 +213,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( "background_updates", { "update_name": "populate_user_directory_cleanup", @@ -255,8 +255,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) shares_private = self.get_users_who_share_private_rooms() public_users = self.get_users_in_public_rooms() @@ -290,8 +294,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Do the initial population of the user directory via the background update self._add_background_updates() - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) shares_private = self.get_users_who_share_private_rooms() public_users = self.get_users_in_public_rooms() diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 4f924ce451..3dae83c543 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -20,6 +20,7 @@ from synapse.replication.tcp.client import ( ReplicationClientHandler, ) from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.storage.database import Database from tests import unittest from tests.server import FakeTransport @@ -42,13 +43,18 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): self.master_store = self.hs.get_datastore() self.storage = hs.get_storage() - self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) + self.slaved_store = self.STORE_TYPE( + Database(hs), self.hs.get_db_conn(), self.hs + ) self.event_id = 0 server_factory = ReplicationStreamProtocolFactory(self.hs) self.streamer = server_factory.streamer + handler_factory = Mock() self.replication_handler = ReplicationClientHandler(self.slaved_store) + self.replication_handler.factory = handler_factory + client_factory = ReplicationClientFactory( self.hs, "client_name", self.replication_handler ) diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index ce3835ae6a..1d14e77255 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from synapse.replication.tcp.commands import ReplicateCommand from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -30,7 +32,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): server = server_factory.buildProtocol(None) # build a replication client, with a dummy handler + handler_factory = Mock() self.test_handler = TestReplicationClientHandler() + self.test_handler.factory = handler_factory self.client = ClientReplicationStreamProtocol( "client", "test", clock, self.test_handler ) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 9575058252..0ed2594381 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -632,7 +632,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): "state_groups_state", ): count = self.get_success( - self.store._simple_select_one_onecol( + self.store.db.simple_select_one_onecol( table=table, keyvalues={"room_id": room_id}, retcol="COUNT(*)", diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py new file mode 100644 index 0000000000..5e9c07ebf3 --- /dev/null +++ b/tests/rest/client/test_ephemeral_message.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.constants import EventContentFields, EventTypes +from synapse.rest import admin +from synapse.rest.client.v1 import room + +from tests import unittest + + +class EphemeralMessageTestCase(unittest.HomeserverTestCase): + + user_id = "@user:test" + + servlets = [ + admin.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + config["enable_ephemeral_messages"] = True + + self.hs = self.setup_test_homeserver(config=config) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_message_expiry_no_delay(self): + """Tests that sending a message sent with a m.self_destruct_after field set to the + past results in that event being deleted right away. + """ + # Send a message in the room that has expired. From here, the reactor clock is + # at 200ms, so 0 is in the past, and even if that wasn't the case and the clock + # is at 0ms the code path is the same if the event's expiry timestamp is the + # current timestamp. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "hello", + EventContentFields.SELF_DESTRUCT_AFTER: 0, + }, + ) + event_id = res["event_id"] + + # Check that we can't retrieve the content of the event. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertFalse(bool(event_content), event_content) + + def test_message_expiry_delay(self): + """Tests that sending a message with a m.self_destruct_after field set to the + future results in that event not being deleted right away, but advancing the + clock to after that expiry timestamp causes the event to be deleted. + """ + # Send a message in the room that'll expire in 1s. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "hello", + EventContentFields.SELF_DESTRUCT_AFTER: self.clock.time_msec() + 1000, + }, + ) + event_id = res["event_id"] + + # Check that we can retrieve the content of the event before it has expired. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertTrue(bool(event_content), event_content) + + # Advance the clock to after the deletion. + self.reactor.advance(1) + + # Check that we can't retrieve the content of the event anymore. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertFalse(bool(event_content), event_content) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py new file mode 100644 index 0000000000..95475bb651 --- /dev/null +++ b/tests/rest/client/test_retention.py @@ -0,0 +1,293 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from mock import Mock + +from synapse.api.constants import EventTypes +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.visibility import filter_events_for_client + +from tests import unittest + +one_hour_ms = 3600000 +one_day_ms = one_hour_ms * 24 + + +class RetentionTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["retention"] = { + "enabled": True, + "default_policy": { + "min_lifetime": one_day_ms, + "max_lifetime": one_day_ms * 3, + }, + "allowed_lifetime_min": one_day_ms, + "allowed_lifetime_max": one_day_ms * 3, + } + + self.hs = self.setup_test_homeserver(config=config) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.token = self.login("user", "password") + + def test_retention_state_event(self): + """Tests that the server configuration can limit the values a user can set to the + room's retention policy. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={"max_lifetime": one_day_ms * 4}, + tok=self.token, + expect_code=400, + ) + + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={"max_lifetime": one_hour_ms}, + tok=self.token, + expect_code=400, + ) + + def test_retention_event_purged_with_state_event(self): + """Tests that expired events are correctly purged when the room's retention policy + is defined by a state event. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Set the room's retention period to 2 days. + lifetime = one_day_ms * 2 + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={"max_lifetime": lifetime}, + tok=self.token, + ) + + self._test_retention_event_purged(room_id, one_day_ms * 1.5) + + def test_retention_event_purged_without_state_event(self): + """Tests that expired events are correctly purged when the room's retention policy + is defined by the server's configuration's default retention policy. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self._test_retention_event_purged(room_id, one_day_ms * 2) + + def test_visibility(self): + """Tests that synapse.visibility.filter_events_for_client correctly filters out + outdated events + """ + store = self.hs.get_datastore() + storage = self.hs.get_storage() + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + events = [] + + # Send a first event, which should be filtered out at the end of the test. + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) + + # Get the event from the store so that we end up with a FrozenEvent that we can + # give to filter_events_for_client. We need to do this now because the event won't + # be in the database anymore after it has expired. + events.append(self.get_success(store.get_event(resp.get("event_id")))) + + # Advance the time by 2 days. We're using the default retention policy, therefore + # after this the first event will still be valid. + self.reactor.advance(one_day_ms * 2 / 1000) + + # Send another event, which shouldn't get filtered out. + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) + + valid_event_id = resp.get("event_id") + + events.append(self.get_success(store.get_event(valid_event_id))) + + # Advance the time by anothe 2 days. After this, the first event should be + # outdated but not the second one. + self.reactor.advance(one_day_ms * 2 / 1000) + + # Run filter_events_for_client with our list of FrozenEvents. + filtered_events = self.get_success( + filter_events_for_client(storage, self.user_id, events) + ) + + # We should only get one event back. + self.assertEqual(len(filtered_events), 1, filtered_events) + # That event should be the second, not outdated event. + self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events) + + def _test_retention_event_purged(self, room_id, increment): + # Get the create event to, later, check that we can still access it. + message_handler = self.hs.get_message_handler() + create_event = self.get_success( + message_handler.get_room_data(self.user_id, room_id, EventTypes.Create) + ) + + # Send a first event to the room. This is the event we'll want to be purged at the + # end of the test. + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) + + expired_event_id = resp.get("event_id") + + # Check that we can retrieve the event. + expired_event = self.get_event(room_id, expired_event_id) + self.assertEqual( + expired_event.get("content", {}).get("body"), "1", expired_event + ) + + # Advance the time. + self.reactor.advance(increment / 1000) + + # Send another event. We need this because the purge job won't purge the most + # recent event in the room. + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) + + valid_event_id = resp.get("event_id") + + # Advance the time again. Now our first event should have expired but our second + # one should still be kept. + self.reactor.advance(increment / 1000) + + # Check that the event has been purged from the database. + self.get_event(room_id, expired_event_id, expected_code=404) + + # Check that the event that hasn't been purged can still be retrieved. + valid_event = self.get_event(room_id, valid_event_id) + self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event) + + # Check that we can still access state events that were sent before the event that + # has been purged. + self.get_event(room_id, create_event.event_id) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url, access_token=self.token) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body + + +class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["retention"] = { + "enabled": True, + } + + mock_federation_client = Mock(spec=["backfill"]) + + self.hs = self.setup_test_homeserver( + config=config, federation_client=mock_federation_client, + ) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.token = self.login("user", "password") + + def test_no_default_policy(self): + """Tests that an event doesn't get expired if there is neither a default retention + policy nor a policy specific to the room. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self._test_retention(room_id) + + def test_state_policy(self): + """Tests that an event gets correctly expired if there is no default retention + policy but there's a policy specific to the room. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Set the maximum lifetime to 35 days so that the first event gets expired but not + # the second one. + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={"max_lifetime": one_day_ms * 35}, + tok=self.token, + ) + + self._test_retention(room_id, expected_code_for_first_event=404) + + def _test_retention(self, room_id, expected_code_for_first_event=200): + # Send a first event to the room. This is the event we'll want to be purged at the + # end of the test. + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) + + first_event_id = resp.get("event_id") + + # Check that we can retrieve the event. + expired_event = self.get_event(room_id, first_event_id) + self.assertEqual( + expired_event.get("content", {}).get("body"), "1", expired_event + ) + + # Advance the time by a month. + self.reactor.advance(one_day_ms * 30 / 1000) + + # Send another event. We need this because the purge job won't purge the most + # recent event in the room. + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) + + second_event_id = resp.get("event_id") + + # Advance the time by another month. + self.reactor.advance(one_day_ms * 30 / 1000) + + # Check if the event has been purged from the database. + first_event = self.get_event( + room_id, first_event_id, expected_code=expected_code_for_first_event + ) + + if expected_code_for_first_event == 200: + self.assertEqual( + first_event.get("content", {}).get("body"), "1", first_event + ) + + # Check that the event that hasn't been purged can still be retrieved. + second_event = self.get_event(room_id, second_event_id) + self.assertEqual(second_event.get("content", {}).get("body"), "2", second_event) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url, access_token=self.token) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 66c2b68707..0fdff79aa7 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -15,6 +15,8 @@ from mock import Mock +from twisted.internet import defer + from synapse.rest.client.v1 import presence from synapse.types import UserID @@ -36,6 +38,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): ) hs.presence_handler = Mock() + hs.presence_handler.set_state.return_value = defer.succeed(None) return hs diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 140d8b3772..12c5e95cb5 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -52,6 +52,14 @@ class MockHandlerProfileTestCase(unittest.TestCase): ] ) + self.mock_handler.get_displayname.return_value = defer.succeed(Mock()) + self.mock_handler.set_displayname.return_value = defer.succeed(Mock()) + self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock()) + self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock()) + self.mock_handler.check_profile_query_allowed.return_value = defer.succeed( + Mock() + ) + hs = yield setup_test_homeserver( self.addCleanup, "test", @@ -63,7 +71,7 @@ class MockHandlerProfileTestCase(unittest.TestCase): ) def _get_user_by_req(request=None, allow_guest=False): - return synapse.types.create_requester(myid) + return defer.succeed(synapse.types.create_requester(myid)) hs.get_auth().get_user_by_req = _get_user_by_req diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index e84e578f99..1ca7fa742f 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd # Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -813,105 +815,6 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) - def test_filter_labels(self): - """Test that we can filter by a label.""" - message_filter = json.dumps( - {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]} - ) - - events = self._test_filter_labels(message_filter) - - self.assertEqual(len(events), 2, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) - - def test_filter_not_labels(self): - """Test that we can filter by the absence of a label.""" - message_filter = json.dumps( - {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]} - ) - - events = self._test_filter_labels(message_filter) - - self.assertEqual(len(events), 3, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "without label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) - self.assertEqual( - events[2]["content"]["body"], "with two wrong labels", events[2] - ) - - def test_filter_labels_not_labels(self): - """Test that we can filter by both a label and the absence of another label.""" - sync_filter = json.dumps( - { - "types": [EventTypes.Message], - "org.matrix.labels": ["#work"], - "org.matrix.not_labels": ["#notfun"], - } - ) - - events = self._test_filter_labels(sync_filter) - - self.assertEqual(len(events), 1, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) - - def _test_filter_labels(self, message_filter): - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "without label"}, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with wrong label", - EventContentFields.LABELS: ["#work"], - }, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with two wrong labels", - EventContentFields.LABELS: ["#work", "#notfun"], - }, - ) - - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - ) - - token = "s0_0_0_0_0_0_0_0_0" - request, channel = self.make_request( - "GET", - "/rooms/%s/messages?access_token=x&from=%s&filter=%s" - % (self.room_id, token, message_filter), - ) - self.render(request) - - return channel.json_body["chunk"] - def test_room_messages_purge(self): store = self.hs.get_datastore() pagination_handler = self.hs.get_pagination_handler() @@ -1180,3 +1083,517 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): res_displayname = channel.json_body["content"]["displayname"] self.assertEqual(res_displayname, self.displayname, channel.result) + + +class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): + """Tests that clients can add a "reason" field to membership events and + that they get correctly added to the generated events and propagated. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) + + def test_join_reason(self): + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/join".format(self.room_id), + content={"reason": reason}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_leave_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + content={"reason": reason}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_kick_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/kick".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_ban_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/ban".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_unban_reason(self): + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/unban".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_invite_reason(self): + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/invite".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_reject_invite_reason(self): + self.helper.invite( + self.room_id, + src=self.creator, + targ=self.second_user_id, + tok=self.creator_tok, + ) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + content={"reason": reason}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def _check_for_reason(self, reason): + request, channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format( + self.room_id, self.second_user_id + ), + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + event_content = channel.json_body + + self.assertEqual(event_content.get("reason"), reason, channel.result) + + +class LabelsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + profile.register_servlets, + ] + + # Filter that should only catch messages with the label "#fun". + FILTER_LABELS = { + "types": [EventTypes.Message], + "org.matrix.labels": ["#fun"], + } + # Filter that should only catch messages without the label "#fun". + FILTER_NOT_LABELS = { + "types": [EventTypes.Message], + "org.matrix.not_labels": ["#fun"], + } + # Filter that should only catch messages with the label "#work" but without the label + # "#notfun". + FILTER_LABELS_NOT_LABELS = { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("test", "test") + self.tok = self.login("test", "test") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + def test_context_filter_labels(self): + """Test that we can filter by a label on a /context request.""" + event_id = self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 1, [event["content"] for event in events_before] + ) + self.assertEqual( + events_before[0]["content"]["body"], "with right label", events_before[0] + ) + + events_after = channel.json_body["events_before"] + + self.assertEqual( + len(events_after), 1, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with right label", events_after[0] + ) + + def test_context_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /context request.""" + event_id = self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 1, [event["content"] for event in events_before] + ) + self.assertEqual( + events_before[0]["content"]["body"], "without label", events_before[0] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual( + len(events_after), 2, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with wrong label", events_after[0] + ) + self.assertEqual( + events_after[1]["content"]["body"], "with two wrong labels", events_after[1] + ) + + def test_context_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /context request. + """ + event_id = self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 0, [event["content"] for event in events_before] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual( + len(events_after), 1, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with wrong label", events_after[0] + ) + + def test_messages_filter_labels(self): + """Test that we can filter by a label on a /messages request.""" + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)), + ) + self.render(request) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 2, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) + + def test_messages_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /messages request.""" + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)), + ) + self.render(request) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 4, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "without label", events[0]) + self.assertEqual(events[1]["content"]["body"], "without label", events[1]) + self.assertEqual(events[2]["content"]["body"], "with wrong label", events[2]) + self.assertEqual( + events[3]["content"]["body"], "with two wrong labels", events[3] + ) + + def test_messages_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /messages request. + """ + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % ( + self.room_id, + self.tok, + token, + json.dumps(self.FILTER_LABELS_NOT_LABELS), + ), + ) + self.render(request) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 1, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) + + def test_search_filter_labels(self): + """Test that we can filter by a label on a /search request.""" + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), 2, [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "with right label", + results[0]["result"]["content"]["body"], + ) + self.assertEqual( + results[1]["result"]["content"]["body"], + "with right label", + results[1]["result"]["content"]["body"], + ) + + def test_search_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /search request.""" + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_NOT_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), 4, [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "without label", + results[0]["result"]["content"]["body"], + ) + self.assertEqual( + results[1]["result"]["content"]["body"], + "without label", + results[1]["result"]["content"]["body"], + ) + self.assertEqual( + results[2]["result"]["content"]["body"], + "with wrong label", + results[2]["result"]["content"]["body"], + ) + self.assertEqual( + results[3]["result"]["content"]["body"], + "with two wrong labels", + results[3]["result"]["content"]["body"], + ) + + def test_search_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /search request. + """ + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS_NOT_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), 1, [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "with wrong label", + results[0]["result"]["content"]["body"], + ) + + def _send_labelled_messages_in_room(self): + """Sends several messages to a room with different labels (or without any) to test + filtering by label. + Returns: + The ID of the event to use if we're testing filtering on /context. + """ + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=self.tok, + ) + + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=self.tok, + ) + # Return this event's ID when we test filtering in /context requests. + event_id = res["event_id"] + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with wrong label", + EventContentFields.LABELS: ["#work"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with two wrong labels", + EventContentFields.LABELS: ["#work", "#notfun"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=self.tok, + ) + + return event_id diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 30fb77bac8..4bc3aaf02d 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -109,7 +109,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code) self.assertEquals(self.event_source.get_current_key(), 1) - events = self.event_source.get_new_events(from_key=0, room_ids=[self.room_id]) + events = self.get_success( + self.event_source.get_new_events(from_key=0, room_ids=[self.room_id]) + ) self.assertEquals( events[0], [ diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 8ea0cb05ea..e7417b3d14 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 3283c0e47b..661c1f88b9 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vector +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 9b81b536f5..d491ea2924 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -323,7 +323,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): self.table_name = "table_" + hs.get_secrets().token_hex(6) self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "create", lambda x, *a: x.execute(*a), "CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)" @@ -331,7 +331,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "index", lambda x, *a: x.execute(*a), "CREATE UNIQUE INDEX %sindex ON %s(id, username)" @@ -354,9 +354,9 @@ class UpsertManyTests(unittest.HomeserverTestCase): value_values = [["hello"], ["there"]] self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "test", - self.storage._simple_upsert_many_txn, + self.storage.db.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -367,7 +367,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage._simple_select_list( + self.storage.db.simple_select_list( self.table_name, None, ["id, username, value"] ) ) @@ -381,9 +381,9 @@ class UpsertManyTests(unittest.HomeserverTestCase): value_values = [["bleb"]] self.get_success( - self.storage.runInteraction( + self.storage.db.runInteraction( "test", - self.storage._simple_upsert_many_txn, + self.storage.db.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -394,7 +394,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage._simple_select_list( + self.storage.db.simple_select_list( self.table_name, None, ["id, username, value"] ) ) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index dfeea24599..2e521e9ab7 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -28,6 +28,7 @@ from synapse.storage.data_stores.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) +from synapse.storage.database import Database from tests import unittest from tests.utils import setup_test_homeserver @@ -54,7 +55,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - self.store = ApplicationServiceStore(hs.get_db_conn(), hs) + database = Database(hs) + self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs) def tearDown(self): # TODO: suboptimal that we need to create files for tests! @@ -123,7 +125,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files = [] - self.store = TestTransactionStore(hs.get_db_conn(), hs) + database = Database(hs) + self.store = TestTransactionStore(database, hs.get_db_conn(), hs) def _add_service(self, url, as_token, id): as_yaml = dict( @@ -382,8 +385,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): # required for ApplicationServiceTransactionStoreTestCase tests class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore): - def __init__(self, db_conn, hs): - super(TestTransactionStore, self).__init__(db_conn, hs) + def __init__(self, database: Database, db_conn, hs): + super(TestTransactionStore, self).__init__(database, db_conn, hs) class ApplicationServiceStoreConfigTestCase(unittest.TestCase): @@ -416,7 +419,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.event_cache_size = 1 hs.config.password_providers = [] - ApplicationServiceStore(hs.get_db_conn(), hs) + ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) @defer.inlineCallbacks def test_duplicate_ids(self): @@ -432,7 +435,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(hs.get_db_conn(), hs) + ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) e = cm.exception self.assertIn(f1, str(e)) @@ -453,7 +456,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(hs.get_db_conn(), hs) + ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) e = cm.exception self.assertIn(f1, str(e)) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 9fabe3fbc0..aec76f4ab1 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -15,7 +15,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): self.update_handler = Mock() - yield self.store.register_background_update_handler( + yield self.store.db.updates.register_background_update_handler( "test_update", self.update_handler ) @@ -23,7 +23,7 @@ class BackgroundUpdateTestCase(unittest.TestCase): # (perhaps we should run them as part of the test HS setup, since we # run all of the other schema setup stuff there?) while True: - res = yield self.store.do_next_background_update(1000) + res = yield self.store.db.updates.do_next_background_update(1000) if res is None: break @@ -37,9 +37,9 @@ class BackgroundUpdateTestCase(unittest.TestCase): def update(progress, count): self.clock.advance_time_msec(count * duration_ms) progress = {"my_key": progress["my_key"] + 1} - yield self.store.runInteraction( + yield self.store.db.runInteraction( "update_progress", - self.store._background_update_progress_txn, + self.store.db.updates._background_update_progress_txn, "test_update", progress, ) @@ -47,29 +47,37 @@ class BackgroundUpdateTestCase(unittest.TestCase): self.update_handler.side_effect = update - yield self.store.start_background_update("test_update", {"my_key": 1}) + yield self.store.db.updates.start_background_update( + "test_update", {"my_key": 1} + ) self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) + result = yield self.store.db.updates.do_next_background_update( + duration_ms * desired_count + ) self.assertIsNotNone(result) self.update_handler.assert_called_once_with( - {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE + {"my_key": 1}, self.store.db.updates.DEFAULT_BACKGROUND_BATCH_SIZE ) # second step: complete the update @defer.inlineCallbacks def update(progress, count): - yield self.store._end_background_update("test_update") + yield self.store.db.updates._end_background_update("test_update") return count self.update_handler.side_effect = update self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) + result = yield self.store.db.updates.do_next_background_update( + duration_ms * desired_count + ) self.assertIsNotNone(result) self.update_handler.assert_called_once_with({"my_key": 2}, desired_count) # third step: we don't expect to be called any more self.update_handler.reset_mock() - result = yield self.store.do_next_background_update(duration_ms * desired_count) + result = yield self.store.db.updates.do_next_background_update( + duration_ms * desired_count + ) self.assertIsNone(result) self.assertFalse(self.update_handler.called) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index c778de1f0c..537cfe9f64 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -21,6 +21,7 @@ from mock import Mock from twisted.internet import defer from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database from synapse.storage.engines import create_engine from tests import unittest @@ -59,13 +60,13 @@ class SQLBaseStoreTestCase(unittest.TestCase): "test", db_pool=self.db_pool, config=config, database_engine=fake_engine ) - self.datastore = SQLBaseStore(None, hs) + self.datastore = SQLBaseStore(Database(hs), None, hs) @defer.inlineCallbacks def test_insert_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_insert( + yield self.datastore.db.simple_insert( table="tablename", values={"columname": "Value"} ) @@ -77,7 +78,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_3cols(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_insert( + yield self.datastore.db.simple_insert( table="tablename", # Use OrderedDict() so we can assert on the SQL generated values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), @@ -92,7 +93,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) - value = yield self.datastore._simple_select_one_onecol( + value = yield self.datastore.db.simple_select_one_onecol( table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" ) @@ -106,7 +107,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) - ret = yield self.datastore._simple_select_one( + ret = yield self.datastore.db.simple_select_one( table="tablename", keyvalues={"keycol": "TheKey"}, retcols=["colA", "colB", "colC"], @@ -122,7 +123,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None - ret = yield self.datastore._simple_select_one( + ret = yield self.datastore.db.simple_select_one( table="tablename", keyvalues={"keycol": "Not here"}, retcols=["colA"], @@ -137,7 +138,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) - ret = yield self.datastore._simple_select_list( + ret = yield self.datastore.db.simple_select_list( table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] ) @@ -150,7 +151,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_update_one( + yield self.datastore.db.simple_update_one( table="tablename", keyvalues={"keycol": "TheKey"}, updatevalues={"columnname": "New Value"}, @@ -165,7 +166,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_4cols(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_update_one( + yield self.datastore.db.simple_update_one( table="tablename", keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), @@ -180,7 +181,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_delete_one(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_delete_one( + yield self.datastore.db.simple_delete_one( table="tablename", keyvalues={"keycol": "Go away"} ) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 69dcaa63d5..029ac26454 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -46,7 +46,9 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): """Re run the background update to clean up the extremities. """ # Make sure we don't clash with in progress updates. - self.assertTrue(self.store._all_done, "Background updates are still ongoing") + self.assertTrue( + self.store.db.updates._all_done, "Background updates are still ongoing" + ) schema_path = os.path.join( prepare_database.dir_path, @@ -62,14 +64,20 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): prepare_database.executescript(txn, schema_path) self.get_success( - self.store.runInteraction("test_delete_forward_extremities", run_delta_file) + self.store.db.runInteraction( + "test_delete_forward_extremities", run_delta_file + ) ) # Ugh, have to reset this flag - self.store._all_done = False + self.store.db.updates._all_done = False - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) def test_soft_failed_extremities_handled_correctly(self): """Test that extremities are correctly calculated in the presence of diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index afac5dec7f..fc279340d4 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -81,7 +81,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -112,7 +112,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -202,8 +202,12 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): def test_devices_last_seen_bg_update(self): # First make sure we have completed all updates. - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Insert a user IP user_id = "@user:id" @@ -218,7 +222,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # But clear the associated entry in devices table self.get_success( - self.store._simple_update( + self.store.db.simple_update( table="devices", keyvalues={"user_id": user_id, "device_id": "device_id"}, updatevalues={"last_seen": None, "ip": None, "user_agent": None}, @@ -245,7 +249,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # Register the background update to run again. self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( table="background_updates", values={ "update_name": "devices_last_seen", @@ -256,11 +260,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.store._all_done = False + self.store.db.updates._all_done = False # Now let's actually drive the updates to completion - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # We should now get the correct result again result = self.get_success( @@ -281,8 +289,12 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): def test_old_user_ips_pruned(self): # First make sure we have completed all updates. - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Insert a user IP user_id = "@user:id" @@ -297,7 +309,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should see that in the DB result = self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -323,7 +335,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should get no results. result = self.get_success( - self.store._simple_select_list( + self.store.db.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index d128fde441..35dafbb904 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -39,8 +39,8 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): ) self.get_success( - self.store.set_e2e_room_key( - "user_id", version1, "room", "session", room_key + self.store.add_e2e_room_keys( + "user_id", version1, [("room", "session", room_key)] ) ) @@ -51,8 +51,8 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): ) self.get_success( - self.store.set_e2e_room_key( - "user_id", version2, "room", "session", room_key + self.store.add_e2e_room_keys( + "user_id", version2, [("room", "session", room_key)] ) ) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 2fe50377f8..eadfb90a22 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -61,7 +61,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): ) for i in range(0, 11): - yield self.store.runInteraction("insert", insert_event, i) + yield self.store.db.runInteraction("insert", insert_event, i) # this should get the last five and five others r = yield self.store.get_prev_events_for_room(room_id) @@ -93,9 +93,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase): ) for i in range(0, 20): - yield self.store.runInteraction("insert", insert_event, i, room1) - yield self.store.runInteraction("insert", insert_event, i, room2) - yield self.store.runInteraction("insert", insert_event, i, room3) + yield self.store.db.runInteraction("insert", insert_event, i, room1) + yield self.store.db.runInteraction("insert", insert_event, i, room2) + yield self.store.db.runInteraction("insert", insert_event, i, room3) # Test simple case r = yield self.store.get_rooms_with_many_extremities(5, 5, []) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index b114c6fb1d..d4bcf1821e 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -55,7 +55,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def _assert_counts(noitf_count, highlight_count): - counts = yield self.store.runInteraction( + counts = yield self.store.db.runInteraction( "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 ) self.assertEquals( @@ -74,7 +74,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield self.store.add_push_actions_to_staging( event.event_id, {user_id: action} ) - yield self.store.runInteraction( + yield self.store.db.runInteraction( "", self.store._set_push_actions_for_event_and_users_txn, [(event, None)], @@ -82,12 +82,12 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): ) def _rotate(stream): - return self.store.runInteraction( + return self.store.db.runInteraction( "", self.store._rotate_notifs_before_txn, stream ) def _mark_read(stream, depth): - return self.store.runInteraction( + return self.store.db.runInteraction( "", self.store._remove_old_push_actions_before_txn, room_id, @@ -116,7 +116,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield _inject_actions(6, PlAIN_NOTIF) yield _rotate(7) - yield self.store._simple_delete( + yield self.store.db.simple_delete( table="event_push_actions", keyvalues={"1": 1}, desc="" ) @@ -135,7 +135,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return self.store._simple_insert( + return self.store.db.simple_insert( "events", { "stream_ordering": so, diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 90a63dc477..3c78faab45 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -65,7 +65,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.user_add_threepid(user1, "email", user1_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now) - self.store.runInteraction( + self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) self.pump() @@ -183,7 +183,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): ) self.hs.config.mau_limits_reserved_threepids = threepids - self.store.runInteraction( + self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) count = self.store.get_monthly_active_count() @@ -244,7 +244,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): {"medium": "email", "address": user2_email}, ] self.hs.config.mau_limits_reserved_threepids = threepids - self.store.runInteraction( + self.store.db.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 24c7fe16c3..9b6f7211ae 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -16,7 +16,6 @@ from twisted.internet import defer -from synapse.storage.data_stores.main.profile import ProfileStore from synapse.types import UserID from tests import unittest @@ -28,7 +27,7 @@ class ProfileStoreTestCase(unittest.TestCase): def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.store = ProfileStore(hs.get_db_conn(), hs) + self.store = hs.get_datastore() self.u_frank = UserID.from_string("@frank:test") diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 4561c3e383..dc45173355 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -338,7 +338,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) event_json = self.get_success( - self.store._simple_select_one_onecol( + self.store.db.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", @@ -356,7 +356,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.reactor.advance(60 * 60 * 2) event_json = self.get_success( - self.store._simple_select_one_onecol( + self.store.db.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 9ddd17f73d..7840f63fe3 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -16,8 +16,7 @@ from unittest.mock import Mock -from synapse.api.constants import EventTypes, Membership -from synapse.api.room_versions import RoomVersions +from synapse.api.constants import Membership from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client.v1 import login, room from synapse.types import Requester, UserID @@ -44,9 +43,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # We can't test the RoomMemberStore on its own without the other event # storage logic self.store = hs.get_datastore() - self.storage = hs.get_storage() - self.event_builder_factory = hs.get_event_builder_factory() - self.event_creation_handler = hs.get_event_creation_handler() self.u_alice = self.register_user("alice", "pass") self.t_alice = self.login("alice", "pass") @@ -55,26 +51,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # User elsewhere on another host self.u_charlie = UserID.from_string("@charlie:elsewhere") - def inject_room_member(self, room, user, membership, replaces_state=None): - builder = self.event_builder_factory.for_room_version( - RoomVersions.V1, - { - "type": EventTypes.Member, - "sender": user, - "state_key": user, - "room_id": room, - "content": {"membership": membership}, - }, - ) - - event, context = self.get_success( - self.event_creation_handler.create_new_client_event(builder) - ) - - self.get_success(self.storage.persistence.persist_event(event, context)) - - return event - def test_one_member(self): # Alice creates the room, and is automatically joined @@ -146,8 +122,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): def test_can_rerun_update(self): # First make sure we have completed all updates. - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) # Now let's create a room, which will insert a membership user = UserID("alice", "test") @@ -156,7 +136,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): # Register the background update to run again. self.get_success( - self.store._simple_insert( + self.store.db.simple_insert( table="background_updates", values={ "update_name": "current_state_events_membership", @@ -167,8 +147,12 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): ) # ... and tell the DataStore that it hasn't finished all updates yet - self.store._all_done = False + self.store.db.updates._all_done = False # Now let's actually drive the updates to completion - while not self.get_success(self.store.has_completed_background_updates()): - self.get_success(self.store.do_next_background_update(100), by=0.1) + while not self.get_success( + self.store.db.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db.updates.do_next_background_update(100), by=0.1 + ) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 7eea57c0e2..6a545d2eb0 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -15,8 +15,6 @@ from twisted.internet import defer -from synapse.storage.data_stores.main.user_directory import UserDirectoryStore - from tests import unittest from tests.utils import setup_test_homeserver @@ -29,7 +27,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.hs = yield setup_test_homeserver(self.addCleanup) - self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs) + self.store = self.hs.get_datastore() # alice and bob are both in !room_id. bobby is not but shares # a homeserver with alice. diff --git a/tests/test_federation.py b/tests/test_federation.py index 7d82b58466..ad165d7295 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -33,6 +33,8 @@ class MessageAcceptTests(unittest.TestCase): self.reactor.advance(0.1) self.room_id = self.successResultOf(room)["room_id"] + self.store = self.homeserver.get_datastore() + # Figure out what the most recent event is most_recent = self.successResultOf( maybeDeferred( @@ -77,10 +79,7 @@ class MessageAcceptTests(unittest.TestCase): # Make sure we actually joined the room self.assertEqual( self.successResultOf( - maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, - self.room_id, - ) + maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) )[0], "$join:test.serv", ) @@ -100,10 +99,7 @@ class MessageAcceptTests(unittest.TestCase): # Figure out what the most recent event is most_recent = self.successResultOf( - maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, - self.room_id, - ) + maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) )[0] # Now lie about an event @@ -141,7 +137,5 @@ class MessageAcceptTests(unittest.TestCase): ) # Make sure the invalid event isn't there - extrem = maybeDeferred( - self.homeserver.get_datastore().get_latest_event_ids_in_room, self.room_id - ) + extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") diff --git a/tests/unittest.py b/tests/unittest.py index 561cebc223..b30b7d1718 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018 New Vector +# Copyright 2019 Matrix.org Federation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,9 +14,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import gc import hashlib import hmac +import inspect import logging import time @@ -23,17 +26,21 @@ from mock import Mock from canonicaljson import json -from twisted.internet.defer import Deferred, succeed +from twisted.internet.defer import Deferred, ensureDeferred, succeed from twisted.python.threadpool import ThreadPool from twisted.trial import unittest -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, Membership +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.config.homeserver import HomeServerConfig +from synapse.config.ratelimiting import FederationRateLimitConfig +from synapse.federation.transport import server as federation_server from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest from synapse.logging.context import LoggingContext from synapse.server import HomeServer from synapse.types import Requester, UserID, create_requester +from synapse.util.ratelimitutils import FederationRateLimiter from tests.server import get_clock, make_request, render, setup_test_homeserver from tests.test_utils.logging_setup import setup_logging @@ -395,10 +402,12 @@ class HomeserverTestCase(TestCase): hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) stor = hs.get_datastore() - # Run the database background updates. - if hasattr(stor, "do_next_background_update"): - while not self.get_success(stor.has_completed_background_updates()): - self.get_success(stor.do_next_background_update(1)) + # Run the database background updates, when running against "master". + if hs.__class__.__name__ == "TestHomeServer": + while not self.get_success( + stor.db.updates.has_completed_background_updates() + ): + self.get_success(stor.db.updates.do_next_background_update(1)) return hs @@ -409,6 +418,8 @@ class HomeserverTestCase(TestCase): self.reactor.pump([by] * 100) def get_success(self, d, by=0.0): + if inspect.isawaitable(d): + d = ensureDeferred(d) if not isinstance(d, Deferred): return d self.pump(by=by) @@ -418,6 +429,8 @@ class HomeserverTestCase(TestCase): """ Run a Deferred and get a Failure from it. The failure must be of the type `exc`. """ + if inspect.isawaitable(d): + d = ensureDeferred(d) if not isinstance(d, Deferred): return d self.pump() @@ -538,7 +551,7 @@ class HomeserverTestCase(TestCase): Add the given event as an extremity to the room. """ self.get_success( - self.hs.get_datastore()._simple_insert( + self.hs.get_datastore().db.simple_insert( table="event_forward_extremities", values={"room_id": room_id, "event_id": event_id}, desc="test_add_extremity", @@ -559,6 +572,66 @@ class HomeserverTestCase(TestCase): self.render(request) self.assertEqual(channel.code, 403, channel.result) + def inject_room_member(self, room: str, user: str, membership: Membership) -> None: + """ + Inject a membership event into a room. + + Args: + room: Room ID to inject the event into. + user: MXID of the user to inject the membership for. + membership: The membership type. + """ + event_builder_factory = self.hs.get_event_builder_factory() + event_creation_handler = self.hs.get_event_creation_handler() + + room_version = self.get_success(self.hs.get_datastore().get_room_version(room)) + + builder = event_builder_factory.for_room_version( + KNOWN_ROOM_VERSIONS[room_version], + { + "type": EventTypes.Member, + "sender": user, + "state_key": user, + "room_id": room, + "content": {"membership": membership}, + }, + ) + + event, context = self.get_success( + event_creation_handler.create_new_client_event(builder) + ) + + self.get_success( + self.hs.get_storage().persistence.persist_event(event, context) + ) + + +class FederatingHomeserverTestCase(HomeserverTestCase): + """ + A federating homeserver that authenticates incoming requests as `other.example.com`. + """ + + def prepare(self, reactor, clock, homeserver): + class Authenticator(object): + def authenticate_request(self, request, content): + return succeed("other.example.com") + + ratelimiter = FederationRateLimiter( + clock, + FederationRateLimitConfig( + window_size=1, + sleep_limit=1, + sleep_msec=1, + reject_limit=1000, + concurrent_requests=1000, + ), + ) + federation_server.register_servlets( + homeserver, self.resource, Authenticator(), ratelimiter + ) + + return super().prepare(reactor, clock, homeserver) + def override_config(extra_config): """A decorator which can be applied to test functions to give additional HS config diff --git a/tests/utils.py b/tests/utils.py index 7dc9bdc505..c57da59191 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -109,6 +109,7 @@ def default_config(name, parse=False): """ config_dict = { "server_name": name, + "send_federation": False, "media_store_path": "media", "uploads_path": "uploads", # the test signing key is just an arbitrary ed25519 key to keep the config @@ -460,7 +461,9 @@ class MockHttpResource(HttpServer): try: args = [urlparse.unquote(u) for u in matcher.groups()] - (code, response) = yield func(mock_request, *args) + (code, response) = yield defer.ensureDeferred( + func(mock_request, *args) + ) return code, response except CodeMessageException as e: return (e.code, cs_error(e.msg, code=e.errcode)) diff --git a/tox.ini b/tox.ini index 62b350ea6a..903a245fb0 100644 --- a/tox.ini +++ b/tox.ini @@ -102,6 +102,15 @@ commands = {envbindir}/coverage run "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:} +[testenv:benchmark] +deps = + {[base]deps} + pyperf +setenv = + SYNAPSE_POSTGRES = 1 +commands = + python -m synmark {posargs:} + [testenv:packaging] skip_install=True deps = |