diff options
183 files changed, 4593 insertions, 1540 deletions
diff --git a/.buildkite/merge_base_branch.sh b/.buildkite/merge_base_branch.sh index eb7219a56d..361440fd1a 100755 --- a/.buildkite/merge_base_branch.sh +++ b/.buildkite/merge_base_branch.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash -set -ex +set -e if [[ "$BUILDKITE_BRANCH" =~ ^(develop|master|dinsic|shhs|release-.*)$ ]]; then echo "Not merging forward, as this is a release branch" @@ -18,6 +18,8 @@ else GITBASE=$BUILDKITE_PULL_REQUEST_BASE_BRANCH fi +echo "--- merge_base_branch $GITBASE" + # Show what we are before git --no-pager show -s diff --git a/.buildkite/worker-blacklist b/.buildkite/worker-blacklist index cda5c84e94..7950d19db3 100644 --- a/.buildkite/worker-blacklist +++ b/.buildkite/worker-blacklist @@ -28,3 +28,39 @@ User sees updates to presence from other users in the incremental sync. Gapped incremental syncs include all state changes Old members are included in gappy incr LL sync if they start speaking + +# new failures as of https://github.com/matrix-org/sytest/pull/732 +Device list doesn't change if remote server is down +Remote servers cannot set power levels in rooms without existing powerlevels +Remote servers should reject attempts by non-creators to set the power levels + +# new failures as of https://github.com/matrix-org/sytest/pull/753 +GET /rooms/:room_id/messages returns a message +GET /rooms/:room_id/messages lazy loads members correctly +Read receipts are sent as events +Only original members of the room can see messages from erased users +Device deletion propagates over federation +If user leaves room, remote user changes device and rejoins we see update in /sync and /keys/changes +Changing user-signing key notifies local users +Newly updated tags appear in an incremental v2 /sync +Server correctly handles incoming m.device_list_update +Local device key changes get to remote servers with correct prev_id +AS-ghosted users can use rooms via AS +Ghost user must register before joining room +Test that a message is pushed +Invites are pushed +Rooms with aliases are correctly named in pushed +Rooms with names are correctly named in pushed +Rooms with canonical alias are correctly named in pushed +Rooms with many users are correctly pushed +Don't get pushed for rooms you've muted +Rejected events are not pushed +Test that rejected pushers are removed. +Events come down the correct room + +# https://buildkite.com/matrix-dot-org/sytest/builds/326#cca62404-a88a-4fcb-ad41-175fd3377603 +Presence changes to UNAVAILABLE are reported to remote room members +If remote user leaves room, changes device and rejoins we see update in sync +uploading self-signing key notifies over federation +Inbound federation can receive redacted events +Outbound federation can request missing events diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 8939fda67d..11fb05ca96 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,8 +1,8 @@ ### Pull Request Checklist -<!-- Please read CONTRIBUTING.rst before submitting your pull request --> +<!-- Please read CONTRIBUTING.md before submitting your pull request --> * [ ] Pull request is based on the develop branch -* [ ] Pull request includes a [changelog file](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#changelog) -* [ ] Pull request includes a [sign off](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#sign-off) -* [ ] Code style is correct (run the [linters](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.rst#code-style)) +* [ ] Pull request includes a [changelog file](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.md#changelog) +* [ ] Pull request includes a [sign off](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.md#sign-off) +* [ ] Code style is correct (run the [linters](https://github.com/matrix-org/synapse/blob/master/CONTRIBUTING.md#code-style)) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000..c0091346f3 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,210 @@ +# Contributing code to Matrix + +Everyone is welcome to contribute code to Matrix +(https://github.com/matrix-org), provided that they are willing to license +their contributions under the same license as the project itself. We follow a +simple 'inbound=outbound' model for contributions: the act of submitting an +'inbound' contribution means that the contributor agrees to license the code +under the same terms as the project's overall 'outbound' license - in our +case, this is almost always Apache Software License v2 (see [LICENSE](LICENSE)). + +## How to contribute + +The preferred and easiest way to contribute changes to Matrix is to fork the +relevant project on github, and then [create a pull request]( +https://help.github.com/articles/using-pull-requests/) to ask us to pull +your changes into our repo. + +**The single biggest thing you need to know is: please base your changes on +the develop branch - *not* master.** + +We use the master branch to track the most recent release, so that folks who +blindly clone the repo and automatically check out master get something that +works. Develop is the unstable branch where all the development actually +happens: the workflow is that contributors should fork the develop branch to +make a 'feature' branch for a particular contribution, and then make a pull +request to merge this back into the matrix.org 'official' develop branch. We +use github's pull request workflow to review the contribution, and either ask +you to make any refinements needed or merge it and make them ourselves. The +changes will then land on master when we next do a release. + +We use [Buildkite](https://buildkite.com/matrix-dot-org/synapse) for continuous +integration. If your change breaks the build, this will be shown in GitHub, so +please keep an eye on the pull request for feedback. + +To run unit tests in a local development environment, you can use: + +- ``tox -e py35`` (requires tox to be installed by ``pip install tox``) + for SQLite-backed Synapse on Python 3.5. +- ``tox -e py36`` for SQLite-backed Synapse on Python 3.6. +- ``tox -e py36-postgres`` for PostgreSQL-backed Synapse on Python 3.6 + (requires a running local PostgreSQL with access to create databases). +- ``./test_postgresql.sh`` for PostgreSQL-backed Synapse on Python 3.5 + (requires Docker). Entirely self-contained, recommended if you don't want to + set up PostgreSQL yourself. + +Docker images are available for running the integration tests (SyTest) locally, +see the [documentation in the SyTest repo]( +https://github.com/matrix-org/sytest/blob/develop/docker/README.md) for more +information. + +## Code style + +All Matrix projects have a well-defined code-style - and sometimes we've even +got as far as documenting it... For instance, synapse's code style doc lives +[here](docs/code_style.md). + +To facilitate meeting these criteria you can run `scripts-dev/lint.sh` +locally. Since this runs the tools listed in the above document, you'll need +python 3.6 and to install each tool: + +``` +# Install the dependencies +pip install -U black flake8 isort + +# Run the linter script +./scripts-dev/lint.sh +``` + +**Note that the script does not just test/check, but also reformats code, so you +may wish to ensure any new code is committed first**. By default this script +checks all files and can take some time; if you alter only certain files, you +might wish to specify paths as arguments to reduce the run-time: + +``` +./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder +``` + +Before pushing new changes, ensure they don't produce linting errors. Commit any +files that were corrected. + +Please ensure your changes match the cosmetic style of the existing project, +and **never** mix cosmetic and functional changes in the same commit, as it +makes it horribly hard to review otherwise. + + +## Changelog + +All changes, even minor ones, need a corresponding changelog / newsfragment +entry. These are managed by [Towncrier](https://github.com/hawkowl/towncrier). + +To create a changelog entry, make a new file in the `changelog.d` directory named +in the format of `PRnumber.type`. The type can be one of the following: + +* `feature` +* `bugfix` +* `docker` (for updates to the Docker image) +* `doc` (for updates to the documentation) +* `removal` (also used for deprecations) +* `misc` (for internal-only changes) + +The content of the file is your changelog entry, which should be a short +description of your change in the same style as the rest of our [changelog]( +https://github.com/matrix-org/synapse/blob/master/CHANGES.md). The file can +contain Markdown formatting, and should end with a full stop ('.') for +consistency. + +Adding credits to the changelog is encouraged, we value your +contributions and would like to have you shouted out in the release notes! + +For example, a fix in PR #1234 would have its changelog entry in +`changelog.d/1234.bugfix`, and contain content like "The security levels of +Florbs are now validated when received over federation. Contributed by Jane +Matrix.". + +## Debian changelog + +Changes which affect the debian packaging files (in `debian`) are an +exception. + +In this case, you will need to add an entry to the debian changelog for the +next release. For this, run the following command: + +``` +dch +``` + +This will make up a new version number (if there isn't already an unreleased +version in flight), and open an editor where you can add a new changelog entry. +(Our release process will ensure that the version number and maintainer name is +corrected for the release.) + +If your change affects both the debian packaging *and* files outside the debian +directory, you will need both a regular newsfragment *and* an entry in the +debian changelog. (Though typically such changes should be submitted as two +separate pull requests.) + +## Sign off + +In order to have a concrete record that your contribution is intentional +and you agree to license it under the same terms as the project's license, we've adopted the +same lightweight approach that the Linux Kernel +[submitting patches process]( +https://www.kernel.org/doc/html/latest/process/submitting-patches.html#sign-your-work-the-developer-s-certificate-of-origin>), +[Docker](https://github.com/docker/docker/blob/master/CONTRIBUTING.md), and many other +projects use: the DCO (Developer Certificate of Origin: +http://developercertificate.org/). This is a simple declaration that you wrote +the contribution or otherwise have the right to contribute it to Matrix: + +``` +Developer Certificate of Origin +Version 1.1 + +Copyright (C) 2004, 2006 The Linux Foundation and its contributors. +660 York Street, Suite 102, +San Francisco, CA 94110 USA + +Everyone is permitted to copy and distribute verbatim copies of this +license document, but changing it is not allowed. + +Developer's Certificate of Origin 1.1 + +By making a contribution to this project, I certify that: + +(a) The contribution was created in whole or in part by me and I + have the right to submit it under the open source license + indicated in the file; or + +(b) The contribution is based upon previous work that, to the best + of my knowledge, is covered under an appropriate open source + license and I have the right under that license to submit that + work with modifications, whether created in whole or in part + by me, under the same open source license (unless I am + permitted to submit under a different license), as indicated + in the file; or + +(c) The contribution was provided directly to me by some other + person who certified (a), (b) or (c) and I have not modified + it. + +(d) I understand and agree that this project and the contribution + are public and that a record of the contribution (including all + personal information I submit with it, including my sign-off) is + maintained indefinitely and may be redistributed consistent with + this project or the open source license(s) involved. +``` + +If you agree to this for your contribution, then all that's needed is to +include the line in your commit or pull request comment: + +``` +Signed-off-by: Your Name <your@email.example.org> +``` + +We accept contributions under a legally identifiable name, such as +your name on government documentation or common-law names (names +claimed by legitimate usage or repute). Unfortunately, we cannot +accept anonymous contributions at this time. + +Git allows you to add this signoff automatically when using the `-s` +flag to `git commit`, which uses the name and email set in your +`user.name` and `user.email` git configs. + +## Conclusion + +That's it! Matrix is a very open and collaborative project as you might expect +given our obsession with open communication. If we're going to successfully +matrix together all the fragmented communication technologies out there we are +reliant on contributions and collaboration from the community to do so. So +please get involved - and we hope you have as much fun hacking on Matrix as we +do! diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst deleted file mode 100644 index df81f6e54f..0000000000 --- a/CONTRIBUTING.rst +++ /dev/null @@ -1,206 +0,0 @@ -Contributing code to Matrix -=========================== - -Everyone is welcome to contribute code to Matrix -(https://github.com/matrix-org), provided that they are willing to license -their contributions under the same license as the project itself. We follow a -simple 'inbound=outbound' model for contributions: the act of submitting an -'inbound' contribution means that the contributor agrees to license the code -under the same terms as the project's overall 'outbound' license - in our -case, this is almost always Apache Software License v2 (see LICENSE). - -How to contribute -~~~~~~~~~~~~~~~~~ - -The preferred and easiest way to contribute changes to Matrix is to fork the -relevant project on github, and then create a pull request to ask us to pull -your changes into our repo -(https://help.github.com/articles/using-pull-requests/) - -**The single biggest thing you need to know is: please base your changes on -the develop branch - /not/ master.** - -We use the master branch to track the most recent release, so that folks who -blindly clone the repo and automatically check out master get something that -works. Develop is the unstable branch where all the development actually -happens: the workflow is that contributors should fork the develop branch to -make a 'feature' branch for a particular contribution, and then make a pull -request to merge this back into the matrix.org 'official' develop branch. We -use github's pull request workflow to review the contribution, and either ask -you to make any refinements needed or merge it and make them ourselves. The -changes will then land on master when we next do a release. - -We use `Buildkite <https://buildkite.com/matrix-dot-org/synapse>`_ for -continuous integration. Buildkite builds need to be authorised by a -maintainer. If your change breaks the build, this will be shown in GitHub, so -please keep an eye on the pull request for feedback. - -To run unit tests in a local development environment, you can use: - -- ``tox -e py35`` (requires tox to be installed by ``pip install tox``) - for SQLite-backed Synapse on Python 3.5. -- ``tox -e py36`` for SQLite-backed Synapse on Python 3.6. -- ``tox -e py36-postgres`` for PostgreSQL-backed Synapse on Python 3.6 - (requires a running local PostgreSQL with access to create databases). -- ``./test_postgresql.sh`` for PostgreSQL-backed Synapse on Python 3.5 - (requires Docker). Entirely self-contained, recommended if you don't want to - set up PostgreSQL yourself. - -Docker images are available for running the integration tests (SyTest) locally, -see the `documentation in the SyTest repo -<https://github.com/matrix-org/sytest/blob/develop/docker/README.md>`_ for more -information. - -Code style -~~~~~~~~~~ - -All Matrix projects have a well-defined code-style - and sometimes we've even -got as far as documenting it... For instance, synapse's code style doc lives -at https://github.com/matrix-org/synapse/tree/master/docs/code_style.md. - -To facilitate meeting these criteria you can run ``scripts-dev/lint.sh`` -locally. Since this runs the tools listed in the above document, you'll need -python 3.6 and to install each tool. **Note that the script does not just -test/check, but also reformats code, so you may wish to ensure any new code is -committed first**. By default this script checks all files and can take some -time; if you alter only certain files, you might wish to specify paths as -arguments to reduce the run-time. - -Please ensure your changes match the cosmetic style of the existing project, -and **never** mix cosmetic and functional changes in the same commit, as it -makes it horribly hard to review otherwise. - -Before doing a commit, ensure the changes you've made don't produce -linting errors. You can do this by running the linters as follows. Ensure to -commit any files that were corrected. - -:: - # Install the dependencies - pip install -U black flake8 isort - - # Run the linter script - ./scripts-dev/lint.sh - -Changelog -~~~~~~~~~ - -All changes, even minor ones, need a corresponding changelog / newsfragment -entry. These are managed by Towncrier -(https://github.com/hawkowl/towncrier). - -To create a changelog entry, make a new file in the ``changelog.d`` file named -in the format of ``PRnumber.type``. The type can be one of the following: - -* ``feature``. -* ``bugfix``. -* ``docker`` (for updates to the Docker image). -* ``doc`` (for updates to the documentation). -* ``removal`` (also used for deprecations). -* ``misc`` (for internal-only changes). - -The content of the file is your changelog entry, which should be a short -description of your change in the same style as the rest of our `changelog -<https://github.com/matrix-org/synapse/blob/master/CHANGES.md>`_. The file can -contain Markdown formatting, and should end with a full stop ('.') for -consistency. - -Adding credits to the changelog is encouraged, we value your -contributions and would like to have you shouted out in the release notes! - -For example, a fix in PR #1234 would have its changelog entry in -``changelog.d/1234.bugfix``, and contain content like "The security levels of -Florbs are now validated when recieved over federation. Contributed by Jane -Matrix.". - -Debian changelog ----------------- - -Changes which affect the debian packaging files (in ``debian``) are an -exception. - -In this case, you will need to add an entry to the debian changelog for the -next release. For this, run the following command:: - - dch - -This will make up a new version number (if there isn't already an unreleased -version in flight), and open an editor where you can add a new changelog entry. -(Our release process will ensure that the version number and maintainer name is -corrected for the release.) - -If your change affects both the debian packaging *and* files outside the debian -directory, you will need both a regular newsfragment *and* an entry in the -debian changelog. (Though typically such changes should be submitted as two -separate pull requests.) - -Sign off -~~~~~~~~ - -In order to have a concrete record that your contribution is intentional -and you agree to license it under the same terms as the project's license, we've adopted the -same lightweight approach that the Linux Kernel -`submitting patches process <https://www.kernel.org/doc/html/latest/process/submitting-patches.html#sign-your-work-the-developer-s-certificate-of-origin>`_, Docker -(https://github.com/docker/docker/blob/master/CONTRIBUTING.md), and many other -projects use: the DCO (Developer Certificate of Origin: -http://developercertificate.org/). This is a simple declaration that you wrote -the contribution or otherwise have the right to contribute it to Matrix:: - - Developer Certificate of Origin - Version 1.1 - - Copyright (C) 2004, 2006 The Linux Foundation and its contributors. - 660 York Street, Suite 102, - San Francisco, CA 94110 USA - - Everyone is permitted to copy and distribute verbatim copies of this - license document, but changing it is not allowed. - - Developer's Certificate of Origin 1.1 - - By making a contribution to this project, I certify that: - - (a) The contribution was created in whole or in part by me and I - have the right to submit it under the open source license - indicated in the file; or - - (b) The contribution is based upon previous work that, to the best - of my knowledge, is covered under an appropriate open source - license and I have the right under that license to submit that - work with modifications, whether created in whole or in part - by me, under the same open source license (unless I am - permitted to submit under a different license), as indicated - in the file; or - - (c) The contribution was provided directly to me by some other - person who certified (a), (b) or (c) and I have not modified - it. - - (d) I understand and agree that this project and the contribution - are public and that a record of the contribution (including all - personal information I submit with it, including my sign-off) is - maintained indefinitely and may be redistributed consistent with - this project or the open source license(s) involved. - -If you agree to this for your contribution, then all that's needed is to -include the line in your commit or pull request comment:: - - Signed-off-by: Your Name <your@email.example.org> - -We accept contributions under a legally identifiable name, such as -your name on government documentation or common-law names (names -claimed by legitimate usage or repute). Unfortunately, we cannot -accept anonymous contributions at this time. - -Git allows you to add this signoff automatically when using the ``-s`` -flag to ``git commit``, which uses the name and email set in your -``user.name`` and ``user.email`` git configs. - -Conclusion -~~~~~~~~~~ - -That's it! Matrix is a very open and collaborative project as you might expect -given our obsession with open communication. If we're going to successfully -matrix together all the fragmented communication technologies out there we are -reliant on contributions and collaboration from the community to do so. So -please get involved - and we hope you have as much fun hacking on Matrix as we -do! diff --git a/INSTALL.md b/INSTALL.md index 29e0abafd3..9da2e3c734 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -109,8 +109,8 @@ Installing prerequisites on Ubuntu or Debian: ``` sudo apt-get install build-essential python3-dev libffi-dev \ - python-pip python-setuptools sqlite3 \ - libssl-dev python-virtualenv libjpeg-dev libxslt1-dev + python3-pip python3-setuptools sqlite3 \ + libssl-dev python3-virtualenv libjpeg-dev libxslt1-dev ``` #### ArchLinux @@ -133,9 +133,9 @@ sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \ sudo yum groupinstall "Development Tools" ``` -#### Mac OS X +#### macOS -Installing prerequisites on Mac OS X: +Installing prerequisites on macOS: ``` xcode-select --install @@ -144,6 +144,14 @@ sudo pip install virtualenv brew install pkg-config libffi ``` +On macOS Catalina (10.15) you may need to explicitly install OpenSSL +via brew and inform `pip` about it so that `psycopg2` builds: + +``` +brew install openssl@1.1 +export LDFLAGS=-L/usr/local/Cellar/openssl\@1.1/1.1.1d/lib/ +``` + #### OpenSUSE Installing prerequisites on openSUSE: diff --git a/UPGRADE.rst b/UPGRADE.rst index 5ebf16a73e..d9020f2663 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -75,6 +75,23 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb +Upgrading to v1.7.0 +=================== + +In an attempt to configure Synapse in a privacy preserving way, the default +behaviours of ``allow_public_rooms_without_auth`` and +``allow_public_rooms_over_federation`` have been inverted. This means that by +default, only authenticated users querying the Client/Server API will be able +to query the room directory, and relatedly that the server will not share +room directory information with other servers over federation. + +If your installation does not explicitly set these settings one way or the other +and you want either setting to be ``true`` then it will necessary to update +your homeserver configuration file accordingly. + +For more details on the surrounding context see our `explainer +<https://matrix.org/blog/2019/11/09/avoiding-unwelcome-visitors-on-private-matrix-servers>`_. + Upgrading to v1.5.0 =================== diff --git a/changelog.d/5815.feature b/changelog.d/5815.feature new file mode 100644 index 0000000000..ca4df4e7f6 --- /dev/null +++ b/changelog.d/5815.feature @@ -0,0 +1 @@ +Implement per-room message retention policies. diff --git a/changelog.d/5858.feature b/changelog.d/5858.feature new file mode 100644 index 0000000000..55ee93051e --- /dev/null +++ b/changelog.d/5858.feature @@ -0,0 +1 @@ +Add etag and count fields to key backup endpoints to help clients guess if there are new keys. diff --git a/changelog.d/6119.feature b/changelog.d/6119.feature new file mode 100644 index 0000000000..1492e83c5a --- /dev/null +++ b/changelog.d/6119.feature @@ -0,0 +1 @@ +Require User-Interactive Authentication for `/account/3pid/add`, meaning the user's password will be required to add a third-party ID to their account. \ No newline at end of file diff --git a/changelog.d/6176.feature b/changelog.d/6176.feature new file mode 100644 index 0000000000..3c66d689d4 --- /dev/null +++ b/changelog.d/6176.feature @@ -0,0 +1 @@ +Implement the `/_matrix/federation/unstable/net.atleastfornow/state/<context>` API as drafted in MSC2314. diff --git a/changelog.d/6237.bugfix b/changelog.d/6237.bugfix new file mode 100644 index 0000000000..9285600b00 --- /dev/null +++ b/changelog.d/6237.bugfix @@ -0,0 +1 @@ +Transfer non-standard power levels on room upgrade. \ No newline at end of file diff --git a/changelog.d/6241.bugfix b/changelog.d/6241.bugfix new file mode 100644 index 0000000000..25109ca4a6 --- /dev/null +++ b/changelog.d/6241.bugfix @@ -0,0 +1 @@ +Fix error from the Pillow library when uploading RGBA images. diff --git a/changelog.d/6266.misc b/changelog.d/6266.misc new file mode 100644 index 0000000000..634e421a79 --- /dev/null +++ b/changelog.d/6266.misc @@ -0,0 +1 @@ +Add benchmarks for structured logging and improve output performance. diff --git a/changelog.d/6322.misc b/changelog.d/6322.misc new file mode 100644 index 0000000000..70ef36ca80 --- /dev/null +++ b/changelog.d/6322.misc @@ -0,0 +1 @@ +Improve the performance of outputting structured logging. diff --git a/changelog.d/6329.bugfix b/changelog.d/6329.bugfix new file mode 100644 index 0000000000..e558d13b7d --- /dev/null +++ b/changelog.d/6329.bugfix @@ -0,0 +1 @@ +Correctly apply the event filter to the `state`, `events_before` and `events_after` fields in the response to `/context` requests. \ No newline at end of file diff --git a/changelog.d/6332.bugfix b/changelog.d/6332.bugfix new file mode 100644 index 0000000000..67d5170ba0 --- /dev/null +++ b/changelog.d/6332.bugfix @@ -0,0 +1 @@ +Fix caching devices for remote users when using workers, so that we don't attempt to refetch (and potentially fail) each time a user requests devices. diff --git a/changelog.d/6333.bugfix b/changelog.d/6333.bugfix new file mode 100644 index 0000000000..a25d6ef3cb --- /dev/null +++ b/changelog.d/6333.bugfix @@ -0,0 +1 @@ +Prevent account data syncs getting lost across TCP replication. \ No newline at end of file diff --git a/changelog.d/6343.misc b/changelog.d/6343.misc new file mode 100644 index 0000000000..d9a44389b9 --- /dev/null +++ b/changelog.d/6343.misc @@ -0,0 +1 @@ +Refactor some code in the event authentication path for clarity. diff --git a/changelog.d/6354.feature b/changelog.d/6354.feature new file mode 100644 index 0000000000..fed9db884b --- /dev/null +++ b/changelog.d/6354.feature @@ -0,0 +1 @@ +Configure privacy preserving settings by default for the room directory. diff --git a/changelog.d/6362.misc b/changelog.d/6362.misc new file mode 100644 index 0000000000..b79a5bea99 --- /dev/null +++ b/changelog.d/6362.misc @@ -0,0 +1 @@ +Clean up some unnecessary quotation marks around the codebase. \ No newline at end of file diff --git a/changelog.d/6379.misc b/changelog.d/6379.misc new file mode 100644 index 0000000000..725c2e7d87 --- /dev/null +++ b/changelog.d/6379.misc @@ -0,0 +1 @@ +Complain on startup instead of 500'ing during runtime when `public_baseurl` isn't set when necessary. \ No newline at end of file diff --git a/changelog.d/6388.doc b/changelog.d/6388.doc new file mode 100644 index 0000000000..c777cb6b8f --- /dev/null +++ b/changelog.d/6388.doc @@ -0,0 +1 @@ +Fix link in the user directory documentation. diff --git a/changelog.d/6390.doc b/changelog.d/6390.doc new file mode 100644 index 0000000000..093411bec1 --- /dev/null +++ b/changelog.d/6390.doc @@ -0,0 +1 @@ +Add build instructions to the docker readme. \ No newline at end of file diff --git a/changelog.d/6392.misc b/changelog.d/6392.misc new file mode 100644 index 0000000000..a00257944f --- /dev/null +++ b/changelog.d/6392.misc @@ -0,0 +1 @@ +Add a test scenario to make sure room history purges don't break `/messages` in the future. diff --git a/changelog.d/6406.bugfix b/changelog.d/6406.bugfix new file mode 100644 index 0000000000..ca9bee084b --- /dev/null +++ b/changelog.d/6406.bugfix @@ -0,0 +1 @@ +Fix bug: TypeError in `register_user()` while using LDAP auth module. diff --git a/changelog.d/6408.bugfix b/changelog.d/6408.bugfix new file mode 100644 index 0000000000..c9babe599b --- /dev/null +++ b/changelog.d/6408.bugfix @@ -0,0 +1 @@ +Fix an intermittent exception when handling read-receipts. diff --git a/changelog.d/6409.feature b/changelog.d/6409.feature new file mode 100644 index 0000000000..653ff5a5ad --- /dev/null +++ b/changelog.d/6409.feature @@ -0,0 +1 @@ +Add ephemeral messages support by partially implementing [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228). diff --git a/changelog.d/6420.bugfix b/changelog.d/6420.bugfix new file mode 100644 index 0000000000..aef47cccaa --- /dev/null +++ b/changelog.d/6420.bugfix @@ -0,0 +1 @@ +Fix broken guest registration when there are existing blocks of numeric user IDs. diff --git a/changelog.d/6421.bugfix b/changelog.d/6421.bugfix new file mode 100644 index 0000000000..7969f7f71d --- /dev/null +++ b/changelog.d/6421.bugfix @@ -0,0 +1 @@ +Fix startup error when http proxy is defined. diff --git a/changelog.d/6423.misc b/changelog.d/6423.misc new file mode 100644 index 0000000000..9bcd5d36c1 --- /dev/null +++ b/changelog.d/6423.misc @@ -0,0 +1 @@ +Clarifications for the email configuration settings. diff --git a/changelog.d/6426.bugfix b/changelog.d/6426.bugfix new file mode 100644 index 0000000000..3acfde4211 --- /dev/null +++ b/changelog.d/6426.bugfix @@ -0,0 +1 @@ +Clean up local threepids from user on account deactivation. \ No newline at end of file diff --git a/changelog.d/6429.misc b/changelog.d/6429.misc new file mode 100644 index 0000000000..4b32cdeac6 --- /dev/null +++ b/changelog.d/6429.misc @@ -0,0 +1 @@ +Add more tests to the blacklist when running in worker mode. diff --git a/changelog.d/6434.feature b/changelog.d/6434.feature new file mode 100644 index 0000000000..affa5d50c1 --- /dev/null +++ b/changelog.d/6434.feature @@ -0,0 +1 @@ +Add support for MSC 2367, which allows specifying a reason on all membership events. diff --git a/changelog.d/6436.bugfix b/changelog.d/6436.bugfix new file mode 100644 index 0000000000..954a4e1d84 --- /dev/null +++ b/changelog.d/6436.bugfix @@ -0,0 +1 @@ +Fix a bug where a room could become unusable with a low retention policy and a low activity. diff --git a/changelog.d/6443.doc b/changelog.d/6443.doc new file mode 100644 index 0000000000..67c59f92ee --- /dev/null +++ b/changelog.d/6443.doc @@ -0,0 +1 @@ +Switch Ubuntu package install recommendation to use python3 packages in INSTALL.md. \ No newline at end of file diff --git a/changelog.d/6449.bugfix b/changelog.d/6449.bugfix new file mode 100644 index 0000000000..002f33c450 --- /dev/null +++ b/changelog.d/6449.bugfix @@ -0,0 +1 @@ +Fix error when using synapse_port_db on a vanilla synapse db. diff --git a/changelog.d/6451.bugfix b/changelog.d/6451.bugfix new file mode 100644 index 0000000000..23b67583ec --- /dev/null +++ b/changelog.d/6451.bugfix @@ -0,0 +1 @@ +Fix uploading multiple cross signing signatures for the same user. diff --git a/changelog.d/6454.misc b/changelog.d/6454.misc new file mode 100644 index 0000000000..9e5259157c --- /dev/null +++ b/changelog.d/6454.misc @@ -0,0 +1 @@ +Move data store specific code out of `SQLBaseStore`. diff --git a/changelog.d/6458.doc b/changelog.d/6458.doc new file mode 100644 index 0000000000..3a9f831d89 --- /dev/null +++ b/changelog.d/6458.doc @@ -0,0 +1 @@ +Write some docs for the quarantine_media api. diff --git a/changelog.d/6461.doc b/changelog.d/6461.doc new file mode 100644 index 0000000000..1502fa2855 --- /dev/null +++ b/changelog.d/6461.doc @@ -0,0 +1 @@ +Convert CONTRIBUTING.rst to markdown (among other small fixes). \ No newline at end of file diff --git a/changelog.d/6462.bugfix b/changelog.d/6462.bugfix new file mode 100644 index 0000000000..c435939526 --- /dev/null +++ b/changelog.d/6462.bugfix @@ -0,0 +1 @@ +Fix bug which lead to exceptions being thrown in a loop when a cross-signed device is deleted. diff --git a/changelog.d/6464.misc b/changelog.d/6464.misc new file mode 100644 index 0000000000..bd65276ef6 --- /dev/null +++ b/changelog.d/6464.misc @@ -0,0 +1 @@ +Prepare SQLBaseStore functions being moved out of the stores. diff --git a/changelog.d/6468.misc b/changelog.d/6468.misc new file mode 100644 index 0000000000..d9a44389b9 --- /dev/null +++ b/changelog.d/6468.misc @@ -0,0 +1 @@ +Refactor some code in the event authentication path for clarity. diff --git a/changelog.d/6470.bugfix b/changelog.d/6470.bugfix new file mode 100644 index 0000000000..c08b34c14c --- /dev/null +++ b/changelog.d/6470.bugfix @@ -0,0 +1 @@ +Fix `synapse_port_db` not exiting with a 0 code if something went wrong during the port process. diff --git a/changelog.d/6472.bugfix b/changelog.d/6472.bugfix new file mode 100644 index 0000000000..598efb79fc --- /dev/null +++ b/changelog.d/6472.bugfix @@ -0,0 +1 @@ +Improve sanity-checking when receiving events over federation. diff --git a/changelog.d/6480.misc b/changelog.d/6480.misc new file mode 100644 index 0000000000..d9a44389b9 --- /dev/null +++ b/changelog.d/6480.misc @@ -0,0 +1 @@ +Refactor some code in the event authentication path for clarity. diff --git a/docker/README.md b/docker/README.md index 24dfa77dcc..9f112a01d0 100644 --- a/docker/README.md +++ b/docker/README.md @@ -130,3 +130,15 @@ docker run -it --rm \ This will generate the same configuration file as the legacy mode used, but will store it in `/data/homeserver.yaml` instead of a temporary location. You can then use it as shown above at [Running synapse](#running-synapse). + +## Building the image + +If you need to build the image from a Synapse checkout, use the following `docker + build` command from the repo's root: + +``` +docker build -t matrixdotorg/synapse -f docker/Dockerfile . +``` + +You can choose to build a different docker image by changing the value of the `-f` flag to +point to another Dockerfile. diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index 5e9f8e5d84..8b3666d5f5 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -21,3 +21,20 @@ It returns a JSON body like the following: ] } ``` + +# Quarantine media in a room + +This API 'quarantines' all the media in a room. + +The API is: + +``` +POST /_synapse/admin/v1/quarantine_media/<room_id> + +{} +``` + +Quarantining media means that it is marked as inaccessible by users. It applies +to any local media, and any locally-cached copies of remote media. + +The media file itself (and any thumbnails) is not deleted from the server. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 896159394c..10664ae8f7 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -54,15 +54,16 @@ pid_file: DATADIR/homeserver.pid # #require_auth_for_profile_requests: true -# If set to 'false', requires authentication to access the server's public rooms -# directory through the client API. Defaults to 'true'. +# If set to 'true', removes the need for authentication to access the server's +# public rooms directory through the client API, meaning that anyone can +# query the room directory. Defaults to 'false'. # -#allow_public_rooms_without_auth: false +#allow_public_rooms_without_auth: true -# If set to 'false', forbids any other homeserver to fetch the server's public -# rooms directory via federation. Defaults to 'true'. +# If set to 'true', allows any other homeserver to fetch the server's public +# rooms directory via federation. Defaults to 'false'. # -#allow_public_rooms_over_federation: false +#allow_public_rooms_over_federation: true # The default room version for newly created rooms. # @@ -328,6 +329,69 @@ listeners: # #user_ips_max_age: 14d +# Message retention policy at the server level. +# +# Room admins and mods can define a retention period for their rooms using the +# 'm.room.retention' state event, and server admins can cap this period by setting +# the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options. +# +# If this feature is enabled, Synapse will regularly look for and purge events +# which are older than the room's maximum retention period. Synapse will also +# filter events received over federation so that events that should have been +# purged are ignored and not stored again. +# +retention: + # The message retention policies feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # Default retention policy. If set, Synapse will apply it to rooms that lack the + # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't + # matter much because Synapse doesn't take it into account yet. + # + #default_policy: + # min_lifetime: 1d + # max_lifetime: 1y + + # Retention policy limits. If set, a user won't be able to send a + # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime' + # that's not within this range. This is especially useful in closed federations, + # in which server admins can make sure every federating server applies the same + # rules. + # + #allowed_lifetime_min: 1d + #allowed_lifetime_max: 1y + + # Server admins can define the settings of the background jobs purging the + # events which lifetime has expired under the 'purge_jobs' section. + # + # If no configuration is provided, a single job will be set up to delete expired + # events in every room daily. + # + # Each job's configuration defines which range of message lifetimes the job + # takes care of. For example, if 'shortest_max_lifetime' is '2d' and + # 'longest_max_lifetime' is '3d', the job will handle purging expired events in + # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and + # lower than or equal to 3 days. Both the minimum and the maximum value of a + # range are optional, e.g. a job with no 'shortest_max_lifetime' and a + # 'longest_max_lifetime' of '3d' will handle every room with a retention policy + # which 'max_lifetime' is lower than or equal to three days. + # + # The rationale for this per-job configuration is that some rooms might have a + # retention policy with a low 'max_lifetime', where history needs to be purged + # of outdated messages on a very frequent basis (e.g. every 5min), but not want + # that purge to be performed by a job that's iterating over every room it knows, + # which would be quite heavy on the server. + # + #purge_jobs: + # - shortest_max_lifetime: 1d + # longest_max_lifetime: 3d + # interval: 5m: + # - shortest_max_lifetime: 3d + # longest_max_lifetime: 1y + # interval: 24h + ## TLS ## @@ -1270,8 +1334,23 @@ password_config: # smtp_user: "exampleusername" # smtp_pass: "examplepassword" # require_transport_security: false +# +# # notif_from defines the "From" address to use when sending emails. +# # It must be set if email sending is enabled. +# # +# # The placeholder '%(app)s' will be replaced by the application name, +# # which is normally 'app_name' (below), but may be overridden by the +# # Matrix client application. +# # +# # Note that the placeholder must be written '%(app)s', including the +# # trailing 's'. +# # # notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>" -# app_name: Matrix +# +# # app_name defines the default value for '%(app)s' in notif_from. It +# # defaults to 'Matrix'. +# # +# #app_name: my_branded_matrix_server # # # Enable email notifications by default # # diff --git a/docs/user_directory.md b/docs/user_directory.md index e64aa453cc..37dc71e751 100644 --- a/docs/user_directory.md +++ b/docs/user_directory.md @@ -7,7 +7,6 @@ who are present in a publicly viewable room present on the server. The directory info is stored in various tables, which can (typically after DB corruption) get stale or out of sync. If this happens, for now the -solution to fix it is to execute the SQL here -https://github.com/matrix-org/synapse/blob/master/synapse/storage/schema/delta/53/user_dir_populate.sql +solution to fix it is to execute the SQL [here](../synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql) and then restart synapse. This should then start a background task to flush the current tables and regenerate the directory. diff --git a/scripts-dev/hash_history.py b/scripts-dev/hash_history.py index d20f6db176..bf3862a386 100644 --- a/scripts-dev/hash_history.py +++ b/scripts-dev/hash_history.py @@ -27,7 +27,7 @@ class Store(object): "_store_pdu_reference_hash_txn" ] _store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"] - _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"] + simple_insert_txn = SQLBaseStore.__dict__["simple_insert_txn"] store = Store() diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 0d3321682c..c4cf11d19a 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -47,6 +47,7 @@ from synapse.storage.data_stores.main.media_repository import ( from synapse.storage.data_stores.main.registration import ( RegistrationBackgroundUpdateStore, ) +from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore from synapse.storage.data_stores.main.search import SearchBackgroundUpdateStore from synapse.storage.data_stores.main.state import StateBackgroundUpdateStore @@ -131,6 +132,7 @@ class Store( EventsBackgroundUpdatesStore, MediaRepositoryBackgroundUpdateStore, RegistrationBackgroundUpdateStore, + RoomBackgroundUpdateStore, RoomMemberBackgroundUpdateStore, SearchBackgroundUpdateStore, StateBackgroundUpdateStore, @@ -221,7 +223,7 @@ class Porter(object): def setup_table(self, table): if table in APPEND_ONLY_TABLES: # It's safe to just carry on inserting. - row = yield self.postgres_store._simple_select_one( + row = yield self.postgres_store.simple_select_one( table="port_from_sqlite3", keyvalues={"table_name": table}, retcols=("forward_rowid", "backward_rowid"), @@ -236,7 +238,7 @@ class Porter(object): ) backward_chunk = 0 else: - yield self.postgres_store._simple_insert( + yield self.postgres_store.simple_insert( table="port_from_sqlite3", values={ "table_name": table, @@ -266,7 +268,7 @@ class Porter(object): yield self.postgres_store.execute(delete_all) - yield self.postgres_store._simple_insert( + yield self.postgres_store.simple_insert( table="port_from_sqlite3", values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0}, ) @@ -320,7 +322,7 @@ class Porter(object): if table == "user_directory_stream_pos": # We need to make sure there is a single row, `(X, null), as that is # what synapse expects to be there. - yield self.postgres_store._simple_insert( + yield self.postgres_store.simple_insert( table=table, values={"stream_id": None} ) self.progress.update(table, table_size) # Mark table as done @@ -375,7 +377,7 @@ class Porter(object): def insert(txn): self.postgres_store.insert_many_txn(txn, table, headers[1:], rows) - self.postgres_store._simple_update_one_txn( + self.postgres_store.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": table}, @@ -452,7 +454,7 @@ class Porter(object): ], ) - self.postgres_store._simple_update_one_txn( + self.postgres_store.simple_update_one_txn( txn, table="port_from_sqlite3", keyvalues={"table_name": "event_search"}, @@ -591,11 +593,11 @@ class Porter(object): # Step 2. Get tables. self.progress.set_state("Fetching tables") - sqlite_tables = yield self.sqlite_store._simple_select_onecol( + sqlite_tables = yield self.sqlite_store.simple_select_onecol( table="sqlite_master", keyvalues={"type": "table"}, retcol="name" ) - postgres_tables = yield self.postgres_store._simple_select_onecol( + postgres_tables = yield self.postgres_store.simple_select_onecol( table="information_schema.tables", keyvalues={}, retcol="distinct table_name", @@ -722,7 +724,7 @@ class Porter(object): next_chunk = yield self.sqlite_store.execute(get_start_id) next_chunk = max(max_inserted_rowid + 1, next_chunk) - yield self.postgres_store._simple_insert( + yield self.postgres_store.simple_insert( table="port_from_sqlite3", values={ "table_name": "sent_transactions", @@ -782,7 +784,10 @@ class Porter(object): def _setup_state_group_id_seq(self): def r(txn): txn.execute("SELECT MAX(id) FROM state_groups") - next_id = txn.fetchone()[0] + 1 + curr_id = txn.fetchone()[0] + if not curr_id: + return + next_id = curr_id + 1 txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,)) return self.postgres_store.runInteraction("setup_state_group_id_seq", r) @@ -1052,3 +1057,4 @@ if __name__ == "__main__": if end_error_exec_info: exc_type, exc_value, exc_traceback = end_error_exec_info traceback.print_exception(exc_type, exc_value, exc_traceback) + sys.exit(5) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 49c4b85054..0ade47e624 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -94,6 +95,8 @@ class EventTypes(object): ServerACL = "m.room.server_acl" Pinned = "m.room.pinned_events" + Retention = "m.room.retention" + class RejectedReason(object): AUTH_ERROR = "auth_error" @@ -145,3 +148,7 @@ class EventContentFields(object): # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 LABELS = "org.matrix.labels" + + # Timestamp to delete the event after + # cf https://github.com/matrix-org/matrix-doc/pull/2228 + SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after" diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index bec13f08d8..6eab1f13f0 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index 139221ad34..448e45e00f 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -69,7 +69,7 @@ class FederationSenderSlaveStore( self.federation_out_pos_startup = self._get_federation_out_pos(db_conn) def _get_federation_out_pos(self, db_conn): - sql = "SELECT stream_id FROM federation_stream_position" " WHERE type = ?" + sql = "SELECT stream_id FROM federation_stream_position WHERE type = ?" sql = self.database_engine.convert_param_style(sql) txn = db_conn.cursor() diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 883b3fb70b..267aebaae9 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -542,8 +542,8 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): # Database version # - stats["database_engine"] = hs.get_datastore().database_engine_name - stats["database_server_version"] = hs.get_datastore().get_server_version() + stats["database_engine"] = hs.database_engine.module.__name__ + stats["database_server_version"] = hs.database_engine.server_version logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) try: yield hs.get_proxied_http_client().put_json( diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index 6cb100319f..0fa2b50999 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -64,7 +64,7 @@ class UserDirectorySlaveStore( super(UserDirectorySlaveStore, self).__init__(db_conn, hs) events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( + curr_state_delta_prefill, min_curr_state_delta_id = self.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 3e25bf5747..57174da021 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -185,7 +185,7 @@ class ApplicationServiceApi(SimpleHttpClient): if not _is_valid_3pe_metadata(info): logger.warning( - "query_3pe_protocol to %s did not return a" " valid result", uri + "query_3pe_protocol to %s did not return a valid result", uri ) return None diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index e77d3387ff..ca43e96bd1 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -134,7 +134,7 @@ def _load_appservice(hostname, as_info, config_filename): for regex_obj in as_info["namespaces"][ns]: if not isinstance(regex_obj, dict): raise ValueError( - "Expected namespace entry in %s to be an object," " but got %s", + "Expected namespace entry in %s to be an object, but got %s", ns, regex_obj, ) diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 43fad0bf8b..18f42a87f9 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -146,6 +146,8 @@ class EmailConfig(Config): if k not in email_config: missing.append("email." + k) + # public_baseurl is required to build password reset and validation links that + # will be emailed to users if config.get("public_baseurl") is None: missing.append("public_baseurl") @@ -305,8 +307,23 @@ class EmailConfig(Config): # smtp_user: "exampleusername" # smtp_pass: "examplepassword" # require_transport_security: false + # + # # notif_from defines the "From" address to use when sending emails. + # # It must be set if email sending is enabled. + # # + # # The placeholder '%(app)s' will be replaced by the application name, + # # which is normally 'app_name' (below), but may be overridden by the + # # Matrix client application. + # # + # # Note that the placeholder must be written '%(app)s', including the + # # trailing 's'. + # # # notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>" - # app_name: Matrix + # + # # app_name defines the default value for '%(app)s' in notif_from. It + # # defaults to 'Matrix'. + # # + # #app_name: my_branded_matrix_server # # # Enable email notifications by default # # diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 1f6dac69da..ee9614c5f7 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -106,6 +106,13 @@ class RegistrationConfig(Config): account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") + if self.account_threepid_delegate_msisdn and not self.public_baseurl: + raise ConfigError( + "The configuration option `public_baseurl` is required if " + "`account_threepid_delegate.msisdn` is set, such that " + "clients know where to submit validation tokens to. Please " + "configure `public_baseurl`." + ) self.default_identity_server = config.get("default_identity_server") self.allow_guest_access = config.get("allow_guest_access", False) diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 7c9f05bde4..7ac7699676 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -170,7 +170,7 @@ class _RoomDirectoryRule(object): self.action = action else: raise ConfigError( - "%s rules can only have action of 'allow'" " or 'deny'" % (option_name,) + "%s rules can only have action of 'allow' or 'deny'" % (option_name,) ) self._alias_matches_all = alias == "*" diff --git a/synapse/config/server.py b/synapse/config/server.py index 00d01c43af..a4bef00936 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -19,7 +19,7 @@ import logging import os.path import re from textwrap import indent -from typing import List +from typing import Dict, List, Optional import attr import yaml @@ -118,15 +118,16 @@ class ServerConfig(Config): self.allow_public_rooms_without_auth = False self.allow_public_rooms_over_federation = False else: - # If set to 'False', requires authentication to access the server's public - # rooms directory through the client API. Defaults to 'True'. + # If set to 'true', removes the need for authentication to access the server's + # public rooms directory through the client API, meaning that anyone can + # query the room directory. Defaults to 'false'. self.allow_public_rooms_without_auth = config.get( - "allow_public_rooms_without_auth", True + "allow_public_rooms_without_auth", False ) - # If set to 'False', forbids any other homeserver to fetch the server's public - # rooms directory via federation. Defaults to 'True'. + # If set to 'true', allows any other homeserver to fetch the server's public + # rooms directory via federation. Defaults to 'false'. self.allow_public_rooms_over_federation = config.get( - "allow_public_rooms_over_federation", True + "allow_public_rooms_over_federation", False ) default_room_version = config.get("default_room_version", DEFAULT_ROOM_VERSION) @@ -223,7 +224,7 @@ class ServerConfig(Config): self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) except Exception as e: raise ConfigError( - "Invalid range(s) provided in " "federation_ip_range_blacklist: %s" % e + "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e ) if self.public_baseurl is not None: @@ -246,6 +247,124 @@ class ServerConfig(Config): # events with profile information that differ from the target's global profile. self.allow_per_room_profiles = config.get("allow_per_room_profiles", True) + retention_config = config.get("retention") + if retention_config is None: + retention_config = {} + + self.retention_enabled = retention_config.get("enabled", False) + + retention_default_policy = retention_config.get("default_policy") + + if retention_default_policy is not None: + self.retention_default_min_lifetime = retention_default_policy.get( + "min_lifetime" + ) + if self.retention_default_min_lifetime is not None: + self.retention_default_min_lifetime = self.parse_duration( + self.retention_default_min_lifetime + ) + + self.retention_default_max_lifetime = retention_default_policy.get( + "max_lifetime" + ) + if self.retention_default_max_lifetime is not None: + self.retention_default_max_lifetime = self.parse_duration( + self.retention_default_max_lifetime + ) + + if ( + self.retention_default_min_lifetime is not None + and self.retention_default_max_lifetime is not None + and ( + self.retention_default_min_lifetime + > self.retention_default_max_lifetime + ) + ): + raise ConfigError( + "The default retention policy's 'min_lifetime' can not be greater" + " than its 'max_lifetime'" + ) + else: + self.retention_default_min_lifetime = None + self.retention_default_max_lifetime = None + + self.retention_allowed_lifetime_min = retention_config.get( + "allowed_lifetime_min" + ) + if self.retention_allowed_lifetime_min is not None: + self.retention_allowed_lifetime_min = self.parse_duration( + self.retention_allowed_lifetime_min + ) + + self.retention_allowed_lifetime_max = retention_config.get( + "allowed_lifetime_max" + ) + if self.retention_allowed_lifetime_max is not None: + self.retention_allowed_lifetime_max = self.parse_duration( + self.retention_allowed_lifetime_max + ) + + if ( + self.retention_allowed_lifetime_min is not None + and self.retention_allowed_lifetime_max is not None + and self.retention_allowed_lifetime_min + > self.retention_allowed_lifetime_max + ): + raise ConfigError( + "Invalid retention policy limits: 'allowed_lifetime_min' can not be" + " greater than 'allowed_lifetime_max'" + ) + + self.retention_purge_jobs = [] # type: List[Dict[str, Optional[int]]] + for purge_job_config in retention_config.get("purge_jobs", []): + interval_config = purge_job_config.get("interval") + + if interval_config is None: + raise ConfigError( + "A retention policy's purge jobs configuration must have the" + " 'interval' key set." + ) + + interval = self.parse_duration(interval_config) + + shortest_max_lifetime = purge_job_config.get("shortest_max_lifetime") + + if shortest_max_lifetime is not None: + shortest_max_lifetime = self.parse_duration(shortest_max_lifetime) + + longest_max_lifetime = purge_job_config.get("longest_max_lifetime") + + if longest_max_lifetime is not None: + longest_max_lifetime = self.parse_duration(longest_max_lifetime) + + if ( + shortest_max_lifetime is not None + and longest_max_lifetime is not None + and shortest_max_lifetime > longest_max_lifetime + ): + raise ConfigError( + "A retention policy's purge jobs configuration's" + " 'shortest_max_lifetime' value can not be greater than its" + " 'longest_max_lifetime' value." + ) + + self.retention_purge_jobs.append( + { + "interval": interval, + "shortest_max_lifetime": shortest_max_lifetime, + "longest_max_lifetime": longest_max_lifetime, + } + ) + + if not self.retention_purge_jobs: + self.retention_purge_jobs = [ + { + "interval": self.parse_duration("1d"), + "shortest_max_lifetime": None, + "longest_max_lifetime": None, + } + ] + self.listeners = [] # type: List[dict] for listener in config.get("listeners", []): if not isinstance(listener.get("port", None), int): @@ -372,6 +491,8 @@ class ServerConfig(Config): "cleanup_extremities_with_dummy_events", True ) + self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False) + def has_tls_listener(self) -> bool: return any(l["tls"] for l in self.listeners) @@ -500,15 +621,16 @@ class ServerConfig(Config): # #require_auth_for_profile_requests: true - # If set to 'false', requires authentication to access the server's public rooms - # directory through the client API. Defaults to 'true'. + # If set to 'true', removes the need for authentication to access the server's + # public rooms directory through the client API, meaning that anyone can + # query the room directory. Defaults to 'false'. # - #allow_public_rooms_without_auth: false + #allow_public_rooms_without_auth: true - # If set to 'false', forbids any other homeserver to fetch the server's public - # rooms directory via federation. Defaults to 'true'. + # If set to 'true', allows any other homeserver to fetch the server's public + # rooms directory via federation. Defaults to 'false'. # - #allow_public_rooms_over_federation: false + #allow_public_rooms_over_federation: true # The default room version for newly created rooms. # @@ -761,6 +883,69 @@ class ServerConfig(Config): # Defaults to `28d`. Set to `null` to disable clearing out of old rows. # #user_ips_max_age: 14d + + # Message retention policy at the server level. + # + # Room admins and mods can define a retention period for their rooms using the + # 'm.room.retention' state event, and server admins can cap this period by setting + # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options. + # + # If this feature is enabled, Synapse will regularly look for and purge events + # which are older than the room's maximum retention period. Synapse will also + # filter events received over federation so that events that should have been + # purged are ignored and not stored again. + # + retention: + # The message retention policies feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # Default retention policy. If set, Synapse will apply it to rooms that lack the + # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't + # matter much because Synapse doesn't take it into account yet. + # + #default_policy: + # min_lifetime: 1d + # max_lifetime: 1y + + # Retention policy limits. If set, a user won't be able to send a + # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime' + # that's not within this range. This is especially useful in closed federations, + # in which server admins can make sure every federating server applies the same + # rules. + # + #allowed_lifetime_min: 1d + #allowed_lifetime_max: 1y + + # Server admins can define the settings of the background jobs purging the + # events which lifetime has expired under the 'purge_jobs' section. + # + # If no configuration is provided, a single job will be set up to delete expired + # events in every room daily. + # + # Each job's configuration defines which range of message lifetimes the job + # takes care of. For example, if 'shortest_max_lifetime' is '2d' and + # 'longest_max_lifetime' is '3d', the job will handle purging expired events in + # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and + # lower than or equal to 3 days. Both the minimum and the maximum value of a + # range are optional, e.g. a job with no 'shortest_max_lifetime' and a + # 'longest_max_lifetime' of '3d' will handle every room with a retention policy + # which 'max_lifetime' is lower than or equal to three days. + # + # The rationale for this per-job configuration is that some rooms might have a + # retention policy with a low 'max_lifetime', where history needs to be purged + # of outdated messages on a very frequent basis (e.g. every 5min), but not want + # that purge to be performed by a job that's iterating over every room it knows, + # which would be quite heavy on the server. + # + #purge_jobs: + # - shortest_max_lifetime: 1d + # longest_max_lifetime: 3d + # interval: 5m: + # - shortest_max_lifetime: 3d + # longest_max_lifetime: 1y + # interval: 24h """ % locals() ) @@ -787,14 +972,14 @@ class ServerConfig(Config): "--print-pidfile", action="store_true", default=None, - help="Print the path to the pidfile just" " before daemonizing", + help="Print the path to the pidfile just before daemonizing", ) server_group.add_argument( "--manhole", metavar="PORT", dest="manhole", type=int, - help="Turn on the twisted telnet manhole" " service on the given port.", + help="Turn on the twisted telnet manhole service on the given port.", ) diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 272426e105..9b90c9ce04 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from six import string_types +from six import integer_types, string_types from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError @@ -22,11 +22,12 @@ from synapse.types import EventID, RoomID, UserID class EventValidator(object): - def validate_new(self, event): + def validate_new(self, event, config): """Validates the event has roughly the right format Args: - event (FrozenEvent) + event (FrozenEvent): The event to validate. + config (Config): The homeserver's configuration. """ self.validate_builder(event) @@ -67,6 +68,99 @@ class EventValidator(object): Codes.INVALID_PARAM, ) + if event.type == EventTypes.Retention: + self._validate_retention(event, config) + + def _validate_retention(self, event, config): + """Checks that an event that defines the retention policy for a room respects the + boundaries imposed by the server's administrator. + + Args: + event (FrozenEvent): The event to validate. + config (Config): The homeserver's configuration. + """ + min_lifetime = event.content.get("min_lifetime") + max_lifetime = event.content.get("max_lifetime") + + if min_lifetime is not None: + if not isinstance(min_lifetime, integer_types): + raise SynapseError( + code=400, + msg="'min_lifetime' must be an integer", + errcode=Codes.BAD_JSON, + ) + + if ( + config.retention_allowed_lifetime_min is not None + and min_lifetime < config.retention_allowed_lifetime_min + ): + raise SynapseError( + code=400, + msg=( + "'min_lifetime' can't be lower than the minimum allowed" + " value enforced by the server's administrator" + ), + errcode=Codes.BAD_JSON, + ) + + if ( + config.retention_allowed_lifetime_max is not None + and min_lifetime > config.retention_allowed_lifetime_max + ): + raise SynapseError( + code=400, + msg=( + "'min_lifetime' can't be greater than the maximum allowed" + " value enforced by the server's administrator" + ), + errcode=Codes.BAD_JSON, + ) + + if max_lifetime is not None: + if not isinstance(max_lifetime, integer_types): + raise SynapseError( + code=400, + msg="'max_lifetime' must be an integer", + errcode=Codes.BAD_JSON, + ) + + if ( + config.retention_allowed_lifetime_min is not None + and max_lifetime < config.retention_allowed_lifetime_min + ): + raise SynapseError( + code=400, + msg=( + "'max_lifetime' can't be lower than the minimum allowed value" + " enforced by the server's administrator" + ), + errcode=Codes.BAD_JSON, + ) + + if ( + config.retention_allowed_lifetime_max is not None + and max_lifetime > config.retention_allowed_lifetime_max + ): + raise SynapseError( + code=400, + msg=( + "'max_lifetime' can't be greater than the maximum allowed" + " value enforced by the server's administrator" + ), + errcode=Codes.BAD_JSON, + ) + + if ( + min_lifetime is not None + and max_lifetime is not None + and min_lifetime > max_lifetime + ): + raise SynapseError( + code=400, + msg="'min_lifetime' can't be greater than 'max_lifetime", + errcode=Codes.BAD_JSON, + ) + def validate_builder(self, event): """Validates that the builder/event has roughly the right format. Only checks values that we expect a proto event to have, rather than all the diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index d942d77a72..84d4eca041 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd +# Copyright 2019 Matrix.org Federation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -73,6 +74,7 @@ class FederationServer(FederationBase): self.auth = hs.get_auth() self.handler = hs.get_handlers().federation_handler + self.state = hs.get_state_handler() self._server_linearizer = Linearizer("fed_server") self._transaction_linearizer = Linearizer("fed_txn_handler") @@ -264,9 +266,6 @@ class FederationServer(FederationBase): await self.registry.on_edu(edu_type, origin, content) async def on_context_state_request(self, origin, room_id, event_id): - if not event_id: - raise NotImplementedError("Specify an event") - origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) @@ -280,13 +279,18 @@ class FederationServer(FederationBase): # - but that's non-trivial to get right, and anyway somewhat defeats # the point of the linearizer. with (await self._server_linearizer.queue((origin, room_id))): - resp = await self._state_resp_cache.wrap( - (room_id, event_id), - self._on_context_state_request_compute, - room_id, - event_id, + resp = dict( + await self._state_resp_cache.wrap( + (room_id, event_id), + self._on_context_state_request_compute, + room_id, + event_id, + ) ) + room_version = await self.store.get_room_version(room_id) + resp["room_version"] = room_version + return 200, resp async def on_state_ids_request(self, origin, room_id, event_id): @@ -306,7 +310,11 @@ class FederationServer(FederationBase): return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} async def _on_context_state_request_compute(self, room_id, event_id): - pdus = await self.handler.get_state_for_pdu(room_id, event_id) + if event_id: + pdus = await self.handler.get_state_for_pdu(room_id, event_id) + else: + pdus = (await self.state.get_current_state(room_id)).values() + auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus]) return { diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 44edcabed4..d68b4bd670 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -44,7 +44,7 @@ class TransactionActions(object): response code and response body. """ if not transaction.transaction_id: - raise RuntimeError("Cannot persist a transaction with no " "transaction_id") + raise RuntimeError("Cannot persist a transaction with no transaction_id") return self.store.get_received_txn_response(transaction.transaction_id, origin) @@ -56,7 +56,7 @@ class TransactionActions(object): Deferred """ if not transaction.transaction_id: - raise RuntimeError("Cannot persist a transaction with no " "transaction_id") + raise RuntimeError("Cannot persist a transaction with no transaction_id") return self.store.set_received_txn_response( transaction.transaction_id, origin, code, response diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 2b2ee8612a..4ebb0e8bc0 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -49,7 +49,7 @@ sent_pdus_destination_dist_count = Counter( sent_pdus_destination_dist_total = Counter( "synapse_federation_client_sent_pdu_destinations:total", - "" "Total number of PDUs queued for sending across all destinations", + "Total number of PDUs queued for sending across all destinations", ) diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 67b3e1ab6e..5fed626d5b 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -84,7 +84,7 @@ class TransactionManager(object): txn_id = str(self._next_txn_id) logger.debug( - "TX [%s] {%s} Attempting new transaction" " (pdus: %d, edus: %d)", + "TX [%s] {%s} Attempting new transaction (pdus: %d, edus: %d)", destination, txn_id, len(pdus), @@ -103,7 +103,7 @@ class TransactionManager(object): self._next_txn_id += 1 logger.info( - "TX [%s] {%s} Sending transaction [%s]," " (PDUs: %d, EDUs: %d)", + "TX [%s] {%s} Sending transaction [%s], (PDUs: %d, EDUs: %d)", destination, txn_id, transaction.transaction_id, diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 09baa9c57d..fefc789c85 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -421,7 +421,7 @@ class FederationEventServlet(BaseFederationServlet): return await self.handler.on_pdu_request(origin, event_id) -class FederationStateServlet(BaseFederationServlet): +class FederationStateV1Servlet(BaseFederationServlet): PATH = "/state/(?P<context>[^/]*)/?" # This is when someone asks for all data for a given context. @@ -429,7 +429,7 @@ class FederationStateServlet(BaseFederationServlet): return await self.handler.on_context_state_request( origin, context, - parse_string_from_args(query, "event_id", None, required=True), + parse_string_from_args(query, "event_id", None, required=False), ) @@ -1360,7 +1360,7 @@ class RoomComplexityServlet(BaseFederationServlet): FEDERATION_SERVLET_CLASSES = ( FederationSendServlet, FederationEventServlet, - FederationStateServlet, + FederationStateV1Servlet, FederationStateIdsServlet, FederationBackfillServlet, FederationQueryServlet, diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 69051101a6..a07d2f1a17 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -119,7 +119,7 @@ class DirectoryHandler(BaseHandler): if not service.is_interested_in_alias(room_alias.to_string()): raise SynapseError( 400, - "This application service has not reserved" " this kind of alias.", + "This application service has not reserved this kind of alias.", errcode=Codes.EXCLUSIVE, ) else: diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index f09a0b73c8..28c12753c1 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -30,6 +30,7 @@ from twisted.internet import defer from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace +from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.types import ( UserID, get_domain_from_id, @@ -53,6 +54,12 @@ class E2eKeysHandler(object): self._edu_updater = SigningKeyEduUpdater(hs, self) + self._is_master = hs.config.worker_app is None + if not self._is_master: + self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client( + hs + ) + federation_registry = hs.get_federation_registry() # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec @@ -191,9 +198,15 @@ class E2eKeysHandler(object): # probably be tracking their device lists. However, we haven't # done an initial sync on the device list so we do it now. try: - user_devices = yield self.device_handler.device_list_updater.user_device_resync( - user_id - ) + if self._is_master: + user_devices = yield self.device_handler.device_list_updater.user_device_resync( + user_id + ) + else: + user_devices = yield self._user_device_resync_client( + user_id=user_id + ) + user_devices = user_devices["devices"] for device in user_devices: results[user_id] = {device["device_id"]: device["keys"]} diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 0cea445f0d..f1b4424a02 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017, 2018 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -103,14 +104,35 @@ class E2eRoomKeysHandler(object): rooms session_id(string): session ID to delete keys for, for None to delete keys for all sessions + Raises: + NotFoundError: if the backup version does not exist Returns: - A deferred of the deletion transaction + A dict containing the count and etag for the backup version """ # lock for consistency with uploading with (yield self._upload_linearizer.queue(user_id)): + # make sure the backup version exists + try: + version_info = yield self.store.get_e2e_room_keys_version_info( + user_id, version + ) + except StoreError as e: + if e.code == 404: + raise NotFoundError("Unknown backup version") + else: + raise + yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id) + version_etag = version_info["etag"] + 1 + yield self.store.update_e2e_room_keys_version( + user_id, version, None, version_etag + ) + + count = yield self.store.count_e2e_room_keys(user_id, version) + return {"etag": str(version_etag), "count": count} + @trace @defer.inlineCallbacks def upload_room_keys(self, user_id, version, room_keys): @@ -138,6 +160,9 @@ class E2eRoomKeysHandler(object): } } + Returns: + A dict containing the count and etag for the backup version + Raises: NotFoundError: if there are no versions defined RoomKeysVersionError: if the uploaded version is not the current version @@ -171,59 +196,62 @@ class E2eRoomKeysHandler(object): else: raise - # go through the room_keys. - # XXX: this should/could be done concurrently, given we're in a lock. + # Fetch any existing room keys for the sessions that have been + # submitted. Then compare them with the submitted keys. If the + # key is new, insert it; if the key should be updated, then update + # it; otherwise, drop it. + existing_keys = yield self.store.get_e2e_room_keys_multi( + user_id, version, room_keys["rooms"] + ) + to_insert = [] # batch the inserts together + changed = False # if anything has changed, we need to update the etag for room_id, room in iteritems(room_keys["rooms"]): - for session_id, session in iteritems(room["sessions"]): - yield self._upload_room_key( - user_id, version, room_id, session_id, session + for session_id, room_key in iteritems(room["sessions"]): + log_kv( + { + "message": "Trying to upload room key", + "room_id": room_id, + "session_id": session_id, + "user_id": user_id, + } ) - - @defer.inlineCallbacks - def _upload_room_key(self, user_id, version, room_id, session_id, room_key): - """Upload a given room_key for a given room and session into a given - version of the backup. Merges the key with any which might already exist. - - Args: - user_id(str): the user whose backup we're setting - version(str): the version ID of the backup we're updating - room_id(str): the ID of the room whose keys we're setting - session_id(str): the session whose room_key we're setting - room_key(dict): the room_key being set - """ - log_kv( - { - "message": "Trying to upload room key", - "room_id": room_id, - "session_id": session_id, - "user_id": user_id, - } - ) - # get the room_key for this particular row - current_room_key = None - try: - current_room_key = yield self.store.get_e2e_room_key( - user_id, version, room_id, session_id - ) - except StoreError as e: - if e.code == 404: - log_kv( - { - "message": "Room key not found.", - "room_id": room_id, - "user_id": user_id, - } + current_room_key = existing_keys.get(room_id, {}).get(session_id) + if current_room_key: + if self._should_replace_room_key(current_room_key, room_key): + log_kv({"message": "Replacing room key."}) + # updates are done one at a time in the DB, so send + # updates right away rather than batching them up, + # like we do with the inserts + yield self.store.update_e2e_room_key( + user_id, version, room_id, session_id, room_key + ) + changed = True + else: + log_kv({"message": "Not replacing room_key."}) + else: + log_kv( + { + "message": "Room key not found.", + "room_id": room_id, + "user_id": user_id, + } + ) + log_kv({"message": "Replacing room key."}) + to_insert.append((room_id, session_id, room_key)) + changed = True + + if len(to_insert): + yield self.store.add_e2e_room_keys(user_id, version, to_insert) + + version_etag = version_info["etag"] + if changed: + version_etag = version_etag + 1 + yield self.store.update_e2e_room_keys_version( + user_id, version, None, version_etag ) - else: - raise - if self._should_replace_room_key(current_room_key, room_key): - log_kv({"message": "Replacing room key."}) - yield self.store.set_e2e_room_key( - user_id, version, room_id, session_id, room_key - ) - else: - log_kv({"message": "Not replacing room_key."}) + count = yield self.store.count_e2e_room_keys(user_id, version) + return {"etag": str(version_etag), "count": count} @staticmethod def _should_replace_room_key(current_room_key, room_key): @@ -314,6 +342,8 @@ class E2eRoomKeysHandler(object): raise NotFoundError("Unknown backup version") else: raise + + res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"]) return res @trace diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0e904f2da0..bc26921768 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -19,11 +19,13 @@ import itertools import logging +from typing import Dict, Iterable, Optional, Sequence, Tuple import six from six import iteritems, itervalues from six.moves import http_client, zip +import attr from signedjson.key import decode_verify_key_bytes from signedjson.sign import verify_signed_json from unpaddedbase64 import decode_base64 @@ -45,6 +47,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.crypto.event_signing import compute_event_signature from synapse.event_auth import auth_types_for_event +from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.logging.context import ( @@ -72,6 +75,23 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) +@attr.s +class _NewEventInfo: + """Holds information about a received event, ready for passing to _handle_new_events + + Attributes: + event: the received event + + state: the state at that event + + auth_events: the auth_event map for that event + """ + + event = attr.ib(type=EventBase) + state = attr.ib(type=Optional[Sequence[EventBase]], default=None) + auth_events = attr.ib(type=Optional[Dict[Tuple[str, str], EventBase]], default=None) + + def shortstr(iterable, maxitems=5): """If iterable has maxitems or fewer, return the stringification of a list containing those items. @@ -121,6 +141,7 @@ class FederationHandler(BaseHandler): self.pusher_pool = hs.get_pusherpool() self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() + self._message_handler = hs.get_message_handler() self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config self.http_client = hs.get_simple_http_client() @@ -141,6 +162,8 @@ class FederationHandler(BaseHandler): self.third_party_event_rules = hs.get_third_party_event_rules() + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): """ Process a PDU received via a federation /send/ transaction, or @@ -594,14 +617,14 @@ class FederationHandler(BaseHandler): for e in auth_chain if e.event_id in auth_ids or e.type == EventTypes.Create } - event_infos.append({"event": e, "auth_events": auth}) + event_infos.append(_NewEventInfo(event=e, auth_events=auth)) seen_ids.add(e.event_id) logger.info( "[%s %s] persisting newly-received auth/state events %s", room_id, event_id, - [e["event"].event_id for e in event_infos], + [e.event.event_id for e in event_infos], ) yield self._handle_new_events(origin, event_infos) @@ -792,9 +815,9 @@ class FederationHandler(BaseHandler): a.internal_metadata.outlier = True ev_infos.append( - { - "event": a, - "auth_events": { + _NewEventInfo( + event=a, + auth_events={ ( auth_events[a_id].type, auth_events[a_id].state_key, @@ -802,7 +825,7 @@ class FederationHandler(BaseHandler): for a_id in a.auth_event_ids() if a_id in auth_events }, - } + ) ) # Step 1b: persist the events in the chunk we fetched state for (i.e. @@ -814,10 +837,10 @@ class FederationHandler(BaseHandler): assert not ev.internal_metadata.is_outlier() ev_infos.append( - { - "event": ev, - "state": events_to_state[e_id], - "auth_events": { + _NewEventInfo( + event=ev, + state=events_to_state[e_id], + auth_events={ ( auth_events[a_id].type, auth_events[a_id].state_key, @@ -825,7 +848,7 @@ class FederationHandler(BaseHandler): for a_id in ev.auth_event_ids() if a_id in auth_events }, - } + ) ) yield self._handle_new_events(dest, ev_infos, backfilled=True) @@ -1428,9 +1451,9 @@ class FederationHandler(BaseHandler): return event @defer.inlineCallbacks - def do_remotely_reject_invite(self, target_hosts, room_id, user_id): + def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content): origin, event, event_format_version = yield self._make_and_verify_event( - target_hosts, room_id, user_id, "leave" + target_hosts, room_id, user_id, "leave", content=content, ) # Mark as outlier as we don't have any state for this event; we're not # even in the room. @@ -1710,7 +1733,12 @@ class FederationHandler(BaseHandler): return context @defer.inlineCallbacks - def _handle_new_events(self, origin, event_infos, backfilled=False): + def _handle_new_events( + self, + origin: str, + event_infos: Iterable[_NewEventInfo], + backfilled: bool = False, + ): """Creates the appropriate contexts and persists events. The events should not depend on one another, e.g. this should be used to persist a bunch of outliers, but not a chunk of individual events that depend @@ -1720,14 +1748,14 @@ class FederationHandler(BaseHandler): """ @defer.inlineCallbacks - def prep(ev_info): - event = ev_info["event"] + def prep(ev_info: _NewEventInfo): + event = ev_info.event with nested_logging_context(suffix=event.event_id): res = yield self._prep_event( origin, event, - state=ev_info.get("state"), - auth_events=ev_info.get("auth_events"), + state=ev_info.state, + auth_events=ev_info.auth_events, backfilled=backfilled, ) return res @@ -1741,7 +1769,7 @@ class FederationHandler(BaseHandler): yield self.persist_events_and_notify( [ - (ev_info["event"], context) + (ev_info.event, context) for ev_info, context in zip(event_infos, contexts) ], backfilled=backfilled, @@ -1843,7 +1871,14 @@ class FederationHandler(BaseHandler): yield self.persist_events_and_notify([(event, new_event_context)]) @defer.inlineCallbacks - def _prep_event(self, origin, event, state, auth_events, backfilled): + def _prep_event( + self, + origin: str, + event: EventBase, + state: Optional[Iterable[EventBase]], + auth_events: Optional[Dict[Tuple[str, str], EventBase]], + backfilled: bool, + ): """ Args: @@ -1851,7 +1886,7 @@ class FederationHandler(BaseHandler): event: state: auth_events: - backfilled (bool) + backfilled: Returns: Deferred, which resolves to synapse.events.snapshot.EventContext @@ -1887,15 +1922,16 @@ class FederationHandler(BaseHandler): return context @defer.inlineCallbacks - def _check_for_soft_fail(self, event, state, backfilled): + def _check_for_soft_fail( + self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool + ): """Checks if we should soft fail the event, if so marks the event as such. Args: - event (FrozenEvent) - state (dict|None): The state at the event if we don't have all the - event's prev events - backfilled (bool): Whether the event is from backfill + event + state: The state at the event if we don't have all the event's prev events + backfilled: Whether the event is from backfill Returns: Deferred @@ -2040,8 +2076,10 @@ class FederationHandler(BaseHandler): auth_events (dict[(str, str)->synapse.events.EventBase]): Map from (event_type, state_key) to event - What we expect the event's auth_events to be, based on the event's - position in the dag. I think? maybe?? + Normally, our calculated auth_events based on the state of the room + at the event's position in the DAG, though occasionally (eg if the + event is an outlier), may be the auth events claimed by the remote + server. Also NB that this function adds entries to it. Returns: @@ -2091,35 +2129,35 @@ class FederationHandler(BaseHandler): origin (str): event (synapse.events.EventBase): context (synapse.events.snapshot.EventContext): + auth_events (dict[(str, str)->synapse.events.EventBase]): + Map from (event_type, state_key) to event + + Normally, our calculated auth_events based on the state of the room + at the event's position in the DAG, though occasionally (eg if the + event is an outlier), may be the auth events claimed by the remote + server. + + Also NB that this function adds entries to it. Returns: defer.Deferred[EventContext]: updated context """ event_auth_events = set(event.auth_event_ids()) - if event.is_state(): - event_key = (event.type, event.state_key) - else: - event_key = None - - # if the event's auth_events refers to events which are not in our - # calculated auth_events, we need to fetch those events from somewhere. - # - # we start by fetching them from the store, and then try calling /event_auth/. + # missing_auth is the set of the event's auth_events which we don't yet have + # in auth_events. missing_auth = event_auth_events.difference( e.event_id for e in auth_events.values() ) + # if we have missing events, we need to fetch those events from somewhere. + # + # we start by checking if they are in the store, and then try calling /event_auth/. if missing_auth: - # TODO: can we use store.have_seen_events here instead? - have_events = yield self.store.get_seen_events_with_rejections(missing_auth) - logger.debug("Got events %s from store", have_events) - missing_auth.difference_update(have_events.keys()) - else: - have_events = {} - - have_events.update({e.event_id: "" for e in auth_events.values()}) + have_events = yield self.store.have_seen_events(missing_auth) + logger.debug("Events %s are in the store", have_events) + missing_auth.difference_update(have_events) if missing_auth: # If we don't have all the auth events, we need to get them. @@ -2165,19 +2203,18 @@ class FederationHandler(BaseHandler): except AuthError: pass - have_events = yield self.store.get_seen_events_with_rejections( - event.auth_event_ids() - ) except Exception: - # FIXME: logger.exception("Failed to get auth chain") if event.internal_metadata.is_outlier(): + # XXX: given that, for an outlier, we'll be working with the + # event's *claimed* auth events rather than those we calculated: + # (a) is there any point in this test, since different_auth below will + # obviously be empty + # (b) alternatively, why don't we do it earlier? logger.info("Skipping auth_event fetch for outlier") return context - # FIXME: Assumes we have and stored all the state for all the - # prev_events different_auth = event_auth_events.difference( e.event_id for e in auth_events.values() ) @@ -2191,53 +2228,58 @@ class FederationHandler(BaseHandler): different_auth, ) - room_version = yield self.store.get_room_version(event.room_id) + # XXX: currently this checks for redactions but I'm not convinced that is + # necessary? + different_events = yield self.store.get_events_as_list(different_auth) - different_events = yield make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self.store.get_event, d, allow_none=True, allow_rejected=False - ) - for d in different_auth - if d in have_events and not have_events[d] - ], - consumeErrors=True, - ) - ).addErrback(unwrapFirstError) + for d in different_events: + if d.room_id != event.room_id: + logger.warning( + "Event %s refers to auth_event %s which is in a different room", + event.event_id, + d.event_id, + ) - if different_events: - local_view = dict(auth_events) - remote_view = dict(auth_events) - remote_view.update( - {(d.type, d.state_key): d for d in different_events if d} - ) + # don't attempt to resolve the claimed auth events against our own + # in this case: just use our own auth events. + # + # XXX: should we reject the event in this case? It feels like we should, + # but then shouldn't we also do so if we've failed to fetch any of the + # auth events? + return context - new_state = yield self.state_handler.resolve_events( - room_version, - [list(local_view.values()), list(remote_view.values())], - event, - ) + # now we state-resolve between our own idea of the auth events, and the remote's + # idea of them. - logger.info( - "After state res: updating auth_events with new state %s", - { - (d.type, d.state_key): d.event_id - for d in new_state.values() - if auth_events.get((d.type, d.state_key)) != d - }, - ) + local_state = auth_events.values() + remote_auth_events = dict(auth_events) + remote_auth_events.update({(d.type, d.state_key): d for d in different_events}) + remote_state = remote_auth_events.values() + + room_version = yield self.store.get_room_version(event.room_id) + new_state = yield self.state_handler.resolve_events( + room_version, (local_state, remote_state), event + ) + + logger.info( + "After state res: updating auth_events with new state %s", + { + (d.type, d.state_key): d.event_id + for d in new_state.values() + if auth_events.get((d.type, d.state_key)) != d + }, + ) - auth_events.update(new_state) + auth_events.update(new_state) - context = yield self._update_context_for_auth_events( - event, context, auth_events, event_key - ) + context = yield self._update_context_for_auth_events( + event, context, auth_events + ) return context @defer.inlineCallbacks - def _update_context_for_auth_events(self, event, context, auth_events, event_key): + def _update_context_for_auth_events(self, event, context, auth_events): """Update the state_ids in an event context after auth event resolution, storing the changes as a new state group. @@ -2246,18 +2288,21 @@ class FederationHandler(BaseHandler): context (synapse.events.snapshot.EventContext): initial event context - auth_events (dict[(str, str)->str]): Events to update in the event + auth_events (dict[(str, str)->EventBase]): Events to update in the event context. - event_key ((str, str)): (type, state_key) for the current event. - this will not be included in the current_state in the context. - Returns: Deferred[EventContext]: new event context """ + # exclude the state key of the new event from the current_state in the context. + if event.is_state(): + event_key = (event.type, event.state_key) + else: + event_key = None state_updates = { k: a.event_id for k, a in iteritems(auth_events) if k != event_key } + current_state_ids = yield context.get_current_state_ids(self.store) current_state_ids = dict(current_state_ids) @@ -2459,7 +2504,7 @@ class FederationHandler(BaseHandler): room_version, event_dict, event, context ) - EventValidator().validate_new(event) + EventValidator().validate_new(event, self.config) # We need to tell the transaction queue to send this out, even # though the sender isn't a local user. @@ -2574,7 +2619,7 @@ class FederationHandler(BaseHandler): event, context = yield self.event_creation_handler.create_new_client_event( builder=builder ) - EventValidator().validate_new(event) + EventValidator().validate_new(event, self.config) return (event, context) @defer.inlineCallbacks @@ -2708,6 +2753,11 @@ class FederationHandler(BaseHandler): event_and_contexts, backfilled=backfilled ) + if self._ephemeral_messages_enabled: + for (event, context) in event_and_contexts: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) + if not backfilled: # Never notify for backfilled events for event, _ in event_and_contexts: yield self._notify_persisted_event(event, max_stream_id) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d682dc2b7a..4f53a5f5dc 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Optional from six import iteritems, itervalues, string_types @@ -22,9 +23,16 @@ from canonicaljson import encode_canonical_json, json from twisted.internet import defer from twisted.internet.defer import succeed +from twisted.internet.interfaces import IDelayedCall from synapse import event_auth -from synapse.api.constants import EventTypes, Membership, RelationTypes, UserTypes +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RelationTypes, + UserTypes, +) from synapse.api.errors import ( AuthError, Codes, @@ -62,6 +70,17 @@ class MessageHandler(object): self.storage = hs.get_storage() self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() + self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._is_worker_app = bool(hs.config.worker_app) + + # The scheduled call to self._expire_event. None if no call is currently + # scheduled. + self._scheduled_expiry = None # type: Optional[IDelayedCall] + + if not hs.config.worker_app: + run_as_background_process( + "_schedule_next_expiry", self._schedule_next_expiry + ) @defer.inlineCallbacks def get_room_data( @@ -138,7 +157,7 @@ class MessageHandler(object): raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = yield filter_events_for_client( - self.storage, user_id, last_events + self.storage, user_id, last_events, apply_retention_policies=False ) event = last_events[0] @@ -225,6 +244,100 @@ class MessageHandler(object): for user_id, profile in iteritems(users_with_profile) } + def maybe_schedule_expiry(self, event): + """Schedule the expiry of an event if there's not already one scheduled, + or if the one running is for an event that will expire after the provided + timestamp. + + This function needs to invalidate the event cache, which is only possible on + the master process, and therefore needs to be run on there. + + Args: + event (EventBase): The event to schedule the expiry of. + """ + assert not self._is_worker_app + + expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) + if not isinstance(expiry_ts, int) or event.is_state(): + return + + # _schedule_expiry_for_event won't actually schedule anything if there's already + # a task scheduled for a timestamp that's sooner than the provided one. + self._schedule_expiry_for_event(event.event_id, expiry_ts) + + @defer.inlineCallbacks + def _schedule_next_expiry(self): + """Retrieve the ID and the expiry timestamp of the next event to be expired, + and schedule an expiry task for it. + + If there's no event left to expire, set _expiry_scheduled to None so that a + future call to save_expiry_ts can schedule a new expiry task. + """ + # Try to get the expiry timestamp of the next event to expire. + res = yield self.store.get_next_event_to_expire() + if res: + event_id, expiry_ts = res + self._schedule_expiry_for_event(event_id, expiry_ts) + + def _schedule_expiry_for_event(self, event_id, expiry_ts): + """Schedule an expiry task for the provided event if there's not already one + scheduled at a timestamp that's sooner than the provided one. + + Args: + event_id (str): The ID of the event to expire. + expiry_ts (int): The timestamp at which to expire the event. + """ + if self._scheduled_expiry: + # If the provided timestamp refers to a time before the scheduled time of the + # next expiry task, cancel that task and reschedule it for this timestamp. + next_scheduled_expiry_ts = self._scheduled_expiry.getTime() * 1000 + if expiry_ts < next_scheduled_expiry_ts: + self._scheduled_expiry.cancel() + else: + return + + # Figure out how many seconds we need to wait before expiring the event. + now_ms = self.clock.time_msec() + delay = (expiry_ts - now_ms) / 1000 + + # callLater doesn't support negative delays, so trim the delay to 0 if we're + # in that case. + if delay < 0: + delay = 0 + + logger.info("Scheduling expiry for event %s in %.3fs", event_id, delay) + + self._scheduled_expiry = self.clock.call_later( + delay, + run_as_background_process, + "_expire_event", + self._expire_event, + event_id, + ) + + @defer.inlineCallbacks + def _expire_event(self, event_id): + """Retrieve and expire an event that needs to be expired from the database. + + If the event doesn't exist in the database, log it and delete the expiry date + from the database (so that we don't try to expire it again). + """ + assert self._ephemeral_events_enabled + + self._scheduled_expiry = None + + logger.info("Expiring event %s", event_id) + + try: + # Expire the event if we know about it. This function also deletes the expiry + # date from the database in the same database transaction. + yield self.store.expire_event(event_id) + except Exception as e: + logger.error("Could not expire event %s: %r", event_id, e) + + # Schedule the expiry of the next event to expire. + yield self._schedule_next_expiry() + # The duration (in ms) after which rooms should be removed # `_rooms_to_exclude_from_dummy_event_insertion` (with the effect that we will try @@ -295,6 +408,10 @@ class EventCreationHandler(object): 5 * 60 * 1000, ) + self._message_handler = hs.get_message_handler() + + self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def create_event( self, @@ -417,7 +534,7 @@ class EventCreationHandler(object): 403, "You must be in the room to create an alias for it" ) - self.validator.validate_new(event) + self.validator.validate_new(event, self.config) return (event, context) @@ -634,7 +751,7 @@ class EventCreationHandler(object): if requester: context.app_service = requester.app_service - self.validator.validate_new(event) + self.validator.validate_new(event, self.config) # If this event is an annotation then we check that that the sender # can't annotate the same way twice (e.g. stops users from liking an @@ -877,6 +994,10 @@ class EventCreationHandler(object): event, context=context ) + if self._ephemeral_events_enabled: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) + yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _notify(): diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 260a4351ca..8514ddc600 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -15,12 +15,15 @@ # limitations under the License. import logging +from six import iteritems + from twisted.internet import defer from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.logging.context import run_in_background +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter from synapse.types import RoomStreamToken from synapse.util.async_helpers import ReadWriteLock @@ -80,6 +83,109 @@ class PaginationHandler(object): self._purges_by_id = {} self._event_serializer = hs.get_event_client_serializer() + self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime + + if hs.config.retention_enabled: + # Run the purge jobs described in the configuration file. + for job in hs.config.retention_purge_jobs: + self.clock.looping_call( + run_as_background_process, + job["interval"], + "purge_history_for_rooms_in_range", + self.purge_history_for_rooms_in_range, + job["shortest_max_lifetime"], + job["longest_max_lifetime"], + ) + + @defer.inlineCallbacks + def purge_history_for_rooms_in_range(self, min_ms, max_ms): + """Purge outdated events from rooms within the given retention range. + + If a default retention policy is defined in the server's configuration and its + 'max_lifetime' is within this range, also targets rooms which don't have a + retention policy. + + Args: + min_ms (int|None): Duration in milliseconds that define the lower limit of + the range to handle (exclusive). If None, it means that the range has no + lower limit. + max_ms (int|None): Duration in milliseconds that define the upper limit of + the range to handle (inclusive). If None, it means that the range has no + upper limit. + """ + # We want the storage layer to to include rooms with no retention policy in its + # return value only if a default retention policy is defined in the server's + # configuration and that policy's 'max_lifetime' is either lower (or equal) than + # max_ms or higher than min_ms (or both). + if self._retention_default_max_lifetime is not None: + include_null = True + + if min_ms is not None and min_ms >= self._retention_default_max_lifetime: + # The default max_lifetime is lower than (or equal to) min_ms. + include_null = False + + if max_ms is not None and max_ms < self._retention_default_max_lifetime: + # The default max_lifetime is higher than max_ms. + include_null = False + else: + include_null = False + + rooms = yield self.store.get_rooms_for_retention_period_in_range( + min_ms, max_ms, include_null + ) + + for room_id, retention_policy in iteritems(rooms): + if room_id in self._purges_in_progress_by_room: + logger.warning( + "[purge] not purging room %s as there's an ongoing purge running" + " for this room", + room_id, + ) + continue + + max_lifetime = retention_policy["max_lifetime"] + + if max_lifetime is None: + # If max_lifetime is None, it means that include_null equals True, + # therefore we can safely assume that there is a default policy defined + # in the server's configuration. + max_lifetime = self._retention_default_max_lifetime + + # Figure out what token we should start purging at. + ts = self.clock.time_msec() - max_lifetime + + stream_ordering = yield self.store.find_first_stream_ordering_after_ts(ts) + + r = yield self.store.get_room_event_after_stream_ordering( + room_id, stream_ordering, + ) + if not r: + logger.warning( + "[purge] purging events not possible: No event found " + "(ts %i => stream_ordering %i)", + ts, + stream_ordering, + ) + continue + + (stream, topo, _event_id) = r + token = "t%d-%d" % (topo, stream) + + purge_id = random_string(16) + + self._purges_by_id[purge_id] = PurgeStatus() + + logger.info( + "Starting purging events in room %s (purge_id %s)" % (room_id, purge_id) + ) + + # We want to purge everything, including local events, and to run the purge in + # the background so that it's not blocking any other operation apart from + # other purges in the same room. + run_as_background_process( + "_purge_history", self._purge_history, purge_id, room_id, token, True, + ) + def start_purge_history(self, room_id, token, delete_local_events=False): """Start off a history purge on a room. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 95806af41e..8a7d965feb 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -266,7 +266,7 @@ class RegistrationHandler(BaseHandler): } # Bind email to new account - yield self._register_email_threepid(user_id, threepid_dict, None, False) + yield self._register_email_threepid(user_id, threepid_dict, None) return user_id diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index e92b2eafd5..22768e97ff 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -198,21 +199,21 @@ class RoomCreationHandler(BaseHandler): # finally, shut down the PLs in the old room, and update them in the new # room. yield self._update_upgraded_room_pls( - requester, old_room_id, new_room_id, old_room_state + requester, old_room_id, new_room_id, old_room_state, ) return new_room_id @defer.inlineCallbacks def _update_upgraded_room_pls( - self, requester, old_room_id, new_room_id, old_room_state + self, requester, old_room_id, new_room_id, old_room_state, ): """Send updated power levels in both rooms after an upgrade Args: requester (synapse.types.Requester): the user requesting the upgrade - old_room_id (unicode): the id of the room to be replaced - new_room_id (unicode): the id of the replacement room + old_room_id (str): the id of the room to be replaced + new_room_id (str): the id of the replacement room old_room_state (dict[tuple[str, str], str]): the state map for the old room Returns: @@ -298,7 +299,7 @@ class RoomCreationHandler(BaseHandler): tombstone_event_id (unicode|str): the ID of the tombstone event in the old room. Returns: - Deferred[None] + Deferred """ user_id = requester.user.to_string() @@ -333,6 +334,7 @@ class RoomCreationHandler(BaseHandler): (EventTypes.Encryption, ""), (EventTypes.ServerACL, ""), (EventTypes.RelatedGroups, ""), + (EventTypes.PowerLevels, ""), ) old_room_state_ids = yield self.store.get_filtered_current_state_ids( @@ -346,6 +348,31 @@ class RoomCreationHandler(BaseHandler): if old_event: initial_state[k] = old_event.content + # Resolve the minimum power level required to send any state event + # We will give the upgrading user this power level temporarily (if necessary) such that + # they are able to copy all of the state events over, then revert them back to their + # original power level afterwards in _update_upgraded_room_pls + + # Copy over user power levels now as this will not be possible with >100PL users once + # the room has been created + + power_levels = initial_state[(EventTypes.PowerLevels, "")] + + # Calculate the minimum power level needed to clone the room + event_power_levels = power_levels.get("events", {}) + state_default = power_levels.get("state_default", 0) + ban = power_levels.get("ban") + needed_power_level = max(state_default, ban, max(event_power_levels.values())) + + # Raise the requester's power level in the new room if necessary + current_power_level = power_levels["users"][requester.user.to_string()] + if current_power_level < needed_power_level: + # Assign this power level to the requester + power_levels["users"][requester.user.to_string()] = needed_power_level + + # Set the power levels to the modified state + initial_state[(EventTypes.PowerLevels, "")] = power_levels + yield self._send_events_for_new_room( requester, new_room_id, @@ -874,6 +901,10 @@ class RoomContextHandler(object): room_id, event_id, before_limit, after_limit, event_filter ) + if event_filter: + results["events_before"] = event_filter.filter(results["events_before"]) + results["events_after"] = event_filter.filter(results["events_after"]) + results["events_before"] = yield filter_evts(results["events_before"]) results["events_after"] = yield filter_evts(results["events_after"]) results["event"] = event @@ -902,7 +933,12 @@ class RoomContextHandler(object): state = yield self.state_store.get_state_for_events( [last_event_id], state_filter=state_filter ) - results["state"] = list(state[last_event_id].values()) + + state_events = list(state[last_event_id].values()) + if event_filter: + state_events = event_filter.filter(state_events) + + results["state"] = state_events # We use a dummy token here as we only care about the room portion of # the token, which we replace. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 6cfee4b361..7b7270fc61 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -94,7 +94,9 @@ class RoomMemberHandler(object): raise NotImplementedError() @abc.abstractmethod - def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): + def _remote_reject_invite( + self, requester, remote_room_hosts, room_id, target, content + ): """Attempt to reject an invite for a room this server is not in. If we fail to do so we locally mark the invite as rejected. @@ -104,6 +106,7 @@ class RoomMemberHandler(object): reject invite room_id (str) target (UserID): The user rejecting the invite + content (dict): The content for the rejection event Returns: Deferred[dict]: A dictionary to be returned to the client, may @@ -471,7 +474,7 @@ class RoomMemberHandler(object): # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] res = yield self._remote_reject_invite( - requester, remote_room_hosts, room_id, target + requester, remote_room_hosts, room_id, target, content, ) return res @@ -971,13 +974,15 @@ class RoomMemberMasterHandler(RoomMemberHandler): ) @defer.inlineCallbacks - def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): + def _remote_reject_invite( + self, requester, remote_room_hosts, room_id, target, content + ): """Implements RoomMemberHandler._remote_reject_invite """ fed_handler = self.federation_handler try: ret = yield fed_handler.do_remotely_reject_invite( - remote_room_hosts, room_id, target.to_string() + remote_room_hosts, room_id, target.to_string(), content=content, ) return ret except Exception as e: diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 75e96ae1a2..69be86893b 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -55,7 +55,9 @@ class RoomMemberWorkerHandler(RoomMemberHandler): return ret - def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target): + def _remote_reject_invite( + self, requester, remote_room_hosts, room_id, target, content + ): """Implements RoomMemberHandler._remote_reject_invite """ return self._remote_reject_client( @@ -63,6 +65,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): remote_room_hosts=remote_room_hosts, room_id=room_id, user_id=target.to_string(), + content=content, ) def _user_joined_room(self, target, room_id): diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index e9a5e46ced..13fcb408a6 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -96,7 +96,7 @@ def parse_boolean_from_args(args, name, default=None, required=False): return {b"true": True, b"false": False}[args[name][0]] except Exception: message = ( - "Boolean query parameter %r must be one of" " ['true', 'false']" + "Boolean query parameter %r must be one of ['true', 'false']" ) % (name,) raise SynapseError(400, message) else: diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index 334ddaf39a..ffa7b20ca8 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -261,6 +261,18 @@ def parse_drain_configs( ) +class StoppableLogPublisher(LogPublisher): + """ + A log publisher that can tell its observers to shut down any external + communications. + """ + + def stop(self): + for obs in self._observers: + if hasattr(obs, "stop"): + obs.stop() + + def setup_structured_logging( hs, config, @@ -336,7 +348,7 @@ def setup_structured_logging( # We should never get here, but, just in case, throw an error. raise ConfigError("%s drain type cannot be configured" % (observer.type,)) - publisher = LogPublisher(*observers) + publisher = StoppableLogPublisher(*observers) log_filter = LogLevelFilterPredicate() for namespace, namespace_config in log_config.get( diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py index 76ce7d8808..03934956f4 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py @@ -17,25 +17,29 @@ Log formatters that output terse JSON. """ +import json import sys +import traceback from collections import deque from ipaddress import IPv4Address, IPv6Address, ip_address from math import floor -from typing import IO +from typing import IO, Optional import attr -from simplejson import dumps from zope.interface import implementer from twisted.application.internet import ClientService +from twisted.internet.defer import Deferred from twisted.internet.endpoints import ( HostnameEndpoint, TCP4ClientEndpoint, TCP6ClientEndpoint, ) +from twisted.internet.interfaces import IPushProducer, ITransport from twisted.internet.protocol import Factory, Protocol from twisted.logger import FileLogObserver, ILogObserver, Logger -from twisted.python.failure import Failure + +_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":")) def flatten_event(event: dict, metadata: dict, include_time: bool = False): @@ -141,12 +145,50 @@ def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogOb def formatEvent(_event: dict) -> str: flattened = flatten_event(_event, metadata) - return dumps(flattened, ensure_ascii=False, separators=(",", ":")) + "\n" + return _encoder.encode(flattened) + "\n" return FileLogObserver(outFile, formatEvent) @attr.s +@implementer(IPushProducer) +class LogProducer(object): + """ + An IPushProducer that writes logs from its buffer to its transport when it + is resumed. + + Args: + buffer: Log buffer to read logs from. + transport: Transport to write to. + """ + + transport = attr.ib(type=ITransport) + _buffer = attr.ib(type=deque) + _paused = attr.ib(default=False, type=bool, init=False) + + def pauseProducing(self): + self._paused = True + + def stopProducing(self): + self._paused = True + self._buffer = None + + def resumeProducing(self): + self._paused = False + + while self._paused is False and (self._buffer and self.transport.connected): + try: + event = self._buffer.popleft() + self.transport.write(_encoder.encode(event).encode("utf8")) + self.transport.write(b"\n") + except Exception: + # Something has gone wrong writing to the transport -- log it + # and break out of the while. + traceback.print_exc(file=sys.__stderr__) + break + + +@attr.s @implementer(ILogObserver) class TerseJSONToTCPLogObserver(object): """ @@ -165,8 +207,9 @@ class TerseJSONToTCPLogObserver(object): metadata = attr.ib(type=dict) maximum_buffer = attr.ib(type=int) _buffer = attr.ib(default=attr.Factory(deque), type=deque) - _writer = attr.ib(default=None) + _connection_waiter = attr.ib(default=None, type=Optional[Deferred]) _logger = attr.ib(default=attr.Factory(Logger)) + _producer = attr.ib(default=None, type=Optional[LogProducer]) def start(self) -> None: @@ -187,38 +230,44 @@ class TerseJSONToTCPLogObserver(object): factory = Factory.forProtocol(Protocol) self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor()) self._service.startService() + self._connect() - def _write_loop(self) -> None: + def stop(self): + self._service.stopService() + + def _connect(self) -> None: """ - Implement the write loop. + Triggers an attempt to connect then write to the remote if not already writing. """ - if self._writer: + if self._connection_waiter: return - self._writer = self._service.whenConnected() + self._connection_waiter = self._service.whenConnected(failAfterFailures=1) + + @self._connection_waiter.addErrback + def fail(r): + r.printTraceback(file=sys.__stderr__) + self._connection_waiter = None + self._connect() - @self._writer.addBoth + @self._connection_waiter.addCallback def writer(r): - if isinstance(r, Failure): - r.printTraceback(file=sys.__stderr__) - self._writer = None - self.hs.get_reactor().callLater(1, self._write_loop) + # We have a connection. If we already have a producer, and its + # transport is the same, just trigger a resumeProducing. + if self._producer and r.transport is self._producer.transport: + self._producer.resumeProducing() + self._connection_waiter = None return - try: - for event in self._buffer: - r.transport.write( - dumps(event, ensure_ascii=False, separators=(",", ":")).encode( - "utf8" - ) - ) - r.transport.write(b"\n") - self._buffer.clear() - except Exception as e: - sys.__stderr__.write("Failed writing out logs with %s\n" % (str(e),)) - - self._writer = False - self.hs.get_reactor().callLater(1, self._write_loop) + # If the producer is still producing, stop it. + if self._producer: + self._producer.stopProducing() + + # Make a new producer and start it. + self._producer = LogProducer(buffer=self._buffer, transport=r.transport) + r.transport.registerProducer(self._producer, True) + self._producer.resumeProducing() + self._connection_waiter = None def _handle_pressure(self) -> None: """ @@ -277,4 +326,4 @@ class TerseJSONToTCPLogObserver(object): self._logger.failure("Failed clearing backpressure") # Try and write immediately. - self._write_loop() + self._connect() diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 1ba7bcd4d8..7881780760 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -386,15 +386,7 @@ class RulesForRoom(object): """ sequence = self.sequence - rows = yield self.store._simple_select_many_batch( - table="room_memberships", - column="event_id", - iterable=member_event_ids.values(), - retcols=("user_id", "membership", "event_id"), - keyvalues={}, - batch_size=500, - desc="_get_rules_for_member_event_ids", - ) + rows = yield self.store.get_membership_from_event_ids(member_event_ids.values()) members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows} diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index e994037be6..d0879b0490 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -246,7 +246,7 @@ class HttpPusher(object): # fixed, we don't suddenly deliver a load # of old notifications. logger.warning( - "Giving up on a notification to user %s, " "pushkey %s", + "Giving up on a notification to user %s, pushkey %s", self.user_id, self.pushkey, ) @@ -299,8 +299,7 @@ class HttpPusher(object): # for sanity, we only remove the pushkey if it # was the one we actually sent... logger.warning( - ("Ignoring rejected pushkey %s because we" " didn't send it"), - pk, + ("Ignoring rejected pushkey %s because we didn't send it"), pk, ) else: logger.info("Pushkey %s was rejected: removing", pk) diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 1d15a06a58..b13b646bfd 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -43,7 +43,7 @@ logger = logging.getLogger(__name__) MESSAGE_FROM_PERSON_IN_ROOM = ( - "You have a message on %(app)s from %(person)s " "in the %(room)s room..." + "You have a message on %(app)s from %(person)s in the %(room)s room..." ) MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..." MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..." @@ -55,7 +55,7 @@ MESSAGES_FROM_PERSON_AND_OTHERS = ( "You have messages on %(app)s from %(person)s and others..." ) INVITE_FROM_PERSON_TO_ROOM = ( - "%(person)s has invited you to join the " "%(room)s room on %(app)s..." + "%(person)s has invited you to join the %(room)s room on %(app)s..." ) INVITE_FROM_PERSON = "%(person)s has invited you to chat on %(app)s..." diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index 81b85352b1..28dbc6fcba 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -14,7 +14,14 @@ # limitations under the License. from synapse.http.server import JsonResource -from synapse.replication.http import federation, login, membership, register, send_event +from synapse.replication.http import ( + devices, + federation, + login, + membership, + register, + send_event, +) REPLICATION_PREFIX = "/_synapse/replication" @@ -30,3 +37,4 @@ class ReplicationRestResource(JsonResource): federation.register_servlets(hs, self) login.register_servlets(hs, self) register.register_servlets(hs, self) + devices.register_servlets(hs, self) diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py new file mode 100644 index 0000000000..e32aac0a25 --- /dev/null +++ b/synapse/replication/http/devices.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.replication.http._base import ReplicationEndpoint + +logger = logging.getLogger(__name__) + + +class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): + """Ask master to resync the device list for a user by contacting their + server. + + This must happen on master so that the results can be correctly cached in + the database and streamed to workers. + + Request format: + + POST /_synapse/replication/user_device_resync/:user_id + + {} + + Response is equivalent to ` /_matrix/federation/v1/user/devices/:user_id` + response, e.g.: + + { + "user_id": "@alice:example.org", + "devices": [ + { + "device_id": "JLAFKJWSCS", + "keys": { ... }, + "device_display_name": "Alice's Mobile Phone" + } + ] + } + """ + + NAME = "user_device_resync" + PATH_ARGS = ("user_id",) + CACHE = False + + def __init__(self, hs): + super(ReplicationUserDevicesResyncRestServlet, self).__init__(hs) + + self.device_list_updater = hs.get_device_handler().device_list_updater + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + @staticmethod + def _serialize_payload(user_id): + return {} + + async def _handle_request(self, request, user_id): + user_devices = await self.device_list_updater.user_device_resync(user_id) + + return 200, user_devices + + +def register_servlets(hs, http_server): + ReplicationUserDevicesResyncRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index cc1f249740..3577611fd7 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -93,6 +93,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): { "requester": ..., "remote_room_hosts": [...], + "content": { ... } } """ @@ -107,7 +108,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - def _serialize_payload(requester, room_id, user_id, remote_room_hosts): + def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content): """ Args: requester(Requester) @@ -118,12 +119,14 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): return { "requester": requester.serialize(), "remote_room_hosts": remote_room_hosts, + "content": content, } async def _handle_request(self, request, room_id, user_id): content = parse_json_object_from_request(request) remote_room_hosts = content["remote_room_hosts"] + event_content = content["content"] requester = Requester.deserialize(self.store, content["requester"]) @@ -134,7 +137,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): try: event = await self.federation_handler.do_remotely_reject_invite( - remote_room_hosts, room_id, user_id + remote_room_hosts, room_id, user_id, event_content, ) ret = event.get_pdu_json() except Exception as e: diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 456bc005a0..6ece1d6745 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -18,7 +18,8 @@ from typing import Dict import six -from synapse.storage._base import _CURRENT_STATE_CACHE_NAME, SQLBaseStore +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME from synapse.storage.engines import PostgresEngine from ._slaved_id_tracker import SlavedIdTracker @@ -62,7 +63,7 @@ class BaseSlavedStore(SQLBaseStore): if stream_name == "caches": self._cache_id_gen.advance(token) for row in rows: - if row.cache_func == _CURRENT_STATE_CACHE_NAME: + if row.cache_func == CURRENT_STATE_CACHE_NAME: room_id = row.keys[0] members_changed = set(row.keys[1:]) self._invalidate_state_caches(room_id, members_changed) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 9e45429d49..8512923eae 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -88,8 +88,7 @@ TagAccountDataStreamRow = namedtuple( "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict ) AccountDataStreamRow = namedtuple( - "AccountDataStream", - ("user_id", "room_id", "data_type", "data"), # str # str # str # dict + "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str ) GroupsStreamRow = namedtuple( "GroupsStreamRow", @@ -421,8 +420,8 @@ class AccountDataStream(Stream): results = list(room_results) results.extend( - (stream_id, user_id, None, account_data_type, content) - for stream_id, user_id, account_data_type, content in global_results + (stream_id, user_id, None, account_data_type) + for stream_id, user_id, account_data_type in global_results ) return results diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 86bbcc0eea..711d4ad304 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -714,7 +714,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): target = UserID.from_string(content["user_id"]) event_content = None - if "reason" in content and membership_action in ["kick", "ban"]: + if "reason" in content: event_content = {"reason": content["reason"]} await self.room_member_handler.update_membership( diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index f26eae794c..ad674239ab 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -642,6 +642,7 @@ class ThreepidAddRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): requester = yield self.auth.get_user_by_req(request) @@ -652,6 +653,10 @@ class ThreepidAddRestServlet(RestServlet): client_secret = body["client_secret"] sid = body["sid"] + yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request) + ) + validation_session = yield self.identity_handler.validate_threepid_session( client_secret, sid ) diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index d596786430..d83ac8e3c5 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -134,8 +134,8 @@ class RoomKeysServlet(RestServlet): if room_id: body = {"rooms": {room_id: body}} - yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) - return 200, {} + ret = yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) + return 200, ret @defer.inlineCallbacks def on_GET(self, request, room_id, session_id): @@ -239,10 +239,10 @@ class RoomKeysServlet(RestServlet): user_id = requester.user.to_string() version = parse_string(request, "version") - yield self.e2e_room_keys_handler.delete_room_keys( + ret = yield self.e2e_room_keys_handler.delete_room_keys( user_id, version, room_id, session_id ) - return 200, {} + return 200, ret class RoomKeysNewVersionServlet(RestServlet): diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index bb30ce3f34..2a477ad22e 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 87343d9db9..fb0d02aa83 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -122,7 +122,7 @@ class PreviewUrlResource(DirectServeResource): pattern = entry[attrib] value = getattr(url_tuple, attrib) logger.debug( - "Matching attrib '%s' with value '%s' against" " pattern '%s'", + "Matching attrib '%s' with value '%s' against pattern '%s'", attrib, value, pattern, diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 8cf415e29d..c234ea7421 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -129,5 +129,8 @@ class Thumbnailer(object): def _encode_image(self, output_image, output_type): output_bytes_io = BytesIO() - output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80) + fmt = self.FORMATS[output_type] + if fmt == "JPEG": + output_image = output_image.convert("RGB") + output_image.save(output_bytes_io, fmt, quality=80) return output_bytes_io diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 415e9c17d8..5736c56032 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -54,7 +54,7 @@ class ConsentServerNotices(object): ) if "body" not in self._server_notice_content: raise ConfigError( - "user_consent server_notice_consent must contain a 'body' " "key." + "user_consent server_notice_consent must contain a 'body' key." ) self._consent_uri_builder = ConsentURIBuilder(hs.config) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index ab596fa68d..9205e550bb 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -14,11 +14,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging import random import sys -import threading import time from typing import Iterable, Tuple @@ -35,8 +33,6 @@ from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import get_domain_from_id -from synapse.util import batch_iter -from synapse.util.caches.descriptors import Cache from synapse.util.stringutils import exception_to_unicode # import a function which will return a monotonic time, in seconds @@ -79,10 +75,6 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { "event_search": "event_search_event_id_idx", } -# This is a special cache name we use to batch multiple invalidations of caches -# based on the current state when notifying workers over replication. -_CURRENT_STATE_CACHE_NAME = "cs_cache_fake" - class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object @@ -237,23 +229,11 @@ class SQLBaseStore(object): # to watch it self._txn_perf_counters = PerformanceCounters() - self._get_event_cache = Cache( - "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size - ) - - self._event_fetch_lock = threading.Condition() - self._event_fetch_list = [] - self._event_fetch_ongoing = 0 - - self._pending_ds = [] - self.database_engine = hs.database_engine # A set of tables that are not safe to use native upserts in. self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) - self._account_validity = self.hs.config.account_validity - # We add the user_directory_search table to the blacklist on SQLite # because the existing search table does not have an index, making it # unsafe to use native upserts. @@ -272,14 +252,6 @@ class SQLBaseStore(object): self.rand = random.SystemRandom() - if self._account_validity.enabled: - self._clock.call_later( - 0.0, - run_as_background_process, - "account_validity_set_expiration_dates", - self._set_expiration_date_when_missing, - ) - @defer.inlineCallbacks def _check_safe_to_upsert(self): """ @@ -290,7 +262,7 @@ class SQLBaseStore(object): If the background updates have not completed, wait 15 sec and check again. """ - updates = yield self._simple_select_list( + updates = yield self.simple_select_list( "background_updates", keyvalues=None, retcols=["update_name"], @@ -312,62 +284,6 @@ class SQLBaseStore(object): self._check_safe_to_upsert, ) - @defer.inlineCallbacks - def _set_expiration_date_when_missing(self): - """ - Retrieves the list of registered users that don't have an expiration date, and - adds an expiration date for each of them. - """ - - def select_users_with_no_expiration_date_txn(txn): - """Retrieves the list of registered users with no expiration date from the - database, filtering out deactivated users. - """ - sql = ( - "SELECT users.name FROM users" - " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" - " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" - ) - txn.execute(sql, []) - - res = self.cursor_to_dict(txn) - if res: - for user in res: - self.set_expiration_date_for_user_txn( - txn, user["name"], use_delta=True - ) - - yield self.runInteraction( - "get_users_with_no_expiration_date", - select_users_with_no_expiration_date_txn, - ) - - def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): - """Sets an expiration date to the account with the given user ID. - - Args: - user_id (str): User ID to set an expiration date for. - use_delta (bool): If set to False, the expiration date for the user will be - now + validity period. If set to True, this expiration date will be a - random value in the [now + period - d ; now + period] range, d being a - delta equal to 10% of the validity period. - """ - now_ms = self._clock.time_msec() - expiration_ts = now_ms + self._account_validity.period - - if use_delta: - expiration_ts = self.rand.randrange( - expiration_ts - self._account_validity.startup_job_max_delta, - expiration_ts, - ) - - self._simple_upsert_txn( - txn, - "account_validity", - keyvalues={"user_id": user_id}, - values={"expiration_ts_ms": expiration_ts, "email_sent": False}, - ) - def start_profiling(self): self._previous_loop_ts = monotonic_time() @@ -391,7 +307,7 @@ class SQLBaseStore(object): self._clock.looping_call(loop, 10000) - def _new_transaction( + def new_transaction( self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs ): start = monotonic_time() @@ -409,16 +325,15 @@ class SQLBaseStore(object): i = 0 N = 5 while True: + cursor = LoggingTransaction( + conn.cursor(), + name, + self.database_engine, + after_callbacks, + exception_callbacks, + ) try: - txn = conn.cursor() - txn = LoggingTransaction( - txn, - name, - self.database_engine, - after_callbacks, - exception_callbacks, - ) - r = func(txn, *args, **kwargs) + r = func(cursor, *args, **kwargs) conn.commit() return r except self.database_engine.module.OperationalError as e: @@ -456,6 +371,40 @@ class SQLBaseStore(object): ) continue raise + finally: + # we're either about to retry with a new cursor, or we're about to + # release the connection. Once we release the connection, it could + # get used for another query, which might do a conn.rollback(). + # + # In the latter case, even though that probably wouldn't affect the + # results of this transaction, python's sqlite will reset all + # statements on the connection [1], which will make our cursor + # invalid [2]. + # + # In any case, continuing to read rows after commit()ing seems + # dubious from the PoV of ACID transactional semantics + # (sqlite explicitly says that once you commit, you may see rows + # from subsequent updates.) + # + # In psycopg2, cursors are essentially a client-side fabrication - + # all the data is transferred to the client side when the statement + # finishes executing - so in theory we could go on streaming results + # from the cursor, but attempting to do so would make us + # incompatible with sqlite, so let's make sure we're not doing that + # by closing the cursor. + # + # (*named* cursors in psycopg2 are different and are proper server- + # side things, but (a) we don't use them and (b) they are implicitly + # closed by ending the transaction anyway.) + # + # In short, if we haven't finished with the cursor yet, that's a + # problem waiting to bite us. + # + # TL;DR: we're done with the cursor, so we can close it. + # + # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465 + # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236 + cursor.close() except Exception as e: logger.debug("[TXN FAIL] {%s} %s", name, e) raise @@ -495,7 +444,7 @@ class SQLBaseStore(object): try: result = yield self.runWithConnection( - self._new_transaction, + self.new_transaction, desc, after_callbacks, exception_callbacks, @@ -567,7 +516,7 @@ class SQLBaseStore(object): results = list(dict(zip(col_headers, row)) for row in cursor) return results - def _execute(self, desc, decoder, query, *args): + def execute(self, desc, decoder, query, *args): """Runs a single query for a result set. Args: @@ -592,7 +541,7 @@ class SQLBaseStore(object): # no complex WHERE clauses, just a dict of values for columns. @defer.inlineCallbacks - def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"): + def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"): """Executes an INSERT query on the named table. Args: @@ -608,7 +557,7 @@ class SQLBaseStore(object): `or_ignore` is True """ try: - yield self.runInteraction(desc, self._simple_insert_txn, table, values) + yield self.runInteraction(desc, self.simple_insert_txn, table, values) except self.database_engine.module.IntegrityError: # We have to do or_ignore flag at this layer, since we can't reuse # a cursor after we receive an error from the db. @@ -618,7 +567,7 @@ class SQLBaseStore(object): return True @staticmethod - def _simple_insert_txn(txn, table, values): + def simple_insert_txn(txn, table, values): keys, vals = zip(*values.items()) sql = "INSERT INTO %s (%s) VALUES(%s)" % ( @@ -629,11 +578,11 @@ class SQLBaseStore(object): txn.execute(sql, vals) - def _simple_insert_many(self, table, values, desc): - return self.runInteraction(desc, self._simple_insert_many_txn, table, values) + def simple_insert_many(self, table, values, desc): + return self.runInteraction(desc, self.simple_insert_many_txn, table, values) @staticmethod - def _simple_insert_many_txn(txn, table, values): + def simple_insert_many_txn(txn, table, values): if not values: return @@ -662,13 +611,13 @@ class SQLBaseStore(object): txn.executemany(sql, vals) @defer.inlineCallbacks - def _simple_upsert( + def simple_upsert( self, table, keyvalues, values, insertion_values={}, - desc="_simple_upsert", + desc="simple_upsert", lock=True, ): """ @@ -700,7 +649,7 @@ class SQLBaseStore(object): try: result = yield self.runInteraction( desc, - self._simple_upsert_txn, + self.simple_upsert_txn, table, keyvalues, values, @@ -720,7 +669,7 @@ class SQLBaseStore(object): "IntegrityError when upserting into %s; retrying: %s", table, e ) - def _simple_upsert_txn( + def simple_upsert_txn( self, txn, table, keyvalues, values, insertion_values={}, lock=True ): """ @@ -744,11 +693,11 @@ class SQLBaseStore(object): self.database_engine.can_native_upsert and table not in self._unsafe_to_upsert_tables ): - return self._simple_upsert_txn_native_upsert( + return self.simple_upsert_txn_native_upsert( txn, table, keyvalues, values, insertion_values=insertion_values ) else: - return self._simple_upsert_txn_emulated( + return self.simple_upsert_txn_emulated( txn, table, keyvalues, @@ -757,7 +706,7 @@ class SQLBaseStore(object): lock=lock, ) - def _simple_upsert_txn_emulated( + def simple_upsert_txn_emulated( self, txn, table, keyvalues, values, insertion_values={}, lock=True ): """ @@ -826,7 +775,7 @@ class SQLBaseStore(object): # successfully inserted return True - def _simple_upsert_txn_native_upsert( + def simple_upsert_txn_native_upsert( self, txn, table, keyvalues, values, insertion_values={} ): """ @@ -851,7 +800,7 @@ class SQLBaseStore(object): allvalues.update(values) latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) - sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % ( + sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % ( table, ", ".join(k for k in allvalues), ", ".join("?" for _ in allvalues), @@ -860,7 +809,7 @@ class SQLBaseStore(object): ) txn.execute(sql, list(allvalues.values())) - def _simple_upsert_many_txn( + def simple_upsert_many_txn( self, txn, table, key_names, key_values, value_names, value_values ): """ @@ -880,15 +829,15 @@ class SQLBaseStore(object): self.database_engine.can_native_upsert and table not in self._unsafe_to_upsert_tables ): - return self._simple_upsert_many_txn_native_upsert( + return self.simple_upsert_many_txn_native_upsert( txn, table, key_names, key_values, value_names, value_values ) else: - return self._simple_upsert_many_txn_emulated( + return self.simple_upsert_many_txn_emulated( txn, table, key_names, key_values, value_names, value_values ) - def _simple_upsert_many_txn_emulated( + def simple_upsert_many_txn_emulated( self, txn, table, key_names, key_values, value_names, value_values ): """ @@ -913,9 +862,9 @@ class SQLBaseStore(object): _keys = {x: y for x, y in zip(key_names, keyv)} _vals = {x: y for x, y in zip(value_names, valv)} - self._simple_upsert_txn_emulated(txn, table, _keys, _vals) + self.simple_upsert_txn_emulated(txn, table, _keys, _vals) - def _simple_upsert_many_txn_native_upsert( + def simple_upsert_many_txn_native_upsert( self, txn, table, key_names, key_values, value_names, value_values ): """ @@ -960,8 +909,8 @@ class SQLBaseStore(object): return txn.execute_batch(sql, args) - def _simple_select_one( - self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one" + def simple_select_one( + self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one" ): """Executes a SELECT query on the named table, which is expected to return a single row, returning multiple columns from it. @@ -975,16 +924,16 @@ class SQLBaseStore(object): statement returns no rows """ return self.runInteraction( - desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none + desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none ) - def _simple_select_one_onecol( + def simple_select_one_onecol( self, table, keyvalues, retcol, allow_none=False, - desc="_simple_select_one_onecol", + desc="simple_select_one_onecol", ): """Executes a SELECT query on the named table, which is expected to return a single row, returning a single column from it. @@ -996,7 +945,7 @@ class SQLBaseStore(object): """ return self.runInteraction( desc, - self._simple_select_one_onecol_txn, + self.simple_select_one_onecol_txn, table, keyvalues, retcol, @@ -1004,10 +953,10 @@ class SQLBaseStore(object): ) @classmethod - def _simple_select_one_onecol_txn( + def simple_select_one_onecol_txn( cls, txn, table, keyvalues, retcol, allow_none=False ): - ret = cls._simple_select_onecol_txn( + ret = cls.simple_select_onecol_txn( txn, table=table, keyvalues=keyvalues, retcol=retcol ) @@ -1020,7 +969,7 @@ class SQLBaseStore(object): raise StoreError(404, "No row found") @staticmethod - def _simple_select_onecol_txn(txn, table, keyvalues, retcol): + def simple_select_onecol_txn(txn, table, keyvalues, retcol): sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} if keyvalues: @@ -1031,8 +980,8 @@ class SQLBaseStore(object): return [r[0] for r in txn] - def _simple_select_onecol( - self, table, keyvalues, retcol, desc="_simple_select_onecol" + def simple_select_onecol( + self, table, keyvalues, retcol, desc="simple_select_onecol" ): """Executes a SELECT query on the named table, which returns a list comprising of the values of the named column from the selected rows. @@ -1046,12 +995,10 @@ class SQLBaseStore(object): Deferred: Results in a list """ return self.runInteraction( - desc, self._simple_select_onecol_txn, table, keyvalues, retcol + desc, self.simple_select_onecol_txn, table, keyvalues, retcol ) - def _simple_select_list( - self, table, keyvalues, retcols, desc="_simple_select_list" - ): + def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1065,11 +1012,11 @@ class SQLBaseStore(object): defer.Deferred: resolves to list[dict[str, Any]] """ return self.runInteraction( - desc, self._simple_select_list_txn, table, keyvalues, retcols + desc, self.simple_select_list_txn, table, keyvalues, retcols ) @classmethod - def _simple_select_list_txn(cls, txn, table, keyvalues, retcols): + def simple_select_list_txn(cls, txn, table, keyvalues, retcols): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1095,14 +1042,14 @@ class SQLBaseStore(object): return cls.cursor_to_dict(txn) @defer.inlineCallbacks - def _simple_select_many_batch( + def simple_select_many_batch( self, table, column, iterable, retcols, keyvalues={}, - desc="_simple_select_many_batch", + desc="simple_select_many_batch", batch_size=100, ): """Executes a SELECT query on the named table, which may return zero or @@ -1131,7 +1078,7 @@ class SQLBaseStore(object): for chunk in chunks: rows = yield self.runInteraction( desc, - self._simple_select_many_txn, + self.simple_select_many_txn, table, column, chunk, @@ -1144,7 +1091,7 @@ class SQLBaseStore(object): return results @classmethod - def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): + def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1177,13 +1124,13 @@ class SQLBaseStore(object): txn.execute(sql, values) return cls.cursor_to_dict(txn) - def _simple_update(self, table, keyvalues, updatevalues, desc): + def simple_update(self, table, keyvalues, updatevalues, desc): return self.runInteraction( - desc, self._simple_update_txn, table, keyvalues, updatevalues + desc, self.simple_update_txn, table, keyvalues, updatevalues ) @staticmethod - def _simple_update_txn(txn, table, keyvalues, updatevalues): + def simple_update_txn(txn, table, keyvalues, updatevalues): if keyvalues: where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) else: @@ -1199,8 +1146,8 @@ class SQLBaseStore(object): return txn.rowcount - def _simple_update_one( - self, table, keyvalues, updatevalues, desc="_simple_update_one" + def simple_update_one( + self, table, keyvalues, updatevalues, desc="simple_update_one" ): """Executes an UPDATE query on the named table, setting new values for columns in a row matching the key values. @@ -1220,12 +1167,12 @@ class SQLBaseStore(object): the update column in the 'keyvalues' dict as well. """ return self.runInteraction( - desc, self._simple_update_one_txn, table, keyvalues, updatevalues + desc, self.simple_update_one_txn, table, keyvalues, updatevalues ) @classmethod - def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): - rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues) + def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): + rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues) if rowcount == 0: raise StoreError(404, "No row found (%s)" % (table,)) @@ -1233,7 +1180,7 @@ class SQLBaseStore(object): raise StoreError(500, "More than one row matched (%s)" % (table,)) @staticmethod - def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): + def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False): select_sql = "SELECT %s FROM %s WHERE %s" % ( ", ".join(retcols), table, @@ -1252,7 +1199,7 @@ class SQLBaseStore(object): return dict(zip(retcols, row)) - def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"): + def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"): """Executes a DELETE query on the named table, expecting to delete a single row. @@ -1260,10 +1207,10 @@ class SQLBaseStore(object): table : string giving the table name keyvalues : dict of column names and values to select the row with """ - return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues) + return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) @staticmethod - def _simple_delete_one_txn(txn, table, keyvalues): + def simple_delete_one_txn(txn, table, keyvalues): """Executes a DELETE query on the named table, expecting to delete a single row. @@ -1282,11 +1229,11 @@ class SQLBaseStore(object): if txn.rowcount > 1: raise StoreError(500, "More than one row matched (%s)" % (table,)) - def _simple_delete(self, table, keyvalues, desc): - return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues) + def simple_delete(self, table, keyvalues, desc): + return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) @staticmethod - def _simple_delete_txn(txn, table, keyvalues): + def simple_delete_txn(txn, table, keyvalues): sql = "DELETE FROM %s WHERE %s" % ( table, " AND ".join("%s = ?" % (k,) for k in keyvalues), @@ -1295,13 +1242,13 @@ class SQLBaseStore(object): txn.execute(sql, list(keyvalues.values())) return txn.rowcount - def _simple_delete_many(self, table, column, iterable, keyvalues, desc): + def simple_delete_many(self, table, column, iterable, keyvalues, desc): return self.runInteraction( - desc, self._simple_delete_many_txn, table, column, iterable, keyvalues + desc, self.simple_delete_many_txn, table, column, iterable, keyvalues ) @staticmethod - def _simple_delete_many_txn(txn, table, column, iterable, keyvalues): + def simple_delete_many_txn(txn, table, column, iterable, keyvalues): """Executes a DELETE query on the named table. Filters rows by if value of `column` is in `iterable`. @@ -1334,7 +1281,7 @@ class SQLBaseStore(object): return txn.rowcount - def _get_cache_dict( + def get_cache_dict( self, db_conn, table, entity_column, stream_column, max_value, limit=100000 ): # Fetch a mapping of room_id -> max stream position for "recent" rooms. @@ -1367,47 +1314,6 @@ class SQLBaseStore(object): return cache, min_val - def _invalidate_cache_and_stream(self, txn, cache_func, keys): - """Invalidates the cache and adds it to the cache stream so slaves - will know to invalidate their caches. - - This should only be used to invalidate caches where slaves won't - otherwise know from other replication streams that the cache should - be invalidated. - """ - txn.call_after(cache_func.invalidate, keys) - self._send_invalidation_to_replication(txn, cache_func.__name__, keys) - - def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): - """Special case invalidation of caches based on current state. - - We special case this so that we can batch the cache invalidations into a - single replication poke. - - Args: - txn - room_id (str): Room where state changed - members_changed (iterable[str]): The user_ids of members that have changed - """ - txn.call_after(self._invalidate_state_caches, room_id, members_changed) - - if members_changed: - # We need to be careful that the size of the `members_changed` list - # isn't so large that it causes problems sending over replication, so we - # send them in chunks. - # Max line length is 16K, and max user ID length is 255, so 50 should - # be safe. - for chunk in batch_iter(members_changed, 50): - keys = itertools.chain([room_id], chunk) - self._send_invalidation_to_replication( - txn, _CURRENT_STATE_CACHE_NAME, keys - ) - else: - # if no members changed, we still need to invalidate the other caches. - self._send_invalidation_to_replication( - txn, _CURRENT_STATE_CACHE_NAME, [room_id] - ) - def _invalidate_state_caches(self, room_id, members_changed): """Invalidates caches that are based on the current state, but does not stream invalidations down replication. @@ -1441,64 +1347,7 @@ class SQLBaseStore(object): # which is fine. pass - def _send_invalidation_to_replication(self, txn, cache_name, keys): - """Notifies replication that given cache has been invalidated. - - Note that this does *not* invalidate the cache locally. - - Args: - txn - cache_name (str) - keys (iterable[str]) - """ - - if isinstance(self.database_engine, PostgresEngine): - # get_next() returns a context manager which is designed to wrap - # the transaction. However, we want to only get an ID when we want - # to use it, here, so we need to call __enter__ manually, and have - # __exit__ called after the transaction finishes. - ctx = self._cache_id_gen.get_next() - stream_id = ctx.__enter__() - txn.call_on_exception(ctx.__exit__, None, None, None) - txn.call_after(ctx.__exit__, None, None, None) - txn.call_after(self.hs.get_notifier().on_new_replication_data) - - self._simple_insert_txn( - txn, - table="cache_invalidation_stream", - values={ - "stream_id": stream_id, - "cache_func": cache_name, - "keys": list(keys), - "invalidation_ts": self.clock.time_msec(), - }, - ) - - def get_all_updated_caches(self, last_id, current_id, limit): - if last_id == current_id: - return defer.succeed([]) - - def get_all_updated_caches_txn(txn): - # We purposefully don't bound by the current token, as we want to - # send across cache invalidations as quickly as possible. Cache - # invalidations are idempotent, so duplicates are fine. - sql = ( - "SELECT stream_id, cache_func, keys, invalidation_ts" - " FROM cache_invalidation_stream" - " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" - ) - txn.execute(sql, (last_id, limit)) - return txn.fetchall() - - return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn) - - def get_cache_stream_token(self): - if self._cache_id_gen: - return self._cache_id_gen.get_current_token() - else: - return 0 - - def _simple_select_list_paginate( + def simple_select_list_paginate( self, table, keyvalues, @@ -1507,7 +1356,7 @@ class SQLBaseStore(object): limit, retcols, order_direction="ASC", - desc="_simple_select_list_paginate", + desc="simple_select_list_paginate", ): """ Executes a SELECT query on the named table with start and limit, @@ -1529,7 +1378,7 @@ class SQLBaseStore(object): """ return self.runInteraction( desc, - self._simple_select_list_paginate_txn, + self.simple_select_list_paginate_txn, table, keyvalues, orderby, @@ -1540,7 +1389,7 @@ class SQLBaseStore(object): ) @classmethod - def _simple_select_list_paginate_txn( + def simple_select_list_paginate_txn( cls, txn, table, @@ -1601,9 +1450,7 @@ class SQLBaseStore(object): txn.execute(sql_count) return txn.fetchone()[0] - def _simple_search_list( - self, table, term, col, retcols, desc="_simple_search_list" - ): + def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1618,11 +1465,11 @@ class SQLBaseStore(object): """ return self.runInteraction( - desc, self._simple_search_list_txn, table, term, col, retcols + desc, self.simple_search_list_txn, table, term, col, retcols ) @classmethod - def _simple_search_list_txn(cls, txn, table, term, col, retcols): + def simple_search_list_txn(cls, txn, table, term, col, retcols): """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. @@ -1645,14 +1492,6 @@ class SQLBaseStore(object): return cls.cursor_to_dict(txn) - @property - def database_engine_name(self): - return self.database_engine.module.__name__ - - def get_server_version(self): - """Returns a string describing the server version number""" - return self.database_engine.server_version - class _RollbackButIsFineException(Exception): """ This exception is used to rollback a transaction without implying diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 37d469ffd7..06955a0537 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -139,7 +139,7 @@ class BackgroundUpdateStore(SQLBaseStore): # otherwise, check if there are updates to be run. This is important, # as we may be running on a worker which doesn't perform the bg updates # itself, but still wants to wait for them to happen. - updates = yield self._simple_select_onecol( + updates = yield self.simple_select_onecol( "background_updates", keyvalues=None, retcol="1", @@ -161,7 +161,7 @@ class BackgroundUpdateStore(SQLBaseStore): if update_name in self._background_update_queue: return False - update_exists = await self._simple_select_one_onecol( + update_exists = await self.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="1", @@ -184,7 +184,7 @@ class BackgroundUpdateStore(SQLBaseStore): no more work to do. """ if not self._background_update_queue: - updates = yield self._simple_select_list( + updates = yield self.simple_select_list( "background_updates", keyvalues=None, retcols=("update_name", "depends_on"), @@ -226,7 +226,7 @@ class BackgroundUpdateStore(SQLBaseStore): else: batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE - progress_json = yield self._simple_select_one_onecol( + progress_json = yield self.simple_select_one_onecol( "background_updates", keyvalues={"update_name": update_name}, retcol="progress_json", @@ -413,7 +413,7 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_queue = [] progress_json = json.dumps(progress) - return self._simple_insert( + return self.simple_insert( "background_updates", {"update_name": update_name, "progress_json": progress_json}, ) @@ -429,7 +429,7 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_queue = [ name for name in self._background_update_queue if name != update_name ] - return self._simple_delete_one( + return self.simple_delete_one( "background_updates", keyvalues={"update_name": update_name} ) @@ -444,7 +444,7 @@ class BackgroundUpdateStore(SQLBaseStore): progress_json = json.dumps(progress) - self._simple_update_one_txn( + self.simple_update_one_txn( txn, "background_updates", keyvalues={"update_name": update_name}, diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index 10c940df1e..2a5b33dda1 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -32,6 +32,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache from .account_data import AccountDataStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore +from .cache import CacheInvalidationStore from .client_ips import ClientIpStore from .deviceinbox import DeviceInboxStore from .devices import DeviceStore @@ -110,6 +111,7 @@ class DataStore( MonthlyActiveUsersStore, StatsStore, RelationsStore, + CacheInvalidationStore, ): def __init__(self, db_conn, hs): self.hs = hs @@ -171,7 +173,7 @@ class DataStore( self._presence_on_startup = self._get_active_presence(db_conn) - presence_cache_prefill, min_presence_val = self._get_cache_dict( + presence_cache_prefill, min_presence_val = self.get_cache_dict( db_conn, "presence_stream", entity_column="user_id", @@ -185,7 +187,7 @@ class DataStore( ) max_device_inbox_id = self._device_inbox_id_gen.get_current_token() - device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( + device_inbox_prefill, min_device_inbox_id = self.get_cache_dict( db_conn, "device_inbox", entity_column="user_id", @@ -200,7 +202,7 @@ class DataStore( ) # The federation outbox and the local device inbox uses the same # stream_id generator. - device_outbox_prefill, min_device_outbox_id = self._get_cache_dict( + device_outbox_prefill, min_device_outbox_id = self.get_cache_dict( db_conn, "device_federation_outbox", entity_column="destination", @@ -226,7 +228,7 @@ class DataStore( ) events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( + curr_state_delta_prefill, min_curr_state_delta_id = self.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", @@ -240,7 +242,7 @@ class DataStore( prefilled_cache=curr_state_delta_prefill, ) - _group_updates_prefill, min_group_updates_id = self._get_cache_dict( + _group_updates_prefill, min_group_updates_id = self.get_cache_dict( db_conn, "local_group_updates", entity_column="user_id", @@ -480,7 +482,7 @@ class DataStore( Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - return self._simple_select_list( + return self.simple_select_list( table="users", keyvalues={}, retcols=["name", "password_hash", "is_guest", "admin", "user_type"], @@ -502,7 +504,7 @@ class DataStore( """ users = yield self.runInteraction( "get_users_paginate", - self._simple_select_list_paginate_txn, + self.simple_select_list_paginate_txn, table="users", keyvalues={"is_guest": False}, orderby=order, @@ -524,7 +526,7 @@ class DataStore( Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - return self._simple_search_list( + return self.simple_search_list( table="users", term=term, col="name", diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index 6afbfc0d74..b0d22faf3f 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -67,7 +67,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_user_txn(txn): - rows = self._simple_select_list_txn( + rows = self.simple_select_list_txn( txn, "account_data", {"user_id": user_id}, @@ -78,7 +78,7 @@ class AccountDataWorkerStore(SQLBaseStore): row["account_data_type"]: json.loads(row["content"]) for row in rows } - rows = self._simple_select_list_txn( + rows = self.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id}, @@ -102,7 +102,7 @@ class AccountDataWorkerStore(SQLBaseStore): Returns: Deferred: A dict """ - result = yield self._simple_select_one_onecol( + result = yield self.simple_select_one_onecol( table="account_data", keyvalues={"user_id": user_id, "account_data_type": data_type}, retcol="content", @@ -127,7 +127,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_room_txn(txn): - rows = self._simple_select_list_txn( + rows = self.simple_select_list_txn( txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, @@ -156,7 +156,7 @@ class AccountDataWorkerStore(SQLBaseStore): """ def get_account_data_for_room_and_type_txn(txn): - content_json = self._simple_select_one_onecol_txn( + content_json = self.simple_select_one_onecol_txn( txn, table="room_account_data", keyvalues={ @@ -184,14 +184,14 @@ class AccountDataWorkerStore(SQLBaseStore): current_id(int): The position to fetch up to. Returns: A deferred pair of lists of tuples of stream_id int, user_id string, - room_id string, type string, and content string. + room_id string, and type string. """ if last_room_id == current_id and last_global_id == current_id: return defer.succeed(([], [])) def get_updated_account_data_txn(txn): sql = ( - "SELECT stream_id, user_id, account_data_type, content" + "SELECT stream_id, user_id, account_data_type" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) @@ -199,7 +199,7 @@ class AccountDataWorkerStore(SQLBaseStore): global_results = txn.fetchall() sql = ( - "SELECT stream_id, user_id, room_id, account_data_type, content" + "SELECT stream_id, user_id, room_id, account_data_type" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) @@ -300,9 +300,9 @@ class AccountDataStore(AccountDataWorkerStore): with self._account_data_id_gen.get_next() as next_id: # no need to lock here as room_account_data has a unique constraint - # on (user_id, room_id, account_data_type) so _simple_upsert will + # on (user_id, room_id, account_data_type) so simple_upsert will # retry if there is a conflict. - yield self._simple_upsert( + yield self.simple_upsert( desc="add_room_account_data", table="room_account_data", keyvalues={ @@ -346,9 +346,9 @@ class AccountDataStore(AccountDataWorkerStore): with self._account_data_id_gen.get_next() as next_id: # no need to lock here as account_data has a unique constraint on - # (user_id, account_data_type) so _simple_upsert will retry if + # (user_id, account_data_type) so simple_upsert will retry if # there is a conflict. - yield self._simple_upsert( + yield self.simple_upsert( desc="add_user_account_data", table="account_data", keyvalues={"user_id": user_id, "account_data_type": account_data_type}, diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index 81babf2029..6b82fd392a 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -133,7 +133,7 @@ class ApplicationServiceTransactionWorkerStore( A Deferred which resolves to a list of ApplicationServices, which may be empty. """ - results = yield self._simple_select_list( + results = yield self.simple_select_list( "application_services_state", dict(state=state), ["as_id"] ) # NB: This assumes this class is linked with ApplicationServiceStore @@ -155,7 +155,7 @@ class ApplicationServiceTransactionWorkerStore( Returns: A Deferred which resolves to ApplicationServiceState. """ - result = yield self._simple_select_one( + result = yield self.simple_select_one( "application_services_state", dict(as_id=service.id), ["state"], @@ -175,7 +175,7 @@ class ApplicationServiceTransactionWorkerStore( Returns: A Deferred which resolves when the state was set successfully. """ - return self._simple_upsert( + return self.simple_upsert( "application_services_state", dict(as_id=service.id), dict(state=state) ) @@ -249,7 +249,7 @@ class ApplicationServiceTransactionWorkerStore( ) # Set current txn_id for AS to 'txn_id' - self._simple_upsert_txn( + self.simple_upsert_txn( txn, "application_services_state", dict(as_id=service.id), @@ -257,7 +257,7 @@ class ApplicationServiceTransactionWorkerStore( ) # Delete txn - self._simple_delete_txn( + self.simple_delete_txn( txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id) ) diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py new file mode 100644 index 0000000000..de3256049d --- /dev/null +++ b/synapse/storage/data_stores/main/cache.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.storage.engines import PostgresEngine +from synapse.util import batch_iter + +logger = logging.getLogger(__name__) + + +# This is a special cache name we use to batch multiple invalidations of caches +# based on the current state when notifying workers over replication. +CURRENT_STATE_CACHE_NAME = "cs_cache_fake" + + +class CacheInvalidationStore(SQLBaseStore): + def _invalidate_cache_and_stream(self, txn, cache_func, keys): + """Invalidates the cache and adds it to the cache stream so slaves + will know to invalidate their caches. + + This should only be used to invalidate caches where slaves won't + otherwise know from other replication streams that the cache should + be invalidated. + """ + txn.call_after(cache_func.invalidate, keys) + self._send_invalidation_to_replication(txn, cache_func.__name__, keys) + + def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed): + """Special case invalidation of caches based on current state. + + We special case this so that we can batch the cache invalidations into a + single replication poke. + + Args: + txn + room_id (str): Room where state changed + members_changed (iterable[str]): The user_ids of members that have changed + """ + txn.call_after(self._invalidate_state_caches, room_id, members_changed) + + if members_changed: + # We need to be careful that the size of the `members_changed` list + # isn't so large that it causes problems sending over replication, so we + # send them in chunks. + # Max line length is 16K, and max user ID length is 255, so 50 should + # be safe. + for chunk in batch_iter(members_changed, 50): + keys = itertools.chain([room_id], chunk) + self._send_invalidation_to_replication( + txn, CURRENT_STATE_CACHE_NAME, keys + ) + else: + # if no members changed, we still need to invalidate the other caches. + self._send_invalidation_to_replication( + txn, CURRENT_STATE_CACHE_NAME, [room_id] + ) + + def _send_invalidation_to_replication(self, txn, cache_name, keys): + """Notifies replication that given cache has been invalidated. + + Note that this does *not* invalidate the cache locally. + + Args: + txn + cache_name (str) + keys (iterable[str]) + """ + + if isinstance(self.database_engine, PostgresEngine): + # get_next() returns a context manager which is designed to wrap + # the transaction. However, we want to only get an ID when we want + # to use it, here, so we need to call __enter__ manually, and have + # __exit__ called after the transaction finishes. + ctx = self._cache_id_gen.get_next() + stream_id = ctx.__enter__() + txn.call_on_exception(ctx.__exit__, None, None, None) + txn.call_after(ctx.__exit__, None, None, None) + txn.call_after(self.hs.get_notifier().on_new_replication_data) + + self.simple_insert_txn( + txn, + table="cache_invalidation_stream", + values={ + "stream_id": stream_id, + "cache_func": cache_name, + "keys": list(keys), + "invalidation_ts": self.clock.time_msec(), + }, + ) + + def get_all_updated_caches(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_caches_txn(txn): + # We purposefully don't bound by the current token, as we want to + # send across cache invalidations as quickly as possible. Cache + # invalidations are idempotent, so duplicates are fine. + sql = ( + "SELECT stream_id, cache_func, keys, invalidation_ts" + " FROM cache_invalidation_stream" + " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, limit)) + return txn.fetchall() + + return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn) + + def get_cache_stream_token(self): + if self._cache_id_gen: + return self._cache_id_gen.get_current_token() + else: + return 0 diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 706c6a1f3f..66522a04b7 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -21,8 +21,8 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage import background_updates -from synapse.storage._base import Cache from synapse.util.caches import CACHE_SIZE_FACTOR +from synapse.util.caches.descriptors import Cache logger = logging.getLogger(__name__) @@ -431,7 +431,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry try: - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="user_ips", keyvalues={ @@ -450,7 +450,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Technically an access token might not be associated with # a device so we need to check. if device_id: - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -483,7 +483,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): if device_id is not None: keyvalues["device_id"] = device_id - res = yield self._simple_select_list( + res = yield self.simple_select_list( table="devices", keyvalues=keyvalues, retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"), @@ -516,7 +516,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): user_agent, _, last_seen = self._batch_row_update[key] results[(access_token, ip)] = (user_agent, last_seen) - rows = yield self._simple_select_list( + rows = yield self.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "last_seen"], diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 96cd0fb77a..206d39134d 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -314,7 +314,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our # acknowledgement from the first time we received the message. - already_inserted = self._simple_select_one_txn( + already_inserted = self.simple_select_one_txn( txn, table="device_federation_inbox", keyvalues={"origin": origin, "message_id": message_id}, @@ -326,7 +326,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) # Add an entry for this message_id so that we know we've processed # it. - self._simple_insert_txn( + self.simple_insert_txn( txn, table="device_federation_inbox", values={ @@ -380,7 +380,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) devices = list(messages_by_device.keys()) if len(devices) == 1 and devices[0] == "*": # Handle wildcard device_ids. - sql = "SELECT device_id FROM devices" " WHERE user_id = ?" + sql = "SELECT device_id FROM devices WHERE user_id = ?" txn.execute(sql, (user_id,)) message_json = json.dumps(messages_by_device["*"]) for row in txn: diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py index 71f62036c0..727c582121 100644 --- a/synapse/storage/data_stores/main/devices.py +++ b/synapse/storage/data_stores/main/devices.py @@ -30,16 +30,16 @@ from synapse.logging.opentracing import ( whitelisted_homeserver, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage._base import ( - Cache, - SQLBaseStore, - db_to_json, - make_in_list_sql_clause, -) +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.background_updates import BackgroundUpdateStore from synapse.types import get_verify_key_from_cross_signing_key from synapse.util import batch_iter -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.caches.descriptors import ( + Cache, + cached, + cachedInlineCallbacks, + cachedList, +) logger = logging.getLogger(__name__) @@ -61,7 +61,7 @@ class DeviceWorkerStore(SQLBaseStore): Raises: StoreError: if the device is not found """ - return self._simple_select_one( + return self.simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -80,7 +80,7 @@ class DeviceWorkerStore(SQLBaseStore): containing "device_id", "user_id" and "display_name" for each device. """ - devices = yield self._simple_select_list( + devices = yield self.simple_select_list( table="devices", keyvalues={"user_id": user_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), @@ -414,7 +414,7 @@ class DeviceWorkerStore(SQLBaseStore): from_user_id, stream_id, ) - self._simple_insert_txn( + self.simple_insert_txn( txn, "user_signature_stream", values={ @@ -466,7 +466,7 @@ class DeviceWorkerStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2, tree=True) def _get_cached_user_device(self, user_id, device_id): - content = yield self._simple_select_one_onecol( + content = yield self.simple_select_one_onecol( table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="content", @@ -476,7 +476,7 @@ class DeviceWorkerStore(SQLBaseStore): @cachedInlineCallbacks() def _get_cached_devices_for_user(self, user_id): - devices = yield self._simple_select_list( + devices = yield self.simple_select_list( table="device_lists_remote_cache", keyvalues={"user_id": user_id}, retcols=("device_id", "content"), @@ -584,7 +584,7 @@ class DeviceWorkerStore(SQLBaseStore): SELECT DISTINCT user_ids FROM user_signature_stream WHERE from_user_id = ? AND stream_id > ? """ - rows = yield self._execute( + rows = yield self.execute( "get_users_whose_signatures_changed", None, sql, user_id, from_key ) return set(user for row in rows for user in json.loads(row[0])) @@ -605,7 +605,7 @@ class DeviceWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? GROUP BY user_id, destination """ - return self._execute( + return self.execute( "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key ) @@ -614,7 +614,7 @@ class DeviceWorkerStore(SQLBaseStore): """Get the last stream_id we got for a user. May be None if we haven't got any information for them. """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, retcol="stream_id", @@ -628,7 +628,7 @@ class DeviceWorkerStore(SQLBaseStore): inlineCallbacks=True, ) def get_device_list_last_stream_id_for_remotes(self, user_ids): - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", iterable=user_ids, @@ -722,7 +722,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): return False try: - inserted = yield self._simple_insert( + inserted = yield self.simple_insert( "devices", values={ "user_id": user_id, @@ -736,7 +736,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): if not inserted: # if the device already exists, check if it's a real device, or # if the device ID is reserved by something else - hidden = yield self._simple_select_one_onecol( + hidden = yield self.simple_select_one_onecol( "devices", keyvalues={"user_id": user_id, "device_id": device_id}, retcol="hidden", @@ -771,7 +771,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Returns: defer.Deferred """ - yield self._simple_delete_one( + yield self.simple_delete_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, desc="delete_device", @@ -789,7 +789,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): Returns: defer.Deferred """ - yield self._simple_delete_many( + yield self.simple_delete_many( table="devices", column="device_id", iterable=device_ids, @@ -818,7 +818,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): updates["display_name"] = new_display_name if not updates: return defer.succeed(None) - return self._simple_update_one( + return self.simple_update_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, updatevalues=updates, @@ -829,7 +829,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def mark_remote_user_device_list_as_unsubscribed(self, user_id): """Mark that we no longer track device lists for remote user. """ - yield self._simple_delete( + yield self.simple_delete( table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, desc="mark_remote_user_device_list_as_unsubscribed", @@ -866,7 +866,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self, txn, user_id, device_id, content, stream_id ): if content.get("deleted"): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -874,7 +874,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id)) else: - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -890,7 +890,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -923,11 +923,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="device_lists_remote_cache", values=[ @@ -946,7 +946,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="device_lists_remote_extremeties", keyvalues={"user_id": user_id}, @@ -995,7 +995,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): [(user_id, device_id, stream_id) for device_id in device_ids], ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="device_lists_stream", values=[ @@ -1006,7 +1006,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): context = get_active_span_text_map() - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="device_lists_outbound_pokes", values=[ diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/data_stores/main/directory.py index 297966d9f4..d332f8a409 100644 --- a/synapse/storage/data_stores/main/directory.py +++ b/synapse/storage/data_stores/main/directory.py @@ -36,7 +36,7 @@ class DirectoryWorkerStore(SQLBaseStore): Deferred: results in namedtuple with keys "room_id" and "servers" or None if no association can be found """ - room_id = yield self._simple_select_one_onecol( + room_id = yield self.simple_select_one_onecol( "room_aliases", {"room_alias": room_alias.to_string()}, "room_id", @@ -47,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore): if not room_id: return None - servers = yield self._simple_select_onecol( + servers = yield self.simple_select_onecol( "room_alias_servers", {"room_alias": room_alias.to_string()}, "server", @@ -60,7 +60,7 @@ class DirectoryWorkerStore(SQLBaseStore): return RoomAliasMapping(room_id, room_alias.to_string(), servers) def get_room_alias_creator(self, room_alias): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="room_aliases", keyvalues={"room_alias": room_alias}, retcol="creator", @@ -69,7 +69,7 @@ class DirectoryWorkerStore(SQLBaseStore): @cached(max_entries=5000) def get_aliases_for_room(self, room_id): - return self._simple_select_onecol( + return self.simple_select_onecol( "room_aliases", {"room_id": room_id}, "room_alias", @@ -93,7 +93,7 @@ class DirectoryStore(DirectoryWorkerStore): """ def alias_txn(txn): - self._simple_insert_txn( + self.simple_insert_txn( txn, "room_aliases", { @@ -103,7 +103,7 @@ class DirectoryStore(DirectoryWorkerStore): }, ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="room_alias_servers", values=[ diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py index 1cbbae5b63..df89eda337 100644 --- a/synapse/storage/data_stores/main/e2e_room_keys.py +++ b/synapse/storage/data_stores/main/e2e_room_keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2017 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,49 +25,8 @@ from synapse.storage._base import SQLBaseStore class EndToEndRoomKeyStore(SQLBaseStore): @defer.inlineCallbacks - def get_e2e_room_key(self, user_id, version, room_id, session_id): - """Get the encrypted E2E room key for a given session from a given - backup version of room_keys. We only store the 'best' room key for a given - session at a given time, as determined by the handler. - - Args: - user_id(str): the user whose backup we're querying - version(str): the version ID of the backup for the set of keys we're querying - room_id(str): the ID of the room whose keys we're querying. - This is a bit redundant as it's implied by the session_id, but - we include for consistency with the rest of the API. - session_id(str): the session whose room_key we're querying. - - Returns: - A deferred dict giving the session_data and message metadata for - this room key. - """ - - row = yield self._simple_select_one( - table="e2e_room_keys", - keyvalues={ - "user_id": user_id, - "version": version, - "room_id": room_id, - "session_id": session_id, - }, - retcols=( - "first_message_index", - "forwarded_count", - "is_verified", - "session_data", - ), - desc="get_e2e_room_key", - ) - - row["session_data"] = json.loads(row["session_data"]) - - return row - - @defer.inlineCallbacks - def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key): - """Replaces or inserts the encrypted E2E room key for a given session in - a given backup + def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key): + """Replaces the encrypted E2E room key for a given session in a given backup Args: user_id(str): the user whose backup we're setting @@ -78,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): StoreError """ - yield self._simple_upsert( + yield self.simple_update_one( table="e2e_room_keys", keyvalues={ "user_id": user_id, @@ -86,21 +46,51 @@ class EndToEndRoomKeyStore(SQLBaseStore): "room_id": room_id, "session_id": session_id, }, - values={ + updatevalues={ "first_message_index": room_key["first_message_index"], "forwarded_count": room_key["forwarded_count"], "is_verified": room_key["is_verified"], "session_data": json.dumps(room_key["session_data"]), }, - lock=False, + desc="update_e2e_room_key", ) - log_kv( - { - "message": "Set room key", - "room_id": room_id, - "session_id": session_id, - "room_key": room_key, - } + + @defer.inlineCallbacks + def add_e2e_room_keys(self, user_id, version, room_keys): + """Bulk add room keys to a given backup. + + Args: + user_id (str): the user whose backup we're adding to + version (str): the version ID of the backup for the set of keys we're adding to + room_keys (iterable[(str, str, dict)]): the keys to add, in the form + (roomID, sessionID, keyData) + """ + + values = [] + for (room_id, session_id, room_key) in room_keys: + values.append( + { + "user_id": user_id, + "version": version, + "room_id": room_id, + "session_id": session_id, + "first_message_index": room_key["first_message_index"], + "forwarded_count": room_key["forwarded_count"], + "is_verified": room_key["is_verified"], + "session_data": json.dumps(room_key["session_data"]), + } + ) + log_kv( + { + "message": "Set room key", + "room_id": room_id, + "session_id": session_id, + "room_key": room_key, + } + ) + + yield self.simple_insert_many( + table="e2e_room_keys", values=values, desc="add_e2e_room_keys" ) @trace @@ -110,11 +100,11 @@ class EndToEndRoomKeyStore(SQLBaseStore): room, or a given session. Args: - user_id(str): the user whose backup we're querying - version(str): the version ID of the backup for the set of keys we're querying - room_id(str): Optional. the ID of the room whose keys we're querying, if any. + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup for the set of keys we're querying + room_id (str): Optional. the ID of the room whose keys we're querying, if any. If not specified, we return the keys for all the rooms in the backup. - session_id(str): Optional. the session whose room_key we're querying, if any. + session_id (str): Optional. the session whose room_key we're querying, if any. If specified, we also require the room_id to be specified. If not specified, we return all the keys in this version of the backup (or for the specified room) @@ -135,7 +125,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - rows = yield self._simple_select_list( + rows = yield self.simple_select_list( table="e2e_room_keys", keyvalues=keyvalues, retcols=( @@ -162,6 +152,95 @@ class EndToEndRoomKeyStore(SQLBaseStore): return sessions + def get_e2e_room_keys_multi(self, user_id, version, room_keys): + """Get multiple room keys at a time. The difference between this function and + get_e2e_room_keys is that this function can be used to retrieve + multiple specific keys at a time, whereas get_e2e_room_keys is used for + getting all the keys in a backup version, all the keys for a room, or a + specific key. + + Args: + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup we're querying about + room_keys (dict[str, dict[str, iterable[str]]]): a map from + room ID -> {"session": [session ids]} indicating the session IDs + that we want to query + + Returns: + Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key + """ + + return self.runInteraction( + "get_e2e_room_keys_multi", + self._get_e2e_room_keys_multi_txn, + user_id, + version, + room_keys, + ) + + @staticmethod + def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys): + if not room_keys: + return {} + + where_clauses = [] + params = [user_id, version] + for room_id, room in room_keys.items(): + sessions = list(room["sessions"]) + if not sessions: + continue + params.append(room_id) + params.extend(sessions) + where_clauses.append( + "(room_id = ? AND session_id IN (%s))" + % (",".join(["?" for _ in sessions]),) + ) + + # check if we're actually querying something + if not where_clauses: + return {} + + sql = """ + SELECT room_id, session_id, first_message_index, forwarded_count, + is_verified, session_data + FROM e2e_room_keys + WHERE user_id = ? AND version = ? AND (%s) + """ % ( + " OR ".join(where_clauses) + ) + + txn.execute(sql, params) + + ret = {} + + for row in txn: + room_id = row[0] + session_id = row[1] + ret.setdefault(room_id, {}) + ret[room_id][session_id] = { + "first_message_index": row[2], + "forwarded_count": row[3], + "is_verified": row[4], + "session_data": json.loads(row[5]), + } + + return ret + + def count_e2e_room_keys(self, user_id, version): + """Get the number of keys in a backup version. + + Args: + user_id (str): the user whose backup we're querying + version (str): the version ID of the backup we're querying about + """ + + return self.simple_select_one_onecol( + table="e2e_room_keys", + keyvalues={"user_id": user_id, "version": version}, + retcol="COUNT(*)", + desc="count_e2e_room_keys", + ) + @trace @defer.inlineCallbacks def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): @@ -188,7 +267,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): if session_id: keyvalues["session_id"] = session_id - yield self._simple_delete( + yield self.simple_delete( table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" ) @@ -219,6 +298,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): version(str) algorithm(str) auth_data(object): opaque dict supplied by the client + etag(int): tag of the keys in the backup """ def _get_e2e_room_keys_version_info_txn(txn): @@ -232,14 +312,16 @@ class EndToEndRoomKeyStore(SQLBaseStore): # it isn't there. raise StoreError(404, "No row found") - result = self._simple_select_one_txn( + result = self.simple_select_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, - retcols=("version", "algorithm", "auth_data"), + retcols=("version", "algorithm", "auth_data", "etag"), ) result["auth_data"] = json.loads(result["auth_data"]) result["version"] = str(result["version"]) + if result["etag"] is None: + result["etag"] = 0 return result return self.runInteraction( @@ -270,7 +352,7 @@ class EndToEndRoomKeyStore(SQLBaseStore): new_version = str(int(current_version) + 1) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="e2e_room_keys_versions", values={ @@ -288,21 +370,33 @@ class EndToEndRoomKeyStore(SQLBaseStore): ) @trace - def update_e2e_room_keys_version(self, user_id, version, info): + def update_e2e_room_keys_version( + self, user_id, version, info=None, version_etag=None + ): """Update a given backup version Args: user_id(str): the user whose backup version we're updating version(str): the version ID of the backup version we're updating - info(dict): the new backup version info to store + info (dict): the new backup version info to store. If None, then + the backup version info is not updated + version_etag (Optional[int]): etag of the keys in the backup. If + None, then the etag is not updated """ + updatevalues = {} - return self._simple_update( - table="e2e_room_keys_versions", - keyvalues={"user_id": user_id, "version": version}, - updatevalues={"auth_data": json.dumps(info["auth_data"])}, - desc="update_e2e_room_keys_version", - ) + if info is not None and "auth_data" in info: + updatevalues["auth_data"] = json.dumps(info["auth_data"]) + if version_etag is not None: + updatevalues["etag"] = version_etag + + if updatevalues: + return self.simple_update( + table="e2e_room_keys_versions", + keyvalues={"user_id": user_id, "version": version}, + updatevalues=updatevalues, + desc="update_e2e_room_keys_version", + ) @trace def delete_e2e_room_keys_version(self, user_id, version=None): @@ -326,13 +420,13 @@ class EndToEndRoomKeyStore(SQLBaseStore): else: this_version = version - self._simple_delete_txn( + self.simple_delete_txn( txn, table="e2e_room_keys", keyvalues={"user_id": user_id, "version": this_version}, ) - return self._simple_update_one_txn( + return self.simple_update_one_txn( txn, table="e2e_room_keys_versions", keyvalues={"user_id": user_id, "version": this_version}, diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index 073412a78d..08bcdc4725 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -138,20 +138,35 @@ class EndToEndKeyWorkerStore(SQLBaseStore): result.setdefault(user_id, {})[device_id] = None # get signatures on the device - signature_sql = ( - "SELECT * " " FROM e2e_cross_signing_signatures " " WHERE %s" - ) % (" OR ".join("(" + q + ")" for q in signature_query_clauses)) + signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % ( + " OR ".join("(" + q + ")" for q in signature_query_clauses) + ) txn.execute(signature_sql, signature_query_params) rows = self.cursor_to_dict(txn) + # add each cross-signing signature to the correct device in the result dict. for row in rows: + signing_user_id = row["user_id"] + signing_key_id = row["key_id"] target_user_id = row["target_user_id"] target_device_id = row["target_device_id"] - if target_user_id in result and target_device_id in result[target_user_id]: - result[target_user_id][target_device_id].setdefault( - "signatures", {} - ).setdefault(row["user_id"], {})[row["key_id"]] = row["signature"] + signature = row["signature"] + + target_user_result = result.get(target_user_id) + if not target_user_result: + continue + + target_device_result = target_user_result.get(target_device_id) + if not target_device_result: + # note that target_device_result will be None for deleted devices. + continue + + target_device_signatures = target_device_result.setdefault("signatures", {}) + signing_user_signatures = target_device_signatures.setdefault( + signing_user_id, {} + ) + signing_user_signatures[signing_key_id] = signature log_kv(result) return result @@ -171,7 +186,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): key_id) to json string for key """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="e2e_one_time_keys_json", column="key_id", iterable=key_ids, @@ -204,7 +219,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): # a unique constraint. If there is a race of two calls to # `add_e2e_one_time_keys` then they'll conflict and we will only # insert one set. - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="e2e_one_time_keys_json", values=[ @@ -335,7 +350,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): WHERE ? < stream_id AND stream_id <= ? GROUP BY user_id """ - return self._execute( + return self.execute( "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key ) @@ -352,7 +367,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): set_tag("time_now", time_now) set_tag("device_keys", device_keys) - old_key_json = self._simple_select_one_onecol_txn( + old_key_json = self.simple_select_one_onecol_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -368,7 +383,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): log_kv({"Message": "Device key already stored."}) return False - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -427,12 +442,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "user_id": user_id, } ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="e2e_device_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="e2e_one_time_keys_json", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -477,7 +492,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): # The "keys" property must only have one entry, which will be the public # key, so we just grab the first value in there pubkey = next(iter(key["keys"].values())) - self._simple_insert_txn( + self.simple_insert_txn( txn, "devices", values={ @@ -490,7 +505,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): # and finally, store the key itself with self._cross_signing_id_gen.get_next() as stream_id: - self._simple_insert_txn( + self.simple_insert_txn( txn, "e2e_cross_signing_keys", values={ @@ -524,7 +539,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): user_id (str): the user who made the signatures signatures (iterable[SignatureListItem]): signatures to add """ - return self._simple_insert_many( + return self.simple_insert_many( "e2e_cross_signing_signatures", [ { diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py index 90bef0cd2c..051ac7a8cb 100644 --- a/synapse/storage/data_stores/main/event_federation.py +++ b/synapse/storage/data_stores/main/event_federation.py @@ -126,7 +126,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns Deferred[int] """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="events", column="event_id", iterable=event_ids, @@ -140,7 +140,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return max(row["depth"] for row in rows) def _get_oldest_events_in_room_txn(self, txn, room_id): - return self._simple_select_onecol_txn( + return self.simple_select_onecol_txn( txn, table="event_backward_extremities", keyvalues={"room_id": room_id}, @@ -235,7 +235,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas @cached(max_entries=5000, iterable=True) def get_latest_event_ids_in_room(self, room_id): - return self._simple_select_onecol( + return self.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, retcol="event_id", @@ -271,7 +271,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) def _get_min_depth_interaction(self, txn, room_id): - min_depth = self._simple_select_one_onecol_txn( + min_depth = self.simple_select_one_onecol_txn( txn, table="room_depth", keyvalues={"room_id": room_id}, @@ -383,7 +383,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas queue = PriorityQueue() for event_id in event_list: - depth = self._simple_select_one_onecol_txn( + depth = self.simple_select_one_onecol_txn( txn, table="events", keyvalues={"event_id": event_id, "room_id": room_id}, @@ -468,7 +468,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas Returns: Deferred[list[str]] """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="event_edges", column="prev_event_id", iterable=event_ids, @@ -508,7 +508,7 @@ class EventFederationStore(EventFederationWorkerStore): if min_depth and depth >= min_depth: return - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="room_depth", keyvalues={"room_id": room_id}, @@ -520,7 +520,7 @@ class EventFederationStore(EventFederationWorkerStore): For the given event, update the event edges table and forward and backward extremities tables. """ - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="event_edges", values=[ diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 04ce21ac66..0a37847cfd 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -441,7 +441,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) def _add_push_actions_to_staging_txn(txn): - # We don't use _simple_insert_many here to avoid the overhead + # We don't use simple_insert_many here to avoid the overhead # of generating lists of dicts. sql = """ @@ -472,7 +472,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): """ try: - res = yield self._simple_delete( + res = yield self.simple_delete( table="event_push_actions_staging", keyvalues={"event_id": event_id}, desc="remove_push_actions_from_staging", @@ -677,7 +677,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): ) for event, _ in events_and_contexts: - user_ids = self._simple_select_onecol_txn( + user_ids = self.simple_select_onecol_txn( txn, table="event_push_actions_staging", keyvalues={"event_id": event.event_id}, @@ -844,7 +844,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): the archiving process has caught up or not. """ - old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + old_rotate_stream_ordering = self.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", keyvalues={}, @@ -880,7 +880,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): return caught_up def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): - old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + old_rotate_stream_ordering = self.simple_select_one_onecol_txn( txn, table="event_push_summary_stream_ordering", keyvalues={}, @@ -912,7 +912,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): # If the `old.user_id` above is NULL then we know there isn't already an # entry in the table, so we simply insert it. Otherwise we update the # existing table. - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="event_push_summary", values=[ diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 878f7568a6..98ae69e996 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -130,6 +130,8 @@ class EventsStore( if self.hs.config.redaction_retention_period is not None: hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def _read_forward_extremities(self): def fetch(txn): @@ -430,7 +432,7 @@ class EventsStore( # event's auth chain, but its easier for now just to store them (and # it doesn't take much storage compared to storing the entire event # anyway). - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="event_auth", values=[ @@ -578,12 +580,12 @@ class EventsStore( self, txn, new_forward_extremities, max_stream_order ): for room_id, new_extrem in iteritems(new_forward_extremities): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="event_forward_extremities", keyvalues={"room_id": room_id} ) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="event_forward_extremities", values=[ @@ -596,7 +598,7 @@ class EventsStore( # new stream_ordering to new forward extremeties in the room. # This allows us to later efficiently look up the forward extremeties # for a room before a given stream_ordering - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="stream_ordering_to_exterm", values=[ @@ -713,16 +715,14 @@ class EventsStore( metadata_json = encode_json(event.internal_metadata.get_dict()) - sql = ( - "UPDATE event_json SET internal_metadata = ?" " WHERE event_id = ?" - ) + sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?" txn.execute(sql, (metadata_json, event.event_id)) # Add an entry to the ex_outlier_stream table to replicate the # change in outlier status to our workers. stream_order = event.internal_metadata.stream_ordering state_group_id = context.state_group - self._simple_insert_txn( + self.simple_insert_txn( txn, table="ex_outlier_stream", values={ @@ -732,7 +732,7 @@ class EventsStore( }, ) - sql = "UPDATE events SET outlier = ?" " WHERE event_id = ?" + sql = "UPDATE events SET outlier = ? WHERE event_id = ?" txn.execute(sql, (False, event.event_id)) # Update the event_backward_extremities table now that this @@ -794,7 +794,7 @@ class EventsStore( d.pop("redacted_because", None) return d - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="event_json", values=[ @@ -811,7 +811,7 @@ class EventsStore( ], ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="events", values=[ @@ -841,7 +841,7 @@ class EventsStore( # If we're persisting an unredacted event we go and ensure # that we mark any redactions that reference this event as # requiring censoring. - self._simple_update_txn( + self.simple_update_txn( txn, table="redactions", keyvalues={"redacts": event.event_id}, @@ -929,6 +929,9 @@ class EventsStore( elif event.type == EventTypes.Redaction: # Insert into the redactions table. self._store_redaction(txn, event) + elif event.type == EventTypes.Retention: + # Update the room_retention table. + self._store_retention_policy_for_room_txn(txn, event) self._handle_event_relations(txn, event) @@ -939,6 +942,12 @@ class EventsStore( txn, event.event_id, labels, event.room_id, event.depth ) + if self._ephemeral_messages_enabled: + # If there's an expiry timestamp on the event, store it. + expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) + if isinstance(expiry_ts, int) and not event.is_state(): + self._insert_event_expiry_txn(txn, event.event_id, expiry_ts) + # Insert into the room_memberships table. self._store_room_members_txn( txn, @@ -974,7 +983,7 @@ class EventsStore( state_values.append(vals) - self._simple_insert_many_txn(txn, table="state_events", values=state_values) + self.simple_insert_many_txn(txn, table="state_events", values=state_values) # Prefill the event cache self._add_to_cache(txn, events_and_contexts) @@ -1023,7 +1032,7 @@ class EventsStore( # invalidate the cache for the redacted event txn.call_after(self._invalidate_get_event_cache, event.redacts) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="redactions", values={ @@ -1068,9 +1077,7 @@ class EventsStore( LIMIT ? """ - rows = yield self._execute( - "_censor_redactions_fetch", None, sql, before_ts, 100 - ) + rows = yield self.execute("_censor_redactions_fetch", None, sql, before_ts, 100) updates = [] @@ -1100,14 +1107,9 @@ class EventsStore( def _update_censor_txn(txn): for redaction_id, event_id, pruned_json in updates: if pruned_json: - self._simple_update_one_txn( - txn, - table="event_json", - keyvalues={"event_id": event_id}, - updatevalues={"json": pruned_json}, - ) + self._censor_event_txn(txn, event_id, pruned_json) - self._simple_update_one_txn( + self.simple_update_one_txn( txn, table="redactions", keyvalues={"event_id": redaction_id}, @@ -1116,6 +1118,22 @@ class EventsStore( yield self.runInteraction("_update_censor_txn", _update_censor_txn) + def _censor_event_txn(self, txn, event_id, pruned_json): + """Censor an event by replacing its JSON in the event_json table with the + provided pruned JSON. + + Args: + txn (LoggingTransaction): The database transaction. + event_id (str): The ID of the event to censor. + pruned_json (str): The pruned JSON + """ + self.simple_update_one_txn( + txn, + table="event_json", + keyvalues={"event_id": event_id}, + updatevalues={"json": pruned_json}, + ) + @defer.inlineCallbacks def count_daily_messages(self): """ @@ -1479,7 +1497,7 @@ class EventsStore( # We do joins against events_to_purge for e.g. calculating state # groups to purge, etc., so lets make an index. - txn.execute("CREATE INDEX events_to_purge_id" " ON events_to_purge(event_id)") + txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)") txn.execute("SELECT event_id, should_delete FROM events_to_purge") event_rows = txn.fetchall() @@ -1760,7 +1778,7 @@ class EventsStore( "[purge] found %i state groups to delete", len(state_groups_to_delete) ) - rows = self._simple_select_many_txn( + rows = self.simple_select_many_txn( txn, table="state_group_edges", column="prev_state_group", @@ -1787,15 +1805,15 @@ class EventsStore( curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) curr_state = curr_state[sg] - self._simple_delete_txn( + self.simple_delete_txn( txn, table="state_groups_state", keyvalues={"state_group": sg} ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="state_group_edges", keyvalues={"state_group": sg} ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -1832,7 +1850,7 @@ class EventsStore( state group. """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="state_group_edges", column="prev_state_group", iterable=state_groups, @@ -1862,7 +1880,7 @@ class EventsStore( # first we have to delete the state groups states logger.info("[purge] removing %s from state_groups_state", room_id) - self._simple_delete_many_txn( + self.simple_delete_many_txn( txn, table="state_groups_state", column="state_group", @@ -1873,7 +1891,7 @@ class EventsStore( # ... and the state group edges logger.info("[purge] removing %s from state_group_edges", room_id) - self._simple_delete_many_txn( + self.simple_delete_many_txn( txn, table="state_group_edges", column="state_group", @@ -1884,7 +1902,7 @@ class EventsStore( # ... and the state groups logger.info("[purge] removing %s from state_groups", room_id) - self._simple_delete_many_txn( + self.simple_delete_many_txn( txn, table="state_groups", column="id", @@ -1901,7 +1919,7 @@ class EventsStore( @cachedInlineCallbacks(max_entries=5000) def _get_event_ordering(self, event_id): - res = yield self._simple_select_one( + res = yield self.simple_select_one( table="events", retcols=["topological_ordering", "stream_ordering"], keyvalues={"event_id": event_id}, @@ -1942,7 +1960,7 @@ class EventsStore( room_id (str): The ID of the room the event was sent to. topological_ordering (int): The position of the event in the room's topology. """ - return self._simple_insert_many_txn( + return self.simple_insert_many_txn( txn=txn, table="event_labels", values=[ @@ -1956,6 +1974,101 @@ class EventsStore( ], ) + def _insert_event_expiry_txn(self, txn, event_id, expiry_ts): + """Save the expiry timestamp associated with a given event ID. + + Args: + txn (LoggingTransaction): The database transaction to use. + event_id (str): The event ID the expiry timestamp is associated with. + expiry_ts (int): The timestamp at which to expire (delete) the event. + """ + return self.simple_insert_txn( + txn=txn, + table="event_expiry", + values={"event_id": event_id, "expiry_ts": expiry_ts}, + ) + + @defer.inlineCallbacks + def expire_event(self, event_id): + """Retrieve and expire an event that has expired, and delete its associated + expiry timestamp. If the event can't be retrieved, delete its associated + timestamp so we don't try to expire it again in the future. + + Args: + event_id (str): The ID of the event to delete. + """ + # Try to retrieve the event's content from the database or the event cache. + event = yield self.get_event(event_id) + + def delete_expired_event_txn(txn): + # Delete the expiry timestamp associated with this event from the database. + self._delete_event_expiry_txn(txn, event_id) + + if not event: + # If we can't find the event, log a warning and delete the expiry date + # from the database so that we don't try to expire it again in the + # future. + logger.warning( + "Can't expire event %s because we don't have it.", event_id + ) + return + + # Prune the event's dict then convert it to JSON. + pruned_json = encode_json(prune_event_dict(event.get_dict())) + + # Update the event_json table to replace the event's JSON with the pruned + # JSON. + self._censor_event_txn(txn, event.event_id, pruned_json) + + # We need to invalidate the event cache entry for this event because we + # changed its content in the database. We can't call + # self._invalidate_cache_and_stream because self.get_event_cache isn't of the + # right type. + txn.call_after(self._get_event_cache.invalidate, (event.event_id,)) + # Send that invalidation to replication so that other workers also invalidate + # the event cache. + self._send_invalidation_to_replication( + txn, "_get_event_cache", (event.event_id,) + ) + + yield self.runInteraction("delete_expired_event", delete_expired_event_txn) + + def _delete_event_expiry_txn(self, txn, event_id): + """Delete the expiry timestamp associated with an event ID without deleting the + actual event. + + Args: + txn (LoggingTransaction): The transaction to use to perform the deletion. + event_id (str): The event ID to delete the associated expiry timestamp of. + """ + return self.simple_delete_txn( + txn=txn, table="event_expiry", keyvalues={"event_id": event_id} + ) + + def get_next_event_to_expire(self): + """Retrieve the entry with the lowest expiry timestamp in the event_expiry + table, or None if there's no more event to expire. + + Returns: Deferred[Optional[Tuple[str, int]]] + A tuple containing the event ID as its first element and an expiry timestamp + as its second one, if there's at least one row in the event_expiry table. + None otherwise. + """ + + def get_next_event_to_expire_txn(txn): + txn.execute( + """ + SELECT event_id, expiry_ts FROM event_expiry + ORDER BY expiry_ts ASC LIMIT 1 + """ + ) + + return txn.fetchone() + + return self.runInteraction( + desc="get_next_event_to_expire", func=get_next_event_to_expire_txn + ) + AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py index aa87f9abc5..37dfc8c871 100644 --- a/synapse/storage/data_stores/main/events_bg_updates.py +++ b/synapse/storage/data_stores/main/events_bg_updates.py @@ -189,7 +189,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)] for chunk in chunks: - ev_rows = self._simple_select_many_txn( + ev_rows = self.simple_select_many_txn( txn, table="event_json", column="event_id", @@ -366,7 +366,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): to_delete.intersection_update(original_set) - deleted = self._simple_delete_many_txn( + deleted = self.simple_delete_many_txn( txn=txn, table="event_forward_extremities", column="event_id", @@ -382,7 +382,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): if deleted: # We now need to invalidate the caches of these rooms - rows = self._simple_select_many_txn( + rows = self.simple_select_many_txn( txn, table="events", column="event_id", @@ -396,7 +396,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): self.get_latest_event_ids_in_room.invalidate, (room_id,) ) - self._simple_delete_many_txn( + self.simple_delete_many_txn( txn=txn, table="_extremities_to_check", column="event_id", @@ -533,7 +533,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore): try: event_json = json.loads(event_json_raw) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn=txn, table="event_labels", values=[ diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 4c4b76bd93..6a08a746b6 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -17,6 +17,7 @@ from __future__ import division import itertools import logging +import threading from collections import namedtuple from canonicaljson import json @@ -34,6 +35,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.types import get_domain_from_id from synapse.util import batch_iter +from synapse.util.caches.descriptors import Cache from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -53,6 +55,17 @@ _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) class EventsWorkerStore(SQLBaseStore): + def __init__(self, db_conn, hs): + super(EventsWorkerStore, self).__init__(db_conn, hs) + + self._get_event_cache = Cache( + "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size + ) + + self._event_fetch_lock = threading.Condition() + self._event_fetch_list = [] + self._event_fetch_ongoing = 0 + def get_received_ts(self, event_id): """Get received_ts (when it was persisted) for the event. @@ -65,7 +78,7 @@ class EventsWorkerStore(SQLBaseStore): Deferred[int|None]: Timestamp in milliseconds, or None for events that were persisted before received_ts was implemented. """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="received_ts", @@ -439,7 +452,7 @@ class EventsWorkerStore(SQLBaseStore): event_id for events, _ in event_list for event_id in events ) - row_dict = self._new_transaction( + row_dict = self.new_transaction( conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch ) @@ -732,7 +745,7 @@ class EventsWorkerStore(SQLBaseStore): """Given a list of event ids, check if we have already processed and stored them as non outliers. """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="events", retcols=("event_id",), column="event_id", @@ -770,40 +783,6 @@ class EventsWorkerStore(SQLBaseStore): yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk) return results - def get_seen_events_with_rejections(self, event_ids): - """Given a list of event ids, check if we rejected them. - - Args: - event_ids (list[str]) - - Returns: - Deferred[dict[str, str|None): - Has an entry for each event id we already have seen. Maps to - the rejected reason string if we rejected the event, else maps - to None. - """ - if not event_ids: - return defer.succeed({}) - - def f(txn): - sql = ( - "SELECT e.event_id, reason FROM events as e " - "LEFT JOIN rejections as r ON e.event_id = r.event_id " - "WHERE e.event_id = ?" - ) - - res = {} - for event_id in event_ids: - txn.execute(sql, (event_id,)) - row = txn.fetchone() - if row: - _, rejected = row - res[event_id] = rejected - - return res - - return self.runInteraction("get_seen_events_with_rejections", f) - def _get_total_state_event_counts_txn(self, txn, room_id): """ See get_total_state_event_counts. diff --git a/synapse/storage/data_stores/main/filtering.py b/synapse/storage/data_stores/main/filtering.py index a2a2a67927..17ef7b9354 100644 --- a/synapse/storage/data_stores/main/filtering.py +++ b/synapse/storage/data_stores/main/filtering.py @@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore): except ValueError: raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) - def_json = yield self._simple_select_one_onecol( + def_json = yield self.simple_select_one_onecol( table="user_filters", keyvalues={"user_id": user_localpart, "filter_id": filter_id}, retcol="filter_json", @@ -55,7 +55,7 @@ class FilteringStore(SQLBaseStore): if filter_id_response is not None: return filter_id_response[0] - sql = "SELECT MAX(filter_id) FROM user_filters " "WHERE user_id = ?" + sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?" txn.execute(sql, (user_localpart,)) max_id = txn.fetchone()[0] if max_id is None: diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py index 5ded539af8..9e1d12bcb7 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -35,7 +35,7 @@ class GroupServerStore(SQLBaseStore): * "invite" * "open" """ - return self._simple_update_one( + return self.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues={"join_policy": join_policy}, @@ -43,7 +43,7 @@ class GroupServerStore(SQLBaseStore): ) def get_group(self, group_id): - return self._simple_select_one( + return self.simple_select_one( table="groups", keyvalues={"group_id": group_id}, retcols=( @@ -65,7 +65,7 @@ class GroupServerStore(SQLBaseStore): if not include_private: keyvalues["is_public"] = True - return self._simple_select_list( + return self.simple_select_list( table="group_users", keyvalues=keyvalues, retcols=("user_id", "is_public", "is_admin"), @@ -75,7 +75,7 @@ class GroupServerStore(SQLBaseStore): def get_invited_users_in_group(self, group_id): # TODO: Pagination - return self._simple_select_onecol( + return self.simple_select_onecol( table="group_invites", keyvalues={"group_id": group_id}, retcol="user_id", @@ -89,7 +89,7 @@ class GroupServerStore(SQLBaseStore): if not include_private: keyvalues["is_public"] = True - return self._simple_select_list( + return self.simple_select_list( table="group_rooms", keyvalues=keyvalues, retcols=("room_id", "is_public"), @@ -180,7 +180,7 @@ class GroupServerStore(SQLBaseStore): an order of 1 will put the room first. Otherwise, the room gets added to the end. """ - room_in_group = self._simple_select_one_onecol_txn( + room_in_group = self.simple_select_one_onecol_txn( txn, table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, @@ -193,7 +193,7 @@ class GroupServerStore(SQLBaseStore): if category_id is None: category_id = _DEFAULT_CATEGORY_ID else: - cat_exists = self._simple_select_one_onecol_txn( + cat_exists = self.simple_select_one_onecol_txn( txn, table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -204,7 +204,7 @@ class GroupServerStore(SQLBaseStore): raise SynapseError(400, "Category doesn't exist") # TODO: Check category is part of summary already - cat_exists = self._simple_select_one_onecol_txn( + cat_exists = self.simple_select_one_onecol_txn( txn, table="group_summary_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, @@ -224,7 +224,7 @@ class GroupServerStore(SQLBaseStore): (group_id, category_id, group_id, category_id), ) - existing = self._simple_select_one_txn( + existing = self.simple_select_one_txn( txn, table="group_summary_rooms", keyvalues={ @@ -257,7 +257,7 @@ class GroupServerStore(SQLBaseStore): to_update["room_order"] = order if is_public is not None: to_update["is_public"] = is_public - self._simple_update_txn( + self.simple_update_txn( txn, table="group_summary_rooms", keyvalues={ @@ -271,7 +271,7 @@ class GroupServerStore(SQLBaseStore): if is_public is None: is_public = True - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_summary_rooms", values={ @@ -287,7 +287,7 @@ class GroupServerStore(SQLBaseStore): if category_id is None: category_id = _DEFAULT_CATEGORY_ID - return self._simple_delete( + return self.simple_delete( table="group_summary_rooms", keyvalues={ "group_id": group_id, @@ -299,7 +299,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_categories(self, group_id): - rows = yield self._simple_select_list( + rows = yield self.simple_select_list( table="group_room_categories", keyvalues={"group_id": group_id}, retcols=("category_id", "is_public", "profile"), @@ -316,7 +316,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_category(self, group_id, category_id): - category = yield self._simple_select_one( + category = yield self.simple_select_one( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, retcols=("is_public", "profile"), @@ -343,7 +343,7 @@ class GroupServerStore(SQLBaseStore): else: update_values["is_public"] = is_public - return self._simple_upsert( + return self.simple_upsert( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, values=update_values, @@ -352,7 +352,7 @@ class GroupServerStore(SQLBaseStore): ) def remove_group_category(self, group_id, category_id): - return self._simple_delete( + return self.simple_delete( table="group_room_categories", keyvalues={"group_id": group_id, "category_id": category_id}, desc="remove_group_category", @@ -360,7 +360,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_roles(self, group_id): - rows = yield self._simple_select_list( + rows = yield self.simple_select_list( table="group_roles", keyvalues={"group_id": group_id}, retcols=("role_id", "is_public", "profile"), @@ -377,7 +377,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def get_group_role(self, group_id, role_id): - role = yield self._simple_select_one( + role = yield self.simple_select_one( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, retcols=("is_public", "profile"), @@ -404,7 +404,7 @@ class GroupServerStore(SQLBaseStore): else: update_values["is_public"] = is_public - return self._simple_upsert( + return self.simple_upsert( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, values=update_values, @@ -413,7 +413,7 @@ class GroupServerStore(SQLBaseStore): ) def remove_group_role(self, group_id, role_id): - return self._simple_delete( + return self.simple_delete( table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, desc="remove_group_role", @@ -444,7 +444,7 @@ class GroupServerStore(SQLBaseStore): an order of 1 will put the user first. Otherwise, the user gets added to the end. """ - user_in_group = self._simple_select_one_onecol_txn( + user_in_group = self.simple_select_one_onecol_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -457,7 +457,7 @@ class GroupServerStore(SQLBaseStore): if role_id is None: role_id = _DEFAULT_ROLE_ID else: - role_exists = self._simple_select_one_onecol_txn( + role_exists = self.simple_select_one_onecol_txn( txn, table="group_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -468,7 +468,7 @@ class GroupServerStore(SQLBaseStore): raise SynapseError(400, "Role doesn't exist") # TODO: Check role is part of the summary already - role_exists = self._simple_select_one_onecol_txn( + role_exists = self.simple_select_one_onecol_txn( txn, table="group_summary_roles", keyvalues={"group_id": group_id, "role_id": role_id}, @@ -488,7 +488,7 @@ class GroupServerStore(SQLBaseStore): (group_id, role_id, group_id, role_id), ) - existing = self._simple_select_one_txn( + existing = self.simple_select_one_txn( txn, table="group_summary_users", keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id}, @@ -517,7 +517,7 @@ class GroupServerStore(SQLBaseStore): to_update["user_order"] = order if is_public is not None: to_update["is_public"] = is_public - self._simple_update_txn( + self.simple_update_txn( txn, table="group_summary_users", keyvalues={ @@ -531,7 +531,7 @@ class GroupServerStore(SQLBaseStore): if is_public is None: is_public = True - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_summary_users", values={ @@ -547,7 +547,7 @@ class GroupServerStore(SQLBaseStore): if role_id is None: role_id = _DEFAULT_ROLE_ID - return self._simple_delete( + return self.simple_delete( table="group_summary_users", keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id}, desc="remove_user_from_summary", @@ -561,7 +561,7 @@ class GroupServerStore(SQLBaseStore): Deferred[list[str]]: A twisted.Deferred containing a list of group ids containing this room """ - return self._simple_select_onecol( + return self.simple_select_onecol( table="group_rooms", keyvalues={"room_id": room_id}, retcol="group_id", @@ -630,7 +630,7 @@ class GroupServerStore(SQLBaseStore): ) def is_user_in_group(self, user_id, group_id): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", @@ -639,7 +639,7 @@ class GroupServerStore(SQLBaseStore): ).addCallback(lambda r: bool(r)) def is_user_admin_in_group(self, group_id, user_id): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="is_admin", @@ -650,7 +650,7 @@ class GroupServerStore(SQLBaseStore): def add_group_invite(self, group_id, user_id): """Record that the group server has invited a user """ - return self._simple_insert( + return self.simple_insert( table="group_invites", values={"group_id": group_id, "user_id": user_id}, desc="add_group_invite", @@ -659,7 +659,7 @@ class GroupServerStore(SQLBaseStore): def is_user_invited_to_local_group(self, group_id, user_id): """Has the group server invited a user? """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, retcol="user_id", @@ -682,7 +682,7 @@ class GroupServerStore(SQLBaseStore): """ def _get_users_membership_in_group_txn(txn): - row = self._simple_select_one_txn( + row = self.simple_select_one_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -697,7 +697,7 @@ class GroupServerStore(SQLBaseStore): "is_privileged": row["is_admin"], } - row = self._simple_select_one_onecol_txn( + row = self.simple_select_one_onecol_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -738,7 +738,7 @@ class GroupServerStore(SQLBaseStore): """ def _add_user_to_group_txn(txn): - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_users", values={ @@ -749,14 +749,14 @@ class GroupServerStore(SQLBaseStore): }, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, ) if local_attestation: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_attestations_renewals", values={ @@ -766,7 +766,7 @@ class GroupServerStore(SQLBaseStore): }, ) if remote_attestation: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_attestations_remote", values={ @@ -781,27 +781,27 @@ class GroupServerStore(SQLBaseStore): def remove_user_from_group(self, group_id, user_id): def _remove_user_from_group_txn(txn): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_users", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_summary_users", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -812,14 +812,14 @@ class GroupServerStore(SQLBaseStore): ) def add_room_to_group(self, group_id, room_id, is_public): - return self._simple_insert( + return self.simple_insert( table="group_rooms", values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, desc="add_room_to_group", ) def update_room_in_group_visibility(self, group_id, room_id, is_public): - return self._simple_update( + return self.simple_update( table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, updatevalues={"is_public": is_public}, @@ -828,13 +828,13 @@ class GroupServerStore(SQLBaseStore): def remove_room_from_group(self, group_id, room_id): def _remove_room_from_group_txn(txn): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_summary_rooms", keyvalues={"group_id": group_id, "room_id": room_id}, @@ -847,7 +847,7 @@ class GroupServerStore(SQLBaseStore): def get_publicised_groups_for_user(self, user_id): """Get all groups a user is publicising """ - return self._simple_select_onecol( + return self.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, retcol="group_id", @@ -857,7 +857,7 @@ class GroupServerStore(SQLBaseStore): def update_group_publicity(self, group_id, user_id, publicise): """Update whether the user is publicising their membership of the group """ - return self._simple_update_one( + return self.simple_update_one( table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"is_publicised": publicise}, @@ -893,12 +893,12 @@ class GroupServerStore(SQLBaseStore): def _register_user_group_membership_txn(txn, next_id): # TODO: Upsert? - self._simple_delete_txn( + self.simple_delete_txn( txn, table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="local_group_membership", values={ @@ -911,7 +911,7 @@ class GroupServerStore(SQLBaseStore): }, ) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="local_group_updates", values={ @@ -930,7 +930,7 @@ class GroupServerStore(SQLBaseStore): if membership == "join": if local_attestation: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_attestations_renewals", values={ @@ -940,7 +940,7 @@ class GroupServerStore(SQLBaseStore): }, ) if remote_attestation: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="group_attestations_remote", values={ @@ -951,12 +951,12 @@ class GroupServerStore(SQLBaseStore): }, ) else: - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -976,7 +976,7 @@ class GroupServerStore(SQLBaseStore): def create_group( self, group_id, user_id, name, avatar_url, short_description, long_description ): - yield self._simple_insert( + yield self.simple_insert( table="groups", values={ "group_id": group_id, @@ -991,7 +991,7 @@ class GroupServerStore(SQLBaseStore): @defer.inlineCallbacks def update_group_profile(self, group_id, profile): - yield self._simple_update_one( + yield self.simple_update_one( table="groups", keyvalues={"group_id": group_id}, updatevalues=profile, @@ -1017,7 +1017,7 @@ class GroupServerStore(SQLBaseStore): def update_attestation_renewal(self, group_id, user_id, attestation): """Update an attestation that we have renewed """ - return self._simple_update_one( + return self.simple_update_one( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, @@ -1027,7 +1027,7 @@ class GroupServerStore(SQLBaseStore): def update_remote_attestion(self, group_id, user_id, attestation): """Update an attestation that a remote has renewed """ - return self._simple_update_one( + return self.simple_update_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, updatevalues={ @@ -1046,7 +1046,7 @@ class GroupServerStore(SQLBaseStore): group_id (str) user_id (str) """ - return self._simple_delete( + return self.simple_delete( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, desc="remove_attestation_renewal", @@ -1057,7 +1057,7 @@ class GroupServerStore(SQLBaseStore): """Get the attestation that proves the remote agrees that the user is in the group. """ - row = yield self._simple_select_one( + row = yield self.simple_select_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, retcols=("valid_until_ms", "attestation_json"), @@ -1072,7 +1072,7 @@ class GroupServerStore(SQLBaseStore): return None def get_joined_groups(self, user_id): - return self._simple_select_onecol( + return self.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join"}, retcol="group_id", @@ -1188,7 +1188,7 @@ class GroupServerStore(SQLBaseStore): ] for table in tables: - self._simple_delete_txn( + self.simple_delete_txn( txn, table=table, keyvalues={"group_id": group_id} ) diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py index ebc7db3ed6..c7150432b3 100644 --- a/synapse/storage/data_stores/main/keys.py +++ b/synapse/storage/data_stores/main/keys.py @@ -129,7 +129,7 @@ class KeyStore(SQLBaseStore): return self.runInteraction( "store_server_verify_keys", - self._simple_upsert_many_txn, + self.simple_upsert_many_txn, table="server_signature_keys", key_names=("server_name", "key_id"), key_values=key_values, @@ -157,7 +157,7 @@ class KeyStore(SQLBaseStore): ts_valid_until_ms (int): The time when this json stops being valid. key_json (bytes): The encoded JSON. """ - return self._simple_upsert( + return self.simple_upsert( table="server_keys_json", keyvalues={ "server_name": server_name, @@ -196,7 +196,7 @@ class KeyStore(SQLBaseStore): keyvalues["key_id"] = key_id if from_server is not None: keyvalues["from_server"] = from_server - rows = self._simple_select_list_txn( + rows = self.simple_select_list_txn( txn, "server_keys_json", keyvalues=keyvalues, diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py index 84b5f3ad5e..0cb9446f96 100644 --- a/synapse/storage/data_stores/main/media_repository.py +++ b/synapse/storage/data_stores/main/media_repository.py @@ -39,7 +39,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): Returns: None if the media_id doesn't exist. """ - return self._simple_select_one( + return self.simple_select_one( "local_media_repository", {"media_id": media_id}, ( @@ -64,7 +64,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): user_id, url_cache=None, ): - return self._simple_insert( + return self.simple_insert( "local_media_repository", { "media_id": media_id, @@ -129,7 +129,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def store_url_cache( self, url, response_code, etag, expires_ts, og, media_id, download_ts ): - return self._simple_insert( + return self.simple_insert( "local_media_repository_url_cache", { "url": url, @@ -144,7 +144,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) def get_local_media_thumbnails(self, media_id): - return self._simple_select_list( + return self.simple_select_list( "local_media_repository_thumbnails", {"media_id": media_id}, ( @@ -166,7 +166,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self._simple_insert( + return self.simple_insert( "local_media_repository_thumbnails", { "media_id": media_id, @@ -180,7 +180,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) def get_cached_remote_media(self, origin, media_id): - return self._simple_select_one( + return self.simple_select_one( "remote_media_cache", {"media_origin": origin, "media_id": media_id}, ( @@ -205,7 +205,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): upload_name, filesystem_id, ): - return self._simple_insert( + return self.simple_insert( "remote_media_cache", { "media_origin": origin, @@ -253,7 +253,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): return self.runInteraction("update_cached_last_access_time", update_cache_txn) def get_remote_media_thumbnails(self, origin, media_id): - return self._simple_select_list( + return self.simple_select_list( "remote_media_cache_thumbnails", {"media_origin": origin, "media_id": media_id}, ( @@ -278,7 +278,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - return self._simple_insert( + return self.simple_insert( "remote_media_cache_thumbnails", { "media_origin": origin, @@ -300,18 +300,18 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " WHERE last_access_ts < ?" ) - return self._execute( + return self.execute( "get_remote_media_before", self.cursor_to_dict, sql, before_ts ) def delete_remote_media(self, media_origin, media_id): def delete_remote_media_txn(txn): - self._simple_delete_txn( + self.simple_delete_txn( txn, "remote_media_cache", keyvalues={"media_origin": media_origin, "media_id": media_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, "remote_media_cache_thumbnails", keyvalues={"media_origin": media_origin, "media_id": media_id}, @@ -337,7 +337,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): if len(media_ids) == 0: return - sql = "DELETE FROM local_media_repository_url_cache" " WHERE media_id = ?" + sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?" def _delete_url_cache_txn(txn): txn.executemany(sql, [(media_id,) for media_id in media_ids]) @@ -365,11 +365,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): return def _delete_url_cache_media_txn(txn): - sql = "DELETE FROM local_media_repository" " WHERE media_id = ?" + sql = "DELETE FROM local_media_repository WHERE media_id = ?" txn.executemany(sql, [(media_id,) for media_id in media_ids]) - sql = "DELETE FROM local_media_repository_thumbnails" " WHERE media_id = ?" + sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?" txn.executemany(sql, [(media_id,) for media_id in media_ids]) diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py index b41c3d317a..b8fc28f97b 100644 --- a/synapse/storage/data_stores/main/monthly_active_users.py +++ b/synapse/storage/data_stores/main/monthly_active_users.py @@ -32,7 +32,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs # Do not add more reserved users than the total allowable number - self._new_transaction( + self.new_transaction( dbconn, "initialise_mau_threepids", [], @@ -261,7 +261,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): # never be a big table and alternative approaches (batching multiple # upserts into a single txn) introduced a lot of extra complexity. # See https://github.com/matrix-org/synapse/issues/3854 for more - is_insert = self._simple_upsert_txn( + is_insert = self.simple_upsert_txn( txn, table="monthly_active_users", keyvalues={"user_id": user_id}, @@ -281,7 +281,7 @@ class MonthlyActiveUsersStore(SQLBaseStore): """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="monthly_active_users", keyvalues={"user_id": user_id}, retcol="timestamp", diff --git a/synapse/storage/data_stores/main/openid.py b/synapse/storage/data_stores/main/openid.py index 79b40044d9..650e49750e 100644 --- a/synapse/storage/data_stores/main/openid.py +++ b/synapse/storage/data_stores/main/openid.py @@ -3,7 +3,7 @@ from synapse.storage._base import SQLBaseStore class OpenIdStore(SQLBaseStore): def insert_open_id_token(self, token, ts_valid_until_ms, user_id): - return self._simple_insert( + return self.simple_insert( table="open_id_tokens", values={ "token": token, diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py index 523ed6575e..a5e121efd1 100644 --- a/synapse/storage/data_stores/main/presence.py +++ b/synapse/storage/data_stores/main/presence.py @@ -46,7 +46,7 @@ class PresenceStore(SQLBaseStore): txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,)) # Actually insert new rows - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="presence_stream", values=[ @@ -103,7 +103,7 @@ class PresenceStore(SQLBaseStore): inlineCallbacks=True, ) def get_presence_for_users(self, user_ids): - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="presence_stream", column="user_id", iterable=user_ids, @@ -129,7 +129,7 @@ class PresenceStore(SQLBaseStore): return self._presence_id_gen.get_current_token() def allow_presence_visible(self, observed_localpart, observer_userid): - return self._simple_insert( + return self.simple_insert( table="presence_allow_inbound", values={ "observed_user_id": observed_localpart, @@ -140,7 +140,7 @@ class PresenceStore(SQLBaseStore): ) def disallow_presence_visible(self, observed_localpart, observer_userid): - return self._simple_delete_one( + return self.simple_delete_one( table="presence_allow_inbound", keyvalues={ "observed_user_id": observed_localpart, diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py index e4e8a1c1d6..c8b5b60301 100644 --- a/synapse/storage/data_stores/main/profile.py +++ b/synapse/storage/data_stores/main/profile.py @@ -24,7 +24,7 @@ class ProfileWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_profileinfo(self, user_localpart): try: - profile = yield self._simple_select_one( + profile = yield self.simple_select_one( table="profiles", keyvalues={"user_id": user_localpart}, retcols=("displayname", "avatar_url"), @@ -42,7 +42,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_profile_displayname(self, user_localpart): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="displayname", @@ -50,7 +50,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_profile_avatar_url(self, user_localpart): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, retcol="avatar_url", @@ -58,7 +58,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def get_from_remote_profile_cache(self, user_id): - return self._simple_select_one( + return self.simple_select_one( table="remote_profile_cache", keyvalues={"user_id": user_id}, retcols=("displayname", "avatar_url"), @@ -67,12 +67,12 @@ class ProfileWorkerStore(SQLBaseStore): ) def create_profile(self, user_localpart): - return self._simple_insert( + return self.simple_insert( table="profiles", values={"user_id": user_localpart}, desc="create_profile" ) def set_profile_displayname(self, user_localpart, new_displayname): - return self._simple_update_one( + return self.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"displayname": new_displayname}, @@ -80,7 +80,7 @@ class ProfileWorkerStore(SQLBaseStore): ) def set_profile_avatar_url(self, user_localpart, new_avatar_url): - return self._simple_update_one( + return self.simple_update_one( table="profiles", keyvalues={"user_id": user_localpart}, updatevalues={"avatar_url": new_avatar_url}, @@ -95,7 +95,7 @@ class ProfileStore(ProfileWorkerStore): This should only be called when `is_subscribed_remote_profile_for_user` would return true for the user. """ - return self._simple_upsert( + return self.simple_upsert( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ @@ -107,7 +107,7 @@ class ProfileStore(ProfileWorkerStore): ) def update_remote_profile_cache(self, user_id, displayname, avatar_url): - return self._simple_update( + return self.simple_update( table="remote_profile_cache", keyvalues={"user_id": user_id}, values={ @@ -125,7 +125,7 @@ class ProfileStore(ProfileWorkerStore): """ subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) if not subscribed: - yield self._simple_delete( + yield self.simple_delete( table="remote_profile_cache", keyvalues={"user_id": user_id}, desc="delete_remote_profile_cache", @@ -155,7 +155,7 @@ class ProfileStore(ProfileWorkerStore): def is_subscribed_remote_profile_for_user(self, user_id): """Check whether we are interested in a remote user's profile. """ - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="group_users", keyvalues={"user_id": user_id}, retcol="user_id", @@ -166,7 +166,7 @@ class ProfileStore(ProfileWorkerStore): if res: return True - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="group_invites", keyvalues={"user_id": user_id}, retcol="user_id", diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index b520062d84..75bd499bcd 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -75,7 +75,7 @@ class PushRulesWorkerStore( def __init__(self, db_conn, hs): super(PushRulesWorkerStore, self).__init__(db_conn, hs) - push_rules_prefill, push_rules_id = self._get_cache_dict( + push_rules_prefill, push_rules_id = self.get_cache_dict( db_conn, "push_rules_stream", entity_column="user_id", @@ -100,7 +100,7 @@ class PushRulesWorkerStore( @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): - rows = yield self._simple_select_list( + rows = yield self.simple_select_list( table="push_rules", keyvalues={"user_name": user_id}, retcols=( @@ -124,7 +124,7 @@ class PushRulesWorkerStore( @cachedInlineCallbacks(max_entries=5000) def get_push_rules_enabled_for_user(self, user_id): - results = yield self._simple_select_list( + results = yield self.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, retcols=("user_name", "rule_id", "enabled"), @@ -162,7 +162,7 @@ class PushRulesWorkerStore( results = {user_id: [] for user_id in user_ids} - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="push_rules", column="user_name", iterable=user_ids, @@ -320,7 +320,7 @@ class PushRulesWorkerStore( results = {user_id: {} for user_id in user_ids} - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="push_rules_enable", column="user_name", iterable=user_ids, @@ -395,7 +395,7 @@ class PushRuleStore(PushRulesWorkerStore): relative_to_rule = before or after - res = self._simple_select_one_txn( + res = self.simple_select_one_txn( txn, table="push_rules", keyvalues={"user_name": user_id, "rule_id": relative_to_rule}, @@ -499,7 +499,7 @@ class PushRuleStore(PushRulesWorkerStore): actions_json, update_stream=True, ): - """Specialised version of _simple_upsert_txn that picks a push_rule_id + """Specialised version of simple_upsert_txn that picks a push_rule_id using the _push_rule_id_gen if it needs to insert the rule. It assumes that the "push_rules" table is locked""" @@ -518,7 +518,7 @@ class PushRuleStore(PushRulesWorkerStore): # We didn't update a row with the given rule_id so insert one push_rule_id = self._push_rule_id_gen.get_next() - self._simple_insert_txn( + self.simple_insert_txn( txn, table="push_rules", values={ @@ -561,7 +561,7 @@ class PushRuleStore(PushRulesWorkerStore): """ def delete_push_rule_txn(txn, stream_id, event_stream_ordering): - self._simple_delete_one_txn( + self.simple_delete_one_txn( txn, "push_rules", {"user_name": user_id, "rule_id": rule_id} ) @@ -596,7 +596,7 @@ class PushRuleStore(PushRulesWorkerStore): self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled ): new_id = self._push_rules_enable_id_gen.get_next() - self._simple_upsert_txn( + self.simple_upsert_txn( txn, "push_rules_enable", {"user_name": user_id, "rule_id": rule_id}, @@ -636,7 +636,7 @@ class PushRuleStore(PushRulesWorkerStore): update_stream=False, ) else: - self._simple_update_one_txn( + self.simple_update_one_txn( txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}, @@ -675,7 +675,7 @@ class PushRuleStore(PushRulesWorkerStore): if data is not None: values.update(data) - self._simple_insert_txn(txn, "push_rules_stream", values=values) + self.simple_insert_txn(txn, "push_rules_stream", values=values) txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,)) txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,)) diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py index d76861cdc0..d5a169872b 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -59,7 +59,7 @@ class PusherWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_has_pusher(self, user_id): - ret = yield self._simple_select_one_onecol( + ret = yield self.simple_select_one_onecol( "pushers", {"user_name": user_id}, "id", allow_none=True ) return ret is not None @@ -72,7 +72,7 @@ class PusherWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_pushers_by(self, keyvalues): - ret = yield self._simple_select_list( + ret = yield self.simple_select_list( "pushers", keyvalues, [ @@ -193,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore): inlineCallbacks=True, ) def get_if_users_have_pushers(self, user_ids): - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="pushers", column="user_name", iterable=user_ids, @@ -229,8 +229,8 @@ class PusherStore(PusherWorkerStore): ): with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on - # (app_id, pushkey, user_name) so _simple_upsert will retry - yield self._simple_upsert( + # (app_id, pushkey, user_name) so simple_upsert will retry + yield self.simple_upsert( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, values={ @@ -269,7 +269,7 @@ class PusherStore(PusherWorkerStore): txn, self.get_if_user_has_pusher, (user_id,) ) - self._simple_delete_one_txn( + self.simple_delete_one_txn( txn, "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, @@ -278,7 +278,7 @@ class PusherStore(PusherWorkerStore): # it's possible for us to end up with duplicate rows for # (app_id, pushkey, user_id) at different stream_ids, but that # doesn't really matter. - self._simple_insert_txn( + self.simple_insert_txn( txn, table="deleted_pushers", values={ @@ -296,7 +296,7 @@ class PusherStore(PusherWorkerStore): def update_pusher_last_stream_ordering( self, app_id, pushkey, user_id, last_stream_ordering ): - yield self._simple_update_one( + yield self.simple_update_one( "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, {"last_stream_ordering": last_stream_ordering}, @@ -319,7 +319,7 @@ class PusherStore(PusherWorkerStore): Returns: Deferred[bool]: True if the pusher still exists; False if it has been deleted. """ - updated = yield self._simple_update( + updated = yield self.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={ @@ -333,7 +333,7 @@ class PusherStore(PusherWorkerStore): @defer.inlineCallbacks def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since): - yield self._simple_update( + yield self.simple_update( table="pushers", keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, updatevalues={"failing_since": failing_since}, @@ -342,7 +342,7 @@ class PusherStore(PusherWorkerStore): @defer.inlineCallbacks def get_throttle_params_by_room(self, pusher_id): - res = yield self._simple_select_list( + res = yield self.simple_select_list( "pusher_throttle", {"pusher": pusher_id}, ["room_id", "last_sent_ts", "throttle_ms"], @@ -361,8 +361,8 @@ class PusherStore(PusherWorkerStore): @defer.inlineCallbacks def set_throttle_params(self, pusher_id, room_id, params): # no need to lock because `pusher_throttle` has a primary key on - # (pusher, room_id) so _simple_upsert will retry - yield self._simple_upsert( + # (pusher, room_id) so simple_upsert will retry + yield self.simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, params, diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py index 0c24430f28..380f388e30 100644 --- a/synapse/storage/data_stores/main/receipts.py +++ b/synapse/storage/data_stores/main/receipts.py @@ -61,7 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): - return self._simple_select_list( + return self.simple_select_list( table="receipts_linearized", keyvalues={"room_id": room_id, "receipt_type": receipt_type}, retcols=("user_id", "event_id"), @@ -70,7 +70,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(num_args=3) def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="receipts_linearized", keyvalues={ "room_id": room_id, @@ -84,7 +84,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cachedInlineCallbacks(num_args=2) def get_receipts_for_user(self, user_id, receipt_type): - rows = yield self._simple_select_list( + rows = yield self.simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, retcols=("room_id", "event_id"), @@ -280,7 +280,7 @@ class ReceiptsWorkerStore(SQLBaseStore): args.append(limit) txn.execute(sql, args) - return (r[0:5] + (json.loads(r[5]),) for r in txn) + return list(r[0:5] + (json.loads(r[5]),) for r in txn) return self.runInteraction( "get_all_updated_receipts", get_all_updated_receipts_txn @@ -335,7 +335,7 @@ class ReceiptsStore(ReceiptsWorkerStore): otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) """ - res = self._simple_select_one_txn( + res = self.simple_select_one_txn( txn, table="events", retcols=["stream_ordering", "received_ts"], @@ -388,7 +388,7 @@ class ReceiptsStore(ReceiptsWorkerStore): (user_id, room_id, receipt_type), ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="receipts_linearized", keyvalues={ @@ -398,7 +398,7 @@ class ReceiptsStore(ReceiptsWorkerStore): }, ) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="receipts_linearized", values={ @@ -514,7 +514,7 @@ class ReceiptsStore(ReceiptsWorkerStore): self._get_linearized_receipts_for_room.invalidate_many, (room_id,) ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="receipts_graph", keyvalues={ @@ -523,7 +523,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "user_id": user_id, }, ) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="receipts_graph", values={ diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py index 89147ad511..debc6706f5 100644 --- a/synapse/storage/data_stores/main/registration.py +++ b/synapse/storage/data_stores/main/registration.py @@ -19,7 +19,6 @@ import logging import re from six import iterkeys -from six.moves import range from twisted.internet import defer from twisted.internet.defer import Deferred @@ -46,7 +45,7 @@ class RegistrationWorkerStore(SQLBaseStore): @cached() def get_user_by_id(self, user_id): - return self._simple_select_one( + return self.simple_select_one( table="users", keyvalues={"name": user_id}, retcols=[ @@ -110,7 +109,7 @@ class RegistrationWorkerStore(SQLBaseStore): otherwise int representation of the timestamp (as a number of milliseconds since epoch). """ - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="expiration_ts_ms", @@ -138,7 +137,7 @@ class RegistrationWorkerStore(SQLBaseStore): """ def set_account_validity_for_user_txn(txn): - self._simple_update_txn( + self.simple_update_txn( txn=txn, table="account_validity", keyvalues={"user_id": user_id}, @@ -168,7 +167,7 @@ class RegistrationWorkerStore(SQLBaseStore): Raises: StoreError: The provided token is already set for another user. """ - yield self._simple_update_one( + yield self.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"renewal_token": renewal_token}, @@ -185,7 +184,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: defer.Deferred[str]: The ID of the user to which the token belongs. """ - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="account_validity", keyvalues={"renewal_token": renewal_token}, retcol="user_id", @@ -204,7 +203,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: defer.Deferred[str]: The renewal token associated with this user ID. """ - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="account_validity", keyvalues={"user_id": user_id}, retcol="renewal_token", @@ -251,7 +250,7 @@ class RegistrationWorkerStore(SQLBaseStore): email_sent (bool): Flag which indicates whether a renewal email has been sent to this user. """ - yield self._simple_update_one( + yield self.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, updatevalues={"email_sent": email_sent}, @@ -266,7 +265,7 @@ class RegistrationWorkerStore(SQLBaseStore): Args: user_id (str): ID of the user to remove from the account validity table. """ - yield self._simple_delete_one( + yield self.simple_delete_one( table="account_validity", keyvalues={"user_id": user_id}, desc="delete_account_validity_for_user", @@ -282,7 +281,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns (bool): true iff the user is a server admin, false otherwise. """ - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="users", keyvalues={"name": user.to_string()}, retcol="admin", @@ -300,7 +299,7 @@ class RegistrationWorkerStore(SQLBaseStore): admin (bool): true iff the user is to be a server admin, false otherwise. """ - return self._simple_update_one( + return self.simple_update_one( table="users", keyvalues={"name": user.to_string()}, updatevalues={"admin": 1 if admin else 0}, @@ -352,7 +351,7 @@ class RegistrationWorkerStore(SQLBaseStore): return res def is_real_user_txn(self, txn, user_id): - res = self._simple_select_one_onecol_txn( + res = self.simple_select_one_onecol_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -362,7 +361,7 @@ class RegistrationWorkerStore(SQLBaseStore): return res is None def is_support_user_txn(self, txn, user_id): - res = self._simple_select_one_onecol_txn( + res = self.simple_select_one_onecol_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -377,9 +376,7 @@ class RegistrationWorkerStore(SQLBaseStore): """ def f(txn): - sql = ( - "SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)" - ) + sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)" txn.execute(sql, (user_id,)) return dict(txn) @@ -397,7 +394,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: str|None: the mxid of the user, or None if they are not known """ - return await self._simple_select_one_onecol( + return await self.simple_select_one_onecol( table="user_external_ids", keyvalues={"auth_provider": auth_provider, "external_id": external_id}, retcol="user_id", @@ -484,12 +481,8 @@ class RegistrationWorkerStore(SQLBaseStore): """ Gets the localpart of the next generated user ID. - Generated user IDs are integers, and we aim for them to be as small as - we can. Unfortunately, it's possible some of them are already taken by - existing users, and there may be gaps in the already taken range. This - function returns the start of the first allocatable gap. This is to - avoid the case of ID 1000 being pre-allocated and starting at 1001 while - 0-999 are available. + Generated user IDs are integers, so we find the largest integer user ID + already taken and return that plus one. """ def _find_next_generated_user_id(txn): @@ -499,15 +492,14 @@ class RegistrationWorkerStore(SQLBaseStore): regex = re.compile(r"^@(\d+):") - found = set() + max_found = 0 for (user_id,) in txn: match = regex.search(user_id) if match: - found.add(int(match.group(1))) - for i in range(len(found) + 1): - if i not in found: - return i + max_found = max(int(match.group(1)), max_found) + + return max_found + 1 return ( ( @@ -544,7 +536,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: str|None: user id or None if no user id/threepid mapping exists """ - ret = self._simple_select_one_txn( + ret = self.simple_select_one_txn( txn, "user_threepids", {"medium": medium, "address": address}, @@ -557,7 +549,7 @@ class RegistrationWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_add_threepid(self, user_id, medium, address, validated_at, added_at): - yield self._simple_upsert( + yield self.simple_upsert( "user_threepids", {"medium": medium, "address": address}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, @@ -565,7 +557,7 @@ class RegistrationWorkerStore(SQLBaseStore): @defer.inlineCallbacks def user_get_threepids(self, user_id): - ret = yield self._simple_select_list( + ret = yield self.simple_select_list( "user_threepids", {"user_id": user_id}, ["medium", "address", "validated_at", "added_at"], @@ -574,7 +566,7 @@ class RegistrationWorkerStore(SQLBaseStore): return ret def user_delete_threepid(self, user_id, medium, address): - return self._simple_delete( + return self.simple_delete( "user_threepids", keyvalues={"user_id": user_id, "medium": medium, "address": address}, desc="user_delete_threepid", @@ -587,7 +579,7 @@ class RegistrationWorkerStore(SQLBaseStore): user_id: The user id to delete all threepids of """ - return self._simple_delete( + return self.simple_delete( "user_threepids", keyvalues={"user_id": user_id}, desc="user_delete_threepids", @@ -609,7 +601,7 @@ class RegistrationWorkerStore(SQLBaseStore): """ # We need to use an upsert, in case they user had already bound the # threepid - return self._simple_upsert( + return self.simple_upsert( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -635,7 +627,7 @@ class RegistrationWorkerStore(SQLBaseStore): medium (str): The medium of the threepid (e.g "email") address (str): The address of the threepid (e.g "bob@example.com") """ - return self._simple_select_list( + return self.simple_select_list( table="user_threepid_id_server", keyvalues={"user_id": user_id}, retcols=["medium", "address"], @@ -656,7 +648,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred """ - return self._simple_delete( + return self.simple_delete( table="user_threepid_id_server", keyvalues={ "user_id": user_id, @@ -679,7 +671,7 @@ class RegistrationWorkerStore(SQLBaseStore): Returns: Deferred[list[str]]: Resolves to a list of identity servers """ - return self._simple_select_onecol( + return self.simple_select_onecol( table="user_threepid_id_server", keyvalues={"user_id": user_id, "medium": medium, "address": address}, retcol="id_server", @@ -697,7 +689,7 @@ class RegistrationWorkerStore(SQLBaseStore): defer.Deferred(bool): The requested value. """ - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="deactivated", @@ -784,12 +776,12 @@ class RegistrationWorkerStore(SQLBaseStore): """ def delete_threepid_session_txn(txn): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="threepid_validation_token", keyvalues={"session_id": session_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -934,6 +926,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): self._account_validity = hs.config.account_validity + if self._account_validity.enabled: + self._clock.call_later( + 0.0, + run_as_background_process, + "account_validity_set_expiration_dates", + self._set_expiration_date_when_missing, + ) + # Create a background job for culling expired 3PID validity tokens def start_cull(): # run as a background process to make sure that the database transactions @@ -961,7 +961,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ next_id = self._access_tokens_id_gen.get_next() - yield self._simple_insert( + yield self.simple_insert( "access_tokens", { "id": next_id, @@ -1037,7 +1037,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): # Ensure that the guest user actually exists # ``allow_none=False`` makes this raise an exception # if the row isn't in the database. - self._simple_select_one_txn( + self.simple_select_one_txn( txn, "users", keyvalues={"name": user_id, "is_guest": 1}, @@ -1045,7 +1045,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): allow_none=False, ) - self._simple_update_one_txn( + self.simple_update_one_txn( txn, "users", keyvalues={"name": user_id, "is_guest": 1}, @@ -1059,7 +1059,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): }, ) else: - self._simple_insert_txn( + self.simple_insert_txn( txn, "users", values={ @@ -1114,7 +1114,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): external_id: id on that system user_id: complete mxid that it is mapped to """ - return self._simple_insert( + return self.simple_insert( table="user_external_ids", values={ "auth_provider": auth_provider, @@ -1132,7 +1132,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ def user_set_password_hash_txn(txn): - self._simple_update_one_txn( + self.simple_update_one_txn( txn, "users", {"name": user_id}, {"password_hash": password_hash} ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) @@ -1152,7 +1152,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ def f(txn): - self._simple_update_one_txn( + self.simple_update_one_txn( txn, table="users", keyvalues={"name": user_id}, @@ -1176,7 +1176,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ def f(txn): - self._simple_update_one_txn( + self.simple_update_one_txn( txn, table="users", keyvalues={"name": user_id}, @@ -1234,7 +1234,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def delete_access_token(self, access_token): def f(txn): - self._simple_delete_one_txn( + self.simple_delete_one_txn( txn, table="access_tokens", keyvalues={"token": access_token} ) @@ -1246,7 +1246,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): @cachedInlineCallbacks() def is_guest(self, user_id): - res = yield self._simple_select_one_onecol( + res = yield self.simple_select_one_onecol( table="users", keyvalues={"name": user_id}, retcol="is_guest", @@ -1261,7 +1261,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): Adds a user to the table of users who need to be parted from all the rooms they're in """ - return self._simple_insert( + return self.simple_insert( "users_pending_deactivation", values={"user_id": user_id}, desc="add_user_pending_deactivation", @@ -1274,7 +1274,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): """ # XXX: This should be simple_delete_one but we failed to put a unique index on # the table, so somehow duplicate entries have ended up in it. - return self._simple_delete( + return self.simple_delete( "users_pending_deactivation", keyvalues={"user_id": user_id}, desc="del_user_pending_deactivation", @@ -1285,7 +1285,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): Gets one user from the table of users waiting to be parted from all the rooms they're in. """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( "users_pending_deactivation", keyvalues={}, retcol="user_id", @@ -1315,7 +1315,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): # Insert everything into a transaction in order to run atomically def validate_threepid_session_txn(txn): - row = self._simple_select_one_txn( + row = self.simple_select_one_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1333,7 +1333,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): 400, "This client_secret does not match the provided session_id" ) - row = self._simple_select_one_txn( + row = self.simple_select_one_txn( txn, table="threepid_validation_token", keyvalues={"session_id": session_id, "token": token}, @@ -1358,7 +1358,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) # Looks good. Validate the session - self._simple_update_txn( + self.simple_update_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1401,7 +1401,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): if validated_at: insertion_values["validated_at"] = validated_at - return self._simple_upsert( + return self.simple_upsert( table="threepid_validation_session", keyvalues={"session_id": session_id}, values={"last_send_attempt": send_attempt}, @@ -1439,7 +1439,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def start_or_continue_validation_session_txn(txn): # Create or update a validation session - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, @@ -1452,7 +1452,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) # Create a new validation token with this session ID - self._simple_insert_txn( + self.simple_insert_txn( txn, table="threepid_validation_token", values={ @@ -1501,7 +1501,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): ) def set_user_deactivated_status_txn(self, txn, user_id, deactivated): - self._simple_update_one_txn( + self.simple_update_one_txn( txn=txn, table="users", keyvalues={"name": user_id}, @@ -1510,3 +1510,59 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): self._invalidate_cache_and_stream( txn, self.get_user_deactivated_status, (user_id,) ) + + @defer.inlineCallbacks + def _set_expiration_date_when_missing(self): + """ + Retrieves the list of registered users that don't have an expiration date, and + adds an expiration date for each of them. + """ + + def select_users_with_no_expiration_date_txn(txn): + """Retrieves the list of registered users with no expiration date from the + database, filtering out deactivated users. + """ + sql = ( + "SELECT users.name FROM users" + " LEFT JOIN account_validity ON (users.name = account_validity.user_id)" + " WHERE account_validity.user_id is NULL AND users.deactivated = 0;" + ) + txn.execute(sql, []) + + res = self.cursor_to_dict(txn) + if res: + for user in res: + self.set_expiration_date_for_user_txn( + txn, user["name"], use_delta=True + ) + + yield self.runInteraction( + "get_users_with_no_expiration_date", + select_users_with_no_expiration_date_txn, + ) + + def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False): + """Sets an expiration date to the account with the given user ID. + + Args: + user_id (str): User ID to set an expiration date for. + use_delta (bool): If set to False, the expiration date for the user will be + now + validity period. If set to True, this expiration date will be a + random value in the [now + period - d ; now + period] range, d being a + delta equal to 10% of the validity period. + """ + now_ms = self._clock.time_msec() + expiration_ts = now_ms + self._account_validity.period + + if use_delta: + expiration_ts = self.rand.randrange( + expiration_ts - self._account_validity.startup_job_max_delta, + expiration_ts, + ) + + self.simple_upsert_txn( + txn, + "account_validity", + keyvalues={"user_id": user_id}, + values={"expiration_ts_ms": expiration_ts, "email_sent": False}, + ) diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/data_stores/main/rejections.py index 7d5de0ea2e..f81f9279a1 100644 --- a/synapse/storage/data_stores/main/rejections.py +++ b/synapse/storage/data_stores/main/rejections.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) class RejectionsStore(SQLBaseStore): def _store_rejections_txn(self, txn, event_id, reason): - self._simple_insert_txn( + self.simple_insert_txn( txn, table="rejections", values={ @@ -33,7 +33,7 @@ class RejectionsStore(SQLBaseStore): ) def get_rejection_reason(self, event_id): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="rejections", retcol="reason", keyvalues={"event_id": event_id}, diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py index 858f65582b..aa5e10538b 100644 --- a/synapse/storage/data_stores/main/relations.py +++ b/synapse/storage/data_stores/main/relations.py @@ -352,7 +352,7 @@ class RelationsStore(RelationsWorkerStore): aggregation_key = relation.get("key") - self._simple_insert_txn( + self.simple_insert_txn( txn, table="event_relations", values={ @@ -380,6 +380,6 @@ class RelationsStore(RelationsWorkerStore): redacted_event_id (str): The event that was redacted. """ - self._simple_delete_txn( + self.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} ) diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index 67bb1b6f60..f309e3640c 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -19,12 +19,16 @@ import logging import re from typing import Optional, Tuple +from six import integer_types + from canonicaljson import json from twisted.internet import defer +from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore +from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.data_stores.main.search import SearchStore from synapse.types import ThirdPartyInstanceID from synapse.util.caches.descriptors import cached, cachedInlineCallbacks @@ -50,7 +54,7 @@ class RoomWorkerStore(SQLBaseStore): Returns: A dict containing the room information, or None if the room is unknown. """ - return self._simple_select_one( + return self.simple_select_one( table="rooms", keyvalues={"room_id": room_id}, retcols=("room_id", "is_public", "creator"), @@ -59,7 +63,7 @@ class RoomWorkerStore(SQLBaseStore): ) def get_public_room_ids(self): - return self._simple_select_onecol( + return self.simple_select_onecol( table="rooms", keyvalues={"is_public": True}, retcol="room_id", @@ -263,7 +267,7 @@ class RoomWorkerStore(SQLBaseStore): @cached(max_entries=10000) def is_room_blocked(self, room_id): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="blocked_rooms", keyvalues={"room_id": room_id}, retcol="1", @@ -284,7 +288,7 @@ class RoomWorkerStore(SQLBaseStore): of RatelimitOverride are None or 0 then ratelimitng has been disabled for that user entirely. """ - row = yield self._simple_select_one( + row = yield self.simple_select_one( table="ratelimit_override", keyvalues={"user_id": user_id}, retcols=("messages_per_second", "burst_count"), @@ -300,8 +304,148 @@ class RoomWorkerStore(SQLBaseStore): else: return None + @cachedInlineCallbacks() + def get_retention_policy_for_room(self, room_id): + """Get the retention policy for a given room. + + If no retention policy has been found for this room, returns a policy defined + by the configured default policy (which has None as both the 'min_lifetime' and + the 'max_lifetime' if no default policy has been defined in the server's + configuration). + + Args: + room_id (str): The ID of the room to get the retention policy of. + + Returns: + dict[int, int]: "min_lifetime" and "max_lifetime" for this room. + """ + + def get_retention_policy_for_room_txn(txn): + txn.execute( + """ + SELECT min_lifetime, max_lifetime FROM room_retention + INNER JOIN current_state_events USING (event_id, room_id) + WHERE room_id = ?; + """, + (room_id,), + ) + + return self.cursor_to_dict(txn) + + ret = yield self.runInteraction( + "get_retention_policy_for_room", get_retention_policy_for_room_txn, + ) + + # If we don't know this room ID, ret will be None, in this case return the default + # policy. + if not ret: + defer.returnValue( + { + "min_lifetime": self.config.retention_default_min_lifetime, + "max_lifetime": self.config.retention_default_max_lifetime, + } + ) + + row = ret[0] + + # If one of the room's policy's attributes isn't defined, use the matching + # attribute from the default policy. + # The default values will be None if no default policy has been defined, or if one + # of the attributes is missing from the default policy. + if row["min_lifetime"] is None: + row["min_lifetime"] = self.config.retention_default_min_lifetime + + if row["max_lifetime"] is None: + row["max_lifetime"] = self.config.retention_default_max_lifetime + + defer.returnValue(row) + + +class RoomBackgroundUpdateStore(BackgroundUpdateStore): + def __init__(self, db_conn, hs): + super(RoomBackgroundUpdateStore, self).__init__(db_conn, hs) + + self.config = hs.config + + self.register_background_update_handler( + "insert_room_retention", self._background_insert_retention, + ) + + @defer.inlineCallbacks + def _background_insert_retention(self, progress, batch_size): + """Retrieves a list of all rooms within a range and inserts an entry for each of + them into the room_retention table. + NULLs the property's columns if missing from the retention event in the room's + state (or NULLs all of them if there's no retention event in the room's state), + so that we fall back to the server's retention policy. + """ + + last_room = progress.get("room_id", "") + + def _background_insert_retention_txn(txn): + txn.execute( + """ + SELECT state.room_id, state.event_id, events.json + FROM current_state_events as state + LEFT JOIN event_json AS events ON (state.event_id = events.event_id) + WHERE state.room_id > ? AND state.type = '%s' + ORDER BY state.room_id ASC + LIMIT ?; + """ + % EventTypes.Retention, + (last_room, batch_size), + ) + + rows = self.cursor_to_dict(txn) + + if not rows: + return True + + for row in rows: + if not row["json"]: + retention_policy = {} + else: + ev = json.loads(row["json"]) + retention_policy = json.dumps(ev["content"]) + + self.simple_insert_txn( + txn=txn, + table="room_retention", + values={ + "room_id": row["room_id"], + "event_id": row["event_id"], + "min_lifetime": retention_policy.get("min_lifetime"), + "max_lifetime": retention_policy.get("max_lifetime"), + }, + ) + + logger.info("Inserted %d rows into room_retention", len(rows)) + + self._background_update_progress_txn( + txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]} + ) + + if batch_size > len(rows): + return True + else: + return False + + end = yield self.runInteraction( + "insert_room_retention", _background_insert_retention_txn, + ) + + if end: + yield self._end_background_update("insert_room_retention") + + defer.returnValue(batch_size) + + +class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): + def __init__(self, db_conn, hs): + super(RoomStore, self).__init__(db_conn, hs) + + self.config = hs.config -class RoomStore(RoomWorkerStore, SearchStore): @defer.inlineCallbacks def store_room(self, room_id, room_creator_user_id, is_public): """Stores a room. @@ -317,7 +461,7 @@ class RoomStore(RoomWorkerStore, SearchStore): try: def store_room_txn(txn, next_id): - self._simple_insert_txn( + self.simple_insert_txn( txn, "rooms", { @@ -327,7 +471,7 @@ class RoomStore(RoomWorkerStore, SearchStore): }, ) if is_public: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -346,14 +490,14 @@ class RoomStore(RoomWorkerStore, SearchStore): @defer.inlineCallbacks def set_room_is_public(self, room_id, is_public): def set_room_is_public_txn(txn, next_id): - self._simple_update_one_txn( + self.simple_update_one_txn( txn, table="rooms", keyvalues={"room_id": room_id}, updatevalues={"is_public": is_public}, ) - entries = self._simple_select_list_txn( + entries = self.simple_select_list_txn( txn, table="public_room_list_stream", keyvalues={ @@ -371,7 +515,7 @@ class RoomStore(RoomWorkerStore, SearchStore): add_to_stream = bool(entries[-1]["visibility"]) != is_public if add_to_stream: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -411,7 +555,7 @@ class RoomStore(RoomWorkerStore, SearchStore): def set_room_is_public_appservice_txn(txn, next_id): if is_public: try: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="appservice_room_list", values={ @@ -424,7 +568,7 @@ class RoomStore(RoomWorkerStore, SearchStore): # We've already inserted, nothing to do. return else: - self._simple_delete_txn( + self.simple_delete_txn( txn, table="appservice_room_list", keyvalues={ @@ -434,7 +578,7 @@ class RoomStore(RoomWorkerStore, SearchStore): }, ) - entries = self._simple_select_list_txn( + entries = self.simple_select_list_txn( txn, table="public_room_list_stream", keyvalues={ @@ -452,7 +596,7 @@ class RoomStore(RoomWorkerStore, SearchStore): add_to_stream = bool(entries[-1]["visibility"]) != is_public if add_to_stream: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="public_room_list_stream", values={ @@ -502,11 +646,40 @@ class RoomStore(RoomWorkerStore, SearchStore): txn, event, "content.body", event.content["body"] ) + def _store_retention_policy_for_room_txn(self, txn, event): + if hasattr(event, "content") and ( + "min_lifetime" in event.content or "max_lifetime" in event.content + ): + if ( + "min_lifetime" in event.content + and not isinstance(event.content.get("min_lifetime"), integer_types) + ) or ( + "max_lifetime" in event.content + and not isinstance(event.content.get("max_lifetime"), integer_types) + ): + # Ignore the event if one of the value isn't an integer. + return + + self.simple_insert_txn( + txn=txn, + table="room_retention", + values={ + "room_id": event.room_id, + "event_id": event.event_id, + "min_lifetime": event.content.get("min_lifetime"), + "max_lifetime": event.content.get("max_lifetime"), + }, + ) + + self._invalidate_cache_and_stream( + txn, self.get_retention_policy_for_room, (event.room_id,) + ) + def add_event_report( self, room_id, event_id, user_id, reason, content, received_ts ): next_id = self._event_reports_id_gen.get_next() - return self._simple_insert( + return self.simple_insert( table="event_reports", values={ "id": next_id, @@ -552,7 +725,7 @@ class RoomStore(RoomWorkerStore, SearchStore): Returns: Deferred """ - yield self._simple_upsert( + yield self.simple_upsert( table="blocked_rooms", keyvalues={"room_id": room_id}, values={}, @@ -683,3 +856,89 @@ class RoomStore(RoomWorkerStore, SearchStore): remote_media_mxcs.append((hostname, media_id)) return local_media_mxcs, remote_media_mxcs + + @defer.inlineCallbacks + def get_rooms_for_retention_period_in_range( + self, min_ms, max_ms, include_null=False + ): + """Retrieves all of the rooms within the given retention range. + + Optionally includes the rooms which don't have a retention policy. + + Args: + min_ms (int|None): Duration in milliseconds that define the lower limit of + the range to handle (exclusive). If None, doesn't set a lower limit. + max_ms (int|None): Duration in milliseconds that define the upper limit of + the range to handle (inclusive). If None, doesn't set an upper limit. + include_null (bool): Whether to include rooms which retention policy is NULL + in the returned set. + + Returns: + dict[str, dict]: The rooms within this range, along with their retention + policy. The key is "room_id", and maps to a dict describing the retention + policy associated with this room ID. The keys for this nested dict are + "min_lifetime" (int|None), and "max_lifetime" (int|None). + """ + + def get_rooms_for_retention_period_in_range_txn(txn): + range_conditions = [] + args = [] + + if min_ms is not None: + range_conditions.append("max_lifetime > ?") + args.append(min_ms) + + if max_ms is not None: + range_conditions.append("max_lifetime <= ?") + args.append(max_ms) + + # Do a first query which will retrieve the rooms that have a retention policy + # in their current state. + sql = """ + SELECT room_id, min_lifetime, max_lifetime FROM room_retention + INNER JOIN current_state_events USING (event_id, room_id) + """ + + if len(range_conditions): + sql += " WHERE (" + " AND ".join(range_conditions) + ")" + + if include_null: + sql += " OR max_lifetime IS NULL" + + txn.execute(sql, args) + + rows = self.cursor_to_dict(txn) + rooms_dict = {} + + for row in rows: + rooms_dict[row["room_id"]] = { + "min_lifetime": row["min_lifetime"], + "max_lifetime": row["max_lifetime"], + } + + if include_null: + # If required, do a second query that retrieves all of the rooms we know + # of so we can handle rooms with no retention policy. + sql = "SELECT DISTINCT room_id FROM current_state_events" + + txn.execute(sql) + + rows = self.cursor_to_dict(txn) + + # If a room isn't already in the dict (i.e. it doesn't have a retention + # policy in its state), add it with a null policy. + for row in rows: + if row["room_id"] not in rooms_dict: + rooms_dict[row["room_id"]] = { + "min_lifetime": None, + "max_lifetime": None, + } + + return rooms_dict + + rooms = yield self.runInteraction( + "get_rooms_for_retention_period_in_range", + get_rooms_for_retention_period_in_range_txn, + ) + + defer.returnValue(rooms) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 2af24a20b7..fe2428a281 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +from typing import Iterable, List from six import iteritems, itervalues @@ -127,7 +128,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): membership column is up to date """ - pending_update = self._simple_select_one_txn( + pending_update = self.simple_select_one_txn( txn, table="background_updates", keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME}, @@ -602,7 +603,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): to `user_id` and ProfileInfo (or None if not join event). """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="room_memberships", column="event_id", iterable=event_ids, @@ -642,7 +643,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # the returned user actually has the correct domain. like_clause = "%:" + host - rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause) + rows = yield self.execute("is_host_joined", None, sql, room_id, like_clause) if not rows: return False @@ -682,7 +683,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # the returned user actually has the correct domain. like_clause = "%:" + host - rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause) + rows = yield self.execute("was_host_joined", None, sql, room_id, like_clause) if not rows: return False @@ -804,7 +805,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): Deferred[set[str]]: Set of room IDs. """ - room_ids = yield self._simple_select_onecol( + room_ids = yield self.simple_select_onecol( table="room_memberships", keyvalues={"membership": Membership.JOIN, "user_id": user_id}, retcol="room_id", @@ -813,6 +814,22 @@ class RoomMemberWorkerStore(EventsWorkerStore): return set(room_ids) + def get_membership_from_event_ids( + self, member_event_ids: Iterable[str] + ) -> List[dict]: + """Get user_id and membership of a set of event IDs. + """ + + return self.simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=member_event_ids, + retcols=("user_id", "membership", "event_id"), + keyvalues={}, + batch_size=500, + desc="get_membership_from_event_ids", + ) + class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore): def __init__(self, db_conn, hs): @@ -973,7 +990,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): def _store_room_members_txn(self, txn, events, backfilled): """Store a room member in the database. """ - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="room_memberships", values=[ @@ -1011,7 +1028,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): is_mine = self.hs.is_mine_id(event.state_key) if is_new_state and is_mine: if event.membership == Membership.INVITE: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="local_invites", values={ diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql new file mode 100644 index 0000000000..81a36a8b1d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql @@ -0,0 +1,21 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS event_expiry ( + event_id TEXT PRIMARY KEY, + expiry_ts BIGINT NOT NULL +); + +CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts); diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql new file mode 100644 index 0000000000..7d70dd071e --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- store the current etag of backup version +ALTER TABLE e2e_room_keys_versions ADD COLUMN etag BIGINT; diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql b/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql new file mode 100644 index 0000000000..ee6cdf7a14 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql @@ -0,0 +1,33 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Tracks the retention policy of a room. +-- A NULL max_lifetime or min_lifetime means that the matching property is not defined in +-- the room's retention policy state event. +-- If a room doesn't have a retention policy state event in its state, both max_lifetime +-- and min_lifetime are NULL. +CREATE TABLE IF NOT EXISTS room_retention( + room_id TEXT, + event_id TEXT, + min_lifetime BIGINT, + max_lifetime BIGINT, + + PRIMARY KEY(room_id, event_id) +); + +CREATE INDEX room_retention_max_lifetime_idx on room_retention(max_lifetime); + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('insert_room_retention', '{}'); diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql index 27a96123e3..5c5fffcafb 100644 --- a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql +++ b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql @@ -40,7 +40,8 @@ CREATE TABLE IF NOT EXISTS e2e_cross_signing_signatures ( signature TEXT NOT NULL ); -CREATE UNIQUE INDEX e2e_cross_signing_signatures_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id); +-- replaced by the index created in signing_keys_nonunique_signatures.sql +-- CREATE UNIQUE INDEX e2e_cross_signing_signatures_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id); -- stream of user signature updates CREATE TABLE IF NOT EXISTS user_signature_stream ( diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql new file mode 100644 index 0000000000..0aa90ebf0c --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql @@ -0,0 +1,22 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* The cross-signing signatures index should not be a unique index, because a + * user may upload multiple signatures for the same target user. The previous + * index was unique, so delete it if it's there and create a new non-unique + * index. */ + +DROP INDEX IF EXISTS e2e_cross_signing_signatures_idx; CREATE INDEX IF NOT +EXISTS e2e_cross_signing_signatures2_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id); diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index d1d7c6863d..f735cf095c 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -441,7 +441,7 @@ class SearchStore(SearchBackgroundUpdateStore): # entire table from the database. sql += " ORDER BY rank DESC LIMIT 500" - results = yield self._execute("search_msgs", self.cursor_to_dict, sql, *args) + results = yield self.execute("search_msgs", self.cursor_to_dict, sql, *args) results = list(filter(lambda row: row["room_id"] in room_ids, results)) @@ -455,7 +455,7 @@ class SearchStore(SearchBackgroundUpdateStore): count_sql += " GROUP BY room_id" - count_results = yield self._execute( + count_results = yield self.execute( "search_rooms_count", self.cursor_to_dict, count_sql, *count_args ) @@ -586,7 +586,7 @@ class SearchStore(SearchBackgroundUpdateStore): args.append(limit) - results = yield self._execute("search_rooms", self.cursor_to_dict, sql, *args) + results = yield self.execute("search_rooms", self.cursor_to_dict, sql, *args) results = list(filter(lambda row: row["room_id"] in room_ids, results)) @@ -600,7 +600,7 @@ class SearchStore(SearchBackgroundUpdateStore): count_sql += " GROUP BY room_id" - count_results = yield self._execute( + count_results = yield self.execute( "search_rooms_count", self.cursor_to_dict, count_sql, *count_args ) diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/data_stores/main/signatures.py index 556191b76f..f3da29ce14 100644 --- a/synapse/storage/data_stores/main/signatures.py +++ b/synapse/storage/data_stores/main/signatures.py @@ -98,4 +98,4 @@ class SignatureStore(SignatureWorkerStore): } ) - self._simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) + self.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals) diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 6a90daea31..2b33ec1a35 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -89,7 +89,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): count = 0 while next_group: - next_group = self._simple_select_one_onecol_txn( + next_group = self.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": next_group}, @@ -192,7 +192,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): ): break - next_group = self._simple_select_one_onecol_txn( + next_group = self.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": next_group}, @@ -431,7 +431,7 @@ class StateGroupWorkerStore( """ def _get_state_group_delta_txn(txn): - prev_group = self._simple_select_one_onecol_txn( + prev_group = self.simple_select_one_onecol_txn( txn, table="state_group_edges", keyvalues={"state_group": state_group}, @@ -442,7 +442,7 @@ class StateGroupWorkerStore( if not prev_group: return _GetStateGroupDelta(None, None) - delta_ids = self._simple_select_list_txn( + delta_ids = self.simple_select_list_txn( txn, table="state_groups_state", keyvalues={"state_group": state_group}, @@ -644,7 +644,7 @@ class StateGroupWorkerStore( @cached(max_entries=50000) def _get_state_group_for_event(self, event_id): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="event_to_state_groups", keyvalues={"event_id": event_id}, retcol="state_group", @@ -661,7 +661,7 @@ class StateGroupWorkerStore( def _get_state_group_for_events(self, event_ids): """Returns mapping event_id -> state_group """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="event_to_state_groups", column="event_id", iterable=event_ids, @@ -902,7 +902,7 @@ class StateGroupWorkerStore( state_group = self.database_engine.get_next_state_group_id(txn) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="state_groups", values={"id": state_group, "room_id": room_id, "event_id": event_id}, @@ -911,7 +911,7 @@ class StateGroupWorkerStore( # We persist as a delta if we can, while also ensuring the chain # of deltas isn't tooo long, as otherwise read performance degrades. if prev_group: - is_in_db = self._simple_select_one_onecol_txn( + is_in_db = self.simple_select_one_onecol_txn( txn, table="state_groups", keyvalues={"id": prev_group}, @@ -926,13 +926,13 @@ class StateGroupWorkerStore( potential_hops = self._count_state_group_hops_txn(txn, prev_group) if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="state_group_edges", values={"state_group": state_group, "prev_state_group": prev_group}, ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -947,7 +947,7 @@ class StateGroupWorkerStore( ], ) else: - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -1007,7 +1007,7 @@ class StateGroupWorkerStore( referenced. """ - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="event_to_state_groups", column="state_group", iterable=state_groups, @@ -1065,7 +1065,7 @@ class StateBackgroundUpdateStore( batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) if max_group is None: - rows = yield self._execute( + rows = yield self.execute( "_background_deduplicate_state", None, "SELECT coalesce(max(id), 0) FROM state_groups", @@ -1135,13 +1135,13 @@ class StateBackgroundUpdateStore( if prev_state.get(key, None) != value } - self._simple_delete_txn( + self.simple_delete_txn( txn, table="state_group_edges", keyvalues={"state_group": state_group}, ) - self._simple_insert_txn( + self.simple_insert_txn( txn, table="state_group_edges", values={ @@ -1150,13 +1150,13 @@ class StateBackgroundUpdateStore( }, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="state_groups_state", keyvalues={"state_group": state_group}, ) - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="state_groups_state", values=[ @@ -1263,7 +1263,7 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): state_groups[event.event_id] = context.state_group - self._simple_insert_many_txn( + self.simple_insert_many_txn( txn, table="event_to_state_groups", values=[ diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py index 28f33ec18f..03b908026b 100644 --- a/synapse/storage/data_stores/main/state_deltas.py +++ b/synapse/storage/data_stores/main/state_deltas.py @@ -105,7 +105,7 @@ class StateDeltasStore(SQLBaseStore): ) def _get_max_stream_id_in_current_state_deltas_txn(self, txn): - return self._simple_select_one_onecol_txn( + return self.simple_select_one_onecol_txn( txn, table="current_state_delta_stream", keyvalues={}, diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 45b3de7d56..3aeba859fd 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -186,7 +186,7 @@ class StatsStore(StateDeltasStore): """ Returns the stats processor positions. """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="stats_incremental_position", keyvalues={}, retcol="stream_id", @@ -215,7 +215,7 @@ class StatsStore(StateDeltasStore): if field and "\0" in field: fields[col] = None - return self._simple_upsert( + return self.simple_upsert( table="room_stats_state", keyvalues={"room_id": room_id}, values=fields, @@ -257,7 +257,7 @@ class StatsStore(StateDeltasStore): ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type] ) - slice_list = self._simple_select_list_paginate_txn( + slice_list = self.simple_select_list_paginate_txn( txn, table + "_historical", {id_col: stats_id}, @@ -282,7 +282,7 @@ class StatsStore(StateDeltasStore): "name", "topic", "canonical_alias", "avatar", "join_rules", "history_visibility" """ - return self._simple_select_one( + return self.simple_select_one( "room_stats_state", {"room_id": room_id}, retcols=( @@ -308,7 +308,7 @@ class StatsStore(StateDeltasStore): """ table, id_col = TYPE_TO_TABLE[stats_type] - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( "%s_current" % (table,), keyvalues={id_col: id}, retcol="completed_delta_stream_id", @@ -344,7 +344,7 @@ class StatsStore(StateDeltasStore): complete_with_stream_id=stream_id, ) - self._simple_update_one_txn( + self.simple_update_one_txn( txn, table="stats_incremental_position", keyvalues={}, @@ -517,17 +517,17 @@ class StatsStore(StateDeltasStore): else: self.database_engine.lock_table(txn, table) retcols = list(chain(absolutes.keys(), additive_relatives.keys())) - current_row = self._simple_select_one_txn( + current_row = self.simple_select_one_txn( txn, table, keyvalues, retcols, allow_none=True ) if current_row is None: merged_dict = {**keyvalues, **absolutes, **additive_relatives} - self._simple_insert_txn(txn, table, merged_dict) + self.simple_insert_txn(txn, table, merged_dict) else: for (key, val) in additive_relatives.items(): current_row[key] += val current_row.update(absolutes) - self._simple_update_one_txn(txn, table, keyvalues, current_row) + self.simple_update_one_txn(txn, table, keyvalues, current_row) def _upsert_copy_from_table_with_additive_relatives_txn( self, @@ -614,11 +614,11 @@ class StatsStore(StateDeltasStore): txn.execute(sql, qargs) else: self.database_engine.lock_table(txn, into_table) - src_row = self._simple_select_one_txn( + src_row = self.simple_select_one_txn( txn, src_table, keyvalues, copy_columns ) all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues} - dest_current_row = self._simple_select_one_txn( + dest_current_row = self.simple_select_one_txn( txn, into_table, keyvalues=all_dest_keyvalues, @@ -634,11 +634,11 @@ class StatsStore(StateDeltasStore): **src_row, **additive_relatives, } - self._simple_insert_txn(txn, into_table, merged_dict) + self.simple_insert_txn(txn, into_table, merged_dict) else: for (key, val) in additive_relatives.items(): src_row[key] = dest_current_row[key] + val - self._simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) + self.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row) def get_changes_room_total_events_and_bytes(self, min_pos, max_pos): """Fetches the counts of events in the given range of stream IDs. @@ -735,7 +735,7 @@ class StatsStore(StateDeltasStore): def _fetch_current_state_stats(txn): pos = self.get_room_max_stream_ordering() - rows = self._simple_select_many_txn( + rows = self.simple_select_many_txn( txn, table="current_state_events", column="type", diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py index 8780fdd989..60487c4559 100644 --- a/synapse/storage/data_stores/main/stream.py +++ b/synapse/storage/data_stores/main/stream.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -252,7 +255,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): super(StreamWorkerStore, self).__init__(db_conn, hs) events_max = self.get_room_max_stream_ordering() - event_cache_prefill, min_event_val = self._get_cache_dict( + event_cache_prefill, min_event_val = self.get_cache_dict( db_conn, "events", entity_column="room_id", @@ -573,7 +576,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: A deferred "s%d" stream token. """ - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" ).addCallback(lambda row: "s%d" % (row,)) @@ -586,7 +589,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): Returns: A deferred "t%d-%d" topological token. """ - return self._simple_select_one( + return self.simple_select_one( table="events", keyvalues={"event_id": event_id}, retcols=("stream_ordering", "topological_ordering"), @@ -610,13 +613,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "SELECT coalesce(max(topological_ordering), 0) FROM events" " WHERE room_id = ? AND stream_ordering < ?" ) - return self._execute( + return self.execute( "get_max_topological_token", None, sql, room_id, stream_key ).addCallback(lambda r: r[0][0] if r else 0) def _get_max_topological_txn(self, txn, room_id): txn.execute( - "SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?", + "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?", (room_id,), ) @@ -706,7 +709,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): dict """ - results = self._simple_select_one_txn( + results = self.simple_select_one_txn( txn, "events", keyvalues={"event_id": event_id, "room_id": room_id}, @@ -794,7 +797,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, events def get_federation_out_pos(self, typ): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="federation_stream_position", retcol="stream_id", keyvalues={"type": typ}, @@ -802,7 +805,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) def update_federation_out_pos(self, typ, stream_id): - return self._simple_update_one( + return self.simple_update_one( table="federation_stream_position", keyvalues={"type": typ}, updatevalues={"stream_id": stream_id}, diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py index 10d1887f75..85012403be 100644 --- a/synapse/storage/data_stores/main/tags.py +++ b/synapse/storage/data_stores/main/tags.py @@ -41,7 +41,7 @@ class TagsWorkerStore(AccountDataWorkerStore): tag strings to tag content. """ - deferred = self._simple_select_list( + deferred = self.simple_select_list( "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] ) @@ -83,9 +83,7 @@ class TagsWorkerStore(AccountDataWorkerStore): ) def get_tag_content(txn, tag_ids): - sql = ( - "SELECT tag, content" " FROM room_tags" " WHERE user_id=? AND room_id=?" - ) + sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?" results = [] for stream_id, user_id, room_id in tag_ids: txn.execute(sql, (user_id, room_id)) @@ -155,7 +153,7 @@ class TagsWorkerStore(AccountDataWorkerStore): Returns: A deferred list of string tags. """ - return self._simple_select_list( + return self.simple_select_list( table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id}, retcols=("tag", "content"), @@ -180,7 +178,7 @@ class TagsStore(TagsWorkerStore): content_json = json.dumps(content) def add_tag_txn(txn, next_id): - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag}, diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py index 01b1be5e14..c162f3ea16 100644 --- a/synapse/storage/data_stores/main/transactions.py +++ b/synapse/storage/data_stores/main/transactions.py @@ -85,7 +85,7 @@ class TransactionStore(SQLBaseStore): ) def _get_received_txn_response(self, txn, transaction_id, origin): - result = self._simple_select_one_txn( + result = self.simple_select_one_txn( txn, table="received_transactions", keyvalues={"transaction_id": transaction_id, "origin": origin}, @@ -119,7 +119,7 @@ class TransactionStore(SQLBaseStore): response_json (str) """ - return self._simple_insert( + return self.simple_insert( table="received_transactions", values={ "transaction_id": transaction_id, @@ -160,7 +160,7 @@ class TransactionStore(SQLBaseStore): return result def _get_destination_retry_timings(self, txn, destination): - result = self._simple_select_one_txn( + result = self.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, @@ -227,7 +227,7 @@ class TransactionStore(SQLBaseStore): # We need to be careful here as the data may have changed from under us # due to a worker setting the timings. - prev_row = self._simple_select_one_txn( + prev_row = self.simple_select_one_txn( txn, table="destinations", keyvalues={"destination": destination}, @@ -236,7 +236,7 @@ class TransactionStore(SQLBaseStore): ) if not prev_row: - self._simple_insert_txn( + self.simple_insert_txn( txn, table="destinations", values={ @@ -247,7 +247,7 @@ class TransactionStore(SQLBaseStore): }, ) elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval: - self._simple_update_one_txn( + self.simple_update_one_txn( txn, "destinations", keyvalues={"destination": destination}, diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py index 652abe0e6a..1a85aabbfb 100644 --- a/synapse/storage/data_stores/main/user_directory.py +++ b/synapse/storage/data_stores/main/user_directory.py @@ -85,7 +85,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ txn.execute(sql) rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()] - self._simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) + self.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms) del rooms # If search all users is on, get all the users we want to add. @@ -100,13 +100,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore txn.execute("SELECT name FROM users") users = [{"user_id": x[0]} for x in txn.fetchall()] - self._simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) + self.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users) new_pos = yield self.get_max_stream_id_in_current_state_deltas() yield self.runInteraction( "populate_user_directory_temp_build", _make_staging_area ) - yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) + yield self.simple_insert(TEMP_TABLE + "_position", {"position": new_pos}) yield self._end_background_update("populate_user_directory_createtables") return 1 @@ -116,7 +116,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ Update the user directory stream position, then clean up the old tables. """ - position = yield self._simple_select_one_onecol( + position = yield self.simple_select_one_onecol( TEMP_TABLE + "_position", None, "position" ) yield self.update_user_directory_stream_pos(position) @@ -243,7 +243,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore to_insert.clear() # We've finished a room. Delete it from the table. - yield self._simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) + yield self.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id}) # Update the remaining counter. progress["remaining"] -= 1 yield self.runInteraction( @@ -312,7 +312,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) # We've finished processing a user. Delete it from the table. - yield self._simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) + yield self.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id}) # Update the remaining counter. progress["remaining"] -= 1 yield self.runInteraction( @@ -361,7 +361,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ def _update_profile_in_user_dir_txn(txn): - new_entry = self._simple_upsert_txn( + new_entry = self.simple_upsert_txn( txn, table="user_directory", keyvalues={"user_id": user_id}, @@ -435,7 +435,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) elif isinstance(self.database_engine, Sqlite3Engine): value = "%s %s" % (user_id, display_name) if display_name else user_id - self._simple_upsert_txn( + self.simple_upsert_txn( txn, table="user_directory_search", keyvalues={"user_id": user_id}, @@ -462,7 +462,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore """ def _add_users_who_share_room_txn(txn): - self._simple_upsert_many_txn( + self.simple_upsert_many_txn( txn, table="users_who_share_private_rooms", key_names=["user_id", "other_user_id", "room_id"], @@ -489,7 +489,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore def _add_users_in_public_rooms_txn(txn): - self._simple_upsert_many_txn( + self.simple_upsert_many_txn( txn, table="users_in_public_rooms", key_names=["user_id", "room_id"], @@ -519,7 +519,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore @cached() def get_user_in_directory(self, user_id): - return self._simple_select_one( + return self.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, retcols=("display_name", "avatar_url"), @@ -528,7 +528,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore ) def update_user_directory_stream_pos(self, stream_id): - return self._simple_update_one( + return self.simple_update_one( table="user_directory_stream_pos", keyvalues={}, updatevalues={"stream_id": stream_id}, @@ -547,21 +547,21 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): def remove_from_user_dir(self, user_id): def _remove_from_user_dir_txn(txn): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="user_directory", keyvalues={"user_id": user_id} ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="user_directory_search", keyvalues={"user_id": user_id} ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="users_in_public_rooms", keyvalues={"user_id": user_id} ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"other_user_id": user_id}, @@ -575,14 +575,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): """Get all user_ids that are in the room directory because they're in the given room_id """ - user_ids_share_pub = yield self._simple_select_onecol( + user_ids_share_pub = yield self.simple_select_onecol( table="users_in_public_rooms", keyvalues={"room_id": room_id}, retcol="user_id", desc="get_users_in_dir_due_to_room", ) - user_ids_share_priv = yield self._simple_select_onecol( + user_ids_share_priv = yield self.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"room_id": room_id}, retcol="other_user_id", @@ -605,17 +605,17 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): """ def _remove_user_who_share_room_txn(txn): - self._simple_delete_txn( + self.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"user_id": user_id, "room_id": room_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="users_who_share_private_rooms", keyvalues={"other_user_id": user_id, "room_id": room_id}, ) - self._simple_delete_txn( + self.simple_delete_txn( txn, table="users_in_public_rooms", keyvalues={"user_id": user_id, "room_id": room_id}, @@ -636,14 +636,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): Returns: list: user_id """ - rows = yield self._simple_select_onecol( + rows = yield self.simple_select_onecol( table="users_who_share_private_rooms", keyvalues={"user_id": user_id}, retcol="room_id", desc="get_rooms_user_is_in", ) - pub_rows = yield self._simple_select_onecol( + pub_rows = yield self.simple_select_onecol( table="users_in_public_rooms", keyvalues={"user_id": user_id}, retcol="room_id", @@ -674,14 +674,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): ) f2 USING (room_id) """ - rows = yield self._execute( + rows = yield self.execute( "get_rooms_in_common_for_users", None, sql, user_id, other_user_id ) return [room_id for room_id, in rows] def get_user_directory_stream_pos(self): - return self._simple_select_one_onecol( + return self.simple_select_one_onecol( table="user_directory_stream_pos", keyvalues={}, retcol="stream_id", @@ -786,9 +786,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): # This should be unreachable. raise Exception("Unrecognized database engine") - results = yield self._execute( - "search_user_dir", self.cursor_to_dict, sql, *args - ) + results = yield self.execute("search_user_dir", self.cursor_to_dict, sql, *args) limited = len(results) > limit diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py index aa4f0da5f0..37860af070 100644 --- a/synapse/storage/data_stores/main/user_erasure_store.py +++ b/synapse/storage/data_stores/main/user_erasure_store.py @@ -31,7 +31,7 @@ class UserErasureWorkerStore(SQLBaseStore): Returns: Deferred[bool]: True if the user has requested erasure """ - return self._simple_select_onecol( + return self.simple_select_onecol( table="erased_users", keyvalues={"user_id": user_id}, retcol="1", @@ -56,7 +56,7 @@ class UserErasureWorkerStore(SQLBaseStore): # iterate it multiple times, and (b) avoiding duplicates. user_ids = tuple(set(user_ids)) - rows = yield self._simple_select_many_batch( + rows = yield self.simple_select_many_batch( table="erased_users", column="user_id", iterable=user_ids, diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 2e7753820e..731e1c9d9c 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -447,7 +447,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) # Mark as done. cur.execute( database_engine.convert_param_style( - "INSERT INTO applied_module_schemas (module_name, file)" " VALUES (?,?)" + "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)" ), (modname, name), ) diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 02994ab2a5..cd56cd91ed 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -88,9 +88,12 @@ class PaginationConfig(object): raise SynapseError(400, "Invalid request.") def __repr__(self): - return ( - "PaginationConfig(from_tok=%r, to_tok=%r," " direction=%r, limit=%r)" - ) % (self.from_token, self.to_token, self.direction, self.limit) + return ("PaginationConfig(from_tok=%r, to_tok=%r, direction=%r, limit=%r)") % ( + self.from_token, + self.to_token, + self.direction, + self.limit, + ) def get_source_config(self, source_name): keyname = "%s_key" % source_name diff --git a/synapse/visibility.py b/synapse/visibility.py index 8c843febd8..dffe943b28 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -44,7 +44,12 @@ MEMBERSHIP_PRIORITY = ( @defer.inlineCallbacks def filter_events_for_client( - storage: Storage, user_id, events, is_peeking=False, always_include_ids=frozenset() + storage: Storage, + user_id, + events, + is_peeking=False, + always_include_ids=frozenset(), + apply_retention_policies=True, ): """ Check which events a user is allowed to see @@ -59,6 +64,10 @@ def filter_events_for_client( events always_include_ids (set(event_id)): set of event ids to specifically include (unless sender is ignored) + apply_retention_policies (bool): Whether to filter out events that's older than + allowed by the room's retention policy. Useful when this function is called + to e.g. check whether a user should be allowed to see the state at a given + event rather than to know if it should send an event to a user's client(s). Returns: Deferred[list[synapse.events.EventBase]] @@ -86,6 +95,15 @@ def filter_events_for_client( erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) + if apply_retention_policies: + room_ids = set(e.room_id for e in events) + retention_policies = {} + + for room_id in room_ids: + retention_policies[ + room_id + ] = yield storage.main.get_retention_policy_for_room(room_id) + def allowed(event): """ Args: @@ -103,6 +121,18 @@ def filter_events_for_client( if not event.is_state() and event.sender in ignore_list: return None + # Don't try to apply the room's retention policy if the event is a state event, as + # MSC1763 states that retention is only considered for non-state events. + if apply_retention_policies and not event.is_state(): + retention_policy = retention_policies[event.room_id] + max_lifetime = retention_policy.get("max_lifetime") + + if max_lifetime is not None: + oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime + + if event.origin_server_ts < oldest_allowed_ts: + return None + if event.event_id in always_include_ids: return event diff --git a/synmark/__init__.py b/synmark/__init__.py new file mode 100644 index 0000000000..570eb818d9 --- /dev/null +++ b/synmark/__init__.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +from twisted.internet import epollreactor +from twisted.internet.main import installReactor + +from synapse.config.homeserver import HomeServerConfig +from synapse.util import Clock + +from tests.utils import default_config, setup_test_homeserver + + +async def make_homeserver(reactor, config=None): + """ + Make a Homeserver suitable for running benchmarks against. + + Args: + reactor: A Twisted reactor to run under. + config: A HomeServerConfig to use, or None. + """ + cleanup_tasks = [] + clock = Clock(reactor) + + if not config: + config = default_config("test") + + config_obj = HomeServerConfig() + config_obj.parse_config_dict(config, "", "") + + hs = await setup_test_homeserver( + cleanup_tasks.append, config=config_obj, reactor=reactor, clock=clock + ) + stor = hs.get_datastore() + + # Run the database background updates. + if hasattr(stor, "do_next_background_update"): + while not await stor.has_completed_background_updates(): + await stor.do_next_background_update(1) + + def cleanup(): + for i in cleanup_tasks: + i() + + return hs, clock.sleep, cleanup + + +def make_reactor(): + """ + Instantiate and install a Twisted reactor suitable for testing (i.e. not the + default global one). + """ + reactor = epollreactor.EPollReactor() + + if "twisted.internet.reactor" in sys.modules: + del sys.modules["twisted.internet.reactor"] + installReactor(reactor) + + return reactor diff --git a/synmark/__main__.py b/synmark/__main__.py new file mode 100644 index 0000000000..ac59befbd4 --- /dev/null +++ b/synmark/__main__.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from contextlib import redirect_stderr +from io import StringIO + +import pyperf +from synmark import make_reactor +from synmark.suites import SUITES + +from twisted.internet.defer import ensureDeferred +from twisted.logger import globalLogBeginner, textFileLogObserver +from twisted.python.failure import Failure + +from tests.utils import setupdb + + +def make_test(main): + """ + Take a benchmark function and wrap it in a reactor start and stop. + """ + + def _main(loops): + + reactor = make_reactor() + + file_out = StringIO() + with redirect_stderr(file_out): + + d = ensureDeferred(main(reactor, loops)) + + def on_done(_): + if isinstance(_, Failure): + _.printTraceback() + print(file_out.getvalue()) + reactor.stop() + return _ + + d.addBoth(on_done) + reactor.run() + + return d.result + + return _main + + +if __name__ == "__main__": + + def add_cmdline_args(cmd, args): + if args.log: + cmd.extend(["--log"]) + + runner = pyperf.Runner( + processes=3, min_time=2, show_name=True, add_cmdline_args=add_cmdline_args + ) + runner.argparser.add_argument("--log", action="store_true") + runner.parse_args() + + orig_loops = runner.args.loops + runner.args.inherit_environ = ["SYNAPSE_POSTGRES"] + + if runner.args.worker: + if runner.args.log: + globalLogBeginner.beginLoggingTo( + [textFileLogObserver(sys.__stdout__)], redirectStandardIO=False + ) + setupdb() + + for suite, loops in SUITES: + if loops: + runner.args.loops = loops + else: + runner.args.loops = orig_loops + loops = "auto" + runner.bench_time_func( + suite.__name__ + "_" + str(loops), make_test(suite.main), + ) diff --git a/synmark/suites/__init__.py b/synmark/suites/__init__.py new file mode 100644 index 0000000000..cfa3b0ba38 --- /dev/null +++ b/synmark/suites/__init__.py @@ -0,0 +1,3 @@ +from . import logging + +SUITES = [(logging, 1000), (logging, 10000), (logging, None)] diff --git a/synmark/suites/logging.py b/synmark/suites/logging.py new file mode 100644 index 0000000000..d8e4c7d58f --- /dev/null +++ b/synmark/suites/logging.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from io import StringIO + +from mock import Mock + +from pyperf import perf_counter +from synmark import make_homeserver + +from twisted.internet.defer import Deferred +from twisted.internet.protocol import ServerFactory +from twisted.logger import LogBeginner, Logger, LogPublisher +from twisted.protocols.basic import LineOnlyReceiver + +from synapse.logging._structured import setup_structured_logging + + +class LineCounter(LineOnlyReceiver): + + delimiter = b"\n" + + def __init__(self, *args, **kwargs): + self.count = 0 + super().__init__(*args, **kwargs) + + def lineReceived(self, line): + self.count += 1 + + if self.count >= self.factory.wait_for and self.factory.on_done: + on_done = self.factory.on_done + self.factory.on_done = None + on_done.callback(True) + + +async def main(reactor, loops): + """ + Benchmark how long it takes to send `loops` messages. + """ + servers = [] + + def protocol(): + p = LineCounter() + servers.append(p) + return p + + logger_factory = ServerFactory.forProtocol(protocol) + logger_factory.wait_for = loops + logger_factory.on_done = Deferred() + port = reactor.listenTCP(0, logger_factory, interface="127.0.0.1") + + hs, wait, cleanup = await make_homeserver(reactor) + + errors = StringIO() + publisher = LogPublisher() + mock_sys = Mock() + beginner = LogBeginner( + publisher, errors, mock_sys, warnings, initialBufferSize=loops + ) + + log_config = { + "loggers": {"synapse": {"level": "DEBUG"}}, + "drains": { + "tersejson": { + "type": "network_json_terse", + "host": "127.0.0.1", + "port": port.getHost().port, + "maximum_buffer": 100, + } + }, + } + + logger = Logger(namespace="synapse.logging.test_terse_json", observer=publisher) + logging_system = setup_structured_logging( + hs, hs.config, log_config, logBeginner=beginner, redirect_stdlib_logging=False + ) + + # Wait for it to connect... + await logging_system._observers[0]._service.whenConnected() + + start = perf_counter() + + # Send a bunch of useful messages + for i in range(0, loops): + logger.info("test message %s" % (i,)) + + if ( + len(logging_system._observers[0]._buffer) + == logging_system._observers[0].maximum_buffer + ): + while ( + len(logging_system._observers[0]._buffer) + > logging_system._observers[0].maximum_buffer / 2 + ): + await wait(0.01) + + await logger_factory.on_done + + end = perf_counter() - start + + logging_system.stop() + port.stopListening() + cleanup() + + return end diff --git a/sytest-blacklist b/sytest-blacklist index 11785fd43f..411cce0692 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -1,6 +1,6 @@ # This file serves as a blacklist for SyTest tests that we expect will fail in # Synapse. -# +# # Each line of this file is scanned by sytest during a run and if the line # exactly matches the name of a test, it will be marked as "expected fail", # meaning the test will still run, but failure will not mark the entire test @@ -29,3 +29,7 @@ Enabling an unknown default rule fails with 404 # Blacklisted due to https://github.com/matrix-org/synapse/issues/1663 New federated private chats get full presence information (SYN-115) + +# Blacklisted due to https://github.com/matrix-org/matrix-doc/pull/2314 removing +# this requirement from the spec +Inbound federation of state requires event_id as a mandatory paramater diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 2dc5052249..63d8633582 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 51714a2b06..24fa8dbb45 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -18,17 +18,14 @@ from mock import Mock from twisted.internet import defer from synapse.api.errors import Codes, SynapseError -from synapse.config.ratelimiting import FederationRateLimitConfig -from synapse.federation.transport import server from synapse.rest import admin from synapse.rest.client.v1 import login, room from synapse.types import UserID -from synapse.util.ratelimitutils import FederationRateLimiter from tests import unittest -class RoomComplexityTests(unittest.HomeserverTestCase): +class RoomComplexityTests(unittest.FederatingHomeserverTestCase): servlets = [ admin.register_servlets, @@ -41,25 +38,6 @@ class RoomComplexityTests(unittest.HomeserverTestCase): config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05} return config - def prepare(self, reactor, clock, homeserver): - class Authenticator(object): - def authenticate_request(self, request, content): - return defer.succeed("otherserver.nottld") - - ratelimiter = FederationRateLimiter( - clock, - FederationRateLimitConfig( - window_size=1, - sleep_limit=1, - sleep_msec=1, - reject_limit=1000, - concurrent_requests=1000, - ), - ) - server.register_servlets( - homeserver, self.resource, Authenticator(), ratelimiter - ) - def test_complexity_simple(self): u1 = self.register_user("u1", "pass") @@ -105,7 +83,7 @@ class RoomComplexityTests(unittest.HomeserverTestCase): d = handler._remote_join( None, - ["otherserver.example"], + ["other.example.com"], "roomid", UserID.from_string(u1), {"membership": "join"}, @@ -146,7 +124,7 @@ class RoomComplexityTests(unittest.HomeserverTestCase): d = handler._remote_join( None, - ["otherserver.example"], + ["other.example.com"], room_1, UserID.from_string(u1), {"membership": "join"}, diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index cce8d8c6de..d456267b87 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.types import ReadReceipt -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config class FederationSenderTestCases(HomeserverTestCase): @@ -29,6 +29,7 @@ class FederationSenderTestCases(HomeserverTestCase): federation_transport_client=Mock(spec=["send_transaction"]), ) + @override_config({"send_federation": True}) def test_send_receipts(self): mock_state_handler = self.hs.get_state_handler() mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"] @@ -69,6 +70,7 @@ class FederationSenderTestCases(HomeserverTestCase): ], ) + @override_config({"send_federation": True}) def test_send_receipts_with_backoff(self): """Send two receipts in quick succession; the second should be flushed, but only after 20ms""" diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index b08be451aa..1ec8c40901 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2018 New Vector Ltd +# Copyright 2019 Matrix.org Federation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +17,8 @@ import logging from synapse.events import FrozenEvent from synapse.federation.federation_server import server_matches_acl_event +from synapse.rest import admin +from synapse.rest.client.v1 import login, room from tests import unittest @@ -41,6 +44,66 @@ class ServerACLsTestCase(unittest.TestCase): self.assertTrue(server_matches_acl_event("1:2:3:4", e)) +class StateQueryTests(unittest.FederatingHomeserverTestCase): + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def test_without_event_id(self): + """ + Querying v1/state/<room_id> without an event ID will return the current + known state. + """ + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + self.inject_room_member(room_1, "@user:other.example.com", "join") + + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/state/%s" % (room_1,) + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + self.assertEqual( + channel.json_body["room_version"], + self.hs.config.default_room_version.identifier, + ) + + members = set( + map( + lambda x: x["state_key"], + filter( + lambda x: x["type"] == "m.room.member", channel.json_body["pdus"] + ), + ) + ) + + self.assertEqual(members, set(["@user:other.example.com", u1])) + self.assertEqual(len(channel.json_body["pdus"]), 6) + + def test_needs_to_be_in_room(self): + """ + Querying v1/state/<room_id> requires the server + be in the room to provide data. + """ + u1 = self.register_user("u1", "pass") + u1_token = self.login("u1", "pass") + + room_1 = self.helper.create_room_as(u1, tok=u1_token) + + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/state/%s" % (room_1,) + ) + self.render(request) + self.assertEquals(403, channel.code, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + + def _create_acl_event(content): return FrozenEvent( { diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py new file mode 100644 index 0000000000..27d83bb7d9 --- /dev/null +++ b/tests/federation/transport/test_server.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from synapse.config.ratelimiting import FederationRateLimitConfig +from synapse.federation.transport import server +from synapse.util.ratelimitutils import FederationRateLimiter + +from tests import unittest +from tests.unittest import override_config + + +class RoomDirectoryFederationTests(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): + class Authenticator(object): + def authenticate_request(self, request, content): + return defer.succeed("otherserver.nottld") + + ratelimiter = FederationRateLimiter(clock, FederationRateLimitConfig()) + server.register_servlets( + homeserver, self.resource, Authenticator(), ratelimiter + ) + + @override_config({"allow_public_rooms_over_federation": False}) + def test_blocked_public_room_list_over_federation(self): + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/publicRooms" + ) + self.render(request) + self.assertEquals(403, channel.code) + + @override_config({"allow_public_rooms_over_federation": True}) + def test_open_public_room_list_over_federation(self): + request, channel = self.make_request( + "GET", "/_matrix/federation/v1/publicRooms" + ) + self.render(request) + self.assertEquals(200, channel.code) diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 0bb96674a2..70f172eb02 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # Copyright 2017 New Vector Ltd +# Copyright 2019 Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -94,23 +95,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + version_etag = res["etag"] + del res["etag"] self.assertDictEqual( res, { "version": "1", "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", + "count": 0, }, ) # check we can retrieve it as a specific version res = yield self.handler.get_version_info(self.local_user, "1") + self.assertEqual(res["etag"], version_etag) + del res["etag"] self.assertDictEqual( res, { "version": "1", "algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data", + "count": 0, }, ) @@ -126,12 +133,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["etag"] self.assertDictEqual( res, { "version": "2", "algorithm": "m.megolm_backup.v1", "auth_data": "second_version_auth_data", + "count": 0, }, ) @@ -158,12 +167,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["etag"] self.assertDictEqual( res, { "algorithm": "m.megolm_backup.v1", "auth_data": "revised_first_version_auth_data", "version": version, + "count": 0, }, ) @@ -207,12 +218,14 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): # check we can retrieve it as the current version res = yield self.handler.get_version_info(self.local_user) + del res["etag"] # etag is opaque, so don't test its contents self.assertDictEqual( res, { "algorithm": "m.megolm_backup.v1", "auth_data": "revised_first_version_auth_data", "version": version, + "count": 0, }, ) @@ -409,6 +422,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): yield self.handler.upload_room_keys(self.local_user, version, room_keys) + # get the etag to compare to future versions + res = yield self.handler.get_version_info(self.local_user) + backup_etag = res["etag"] + self.assertEqual(res["count"], 1) + new_room_keys = copy.deepcopy(room_keys) new_room_key = new_room_keys["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"] @@ -423,6 +441,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): "SSBBTSBBIEZJU0gK", ) + # the etag should be the same since the session did not change + res = yield self.handler.get_version_info(self.local_user) + self.assertEqual(res["etag"], backup_etag) + # test that marking the session as verified however /does/ replace it new_room_key["is_verified"] = True yield self.handler.upload_room_keys(self.local_user, version, new_room_keys) @@ -432,6 +454,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) + # the etag should NOT be equal now, since the key changed + res = yield self.handler.get_version_info(self.local_user) + self.assertNotEqual(res["etag"], backup_etag) + backup_etag = res["etag"] + # test that a session with a higher forwarded_count doesn't replace one # with a lower forwarding count new_room_key["forwarded_count"] = 2 @@ -443,6 +470,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" ) + # the etag should be the same since the session did not change + res = yield self.handler.get_version_info(self.local_user) + self.assertEqual(res["etag"], backup_etag) + # TODO: check edge cases as well as the common variations here @defer.inlineCallbacks diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index e0075ccd32..380fd0d107 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -45,13 +45,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store._all_done = False self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms", @@ -61,7 +61,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -71,7 +71,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -82,7 +82,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) def get_all_room_state(self): - return self.store._simple_select_list( + return self.store.simple_select_list( "room_stats_state", None, retcols=("name", "topic", "canonical_alias") ) @@ -96,7 +96,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000) return self.get_success( - self.store._simple_select_one( + self.store.simple_select_one( table + "_historical", {id_col: stat_id, end_ts: end_ts}, cols, @@ -180,7 +180,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.handler.stats_enabled = True self.store._all_done = False self.get_success( - self.store._simple_update_one( + self.store.simple_update_one( table="stats_incremental_position", keyvalues={}, updatevalues={"stream_id": 0}, @@ -188,7 +188,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", {"update_name": "populate_stats_prepare", "progress_json": "{}"}, ) @@ -205,13 +205,13 @@ class StatsRoomTests(unittest.HomeserverTestCase): # Now do the initial ingestion. self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", {"update_name": "populate_stats_process_rooms", "progress_json": "{}"}, ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", @@ -656,12 +656,12 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store._all_done = False self.get_success( - self.store._simple_delete( + self.store.simple_delete( "room_stats_current", {"1": 1}, "test_delete_stats" ) ) self.get_success( - self.store._simple_delete( + self.store.simple_delete( "user_stats_current", {"1": 1}, "test_delete_stats" ) ) @@ -675,7 +675,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.store._all_done = False self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_process_rooms", @@ -685,7 +685,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_process_users", @@ -695,7 +695,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_stats_cleanup", diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 5ec568f4e6..f6d8660285 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -24,6 +24,7 @@ from synapse.api.errors import AuthError from synapse.types import UserID from tests import unittest +from tests.unittest import override_config from tests.utils import register_federation_servlets # Some local users to test with @@ -174,6 +175,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) + @override_config({"send_federation": True}) def test_started_typing_remote_send(self): self.room_members = [U_APPLE, U_ONION] @@ -237,6 +239,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ], ) + @override_config({"send_federation": True}) def test_stopped_typing(self): self.room_members = [U_APPLE, U_BANANA, U_ONION] diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index c5e91a8c41..d5b1c5b4ac 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -158,7 +158,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_in_public_rooms(self): r = self.get_success( - self.store._simple_select_list( + self.store.simple_select_list( "users_in_public_rooms", None, ("user_id", "room_id") ) ) @@ -169,7 +169,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def get_users_who_share_private_rooms(self): return self.get_success( - self.store._simple_select_list( + self.store.simple_select_list( "users_who_share_private_rooms", None, ["user_id", "other_user_id", "room_id"], @@ -184,7 +184,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.store._all_done = False self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_user_directory_createtables", @@ -193,7 +193,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_rooms", @@ -203,7 +203,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_user_directory_process_users", @@ -213,7 +213,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) self.get_success( - self.store._simple_insert( + self.store.simple_insert( "background_updates", { "update_name": "populate_user_directory_cleanup", diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 4f924ce451..e7472e3a93 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -48,7 +48,10 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): server_factory = ReplicationStreamProtocolFactory(self.hs) self.streamer = server_factory.streamer + handler_factory = Mock() self.replication_handler = ReplicationClientHandler(self.slaved_store) + self.replication_handler.factory = handler_factory + client_factory = ReplicationClientFactory( self.hs, "client_name", self.replication_handler ) diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index ce3835ae6a..1d14e77255 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from synapse.replication.tcp.commands import ReplicateCommand from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -30,7 +32,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): server = server_factory.buildProtocol(None) # build a replication client, with a dummy handler + handler_factory = Mock() self.test_handler = TestReplicationClientHandler() + self.test_handler.factory = handler_factory self.client = ClientReplicationStreamProtocol( "client", "test", clock, self.test_handler ) diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 9575058252..124ce0768a 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -632,7 +632,7 @@ class PurgeRoomTestCase(unittest.HomeserverTestCase): "state_groups_state", ): count = self.get_success( - self.store._simple_select_one_onecol( + self.store.simple_select_one_onecol( table=table, keyvalues={"room_id": room_id}, retcol="COUNT(*)", diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py new file mode 100644 index 0000000000..5e9c07ebf3 --- /dev/null +++ b/tests/rest/client/test_ephemeral_message.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.constants import EventContentFields, EventTypes +from synapse.rest import admin +from synapse.rest.client.v1 import room + +from tests import unittest + + +class EphemeralMessageTestCase(unittest.HomeserverTestCase): + + user_id = "@user:test" + + servlets = [ + admin.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + config["enable_ephemeral_messages"] = True + + self.hs = self.setup_test_homeserver(config=config) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_message_expiry_no_delay(self): + """Tests that sending a message sent with a m.self_destruct_after field set to the + past results in that event being deleted right away. + """ + # Send a message in the room that has expired. From here, the reactor clock is + # at 200ms, so 0 is in the past, and even if that wasn't the case and the clock + # is at 0ms the code path is the same if the event's expiry timestamp is the + # current timestamp. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "hello", + EventContentFields.SELF_DESTRUCT_AFTER: 0, + }, + ) + event_id = res["event_id"] + + # Check that we can't retrieve the content of the event. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertFalse(bool(event_content), event_content) + + def test_message_expiry_delay(self): + """Tests that sending a message with a m.self_destruct_after field set to the + future results in that event not being deleted right away, but advancing the + clock to after that expiry timestamp causes the event to be deleted. + """ + # Send a message in the room that'll expire in 1s. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "hello", + EventContentFields.SELF_DESTRUCT_AFTER: self.clock.time_msec() + 1000, + }, + ) + event_id = res["event_id"] + + # Check that we can retrieve the content of the event before it has expired. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertTrue(bool(event_content), event_content) + + # Advance the clock to after the deletion. + self.reactor.advance(1) + + # Check that we can't retrieve the content of the event anymore. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertFalse(bool(event_content), event_content) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py new file mode 100644 index 0000000000..95475bb651 --- /dev/null +++ b/tests/rest/client/test_retention.py @@ -0,0 +1,293 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from mock import Mock + +from synapse.api.constants import EventTypes +from synapse.rest import admin +from synapse.rest.client.v1 import login, room +from synapse.visibility import filter_events_for_client + +from tests import unittest + +one_hour_ms = 3600000 +one_day_ms = one_hour_ms * 24 + + +class RetentionTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["retention"] = { + "enabled": True, + "default_policy": { + "min_lifetime": one_day_ms, + "max_lifetime": one_day_ms * 3, + }, + "allowed_lifetime_min": one_day_ms, + "allowed_lifetime_max": one_day_ms * 3, + } + + self.hs = self.setup_test_homeserver(config=config) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.token = self.login("user", "password") + + def test_retention_state_event(self): + """Tests that the server configuration can limit the values a user can set to the + room's retention policy. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={"max_lifetime": one_day_ms * 4}, + tok=self.token, + expect_code=400, + ) + + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={"max_lifetime": one_hour_ms}, + tok=self.token, + expect_code=400, + ) + + def test_retention_event_purged_with_state_event(self): + """Tests that expired events are correctly purged when the room's retention policy + is defined by a state event. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Set the room's retention period to 2 days. + lifetime = one_day_ms * 2 + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={"max_lifetime": lifetime}, + tok=self.token, + ) + + self._test_retention_event_purged(room_id, one_day_ms * 1.5) + + def test_retention_event_purged_without_state_event(self): + """Tests that expired events are correctly purged when the room's retention policy + is defined by the server's configuration's default retention policy. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self._test_retention_event_purged(room_id, one_day_ms * 2) + + def test_visibility(self): + """Tests that synapse.visibility.filter_events_for_client correctly filters out + outdated events + """ + store = self.hs.get_datastore() + storage = self.hs.get_storage() + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + events = [] + + # Send a first event, which should be filtered out at the end of the test. + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) + + # Get the event from the store so that we end up with a FrozenEvent that we can + # give to filter_events_for_client. We need to do this now because the event won't + # be in the database anymore after it has expired. + events.append(self.get_success(store.get_event(resp.get("event_id")))) + + # Advance the time by 2 days. We're using the default retention policy, therefore + # after this the first event will still be valid. + self.reactor.advance(one_day_ms * 2 / 1000) + + # Send another event, which shouldn't get filtered out. + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) + + valid_event_id = resp.get("event_id") + + events.append(self.get_success(store.get_event(valid_event_id))) + + # Advance the time by anothe 2 days. After this, the first event should be + # outdated but not the second one. + self.reactor.advance(one_day_ms * 2 / 1000) + + # Run filter_events_for_client with our list of FrozenEvents. + filtered_events = self.get_success( + filter_events_for_client(storage, self.user_id, events) + ) + + # We should only get one event back. + self.assertEqual(len(filtered_events), 1, filtered_events) + # That event should be the second, not outdated event. + self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events) + + def _test_retention_event_purged(self, room_id, increment): + # Get the create event to, later, check that we can still access it. + message_handler = self.hs.get_message_handler() + create_event = self.get_success( + message_handler.get_room_data(self.user_id, room_id, EventTypes.Create) + ) + + # Send a first event to the room. This is the event we'll want to be purged at the + # end of the test. + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) + + expired_event_id = resp.get("event_id") + + # Check that we can retrieve the event. + expired_event = self.get_event(room_id, expired_event_id) + self.assertEqual( + expired_event.get("content", {}).get("body"), "1", expired_event + ) + + # Advance the time. + self.reactor.advance(increment / 1000) + + # Send another event. We need this because the purge job won't purge the most + # recent event in the room. + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) + + valid_event_id = resp.get("event_id") + + # Advance the time again. Now our first event should have expired but our second + # one should still be kept. + self.reactor.advance(increment / 1000) + + # Check that the event has been purged from the database. + self.get_event(room_id, expired_event_id, expected_code=404) + + # Check that the event that hasn't been purged can still be retrieved. + valid_event = self.get_event(room_id, valid_event_id) + self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event) + + # Check that we can still access state events that were sent before the event that + # has been purged. + self.get_event(room_id, create_event.event_id) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url, access_token=self.token) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body + + +class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + config["retention"] = { + "enabled": True, + } + + mock_federation_client = Mock(spec=["backfill"]) + + self.hs = self.setup_test_homeserver( + config=config, federation_client=mock_federation_client, + ) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("user", "password") + self.token = self.login("user", "password") + + def test_no_default_policy(self): + """Tests that an event doesn't get expired if there is neither a default retention + policy nor a policy specific to the room. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + self._test_retention(room_id) + + def test_state_policy(self): + """Tests that an event gets correctly expired if there is no default retention + policy but there's a policy specific to the room. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Set the maximum lifetime to 35 days so that the first event gets expired but not + # the second one. + self.helper.send_state( + room_id=room_id, + event_type=EventTypes.Retention, + body={"max_lifetime": one_day_ms * 35}, + tok=self.token, + ) + + self._test_retention(room_id, expected_code_for_first_event=404) + + def _test_retention(self, room_id, expected_code_for_first_event=200): + # Send a first event to the room. This is the event we'll want to be purged at the + # end of the test. + resp = self.helper.send(room_id=room_id, body="1", tok=self.token) + + first_event_id = resp.get("event_id") + + # Check that we can retrieve the event. + expired_event = self.get_event(room_id, first_event_id) + self.assertEqual( + expired_event.get("content", {}).get("body"), "1", expired_event + ) + + # Advance the time by a month. + self.reactor.advance(one_day_ms * 30 / 1000) + + # Send another event. We need this because the purge job won't purge the most + # recent event in the room. + resp = self.helper.send(room_id=room_id, body="2", tok=self.token) + + second_event_id = resp.get("event_id") + + # Advance the time by another month. + self.reactor.advance(one_day_ms * 30 / 1000) + + # Check if the event has been purged from the database. + first_event = self.get_event( + room_id, first_event_id, expected_code=expected_code_for_first_event + ) + + if expected_code_for_first_event == 200: + self.assertEqual( + first_event.get("content", {}).get("body"), "1", first_event + ) + + # Check that the event that hasn't been purged can still be retrieved. + second_event = self.get_event(room_id, second_event_id) + self.assertEqual(second_event.get("content", {}).get("body"), "2", second_event) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url, access_token=self.token) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 5e38fd6ced..1ca7fa742f 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd # Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,7 +27,9 @@ from twisted.internet import defer import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.handlers.pagination import PurgeStatus from synapse.rest.client.v1 import login, profile, room +from synapse.util.stringutils import random_string from tests import unittest @@ -811,104 +815,77 @@ class RoomMessageListTestCase(RoomBase): self.assertTrue("chunk" in channel.json_body) self.assertTrue("end" in channel.json_body) - def test_filter_labels(self): - """Test that we can filter by a label.""" - message_filter = json.dumps( - {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]} - ) - - events = self._test_filter_labels(message_filter) - - self.assertEqual(len(events), 2, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) + def test_room_messages_purge(self): + store = self.hs.get_datastore() + pagination_handler = self.hs.get_pagination_handler() - def test_filter_not_labels(self): - """Test that we can filter by the absence of a label.""" - message_filter = json.dumps( - {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]} + # Send a first message in the room, which will be removed by the purge. + first_event_id = self.helper.send(self.room_id, "message 1")["event_id"] + first_token = self.get_success( + store.get_topological_token_for_event(first_event_id) ) - events = self._test_filter_labels(message_filter) - - self.assertEqual(len(events), 3, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "without label", events[0]) - self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1]) - self.assertEqual( - events[2]["content"]["body"], "with two wrong labels", events[2] + # Send a second message in the room, which won't be removed, and which we'll + # use as the marker to purge events before. + second_event_id = self.helper.send(self.room_id, "message 2")["event_id"] + second_token = self.get_success( + store.get_topological_token_for_event(second_event_id) ) - def test_filter_labels_not_labels(self): - """Test that we can filter by both a label and the absence of another label.""" - sync_filter = json.dumps( - { - "types": [EventTypes.Message], - "org.matrix.labels": ["#work"], - "org.matrix.not_labels": ["#notfun"], - } - ) + # Send a third event in the room to ensure we don't fall under any edge case + # due to our marker being the latest forward extremity in the room. + self.helper.send(self.room_id, "message 3") - events = self._test_filter_labels(sync_filter) - - self.assertEqual(len(events), 1, [event["content"] for event in events]) - self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) - - def _test_filter_labels(self, message_filter): - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, + # Check that we get the first and second message when querying /messages. + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "without label"}, - ) + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with wrong label", - EventContentFields.LABELS: ["#work"], - }, + # Purge every event before the second event. + purge_id = random_string(16) + pagination_handler._purges_by_id[purge_id] = PurgeStatus() + self.get_success( + pagination_handler._purge_history( + purge_id=purge_id, + room_id=self.room_id, + token=second_token, + delete_local_events=True, + ) ) - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with two wrong labels", - EventContentFields.LABELS: ["#work", "#notfun"], - }, + # Check that we only get the second message through /message now that the first + # has been purged. + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})), ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) - self.helper.send_event( - room_id=self.room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "with right label", - EventContentFields.LABELS: ["#fun"], - }, - ) + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 1, [event["content"] for event in chunk]) - token = "s0_0_0_0_0_0_0_0_0" + # Check that we get no event, but also no error, when querying /messages with + # the token that was pointing at the first event, because we don't have it + # anymore. request, channel = self.make_request( "GET", - "/rooms/%s/messages?access_token=x&from=%s&filter=%s" - % (self.room_id, token, message_filter), + "/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s" + % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})), ) self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) - return channel.json_body["chunk"] + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 0, [event["content"] for event in chunk]) class RoomSearchTestCase(unittest.HomeserverTestCase): @@ -1106,3 +1083,517 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): res_displayname = channel.json_body["content"]["displayname"] self.assertEqual(res_displayname, self.displayname, channel.result) + + +class RoomMembershipReasonTestCase(unittest.HomeserverTestCase): + """Tests that clients can add a "reason" field to membership events and + that they get correctly added to the generated events and propagated. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.second_user_id = self.register_user("second", "test") + self.second_tok = self.login("second", "test") + + self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) + + def test_join_reason(self): + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/join".format(self.room_id), + content={"reason": reason}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_leave_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + content={"reason": reason}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_kick_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/kick".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_ban_reason(self): + self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/ban".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_unban_reason(self): + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/unban".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_invite_reason(self): + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/invite".format(self.room_id), + content={"reason": reason, "user_id": self.second_user_id}, + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def test_reject_invite_reason(self): + self.helper.invite( + self.room_id, + src=self.creator, + targ=self.second_user_id, + tok=self.creator_tok, + ) + + reason = "hello" + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/{}/leave".format(self.room_id), + content={"reason": reason}, + access_token=self.second_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + self._check_for_reason(reason) + + def _check_for_reason(self, reason): + request, channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format( + self.room_id, self.second_user_id + ), + access_token=self.creator_tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + event_content = channel.json_body + + self.assertEqual(event_content.get("reason"), reason, channel.result) + + +class LabelsTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + profile.register_servlets, + ] + + # Filter that should only catch messages with the label "#fun". + FILTER_LABELS = { + "types": [EventTypes.Message], + "org.matrix.labels": ["#fun"], + } + # Filter that should only catch messages without the label "#fun". + FILTER_NOT_LABELS = { + "types": [EventTypes.Message], + "org.matrix.not_labels": ["#fun"], + } + # Filter that should only catch messages with the label "#work" but without the label + # "#notfun". + FILTER_LABELS_NOT_LABELS = { + "types": [EventTypes.Message], + "org.matrix.labels": ["#work"], + "org.matrix.not_labels": ["#notfun"], + } + + def prepare(self, reactor, clock, homeserver): + self.user_id = self.register_user("test", "test") + self.tok = self.login("test", "test") + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + def test_context_filter_labels(self): + """Test that we can filter by a label on a /context request.""" + event_id = self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_LABELS)), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 1, [event["content"] for event in events_before] + ) + self.assertEqual( + events_before[0]["content"]["body"], "with right label", events_before[0] + ) + + events_after = channel.json_body["events_before"] + + self.assertEqual( + len(events_after), 1, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with right label", events_after[0] + ) + + def test_context_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /context request.""" + event_id = self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_NOT_LABELS)), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 1, [event["content"] for event in events_before] + ) + self.assertEqual( + events_before[0]["content"]["body"], "without label", events_before[0] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual( + len(events_after), 2, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with wrong label", events_after[0] + ) + self.assertEqual( + events_after[1]["content"]["body"], "with two wrong labels", events_after[1] + ) + + def test_context_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /context request. + """ + event_id = self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "GET", + "/rooms/%s/context/%s?filter=%s" + % (self.room_id, event_id, json.dumps(self.FILTER_LABELS_NOT_LABELS)), + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.result) + + events_before = channel.json_body["events_before"] + + self.assertEqual( + len(events_before), 0, [event["content"] for event in events_before] + ) + + events_after = channel.json_body["events_after"] + + self.assertEqual( + len(events_after), 1, [event["content"] for event in events_after] + ) + self.assertEqual( + events_after[0]["content"]["body"], "with wrong label", events_after[0] + ) + + def test_messages_filter_labels(self): + """Test that we can filter by a label on a /messages request.""" + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, json.dumps(self.FILTER_LABELS)), + ) + self.render(request) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 2, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with right label", events[0]) + self.assertEqual(events[1]["content"]["body"], "with right label", events[1]) + + def test_messages_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /messages request.""" + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % (self.room_id, self.tok, token, json.dumps(self.FILTER_NOT_LABELS)), + ) + self.render(request) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 4, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "without label", events[0]) + self.assertEqual(events[1]["content"]["body"], "without label", events[1]) + self.assertEqual(events[2]["content"]["body"], "with wrong label", events[2]) + self.assertEqual( + events[3]["content"]["body"], "with two wrong labels", events[3] + ) + + def test_messages_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /messages request. + """ + self._send_labelled_messages_in_room() + + token = "s0_0_0_0_0_0_0_0_0" + request, channel = self.make_request( + "GET", + "/rooms/%s/messages?access_token=%s&from=%s&filter=%s" + % ( + self.room_id, + self.tok, + token, + json.dumps(self.FILTER_LABELS_NOT_LABELS), + ), + ) + self.render(request) + + events = channel.json_body["chunk"] + + self.assertEqual(len(events), 1, [event["content"] for event in events]) + self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0]) + + def test_search_filter_labels(self): + """Test that we can filter by a label on a /search request.""" + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), 2, [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "with right label", + results[0]["result"]["content"]["body"], + ) + self.assertEqual( + results[1]["result"]["content"]["body"], + "with right label", + results[1]["result"]["content"]["body"], + ) + + def test_search_filter_not_labels(self): + """Test that we can filter by the absence of a label on a /search request.""" + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_NOT_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), 4, [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "without label", + results[0]["result"]["content"]["body"], + ) + self.assertEqual( + results[1]["result"]["content"]["body"], + "without label", + results[1]["result"]["content"]["body"], + ) + self.assertEqual( + results[2]["result"]["content"]["body"], + "with wrong label", + results[2]["result"]["content"]["body"], + ) + self.assertEqual( + results[3]["result"]["content"]["body"], + "with two wrong labels", + results[3]["result"]["content"]["body"], + ) + + def test_search_filter_labels_not_labels(self): + """Test that we can filter by both a label and the absence of another label on a + /search request. + """ + request_data = json.dumps( + { + "search_categories": { + "room_events": { + "search_term": "label", + "filter": self.FILTER_LABELS_NOT_LABELS, + } + } + } + ) + + self._send_labelled_messages_in_room() + + request, channel = self.make_request( + "POST", "/search?access_token=%s" % self.tok, request_data + ) + self.render(request) + + results = channel.json_body["search_categories"]["room_events"]["results"] + + self.assertEqual( + len(results), 1, [result["result"]["content"] for result in results], + ) + self.assertEqual( + results[0]["result"]["content"]["body"], + "with wrong label", + results[0]["result"]["content"]["body"], + ) + + def _send_labelled_messages_in_room(self): + """Sends several messages to a room with different labels (or without any) to test + filtering by label. + Returns: + The ID of the event to use if we're testing filtering on /context. + """ + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=self.tok, + ) + + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "without label"}, + tok=self.tok, + ) + # Return this event's ID when we test filtering in /context requests. + event_id = res["event_id"] + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with wrong label", + EventContentFields.LABELS: ["#work"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with two wrong labels", + EventContentFields.LABELS: ["#work", "#notfun"], + }, + tok=self.tok, + ) + + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "with right label", + EventContentFields.LABELS: ["#fun"], + }, + tok=self.tok, + ) + + return event_id diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 8ea0cb05ea..e7417b3d14 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -1,5 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2017 Vector Creations Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index dab87e5edf..c0d0d2b44e 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -203,6 +203,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): @unittest.override_config( { + "public_baseurl": "https://test_server", "enable_registration_captcha": True, "user_consent": { "version": "1", diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 3283c0e47b..661c1f88b9 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2018 New Vector +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/server.py b/tests/server.py index f878aeaada..2b7cf4242e 100644 --- a/tests/server.py +++ b/tests/server.py @@ -379,6 +379,7 @@ class FakeTransport(object): disconnecting = False disconnected = False + connected = True buffer = attr.ib(default=b"") producer = attr.ib(default=None) autoflush = attr.ib(default=True) @@ -402,6 +403,7 @@ class FakeTransport(object): "FakeTransport: Delaying disconnect until buffer is flushed" ) else: + self.connected = False self.disconnected = True def abortConnection(self): diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 9b81b536f5..7b7434a468 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -356,7 +356,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): self.get_success( self.storage.runInteraction( "test", - self.storage._simple_upsert_many_txn, + self.storage.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -367,7 +367,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage._simple_select_list( + self.storage.simple_select_list( self.table_name, None, ["id, username, value"] ) ) @@ -383,7 +383,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): self.get_success( self.storage.runInteraction( "test", - self.storage._simple_upsert_many_txn, + self.storage.simple_upsert_many_txn, self.table_name, key_names, key_values, @@ -394,7 +394,7 @@ class UpsertManyTests(unittest.HomeserverTestCase): # Check results are what we expect res = self.get_success( - self.storage._simple_select_list( + self.storage.simple_select_list( self.table_name, None, ["id, username, value"] ) ) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index c778de1f0c..de5e4a5fce 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -65,7 +65,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_insert( + yield self.datastore.simple_insert( table="tablename", values={"columname": "Value"} ) @@ -77,7 +77,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_insert_3cols(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_insert( + yield self.datastore.simple_insert( table="tablename", # Use OrderedDict() so we can assert on the SQL generated values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]), @@ -92,7 +92,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) - value = yield self.datastore._simple_select_one_onecol( + value = yield self.datastore.simple_select_one_onecol( table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol" ) @@ -106,7 +106,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 1 self.mock_txn.fetchone.return_value = (1, 2, 3) - ret = yield self.datastore._simple_select_one( + ret = yield self.datastore.simple_select_one( table="tablename", keyvalues={"keycol": "TheKey"}, retcols=["colA", "colB", "colC"], @@ -122,7 +122,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.rowcount = 0 self.mock_txn.fetchone.return_value = None - ret = yield self.datastore._simple_select_one( + ret = yield self.datastore.simple_select_one( table="tablename", keyvalues={"keycol": "Not here"}, retcols=["colA"], @@ -137,7 +137,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.description = (("colA", None, None, None, None, None, None),) - ret = yield self.datastore._simple_select_list( + ret = yield self.datastore.simple_select_list( table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] ) @@ -150,7 +150,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_1col(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_update_one( + yield self.datastore.simple_update_one( table="tablename", keyvalues={"keycol": "TheKey"}, updatevalues={"columnname": "New Value"}, @@ -165,7 +165,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_update_one_4cols(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_update_one( + yield self.datastore.simple_update_one( table="tablename", keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), @@ -180,7 +180,7 @@ class SQLBaseStoreTestCase(unittest.TestCase): def test_delete_one(self): self.mock_txn.rowcount = 1 - yield self.datastore._simple_delete_one( + yield self.datastore.simple_delete_one( table="tablename", keyvalues={"keycol": "Go away"} ) diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index afac5dec7f..25bdd2c163 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -81,7 +81,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store._simple_select_list( + self.store.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -112,7 +112,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.pump(0) result = self.get_success( - self.store._simple_select_list( + self.store.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -218,7 +218,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # But clear the associated entry in devices table self.get_success( - self.store._simple_update( + self.store.simple_update( table="devices", keyvalues={"user_id": user_id, "device_id": "device_id"}, updatevalues={"last_seen": None, "ip": None, "user_agent": None}, @@ -245,7 +245,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # Register the background update to run again. self.get_success( - self.store._simple_insert( + self.store.simple_insert( table="background_updates", values={ "update_name": "devices_last_seen", @@ -297,7 +297,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should see that in the DB result = self.get_success( - self.store._simple_select_list( + self.store.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], @@ -323,7 +323,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should get no results. result = self.get_success( - self.store._simple_select_list( + self.store.simple_select_list( table="user_ips", keyvalues={"user_id": user_id}, retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index d128fde441..35dafbb904 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -39,8 +39,8 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): ) self.get_success( - self.store.set_e2e_room_key( - "user_id", version1, "room", "session", room_key + self.store.add_e2e_room_keys( + "user_id", version1, [("room", "session", room_key)] ) ) @@ -51,8 +51,8 @@ class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): ) self.get_success( - self.store.set_e2e_room_key( - "user_id", version2, "room", "session", room_key + self.store.add_e2e_room_keys( + "user_id", version2, [("room", "session", room_key)] ) ) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index b114c6fb1d..2337a1ae46 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -116,7 +116,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield _inject_actions(6, PlAIN_NOTIF) yield _rotate(7) - yield self.store._simple_delete( + yield self.store.simple_delete( table="event_push_actions", keyvalues={"1": 1}, desc="" ) @@ -135,7 +135,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): - return self.store._simple_insert( + return self.store.simple_insert( "events", { "stream_ordering": so, diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 4561c3e383..4930b6777e 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -338,7 +338,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): ) event_json = self.get_success( - self.store._simple_select_one_onecol( + self.store.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", @@ -356,7 +356,7 @@ class RedactionTestCase(unittest.HomeserverTestCase): self.reactor.advance(60 * 60 * 2) event_json = self.get_success( - self.store._simple_select_one_onecol( + self.store.simple_select_one_onecol( table="event_json", keyvalues={"event_id": msg_event.event_id}, retcol="json", diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 9ddd17f73d..d389cf578f 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -16,8 +16,7 @@ from unittest.mock import Mock -from synapse.api.constants import EventTypes, Membership -from synapse.api.room_versions import RoomVersions +from synapse.api.constants import Membership from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client.v1 import login, room from synapse.types import Requester, UserID @@ -44,9 +43,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # We can't test the RoomMemberStore on its own without the other event # storage logic self.store = hs.get_datastore() - self.storage = hs.get_storage() - self.event_builder_factory = hs.get_event_builder_factory() - self.event_creation_handler = hs.get_event_creation_handler() self.u_alice = self.register_user("alice", "pass") self.t_alice = self.login("alice", "pass") @@ -55,26 +51,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase): # User elsewhere on another host self.u_charlie = UserID.from_string("@charlie:elsewhere") - def inject_room_member(self, room, user, membership, replaces_state=None): - builder = self.event_builder_factory.for_room_version( - RoomVersions.V1, - { - "type": EventTypes.Member, - "sender": user, - "state_key": user, - "room_id": room, - "content": {"membership": membership}, - }, - ) - - event, context = self.get_success( - self.event_creation_handler.create_new_client_event(builder) - ) - - self.get_success(self.storage.persistence.persist_event(event, context)) - - return event - def test_one_member(self): # Alice creates the room, and is automatically joined @@ -156,7 +132,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): # Register the background update to run again. self.get_success( - self.store._simple_insert( + self.store.simple_insert( table="background_updates", values={ "update_name": "current_state_events_membership", diff --git a/tests/unittest.py b/tests/unittest.py index 561cebc223..295573bc46 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018 New Vector +# Copyright 2019 Matrix.org Federation C.I.C # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import gc import hashlib import hmac @@ -27,13 +29,17 @@ from twisted.internet.defer import Deferred, succeed from twisted.python.threadpool import ThreadPool from twisted.trial import unittest -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, Membership +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.config.homeserver import HomeServerConfig +from synapse.config.ratelimiting import FederationRateLimitConfig +from synapse.federation.transport import server as federation_server from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest from synapse.logging.context import LoggingContext from synapse.server import HomeServer from synapse.types import Requester, UserID, create_requester +from synapse.util.ratelimitutils import FederationRateLimiter from tests.server import get_clock, make_request, render, setup_test_homeserver from tests.test_utils.logging_setup import setup_logging @@ -538,7 +544,7 @@ class HomeserverTestCase(TestCase): Add the given event as an extremity to the room. """ self.get_success( - self.hs.get_datastore()._simple_insert( + self.hs.get_datastore().simple_insert( table="event_forward_extremities", values={"room_id": room_id, "event_id": event_id}, desc="test_add_extremity", @@ -559,6 +565,66 @@ class HomeserverTestCase(TestCase): self.render(request) self.assertEqual(channel.code, 403, channel.result) + def inject_room_member(self, room: str, user: str, membership: Membership) -> None: + """ + Inject a membership event into a room. + + Args: + room: Room ID to inject the event into. + user: MXID of the user to inject the membership for. + membership: The membership type. + """ + event_builder_factory = self.hs.get_event_builder_factory() + event_creation_handler = self.hs.get_event_creation_handler() + + room_version = self.get_success(self.hs.get_datastore().get_room_version(room)) + + builder = event_builder_factory.for_room_version( + KNOWN_ROOM_VERSIONS[room_version], + { + "type": EventTypes.Member, + "sender": user, + "state_key": user, + "room_id": room, + "content": {"membership": membership}, + }, + ) + + event, context = self.get_success( + event_creation_handler.create_new_client_event(builder) + ) + + self.get_success( + self.hs.get_storage().persistence.persist_event(event, context) + ) + + +class FederatingHomeserverTestCase(HomeserverTestCase): + """ + A federating homeserver that authenticates incoming requests as `other.example.com`. + """ + + def prepare(self, reactor, clock, homeserver): + class Authenticator(object): + def authenticate_request(self, request, content): + return succeed("other.example.com") + + ratelimiter = FederationRateLimiter( + clock, + FederationRateLimitConfig( + window_size=1, + sleep_limit=1, + sleep_msec=1, + reject_limit=1000, + concurrent_requests=1000, + ), + ) + federation_server.register_servlets( + homeserver, self.resource, Authenticator(), ratelimiter + ) + + return super().prepare(reactor, clock, homeserver) + def override_config(extra_config): """A decorator which can be applied to test functions to give additional HS config diff --git a/tests/utils.py b/tests/utils.py index 7dc9bdc505..de2ac1ed33 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -109,6 +109,7 @@ def default_config(name, parse=False): """ config_dict = { "server_name": name, + "send_federation": False, "media_store_path": "media", "uploads_path": "uploads", # the test signing key is just an arbitrary ed25519 key to keep the config diff --git a/tox.ini b/tox.ini index 62b350ea6a..903a245fb0 100644 --- a/tox.ini +++ b/tox.ini @@ -102,6 +102,15 @@ commands = {envbindir}/coverage run "{envbindir}/trial" {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:} +[testenv:benchmark] +deps = + {[base]deps} + pyperf +setenv = + SYNAPSE_POSTGRES = 1 +commands = + python -m synmark {posargs:} + [testenv:packaging] skip_install=True deps = |