diff options
199 files changed, 4753 insertions, 2015 deletions
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c17e3b2399..f7bea79b0d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -46,7 +46,7 @@ locally. You'll need python 3.6 or later, and to install a number of tools: ``` # Install the dependencies -pip install -e ".[lint]" +pip install -e ".[lint,mypy]" # Run the linter script ./scripts-dev/lint.sh @@ -63,7 +63,7 @@ run-time: ./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder ``` -You can also provided the `-d` option, which will lint the files that have been +You can also provide the `-d` option, which will lint the files that have been changed since the last git commit. This will often be significantly faster than linting the whole codebase. diff --git a/INSTALL.md b/INSTALL.md index 22f7b7c029..c6fcb3bd7f 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -57,7 +57,7 @@ light workloads. System requirements: - POSIX-compliant system (tested on Linux & OS X) -- Python 3.5.2 or later, up to Python 3.8. +- Python 3.5.2 or later, up to Python 3.9. - At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org Synapse is written in Python but some of the libraries it uses are written in diff --git a/README.rst b/README.rst index d609b4b62e..59d5a4389b 100644 --- a/README.rst +++ b/README.rst @@ -256,9 +256,9 @@ directory of your choice:: Synapse has a number of external dependencies, that are easiest to install using pip and a virtualenv:: - virtualenv -p python3 env - source env/bin/activate - python -m pip install --no-use-pep517 -e ".[all]" + python3 -m venv ./env + source ./env/bin/activate + pip install -e ".[all,test]" This will run a process of downloading and installing all the needed dependencies into a virtual env. @@ -270,9 +270,9 @@ check that everything is installed as it should be:: This should end with a 'PASSED' result:: - Ran 143 tests in 0.601s + Ran 1266 tests in 643.930s - PASSED (successes=143) + PASSED (skips=15, successes=1251) Running the Integration Tests ============================= diff --git a/UPGRADE.rst b/UPGRADE.rst index 5a68312217..960c2aeb2b 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -75,6 +75,22 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb +Upgrading to v1.23.0 +==================== + +Structured logging configuration breaking changes +------------------------------------------------- + +This release deprecates use of the ``structured: true`` logging configuration for +structured logging. If your logging configuration contains ``structured: true`` +then it should be modified based on the `structured logging documentation +<https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md>`_. + +The ``structured`` and ``drains`` logging options are now deprecated and should +be replaced by standard logging configuration of ``handlers`` and ``formatters`. + +A future will release of Synapse will make using ``structured: true`` an error. + Upgrading to v1.22.0 ==================== diff --git a/changelog.d/8455.bugfix b/changelog.d/8455.bugfix new file mode 100644 index 0000000000..561e73f5e0 --- /dev/null +++ b/changelog.d/8455.bugfix @@ -0,0 +1 @@ +Fix fetching of E2E cross signing keys over federation when only one of the master key and device signing key is cached already. diff --git a/changelog.d/8519.feature b/changelog.d/8519.feature new file mode 100644 index 0000000000..e2ab548681 --- /dev/null +++ b/changelog.d/8519.feature @@ -0,0 +1 @@ +Add an admin api to delete a single file or files were not used for a defined time from server. Contributed by @dklimpel. \ No newline at end of file diff --git a/changelog.d/8539.feature b/changelog.d/8539.feature new file mode 100644 index 0000000000..15ce02fb86 --- /dev/null +++ b/changelog.d/8539.feature @@ -0,0 +1 @@ +Split admin API for reported events (`GET /_synapse/admin/v1/event_reports`) into detail and list endpoints. This is a breaking change to #8217 which was introduced in Synapse v1.21.0. Those who already use this API should check their scripts. Contributed by @dklimpel. \ No newline at end of file diff --git a/changelog.d/8559.misc b/changelog.d/8559.misc new file mode 100644 index 0000000000..d7bd00964e --- /dev/null +++ b/changelog.d/8559.misc @@ -0,0 +1 @@ +Optimise `/createRoom` with multiple invited users. diff --git a/changelog.d/8580.bugfix b/changelog.d/8580.bugfix new file mode 100644 index 0000000000..31734fd97d --- /dev/null +++ b/changelog.d/8580.bugfix @@ -0,0 +1 @@ +Fix a bug where Synapse would blindly forward bad responses from federation to clients when retrieving profile information. diff --git a/changelog.d/8582.doc b/changelog.d/8582.doc new file mode 100644 index 0000000000..041f168717 --- /dev/null +++ b/changelog.d/8582.doc @@ -0,0 +1 @@ +Instructions for Azure AD in the OpenID Connect documentation. Contributed by peterk. diff --git a/changelog.d/8595.misc b/changelog.d/8595.misc new file mode 100644 index 0000000000..24fab65cda --- /dev/null +++ b/changelog.d/8595.misc @@ -0,0 +1 @@ +Implement and use an @lru_cache decorator. diff --git a/changelog.d/8607.feature b/changelog.d/8607.feature new file mode 100644 index 0000000000..fef1eccb92 --- /dev/null +++ b/changelog.d/8607.feature @@ -0,0 +1 @@ +Support generating structured logs via the standard logging configuration. diff --git a/changelog.d/8610.feature b/changelog.d/8610.feature new file mode 100644 index 0000000000..ed8d926964 --- /dev/null +++ b/changelog.d/8610.feature @@ -0,0 +1 @@ +Add an admin APIs to allow server admins to list users' pushers. Contributed by @dklimpel. \ No newline at end of file diff --git a/changelog.d/8614.misc b/changelog.d/8614.misc new file mode 100644 index 0000000000..1bf9ea08f0 --- /dev/null +++ b/changelog.d/8614.misc @@ -0,0 +1 @@ +Don't instansiate Requester directly. diff --git a/changelog.d/8615.misc b/changelog.d/8615.misc new file mode 100644 index 0000000000..79fa7b7ff8 --- /dev/null +++ b/changelog.d/8615.misc @@ -0,0 +1 @@ +Type hints for `RegistrationStore`. diff --git a/changelog.d/8616.misc b/changelog.d/8616.misc new file mode 100644 index 0000000000..385b14063e --- /dev/null +++ b/changelog.d/8616.misc @@ -0,0 +1 @@ +Change schema to support access tokens belonging to one user but granting access to another. diff --git a/changelog.d/8620.bugfix b/changelog.d/8620.bugfix new file mode 100644 index 0000000000..c1078a3fb5 --- /dev/null +++ b/changelog.d/8620.bugfix @@ -0,0 +1 @@ +Fix a bug where the account validity endpoint would silently fail if the user ID did not have an expiration time. It now returns a 400 error. diff --git a/changelog.d/8621.misc b/changelog.d/8621.misc new file mode 100644 index 0000000000..5720b665fe --- /dev/null +++ b/changelog.d/8621.misc @@ -0,0 +1 @@ +Remove unused OPTIONS handlers. diff --git a/changelog.d/8627.bugfix b/changelog.d/8627.bugfix new file mode 100644 index 0000000000..143cf95f92 --- /dev/null +++ b/changelog.d/8627.bugfix @@ -0,0 +1 @@ +Fix email notifications for invites without local state. diff --git a/changelog.d/8628.bugfix b/changelog.d/8628.bugfix new file mode 100644 index 0000000000..1316136ca2 --- /dev/null +++ b/changelog.d/8628.bugfix @@ -0,0 +1 @@ +Fix handling of invalid group IDs to return a 400 rather than log an exception and return a 500. diff --git a/changelog.d/8632.bugfix b/changelog.d/8632.bugfix new file mode 100644 index 0000000000..7d834aa2e2 --- /dev/null +++ b/changelog.d/8632.bugfix @@ -0,0 +1 @@ +Fix handling of User-Agent headers that are invalid UTF-8, which caused user agents of users to not get correctly recorded. diff --git a/changelog.d/8633.misc b/changelog.d/8633.misc new file mode 100644 index 0000000000..8e1d006b36 --- /dev/null +++ b/changelog.d/8633.misc @@ -0,0 +1 @@ +Run `mypy` as part of the lint.sh script. diff --git a/changelog.d/8634.misc b/changelog.d/8634.misc new file mode 100644 index 0000000000..c4f74ba7c9 --- /dev/null +++ b/changelog.d/8634.misc @@ -0,0 +1 @@ +Correct Synapse's PyPI package name in the OpenID Connect installation instructions. \ No newline at end of file diff --git a/changelog.d/8635.doc b/changelog.d/8635.doc new file mode 100644 index 0000000000..00fb1e61a7 --- /dev/null +++ b/changelog.d/8635.doc @@ -0,0 +1 @@ +Improve the sample configuration for single sign-on providers. diff --git a/changelog.d/8639.misc b/changelog.d/8639.misc new file mode 100644 index 0000000000..20a213df39 --- /dev/null +++ b/changelog.d/8639.misc @@ -0,0 +1 @@ +Fix typos and spelling errors in the code. diff --git a/changelog.d/8640.misc b/changelog.d/8640.misc new file mode 100644 index 0000000000..cf6023f783 --- /dev/null +++ b/changelog.d/8640.misc @@ -0,0 +1 @@ +Reduce number of OpenTracing spans started. diff --git a/changelog.d/8643.bugfix b/changelog.d/8643.bugfix new file mode 100644 index 0000000000..fcda1ca871 --- /dev/null +++ b/changelog.d/8643.bugfix @@ -0,0 +1 @@ +Fix a bug in the `joined_rooms` admin API if the user has never joined any rooms. The bug was introduced, along with the API, in v1.21.0. diff --git a/changelog.d/8644.misc b/changelog.d/8644.misc new file mode 100644 index 0000000000..87f2b72924 --- /dev/null +++ b/changelog.d/8644.misc @@ -0,0 +1 @@ +Add field `total` to device list in admin API. \ No newline at end of file diff --git a/changelog.d/8647.feature b/changelog.d/8647.feature new file mode 100644 index 0000000000..79e98f6e90 --- /dev/null +++ b/changelog.d/8647.feature @@ -0,0 +1 @@ +Add an admin API `GET /_synapse/admin/v1/users/<user_id>/media` to get information about uploaded media. Contributed by @dklimpel. \ No newline at end of file diff --git a/changelog.d/8655.misc b/changelog.d/8655.misc new file mode 100644 index 0000000000..b588bdd3e2 --- /dev/null +++ b/changelog.d/8655.misc @@ -0,0 +1 @@ +Add more type hints to the application services code. diff --git a/changelog.d/8657.doc b/changelog.d/8657.doc new file mode 100644 index 0000000000..3dcbb221af --- /dev/null +++ b/changelog.d/8657.doc @@ -0,0 +1 @@ +Fix the filepath of Dex's example config and the link to Dex's Getting Started guide in the OpenID Connect docs. diff --git a/changelog.d/8664.misc b/changelog.d/8664.misc new file mode 100644 index 0000000000..278cf53adc --- /dev/null +++ b/changelog.d/8664.misc @@ -0,0 +1 @@ +Tell Black to format code for Python 3.5. diff --git a/changelog.d/8665.doc b/changelog.d/8665.doc new file mode 100644 index 0000000000..3b75307dc5 --- /dev/null +++ b/changelog.d/8665.doc @@ -0,0 +1 @@ +Note support for Python 3.9. diff --git a/changelog.d/8666.doc b/changelog.d/8666.doc new file mode 100644 index 0000000000..dee86b4a26 --- /dev/null +++ b/changelog.d/8666.doc @@ -0,0 +1 @@ +Minor updates to docs on running tests. diff --git a/changelog.d/8667.doc b/changelog.d/8667.doc new file mode 100644 index 0000000000..422d697da6 --- /dev/null +++ b/changelog.d/8667.doc @@ -0,0 +1 @@ +Interlink prometheus/grafana documentation. diff --git a/changelog.d/8668.misc b/changelog.d/8668.misc new file mode 100644 index 0000000000..cf6023f783 --- /dev/null +++ b/changelog.d/8668.misc @@ -0,0 +1 @@ +Reduce number of OpenTracing spans started. diff --git a/changelog.d/8669.misc b/changelog.d/8669.misc new file mode 100644 index 0000000000..5228105cd3 --- /dev/null +++ b/changelog.d/8669.misc @@ -0,0 +1 @@ +Don't pull event from DB when handling replication traffic. diff --git a/changelog.d/8670.misc b/changelog.d/8670.misc new file mode 100644 index 0000000000..cf6023f783 --- /dev/null +++ b/changelog.d/8670.misc @@ -0,0 +1 @@ +Reduce number of OpenTracing spans started. diff --git a/changelog.d/8671.misc b/changelog.d/8671.misc new file mode 100644 index 0000000000..bef8dc425a --- /dev/null +++ b/changelog.d/8671.misc @@ -0,0 +1 @@ +Abstract some invite-related code in preparation for landing knocking. \ No newline at end of file diff --git a/changelog.d/8679.misc b/changelog.d/8679.misc new file mode 100644 index 0000000000..662eced4cf --- /dev/null +++ b/changelog.d/8679.misc @@ -0,0 +1 @@ +Clarify representation of events in logfiles. diff --git a/changelog.d/8680.misc b/changelog.d/8680.misc new file mode 100644 index 0000000000..2ca2975464 --- /dev/null +++ b/changelog.d/8680.misc @@ -0,0 +1 @@ +Don't require `hiredis` package to be installed to run unit tests. diff --git a/changelog.d/8682.bugfix b/changelog.d/8682.bugfix new file mode 100644 index 0000000000..e61276aa05 --- /dev/null +++ b/changelog.d/8682.bugfix @@ -0,0 +1 @@ +Fix exception during handling multiple concurrent requests for remote media when using multiple media repositories. diff --git a/changelog.d/8684.misc b/changelog.d/8684.misc new file mode 100644 index 0000000000..1d23d42926 --- /dev/null +++ b/changelog.d/8684.misc @@ -0,0 +1 @@ +Fix typing info on cache call signature to accept `on_invalidate`. diff --git a/changelog.d/8685.feature b/changelog.d/8685.feature new file mode 100644 index 0000000000..fef1eccb92 --- /dev/null +++ b/changelog.d/8685.feature @@ -0,0 +1 @@ +Support generating structured logs via the standard logging configuration. diff --git a/changelog.d/8688.misc b/changelog.d/8688.misc new file mode 100644 index 0000000000..bef8dc425a --- /dev/null +++ b/changelog.d/8688.misc @@ -0,0 +1 @@ +Abstract some invite-related code in preparation for landing knocking. \ No newline at end of file diff --git a/changelog.d/8689.feature b/changelog.d/8689.feature new file mode 100644 index 0000000000..ed8d926964 --- /dev/null +++ b/changelog.d/8689.feature @@ -0,0 +1 @@ +Add an admin APIs to allow server admins to list users' pushers. Contributed by @dklimpel. \ No newline at end of file diff --git a/changelog.d/8690.misc b/changelog.d/8690.misc new file mode 100644 index 0000000000..0f38ba1f5d --- /dev/null +++ b/changelog.d/8690.misc @@ -0,0 +1 @@ +Fail tests if they do not await coroutines. diff --git a/contrib/grafana/README.md b/contrib/grafana/README.md index ca780d412e..4608793394 100644 --- a/contrib/grafana/README.md +++ b/contrib/grafana/README.md @@ -3,4 +3,4 @@ 0. Set up Prometheus and Grafana. Out of scope for this readme. Useful documentation about using Grafana with Prometheus: http://docs.grafana.org/features/datasources/prometheus/ 1. Have your Prometheus scrape your Synapse. https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md 2. Import dashboard into Grafana. Download `synapse.json`. Import it to Grafana and select the correct Prometheus datasource. http://docs.grafana.org/reference/export_import/ -3. Set up additional recording rules +3. Set up required recording rules. https://github.com/matrix-org/synapse/tree/master/contrib/prometheus diff --git a/docs/admin_api/event_reports.rst b/docs/admin_api/event_reports.rst index 461be01230..5f7b0fa6bb 100644 --- a/docs/admin_api/event_reports.rst +++ b/docs/admin_api/event_reports.rst @@ -17,67 +17,26 @@ It returns a JSON body like the following: { "event_reports": [ { - "content": { - "reason": "foo", - "score": -100 - }, "event_id": "$bNUFCwGzWca1meCGkjp-zwslF-GfVcXukvRLI1_FaVY", - "event_json": { - "auth_events": [ - "$YK4arsKKcc0LRoe700pS8DSjOvUT4NDv0HfInlMFw2M", - "$oggsNXxzPFRE3y53SUNd7nsj69-QzKv03a1RucHu-ws" - ], - "content": { - "body": "matrix.org: This Week in Matrix", - "format": "org.matrix.custom.html", - "formatted_body": "<strong>matrix.org</strong>:<br><a href=\"https://matrix.org/blog/\"><strong>This Week in Matrix</strong></a>", - "msgtype": "m.notice" - }, - "depth": 546, - "hashes": { - "sha256": "xK1//xnmvHJIOvbgXlkI8eEqdvoMmihVDJ9J4SNlsAw" - }, - "origin": "matrix.org", - "origin_server_ts": 1592291711430, - "prev_events": [ - "$YK4arsKKcc0LRoe700pS8DSjOvUT4NDv0HfInlMFw2M" - ], - "prev_state": [], - "room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org", - "sender": "@foobar:matrix.org", - "signatures": { - "matrix.org": { - "ed25519:a_JaEG": "cs+OUKW/iHx5pEidbWxh0UiNNHwe46Ai9LwNz+Ah16aWDNszVIe2gaAcVZfvNsBhakQTew51tlKmL2kspXk/Dg" - } - }, - "type": "m.room.message", - "unsigned": { - "age_ts": 1592291711430, - } - }, "id": 2, "reason": "foo", + "score": -100, "received_ts": 1570897107409, - "room_alias": "#alias1:matrix.org", + "canonical_alias": "#alias1:matrix.org", "room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org", + "name": "Matrix HQ", "sender": "@foobar:matrix.org", "user_id": "@foo:matrix.org" }, { - "content": { - "reason": "bar", - "score": -100 - }, "event_id": "$3IcdZsDaN_En-S1DF4EMCy3v4gNRKeOJs8W5qTOKj4I", - "event_json": { - // hidden items - // see above - }, "id": 3, "reason": "bar", + "score": -100, "received_ts": 1598889612059, - "room_alias": "#alias2:matrix.org", + "canonical_alias": "#alias2:matrix.org", "room_id": "!eGvUQuTCkHGVwNMOjv:matrix.org", + "name": "Your room name here", "sender": "@foobar:matrix.org", "user_id": "@bar:matrix.org" } @@ -113,17 +72,94 @@ The following fields are returned in the JSON response body: - ``id``: integer - ID of event report. - ``received_ts``: integer - The timestamp (in milliseconds since the unix epoch) when this report was sent. - ``room_id``: string - The ID of the room in which the event being reported is located. +- ``name``: string - The name of the room. - ``event_id``: string - The ID of the reported event. - ``user_id``: string - This is the user who reported the event and wrote the reason. - ``reason``: string - Comment made by the ``user_id`` in this report. May be blank. -- ``content``: object - Content of reported event. - - - ``reason``: string - Comment made by the ``user_id`` in this report. May be blank. - - ``score``: integer - Content is reported based upon a negative score, where -100 is "most offensive" and 0 is "inoffensive". - +- ``score``: integer - Content is reported based upon a negative score, where -100 is "most offensive" and 0 is "inoffensive". - ``sender``: string - This is the ID of the user who sent the original message/event that was reported. -- ``room_alias``: string - The alias of the room. ``null`` if the room does not have a canonical alias set. -- ``event_json``: object - Details of the original event that was reported. +- ``canonical_alias``: string - The canonical alias of the room. ``null`` if the room does not have a canonical alias set. - ``next_token``: integer - Indication for pagination. See above. - ``total``: integer - Total number of event reports related to the query (``user_id`` and ``room_id``). +Show details of a specific event report +======================================= + +This API returns information about a specific event report. + +The api is:: + + GET /_synapse/admin/v1/event_reports/<report_id> + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst <README.rst>`_. + +It returns a JSON body like the following: + +.. code:: jsonc + + { + "event_id": "$bNUFCwGzWca1meCGkjp-zwslF-GfVcXukvRLI1_FaVY", + "event_json": { + "auth_events": [ + "$YK4arsKKcc0LRoe700pS8DSjOvUT4NDv0HfInlMFw2M", + "$oggsNXxzPFRE3y53SUNd7nsj69-QzKv03a1RucHu-ws" + ], + "content": { + "body": "matrix.org: This Week in Matrix", + "format": "org.matrix.custom.html", + "formatted_body": "<strong>matrix.org</strong>:<br><a href=\"https://matrix.org/blog/\"><strong>This Week in Matrix</strong></a>", + "msgtype": "m.notice" + }, + "depth": 546, + "hashes": { + "sha256": "xK1//xnmvHJIOvbgXlkI8eEqdvoMmihVDJ9J4SNlsAw" + }, + "origin": "matrix.org", + "origin_server_ts": 1592291711430, + "prev_events": [ + "$YK4arsKKcc0LRoe700pS8DSjOvUT4NDv0HfInlMFw2M" + ], + "prev_state": [], + "room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org", + "sender": "@foobar:matrix.org", + "signatures": { + "matrix.org": { + "ed25519:a_JaEG": "cs+OUKW/iHx5pEidbWxh0UiNNHwe46Ai9LwNz+Ah16aWDNszVIe2gaAcVZfvNsBhakQTew51tlKmL2kspXk/Dg" + } + }, + "type": "m.room.message", + "unsigned": { + "age_ts": 1592291711430, + } + }, + "id": <report_id>, + "reason": "foo", + "score": -100, + "received_ts": 1570897107409, + "canonical_alias": "#alias1:matrix.org", + "room_id": "!ERAgBpSOcCCuTJqQPk:matrix.org", + "name": "Matrix HQ", + "sender": "@foobar:matrix.org", + "user_id": "@foo:matrix.org" + } + +**URL parameters:** + +- ``report_id``: string - The ID of the event report. + +**Response** + +The following fields are returned in the JSON response body: + +- ``id``: integer - ID of event report. +- ``received_ts``: integer - The timestamp (in milliseconds since the unix epoch) when this report was sent. +- ``room_id``: string - The ID of the room in which the event being reported is located. +- ``name``: string - The name of the room. +- ``event_id``: string - The ID of the reported event. +- ``user_id``: string - This is the user who reported the event and wrote the reason. +- ``reason``: string - Comment made by the ``user_id`` in this report. May be blank. +- ``score``: integer - Content is reported based upon a negative score, where -100 is "most offensive" and 0 is "inoffensive". +- ``sender``: string - This is the ID of the user who sent the original message/event that was reported. +- ``canonical_alias``: string - The canonical alias of the room. ``null`` if the room does not have a canonical alias set. +- ``event_json``: object - Details of the original event that was reported. diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index 26948770d8..3994e1f1a9 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -100,3 +100,82 @@ Response: "num_quarantined": 10 # The number of media items successfully quarantined } ``` + +# Delete local media +This API deletes the *local* media from the disk of your own server. +This includes any local thumbnails and copies of media downloaded from +remote homeservers. +This API will not affect media that has been uploaded to external +media repositories (e.g https://github.com/turt2live/matrix-media-repo/). +See also [purge_remote_media.rst](purge_remote_media.rst). + +## Delete a specific local media +Delete a specific `media_id`. + +Request: + +``` +DELETE /_synapse/admin/v1/media/<server_name>/<media_id> + +{} +``` + +URL Parameters + +* `server_name`: string - The name of your local server (e.g `matrix.org`) +* `media_id`: string - The ID of the media (e.g `abcdefghijklmnopqrstuvwx`) + +Response: + +```json + { + "deleted_media": [ + "abcdefghijklmnopqrstuvwx" + ], + "total": 1 + } +``` + +The following fields are returned in the JSON response body: + +* `deleted_media`: an array of strings - List of deleted `media_id` +* `total`: integer - Total number of deleted `media_id` + +## Delete local media by date or size + +Request: + +``` +POST /_synapse/admin/v1/media/<server_name>/delete?before_ts=<before_ts> + +{} +``` + +URL Parameters + +* `server_name`: string - The name of your local server (e.g `matrix.org`). +* `before_ts`: string representing a positive integer - Unix timestamp in ms. +Files that were last used before this timestamp will be deleted. It is the timestamp of +last access and not the timestamp creation. +* `size_gt`: Optional - string representing a positive integer - Size of the media in bytes. +Files that are larger will be deleted. Defaults to `0`. +* `keep_profiles`: Optional - string representing a boolean - Switch to also delete files +that are still used in image data (e.g user profile, room avatar). +If `false` these files will be deleted. Defaults to `true`. + +Response: + +```json + { + "deleted_media": [ + "abcdefghijklmnopqrstuvwx", + "abcdefghijklmnopqrstuvwz" + ], + "total": 2 + } +``` + +The following fields are returned in the JSON response body: + +* `deleted_media`: an array of strings - List of deleted `media_id` +* `total`: integer - Total number of deleted `media_id` diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index 7ca902faba..d4051d0257 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -341,6 +341,89 @@ The following fields are returned in the JSON response body: - ``total`` - Number of rooms. +List media of an user +================================ +Gets a list of all local media that a specific ``user_id`` has created. +The response is ordered by creation date descending and media ID descending. +The newest media is on top. + +The API is:: + + GET /_synapse/admin/v1/users/<user_id>/media + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst <README.rst>`_. + +A response body like the following is returned: + +.. code:: json + + { + "media": [ + { + "created_ts": 100400, + "last_access_ts": null, + "media_id": "qXhyRzulkwLsNHTbpHreuEgo", + "media_length": 67, + "media_type": "image/png", + "quarantined_by": null, + "safe_from_quarantine": false, + "upload_name": "test1.png" + }, + { + "created_ts": 200400, + "last_access_ts": null, + "media_id": "FHfiSnzoINDatrXHQIXBtahw", + "media_length": 67, + "media_type": "image/png", + "quarantined_by": null, + "safe_from_quarantine": false, + "upload_name": "test2.png" + } + ], + "next_token": 3, + "total": 2 + } + +To paginate, check for ``next_token`` and if present, call the endpoint again +with ``from`` set to the value of ``next_token``. This will return a new page. + +If the endpoint does not return a ``next_token`` then there are no more +reports to paginate through. + +**Parameters** + +The following parameters should be set in the URL: + +- ``user_id`` - string - fully qualified: for example, ``@user:server.com``. +- ``limit``: string representing a positive integer - Is optional but is used for pagination, + denoting the maximum number of items to return in this call. Defaults to ``100``. +- ``from``: string representing a positive integer - Is optional but used for pagination, + denoting the offset in the returned results. This should be treated as an opaque value and + not explicitly set to anything other than the return value of ``next_token`` from a previous call. + Defaults to ``0``. + +**Response** + +The following fields are returned in the JSON response body: + +- ``media`` - An array of objects, each containing information about a media. + Media objects contain the following fields: + + - ``created_ts`` - integer - Timestamp when the content was uploaded in ms. + - ``last_access_ts`` - integer - Timestamp when the content was last accessed in ms. + - ``media_id`` - string - The id used to refer to the media. + - ``media_length`` - integer - Length of the media in bytes. + - ``media_type`` - string - The MIME-type of the media. + - ``quarantined_by`` - string - The user ID that initiated the quarantine request + for this media. + + - ``safe_from_quarantine`` - bool - Status if this media is safe from quarantining. + - ``upload_name`` - string - The name the media was uploaded with. + +- ``next_token``: integer - Indication for pagination. See above. +- ``total`` - integer - Total number of media. + User devices ============ @@ -375,7 +458,8 @@ A response body like the following is returned: "last_seen_ts": 1474491775025, "user_id": "<user_id>" } - ] + ], + "total": 2 } **Parameters** @@ -400,6 +484,8 @@ The following fields are returned in the JSON response body: devices was last seen. (May be a few minutes out of date, for efficiency reasons). - ``user_id`` - Owner of device. +- ``total`` - Total number of user's devices. + Delete multiple devices ------------------ Deletes the given devices for a specific ``user_id``, and invalidates @@ -525,3 +611,82 @@ The following parameters should be set in the URL: - ``user_id`` - fully qualified: for example, ``@user:server.com``. - ``device_id`` - The device to delete. + +List all pushers +================ +Gets information about all pushers for a specific ``user_id``. + +The API is:: + + GET /_synapse/admin/v1/users/<user_id>/pushers + +To use it, you will need to authenticate by providing an ``access_token`` for a +server admin: see `README.rst <README.rst>`_. + +A response body like the following is returned: + +.. code:: json + + { + "pushers": [ + { + "app_display_name":"HTTP Push Notifications", + "app_id":"m.http", + "data": { + "url":"example.com" + }, + "device_display_name":"pushy push", + "kind":"http", + "lang":"None", + "profile_tag":"", + "pushkey":"a@example.com" + } + ], + "total": 1 + } + +**Parameters** + +The following parameters should be set in the URL: + +- ``user_id`` - fully qualified: for example, ``@user:server.com``. + +**Response** + +The following fields are returned in the JSON response body: + +- ``pushers`` - An array containing the current pushers for the user + + - ``app_display_name`` - string - A string that will allow the user to identify + what application owns this pusher. + + - ``app_id`` - string - This is a reverse-DNS style identifier for the application. + Max length, 64 chars. + + - ``data`` - A dictionary of information for the pusher implementation itself. + + - ``url`` - string - Required if ``kind`` is ``http``. The URL to use to send + notifications to. + + - ``format`` - string - The format to use when sending notifications to the + Push Gateway. + + - ``device_display_name`` - string - A string that will allow the user to identify + what device owns this pusher. + + - ``profile_tag`` - string - This string determines which set of device specific rules + this pusher executes. + + - ``kind`` - string - The kind of pusher. "http" is a pusher that sends HTTP pokes. + - ``lang`` - string - The preferred language for receiving notifications + (e.g. 'en' or 'en-US') + + - ``profile_tag`` - string - This string determines which set of device specific rules + this pusher executes. + + - ``pushkey`` - string - This is a unique identifier for this pusher. + Max length, 512 bytes. + +- ``total`` - integer - Number of pushers. + +See also `Client-Server API Spec <https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers>`_ diff --git a/docs/metrics-howto.md b/docs/metrics-howto.md index b386ec91c1..fb71af4911 100644 --- a/docs/metrics-howto.md +++ b/docs/metrics-howto.md @@ -60,6 +60,8 @@ 1. Restart Prometheus. +1. Consider using the [grafana dashboard](https://github.com/matrix-org/synapse/tree/master/contrib/grafana/) and required [recording rules](https://github.com/matrix-org/synapse/tree/master/contrib/prometheus/) + ## Monitoring workers To monitor a Synapse installation using diff --git a/docs/openid.md b/docs/openid.md index 4873681999..6670f36261 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -37,7 +37,7 @@ as follows: provided by `matrix.org` so no further action is needed. * If you installed Synapse into a virtualenv, run `/path/to/env/bin/pip - install synapse[oidc]` to install the necessary dependencies. + install matrix-synapse[oidc]` to install the necessary dependencies. * For other installation mechanisms, see the documentation provided by the maintainer. @@ -52,14 +52,39 @@ specific providers. Here are a few configs for providers that should work with Synapse. +### Microsoft Azure Active Directory +Azure AD can act as an OpenID Connect Provider. Register a new application under +*App registrations* in the Azure AD management console. The RedirectURI for your +application should point to your matrix server: `[synapse public baseurl]/_synapse/oidc/callback` + +Go to *Certificates & secrets* and register a new client secret. Make note of your +Directory (tenant) ID as it will be used in the Azure links. +Edit your Synapse config file and change the `oidc_config` section: + +```yaml +oidc_config: + enabled: true + issuer: "https://login.microsoftonline.com/<tenant id>/v2.0" + client_id: "<client id>" + client_secret: "<client secret>" + scopes: ["openid", "profile"] + authorization_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize" + token_endpoint: "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token" + userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo" + + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username.split('@')[0] }}" + display_name_template: "{{ user.name }}" +``` + ### [Dex][dex-idp] [Dex][dex-idp] is a simple, open-source, certified OpenID Connect Provider. Although it is designed to help building a full-blown provider with an external database, it can be configured with static passwords in a config file. -Follow the [Getting Started -guide](https://github.com/dexidp/dex/blob/master/Documentation/getting-started.md) +Follow the [Getting Started guide](https://dexidp.io/docs/getting-started/) to install Dex. Edit `examples/config-dev.yaml` config file from the Dex repo to add a client: @@ -73,7 +98,7 @@ staticClients: name: 'Synapse' ``` -Run with `dex serve examples/config-dex.yaml`. +Run with `dex serve examples/config-dev.yaml`. Synapse config: diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 061226ea6f..7e2cf97c3e 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1505,10 +1505,8 @@ trusted_key_servers: ## Single sign-on integration ## -# Enable SAML2 for registration and login. Uses pysaml2. -# -# At least one of `sp_config` or `config_path` must be set in this section to -# enable SAML login. +# The following settings can be used to make Synapse use a single sign-on +# provider for authentication, instead of its internal password database. # # You will probably also want to set the following options to `false` to # disable the regular login/registration flows: @@ -1517,6 +1515,11 @@ trusted_key_servers: # # You will also want to investigate the settings under the "sso" configuration # section below. + +# Enable SAML2 for registration and login. Uses pysaml2. +# +# At least one of `sp_config` or `config_path` must be set in this section to +# enable SAML login. # # Once SAML support is enabled, a metadata file will be exposed at # https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to @@ -1532,40 +1535,42 @@ saml2_config: # so it is not normally necessary to specify them unless you need to # override them. # - #sp_config: - # # point this to the IdP's metadata. You can use either a local file or - # # (preferably) a URL. - # metadata: - # #local: ["saml2/idp.xml"] - # remote: - # - url: https://our_idp/metadata.xml - # - # # By default, the user has to go to our login page first. If you'd like - # # to allow IdP-initiated login, set 'allow_unsolicited: true' in a - # # 'service.sp' section: - # # - # #service: - # # sp: - # # allow_unsolicited: true - # - # # The examples below are just used to generate our metadata xml, and you - # # may well not need them, depending on your setup. Alternatively you - # # may need a whole lot more detail - see the pysaml2 docs! - # - # description: ["My awesome SP", "en"] - # name: ["Test SP", "en"] - # - # organization: - # name: Example com - # display_name: - # - ["Example co", "en"] - # url: "http://example.com" - # - # contact_person: - # - given_name: Bob - # sur_name: "the Sysadmin" - # email_address": ["admin@example.com"] - # contact_type": technical + sp_config: + # Point this to the IdP's metadata. You must provide either a local + # file via the `local` attribute or (preferably) a URL via the + # `remote` attribute. + # + #metadata: + # local: ["saml2/idp.xml"] + # remote: + # - url: https://our_idp/metadata.xml + + # By default, the user has to go to our login page first. If you'd like + # to allow IdP-initiated login, set 'allow_unsolicited: true' in a + # 'service.sp' section: + # + #service: + # sp: + # allow_unsolicited: true + + # The examples below are just used to generate our metadata xml, and you + # may well not need them, depending on your setup. Alternatively you + # may need a whole lot more detail - see the pysaml2 docs! + + #description: ["My awesome SP", "en"] + #name: ["Test SP", "en"] + + #organization: + # name: Example com + # display_name: + # - ["Example co", "en"] + # url: "http://example.com" + + #contact_person: + # - given_name: Bob + # sur_name: "the Sysadmin" + # email_address": ["admin@example.com"] + # contact_type": technical # Instead of putting the config inline as above, you can specify a # separate pysaml2 configuration file: @@ -1641,11 +1646,10 @@ saml2_config: # value: "sales" -# OpenID Connect integration. The following settings can be used to make Synapse -# use an OpenID Connect Provider for authentication, instead of its internal -# password database. +# Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login. # -# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md. +# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md +# for some example configurations. # oidc_config: # Uncomment the following to enable authorization against an OpenID Connect @@ -1778,15 +1782,37 @@ oidc_config: -# Enable CAS for registration and login. +# Enable Central Authentication Service (CAS) for registration and login. # -#cas_config: -# enabled: true -# server_url: "https://cas-server.com" -# service_url: "https://homeserver.domain.com:8448" -# #displayname_attribute: name -# #required_attributes: -# # name: value +cas_config: + # Uncomment the following to enable authorization against a CAS server. + # Defaults to false. + # + #enabled: true + + # The URL of the CAS authorization endpoint. + # + #server_url: "https://cas-server.com" + + # The public URL of the homeserver. + # + #service_url: "https://homeserver.domain.com:8448" + + # The attribute of the CAS response to use as the display name. + # + # If unset, no displayname will be set. + # + #displayname_attribute: name + + # It is possible to configure Synapse to only allow logins if CAS attributes + # match particular values. All of the keys in the mapping below must exist + # and the values must match the given value. Alternately if the given value + # is None then any value is allowed (the attribute just must exist). + # All of the listed attributes must match for the login to be permitted. + # + #required_attributes: + # userGroup: "staff" + # department: None # Additional settings to use with single-sign on systems such as OpenID Connect, @@ -1886,7 +1912,7 @@ sso: # and issued at ("iat") claims are validated if present. # # Note that this is a non-standard login type and client support is -# expected to be non-existant. +# expected to be non-existent. # # See https://github.com/matrix-org/synapse/blob/master/docs/jwt.md. # @@ -2402,7 +2428,7 @@ spam_checker: # # Options for the rules include: # -# user_id: Matches agaisnt the creator of the alias +# user_id: Matches against the creator of the alias # room_id: Matches against the room ID being published # alias: Matches against any current local or canonical aliases # associated with the room @@ -2448,7 +2474,7 @@ opentracing: # This is a list of regexes which are matched against the server_name of the # homeserver. # - # By defult, it is empty, so no servers are matched. + # By default, it is empty, so no servers are matched. # #homeserver_whitelist: # - ".*" diff --git a/docs/sample_log_config.yaml b/docs/sample_log_config.yaml index 55a48a9ed6..ff3c747180 100644 --- a/docs/sample_log_config.yaml +++ b/docs/sample_log_config.yaml @@ -3,7 +3,11 @@ # This is a YAML file containing a standard Python logging configuration # dictionary. See [1] for details on the valid settings. # +# Synapse also supports structured logging for machine readable logs which can +# be ingested by ELK stacks. See [2] for details. +# # [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema +# [2]: https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md version: 1 @@ -59,7 +63,7 @@ root: # then write them to a file. # # Replace "buffer" with "console" to log to stderr instead. (Note that you'll - # also need to update the configuation for the `twisted` logger above, in + # also need to update the configuration for the `twisted` logger above, in # this case.) # handlers: [buffer] diff --git a/docs/structured_logging.md b/docs/structured_logging.md index decec9b8fa..b1281667e0 100644 --- a/docs/structured_logging.md +++ b/docs/structured_logging.md @@ -1,83 +1,161 @@ # Structured Logging -A structured logging system can be useful when your logs are destined for a machine to parse and process. By maintaining its machine-readable characteristics, it enables more efficient searching and aggregations when consumed by software such as the "ELK stack". +A structured logging system can be useful when your logs are destined for a +machine to parse and process. By maintaining its machine-readable characteristics, +it enables more efficient searching and aggregations when consumed by software +such as the "ELK stack". -Synapse's structured logging system is configured via the file that Synapse's `log_config` config option points to. The file must be YAML and contain `structured: true`. It must contain a list of "drains" (places where logs go to). +Synapse's structured logging system is configured via the file that Synapse's +`log_config` config option points to. The file should include a formatter which +uses the `synapse.logging.TerseJsonFormatter` class included with Synapse and a +handler which uses the above formatter. + +There is also a `synapse.logging.JsonFormatter` option which does not include +a timestamp in the resulting JSON. This is useful if the log ingester adds its +own timestamp. A structured logging configuration looks similar to the following: ```yaml -structured: true +version: 1 + +formatters: + structured: + class: synapse.logging.TerseJsonFormatter + +handlers: + file: + class: logging.handlers.TimedRotatingFileHandler + formatter: structured + filename: /path/to/my/logs/homeserver.log + when: midnight + backupCount: 3 # Does not include the current log file. + encoding: utf8 loggers: synapse: level: INFO + handlers: [remote] synapse.storage.SQL: level: WARNING - -drains: - console: - type: console - location: stdout - file: - type: file_json - location: homeserver.log ``` -The above logging config will set Synapse as 'INFO' logging level by default, with the SQL layer at 'WARNING', and will have two logging drains (to the console and to a file, stored as JSON). - -## Drain Types +The above logging config will set Synapse as 'INFO' logging level by default, +with the SQL layer at 'WARNING', and will log to a file, stored as JSON. -Drain types can be specified by the `type` key. +It is also possible to figure Synapse to log to a remote endpoint by using the +`synapse.logging.RemoteHandler` class included with Synapse. It takes the +following arguments: -### `console` +- `host`: Hostname or IP address of the log aggregator. +- `port`: Numerical port to contact on the host. +- `maximum_buffer`: (Optional, defaults to 1000) The maximum buffer size to allow. -Outputs human-readable logs to the console. +A remote structured logging configuration looks similar to the following: -Arguments: +```yaml +version: 1 -- `location`: Either `stdout` or `stderr`. +formatters: + structured: + class: synapse.logging.TerseJsonFormatter -### `console_json` +handlers: + remote: + class: synapse.logging.RemoteHandler + formatter: structured + host: 10.1.2.3 + port: 9999 -Outputs machine-readable JSON logs to the console. +loggers: + synapse: + level: INFO + handlers: [remote] + synapse.storage.SQL: + level: WARNING +``` -Arguments: +The above logging config will set Synapse as 'INFO' logging level by default, +with the SQL layer at 'WARNING', and will log JSON formatted messages to a +remote endpoint at 10.1.2.3:9999. -- `location`: Either `stdout` or `stderr`. +## Upgrading from legacy structured logging configuration -### `console_json_terse` +Versions of Synapse prior to v1.23.0 included a custom structured logging +configuration which is deprecated. It used a `structured: true` flag and +configured `drains` instead of ``handlers`` and `formatters`. -Outputs machine-readable JSON logs to the console, separated by newlines. This -format is not designed to be read and re-formatted into human-readable text, but -is optimal for a logging aggregation system. +Synapse currently automatically converts the old configuration to the new +configuration, but this will be removed in a future version of Synapse. The +following reference can be used to update your configuration. Based on the drain +`type`, we can pick a new handler: -Arguments: +1. For a type of `console`, `console_json`, or `console_json_terse`: a handler + with a class of `logging.StreamHandler` and a `stream` of `ext://sys.stdout` + or `ext://sys.stderr` should be used. +2. For a type of `file` or `file_json`: a handler of `logging.FileHandler` with + a location of the file path should be used. +3. For a type of `network_json_terse`: a handler of `synapse.logging.RemoteHandler` + with the host and port should be used. -- `location`: Either `stdout` or `stderr`. +Then based on the drain `type` we can pick a new formatter: -### `file` +1. For a type of `console` or `file` no formatter is necessary. +2. For a type of `console_json` or `file_json`: a formatter of + `synapse.logging.JsonFormatter` should be used. +3. For a type of `console_json_terse` or `network_json_terse`: a formatter of + `synapse.logging.TerseJsonFormatter` should be used. -Outputs human-readable logs to a file. +For each new handler and formatter they should be added to the logging configuration +and then assigned to either a logger or the root logger. -Arguments: +An example legacy configuration: -- `location`: An absolute path to the file to log to. +```yaml +structured: true -### `file_json` +loggers: + synapse: + level: INFO + synapse.storage.SQL: + level: WARNING -Outputs machine-readable logs to a file. +drains: + console: + type: console + location: stdout + file: + type: file_json + location: homeserver.log +``` -Arguments: +Would be converted into a new configuration: -- `location`: An absolute path to the file to log to. +```yaml +version: 1 -### `network_json_terse` +formatters: + json: + class: synapse.logging.JsonFormatter -Delivers machine-readable JSON logs to a log aggregator over TCP. This is -compatible with LogStash's TCP input with the codec set to `json_lines`. +handlers: + console: + class: logging.StreamHandler + location: ext://sys.stdout + file: + class: logging.FileHandler + formatter: json + filename: homeserver.log -Arguments: +loggers: + synapse: + level: INFO + handlers: [console, file] + synapse.storage.SQL: + level: WARNING +``` -- `host`: Hostname or IP address of the log aggregator. -- `port`: Numerical port to contact on the host. \ No newline at end of file +The new logging configuration is a bit more verbose, but significantly more +flexible. It allows for configuration that were not previously possible, such as +sending plain logs over the network, or using different handlers for different +modules. diff --git a/mypy.ini b/mypy.ini index 5e9f7b1259..1ece2ba082 100644 --- a/mypy.ini +++ b/mypy.ini @@ -17,6 +17,7 @@ files = synapse/federation, synapse/handlers/_base.py, synapse/handlers/account_data.py, + synapse/handlers/account_validity.py, synapse/handlers/appservice.py, synapse/handlers/auth.py, synapse/handlers/cas_handler.py, @@ -56,7 +57,9 @@ files = synapse/server_notices, synapse/spam_checker_api, synapse/state, + synapse/storage/databases/main/appservice.py, synapse/storage/databases/main/events.py, + synapse/storage/databases/main/registration.py, synapse/storage/databases/main/stream.py, synapse/storage/databases/main/ui_auth.py, synapse/storage/database.py, @@ -80,6 +83,9 @@ ignore_missing_imports = True [mypy-zope] ignore_missing_imports = True +[mypy-bcrypt] +ignore_missing_imports = True + [mypy-constantly] ignore_missing_imports = True diff --git a/pyproject.toml b/pyproject.toml index db4a2e41e4..cd880d4e39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ showcontent = true [tool.black] -target-version = ['py34'] +target-version = ['py35'] exclude = ''' ( diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh index f2b65a2105..f328ab57d5 100755 --- a/scripts-dev/lint.sh +++ b/scripts-dev/lint.sh @@ -80,7 +80,7 @@ else # then lint everything! if [[ -z ${files+x} ]]; then # Lint all source code files and directories - files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py") + files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark") fi fi @@ -94,3 +94,4 @@ isort "${files[@]}" python3 -m black "${files[@]}" ./scripts-dev/config-lint.sh flake8 "${files[@]}" +mypy diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index a5b88731f1..5882f3a0b0 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -19,9 +19,10 @@ can crop up, e.g the cache descriptors. from typing import Callable, Optional +from mypy.nodes import ARG_NAMED_OPT from mypy.plugin import MethodSigContext, Plugin from mypy.typeops import bind_self -from mypy.types import CallableType +from mypy.types import CallableType, NoneType class SynapsePlugin(Plugin): @@ -40,8 +41,9 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: It already has *almost* the correct signature, except: - 1. the `self` argument needs to be marked as "bound"; and - 2. any `cache_context` argument should be removed. + 1. the `self` argument needs to be marked as "bound"; + 2. any `cache_context` argument should be removed; + 3. an optional keyword argument `on_invalidated` should be added. """ # First we mark this as a bound function signature. @@ -58,19 +60,33 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: context_arg_index = idx break + arg_types = list(signature.arg_types) + arg_names = list(signature.arg_names) + arg_kinds = list(signature.arg_kinds) + if context_arg_index: - arg_types = list(signature.arg_types) arg_types.pop(context_arg_index) - - arg_names = list(signature.arg_names) arg_names.pop(context_arg_index) - - arg_kinds = list(signature.arg_kinds) arg_kinds.pop(context_arg_index) - signature = signature.copy_modified( - arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds, - ) + # Third, we add an optional "on_invalidate" argument. + # + # This is a callable which accepts no input and returns nothing. + calltyp = CallableType( + arg_types=[], + arg_kinds=[], + arg_names=[], + ret_type=NoneType(), + fallback=ctx.api.named_generic_type("builtins.function", []), + ) + + arg_types.append(calltyp) + arg_names.append("on_invalidate") + arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg. + + signature = signature.copy_modified( + arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds, + ) return signature diff --git a/setup.py b/setup.py index 2f4a3170d2..9730afb41b 100755 --- a/setup.py +++ b/setup.py @@ -131,6 +131,7 @@ setup( "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", ], scripts=["synctl"] + glob.glob("scripts/*"), cmdclass={"test": TestCommand}, diff --git a/synapse/api/auth.py b/synapse/api/auth.py index bff87fabde..bfcaf68b2a 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -33,6 +33,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.logging import opentracing as opentracing +from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import StateMap, UserID from synapse.util.caches.lrucache import LruCache from synapse.util.metrics import Measure @@ -184,18 +185,12 @@ class Auth: """ try: ip_addr = self.hs.get_ip_from_request(request) - user_agent = request.requestHeaders.getRawHeaders( - b"User-Agent", default=[b""] - )[0].decode("ascii", "surrogateescape") + user_agent = request.get_user_agent("") access_token = self.get_access_token_from_request(request) user_id, app_service = await self._get_appservice_user_id(request) if user_id: - request.authenticated_entity = user_id - opentracing.set_tag("authenticated_entity", user_id) - opentracing.set_tag("appservice_id", app_service.id) - if ip_addr and self._track_appservice_user_ips: await self.store.insert_client_ip( user_id=user_id, @@ -205,31 +200,38 @@ class Auth: device_id="dummy-device", # stubbed ) - return synapse.types.create_requester(user_id, app_service=app_service) + requester = synapse.types.create_requester( + user_id, app_service=app_service + ) + + request.requester = user_id + opentracing.set_tag("authenticated_entity", user_id) + opentracing.set_tag("user_id", user_id) + opentracing.set_tag("appservice_id", app_service.id) + + return requester user_info = await self.get_user_by_access_token( access_token, rights, allow_expired=allow_expired ) - user = user_info["user"] - token_id = user_info["token_id"] - is_guest = user_info["is_guest"] - shadow_banned = user_info["shadow_banned"] + token_id = user_info.token_id + is_guest = user_info.is_guest + shadow_banned = user_info.shadow_banned # Deny the request if the user account has expired. if self._account_validity.enabled and not allow_expired: - user_id = user.to_string() - if await self.store.is_account_expired(user_id, self.clock.time_msec()): + if await self.store.is_account_expired( + user_info.user_id, self.clock.time_msec() + ): raise AuthError( 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT ) - # device_id may not be present if get_user_by_access_token has been - # stubbed out. - device_id = user_info.get("device_id") + device_id = user_info.device_id - if user and access_token and ip_addr: + if access_token and ip_addr: await self.store.insert_client_ip( - user_id=user.to_string(), + user_id=user_info.token_owner, access_token=access_token, ip=ip_addr, user_agent=user_agent, @@ -243,19 +245,23 @@ class Auth: errcode=Codes.GUEST_ACCESS_FORBIDDEN, ) - request.authenticated_entity = user.to_string() - opentracing.set_tag("authenticated_entity", user.to_string()) - if device_id: - opentracing.set_tag("device_id", device_id) - - return synapse.types.create_requester( - user, + requester = synapse.types.create_requester( + user_info.user_id, token_id, is_guest, shadow_banned, device_id, app_service=app_service, + authenticated_entity=user_info.token_owner, ) + + request.requester = requester + opentracing.set_tag("authenticated_entity", user_info.token_owner) + opentracing.set_tag("user_id", user_info.user_id) + if device_id: + opentracing.set_tag("device_id", device_id) + + return requester except KeyError: raise MissingClientTokenError() @@ -286,7 +292,7 @@ class Auth: async def get_user_by_access_token( self, token: str, rights: str = "access", allow_expired: bool = False, - ) -> dict: + ) -> TokenLookupResult: """ Validate access token and get user_id from it Args: @@ -295,13 +301,7 @@ class Auth: allow this allow_expired: If False, raises an InvalidClientTokenError if the token is expired - Returns: - dict that includes: - `user` (UserID) - `is_guest` (bool) - `shadow_banned` (bool) - `token_id` (int|None): access token id. May be None if guest - `device_id` (str|None): device corresponding to access token + Raises: InvalidClientTokenError if a user by that token exists, but the token is expired @@ -311,9 +311,9 @@ class Auth: if rights == "access": # first look in the database - r = await self._look_up_user_by_access_token(token) + r = await self.store.get_user_by_access_token(token) if r: - valid_until_ms = r["valid_until_ms"] + valid_until_ms = r.valid_until_ms if ( not allow_expired and valid_until_ms is not None @@ -330,7 +330,6 @@ class Auth: # otherwise it needs to be a valid macaroon try: user_id, guest = self._parse_and_validate_macaroon(token, rights) - user = UserID.from_string(user_id) if rights == "access": if not guest: @@ -356,23 +355,17 @@ class Auth: raise InvalidClientTokenError( "Guest access token used for regular user" ) - ret = { - "user": user, - "is_guest": True, - "shadow_banned": False, - "token_id": None, + + ret = TokenLookupResult( + user_id=user_id, + is_guest=True, # all guests get the same device id - "device_id": GUEST_DEVICE_ID, - } + device_id=GUEST_DEVICE_ID, + ) elif rights == "delete_pusher": # We don't store these tokens in the database - ret = { - "user": user, - "is_guest": False, - "shadow_banned": False, - "token_id": None, - "device_id": None, - } + + ret = TokenLookupResult(user_id=user_id, is_guest=False) else: raise RuntimeError("Unknown rights setting %s", rights) return ret @@ -481,31 +474,15 @@ class Auth: now = self.hs.get_clock().time_msec() return now < expiry - async def _look_up_user_by_access_token(self, token): - ret = await self.store.get_user_by_access_token(token) - if not ret: - return None - - # we use ret.get() below because *lots* of unit tests stub out - # get_user_by_access_token in a way where it only returns a couple of - # the fields. - user_info = { - "user": UserID.from_string(ret.get("name")), - "token_id": ret.get("token_id", None), - "is_guest": False, - "shadow_banned": ret.get("shadow_banned"), - "device_id": ret.get("device_id"), - "valid_until_ms": ret.get("valid_until_ms"), - } - return user_info - def get_appservice_by_req(self, request): token = self.get_access_token_from_request(request) service = self.store.get_app_service_by_token(token) if not service: logger.warning("Unrecognised appservice access token.") raise InvalidClientTokenError() - request.authenticated_entity = service.sender + request.requester = synapse.types.create_requester( + service.sender, app_service=service + ) return service async def is_server_admin(self, user: UserID) -> bool: diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index f70841ae86..3944780a42 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -52,11 +52,11 @@ class ApplicationService: self, token, hostname, + id, + sender, url=None, namespaces=None, hs_token=None, - sender=None, - id=None, protocols=None, rate_limited=True, ip_range_whitelist=None, diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 4526c1a67b..2f97e6d258 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -26,14 +26,14 @@ class CasConfig(Config): def read_config(self, config, **kwargs): cas_config = config.get("cas_config", None) - if cas_config: - self.cas_enabled = cas_config.get("enabled", True) + self.cas_enabled = cas_config and cas_config.get("enabled", True) + + if self.cas_enabled: self.cas_server_url = cas_config["server_url"] self.cas_service_url = cas_config["service_url"] self.cas_displayname_attribute = cas_config.get("displayname_attribute") - self.cas_required_attributes = cas_config.get("required_attributes", {}) + self.cas_required_attributes = cas_config.get("required_attributes") or {} else: - self.cas_enabled = False self.cas_server_url = None self.cas_service_url = None self.cas_displayname_attribute = None @@ -41,13 +41,35 @@ class CasConfig(Config): def generate_config_section(self, config_dir_path, server_name, **kwargs): return """ - # Enable CAS for registration and login. + # Enable Central Authentication Service (CAS) for registration and login. # - #cas_config: - # enabled: true - # server_url: "https://cas-server.com" - # service_url: "https://homeserver.domain.com:8448" - # #displayname_attribute: name - # #required_attributes: - # # name: value + cas_config: + # Uncomment the following to enable authorization against a CAS server. + # Defaults to false. + # + #enabled: true + + # The URL of the CAS authorization endpoint. + # + #server_url: "https://cas-server.com" + + # The public URL of the homeserver. + # + #service_url: "https://homeserver.domain.com:8448" + + # The attribute of the CAS response to use as the display name. + # + # If unset, no displayname will be set. + # + #displayname_attribute: name + + # It is possible to configure Synapse to only allow logins if CAS attributes + # match particular values. All of the keys in the mapping below must exist + # and the values must match the given value. Alternately if the given value + # is None then any value is allowed (the attribute just must exist). + # All of the listed attributes must match for the login to be permitted. + # + #required_attributes: + # userGroup: "staff" + # department: None """ diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py index 3252ad9e7f..f30330abb6 100644 --- a/synapse/config/jwt_config.py +++ b/synapse/config/jwt_config.py @@ -63,7 +63,7 @@ class JWTConfig(Config): # and issued at ("iat") claims are validated if present. # # Note that this is a non-standard login type and client support is - # expected to be non-existant. + # expected to be non-existent. # # See https://github.com/matrix-org/synapse/blob/master/docs/jwt.md. # diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 13d6f6a3ea..d4e887a3e0 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -23,7 +23,6 @@ from string import Template import yaml from twisted.logger import ( - ILogObserver, LogBeginner, STDLibLogObserver, eventAsText, @@ -32,11 +31,9 @@ from twisted.logger import ( import synapse from synapse.app import _base as appbase -from synapse.logging._structured import ( - reload_structured_logging, - setup_structured_logging, -) +from synapse.logging._structured import setup_structured_logging from synapse.logging.context import LoggingContextFilter +from synapse.logging.filter import MetadataFilter from synapse.util.versionstring import get_version_string from ._base import Config, ConfigError @@ -48,7 +45,11 @@ DEFAULT_LOG_CONFIG = Template( # This is a YAML file containing a standard Python logging configuration # dictionary. See [1] for details on the valid settings. # +# Synapse also supports structured logging for machine readable logs which can +# be ingested by ELK stacks. See [2] for details. +# # [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema +# [2]: https://github.com/matrix-org/synapse/blob/master/docs/structured_logging.md version: 1 @@ -105,7 +106,7 @@ root: # then write them to a file. # # Replace "buffer" with "console" to log to stderr instead. (Note that you'll - # also need to update the configuation for the `twisted` logger above, in + # also need to update the configuration for the `twisted` logger above, in # this case.) # handlers: [buffer] @@ -176,11 +177,11 @@ class LoggingConfig(Config): log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file)) -def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner): +def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> None: """ - Set up Python stdlib logging. + Set up Python standard library logging. """ - if log_config is None: + if log_config_path is None: log_format = ( "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" " - %(message)s" @@ -196,7 +197,8 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner): handler.setFormatter(formatter) logger.addHandler(handler) else: - logging.config.dictConfig(log_config) + # Load the logging configuration. + _load_logging_config(log_config_path) # We add a log record factory that runs all messages through the # LoggingContextFilter so that we get the context *at the time we log* @@ -204,12 +206,14 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner): # filter options, but care must when using e.g. MemoryHandler to buffer # writes. - log_filter = LoggingContextFilter(request="") + log_context_filter = LoggingContextFilter(request="") + log_metadata_filter = MetadataFilter({"server_name": config.server_name}) old_factory = logging.getLogRecordFactory() def factory(*args, **kwargs): record = old_factory(*args, **kwargs) - log_filter.filter(record) + log_context_filter.filter(record) + log_metadata_filter.filter(record) return record logging.setLogRecordFactory(factory) @@ -255,21 +259,40 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner): if not config.no_redirect_stdio: print("Redirected stdout/stderr to logs") - return observer - -def _reload_stdlib_logging(*args, log_config=None): - logger = logging.getLogger("") +def _load_logging_config(log_config_path: str) -> None: + """ + Configure logging from a log config path. + """ + with open(log_config_path, "rb") as f: + log_config = yaml.safe_load(f.read()) if not log_config: - logger.warning("Reloaded a blank config?") + logging.warning("Loaded a blank logging config?") + + # If the old structured logging configuration is being used, convert it to + # the new style configuration. + if "structured" in log_config and log_config.get("structured"): + log_config = setup_structured_logging(log_config) logging.config.dictConfig(log_config) +def _reload_logging_config(log_config_path): + """ + Reload the log configuration from the file and apply it. + """ + # If no log config path was given, it cannot be reloaded. + if log_config_path is None: + return + + _load_logging_config(log_config_path) + logging.info("Reloaded log config from %s due to SIGHUP", log_config_path) + + def setup_logging( hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner -) -> ILogObserver: +) -> None: """ Set up the logging subsystem. @@ -282,41 +305,18 @@ def setup_logging( logBeginner: The Twisted logBeginner to use. - Returns: - The "root" Twisted Logger observer, suitable for sending logs to from a - Logger instance. """ - log_config = config.worker_log_config if use_worker_options else config.log_config - - def read_config(*args, callback=None): - if log_config is None: - return None - - with open(log_config, "rb") as f: - log_config_body = yaml.safe_load(f.read()) - - if callback: - callback(log_config=log_config_body) - logging.info("Reloaded log config from %s due to SIGHUP", log_config) - - return log_config_body + log_config_path = ( + config.worker_log_config if use_worker_options else config.log_config + ) - log_config_body = read_config() + # Perform one-time logging configuration. + _setup_stdlib_logging(config, log_config_path, logBeginner=logBeginner) + # Add a SIGHUP handler to reload the logging configuration, if one is available. + appbase.register_sighup(_reload_logging_config, log_config_path) - if log_config_body and log_config_body.get("structured") is True: - logger = setup_structured_logging( - hs, config, log_config_body, logBeginner=logBeginner - ) - appbase.register_sighup(read_config, callback=reload_structured_logging) - else: - logger = _setup_stdlib_logging(config, log_config_body, logBeginner=logBeginner) - appbase.register_sighup(read_config, callback=_reload_stdlib_logging) - - # make sure that the first thing we log is a thing we can grep backwards - # for + # Log immediately so we can grep backwards. logging.warning("***** STARTING SERVER *****") logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse)) logging.info("Server hostname: %s", config.server_name) logging.info("Instance name: %s", hs.get_instance_name()) - - return logger diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 7597fbc864..69d188341c 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -87,11 +87,10 @@ class OIDCConfig(Config): def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ - # OpenID Connect integration. The following settings can be used to make Synapse - # use an OpenID Connect Provider for authentication, instead of its internal - # password database. + # Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login. # - # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md. + # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md + # for some example configurations. # oidc_config: # Uncomment the following to enable authorization against an OpenID Connect diff --git a/synapse/config/registration.py b/synapse/config/registration.py index d7e3690a32..b0a77a2e43 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -143,7 +143,7 @@ class RegistrationConfig(Config): RoomCreationPreset.TRUSTED_PRIVATE_CHAT, } - # Pull the creater/inviter from the configuration, this gets used to + # Pull the creator/inviter from the configuration, this gets used to # send invites for invite-only rooms. mxid_localpart = config.get("auto_join_mxid_localpart") self.auto_join_user_id = None diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 6de1f9d103..92e1b67528 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -99,7 +99,7 @@ class RoomDirectoryConfig(Config): # # Options for the rules include: # - # user_id: Matches agaisnt the creator of the alias + # user_id: Matches against the creator of the alias # room_id: Matches against the room ID being published # alias: Matches against any current local or canonical aliases # associated with the room diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 99aa8b3bf1..778750f43b 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -216,10 +216,8 @@ class SAML2Config(Config): return """\ ## Single sign-on integration ## - # Enable SAML2 for registration and login. Uses pysaml2. - # - # At least one of `sp_config` or `config_path` must be set in this section to - # enable SAML login. + # The following settings can be used to make Synapse use a single sign-on + # provider for authentication, instead of its internal password database. # # You will probably also want to set the following options to `false` to # disable the regular login/registration flows: @@ -228,6 +226,11 @@ class SAML2Config(Config): # # You will also want to investigate the settings under the "sso" configuration # section below. + + # Enable SAML2 for registration and login. Uses pysaml2. + # + # At least one of `sp_config` or `config_path` must be set in this section to + # enable SAML login. # # Once SAML support is enabled, a metadata file will be exposed at # https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to @@ -243,40 +246,42 @@ class SAML2Config(Config): # so it is not normally necessary to specify them unless you need to # override them. # - #sp_config: - # # point this to the IdP's metadata. You can use either a local file or - # # (preferably) a URL. - # metadata: - # #local: ["saml2/idp.xml"] - # remote: - # - url: https://our_idp/metadata.xml - # - # # By default, the user has to go to our login page first. If you'd like - # # to allow IdP-initiated login, set 'allow_unsolicited: true' in a - # # 'service.sp' section: - # # - # #service: - # # sp: - # # allow_unsolicited: true - # - # # The examples below are just used to generate our metadata xml, and you - # # may well not need them, depending on your setup. Alternatively you - # # may need a whole lot more detail - see the pysaml2 docs! - # - # description: ["My awesome SP", "en"] - # name: ["Test SP", "en"] - # - # organization: - # name: Example com - # display_name: - # - ["Example co", "en"] - # url: "http://example.com" - # - # contact_person: - # - given_name: Bob - # sur_name: "the Sysadmin" - # email_address": ["admin@example.com"] - # contact_type": technical + sp_config: + # Point this to the IdP's metadata. You must provide either a local + # file via the `local` attribute or (preferably) a URL via the + # `remote` attribute. + # + #metadata: + # local: ["saml2/idp.xml"] + # remote: + # - url: https://our_idp/metadata.xml + + # By default, the user has to go to our login page first. If you'd like + # to allow IdP-initiated login, set 'allow_unsolicited: true' in a + # 'service.sp' section: + # + #service: + # sp: + # allow_unsolicited: true + + # The examples below are just used to generate our metadata xml, and you + # may well not need them, depending on your setup. Alternatively you + # may need a whole lot more detail - see the pysaml2 docs! + + #description: ["My awesome SP", "en"] + #name: ["Test SP", "en"] + + #organization: + # name: Example com + # display_name: + # - ["Example co", "en"] + # url: "http://example.com" + + #contact_person: + # - given_name: Bob + # sur_name: "the Sysadmin" + # email_address": ["admin@example.com"] + # contact_type": technical # Instead of putting the config inline as above, you can specify a # separate pysaml2 configuration file: diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py index 8be1346113..0c1a854f09 100644 --- a/synapse/config/tracer.py +++ b/synapse/config/tracer.py @@ -67,7 +67,7 @@ class TracerConfig(Config): # This is a list of regexes which are matched against the server_name of the # homeserver. # - # By defult, it is empty, so no servers are matched. + # By default, it is empty, so no servers are matched. # #homeserver_whitelist: # - ".*" diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 79668a402e..57fd426e87 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -149,7 +149,7 @@ class FederationPolicyForHTTPS: return SSLClientConnectionCreator(host, ssl_context, should_verify) def creatorForNetloc(self, hostname, port): - """Implements the IPolicyForHTTPS interace so that this can be passed + """Implements the IPolicyForHTTPS interface so that this can be passed directly to agents. """ return self.get_options(hostname) diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 65df62107f..8028663fa8 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -59,7 +59,7 @@ class DictProperty: # # To exclude the KeyError from the traceback, we explicitly # 'raise from e1.__context__' (which is better than 'raise from None', - # becuase that would omit any *earlier* exceptions). + # because that would omit any *earlier* exceptions). # raise AttributeError( "'%s' has no '%s' property" % (type(instance), self.key) @@ -368,7 +368,7 @@ class FrozenEvent(EventBase): return self.__repr__() def __repr__(self): - return "<FrozenEvent event_id='%s', type='%s', state_key='%s'>" % ( + return "<FrozenEvent event_id=%r, type=%r, state_key=%r>" % ( self.get("event_id", None), self.get("type", None), self.get("state_key", None), @@ -451,7 +451,7 @@ class FrozenEventV2(EventBase): return self.__repr__() def __repr__(self): - return "<%s event_id='%s', type='%s', state_key='%s'>" % ( + return "<%s event_id=%r, type=%r, state_key=%r>" % ( self.__class__.__name__, self.event_id, self.get("type", None), diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 355cbe05f1..14f7f1156f 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -180,7 +180,7 @@ def only_fields(dictionary, fields): in 'fields'. If there are no event fields specified then all fields are included. - The entries may include '.' charaters to indicate sub-fields. + The entries may include '.' characters to indicate sub-fields. So ['content.body'] will include the 'body' field of the 'content' object. A literal '.' character in a field name may be escaped using a '\'. diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 3a6b95631e..a0933fae88 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -154,7 +154,7 @@ class Authenticator: ) logger.debug("Request from %s", origin) - request.authenticated_entity = origin + request.requester = origin # If we get a valid signed request from the other side, its probably # alive diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index a86b3debc5..41cf07cc88 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -22,7 +22,7 @@ attestations have a validity period so need to be periodically renewed. If a user leaves (or gets kicked out of) a group, either side can still use their attestation to "prove" their membership, until the attestation expires. Therefore attestations shouldn't be relied on to prove membership in important -cases, but can for less important situtations, e.g. showing a users membership +cases, but can for less important situations, e.g. showing a users membership of groups on their profile, showing flairs, etc. An attestation is a signed blob of json that looks like: diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index e5f85b472d..0d042cbfac 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -113,7 +113,7 @@ class GroupsServerWorkerHandler: entry = await self.room_list_handler.generate_room_entry( room_id, len(joined_users), with_alias=False, allow_private=True ) - entry = dict(entry) # so we don't change whats cached + entry = dict(entry) # so we don't change what's cached entry.pop("room_id", None) room_entry["profile"] = entry @@ -550,7 +550,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): group_id, room_id, is_public=is_public ) else: - raise SynapseError(400, "Uknown config option") + raise SynapseError(400, "Unknown config option") return {} diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index fd4f762f33..664d09da1c 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -18,19 +18,22 @@ import email.utils import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -from typing import List +from typing import TYPE_CHECKING, List -from synapse.api.errors import StoreError +from synapse.api.errors import StoreError, SynapseError from synapse.logging.context import make_deferred_yieldable from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.types import UserID from synapse.util import stringutils +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class AccountValidityHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.config = hs.config self.store = self.hs.get_datastore() @@ -67,7 +70,7 @@ class AccountValidityHandler: self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000) @wrap_as_background_process("send_renewals") - async def _send_renewal_emails(self): + async def _send_renewal_emails(self) -> None: """Gets the list of users whose account is expiring in the amount of time configured in the ``renew_at`` parameter from the ``account_validity`` configuration, and sends renewal emails to all of these users as long as they @@ -81,11 +84,25 @@ class AccountValidityHandler: user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] ) - async def send_renewal_email_to_user(self, user_id: str): + async def send_renewal_email_to_user(self, user_id: str) -> None: + """ + Send a renewal email for a specific user. + + Args: + user_id: The user ID to send a renewal email for. + + Raises: + SynapseError if the user is not set to renew. + """ expiration_ts = await self.store.get_expiration_ts_for_user(user_id) + + # If this user isn't set to be expired, raise an error. + if expiration_ts is None: + raise SynapseError(400, "User has no expiration time: %s" % (user_id,)) + await self._send_renewal_email(user_id, expiration_ts) - async def _send_renewal_email(self, user_id: str, expiration_ts: int): + async def _send_renewal_email(self, user_id: str, expiration_ts: int) -> None: """Sends out a renewal email to every email address attached to the given user with a unique link allowing them to renew their account. diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 1ce2091b46..a703944543 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -88,7 +88,7 @@ class AdminHandler(BaseHandler): # We only try and fetch events for rooms the user has been in. If # they've been e.g. invited to a room without joining then we handle - # those seperately. + # those separately. rooms_user_has_been_in = await self.store.get_rooms_user_has_been_in(user_id) for index, room in enumerate(rooms): @@ -226,7 +226,7 @@ class ExfiltrationWriter: """ def finished(self): - """Called when all data has succesfully been exported and written. + """Called when all data has successfully been exported and written. This functions return value is passed to the caller of `export_user_data`. diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 64dea23fc5..9fc8444228 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -12,9 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Union from prometheus_client import Counter @@ -30,17 +29,24 @@ from synapse.metrics import ( event_processing_loop_counter, event_processing_loop_room_count, ) -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import Collection, JsonDict, RoomStreamToken, UserID +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) +from synapse.storage.databases.main.directory import RoomAliasMapping +from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "") class ApplicationServicesHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.is_mine_id = hs.is_mine_id self.appservice_api = hs.get_application_service_api() @@ -53,7 +59,7 @@ class ApplicationServicesHandler: self.current_max = 0 self.is_processing = False - async def notify_interested_services(self, max_token: RoomStreamToken): + def notify_interested_services(self, max_token: RoomStreamToken): """Notifies (pushes) all application services interested in this event. Pushing is done asynchronously, so this method won't block for any @@ -72,6 +78,12 @@ class ApplicationServicesHandler: if self.is_processing: return + # We only start a new background process if necessary rather than + # optimistically (to cut down on overhead). + self._notify_interested_services(max_token) + + @wrap_as_background_process("notify_interested_services") + async def _notify_interested_services(self, max_token: RoomStreamToken): with Measure(self.clock, "notify_interested_services"): self.is_processing = True try: @@ -166,8 +178,11 @@ class ApplicationServicesHandler: finally: self.is_processing = False - async def notify_interested_services_ephemeral( - self, stream_key: str, new_token: Optional[int], users: Collection[UserID] = [], + def notify_interested_services_ephemeral( + self, + stream_key: str, + new_token: Optional[int], + users: Collection[Union[str, UserID]] = [], ): """This is called by the notifier in the background when a ephemeral event handled by the homeserver. @@ -183,13 +198,34 @@ class ApplicationServicesHandler: new_token: The latest stream token users: The user(s) involved with the event. """ + if not self.notify_appservices: + return + + if stream_key not in ("typing_key", "receipt_key", "presence_key"): + return + services = [ service for service in self.store.get_app_services() if service.supports_ephemeral ] - if not services or not self.notify_appservices: + if not services: return + + # We only start a new background process if necessary rather than + # optimistically (to cut down on overhead). + self._notify_interested_services_ephemeral( + services, stream_key, new_token, users + ) + + @wrap_as_background_process("notify_interested_services_ephemeral") + async def _notify_interested_services_ephemeral( + self, + services: List[ApplicationService], + stream_key: str, + new_token: Optional[int], + users: Collection[Union[str, UserID]], + ): logger.info("Checking interested services for %s" % (stream_key)) with Measure(self.clock, "notify_interested_services_ephemeral"): for service in services: @@ -214,7 +250,9 @@ class ApplicationServicesHandler: service, "presence", new_token ) - async def _handle_typing(self, service: ApplicationService, new_token: int): + async def _handle_typing( + self, service: ApplicationService, new_token: int + ) -> List[JsonDict]: typing_source = self.event_sources.sources["typing"] # Get the typing events from just before current typing, _ = await typing_source.get_new_events_as( @@ -226,7 +264,7 @@ class ApplicationServicesHandler: ) return typing - async def _handle_receipts(self, service: ApplicationService): + async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]: from_key = await self.store.get_type_stream_id_for_appservice( service, "read_receipt" ) @@ -237,7 +275,7 @@ class ApplicationServicesHandler: return receipts async def _handle_presence( - self, service: ApplicationService, users: Collection[UserID] + self, service: ApplicationService, users: Collection[Union[str, UserID]] ) -> List[JsonDict]: events = [] # type: List[JsonDict] presence_source = self.event_sources.sources["presence"] @@ -245,6 +283,9 @@ class ApplicationServicesHandler: service, "presence" ) for user in users: + if isinstance(user, str): + user = UserID.from_string(user) + interested = await service.is_interested_in_presence(user, self.store) if not interested: continue @@ -265,11 +306,11 @@ class ApplicationServicesHandler: return events - async def query_user_exists(self, user_id): + async def query_user_exists(self, user_id: str) -> bool: """Check if any application service knows this user_id exists. Args: - user_id(str): The user to query if they exist on any AS. + user_id: The user to query if they exist on any AS. Returns: True if this user exists on at least one application service. """ @@ -280,11 +321,13 @@ class ApplicationServicesHandler: return True return False - async def query_room_alias_exists(self, room_alias): + async def query_room_alias_exists( + self, room_alias: RoomAlias + ) -> Optional[RoomAliasMapping]: """Check if an application service knows this room alias exists. Args: - room_alias(RoomAlias): The room alias to query. + room_alias: The room alias to query. Returns: namedtuple: with keys "room_id" and "servers" or None if no association can be found. @@ -300,10 +343,13 @@ class ApplicationServicesHandler: ) if is_known_alias: # the alias exists now so don't query more ASes. - result = await self.store.get_association_from_room_alias(room_alias) - return result + return await self.store.get_association_from_room_alias(room_alias) - async def query_3pe(self, kind, protocol, fields): + return None + + async def query_3pe( + self, kind: str, protocol: str, fields: Dict[bytes, List[bytes]] + ) -> List[JsonDict]: services = self._get_services_for_3pn(protocol) results = await make_deferred_yieldable( @@ -325,7 +371,9 @@ class ApplicationServicesHandler: return ret - async def get_3pe_protocols(self, only_protocol=None): + async def get_3pe_protocols( + self, only_protocol: Optional[str] = None + ) -> Dict[str, JsonDict]: services = self.store.get_app_services() protocols = {} # type: Dict[str, List[JsonDict]] @@ -343,7 +391,7 @@ class ApplicationServicesHandler: if info is not None: protocols[p].append(info) - def _merge_instances(infos): + def _merge_instances(infos: List[JsonDict]) -> JsonDict: if not infos: return {} @@ -358,19 +406,17 @@ class ApplicationServicesHandler: return combined - for p in protocols.keys(): - protocols[p] = _merge_instances(protocols[p]) - - return protocols + return {p: _merge_instances(protocols[p]) for p in protocols.keys()} - async def _get_services_for_event(self, event): + async def _get_services_for_event( + self, event: EventBase + ) -> List[ApplicationService]: """Retrieve a list of application services interested in this event. Args: - event(Event): The event to check. Can be None if alias_list is not. + event: The event to check. Can be None if alias_list is not. Returns: - list<ApplicationService>: A list of services interested in this - event based on the service regex. + A list of services interested in this event based on the service regex. """ services = self.store.get_app_services() @@ -384,17 +430,15 @@ class ApplicationServicesHandler: return interested_list - def _get_services_for_user(self, user_id): + def _get_services_for_user(self, user_id: str) -> List[ApplicationService]: services = self.store.get_app_services() - interested_list = [s for s in services if (s.is_interested_in_user(user_id))] - return interested_list + return [s for s in services if (s.is_interested_in_user(user_id))] - def _get_services_for_3pn(self, protocol): + def _get_services_for_3pn(self, protocol: str) -> List[ApplicationService]: services = self.store.get_app_services() - interested_list = [s for s in services if s.is_interested_in_protocol(protocol)] - return interested_list + return [s for s in services if s.is_interested_in_protocol(protocol)] - async def _is_unknown_user(self, user_id): + async def _is_unknown_user(self, user_id: str) -> bool: if not self.is_mine_id(user_id): # we don't know if they are unknown or not since it isn't one of our # users. We can't poke ASes. @@ -409,9 +453,8 @@ class ApplicationServicesHandler: service_list = [s for s in services if s.sender == user_id] return len(service_list) == 0 - async def _check_user_exists(self, user_id): + async def _check_user_exists(self, user_id: str) -> bool: unknown_user = await self._is_unknown_user(user_id) if unknown_user: - exists = await self.query_user_exists(user_id) - return exists + return await self.query_user_exists(user_id) return True diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 8619fbb982..ff103cbb92 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -18,10 +18,20 @@ import logging import time import unicodedata import urllib.parse -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) import attr -import bcrypt # type: ignore[import] +import bcrypt import pymacaroons from synapse.api.constants import LoginType @@ -49,6 +59,9 @@ from synapse.util.threepids import canonicalise_email from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -149,11 +162,7 @@ class SsoLoginExtraAttributes: class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker] @@ -470,9 +479,7 @@ class AuthHandler(BaseHandler): # authentication flow. await self.store.set_ui_auth_clientdict(sid, clientdict) - user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ - 0 - ].decode("ascii", "surrogateescape") + user_agent = request.get_user_agent("") await self.store.add_user_agent_ip_to_ui_auth_session( session.session_id, user_agent, clientip @@ -692,7 +699,7 @@ class AuthHandler(BaseHandler): Creates a new access token for the user with the given user ID. The user is assumed to have been authenticated by some other - machanism (e.g. CAS), and the user_id converted to the canonical case. + mechanism (e.g. CAS), and the user_id converted to the canonical case. The device will be recorded in the table if it is not there already. @@ -984,17 +991,17 @@ class AuthHandler(BaseHandler): # This might return an awaitable, if it does block the log out # until it completes. result = provider.on_logged_out( - user_id=str(user_info["user"]), - device_id=user_info["device_id"], + user_id=user_info.user_id, + device_id=user_info.device_id, access_token=access_token, ) if inspect.isawaitable(result): await result # delete pushers associated with this access token - if user_info["token_id"] is not None: + if user_info.token_id is not None: await self.hs.get_pusherpool().remove_pushers_by_access_token( - str(user_info["user"]), (user_info["token_id"],) + user_info.user_id, (user_info.token_id,) ) async def delete_access_tokens_for_user( diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index a4cc4b9a5a..048a3b3c0b 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -212,9 +212,7 @@ class CasHandler: else: if not registered_user_id: # Pull out the user-agent and IP from the request. - user_agent = request.requestHeaders.getRawHeaders( - b"User-Agent", default=[b""] - )[0].decode("ascii", "surrogateescape") + user_agent = request.get_user_agent("") ip_address = self.hs.get_ip_from_request(request) registered_user_id = await self._registration_handler.register_user( diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 611742ae72..929752150d 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -129,6 +129,11 @@ class E2eKeysHandler: if user_id in local_query: results[user_id] = keys + # Get cached cross-signing keys + cross_signing_keys = await self.get_cross_signing_keys_from_cache( + device_keys_query, from_user_id + ) + # Now attempt to get any remote devices from our local cache. remote_queries_not_in_cache = {} if remote_queries: @@ -155,16 +160,28 @@ class E2eKeysHandler: unsigned["device_display_name"] = device_display_name user_devices[device_id] = result + # check for missing cross-signing keys. + for user_id in remote_queries.keys(): + cached_cross_master = user_id in cross_signing_keys["master_keys"] + cached_cross_selfsigning = ( + user_id in cross_signing_keys["self_signing_keys"] + ) + + # check if we are missing only one of cross-signing master or + # self-signing key, but the other one is cached. + # as we need both, this will issue a federation request. + # if we don't have any of the keys, either the user doesn't have + # cross-signing set up, or the cached device list + # is not (yet) updated. + if cached_cross_master ^ cached_cross_selfsigning: + user_ids_not_in_cache.add(user_id) + + # add those users to the list to fetch over federation. for user_id in user_ids_not_in_cache: domain = get_domain_from_id(user_id) r = remote_queries_not_in_cache.setdefault(domain, {}) r[user_id] = remote_queries[user_id] - # Get cached cross-signing keys - cross_signing_keys = await self.get_cross_signing_keys_from_cache( - device_keys_query, from_user_id - ) - # Now fetch any devices that we don't have in our cache @trace async def do_remote_query(destination): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index fde8f00531..c386957706 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -112,7 +112,7 @@ class FederationHandler(BaseHandler): """Handles events that originated from federation. Responsible for: a) handling received Pdus before handing them on as Events to the rest - of the homeserver (including auth and state conflict resoultion) + of the homeserver (including auth and state conflict resolutions) b) converting events that were produced by local clients that may need to be sent to remote homeservers. c) doing the necessary dances to invite remote users and join remote @@ -477,7 +477,7 @@ class FederationHandler(BaseHandler): # ---- # # Update richvdh 2018/09/18: There are a number of problems with timing this - # request out agressively on the client side: + # request out aggressively on the client side: # # - it plays badly with the server-side rate-limiter, which starts tarpitting you # if you send too many requests at once, so you end up with the server carefully @@ -495,13 +495,13 @@ class FederationHandler(BaseHandler): # we'll end up back here for the *next* PDU in the list, which exacerbates the # problem. # - # - the agressive 10s timeout was introduced to deal with incoming federation + # - the aggressive 10s timeout was introduced to deal with incoming federation # requests taking 8 hours to process. It's not entirely clear why that was going # on; certainly there were other issues causing traffic storms which are now # resolved, and I think in any case we may be more sensible about our locking # now. We're *certainly* more sensible about our logging. # - # All that said: Let's try increasing the timout to 60s and see what happens. + # All that said: Let's try increasing the timeout to 60s and see what happens. try: missing_events = await self.federation_client.get_missing_events( @@ -1120,7 +1120,7 @@ class FederationHandler(BaseHandler): logger.info(str(e)) continue except RequestSendFailed as e: - logger.info("Falied to get backfill from %s because %s", dom, e) + logger.info("Failed to get backfill from %s because %s", dom, e) continue except FederationDeniedError as e: logger.info(e) @@ -1545,7 +1545,7 @@ class FederationHandler(BaseHandler): # # The reasons we have the destination server rather than the origin # server send it are slightly mysterious: the origin server should have - # all the neccessary state once it gets the response to the send_join, + # all the necessary state once it gets the response to the send_join, # so it could send the event itself if it wanted to. It may be that # doing it this way reduces failure modes, or avoids certain attacks # where a new server selectively tells a subset of the federation that @@ -1649,7 +1649,7 @@ class FederationHandler(BaseHandler): event.internal_metadata.outlier = True event.internal_metadata.out_of_band_membership = True - # Try the host that we succesfully called /make_leave/ on first for + # Try the host that we successfully called /make_leave/ on first for # the /send_leave/ request. host_list = list(target_hosts) try: diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 9684e60fc8..abd8d2af44 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -17,7 +17,7 @@ import logging from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError -from synapse.types import get_domain_from_id +from synapse.types import GroupID, get_domain_from_id logger = logging.getLogger(__name__) @@ -28,6 +28,9 @@ def _create_rerouter(func_name): """ async def f(self, group_id, *args, **kwargs): + if not GroupID.is_valid(group_id): + raise SynapseError(400, "%s was not legal group ID" % (group_id,)) + if self.is_mine_id(group_id): return await getattr(self.groups_server_handler, func_name)( group_id, *args, **kwargs @@ -346,7 +349,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): server_name=get_domain_from_id(group_id), ) - # TODO: Check that the group is public and we're being added publically + # TODO: Check that the group is public and we're being added publicly is_publicised = content.get("publicise", False) token = await self.store.register_user_group_membership( @@ -391,7 +394,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): server_name=get_domain_from_id(group_id), ) - # TODO: Check that the group is public and we're being added publically + # TODO: Check that the group is public and we're being added publicly is_publicised = content.get("publicise", False) token = await self.store.register_user_group_membership( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index fb0a04e9a7..ca5602c13e 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -656,7 +656,7 @@ class EventCreationHandler: context: The event context. Returns: - The previous verion of the event is returned, if it is found in the + The previous version of the event is returned, if it is found in the event context. Otherwise, None is returned. """ prev_state_ids = await context.get_prev_state_ids() @@ -1099,34 +1099,13 @@ class EventCreationHandler: if event.type == EventTypes.Member: if event.content["membership"] == Membership.INVITE: - - def is_inviter_member_event(e): - return e.type == EventTypes.Member and e.sender == event.sender - - current_state_ids = await context.get_current_state_ids() - - # We know this event is not an outlier, so this must be - # non-None. - assert current_state_ids is not None - - state_to_include_ids = [ - e_id - for k, e_id in current_state_ids.items() - if k[0] in self.room_invite_state_types - or k == (EventTypes.Member, event.sender) - ] - - state_to_include = await self.store.get_events(state_to_include_ids) - - event.unsigned["invite_room_state"] = [ - { - "type": e.type, - "state_key": e.state_key, - "content": e.content, - "sender": e.sender, - } - for e in state_to_include.values() - ] + event.unsigned[ + "invite_room_state" + ] = await self.store.get_stripped_room_state_from_event_context( + context, + self.room_invite_state_types, + membership_user_id=event.sender, + ) invitee = UserID.from_string(event.state_key) if not self.hs.is_mine(invitee): diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 05ac86e697..331d4e7e96 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -217,7 +217,7 @@ class OidcHandler: This is based on the requested scopes: if the scopes include ``openid``, the provider should give use an ID token containing the - user informations. If not, we should fetch them using the + user information. If not, we should fetch them using the ``access_token`` with the ``userinfo_endpoint``. """ @@ -426,7 +426,7 @@ class OidcHandler: return resp async def _fetch_userinfo(self, token: Token) -> UserInfo: - """Fetch user informations from the ``userinfo_endpoint``. + """Fetch user information from the ``userinfo_endpoint``. Args: token: the token given by the ``token_endpoint``. @@ -695,9 +695,7 @@ class OidcHandler: return # Pull out the user-agent and IP from the request. - user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ - 0 - ].decode("ascii", "surrogateescape") + user_agent = request.get_user_agent("") ip_address = self.hs.get_ip_from_request(request) # Call the mapper to register/login the user @@ -756,7 +754,7 @@ class OidcHandler: Defaults to an hour. Returns: - A signed macaroon token with the session informations. + A signed macaroon token with the session information. """ macaroon = pymacaroons.Macaroon( location=self._server_name, identifier="key", key=self._macaroon_secret_key, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 1000ac95ff..8e014c9bb5 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -48,7 +48,7 @@ from synapse.util.wheel_timer import WheelTimer MYPY = False if MYPY: - import synapse.server + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -101,7 +101,7 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER class BasePresenceHandler(abc.ABC): """Parts of the PresenceHandler that are shared between workers and master""" - def __init__(self, hs: "synapse.server.HomeServer"): + def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastore() @@ -199,7 +199,7 @@ class BasePresenceHandler(abc.ABC): class PresenceHandler(BasePresenceHandler): - def __init__(self, hs: "synapse.server.HomeServer"): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs self.is_mine_id = hs.is_mine_id @@ -802,7 +802,7 @@ class PresenceHandler(BasePresenceHandler): between the requested tokens due to the limit. The token returned can be used in a subsequent call to this - function to get further updatees. + function to get further updates. The updates are a list of 2-tuples of stream ID and the row data """ @@ -977,7 +977,7 @@ def should_notify(old_state, new_state): new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY ): - # Only notify about last active bumps if we're not currently acive + # Only notify about last active bumps if we're not currently active if not new_state.currently_active: notify_reason_counter.labels("last_active_change_online").inc() return True @@ -1011,7 +1011,7 @@ def format_user_presence_state(state, now, include_user_id=True): class PresenceEventSource: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): # We can't call get_presence_handler here because there's a cycle: # # Presence -> Notifier -> PresenceEventSource -> Presence @@ -1071,12 +1071,14 @@ class PresenceEventSource: users_interested_in = await self._get_interested_in(user, explicit_room_id) - user_ids_changed = set() + user_ids_changed = set() # type: Collection[str] changed = None if from_key: changed = stream_change_cache.get_all_entities_changed(from_key) if changed is not None and len(changed) < 500: + assert isinstance(user_ids_changed, set) + # For small deltas, its quicker to get all changes and then # work out if we share a room or they're in our presence list get_updates_counter.labels("stream").inc() diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 92700b589c..14348faaf3 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -98,11 +98,18 @@ class ProfileHandler(BaseHandler): except RequestSendFailed as e: raise SynapseError(502, "Failed to fetch profile") from e except HttpResponseException as e: + if e.code < 500 and e.code != 404: + # Other codes are not allowed in c2s API + logger.info( + "Server replied with wrong response: %s %s", e.code, e.msg + ) + + raise SynapseError(502, "Failed to fetch profile") raise e.to_synapse_error() async def get_profile_from_cache(self, user_id: str) -> JsonDict: """Get the profile information from our local cache. If the user is - ours then the profile information will always be corect. Otherwise, + ours then the profile information will always be correct. Otherwise, it may be out of date/missing. """ target_user = UserID.from_string(user_id) @@ -124,7 +131,7 @@ class ProfileHandler(BaseHandler): profile = await self.store.get_from_remote_profile_cache(user_id) return profile or {} - async def get_displayname(self, target_user: UserID) -> str: + async def get_displayname(self, target_user: UserID) -> Optional[str]: if self.hs.is_mine(target_user): try: displayname = await self.store.get_profile_displayname( @@ -211,7 +218,7 @@ class ProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) - async def get_avatar_url(self, target_user: UserID) -> str: + async def get_avatar_url(self, target_user: UserID) -> Optional[str]: if self.hs.is_mine(target_user): try: avatar_url = await self.store.get_profile_avatar_url( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a6f1d21674..ed1ff62599 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -115,7 +115,10 @@ class RegistrationHandler(BaseHandler): 400, "User ID already taken.", errcode=Codes.USER_IN_USE ) user_data = await self.auth.get_user_by_access_token(guest_access_token) - if not user_data["is_guest"] or user_data["user"].localpart != localpart: + if ( + not user_data.is_guest + or UserID.from_string(user_data.user_id).localpart != localpart + ): raise AuthError( 403, "Cannot register taken user ID without valid guest " @@ -741,7 +744,7 @@ class RegistrationHandler(BaseHandler): # up when the access token is saved, but that's quite an # invasive change I'd rather do separately. user_tuple = await self.store.get_user_by_access_token(token) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id await self.pusher_pool.add_pusher( user_id=user_id, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index ec300d8877..e73031475f 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -771,22 +771,29 @@ class RoomCreationHandler(BaseHandler): ratelimit=False, ) - for invitee in invite_list: + # we avoid dropping the lock between invites, as otherwise joins can + # start coming in and making the createRoom slow. + # + # we also don't need to check the requester's shadow-ban here, as we + # have already done so above (and potentially emptied invite_list). + with (await self.room_member_handler.member_linearizer.queue((room_id,))): content = {} is_direct = config.get("is_direct", None) if is_direct: content["is_direct"] = is_direct - # Note that update_membership with an action of "invite" can raise a - # ShadowBanError, but this was handled above by emptying invite_list. - _, last_stream_id = await self.room_member_handler.update_membership( - requester, - UserID.from_string(invitee), - room_id, - "invite", - ratelimit=False, - content=content, - ) + for invitee in invite_list: + ( + _, + last_stream_id, + ) = await self.room_member_handler.update_membership_locked( + requester, + UserID.from_string(invitee), + room_id, + "invite", + ratelimit=False, + content=content, + ) for invite_3pid in invite_3pid_list: id_server = invite_3pid["id_server"] @@ -1268,7 +1275,7 @@ class RoomShutdownHandler: ) # We now wait for the create room to come back in via replication so - # that we can assume that all the joins/invites have propogated before + # that we can assume that all the joins/invites have propagated before # we try and auto join below. await self._replication.wait_for_stream_position( self.hs.config.worker.events_shard_config.get_instance(new_room_id), diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index ec784030e9..7cd858b7db 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -307,7 +307,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): key = (room_id,) with (await self.member_linearizer.queue(key)): - result = await self._update_membership( + result = await self.update_membership_locked( requester, target, room_id, @@ -322,7 +322,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): return result - async def _update_membership( + async def update_membership_locked( self, requester: Requester, target: UserID, @@ -335,6 +335,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): content: Optional[dict] = None, require_consent: bool = True, ) -> Tuple[str, int]: + """Helper for update_membership. + + Assumes that the membership linearizer is already held for the room. + """ content_specified = bool(content) if content is None: content = {} diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 285c481a96..fd6c5e9ea8 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -216,9 +216,7 @@ class SamlHandler: return # Pull out the user-agent and IP from the request. - user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ - 0 - ].decode("ascii", "surrogateescape") + user_agent = request.get_user_agent("") ip_address = self.hs.get_ip_from_request(request) # Call the mapper to register/login the user diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index e9402e6e2e..66f1bbcfc4 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -139,7 +139,7 @@ class SearchHandler(BaseHandler): # Filter to apply to results filter_dict = room_cat.get("filter", {}) - # What to order results by (impacts whether pagination can be doen) + # What to order results by (impacts whether pagination can be done) order_by = room_cat.get("order_by", "rank") # Return the current state of the rooms? diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index 7a4ae0727a..fb4f70e8e2 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -32,7 +32,7 @@ class StateDeltasHandler: Returns: None if the field in the events either both match `public_value` or if neither do, i.e. there has been no change. - True if it didnt match `public_value` but now does + True if it didn't match `public_value` but now does False if it did match `public_value` but now doesn't """ prev_event = None diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b527724bc4..32e53c2d25 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -754,7 +754,7 @@ class SyncHandler: """ # TODO(mjark) Check if the state events were received by the server # after the previous sync, since we need to include those state - # updates even if they occured logically before the previous event. + # updates even if they occurred logically before the previous event. # TODO(mjark) Check for new redactions in the state events. with Measure(self.clock, "compute_state_delta"): @@ -1882,7 +1882,7 @@ class SyncHandler: # members (as the client otherwise doesn't have enough info to form # the name itself). if sync_config.filter_collection.lazy_load_members() and ( - # we recalulate the summary: + # we recalculate the summary: # if there are membership changes in the timeline, or # if membership has changed during a gappy sync, or # if this is an initial sync. diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index d3692842e3..e919a8f9ed 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -167,20 +167,25 @@ class FollowerTypingHandler: now_typing = set(row.user_ids) self._room_typing[row.room_id] = row.user_ids - run_as_background_process( - "_handle_change_in_typing", - self._handle_change_in_typing, - row.room_id, - prev_typing, - now_typing, - ) + if self.federation: + run_as_background_process( + "_send_changes_in_typing_to_remotes", + self._send_changes_in_typing_to_remotes, + row.room_id, + prev_typing, + now_typing, + ) - async def _handle_change_in_typing( + async def _send_changes_in_typing_to_remotes( self, room_id: str, prev_typing: Set[str], now_typing: Set[str] ): """Process a change in typing of a room from replication, sending EDUs for any local users. """ + + if not self.federation: + return + for user_id in now_typing - prev_typing: if self.is_mine_id(user_id): await self._push_remote(RoomMember(room_id, user_id), True) @@ -371,7 +376,7 @@ class TypingWriterHandler(FollowerTypingHandler): between the requested tokens due to the limit. The token returned can be used in a subsequent call to this - function to get further updatees. + function to get further updates. The updates are a list of 2-tuples of stream ID and the row data """ diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 79393c8829..afbebfc200 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -31,7 +31,7 @@ class UserDirectoryHandler(StateDeltasHandler): N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY The user directory is filled with users who this server can see are joined to a - world_readable or publically joinable room. We keep a database table up to date + world_readable or publicly joinable room. We keep a database table up to date by streaming changes of the current state and recalculating whether users should be in the directory or not when necessary. """ diff --git a/synapse/http/client.py b/synapse/http/client.py index 8324632cb6..f409368802 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -359,7 +359,7 @@ class SimpleHttpClient: agent=self.agent, data=body_producer, headers=headers, - **self._extra_treq_args + **self._extra_treq_args, ) # type: defer.Deferred # we use our own timeout mechanism rather than treq's as a workaround diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index a306faa267..1cc666fbf6 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -172,7 +172,7 @@ class WellKnownResolver: had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False) # We do this in two steps to differentiate between possibly transient - # errors (e.g. can't connect to host, 503 response) and more permenant + # errors (e.g. can't connect to host, 503 response) and more permanent # errors (such as getting a 404 response). response, body = await self._make_well_known_request( server_name, retry=had_valid_well_known diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index c23a4d7c0c..04766ca965 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -587,7 +587,7 @@ class MatrixFederationHttpClient: """ Builds the Authorization headers for a federation request Args: - destination (bytes|None): The desination homeserver of the request. + destination (bytes|None): The destination homeserver of the request. May be None if the destination is an identity server, in which case destination_is must be non-None. method (bytes): The HTTP method of the request @@ -640,7 +640,7 @@ class MatrixFederationHttpClient: backoff_on_404=False, try_trailing_slash_on_400=False, ): - """ Sends the specifed json data using PUT + """ Sends the specified json data using PUT Args: destination (str): The remote server to send the HTTP request @@ -729,7 +729,7 @@ class MatrixFederationHttpClient: ignore_backoff=False, args={}, ): - """ Sends the specifed json data using POST + """ Sends the specified json data using POST Args: destination (str): The remote server to send the HTTP request diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index cd94e789e8..7c5defec82 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -109,7 +109,7 @@ in_flight_requests_db_sched_duration = Counter( # The set of all in flight requests, set[RequestMetrics] _in_flight_requests = set() -# Protects the _in_flight_requests set from concurrent accesss +# Protects the _in_flight_requests set from concurrent access _in_flight_requests_lock = threading.Lock() diff --git a/synapse/http/server.py b/synapse/http/server.py index d8e354f0a9..c0919f8cb7 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -180,7 +180,7 @@ class HttpServer: """ Register a callback that gets fired if we receive a http request with the given method for a path that matches the given regex. - If the regex contains groups these gets passed to the calback via + If the regex contains groups these gets passed to the callback via an unpacked tuple. Args: @@ -239,7 +239,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): async def _async_render(self, request: Request): """Delegates to `_async_render_<METHOD>` methods, or returns a 400 if - no appropriate method exists. Can be overriden in sub classes for + no appropriate method exists. Can be overridden in sub classes for different routing. """ # Treat HEAD requests as GET requests. @@ -384,7 +384,7 @@ class JsonResource(DirectServeJsonResource): async def _async_render(self, request): callback, servlet_classname, group_dict = self._get_handler_for_request(request) - # Make sure we have an appopriate name for this handler in prometheus + # Make sure we have an appropriate name for this handler in prometheus # (rather than the default of JsonResource). request.request_metrics.name = servlet_classname diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index fd90ba7828..b361b7cbaf 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -272,7 +272,6 @@ class RestServlet: on_PUT on_POST on_DELETE - on_OPTIONS Automatically handles turning CodeMessageExceptions thrown by these methods into the appropriate HTTP response. @@ -283,7 +282,7 @@ class RestServlet: if hasattr(self, "PATTERNS"): patterns = self.PATTERNS - for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"): + for method in ("GET", "PUT", "POST", "DELETE"): if hasattr(self, "on_%s" % (method,)): servlet_classname = self.__class__.__name__ method_handler = getattr(self, "on_%s" % (method,)) diff --git a/synapse/http/site.py b/synapse/http/site.py index 6e79b47828..5f0581dc3f 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -14,7 +14,7 @@ import contextlib import logging import time -from typing import Optional +from typing import Optional, Union from twisted.python.failure import Failure from twisted.web.server import Request, Site @@ -23,6 +23,7 @@ from synapse.config.server import ListenerConfig from synapse.http import redact_uri from synapse.http.request_metrics import RequestMetrics, requests_counter from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.types import Requester logger = logging.getLogger(__name__) @@ -54,9 +55,12 @@ class SynapseRequest(Request): Request.__init__(self, channel, *args, **kw) self.site = channel.site self._channel = channel # this is used by the tests - self.authenticated_entity = None self.start_time = 0.0 + # The requester, if authenticated. For federation requests this is the + # server name, for client requests this is the Requester object. + self.requester = None # type: Optional[Union[Requester, str]] + # we can't yet create the logcontext, as we don't know the method. self.logcontext = None # type: Optional[LoggingContext] @@ -109,8 +113,14 @@ class SynapseRequest(Request): method = self.method.decode("ascii") return method - def get_user_agent(self): - return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] + def get_user_agent(self, default: str) -> str: + """Return the last User-Agent header, or the given default. + """ + user_agent = self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] + if user_agent is None: + return default + + return user_agent.decode("ascii", "replace") def render(self, resrc): # this is called once a Resource has been found to serve the request; in our @@ -161,7 +171,9 @@ class SynapseRequest(Request): yield except Exception: # this should already have been caught, and sent back to the client as a 500. - logger.exception("Asynchronous messge handler raised an uncaught exception") + logger.exception( + "Asynchronous message handler raised an uncaught exception" + ) finally: # the request handler has finished its work and either sent the whole response # back, or handed over responsibility to a Producer. @@ -263,22 +275,30 @@ class SynapseRequest(Request): # to the client (nb may be negative) response_send_time = self.finish_time - self._processing_finished_time - # need to decode as it could be raw utf-8 bytes - # from a IDN servname in an auth header - authenticated_entity = self.authenticated_entity - if authenticated_entity is not None and isinstance(authenticated_entity, bytes): - authenticated_entity = authenticated_entity.decode("utf-8", "replace") + # Convert the requester into a string that we can log + authenticated_entity = None + if isinstance(self.requester, str): + authenticated_entity = self.requester + elif isinstance(self.requester, Requester): + authenticated_entity = self.requester.authenticated_entity + + # If this is a request where the target user doesn't match the user who + # authenticated (e.g. and admin is puppetting a user) then we log both. + if self.requester.user.to_string() != authenticated_entity: + authenticated_entity = "{},{}".format( + authenticated_entity, self.requester.user.to_string(), + ) + elif self.requester is not None: + # This shouldn't happen, but we log it so we don't lose information + # and can see that we're doing something wrong. + authenticated_entity = repr(self.requester) # type: ignore[unreachable] # ...or could be raw utf-8 bytes in the User-Agent header. # N.B. if you don't do this, the logger explodes cryptically # with maximum recursion trying to log errors about # the charset problem. # c.f. https://github.com/matrix-org/synapse/issues/3471 - user_agent = self.get_user_agent() - if user_agent is not None: - user_agent = user_agent.decode("utf-8", "replace") - else: - user_agent = "-" + user_agent = self.get_user_agent("-") code = str(self.code) if not self.finished: diff --git a/synapse/logging/__init__.py b/synapse/logging/__init__.py index e69de29bb2..b28b7b2ef7 100644 --- a/synapse/logging/__init__.py +++ b/synapse/logging/__init__.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# These are imported to allow for nicer logging configuration files. +from synapse.logging._remote import RemoteHandler +from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter + +__all__ = ["RemoteHandler", "JsonFormatter", "TerseJsonFormatter"] diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py index 0caf325916..fb937b3f28 100644 --- a/synapse/logging/_remote.py +++ b/synapse/logging/_remote.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import sys import traceback from collections import deque @@ -21,10 +22,11 @@ from math import floor from typing import Callable, Optional import attr +from typing_extensions import Deque from zope.interface import implementer from twisted.application.internet import ClientService -from twisted.internet.defer import Deferred +from twisted.internet.defer import CancelledError, Deferred from twisted.internet.endpoints import ( HostnameEndpoint, TCP4ClientEndpoint, @@ -32,7 +34,9 @@ from twisted.internet.endpoints import ( ) from twisted.internet.interfaces import IPushProducer, ITransport from twisted.internet.protocol import Factory, Protocol -from twisted.logger import ILogObserver, Logger, LogLevel +from twisted.python.failure import Failure + +logger = logging.getLogger(__name__) @attr.s @@ -45,11 +49,11 @@ class LogProducer: Args: buffer: Log buffer to read logs from. transport: Transport to write to. - format_event: A callable to format the log entry to a string. + format: A callable to format the log record to a string. """ transport = attr.ib(type=ITransport) - format_event = attr.ib(type=Callable[[dict], str]) + _format = attr.ib(type=Callable[[logging.LogRecord], str]) _buffer = attr.ib(type=deque) _paused = attr.ib(default=False, type=bool, init=False) @@ -61,16 +65,19 @@ class LogProducer: self._buffer = deque() def resumeProducing(self): + # If we're already producing, nothing to do. self._paused = False + # Loop until paused. while self._paused is False and (self._buffer and self.transport.connected): try: - # Request the next event and format it. - event = self._buffer.popleft() - msg = self.format_event(event) + # Request the next record and format it. + record = self._buffer.popleft() + msg = self._format(record) # Send it as a new line over the transport. self.transport.write(msg.encode("utf8")) + self.transport.write(b"\n") except Exception: # Something has gone wrong writing to the transport -- log it # and break out of the while. @@ -78,76 +85,85 @@ class LogProducer: break -@attr.s -@implementer(ILogObserver) -class TCPLogObserver: +class RemoteHandler(logging.Handler): """ - An IObserver that writes JSON logs to a TCP target. + An logging handler that writes logs to a TCP target. Args: - hs (HomeServer): The homeserver that is being logged for. host: The host of the logging target. port: The logging target's port. - format_event: A callable to format the log entry to a string. maximum_buffer: The maximum buffer size. """ - hs = attr.ib() - host = attr.ib(type=str) - port = attr.ib(type=int) - format_event = attr.ib(type=Callable[[dict], str]) - maximum_buffer = attr.ib(type=int) - _buffer = attr.ib(default=attr.Factory(deque), type=deque) - _connection_waiter = attr.ib(default=None, type=Optional[Deferred]) - _logger = attr.ib(default=attr.Factory(Logger)) - _producer = attr.ib(default=None, type=Optional[LogProducer]) - - def start(self) -> None: + def __init__( + self, + host: str, + port: int, + maximum_buffer: int = 1000, + level=logging.NOTSET, + _reactor=None, + ): + super().__init__(level=level) + self.host = host + self.port = port + self.maximum_buffer = maximum_buffer + + self._buffer = deque() # type: Deque[logging.LogRecord] + self._connection_waiter = None # type: Optional[Deferred] + self._producer = None # type: Optional[LogProducer] # Connect without DNS lookups if it's a direct IP. + if _reactor is None: + from twisted.internet import reactor + + _reactor = reactor + try: ip = ip_address(self.host) if isinstance(ip, IPv4Address): - endpoint = TCP4ClientEndpoint( - self.hs.get_reactor(), self.host, self.port - ) + endpoint = TCP4ClientEndpoint(_reactor, self.host, self.port) elif isinstance(ip, IPv6Address): - endpoint = TCP6ClientEndpoint( - self.hs.get_reactor(), self.host, self.port - ) + endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port) else: raise ValueError("Unknown IP address provided: %s" % (self.host,)) except ValueError: - endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port) + endpoint = HostnameEndpoint(_reactor, self.host, self.port) factory = Factory.forProtocol(Protocol) - self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor()) + self._service = ClientService(endpoint, factory, clock=_reactor) self._service.startService() + self._stopping = False self._connect() - def stop(self): + def close(self): + self._stopping = True self._service.stopService() def _connect(self) -> None: """ Triggers an attempt to connect then write to the remote if not already writing. """ + # Do not attempt to open multiple connections. if self._connection_waiter: return self._connection_waiter = self._service.whenConnected(failAfterFailures=1) - @self._connection_waiter.addErrback - def fail(r): - r.printTraceback(file=sys.__stderr__) + def fail(failure: Failure) -> None: + # If the Deferred was cancelled (e.g. during shutdown) do not try to + # reconnect (this will cause an infinite loop of errors). + if failure.check(CancelledError) and self._stopping: + return + + # For a different error, print the traceback and re-connect. + failure.printTraceback(file=sys.__stderr__) self._connection_waiter = None self._connect() - @self._connection_waiter.addCallback - def writer(r): + def writer(result: Protocol) -> None: # We have a connection. If we already have a producer, and its # transport is the same, just trigger a resumeProducing. - if self._producer and r.transport is self._producer.transport: + if self._producer and result.transport is self._producer.transport: self._producer.resumeProducing() self._connection_waiter = None return @@ -158,29 +174,29 @@ class TCPLogObserver: # Make a new producer and start it. self._producer = LogProducer( - buffer=self._buffer, - transport=r.transport, - format_event=self.format_event, + buffer=self._buffer, transport=result.transport, format=self.format, ) - r.transport.registerProducer(self._producer, True) + result.transport.registerProducer(self._producer, True) self._producer.resumeProducing() self._connection_waiter = None + self._connection_waiter.addCallbacks(writer, fail) + def _handle_pressure(self) -> None: """ - Handle backpressure by shedding events. + Handle backpressure by shedding records. The buffer will, in this order, until the buffer is below the maximum: - - Shed DEBUG events - - Shed INFO events - - Shed the middle 50% of the events. + - Shed DEBUG records. + - Shed INFO records. + - Shed the middle 50% of the records. """ if len(self._buffer) <= self.maximum_buffer: return # Strip out DEBUGs self._buffer = deque( - filter(lambda event: event["log_level"] != LogLevel.debug, self._buffer) + filter(lambda record: record.levelno > logging.DEBUG, self._buffer) ) if len(self._buffer) <= self.maximum_buffer: @@ -188,7 +204,7 @@ class TCPLogObserver: # Strip out INFOs self._buffer = deque( - filter(lambda event: event["log_level"] != LogLevel.info, self._buffer) + filter(lambda record: record.levelno > logging.INFO, self._buffer) ) if len(self._buffer) <= self.maximum_buffer: @@ -209,17 +225,17 @@ class TCPLogObserver: self._buffer.extend(reversed(end_buffer)) - def __call__(self, event: dict) -> None: - self._buffer.append(event) + def emit(self, record: logging.LogRecord) -> None: + self._buffer.append(record) # Handle backpressure, if it exists. try: self._handle_pressure() except Exception: - # If handling backpressure fails,clear the buffer and log the + # If handling backpressure fails, clear the buffer and log the # exception. self._buffer.clear() - self._logger.failure("Failed clearing backpressure") + logger.warning("Failed clearing backpressure") # Try and write immediately. self._connect() diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py index 0fc2ea609e..14d9c104c2 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py @@ -12,138 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging import os.path -import sys -import typing -import warnings -from typing import List +from typing import Any, Dict, Generator, Optional, Tuple -import attr -from constantly import NamedConstant, Names, ValueConstant, Values -from zope.interface import implementer - -from twisted.logger import ( - FileLogObserver, - FilteringLogObserver, - ILogObserver, - LogBeginner, - Logger, - LogLevel, - LogLevelFilterPredicate, - LogPublisher, - eventAsText, - jsonFileLogObserver, -) +from constantly import NamedConstant, Names from synapse.config._base import ConfigError -from synapse.logging._terse_json import ( - TerseJSONToConsoleLogObserver, - TerseJSONToTCPLogObserver, -) -from synapse.logging.context import current_context - - -def stdlib_log_level_to_twisted(level: str) -> LogLevel: - """ - Convert a stdlib log level to Twisted's log level. - """ - lvl = level.lower().replace("warning", "warn") - return LogLevel.levelWithName(lvl) - - -@attr.s -@implementer(ILogObserver) -class LogContextObserver: - """ - An ILogObserver which adds Synapse-specific log context information. - - Attributes: - observer (ILogObserver): The target parent observer. - """ - - observer = attr.ib() - - def __call__(self, event: dict) -> None: - """ - Consume a log event and emit it to the parent observer after filtering - and adding log context information. - - Args: - event (dict) - """ - # Filter out some useless events that Twisted outputs - if "log_text" in event: - if event["log_text"].startswith("DNSDatagramProtocol starting on "): - return - - if event["log_text"].startswith("(UDP Port "): - return - - if event["log_text"].startswith("Timing out client") or event[ - "log_format" - ].startswith("Timing out client"): - return - - context = current_context() - - # Copy the context information to the log event. - context.copy_to_twisted_log_entry(event) - - self.observer(event) - - -class PythonStdlibToTwistedLogger(logging.Handler): - """ - Transform a Python stdlib log message into a Twisted one. - """ - - def __init__(self, observer, *args, **kwargs): - """ - Args: - observer (ILogObserver): A Twisted logging observer. - *args, **kwargs: Args/kwargs to be passed to logging.Handler. - """ - self.observer = observer - super().__init__(*args, **kwargs) - - def emit(self, record: logging.LogRecord) -> None: - """ - Emit a record to Twisted's observer. - - Args: - record (logging.LogRecord) - """ - - self.observer( - { - "log_time": record.created, - "log_text": record.getMessage(), - "log_format": "{log_text}", - "log_namespace": record.name, - "log_level": stdlib_log_level_to_twisted(record.levelname), - } - ) - - -def SynapseFileLogObserver(outFile: typing.IO[str]) -> FileLogObserver: - """ - A log observer that formats events like the traditional log formatter and - sends them to `outFile`. - - Args: - outFile (file object): The file object to write to. - """ - - def formatEvent(_event: dict) -> str: - event = dict(_event) - event["log_level"] = event["log_level"].name.upper() - event["log_format"] = "- {log_namespace} - {log_level} - {request} - " + ( - event.get("log_format", "{log_text}") or "{log_text}" - ) - return eventAsText(event, includeSystem=False) + "\n" - - return FileLogObserver(outFile, formatEvent) class DrainType(Names): @@ -155,30 +29,12 @@ class DrainType(Names): NETWORK_JSON_TERSE = NamedConstant() -class OutputPipeType(Values): - stdout = ValueConstant(sys.__stdout__) - stderr = ValueConstant(sys.__stderr__) - - -@attr.s -class DrainConfiguration: - name = attr.ib() - type = attr.ib() - location = attr.ib() - options = attr.ib(default=None) - - -@attr.s -class NetworkJSONTerseOptions: - maximum_buffer = attr.ib(type=int) - - -DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}} +DEFAULT_LOGGERS = {"synapse": {"level": "info"}} def parse_drain_configs( drains: dict, -) -> typing.Generator[DrainConfiguration, None, None]: +) -> Generator[Tuple[str, Dict[str, Any]], None, None]: """ Parse the drain configurations. @@ -186,11 +42,12 @@ def parse_drain_configs( drains (dict): A list of drain configurations. Yields: - DrainConfiguration instances. + dict instances representing a logging handler. Raises: ConfigError: If any of the drain configuration items are invalid. """ + for name, config in drains.items(): if "type" not in config: raise ConfigError("Logging drains require a 'type' key.") @@ -202,6 +59,18 @@ def parse_drain_configs( "%s is not a known logging drain type." % (config["type"],) ) + # Either use the default formatter or the tersejson one. + if logging_type in (DrainType.CONSOLE_JSON, DrainType.FILE_JSON,): + formatter = "json" # type: Optional[str] + elif logging_type in ( + DrainType.CONSOLE_JSON_TERSE, + DrainType.NETWORK_JSON_TERSE, + ): + formatter = "tersejson" + else: + # A formatter of None implies using the default formatter. + formatter = None + if logging_type in [ DrainType.CONSOLE, DrainType.CONSOLE_JSON, @@ -217,9 +86,11 @@ def parse_drain_configs( % (logging_type,) ) - pipe = OutputPipeType.lookupByName(location).value - - yield DrainConfiguration(name=name, type=logging_type, location=pipe) + yield name, { + "class": "logging.StreamHandler", + "formatter": formatter, + "stream": "ext://sys." + location, + } elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]: if "location" not in config: @@ -233,18 +104,25 @@ def parse_drain_configs( "File paths need to be absolute, '%s' is a relative path" % (location,) ) - yield DrainConfiguration(name=name, type=logging_type, location=location) + + yield name, { + "class": "logging.FileHandler", + "formatter": formatter, + "filename": location, + } elif logging_type in [DrainType.NETWORK_JSON_TERSE]: host = config.get("host") port = config.get("port") maximum_buffer = config.get("maximum_buffer", 1000) - yield DrainConfiguration( - name=name, - type=logging_type, - location=(host, port), - options=NetworkJSONTerseOptions(maximum_buffer=maximum_buffer), - ) + + yield name, { + "class": "synapse.logging.RemoteHandler", + "formatter": formatter, + "host": host, + "port": port, + "maximum_buffer": maximum_buffer, + } else: raise ConfigError( @@ -253,126 +131,29 @@ def parse_drain_configs( ) -class StoppableLogPublisher(LogPublisher): +def setup_structured_logging(log_config: dict,) -> dict: """ - A log publisher that can tell its observers to shut down any external - communications. - """ - - def stop(self): - for obs in self._observers: - if hasattr(obs, "stop"): - obs.stop() - - -def setup_structured_logging( - hs, - config, - log_config: dict, - logBeginner: LogBeginner, - redirect_stdlib_logging: bool = True, -) -> LogPublisher: - """ - Set up Twisted's structured logging system. - - Args: - hs: The homeserver to use. - config (HomeserverConfig): The configuration of the Synapse homeserver. - log_config (dict): The log configuration to use. + Convert a legacy structured logging configuration (from Synapse < v1.23.0) + to one compatible with the new standard library handlers. """ - if config.no_redirect_stdio: - raise ConfigError( - "no_redirect_stdio cannot be defined using structured logging." - ) - - logger = Logger() - if "drains" not in log_config: raise ConfigError("The logging configuration requires a list of drains.") - observers = [] # type: List[ILogObserver] - - for observer in parse_drain_configs(log_config["drains"]): - # Pipe drains - if observer.type == DrainType.CONSOLE: - logger.debug( - "Starting up the {name} console logger drain", name=observer.name - ) - observers.append(SynapseFileLogObserver(observer.location)) - elif observer.type == DrainType.CONSOLE_JSON: - logger.debug( - "Starting up the {name} JSON console logger drain", name=observer.name - ) - observers.append(jsonFileLogObserver(observer.location)) - elif observer.type == DrainType.CONSOLE_JSON_TERSE: - logger.debug( - "Starting up the {name} terse JSON console logger drain", - name=observer.name, - ) - observers.append( - TerseJSONToConsoleLogObserver(observer.location, metadata={}) - ) - - # File drains - elif observer.type == DrainType.FILE: - logger.debug("Starting up the {name} file logger drain", name=observer.name) - log_file = open(observer.location, "at", buffering=1, encoding="utf8") - observers.append(SynapseFileLogObserver(log_file)) - elif observer.type == DrainType.FILE_JSON: - logger.debug( - "Starting up the {name} JSON file logger drain", name=observer.name - ) - log_file = open(observer.location, "at", buffering=1, encoding="utf8") - observers.append(jsonFileLogObserver(log_file)) - - elif observer.type == DrainType.NETWORK_JSON_TERSE: - metadata = {"server_name": hs.config.server_name} - log_observer = TerseJSONToTCPLogObserver( - hs=hs, - host=observer.location[0], - port=observer.location[1], - metadata=metadata, - maximum_buffer=observer.options.maximum_buffer, - ) - log_observer.start() - observers.append(log_observer) - else: - # We should never get here, but, just in case, throw an error. - raise ConfigError("%s drain type cannot be configured" % (observer.type,)) - - publisher = StoppableLogPublisher(*observers) - log_filter = LogLevelFilterPredicate() - - for namespace, namespace_config in log_config.get( - "loggers", DEFAULT_LOGGERS - ).items(): - # Set the log level for twisted.logger.Logger namespaces - log_filter.setLogLevelForNamespace( - namespace, - stdlib_log_level_to_twisted(namespace_config.get("level", "INFO")), - ) - - # Also set the log levels for the stdlib logger namespaces, to prevent - # them getting to PythonStdlibToTwistedLogger and having to be formatted - if "level" in namespace_config: - logging.getLogger(namespace).setLevel(namespace_config.get("level")) - - f = FilteringLogObserver(publisher, [log_filter]) - lco = LogContextObserver(f) - - if redirect_stdlib_logging: - stuff_into_twisted = PythonStdlibToTwistedLogger(lco) - stdliblogger = logging.getLogger() - stdliblogger.addHandler(stuff_into_twisted) - - # Always redirect standard I/O, otherwise other logging outputs might miss - # it. - logBeginner.beginLoggingTo([lco], redirectStandardIO=True) + new_config = { + "version": 1, + "formatters": { + "json": {"class": "synapse.logging.JsonFormatter"}, + "tersejson": {"class": "synapse.logging.TerseJsonFormatter"}, + }, + "handlers": {}, + "loggers": log_config.get("loggers", DEFAULT_LOGGERS), + "root": {"handlers": []}, + } - return publisher + for handler_name, handler in parse_drain_configs(log_config["drains"]): + new_config["handlers"][handler_name] = handler + # Add each handler to the root logger. + new_config["root"]["handlers"].append(handler_name) -def reload_structured_logging(*args, log_config=None) -> None: - warnings.warn( - "Currently the structured logging system can not be reloaded, doing nothing" - ) + return new_config diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py index 9b46956ca9..2fbf5549a1 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py @@ -16,141 +16,65 @@ """ Log formatters that output terse JSON. """ - import json -from typing import IO - -from twisted.logger import FileLogObserver - -from synapse.logging._remote import TCPLogObserver +import logging _encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":")) - -def flatten_event(event: dict, metadata: dict, include_time: bool = False): - """ - Flatten a Twisted logging event to an dictionary capable of being sent - as a log event to a logging aggregation system. - - The format is vastly simplified and is not designed to be a "human readable - string" in the sense that traditional logs are. Instead, the structure is - optimised for searchability and filtering, with human-understandable log - keys. - - Args: - event (dict): The Twisted logging event we are flattening. - metadata (dict): Additional data to include with each log message. This - can be information like the server name. Since the target log - consumer does not know who we are other than by host IP, this - allows us to forward through static information. - include_time (bool): Should we include the `time` key? If False, the - event time is stripped from the event. - """ - new_event = {} - - # If it's a failure, make the new event's log_failure be the traceback text. - if "log_failure" in event: - new_event["log_failure"] = event["log_failure"].getTraceback() - - # If it's a warning, copy over a string representation of the warning. - if "warning" in event: - new_event["warning"] = str(event["warning"]) - - # Stdlib logging events have "log_text" as their human-readable portion, - # Twisted ones have "log_format". For now, include the log_format, so that - # context only given in the log format (e.g. what is being logged) is - # available. - if "log_text" in event: - new_event["log"] = event["log_text"] - else: - new_event["log"] = event["log_format"] - - # We want to include the timestamp when forwarding over the network, but - # exclude it when we are writing to stdout. This is because the log ingester - # (e.g. logstash, fluentd) can add its own timestamp. - if include_time: - new_event["time"] = round(event["log_time"], 2) - - # Convert the log level to a textual representation. - new_event["level"] = event["log_level"].name.upper() - - # Ignore these keys, and do not transfer them over to the new log object. - # They are either useless (isError), transferred manually above (log_time, - # log_level, etc), or contain Python objects which are not useful for output - # (log_logger, log_source). - keys_to_delete = [ - "isError", - "log_failure", - "log_format", - "log_level", - "log_logger", - "log_source", - "log_system", - "log_time", - "log_text", - "observer", - "warning", - ] - - # If it's from the Twisted legacy logger (twisted.python.log), it adds some - # more keys we want to purge. - if event.get("log_namespace") == "log_legacy": - keys_to_delete.extend(["message", "system", "time"]) - - # Rather than modify the dictionary in place, construct a new one with only - # the content we want. The original event should be considered 'frozen'. - for key in event.keys(): - - if key in keys_to_delete: - continue - - if isinstance(event[key], (str, int, bool, float)) or event[key] is None: - # If it's a plain type, include it as is. - new_event[key] = event[key] - else: - # If it's not one of those basic types, write out a string - # representation. This should probably be a warning in development, - # so that we are sure we are only outputting useful data. - new_event[key] = str(event[key]) - - # Add the metadata information to the event (e.g. the server_name). - new_event.update(metadata) - - return new_event - - -def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogObserver: - """ - A log observer that formats events to a flattened JSON representation. - - Args: - outFile: The file object to write to. - metadata: Metadata to be added to each log object. - """ - - def formatEvent(_event: dict) -> str: - flattened = flatten_event(_event, metadata) - return _encoder.encode(flattened) + "\n" - - return FileLogObserver(outFile, formatEvent) - - -def TerseJSONToTCPLogObserver( - hs, host: str, port: int, metadata: dict, maximum_buffer: int -) -> FileLogObserver: - """ - A log observer that formats events to a flattened JSON representation. - - Args: - hs (HomeServer): The homeserver that is being logged for. - host: The host of the logging target. - port: The logging target's port. - metadata: Metadata to be added to each log object. - maximum_buffer: The maximum buffer size. - """ - - def formatEvent(_event: dict) -> str: - flattened = flatten_event(_event, metadata, include_time=True) - return _encoder.encode(flattened) + "\n" - - return TCPLogObserver(hs, host, port, formatEvent, maximum_buffer) +# The properties of a standard LogRecord. +_LOG_RECORD_ATTRIBUTES = { + "args", + "asctime", + "created", + "exc_info", + # exc_text isn't a public attribute, but is used to cache the result of formatException. + "exc_text", + "filename", + "funcName", + "levelname", + "levelno", + "lineno", + "message", + "module", + "msecs", + "msg", + "name", + "pathname", + "process", + "processName", + "relativeCreated", + "stack_info", + "thread", + "threadName", +} + + +class JsonFormatter(logging.Formatter): + def format(self, record: logging.LogRecord) -> str: + event = { + "log": record.getMessage(), + "namespace": record.name, + "level": record.levelname, + } + + return self._format(record, event) + + def _format(self, record: logging.LogRecord, event: dict) -> str: + # Add any extra attributes to the event. + for key, value in record.__dict__.items(): + if key not in _LOG_RECORD_ATTRIBUTES: + event[key] = value + + return _encoder.encode(event) + + +class TerseJsonFormatter(JsonFormatter): + def format(self, record: logging.LogRecord) -> str: + event = { + "log": record.getMessage(), + "namespace": record.name, + "level": record.levelname, + "time": round(record.created, 2), + } + + return self._format(record, event) diff --git a/synapse/logging/filter.py b/synapse/logging/filter.py new file mode 100644 index 0000000000..1baf8dd679 --- /dev/null +++ b/synapse/logging/filter.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from typing_extensions import Literal + + +class MetadataFilter(logging.Filter): + """Logging filter that adds constant values to each record. + + Args: + metadata: Key-value pairs to add to each record. + """ + + def __init__(self, metadata: dict): + self._metadata = metadata + + def filter(self, record: logging.LogRecord) -> Literal[True]: + for key, value in self._metadata.items(): + setattr(record, key, value) + return True diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index e58850faff..ab586c318c 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -317,7 +317,7 @@ def ensure_active_span(message, ret=None): @contextlib.contextmanager -def _noop_context_manager(*args, **kwargs): +def noop_context_manager(*args, **kwargs): """Does exactly what it says on the tin""" yield @@ -413,7 +413,7 @@ def start_active_span( """ if opentracing is None: - return _noop_context_manager() + return noop_context_manager() return opentracing.tracer.start_active_span( operation_name, @@ -428,7 +428,7 @@ def start_active_span( def start_active_span_follows_from(operation_name, contexts): if opentracing is None: - return _noop_context_manager() + return noop_context_manager() references = [opentracing.follows_from(context) for context in contexts] scope = start_active_span(operation_name, references=references) @@ -459,7 +459,7 @@ def start_active_span_from_request( # Also, twisted uses byte arrays while opentracing expects strings. if opentracing is None: - return _noop_context_manager() + return noop_context_manager() header_dict = { k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders() @@ -497,7 +497,7 @@ def start_active_span_from_edu( """ if opentracing is None: - return _noop_context_manager() + return noop_context_manager() carrier = json_decoder.decode(edu_content.get("context", "{}")).get( "opentracing", {} diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index ea5f1c7b62..658f6ecd72 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -24,7 +24,7 @@ from prometheus_client.core import REGISTRY, Counter, Gauge from twisted.internet import defer from synapse.logging.context import LoggingContext, PreserveLoggingContext -from synapse.logging.opentracing import start_active_span +from synapse.logging.opentracing import noop_context_manager, start_active_span if TYPE_CHECKING: import resource @@ -167,7 +167,7 @@ class _BackgroundProcess: ) -def run_as_background_process(desc: str, func, *args, **kwargs): +def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwargs): """Run the given function in its own logcontext, with resource metrics This should be used to wrap processes which are fired off to run in the @@ -181,6 +181,9 @@ def run_as_background_process(desc: str, func, *args, **kwargs): Args: desc: a description for this background process type func: a function, which may return a Deferred or a coroutine + bg_start_span: Whether to start an opentracing span. Defaults to True. + Should only be disabled for processes that will not log to or tag + a span. args: positional args for func kwargs: keyword args for func @@ -199,7 +202,10 @@ def run_as_background_process(desc: str, func, *args, **kwargs): with BackgroundProcessLoggingContext(desc) as context: context.request = "%s-%i" % (desc, count) try: - with start_active_span(desc, tags={"request_id": context.request}): + ctx = noop_context_manager() + if bg_start_span: + ctx = start_active_span(desc, tags={"request_id": context.request}) + with ctx: result = func(*args, **kwargs) if inspect.isawaitable(result): @@ -266,7 +272,7 @@ class BackgroundProcessLoggingContext(LoggingContext): super().__exit__(type, value, traceback) - # The background process has finished. We explictly remove and manually + # The background process has finished. We explicitly remove and manually # update the metrics here so that if nothing is scraping metrics the set # doesn't infinitely grow. with _bg_metrics_lock: diff --git a/synapse/notifier.py b/synapse/notifier.py index 2e993411b9..a17352ef46 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -28,6 +28,7 @@ from typing import ( Union, ) +import attr from prometheus_client import Counter from twisted.internet import defer @@ -40,7 +41,6 @@ from synapse.handlers.presence import format_user_presence_state from synapse.logging.context import PreserveLoggingContext from synapse.logging.utils import log_function from synapse.metrics import LaterGauge -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.streams.config import PaginationConfig from synapse.types import ( Collection, @@ -174,6 +174,17 @@ class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))): return bool(self.events) +@attr.s(slots=True, frozen=True) +class _PendingRoomEventEntry: + event_pos = attr.ib(type=PersistedEventPosition) + extra_users = attr.ib(type=Collection[UserID]) + + room_id = attr.ib(type=str) + type = attr.ib(type=str) + state_key = attr.ib(type=Optional[str]) + membership = attr.ib(type=Optional[str]) + + class Notifier: """ This class is responsible for notifying any listeners when there are new events available for it. @@ -191,9 +202,7 @@ class Notifier: self.storage = hs.get_storage() self.event_sources = hs.get_event_sources() self.store = hs.get_datastore() - self.pending_new_room_events = ( - [] - ) # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]] + self.pending_new_room_events = [] # type: List[_PendingRoomEventEntry] # Called when there are new things to stream over replication self.replication_callbacks = [] # type: List[Callable[[], None]] @@ -256,7 +265,29 @@ class Notifier: max_room_stream_token: RoomStreamToken, extra_users: Collection[UserID] = [], ): - """ Used by handlers to inform the notifier something has happened + """Unwraps event and calls `on_new_room_event_args`. + """ + self.on_new_room_event_args( + event_pos=event_pos, + room_id=event.room_id, + event_type=event.type, + state_key=event.get("state_key"), + membership=event.content.get("membership"), + max_room_stream_token=max_room_stream_token, + extra_users=extra_users, + ) + + def on_new_room_event_args( + self, + room_id: str, + event_type: str, + state_key: Optional[str], + membership: Optional[str], + event_pos: PersistedEventPosition, + max_room_stream_token: RoomStreamToken, + extra_users: Collection[UserID] = [], + ): + """Used by handlers to inform the notifier something has happened in the room, room event wise. This triggers the notifier to wake up any listeners that are @@ -267,7 +298,16 @@ class Notifier: until all previous events have been persisted before notifying the client streams. """ - self.pending_new_room_events.append((event_pos, event, extra_users)) + self.pending_new_room_events.append( + _PendingRoomEventEntry( + event_pos=event_pos, + extra_users=extra_users, + room_id=room_id, + type=event_type, + state_key=state_key, + membership=membership, + ) + ) self._notify_pending_new_room_events(max_room_stream_token) self.notify_replication() @@ -285,18 +325,19 @@ class Notifier: users = set() # type: Set[UserID] rooms = set() # type: Set[str] - for event_pos, event, extra_users in pending: - if event_pos.persisted_after(max_room_stream_token): - self.pending_new_room_events.append((event_pos, event, extra_users)) + for entry in pending: + if entry.event_pos.persisted_after(max_room_stream_token): + self.pending_new_room_events.append(entry) else: if ( - event.type == EventTypes.Member - and event.membership == Membership.JOIN + entry.type == EventTypes.Member + and entry.membership == Membership.JOIN + and entry.state_key ): - self._user_joined_room(event.state_key, event.room_id) + self._user_joined_room(entry.state_key, entry.room_id) - users.update(extra_users) - rooms.add(event.room_id) + users.update(entry.extra_users) + rooms.add(entry.room_id) if users or rooms: self.on_new_event( @@ -310,44 +351,37 @@ class Notifier: """ # poke any interested application service. - run_as_background_process( - "_notify_app_services", self._notify_app_services, max_room_stream_token - ) - - run_as_background_process( - "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_token - ) + self._notify_app_services(max_room_stream_token) + self._notify_pusher_pool(max_room_stream_token) if self.federation_sender: self.federation_sender.notify_new_events(max_room_stream_token) - async def _notify_app_services(self, max_room_stream_token: RoomStreamToken): + def _notify_app_services(self, max_room_stream_token: RoomStreamToken): try: - await self.appservice_handler.notify_interested_services( - max_room_stream_token - ) + self.appservice_handler.notify_interested_services(max_room_stream_token) except Exception: logger.exception("Error notifying application services of event") - async def _notify_app_services_ephemeral( + def _notify_app_services_ephemeral( self, stream_key: str, new_token: Union[int, RoomStreamToken], - users: Collection[UserID] = [], + users: Collection[Union[str, UserID]] = [], ): try: stream_token = None if isinstance(new_token, int): stream_token = new_token - await self.appservice_handler.notify_interested_services_ephemeral( + self.appservice_handler.notify_interested_services_ephemeral( stream_key, stream_token, users ) except Exception: logger.exception("Error notifying application services of event") - async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken): + def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken): try: - await self._pusher_pool.on_new_notifications(max_room_stream_token) + self._pusher_pool.on_new_notifications(max_room_stream_token) except Exception: logger.exception("Error pusher pool of event") @@ -384,16 +418,12 @@ class Notifier: self.notify_replication() # Notify appservices - run_as_background_process( - "_notify_app_services_ephemeral", - self._notify_app_services_ephemeral, - stream_key, - new_token, - users, + self._notify_app_services_ephemeral( + stream_key, new_token, users, ) def on_new_replication_data(self) -> None: - """Used to inform replication listeners that something has happend + """Used to inform replication listeners that something has happened without waking up any of the normal user event streams""" self.notify_replication() diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 8047873ff1..2858b61fb1 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -37,7 +37,7 @@ def list_with_base_rules(rawrules, use_new_defaults=False): modified_base_rules = {r["rule_id"]: r for r in rawrules if r["priority_class"] < 0} # Remove the modified base rules from the list, They'll be added back - # in the default postions in the list. + # in the default positions in the list. rawrules = [r for r in rawrules if r["priority_class"] >= 0] # shove the server default rules for each kind onto the end of each diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index a701defcdd..82a72dc34f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -15,8 +15,8 @@ # limitations under the License. import logging -from collections import namedtuple +import attr from prometheus_client import Counter from synapse.api.constants import EventTypes, Membership, RelationTypes @@ -26,7 +26,8 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.util.async_helpers import Linearizer from synapse.util.caches import register_cache -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import lru_cache +from synapse.util.caches.lrucache import LruCache from .push_rule_evaluator import PushRuleEvaluatorForEvent @@ -120,7 +121,7 @@ class BulkPushRuleEvaluator: dict of user_id -> push_rules """ room_id = event.room_id - rules_for_room = await self._get_rules_for_room(room_id) + rules_for_room = self._get_rules_for_room(room_id) rules_by_user = await rules_for_room.get_rules(event, context) @@ -138,7 +139,7 @@ class BulkPushRuleEvaluator: return rules_by_user - @cached() + @lru_cache() def _get_rules_for_room(self, room_id): """Get the current RulesForRoom object for the given room id @@ -275,12 +276,14 @@ class RulesForRoom: the entire cache for the room. """ - def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics): + def __init__( + self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics + ): """ Args: hs (HomeServer) room_id (str) - rules_for_room_cache(Cache): The cache object that caches these + rules_for_room_cache: The cache object that caches these RoomsForUser objects. room_push_rule_cache_metrics (CacheMetric) """ @@ -390,12 +393,12 @@ class RulesForRoom: continue # If a user has left a room we remove their push rule. If they - # joined then we readd it later in _update_rules_with_member_event_ids + # joined then we re-add it later in _update_rules_with_member_event_ids ret_rules_by_user.pop(user_id, None) missing_member_event_ids[user_id] = event_id if missing_member_event_ids: - # If we have some memebr events we haven't seen, look them up + # If we have some member events we haven't seen, look them up # and fetch push rules for them if appropriate. logger.debug("Found new member events %r", missing_member_event_ids) await self._update_rules_with_member_event_ids( @@ -489,13 +492,21 @@ class RulesForRoom: self.state_group = state_group -class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))): - # We rely on _CacheContext implementing __eq__ and __hash__ sensibly, - # which namedtuple does for us (i.e. two _CacheContext are the same if - # their caches and keys match). This is important in particular to - # dedupe when we add callbacks to lru cache nodes, otherwise the number - # of callbacks would grow. +@attr.attrs(slots=True, frozen=True) +class _Invalidation: + # _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules, + # which means that it it is stored on the bulk_get_push_rules cache entry. In order + # to ensure that we don't accumulate lots of redunant callbacks on the cache entry, + # we need to ensure that two _Invalidation objects are "equal" if they refer to the + # same `cache` and `room_id`. + # + # attrs provides suitable __hash__ and __eq__ methods, provided we remember to + # set `frozen=True`. + + cache = attr.ib(type=LruCache) + room_id = attr.ib(type=str) + def __call__(self): - rules = self.cache.get_immediate(self.room_id, None, update_metrics=False) + rules = self.cache.get(self.room_id, None, update_metrics=False) if rules: rules.invalidate_all() diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 155791b754..38195c8eea 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -24,7 +24,7 @@ from typing import Iterable, List, TypeVar import bleach import jinja2 -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import StoreError from synapse.config.emailconfig import EmailSubjectConfig from synapse.logging.context import make_deferred_yieldable @@ -317,9 +317,14 @@ class Mailer: async def get_room_vars( self, room_id, user_id, notifs, notif_events, room_state_ids ): - my_member_event_id = room_state_ids[("m.room.member", user_id)] - my_member_event = await self.store.get_event(my_member_event_id) - is_invite = my_member_event.content["membership"] == "invite" + # Check if one of the notifs is an invite event for the user. + is_invite = False + for n in notifs: + ev = notif_events[n["event_id"]] + if ev.type == EventTypes.Member and ev.state_key == user_id: + if ev.content.get("membership") == Membership.INVITE: + is_invite = True + break room_name = await calculate_room_name(self.store, room_state_ids, user_id) @@ -461,16 +466,26 @@ class Mailer: self.store, room_state_ids[room_id], user_id, fallback_to_members=False ) - my_member_event_id = room_state_ids[room_id][("m.room.member", user_id)] - my_member_event = await self.store.get_event(my_member_event_id) - if my_member_event.content["membership"] == "invite": - inviter_member_event_id = room_state_ids[room_id][ - ("m.room.member", my_member_event.sender) - ] - inviter_member_event = await self.store.get_event( - inviter_member_event_id + # See if one of the notifs is an invite event for the user + invite_event = None + for n in notifs_by_room[room_id]: + ev = notif_events[n["event_id"]] + if ev.type == EventTypes.Member and ev.state_key == user_id: + if ev.content.get("membership") == Membership.INVITE: + invite_event = ev + break + + if invite_event: + inviter_member_event_id = room_state_ids[room_id].get( + ("m.room.member", invite_event.sender) ) - inviter_name = name_from_member_event(inviter_member_event) + inviter_name = invite_event.sender + if inviter_member_event_id: + inviter_member_event = await self.store.get_event( + inviter_member_event_id, allow_none=True + ) + if inviter_member_event: + inviter_name = name_from_member_event(inviter_member_event) if room_name is None: return self.email_subjects.invite_from_person % { diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 0080c68ce2..f325964983 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -19,7 +19,10 @@ from typing import TYPE_CHECKING, Dict, Union from prometheus_client import Gauge -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.push import PusherConfigException from synapse.push.emailpusher import EmailPusher from synapse.push.httppusher import HttpPusher @@ -187,7 +190,7 @@ class PusherPool: ) await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) - async def on_new_notifications(self, max_token: RoomStreamToken): + def on_new_notifications(self, max_token: RoomStreamToken): if not self.pushers: # nothing to do here. return @@ -201,6 +204,17 @@ class PusherPool: # Nothing to do return + # We only start a new background process if necessary rather than + # optimistically (to cut down on overhead). + self._on_new_notifications(max_token) + + @wrap_as_background_process("on_new_notifications") + async def _on_new_notifications(self, max_token: RoomStreamToken): + # We just use the minimum stream ordering and ignore the vector clock + # component. This is safe to do as long as we *always* ignore the vector + # clock components. + max_stream_id = max_token.stream + prev_stream_id = self._last_room_stream_id_seen self._last_room_stream_id_seen = max_stream_id diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index e7cc74a5d2..f0c37eaf5e 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -77,8 +77,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): requester = Requester.deserialize(self.store, content["requester"]) - if requester.user: - request.authenticated_entity = requester.user.to_string() + request.requester = requester logger.info("remote_join: %s into room: %s", user_id, room_id) @@ -142,8 +141,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): requester = Requester.deserialize(self.store, content["requester"]) - if requester.user: - request.authenticated_entity = requester.user.to_string() + request.requester = requester # hopefully we're now on the master, so this won't recurse! event_id, stream_id = await self.member_handler.remote_reject_invite( diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index fc129dbaa7..8fa104c8d3 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -115,8 +115,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): ratelimit = content["ratelimit"] extra_users = [UserID.from_string(u) for u in content["extra_users"]] - if requester.user: - request.authenticated_entity = requester.user.to_string() + request.requester = requester logger.info( "Got event to send with ID: %s into room: %s", event.event_id, event.room_id diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e27ee216f0..2618eb1e53 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -141,21 +141,25 @@ class ReplicationDataHandler: if row.type != EventsStreamEventRow.TypeId: continue assert isinstance(row, EventsStreamRow) + assert isinstance(row.data, EventsStreamEventRow) - event = await self.store.get_event( - row.data.event_id, allow_rejected=True - ) - if event.rejected_reason: + if row.data.rejected: continue extra_users = () # type: Tuple[UserID, ...] - if event.type == EventTypes.Member: - extra_users = (UserID.from_string(event.state_key),) + if row.data.type == EventTypes.Member and row.data.state_key: + extra_users = (UserID.from_string(row.data.state_key),) max_token = self.store.get_room_max_token() event_pos = PersistedEventPosition(instance_name, token) - self.notifier.on_new_room_event( - event, event_pos, max_token, extra_users + self.notifier.on_new_room_event_args( + event_pos=event_pos, + max_room_stream_token=max_token, + extra_users=extra_users, + room_id=row.data.room_id, + event_type=row.data.type, + state_key=row.data.state_key, + membership=row.data.membership, ) # Notify any waiting deferreds. The list is ordered by position so we diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index de19705c1f..bc6ba709a7 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -166,7 +166,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): Args: cmd (Command) """ - run_as_background_process("send-cmd", self._async_send_command, cmd) + run_as_background_process( + "send-cmd", self._async_send_command, cmd, bg_start_span=False + ) async def _async_send_command(self, cmd: Command): """Encode a replication command and send it over our outbound connection""" diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 666c13fdb7..1d4ceac0f1 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -117,6 +117,16 @@ class ReplicationStreamer: stream.discard_updates_and_advance() return + # We check up front to see if anything has actually changed, as we get + # poked because of changes that happened on other instances. + if all( + stream.last_token == stream.current_token(self._instance_name) + for stream in self.streams + ): + return + + # If there are updates then we need to set this even if we're already + # looping, as the loop needs to know that he might need to loop again. self.pending_updates = True if self.is_looping: diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 82e9e0d64e..86a62b71eb 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -15,12 +15,15 @@ # limitations under the License. import heapq from collections.abc import Iterable -from typing import List, Tuple, Type +from typing import TYPE_CHECKING, List, Optional, Tuple, Type import attr from ._base import Stream, StreamUpdateResult, Token +if TYPE_CHECKING: + from synapse.server import HomeServer + """Handling of the 'events' replication stream This stream contains rows of various types. Each row therefore contains a 'type' @@ -81,12 +84,14 @@ class BaseEventsStreamRow: class EventsStreamEventRow(BaseEventsStreamRow): TypeId = "ev" - event_id = attr.ib() # str - room_id = attr.ib() # str - type = attr.ib() # str - state_key = attr.ib() # str, optional - redacts = attr.ib() # str, optional - relates_to = attr.ib() # str, optional + event_id = attr.ib(type=str) + room_id = attr.ib(type=str) + type = attr.ib(type=str) + state_key = attr.ib(type=Optional[str]) + redacts = attr.ib(type=Optional[str]) + relates_to = attr.ib(type=Optional[str]) + membership = attr.ib(type=Optional[str]) + rejected = attr.ib(type=bool) @attr.s(slots=True, frozen=True) @@ -113,7 +118,7 @@ class EventsStream(Stream): NAME = "events" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self._store = hs.get_datastore() super().__init__( hs.get_instance_name(), diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 789431ef25..fa7e9e4043 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -31,7 +31,10 @@ from synapse.rest.admin.devices import ( DeviceRestServlet, DevicesRestServlet, ) -from synapse.rest.admin.event_reports import EventReportsRestServlet +from synapse.rest.admin.event_reports import ( + EventReportDetailRestServlet, + EventReportsRestServlet, +) from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet @@ -47,9 +50,11 @@ from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet from synapse.rest.admin.users import ( AccountValidityRenewServlet, DeactivateAccountRestServlet, + PushersRestServlet, ResetPasswordRestServlet, SearchUsersRestServlet, UserAdminServlet, + UserMediaRestServlet, UserMembershipRestServlet, UserRegisterServlet, UserRestServletV2, @@ -215,13 +220,16 @@ def register_servlets(hs, http_server): SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) + UserMediaRestServlet(hs).register(http_server) UserMembershipRestServlet(hs).register(http_server) UserRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server) DeviceRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) DeleteDevicesRestServlet(hs).register(http_server) + EventReportDetailRestServlet(hs).register(http_server) EventReportsRestServlet(hs).register(http_server) + PushersRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py index a163863322..ffd3aa38f7 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py @@ -119,7 +119,7 @@ class DevicesRestServlet(RestServlet): raise NotFoundError("Unknown user") devices = await self.device_handler.get_devices_by_user(target_user.to_string()) - return 200, {"devices": devices} + return 200, {"devices": devices, "total": len(devices)} class DeleteDevicesRestServlet(RestServlet): diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py index 5b8d0594cd..fd482f0e32 100644 --- a/synapse/rest/admin/event_reports.py +++ b/synapse/rest/admin/event_reports.py @@ -15,7 +15,7 @@ import logging -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin @@ -86,3 +86,47 @@ class EventReportsRestServlet(RestServlet): ret["next_token"] = start + len(event_reports) return 200, ret + + +class EventReportDetailRestServlet(RestServlet): + """ + Get a specific reported event that is known to the homeserver. Results are returned + in a dictionary containing report information. + The requester must have administrator access in Synapse. + + GET /_synapse/admin/v1/event_reports/<report_id> + returns: + 200 OK with details report if success otherwise an error. + + Args: + The parameter `report_id` is the ID of the event report in the database. + Returns: + JSON blob of information about the event report + """ + + PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, report_id): + await assert_requester_is_admin(self.auth, request) + + message = ( + "The report_id parameter must be a string representing a positive integer." + ) + try: + report_id = int(report_id) + except ValueError: + raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + + if report_id < 0: + raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) + + ret = await self.store.get_event_report(report_id) + if not ret: + raise NotFoundError("Event report not found") + + return 200, ret diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index ee75095c0e..ba50cb876d 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -16,9 +16,10 @@ import logging -from synapse.api.errors import AuthError -from synapse.http.servlet import RestServlet, parse_integer +from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.http.servlet import RestServlet, parse_boolean, parse_integer from synapse.rest.admin._base import ( + admin_patterns, assert_requester_is_admin, assert_user_is_admin, historical_admin_path_patterns, @@ -150,6 +151,80 @@ class PurgeMediaCacheRestServlet(RestServlet): return 200, ret +class DeleteMediaByID(RestServlet): + """Delete local media by a given ID. Removes it from this server. + """ + + PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)") + + def __init__(self, hs): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.server_name = hs.hostname + self.media_repository = hs.get_media_repository() + + async def on_DELETE(self, request, server_name: str, media_id: str): + await assert_requester_is_admin(self.auth, request) + + if self.server_name != server_name: + raise SynapseError(400, "Can only delete local media") + + if await self.store.get_local_media(media_id) is None: + raise NotFoundError("Unknown media") + + logging.info("Deleting local media by ID: %s", media_id) + + deleted_media, total = await self.media_repository.delete_local_media(media_id) + return 200, {"deleted_media": deleted_media, "total": total} + + +class DeleteMediaByDateSize(RestServlet): + """Delete local media and local copies of remote media by + timestamp and size. + """ + + PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete") + + def __init__(self, hs): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.server_name = hs.hostname + self.media_repository = hs.get_media_repository() + + async def on_POST(self, request, server_name: str): + await assert_requester_is_admin(self.auth, request) + + before_ts = parse_integer(request, "before_ts", required=True) + size_gt = parse_integer(request, "size_gt", default=0) + keep_profiles = parse_boolean(request, "keep_profiles", default=True) + + if before_ts < 0: + raise SynapseError( + 400, + "Query parameter before_ts must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + if size_gt < 0: + raise SynapseError( + 400, + "Query parameter size_gt must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if self.server_name != server_name: + raise SynapseError(400, "Can only delete local media") + + logging.info( + "Deleting local media by timestamp: %s, size larger than: %s, keep profile media: %s" + % (before_ts, size_gt, keep_profiles) + ) + + deleted_media, total = await self.media_repository.delete_old_local_media( + before_ts, size_gt, keep_profiles + ) + return 200, {"deleted_media": deleted_media, "total": total} + + def register_servlets_for_media_repo(hs, http_server): """ Media repo specific APIs. @@ -159,3 +234,5 @@ def register_servlets_for_media_repo(hs, http_server): QuarantineMediaByID(hs).register(http_server) QuarantineMediaByUser(hs).register(http_server) ListMediaInRoom(hs).register(http_server) + DeleteMediaByID(hs).register(http_server) + DeleteMediaByDateSize(hs).register(http_server) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 8efefbc0a0..b337311a37 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -16,6 +16,7 @@ import hashlib import hmac import logging from http import HTTPStatus +from typing import Tuple from synapse.api.constants import UserTypes from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -27,16 +28,28 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.rest.admin._base import ( admin_patterns, assert_requester_is_admin, assert_user_is_admin, historical_admin_path_patterns, ) -from synapse.types import UserID +from synapse.types import JsonDict, UserID logger = logging.getLogger(__name__) +_GET_PUSHERS_ALLOWED_KEYS = { + "app_display_name", + "app_id", + "data", + "device_display_name", + "kind", + "lang", + "profile_tag", + "pushkey", +} + class UsersRestServlet(RestServlet): PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$") @@ -702,9 +715,114 @@ class UserMembershipRestServlet(RestServlet): if not self.is_mine(UserID.from_string(user_id)): raise SynapseError(400, "Can only lookup local users") + user = await self.store.get_user_by_id(user_id) + if user is None: + raise NotFoundError("Unknown user") + room_ids = await self.store.get_rooms_for_user(user_id) - if not room_ids: + ret = {"joined_rooms": list(room_ids), "total": len(room_ids)} + return 200, ret + + +class PushersRestServlet(RestServlet): + """ + Gets information about all pushers for a specific `user_id`. + + Example: + http://localhost:8008/_synapse/admin/v1/users/ + @user:server/pushers + + Returns: + pushers: Dictionary containing pushers information. + total: Number of pushers in dictonary `pushers`. + """ + + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$") + + def __init__(self, hs): + self.is_mine = hs.is_mine + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + + if not self.is_mine(UserID.from_string(user_id)): + raise SynapseError(400, "Can only lookup local users") + + if not await self.store.get_user_by_id(user_id): raise NotFoundError("User not found") - ret = {"joined_rooms": list(room_ids), "total": len(room_ids)} + pushers = await self.store.get_pushers_by_user_id(user_id) + + filtered_pushers = [ + {k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS} + for p in pushers + ] + + return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)} + + +class UserMediaRestServlet(RestServlet): + """ + Gets information about all uploaded local media for a specific `user_id`. + + Example: + http://localhost:8008/_synapse/admin/v1/users/ + @user:server/media + + Args: + The parameters `from` and `limit` are required for pagination. + By default, a `limit` of 100 is used. + Returns: + A list of media and an integer representing the total number of + media that exist given for this user + """ + + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$") + + def __init__(self, hs): + self.is_mine = hs.is_mine + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET( + self, request: SynapseRequest, user_id: str + ) -> Tuple[int, JsonDict]: + await assert_requester_is_admin(self.auth, request) + + if not self.is_mine(UserID.from_string(user_id)): + raise SynapseError(400, "Can only lookup local users") + + user = await self.store.get_user_by_id(user_id) + if user is None: + raise NotFoundError("Unknown user") + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + 400, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + media, total = await self.store.get_local_media_by_user_paginate( + start, limit, user_id + ) + + ret = {"media": media, "total": total} + if (start + limit) < total: + ret["next_token"] = start + len(media) + return 200, ret diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 1ecb77aa26..6de4078290 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -67,9 +67,6 @@ class EventStreamRestServlet(RestServlet): return 200, chunk - def on_OPTIONS(self, request): - return 200, {} - class EventRestServlet(RestServlet): PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index b82a4e978a..94452fcbf5 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -114,9 +114,6 @@ class LoginRestServlet(RestServlet): return 200, {"flows": flows} - def on_OPTIONS(self, request: SynapseRequest): - return 200, {} - async def on_POST(self, request: SynapseRequest): self._address_ratelimiter.ratelimit(request.getClientIP()) diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index f792b50cdc..ad8cea49c6 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -30,9 +30,6 @@ class LogoutRestServlet(RestServlet): self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() - def on_OPTIONS(self, request): - return 200, {} - async def on_POST(self, request): requester = await self.auth.get_user_by_req(request, allow_expired=True) @@ -58,9 +55,6 @@ class LogoutAllRestServlet(RestServlet): self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() - def on_OPTIONS(self, request): - return 200, {} - async def on_POST(self, request): requester = await self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 79d8e3057f..23a529f8e3 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -86,9 +86,6 @@ class PresenceStatusRestServlet(RestServlet): return 200, {} - def on_OPTIONS(self, request): - return 200, {} - def register_servlets(hs, http_server): PresenceStatusRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index e7fcd2b1ff..85a66458c5 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -67,9 +67,6 @@ class ProfileDisplaynameRestServlet(RestServlet): return 200, {} - def on_OPTIONS(self, request, user_id): - return 200, {} - class ProfileAvatarURLRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True) @@ -118,9 +115,6 @@ class ProfileAvatarURLRestServlet(RestServlet): return 200, {} - def on_OPTIONS(self, request, user_id): - return 200, {} - class ProfileRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index f9eecb7cf5..241e535917 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -155,9 +155,6 @@ class PushRuleRestServlet(RestServlet): else: raise UnrecognizedRequestError() - def on_OPTIONS(self, request, path): - return 200, {} - def notify_user(self, user_id): stream_id = self.store.get_max_push_rules_stream_id() self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 28dabf1c7a..8fe83f321a 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -60,9 +60,6 @@ class PushersRestServlet(RestServlet): return 200, {"pushers": filtered_pushers} - def on_OPTIONS(self, _): - return 200, {} - class PushersSetRestServlet(RestServlet): PATTERNS = client_patterns("/pushers/set$", v1=True) @@ -140,9 +137,6 @@ class PushersSetRestServlet(RestServlet): return 200, {} - def on_OPTIONS(self, _): - return 200, {} - class PushersRemoveRestServlet(RestServlet): """ @@ -182,9 +176,6 @@ class PushersRemoveRestServlet(RestServlet): ) return None - def on_OPTIONS(self, _): - return 200, {} - def register_servlets(hs, http_server): PushersRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 00b4397082..25d3cc6148 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -72,20 +72,6 @@ class RoomCreateRestServlet(TransactionRestServlet): def register(self, http_server): PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) - # define CORS for all of /rooms in RoomCreateRestServlet for simplicity - http_server.register_paths( - "OPTIONS", - client_patterns("/rooms(?:/.*)?$", v1=True), - self.on_OPTIONS, - self.__class__.__name__, - ) - # define CORS for /createRoom[/txnid] - http_server.register_paths( - "OPTIONS", - client_patterns("/createRoom(?:/.*)?$", v1=True), - self.on_OPTIONS, - self.__class__.__name__, - ) def on_PUT(self, request, txn_id): set_tag("txn_id", txn_id) @@ -104,9 +90,6 @@ class RoomCreateRestServlet(TransactionRestServlet): user_supplied_config = parse_json_object_from_request(request) return user_supplied_config - def on_OPTIONS(self, request): - return 200, {} - # TODO: Needs unit testing for generic events class RoomStateEventRestServlet(TransactionRestServlet): diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index b8d491ca5c..d07ca2c47c 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -69,9 +69,6 @@ class VoipRestServlet(RestServlet): }, ) - def on_OPTIONS(self, request): - return 200, {} - def register_servlets(hs, http_server): VoipRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index e857cff176..51effc4d8e 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -268,9 +268,6 @@ class PasswordRestServlet(RestServlet): return 200, {} - def on_OPTIONS(self, _): - return 200, {} - class DeactivateAccountRestServlet(RestServlet): PATTERNS = client_patterns("/account/deactivate$") diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 5fbfae5991..fab077747f 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -176,9 +176,6 @@ class AuthRestServlet(RestServlet): respond_with_html(request, 200, html) return None - def on_OPTIONS(self, _): - return 200, {} - def register_servlets(hs, http_server): AuthRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 395b6a82a9..8f2c8cd991 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -642,9 +642,6 @@ class RegisterRestServlet(RestServlet): return 200, return_dict - def on_OPTIONS(self, _): - return 200, {} - async def _do_appservice_registration(self, username, as_token, body): user_id = await self.registration_handler.appservice_register( username, as_token diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index 7447eeaebe..9e079f672f 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -69,6 +69,23 @@ class MediaFilePaths: local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) + def local_media_thumbnail_dir(self, media_id: str) -> str: + """ + Retrieve the local store path of thumbnails of a given media_id + + Args: + media_id: The media ID to query. + Returns: + Path of local_thumbnails from media_id + """ + return os.path.join( + self.base_path, + "local_thumbnails", + media_id[0:2], + media_id[2:4], + media_id[4:], + ) + def remote_media_filepath_rel(self, server_name, file_id): return os.path.join( "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index e1192b47cd..9cac74ebd8 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -18,7 +18,7 @@ import errno import logging import os import shutil -from typing import IO, Dict, Optional, Tuple +from typing import IO, Dict, List, Optional, Tuple import twisted.internet.error import twisted.web.http @@ -305,15 +305,12 @@ class MediaRepository: # file_id is the ID we use to track the file locally. If we've already # seen the file then reuse the existing ID, otherwise genereate a new # one. - if media_info: - file_id = media_info["filesystem_id"] - else: - file_id = random_string(24) - - file_info = FileInfo(server_name, file_id) # If we have an entry in the DB, try and look for it if media_info: + file_id = media_info["filesystem_id"] + file_info = FileInfo(server_name, file_id) + if media_info["quarantined_by"]: logger.info("Media is quarantined") raise NotFoundError() @@ -324,14 +321,34 @@ class MediaRepository: # Failed to find the file anywhere, lets download it. - media_info = await self._download_remote_file(server_name, media_id, file_id) + try: + media_info = await self._download_remote_file(server_name, media_id,) + except SynapseError: + raise + except Exception as e: + # An exception may be because we downloaded media in another + # process, so let's check if we magically have the media. + media_info = await self.store.get_cached_remote_media(server_name, media_id) + if not media_info: + raise e + + file_id = media_info["filesystem_id"] + file_info = FileInfo(server_name, file_id) + + # We generate thumbnails even if another process downloaded the media + # as a) it's conceivable that the other download request dies before it + # generates thumbnails, but mainly b) we want to be sure the thumbnails + # have finished being generated before responding to the client, + # otherwise they'll request thumbnails and get a 404 if they're not + # ready yet. + await self._generate_thumbnails( + server_name, media_id, file_id, media_info["media_type"] + ) responder = await self.media_storage.fetch_media(file_info) return responder, media_info - async def _download_remote_file( - self, server_name: str, media_id: str, file_id: str - ) -> dict: + async def _download_remote_file(self, server_name: str, media_id: str,) -> dict: """Attempt to download the remote file from the given server name, using the given file_id as the local id. @@ -346,6 +363,8 @@ class MediaRepository: The media info of the file. """ + file_id = random_string(24) + file_info = FileInfo(server_name=server_name, file_id=file_id) with self.media_storage.store_into_file(file_info) as (f, fname, finish): @@ -401,22 +420,32 @@ class MediaRepository: await finish() - media_type = headers[b"Content-Type"][0].decode("ascii") - upload_name = get_filename_from_headers(headers) - time_now_ms = self.clock.time_msec() + media_type = headers[b"Content-Type"][0].decode("ascii") + upload_name = get_filename_from_headers(headers) + time_now_ms = self.clock.time_msec() + + # Multiple remote media download requests can race (when using + # multiple media repos), so this may throw a violation constraint + # exception. If it does we'll delete the newly downloaded file from + # disk (as we're in the ctx manager). + # + # However: we've already called `finish()` so we may have also + # written to the storage providers. This is preferable to the + # alternative where we call `finish()` *after* this, where we could + # end up having an entry in the DB but fail to write the files to + # the storage providers. + await self.store.store_cached_remote_media( + origin=server_name, + media_id=media_id, + media_type=media_type, + time_now_ms=self.clock.time_msec(), + upload_name=upload_name, + media_length=length, + filesystem_id=file_id, + ) logger.info("Stored remote media in file %r", fname) - await self.store.store_cached_remote_media( - origin=server_name, - media_id=media_id, - media_type=media_type, - time_now_ms=self.clock.time_msec(), - upload_name=upload_name, - media_length=length, - filesystem_id=file_id, - ) - media_info = { "media_type": media_type, "media_length": length, @@ -425,8 +454,6 @@ class MediaRepository: "filesystem_id": file_id, } - await self._generate_thumbnails(server_name, media_id, file_id, media_type) - return media_info def _get_thumbnail_requirements(self, media_type): @@ -692,42 +719,60 @@ class MediaRepository: if not t_byte_source: continue - try: - file_info = FileInfo( - server_name=server_name, - file_id=file_id, - thumbnail=True, - thumbnail_width=t_width, - thumbnail_height=t_height, - thumbnail_method=t_method, - thumbnail_type=t_type, - url_cache=url_cache, - ) - - output_path = await self.media_storage.store_file( - t_byte_source, file_info - ) - finally: - t_byte_source.close() - - t_len = os.path.getsize(output_path) + file_info = FileInfo( + server_name=server_name, + file_id=file_id, + thumbnail=True, + thumbnail_width=t_width, + thumbnail_height=t_height, + thumbnail_method=t_method, + thumbnail_type=t_type, + url_cache=url_cache, + ) - # Write to database - if server_name: - await self.store.store_remote_media_thumbnail( - server_name, - media_id, - file_id, - t_width, - t_height, - t_type, - t_method, - t_len, - ) - else: - await self.store.store_local_thumbnail( - media_id, t_width, t_height, t_type, t_method, t_len - ) + with self.media_storage.store_into_file(file_info) as (f, fname, finish): + try: + await self.media_storage.write_to_file(t_byte_source, f) + await finish() + finally: + t_byte_source.close() + + t_len = os.path.getsize(fname) + + # Write to database + if server_name: + # Multiple remote media download requests can race (when + # using multiple media repos), so this may throw a violation + # constraint exception. If it does we'll delete the newly + # generated thumbnail from disk (as we're in the ctx + # manager). + # + # However: we've already called `finish()` so we may have + # also written to the storage providers. This is preferable + # to the alternative where we call `finish()` *after* this, + # where we could end up having an entry in the DB but fail + # to write the files to the storage providers. + try: + await self.store.store_remote_media_thumbnail( + server_name, + media_id, + file_id, + t_width, + t_height, + t_type, + t_method, + t_len, + ) + except Exception as e: + thumbnail_exists = await self.store.get_remote_media_thumbnail( + server_name, media_id, t_width, t_height, t_type, + ) + if not thumbnail_exists: + raise e + else: + await self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len + ) return {"width": m_width, "height": m_height} @@ -767,6 +812,76 @@ class MediaRepository: return {"deleted": deleted} + async def delete_local_media(self, media_id: str) -> Tuple[List[str], int]: + """ + Delete the given local or remote media ID from this server + + Args: + media_id: The media ID to delete. + Returns: + A tuple of (list of deleted media IDs, total deleted media IDs). + """ + return await self._remove_local_media_from_disk([media_id]) + + async def delete_old_local_media( + self, before_ts: int, size_gt: int = 0, keep_profiles: bool = True, + ) -> Tuple[List[str], int]: + """ + Delete local or remote media from this server by size and timestamp. Removes + media files, any thumbnails and cached URLs. + + Args: + before_ts: Unix timestamp in ms. + Files that were last used before this timestamp will be deleted + size_gt: Size of the media in bytes. Files that are larger will be deleted + keep_profiles: Switch to delete also files that are still used in image data + (e.g user profile, room avatar) + If false these files will be deleted + Returns: + A tuple of (list of deleted media IDs, total deleted media IDs). + """ + old_media = await self.store.get_local_media_before( + before_ts, size_gt, keep_profiles, + ) + return await self._remove_local_media_from_disk(old_media) + + async def _remove_local_media_from_disk( + self, media_ids: List[str] + ) -> Tuple[List[str], int]: + """ + Delete local or remote media from this server. Removes media files, + any thumbnails and cached URLs. + + Args: + media_ids: List of media_id to delete + Returns: + A tuple of (list of deleted media IDs, total deleted media IDs). + """ + removed_media = [] + for media_id in media_ids: + logger.info("Deleting media with ID '%s'", media_id) + full_path = self.filepaths.local_media_filepath(media_id) + try: + os.remove(full_path) + except OSError as e: + logger.warning("Failed to remove file: %r: %s", full_path, e) + if e.errno == errno.ENOENT: + pass + else: + continue + + thumbnail_dir = self.filepaths.local_media_thumbnail_dir(media_id) + shutil.rmtree(thumbnail_dir, ignore_errors=True) + + await self.store.delete_remote_media(self.server_name, media_id) + + await self.store.delete_url_cache((media_id,)) + await self.store.delete_url_cache_media((media_id,)) + + removed_media.append(media_id) + + return removed_media, len(removed_media) + class MediaRepositoryResource(Resource): """File uploading and downloading. diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index a9586fb0b7..268e0c8f50 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -52,6 +52,7 @@ class MediaStorage: storage_providers: Sequence["StorageProviderWrapper"], ): self.hs = hs + self.reactor = hs.get_reactor() self.local_media_directory = local_media_directory self.filepaths = filepaths self.storage_providers = storage_providers @@ -70,13 +71,16 @@ class MediaStorage: with self.store_into_file(file_info) as (f, fname, finish_cb): # Write to the main repository - await defer_to_thread( - self.hs.get_reactor(), _write_file_synchronously, source, f - ) + await self.write_to_file(source, f) await finish_cb() return fname + async def write_to_file(self, source: IO, output: IO): + """Asynchronously write the `source` to `output`. + """ + await defer_to_thread(self.reactor, _write_file_synchronously, source, output) + @contextlib.contextmanager def store_into_file(self, file_info: FileInfo): """Context manager used to get a file like object to write into, as @@ -112,14 +116,20 @@ class MediaStorage: finished_called = [False] - async def finish(): - for provider in self.storage_providers: - await provider.store_file(path, file_info) - - finished_called[0] = True - try: with open(fname, "wb") as f: + + async def finish(): + # Ensure that all writes have been flushed and close the + # file. + f.flush() + f.close() + + for provider in self.storage_providers: + await provider.store_file(path, file_info) + + finished_called[0] = True + yield f, fname, finish except Exception: try: @@ -210,7 +220,7 @@ class MediaStorage: if res: with res: consumer = BackgroundFileConsumer( - open(local_path, "wb"), self.hs.get_reactor() + open(local_path, "wb"), self.reactor ) await res.write_to_consumer(consumer) await consumer.wait() diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 3673e7f47e..9137c4edb1 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -104,7 +104,7 @@ class ConsentServerNotices: def copy_with_str_subst(x: Any, substitutions: Any) -> Any: - """Deep-copy a structure, carrying out string substitions on any strings + """Deep-copy a structure, carrying out string substitutions on any strings Args: x (object): structure to be copied diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 5b0900aa3c..1fa3b280b4 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -547,7 +547,7 @@ class StateResolutionHandler: event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be - used as a starting point fof finding the state we need; any missing + used as a starting point for finding the state we need; any missing events will be requested via state_res_store. If None, all events will be fetched via state_res_store. diff --git a/synapse/state/v1.py b/synapse/state/v1.py index a493279cbd..85edae053d 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -56,7 +56,7 @@ async def resolve_events_with_store( event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be - used as a starting point fof finding the state we need; any missing + used as a starting point for finding the state we need; any missing events will be requested via state_map_factory. If None, all events will be fetched via state_map_factory. diff --git a/synapse/state/v2.py b/synapse/state/v2.py index edf94e7ad6..f57df0d728 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -69,7 +69,7 @@ async def resolve_events_with_store( event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be - used as a starting point fof finding the state we need; any missing + used as a starting point for finding the state we need; any missing events will be requested via state_res_store. If None, all events will be fetched via state_res_store. diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js index 3678670ec7..744800ec77 100644 --- a/synapse/static/client/login/js/login.js +++ b/synapse/static/client/login/js/login.js @@ -182,7 +182,7 @@ matrixLogin.passwordLogin = function() { }; /* - * The onLogin function gets called after a succesful login. + * The onLogin function gets called after a successful login. * * It is expected that implementations override this to be notified when the * login is complete. The response to the login call is provided as the single diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 0217e63108..a0572b2952 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -94,7 +94,7 @@ def make_pool( cp_openfun=lambda conn: engine.on_new_connection( LoggingDatabaseConnection(conn, engine, "on_new_connection") ), - **db_config.config.get("args", {}) + **db_config.config.get("args", {}), ) @@ -632,7 +632,7 @@ class DatabasePool: func, *args, db_autocommit=db_autocommit, - **kwargs + **kwargs, ) for after_callback, after_args, after_kwargs in after_callbacks: diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 9b16f45f3e..43660ec4fb 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -146,7 +146,6 @@ class DataStore( db_conn, "e2e_cross_signing_keys", "stream_id" ) - self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 637a938bac..26eef6eb61 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -15,21 +15,31 @@ # limitations under the License. import logging import re -from typing import List +from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple -from synapse.appservice import ApplicationService, AppServiceTransaction +from synapse.appservice import ( + ApplicationService, + ApplicationServiceState, + AppServiceTransaction, +) from synapse.config.appservice import load_appservices from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.types import Connection from synapse.types import JsonDict from synapse.util import json_encoder +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) -def _make_exclusive_regex(services_cache): +def _make_exclusive_regex( + services_cache: List[ApplicationService], +) -> Optional[Pattern]: # We precompile a regex constructed from all the regexes that the AS's # have registered for exclusive users. exclusive_user_regexes = [ @@ -39,17 +49,19 @@ def _make_exclusive_regex(services_cache): ] if exclusive_user_regexes: exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) - exclusive_user_regex = re.compile(exclusive_user_regex) + exclusive_user_pattern = re.compile( + exclusive_user_regex + ) # type: Optional[Pattern] else: # We handle this case specially otherwise the constructed regex # will always match - exclusive_user_regex = None + exclusive_user_pattern = None - return exclusive_user_regex + return exclusive_user_pattern class ApplicationServiceWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): self.services_cache = load_appservices( hs.hostname, hs.config.app_service_config_files ) @@ -60,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore): def get_app_services(self): return self.services_cache - def get_if_app_services_interested_in_user(self, user_id): + def get_if_app_services_interested_in_user(self, user_id: str) -> bool: """Check if the user is one associated with an app service (exclusively) """ if self.exclusive_user_regex: @@ -68,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore): else: return False - def get_app_service_by_user_id(self, user_id): + def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]: """Retrieve an application service from their user ID. All application services have associated with them a particular user ID. @@ -77,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore): a user ID to an application service. Args: - user_id(str): The user ID to see if it is an application service. + user_id: The user ID to see if it is an application service. Returns: - synapse.appservice.ApplicationService or None. + The application service or None. """ for service in self.services_cache: if service.sender == user_id: return service return None - def get_app_service_by_token(self, token): + def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]: """Get the application service with the given appservice token. Args: - token (str): The application service token. + token: The application service token. Returns: - synapse.appservice.ApplicationService or None. + The application service or None. """ for service in self.services_cache: if service.token == token: return service return None - def get_app_service_by_id(self, as_id): + def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]: """Get the application service with the given appservice ID. Args: - as_id (str): The application service ID. + as_id: The application service ID. Returns: - synapse.appservice.ApplicationService or None. + The application service or None. """ for service in self.services_cache: if service.id == as_id: @@ -124,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore): class ApplicationServiceTransactionWorkerStore( ApplicationServiceWorkerStore, EventsWorkerStore ): - async def get_appservices_by_state(self, state): + async def get_appservices_by_state( + self, state: ApplicationServiceState + ) -> List[ApplicationService]: """Get a list of application services based on their state. Args: - state(ApplicationServiceState): The state to filter on. + state: The state to filter on. Returns: A list of ApplicationServices, which may be empty. """ @@ -145,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore( services.append(service) return services - async def get_appservice_state(self, service): + async def get_appservice_state( + self, service: ApplicationService + ) -> Optional[ApplicationServiceState]: """Get the application service state. Args: - service(ApplicationService): The service whose state to set. + service: The service whose state to set. Returns: - An ApplicationServiceState. + An ApplicationServiceState or none. """ result = await self.db_pool.simple_select_one( "application_services_state", @@ -164,12 +180,14 @@ class ApplicationServiceTransactionWorkerStore( return result.get("state") return None - async def set_appservice_state(self, service, state) -> None: + async def set_appservice_state( + self, service: ApplicationService, state: ApplicationServiceState + ) -> None: """Set the application service state. Args: - service(ApplicationService): The service whose state to set. - state(ApplicationServiceState): The connectivity state to apply. + service: The service whose state to set. + state: The connectivity state to apply. """ await self.db_pool.simple_upsert( "application_services_state", {"as_id": service.id}, {"state": state} @@ -226,13 +244,14 @@ class ApplicationServiceTransactionWorkerStore( "create_appservice_txn", _create_appservice_txn ) - async def complete_appservice_txn(self, txn_id, service) -> None: + async def complete_appservice_txn( + self, txn_id: int, service: ApplicationService + ) -> None: """Completes an application service transaction. Args: - txn_id(str): The transaction ID being completed. - service(ApplicationService): The application service which was sent - this transaction. + txn_id: The transaction ID being completed. + service: The application service which was sent this transaction. """ txn_id = int(txn_id) @@ -242,7 +261,7 @@ class ApplicationServiceTransactionWorkerStore( # has probably missed some events), so whine loudly but still continue, # since it shouldn't fail completion of the transaction. last_txn_id = self._get_last_txn(txn, service.id) - if (last_txn_id + 1) != txn_id: + if (txn_id + 1) != txn_id: logger.error( "appservice: Completing a transaction which has an ID > 1 from " "the last ID sent to this AS. We've either dropped events or " @@ -272,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore( "complete_appservice_txn", _complete_appservice_txn ) - async def get_oldest_unsent_txn(self, service): - """Get the oldest transaction which has not been sent for this - service. + async def get_oldest_unsent_txn( + self, service: ApplicationService + ) -> Optional[AppServiceTransaction]: + """Get the oldest transaction which has not been sent for this service. Args: - service(ApplicationService): The app service to get the oldest txn. + service: The app service to get the oldest txn. Returns: An AppServiceTransaction or None. """ @@ -313,7 +333,7 @@ class ApplicationServiceTransactionWorkerStore( service=service, id=entry["txn_id"], events=events, ephemeral=[] ) - def _get_last_txn(self, txn, service_id): + def _get_last_txn(self, txn, service_id: Optional[str]) -> int: txn.execute( "SELECT last_txn FROM application_services_state WHERE as_id=?", (service_id,), @@ -324,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore( else: return int(last_txn_id[0]) # select 'last_txn' col - async def set_appservice_last_pos(self, pos) -> None: + async def set_appservice_last_pos(self, pos: int) -> None: def set_appservice_last_pos_txn(txn): txn.execute( "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,) @@ -334,7 +354,9 @@ class ApplicationServiceTransactionWorkerStore( "set_appservice_last_pos", set_appservice_last_pos_txn ) - async def get_new_events_for_appservice(self, current_id, limit): + async def get_new_events_for_appservice( + self, current_id: int, limit: int + ) -> Tuple[int, List[EventBase]]: """Get all new events for an appservice""" def get_new_events_for_appservice_txn(txn): @@ -394,7 +416,7 @@ class ApplicationServiceTransactionWorkerStore( ) async def set_type_stream_id_for_appservice( - self, service: ApplicationService, type: str, pos: int + self, service: ApplicationService, type: str, pos: Optional[int] ) -> None: if type not in ("read_receipt", "presence"): raise ValueError( diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 5e4af2eb51..97b6754846 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -92,6 +92,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): where_clause="NOT have_censored", ) + self.db_pool.updates.register_background_index_update( + "users_have_local_media", + index_name="users_have_local_media", + table="local_media_repository", + columns=["user_id", "created_ts"], + ) + async def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 6e7f16f39c..4732685f6e 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -31,6 +31,7 @@ from synapse.api.room_versions import ( RoomVersions, ) from synapse.events import EventBase, make_event_from_dict +from synapse.events.snapshot import EventContext from synapse.events.utils import prune_event from synapse.logging.context import PreserveLoggingContext, current_context from synapse.metrics.background_process_metrics import ( @@ -44,7 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator -from synapse.types import Collection, get_domain_from_id +from synapse.types import Collection, JsonDict, get_domain_from_id from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter @@ -525,6 +526,57 @@ class EventsWorkerStore(SQLBaseStore): return event_map + async def get_stripped_room_state_from_event_context( + self, + context: EventContext, + state_types_to_include: List[EventTypes], + membership_user_id: Optional[str] = None, + ) -> List[JsonDict]: + """ + Retrieve the stripped state from a room, given an event context to retrieve state + from as well as the state types to include. Optionally, include the membership + events from a specific user. + + "Stripped" state means that only the `type`, `state_key`, `content` and `sender` keys + are included from each state event. + + Args: + context: The event context to retrieve state of the room from. + state_types_to_include: The type of state events to include. + membership_user_id: An optional user ID to include the stripped membership state + events of. This is useful when generating the stripped state of a room for + invites. We want to send membership events of the inviter, so that the + invitee can display the inviter's profile information if the room lacks any. + + Returns: + A list of dictionaries, each representing a stripped state event from the room. + """ + current_state_ids = await context.get_current_state_ids() + + # We know this event is not an outlier, so this must be + # non-None. + assert current_state_ids is not None + + # The state to include + state_to_include_ids = [ + e_id + for k, e_id in current_state_ids.items() + if k[0] in state_types_to_include + or (membership_user_id and k == (EventTypes.Member, membership_user_id)) + ] + + state_to_include = await self.get_events(state_to_include_ids) + + return [ + { + "type": e.type, + "state_key": e.state_key, + "content": e.content, + "sender": e.sender, + } + for e in state_to_include.values() + ] + def _do_fetch(self, conn): """Takes a database connection and waits for requests for events from the _event_fetch_list queue. @@ -1065,11 +1117,13 @@ class EventsWorkerStore(SQLBaseStore): def get_all_new_forward_event_rows(txn): sql = ( "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" + " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" " FROM events AS e" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" + " LEFT JOIN room_memberships USING (event_id)" + " LEFT JOIN rejections USING (event_id)" " WHERE ? < stream_ordering AND stream_ordering <= ?" " AND instance_name = ?" " ORDER BY stream_ordering ASC" @@ -1100,12 +1154,14 @@ class EventsWorkerStore(SQLBaseStore): def get_ex_outlier_stream_rows_txn(txn): sql = ( "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," - " state_key, redacts, relates_to_id" + " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" " FROM events AS e" " INNER JOIN ex_outlier_stream AS out USING (event_id)" " LEFT JOIN redactions USING (event_id)" " LEFT JOIN state_events USING (event_id)" " LEFT JOIN event_relations USING (event_id)" + " LEFT JOIN room_memberships USING (event_id)" + " LEFT JOIN rejections USING (event_id)" " WHERE ? < event_stream_ordering" " AND event_stream_ordering <= ?" " AND out.instance_name = ?" diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index cc538c5c10..4b2f224718 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -93,6 +93,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) + self.server_name = hs.hostname async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]: """Get the metadata for a local piece of media @@ -115,6 +116,109 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_local_media", ) + async def get_local_media_by_user_paginate( + self, start: int, limit: int, user_id: str + ) -> Tuple[List[Dict[str, Any]], int]: + """Get a paginated list of metadata for a local piece of media + which an user_id has uploaded + + Args: + start: offset in the list + limit: maximum amount of media_ids to retrieve + user_id: fully-qualified user id + Returns: + A paginated list of all metadata of user's media, + plus the total count of all the user's media + """ + + def get_local_media_by_user_paginate_txn(txn): + + args = [user_id] + sql = """ + SELECT COUNT(*) as total_media + FROM local_media_repository + WHERE user_id = ? + """ + txn.execute(sql, args) + count = txn.fetchone()[0] + + sql = """ + SELECT + "media_id", + "media_type", + "media_length", + "upload_name", + "created_ts", + "last_access_ts", + "quarantined_by", + "safe_from_quarantine" + FROM local_media_repository + WHERE user_id = ? + ORDER BY created_ts DESC, media_id DESC + LIMIT ? OFFSET ? + """ + + args += [limit, start] + txn.execute(sql, args) + media = self.db_pool.cursor_to_dict(txn) + return media, count + + return await self.db_pool.runInteraction( + "get_local_media_by_user_paginate_txn", get_local_media_by_user_paginate_txn + ) + + async def get_local_media_before( + self, before_ts: int, size_gt: int, keep_profiles: bool, + ) -> Optional[List[str]]: + + # to find files that have never been accessed (last_access_ts IS NULL) + # compare with `created_ts` + sql = """ + SELECT media_id + FROM local_media_repository AS lmr + WHERE + ( last_access_ts < ? + OR ( created_ts < ? AND last_access_ts IS NULL ) ) + AND media_length > ? + """ + + if keep_profiles: + sql_keep = """ + AND ( + NOT EXISTS + (SELECT 1 + FROM profiles + WHERE profiles.avatar_url = '{media_prefix}' || lmr.media_id) + AND NOT EXISTS + (SELECT 1 + FROM groups + WHERE groups.avatar_url = '{media_prefix}' || lmr.media_id) + AND NOT EXISTS + (SELECT 1 + FROM room_memberships + WHERE room_memberships.avatar_url = '{media_prefix}' || lmr.media_id) + AND NOT EXISTS + (SELECT 1 + FROM user_directory + WHERE user_directory.avatar_url = '{media_prefix}' || lmr.media_id) + AND NOT EXISTS + (SELECT 1 + FROM room_stats_state + WHERE room_stats_state.avatar = '{media_prefix}' || lmr.media_id) + ) + """.format( + media_prefix="mxc://%s/" % (self.server_name,), + ) + sql += sql_keep + + def _get_local_media_before_txn(txn): + txn.execute(sql, (before_ts, before_ts, size_gt)) + return [row[0] for row in txn] + + return await self.db_pool.runInteraction( + "get_local_media_before", _get_local_media_before_txn + ) + async def store_local_media( self, media_id, @@ -348,6 +452,33 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="get_remote_media_thumbnails", ) + async def get_remote_media_thumbnail( + self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str, + ) -> Optional[Dict[str, Any]]: + """Fetch the thumbnail info of given width, height and type. + """ + + return await self.db_pool.simple_select_one( + table="remote_media_cache_thumbnails", + keyvalues={ + "media_origin": origin, + "media_id": media_id, + "thumbnail_width": t_width, + "thumbnail_height": t_height, + "thumbnail_type": t_type, + }, + retcols=( + "thumbnail_width", + "thumbnail_height", + "thumbnail_method", + "thumbnail_type", + "thumbnail_length", + "filesystem_id", + ), + allow_none=True, + desc="get_remote_media_thumbnail", + ) + async def store_remote_media_thumbnail( self, origin, diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index a6d1eb908a..0e25ca3d7a 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -39,7 +39,7 @@ class ProfileWorkerStore(SQLBaseStore): avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) - async def get_profile_displayname(self, user_localpart: str) -> str: + async def get_profile_displayname(self, user_localpart: str) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, @@ -47,7 +47,7 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_profile_displayname", ) - async def get_profile_avatar_url(self, user_localpart: str) -> str: + async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( table="profiles", keyvalues={"user_id": user_localpart}, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 4c843b7679..e5d07ce72a 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -16,29 +16,64 @@ # limitations under the License. import logging import re -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import attr from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool -from synapse.storage.types import Cursor +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore +from synapse.storage.databases.main.stats import StatsStore +from synapse.storage.types import Connection, Cursor +from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import UserID from synapse.util.caches.descriptors import cached +if TYPE_CHECKING: + from synapse.server import HomeServer + THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 logger = logging.getLogger(__name__) -class RegistrationWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn, hs): +@attr.s(frozen=True, slots=True) +class TokenLookupResult: + """Result of looking up an access token. + + Attributes: + user_id: The user that this token authenticates as + is_guest + shadow_banned + token_id: The ID of the access token looked up + device_id: The device associated with the token, if any. + valid_until_ms: The timestamp the token expires, if any. + token_owner: The "owner" of the token. This is either the same as the + user, or a server admin who is logged in as the user. + """ + + user_id = attr.ib(type=str) + is_guest = attr.ib(type=bool, default=False) + shadow_banned = attr.ib(type=bool, default=False) + token_id = attr.ib(type=Optional[int], default=None) + device_id = attr.ib(type=Optional[str], default=None) + valid_until_ms = attr.ib(type=Optional[int], default=None) + token_owner = attr.ib(type=str) + + # Make the token owner default to the user ID, which is the common case. + @token_owner.default + def _default_token_owner(self): + return self.user_id + + +class RegistrationWorkerStore(CacheInvalidationWorkerStore): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) self.config = hs.config - self.clock = hs.get_clock() # Note: we don't check this sequence for consistency as we'd have to # call `find_max_generated_user_id_localpart` each time, which is @@ -55,7 +90,7 @@ class RegistrationWorkerStore(SQLBaseStore): # Create a background job for culling expired 3PID validity tokens if hs.config.run_background_tasks: - self.clock.looping_call( + self._clock.looping_call( self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS ) @@ -92,21 +127,19 @@ class RegistrationWorkerStore(SQLBaseStore): if not info: return False - now = self.clock.time_msec() + now = self._clock.time_msec() trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms return is_trial @cached() - async def get_user_by_access_token(self, token: str) -> Optional[dict]: + async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]: """Get a user from the given access token. Args: token: The access token of a user. Returns: - None, if the token did not match, otherwise dict - including the keys `name`, `is_guest`, `device_id`, `token_id`, - `valid_until_ms`. + None, if the token did not match, otherwise a `TokenLookupResult` """ return await self.db_pool.runInteraction( "get_user_by_access_token", self._query_for_auth, token @@ -236,13 +269,13 @@ class RegistrationWorkerStore(SQLBaseStore): desc="get_renewal_token_for_user", ) - async def get_users_expiring_soon(self) -> List[Dict[str, int]]: + async def get_users_expiring_soon(self) -> List[Dict[str, Any]]: """Selects users whose account will expire in the [now, now + renew_at] time window (see configuration for account_validity for information on what renew_at refers to). Returns: - A list of dictionaries mapping user ID to expiration time (in milliseconds). + A list of dictionaries, each with a user ID and expiration time (in milliseconds). """ def select_users_txn(txn, now_ms, renew_at): @@ -257,7 +290,7 @@ class RegistrationWorkerStore(SQLBaseStore): return await self.db_pool.runInteraction( "get_users_expiring_soon", select_users_txn, - self.clock.time_msec(), + self._clock.time_msec(), self.config.account_validity.renew_at, ) @@ -327,19 +360,24 @@ class RegistrationWorkerStore(SQLBaseStore): await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) - def _query_for_auth(self, txn, token): - sql = ( - "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id," - " access_tokens.device_id, access_tokens.valid_until_ms" - " FROM users" - " INNER JOIN access_tokens on users.name = access_tokens.user_id" - " WHERE token = ?" - ) + def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]: + sql = """ + SELECT users.name as user_id, + users.is_guest, + users.shadow_banned, + access_tokens.id as token_id, + access_tokens.device_id, + access_tokens.valid_until_ms, + access_tokens.user_id as token_owner + FROM users + INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id) + WHERE token = ? + """ txn.execute(sql, (token,)) rows = self.db_pool.cursor_to_dict(txn) if rows: - return rows[0] + return TokenLookupResult(**rows[0]) return None @@ -803,7 +841,7 @@ class RegistrationWorkerStore(SQLBaseStore): await self.db_pool.runInteraction( "cull_expired_threepid_validation_tokens", cull_expired_threepid_validation_tokens_txn, - self.clock.time_msec(), + self._clock.time_msec(), ) @wrap_as_background_process("account_validity_set_expiration_dates") @@ -890,10 +928,10 @@ class RegistrationWorkerStore(SQLBaseStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) - self.clock = hs.get_clock() + self._clock = hs.get_clock() self.config = hs.config self.db_pool.updates.register_background_index_update( @@ -1016,13 +1054,56 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): return 1 + async def set_user_deactivated_status( + self, user_id: str, deactivated: bool + ) -> None: + """Set the `deactivated` property for the provided user to the provided value. + + Args: + user_id: The ID of the user to set the status for. + deactivated: The value to set for `deactivated`. + """ -class RegistrationStore(RegistrationBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn, hs): + await self.db_pool.runInteraction( + "set_user_deactivated_status", + self.set_user_deactivated_status_txn, + user_id, + deactivated, + ) + + def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool): + self.db_pool.simple_update_one_txn( + txn=txn, + table="users", + keyvalues={"name": user_id}, + updatevalues={"deactivated": 1 if deactivated else 0}, + ) + self._invalidate_cache_and_stream( + txn, self.get_user_deactivated_status, (user_id,) + ) + txn.call_after(self.is_guest.invalidate, (user_id,)) + + @cached() + async def is_guest(self, user_id: str) -> bool: + res = await self.db_pool.simple_select_one_onecol( + table="users", + keyvalues={"name": user_id}, + retcol="is_guest", + allow_none=True, + desc="is_guest", + ) + + return res if res else False + + +class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors + self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") + async def add_access_token_to_user( self, user_id: str, @@ -1138,19 +1219,19 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): def _register_user( self, txn, - user_id, - password_hash, - was_guest, - make_guest, - appservice_id, - create_profile_with_displayname, - admin, - user_type, - shadow_banned, + user_id: str, + password_hash: Optional[str], + was_guest: bool, + make_guest: bool, + appservice_id: Optional[str], + create_profile_with_displayname: Optional[str], + admin: bool, + user_type: Optional[str], + shadow_banned: bool, ): user_id_obj = UserID.from_string(user_id) - now = int(self.clock.time()) + now = int(self._clock.time()) try: if was_guest: @@ -1374,18 +1455,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): await self.db_pool.runInteraction("delete_access_token", f) - @cached() - async def is_guest(self, user_id: str) -> bool: - res = await self.db_pool.simple_select_one_onecol( - table="users", - keyvalues={"name": user_id}, - retcol="is_guest", - allow_none=True, - desc="is_guest", - ) - - return res if res else False - async def add_user_pending_deactivation(self, user_id: str) -> None: """ Adds a user to the table of users who need to be parted from all the rooms they're @@ -1479,7 +1548,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, - updatevalues={"validated_at": self.clock.time_msec()}, + updatevalues={"validated_at": self._clock.time_msec()}, ) return next_link @@ -1547,35 +1616,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore): start_or_continue_validation_session_txn, ) - async def set_user_deactivated_status( - self, user_id: str, deactivated: bool - ) -> None: - """Set the `deactivated` property for the provided user to the provided value. - - Args: - user_id: The ID of the user to set the status for. - deactivated: The value to set for `deactivated`. - """ - - await self.db_pool.runInteraction( - "set_user_deactivated_status", - self.set_user_deactivated_status_txn, - user_id, - deactivated, - ) - - def set_user_deactivated_status_txn(self, txn, user_id, deactivated): - self.db_pool.simple_update_one_txn( - txn=txn, - table="users", - keyvalues={"name": user_id}, - updatevalues={"deactivated": 1 if deactivated else 0}, - ) - self._invalidate_cache_and_stream( - txn, self.get_user_deactivated_status, (user_id,) - ) - txn.call_after(self.is_guest.invalidate, (user_id,)) - def find_max_generated_user_id_localpart(cur: Cursor) -> int: """ diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index e83d961c20..dc0c4b5499 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1411,6 +1411,65 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): desc="add_event_report", ) + async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]: + """Retrieve an event report + + Args: + report_id: ID of reported event in database + Returns: + event_report: json list of information from event report + """ + + def _get_event_report_txn(txn, report_id): + + sql = """ + SELECT + er.id, + er.received_ts, + er.room_id, + er.event_id, + er.user_id, + er.content, + events.sender, + room_stats_state.canonical_alias, + room_stats_state.name, + event_json.json AS event_json + FROM event_reports AS er + LEFT JOIN events + ON events.event_id = er.event_id + JOIN event_json + ON event_json.event_id = er.event_id + JOIN room_stats_state + ON room_stats_state.room_id = er.room_id + WHERE er.id = ? + """ + + txn.execute(sql, [report_id]) + row = txn.fetchone() + + if not row: + return None + + event_report = { + "id": row[0], + "received_ts": row[1], + "room_id": row[2], + "event_id": row[3], + "user_id": row[4], + "score": db_to_json(row[5]).get("score"), + "reason": db_to_json(row[5]).get("reason"), + "sender": row[6], + "canonical_alias": row[7], + "name": row[8], + "event_json": db_to_json(row[9]), + } + + return event_report + + return await self.db_pool.runInteraction( + "get_event_report", _get_event_report_txn, report_id + ) + async def get_event_reports_paginate( self, start: int, @@ -1468,18 +1527,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): er.room_id, er.event_id, er.user_id, - er.reason, er.content, events.sender, - room_aliases.room_alias, - event_json.json AS event_json + room_stats_state.canonical_alias, + room_stats_state.name FROM event_reports AS er - LEFT JOIN room_aliases - ON room_aliases.room_id = er.room_id - JOIN events + LEFT JOIN events ON events.event_id = er.event_id - JOIN event_json - ON event_json.event_id = er.event_id + JOIN room_stats_state + ON room_stats_state.room_id = er.room_id {where_clause} ORDER BY er.received_ts {order} LIMIT ? @@ -1490,15 +1546,29 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): args += [limit, start] txn.execute(sql, args) - event_reports = self.db_pool.cursor_to_dict(txn) - - if count > 0: - for row in event_reports: - try: - row["content"] = db_to_json(row["content"]) - row["event_json"] = db_to_json(row["event_json"]) - except Exception: - continue + + event_reports = [] + for row in txn: + try: + s = db_to_json(row[5]).get("score") + r = db_to_json(row[5]).get("reason") + except Exception: + logger.error("Unable to parse json from event_reports: %s", row[0]) + continue + event_reports.append( + { + "id": row[0], + "received_ts": row[1], + "room_id": row[2], + "event_id": row[3], + "user_id": row[4], + "score": s, + "reason": r, + "sender": row[6], + "canonical_alias": row[7], + "name": row[8], + } + ) return event_reports, count diff --git a/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql new file mode 100644 index 0000000000..00a9431a97 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql @@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Whether the access token is an admin token for controlling another user. +ALTER TABLE access_tokens ADD COLUMN puppets_user_id TEXT; diff --git a/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql new file mode 100644 index 0000000000..a2842687f1 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql @@ -0,0 +1,2 @@ +INSERT INTO background_updates (update_name, progress_json) VALUES + ('users_have_local_media', '{}'); \ No newline at end of file diff --git a/synapse/types.py b/synapse/types.py index 5bde67cc07..66bb5bac8d 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -29,6 +29,7 @@ from typing import ( Tuple, Type, TypeVar, + Union, ) import attr @@ -38,6 +39,7 @@ from unpaddedbase64 import decode_base64 from synapse.api.errors import Codes, SynapseError if TYPE_CHECKING: + from synapse.appservice.api import ApplicationService from synapse.storage.databases.main import DataStore # define a version of typing.Collection that works on python 3.5 @@ -74,6 +76,7 @@ class Requester( "shadow_banned", "device_id", "app_service", + "authenticated_entity", ], ) ): @@ -104,6 +107,7 @@ class Requester( "shadow_banned": self.shadow_banned, "device_id": self.device_id, "app_server_id": self.app_service.id if self.app_service else None, + "authenticated_entity": self.authenticated_entity, } @staticmethod @@ -129,16 +133,18 @@ class Requester( shadow_banned=input["shadow_banned"], device_id=input["device_id"], app_service=appservice, + authenticated_entity=input["authenticated_entity"], ) def create_requester( - user_id, - access_token_id=None, - is_guest=False, - shadow_banned=False, - device_id=None, - app_service=None, + user_id: Union[str, "UserID"], + access_token_id: Optional[int] = None, + is_guest: Optional[bool] = False, + shadow_banned: Optional[bool] = False, + device_id: Optional[str] = None, + app_service: Optional["ApplicationService"] = None, + authenticated_entity: Optional[str] = None, ): """ Create a new ``Requester`` object @@ -151,14 +157,27 @@ def create_requester( shadow_banned (bool): True if the user making this request is shadow-banned. device_id (str|None): device_id which was set at authentication time app_service (ApplicationService|None): the AS requesting on behalf of the user + authenticated_entity: The entity that authenticated when making the request. + This is different to the user_id when an admin user or the server is + "puppeting" the user. Returns: Requester """ if not isinstance(user_id, UserID): user_id = UserID.from_string(user_id) + + if authenticated_entity is None: + authenticated_entity = user_id.to_string() + return Requester( - user_id, access_token_id, is_guest, shadow_banned, device_id, app_service + user_id, + access_token_id, + is_guest, + shadow_banned, + device_id, + app_service, + authenticated_entity, ) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 5d7fffee66..a924140cdf 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -13,10 +13,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum import functools import inspect import logging -from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast +from typing import ( + Any, + Callable, + Generic, + Iterable, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + cast, +) from weakref import WeakValueDictionary from twisted.internet import defer @@ -24,6 +37,7 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError from synapse.util.caches.deferred_cache import DeferredCache +from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -48,7 +62,7 @@ class _CachedFunction(Generic[F]): class _CacheDescriptorBase: - def __init__(self, orig: _CachedFunction, num_args, cache_context=False): + def __init__(self, orig: Callable[..., Any], num_args, cache_context=False): self.orig = orig arg_spec = inspect.getfullargspec(orig) @@ -97,8 +111,107 @@ class _CacheDescriptorBase: self.add_cache_context = cache_context + self.cache_key_builder = get_cache_key_builder( + self.arg_names, self.arg_defaults + ) + + +class _LruCachedFunction(Generic[F]): + cache = None # type: LruCache[CacheKey, Any] + __call__ = None # type: F + + +def lru_cache( + max_entries: int = 1000, cache_context: bool = False, +) -> Callable[[F], _LruCachedFunction[F]]: + """A method decorator that applies a memoizing cache around the function. + + This is more-or-less a drop-in equivalent to functools.lru_cache, although note + that the signature is slightly different. + + The main differences with functools.lru_cache are: + (a) the size of the cache can be controlled via the cache_factor mechanism + (b) the wrapped function can request a "cache_context" which provides a + callback mechanism to indicate that the result is no longer valid + (c) prometheus metrics are exposed automatically. + + The function should take zero or more arguments, which are used as the key for the + cache. Single-argument functions use that argument as the cache key; otherwise the + arguments are built into a tuple. + + Cached functions can be "chained" (i.e. a cached function can call other cached + functions and get appropriately invalidated when they called caches are + invalidated) by adding a special "cache_context" argument to the function + and passing that as a kwarg to all caches called. For example: + + @lru_cache(cache_context=True) + def foo(self, key, cache_context): + r1 = self.bar1(key, on_invalidate=cache_context.invalidate) + r2 = self.bar2(key, on_invalidate=cache_context.invalidate) + return r1 + r2 + + The wrapped function also has a 'cache' property which offers direct access to the + underlying LruCache. + """ + + def func(orig: F) -> _LruCachedFunction[F]: + desc = LruCacheDescriptor( + orig, max_entries=max_entries, cache_context=cache_context, + ) + return cast(_LruCachedFunction[F], desc) + + return func + + +class LruCacheDescriptor(_CacheDescriptorBase): + """Helper for @lru_cache""" + + class _Sentinel(enum.Enum): + sentinel = object() + + def __init__( + self, orig, max_entries: int = 1000, cache_context: bool = False, + ): + super().__init__(orig, num_args=None, cache_context=cache_context) + self.max_entries = max_entries + + def __get__(self, obj, owner): + cache = LruCache( + cache_name=self.orig.__name__, max_size=self.max_entries, + ) # type: LruCache[CacheKey, Any] + + get_cache_key = self.cache_key_builder + sentinel = LruCacheDescriptor._Sentinel.sentinel + + @functools.wraps(self.orig) + def _wrapped(*args, **kwargs): + invalidate_callback = kwargs.pop("on_invalidate", None) + callbacks = (invalidate_callback,) if invalidate_callback else () + + cache_key = get_cache_key(args, kwargs) -class CacheDescriptor(_CacheDescriptorBase): + ret = cache.get(cache_key, default=sentinel, callbacks=callbacks) + if ret != sentinel: + return ret + + # Add our own `cache_context` to argument list if the wrapped function + # has asked for one + if self.add_cache_context: + kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key) + + ret2 = self.orig(obj, *args, **kwargs) + cache.set(cache_key, ret2, callbacks=callbacks) + + return ret2 + + wrapped = cast(_CachedFunction, _wrapped) + wrapped.cache = cache + obj.__dict__[self.orig.__name__] = wrapped + + return wrapped + + +class DeferredCacheDescriptor(_CacheDescriptorBase): """ A method decorator that applies a memoizing cache around the function. This caches deferreds, rather than the results themselves. Deferreds that @@ -141,7 +254,6 @@ class CacheDescriptor(_CacheDescriptorBase): cache_context=False, iterable=False, ): - super().__init__(orig, num_args=num_args, cache_context=cache_context) self.max_entries = max_entries @@ -157,41 +269,7 @@ class CacheDescriptor(_CacheDescriptorBase): iterable=self.iterable, ) # type: DeferredCache[CacheKey, Any] - def get_cache_key_gen(args, kwargs): - """Given some args/kwargs return a generator that resolves into - the cache_key. - - We loop through each arg name, looking up if its in the `kwargs`, - otherwise using the next argument in `args`. If there are no more - args then we try looking the arg name up in the defaults - """ - pos = 0 - for nm in self.arg_names: - if nm in kwargs: - yield kwargs[nm] - elif pos < len(args): - yield args[pos] - pos += 1 - else: - yield self.arg_defaults[nm] - - # By default our cache key is a tuple, but if there is only one item - # then don't bother wrapping in a tuple. This is to save memory. - if self.num_args == 1: - nm = self.arg_names[0] - - def get_cache_key(args, kwargs): - if nm in kwargs: - return kwargs[nm] - elif len(args): - return args[0] - else: - return self.arg_defaults[nm] - - else: - - def get_cache_key(args, kwargs): - return tuple(get_cache_key_gen(args, kwargs)) + get_cache_key = self.cache_key_builder @functools.wraps(self.orig) def _wrapped(*args, **kwargs): @@ -223,7 +301,6 @@ class CacheDescriptor(_CacheDescriptorBase): wrapped.prefill = lambda key, val: cache.prefill(key[0], val) else: wrapped.invalidate = cache.invalidate - wrapped.invalidate_all = cache.invalidate_all wrapped.invalidate_many = cache.invalidate_many wrapped.prefill = cache.prefill @@ -236,7 +313,7 @@ class CacheDescriptor(_CacheDescriptorBase): return wrapped -class CacheListDescriptor(_CacheDescriptorBase): +class DeferredCacheListDescriptor(_CacheDescriptorBase): """Wraps an existing cache to support bulk fetching of keys. Given a list of keys it looks in the cache to find any hits, then passes @@ -382,11 +459,13 @@ class _CacheContext: on a lower level. """ + Cache = Union[DeferredCache, LruCache] + _cache_context_objects = ( WeakValueDictionary() - ) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext] + ) # type: WeakValueDictionary[Tuple[_CacheContext.Cache, CacheKey], _CacheContext] - def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None + def __init__(self, cache: "_CacheContext.Cache", cache_key: CacheKey) -> None: self._cache = cache self._cache_key = cache_key @@ -396,8 +475,8 @@ class _CacheContext: @classmethod def get_instance( - cls, cache, cache_key - ): # type: (DeferredCache, CacheKey) -> _CacheContext + cls, cache: "_CacheContext.Cache", cache_key: CacheKey + ) -> "_CacheContext": """Returns an instance constructed with the given arguments. A new instance is only created if none already exists. @@ -418,7 +497,7 @@ def cached( cache_context: bool = False, iterable: bool = False, ) -> Callable[[F], _CachedFunction[F]]: - func = lambda orig: CacheDescriptor( + func = lambda orig: DeferredCacheDescriptor( orig, max_entries=max_entries, num_args=num_args, @@ -460,7 +539,7 @@ def cachedList( def batch_do_something(self, first_arg, second_args): ... """ - func = lambda orig: CacheListDescriptor( + func = lambda orig: DeferredCacheListDescriptor( orig, cached_method_name=cached_method_name, list_name=list_name, @@ -468,3 +547,65 @@ def cachedList( ) return cast(Callable[[F], _CachedFunction[F]], func) + + +def get_cache_key_builder( + param_names: Sequence[str], param_defaults: Mapping[str, Any] +) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]: + """Construct a function which will build cache keys suitable for a cached function + + Args: + param_names: list of formal parameter names for the cached function + param_defaults: a mapping from parameter name to default value for that param + + Returns: + A function which will take an (args, kwargs) pair and return a cache key + """ + + # By default our cache key is a tuple, but if there is only one item + # then don't bother wrapping in a tuple. This is to save memory. + + if len(param_names) == 1: + nm = param_names[0] + + def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: + if nm in kwargs: + return kwargs[nm] + elif len(args): + return args[0] + else: + return param_defaults[nm] + + else: + + def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey: + return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs)) + + return get_cache_key + + +def _get_cache_key_gen( + param_names: Iterable[str], + param_defaults: Mapping[str, Any], + args: Sequence[Any], + kwargs: Mapping[str, Any], +) -> Iterable[Any]: + """Given some args/kwargs return a generator that resolves into + the cache_key. + + This is essentially the same operation as `inspect.getcallargs`, but optimised so + that we don't need to inspect the target function for each call. + """ + + # We loop through each arg name, looking up if its in the `kwargs`, + # otherwise using the next argument in `args`. If there are no more + # args then we try looking the arg name up in the defaults. + pos = 0 + for nm in param_names: + if nm in kwargs: + yield kwargs[nm] + elif pos < len(args): + yield args[pos] + pos += 1 + else: + yield param_defaults[nm] diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index a5cc9d0551..4ab379e429 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -110,7 +110,7 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k failure_ts, retry_interval, backoff_on_failure=backoff_on_failure, - **kwargs + **kwargs, ) diff --git a/synmark/__init__.py b/synmark/__init__.py index 09bc7e7927..3d4ec3e184 100644 --- a/synmark/__init__.py +++ b/synmark/__init__.py @@ -21,45 +21,6 @@ except ImportError: from twisted.internet.pollreactor import PollReactor as Reactor from twisted.internet.main import installReactor -from synapse.config.homeserver import HomeServerConfig -from synapse.util import Clock - -from tests.utils import default_config, setup_test_homeserver - - -async def make_homeserver(reactor, config=None): - """ - Make a Homeserver suitable for running benchmarks against. - - Args: - reactor: A Twisted reactor to run under. - config: A HomeServerConfig to use, or None. - """ - cleanup_tasks = [] - clock = Clock(reactor) - - if not config: - config = default_config("test") - - config_obj = HomeServerConfig() - config_obj.parse_config_dict(config, "", "") - - hs = setup_test_homeserver( - cleanup_tasks.append, config=config_obj, reactor=reactor, clock=clock - ) - stor = hs.get_datastore() - - # Run the database background updates. - if hasattr(stor.db_pool.updates, "do_next_background_update"): - while not await stor.db_pool.updates.has_completed_background_updates(): - await stor.db_pool.updates.do_next_background_update(1) - - def cleanup(): - for i in cleanup_tasks: - i() - - return hs, clock.sleep, cleanup - def make_reactor(): """ diff --git a/synmark/__main__.py b/synmark/__main__.py index 17df9ddeb7..de13c1a909 100644 --- a/synmark/__main__.py +++ b/synmark/__main__.py @@ -12,20 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import sys from argparse import REMAINDER from contextlib import redirect_stderr from io import StringIO import pyperf -from synmark import make_reactor -from synmark.suites import SUITES from twisted.internet.defer import Deferred, ensureDeferred from twisted.logger import globalLogBeginner, textFileLogObserver from twisted.python.failure import Failure +from synmark import make_reactor +from synmark.suites import SUITES + from tests.utils import setupdb diff --git a/synmark/suites/logging.py b/synmark/suites/logging.py index d8e4c7d58f..c9d9cf761e 100644 --- a/synmark/suites/logging.py +++ b/synmark/suites/logging.py @@ -13,20 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import warnings from io import StringIO from mock import Mock from pyperf import perf_counter -from synmark import make_homeserver from twisted.internet.defer import Deferred from twisted.internet.protocol import ServerFactory -from twisted.logger import LogBeginner, Logger, LogPublisher +from twisted.logger import LogBeginner, LogPublisher from twisted.protocols.basic import LineOnlyReceiver -from synapse.logging._structured import setup_structured_logging +from synapse.config.logger import _setup_stdlib_logging +from synapse.logging import RemoteHandler +from synapse.util import Clock class LineCounter(LineOnlyReceiver): @@ -62,7 +64,15 @@ async def main(reactor, loops): logger_factory.on_done = Deferred() port = reactor.listenTCP(0, logger_factory, interface="127.0.0.1") - hs, wait, cleanup = await make_homeserver(reactor) + # A fake homeserver config. + class Config: + server_name = "synmark-" + str(loops) + no_redirect_stdio = True + + hs_config = Config() + + # To be able to sleep. + clock = Clock(reactor) errors = StringIO() publisher = LogPublisher() @@ -72,47 +82,49 @@ async def main(reactor, loops): ) log_config = { - "loggers": {"synapse": {"level": "DEBUG"}}, - "drains": { + "version": 1, + "loggers": {"synapse": {"level": "DEBUG", "handlers": ["tersejson"]}}, + "formatters": {"tersejson": {"class": "synapse.logging.TerseJsonFormatter"}}, + "handlers": { "tersejson": { - "type": "network_json_terse", + "class": "synapse.logging.RemoteHandler", "host": "127.0.0.1", "port": port.getHost().port, "maximum_buffer": 100, + "_reactor": reactor, } }, } - logger = Logger(namespace="synapse.logging.test_terse_json", observer=publisher) - logging_system = setup_structured_logging( - hs, hs.config, log_config, logBeginner=beginner, redirect_stdlib_logging=False + logger = logging.getLogger("synapse.logging.test_terse_json") + _setup_stdlib_logging( + hs_config, log_config, logBeginner=beginner, ) # Wait for it to connect... - await logging_system._observers[0]._service.whenConnected() + for handler in logging.getLogger("synapse").handlers: + if isinstance(handler, RemoteHandler): + break + else: + raise RuntimeError("Improperly configured: no RemoteHandler found.") + + await handler._service.whenConnected() start = perf_counter() # Send a bunch of useful messages for i in range(0, loops): - logger.info("test message %s" % (i,)) - - if ( - len(logging_system._observers[0]._buffer) - == logging_system._observers[0].maximum_buffer - ): - while ( - len(logging_system._observers[0]._buffer) - > logging_system._observers[0].maximum_buffer / 2 - ): - await wait(0.01) + logger.info("test message %s", i) + + if len(handler._buffer) == handler.maximum_buffer: + while len(handler._buffer) > handler.maximum_buffer / 2: + await clock.sleep(0.01) await logger_factory.on_done end = perf_counter() - start - logging_system.stop() + handler.close() port.stopListening() - cleanup() return end diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index cb6f29d670..0fd55f428a 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -29,6 +29,7 @@ from synapse.api.errors import ( MissingClientTokenError, ResourceLimitError, ) +from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import UserID from tests import unittest @@ -61,7 +62,9 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_user_by_req_user_valid_token(self): - user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"} + user_info = TokenLookupResult( + user_id=self.test_user, token_id=5, device_id="device" + ) self.store.get_user_by_access_token = Mock( return_value=defer.succeed(user_info) ) @@ -84,7 +87,7 @@ class AuthTestCase(unittest.TestCase): self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") def test_get_user_by_req_user_missing_token(self): - user_info = {"name": self.test_user, "token_id": "ditto"} + user_info = TokenLookupResult(user_id=self.test_user, token_id=5) self.store.get_user_by_access_token = Mock( return_value=defer.succeed(user_info) ) @@ -221,7 +224,7 @@ class AuthTestCase(unittest.TestCase): def test_get_user_from_macaroon(self): self.store.get_user_by_access_token = Mock( return_value=defer.succeed( - {"name": "@baldrick:matrix.org", "device_id": "device"} + TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device") ) ) @@ -237,12 +240,11 @@ class AuthTestCase(unittest.TestCase): user_info = yield defer.ensureDeferred( self.auth.get_user_by_access_token(macaroon.serialize()) ) - user = user_info["user"] - self.assertEqual(UserID.from_string(user_id), user) + self.assertEqual(user_id, user_info.user_id) # TODO: device_id should come from the macaroon, but currently comes # from the db. - self.assertEqual(user_info["device_id"], "device") + self.assertEqual(user_info.device_id, "device") @defer.inlineCallbacks def test_get_guest_user_from_macaroon(self): @@ -264,10 +266,8 @@ class AuthTestCase(unittest.TestCase): user_info = yield defer.ensureDeferred( self.auth.get_user_by_access_token(serialized) ) - user = user_info["user"] - is_guest = user_info["is_guest"] - self.assertEqual(UserID.from_string(user_id), user) - self.assertTrue(is_guest) + self.assertEqual(user_id, user_info.user_id) + self.assertTrue(user_info.is_guest) self.store.get_user_by_id.assert_called_with(user_id) @defer.inlineCallbacks @@ -289,12 +289,9 @@ class AuthTestCase(unittest.TestCase): if token != tok: return defer.succeed(None) return defer.succeed( - { - "name": USER_ID, - "is_guest": False, - "token_id": 1234, - "device_id": "DEVICE", - } + TokenLookupResult( + user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE", + ) ) self.store.get_user_by_access_token = get_user diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 1e1f30d790..fe504d0869 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -43,7 +43,7 @@ class TestRatelimiter(unittest.TestCase): def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): appservice = ApplicationService( - None, "example.com", id="foo", rate_limited=True, + None, "example.com", id="foo", rate_limited=True, sender="@as:example.com", ) as_requester = create_requester("@user:example.com", app_service=appservice) @@ -68,7 +68,7 @@ class TestRatelimiter(unittest.TestCase): def test_allowed_appservice_via_can_requester_do_action(self): appservice = ApplicationService( - None, "example.com", id="foo", rate_limited=False, + None, "example.com", id="foo", rate_limited=False, sender="@as:example.com", ) as_requester = create_requester("@user:example.com", app_service=appservice) diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 236b608d58..0bffeb1150 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -31,6 +31,7 @@ class ApplicationServiceTestCase(unittest.TestCase): def setUp(self): self.service = ApplicationService( id="unique_identifier", + sender="@as:test", url="some_url", token="some_token", hostname="matrix.org", # only used by get_groups_for_user diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index ee4f3da31c..53763cd0f9 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -42,7 +42,6 @@ class AppServiceHandlerTestCase(unittest.TestCase): hs.get_clock.return_value = MockClock() self.handler = ApplicationServicesHandler(hs) - @defer.inlineCallbacks def test_notify_interested_services(self): interested_service = self._mkservice(is_interested=True) services = [ @@ -62,14 +61,12 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.succeed((0, [event])), defer.succeed((0, [])), ] - yield defer.ensureDeferred( - self.handler.notify_interested_services(RoomStreamToken(None, 0)) - ) + self.handler.notify_interested_services(RoomStreamToken(None, 0)) + self.mock_scheduler.submit_event_for_as.assert_called_once_with( interested_service, event ) - @defer.inlineCallbacks def test_query_user_exists_unknown_user(self): user_id = "@someone:anywhere" services = [self._mkservice(is_interested=True)] @@ -83,12 +80,11 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.succeed((0, [event])), defer.succeed((0, [])), ] - yield defer.ensureDeferred( - self.handler.notify_interested_services(RoomStreamToken(None, 0)) - ) + + self.handler.notify_interested_services(RoomStreamToken(None, 0)) + self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) - @defer.inlineCallbacks def test_query_user_exists_known_user(self): user_id = "@someone:anywhere" services = [self._mkservice(is_interested=True)] @@ -102,9 +98,9 @@ class AppServiceHandlerTestCase(unittest.TestCase): defer.succeed((0, [event])), defer.succeed((0, [])), ] - yield defer.ensureDeferred( - self.handler.notify_interested_services(RoomStreamToken(None, 0)) - ) + + self.handler.notify_interested_services(RoomStreamToken(None, 0)) + self.assertFalse( self.mock_as_api.query_user.called, "query_user called when it shouldn't have been.", diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 4512c51311..875aaec2c6 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -289,7 +289,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase): # make sure that our device ID has changed user_info = self.get_success(self.auth.get_user_by_access_token(access_token)) - self.assertEqual(user_info["device_id"], retrieved_device_id) + self.assertEqual(user_info.device_id, retrieved_device_id) # make sure the device has the display name that was set from the login res = self.get_success(self.handler.get_device(user_id, retrieved_device_id)) diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 9f6f21a6e2..2e0fea04af 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -46,7 +46,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self.info = self.get_success( self.hs.get_datastore().get_user_by_access_token(self.access_token,) ) - self.token_id = self.info["token_id"] + self.token_id = self.info.token_id self.requester = create_requester(self.user_id, access_token_id=self.token_id) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index b6f436c016..0d51705849 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -394,7 +394,14 @@ class OidcHandlerTestCase(HomeserverTestCase): self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) self.handler._auth_handler.complete_sso_login = simple_async_mock() request = Mock( - spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"] + spec=[ + "args", + "getCookie", + "addCookie", + "requestHeaders", + "getClientIP", + "get_user_agent", + ] ) code = "code" @@ -414,9 +421,8 @@ class OidcHandlerTestCase(HomeserverTestCase): request.args[b"code"] = [code.encode("utf-8")] request.args[b"state"] = [state.encode("utf-8")] - request.requestHeaders = Mock(spec=["getRawHeaders"]) - request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")] request.getClientIP.return_value = ip_address + request.get_user_agent.return_value = user_agent self.get_success(self.handler.handle_oidc_callback(request)) @@ -621,7 +627,14 @@ class OidcHandlerTestCase(HomeserverTestCase): self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) self.handler._auth_handler.complete_sso_login = simple_async_mock() request = Mock( - spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"] + spec=[ + "args", + "getCookie", + "addCookie", + "requestHeaders", + "getClientIP", + "get_user_agent", + ] ) state = "state" @@ -637,9 +650,8 @@ class OidcHandlerTestCase(HomeserverTestCase): request.args[b"code"] = [b"code"] request.args[b"state"] = [state.encode("utf-8")] - request.requestHeaders = Mock(spec=["getRawHeaders"]) - request.requestHeaders.getRawHeaders.return_value = [b"Browser"] request.getClientIP.return_value = "10.0.0.1" + request.get_user_agent.return_value = "Browser" self.get_success(self.handler.handle_oidc_callback(request)) diff --git a/tests/logging/__init__.py b/tests/logging/__init__.py index e69de29bb2..a58d51441c 100644 --- a/tests/logging/__init__.py +++ b/tests/logging/__init__.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + + +class LoggerCleanupMixin: + def get_logger(self, handler): + """ + Attach a handler to a logger and add clean-ups to remove revert this. + """ + # Create a logger and add the handler to it. + logger = logging.getLogger(__name__) + logger.addHandler(handler) + + # Ensure the logger actually logs something. + logger.setLevel(logging.INFO) + + # Ensure the logger gets cleaned-up appropriately. + self.addCleanup(logger.removeHandler, handler) + self.addCleanup(logger.setLevel, logging.NOTSET) + + return logger diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py new file mode 100644 index 0000000000..4bc27a1d7d --- /dev/null +++ b/tests/logging/test_remote_handler.py @@ -0,0 +1,169 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.test.proto_helpers import AccumulatingProtocol + +from synapse.logging import RemoteHandler + +from tests.logging import LoggerCleanupMixin +from tests.server import FakeTransport, get_clock +from tests.unittest import TestCase + + +def connect_logging_client(reactor, client_id): + # This is essentially tests.server.connect_client, but disabling autoflush on + # the client transport. This is necessary to avoid an infinite loop due to + # sending of data via the logging transport causing additional logs to be + # written. + factory = reactor.tcpClients.pop(client_id)[2] + client = factory.buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, reactor)) + client.makeConnection(FakeTransport(server, reactor, autoflush=False)) + + return client, server + + +class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): + def setUp(self): + self.reactor, _ = get_clock() + + def test_log_output(self): + """ + The remote handler delivers logs over TCP. + """ + handler = RemoteHandler("127.0.0.1", 9000, _reactor=self.reactor) + logger = self.get_logger(handler) + + logger.info("Hello there, %s!", "wally") + + # Trigger the connection + client, server = connect_logging_client(self.reactor, 0) + + # Trigger data being sent + client.transport.flush() + + # One log message, with a single trailing newline + logs = server.data.decode("utf8").splitlines() + self.assertEqual(len(logs), 1) + self.assertEqual(server.data.count(b"\n"), 1) + + # Ensure the data passed through properly. + self.assertEqual(logs[0], "Hello there, wally!") + + def test_log_backpressure_debug(self): + """ + When backpressure is hit, DEBUG logs will be shed. + """ + handler = RemoteHandler( + "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor + ) + logger = self.get_logger(handler) + + # Send some debug messages + for i in range(0, 3): + logger.debug("debug %s" % (i,)) + + # Send a bunch of useful messages + for i in range(0, 7): + logger.info("info %s" % (i,)) + + # The last debug message pushes it past the maximum buffer + logger.debug("too much debug") + + # Allow the reconnection + client, server = connect_logging_client(self.reactor, 0) + client.transport.flush() + + # Only the 7 infos made it through, the debugs were elided + logs = server.data.splitlines() + self.assertEqual(len(logs), 7) + self.assertNotIn(b"debug", server.data) + + def test_log_backpressure_info(self): + """ + When backpressure is hit, DEBUG and INFO logs will be shed. + """ + handler = RemoteHandler( + "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor + ) + logger = self.get_logger(handler) + + # Send some debug messages + for i in range(0, 3): + logger.debug("debug %s" % (i,)) + + # Send a bunch of useful messages + for i in range(0, 10): + logger.warning("warn %s" % (i,)) + + # Send a bunch of info messages + for i in range(0, 3): + logger.info("info %s" % (i,)) + + # The last debug message pushes it past the maximum buffer + logger.debug("too much debug") + + # Allow the reconnection + client, server = connect_logging_client(self.reactor, 0) + client.transport.flush() + + # The 10 warnings made it through, the debugs and infos were elided + logs = server.data.splitlines() + self.assertEqual(len(logs), 10) + self.assertNotIn(b"debug", server.data) + self.assertNotIn(b"info", server.data) + + def test_log_backpressure_cut_middle(self): + """ + When backpressure is hit, and no more DEBUG and INFOs cannot be culled, + it will cut the middle messages out. + """ + handler = RemoteHandler( + "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor + ) + logger = self.get_logger(handler) + + # Send a bunch of useful messages + for i in range(0, 20): + logger.warning("warn %s" % (i,)) + + # Allow the reconnection + client, server = connect_logging_client(self.reactor, 0) + client.transport.flush() + + # The first five and last five warnings made it through, the debugs and + # infos were elided + logs = server.data.decode("utf8").splitlines() + self.assertEqual( + ["warn %s" % (i,) for i in range(5)] + + ["warn %s" % (i,) for i in range(15, 20)], + logs, + ) + + def test_cancel_connection(self): + """ + Gracefully handle the connection being cancelled. + """ + handler = RemoteHandler( + "127.0.0.1", 9000, maximum_buffer=10, _reactor=self.reactor + ) + logger = self.get_logger(handler) + + # Send a message. + logger.info("Hello there, %s!", "wally") + + # Do not accept the connection and shutdown. This causes the pending + # connection to be cancelled (and should not raise any exceptions). + handler.close() diff --git a/tests/logging/test_structured.py b/tests/logging/test_structured.py deleted file mode 100644 index d36f5f426c..0000000000 --- a/tests/logging/test_structured.py +++ /dev/null @@ -1,214 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -import os.path -import shutil -import sys -import textwrap - -from twisted.logger import Logger, eventAsText, eventsFromJSONLogFile - -from synapse.config.logger import setup_logging -from synapse.logging._structured import setup_structured_logging -from synapse.logging.context import LoggingContext - -from tests.unittest import DEBUG, HomeserverTestCase - - -class FakeBeginner: - def beginLoggingTo(self, observers, **kwargs): - self.observers = observers - - -class StructuredLoggingTestBase: - """ - Test base that registers a cleanup handler to reset the stdlib log handler - to 'unset'. - """ - - def prepare(self, reactor, clock, hs): - def _cleanup(): - logging.getLogger("synapse").setLevel(logging.NOTSET) - - self.addCleanup(_cleanup) - - -class StructuredLoggingTestCase(StructuredLoggingTestBase, HomeserverTestCase): - """ - Tests for Synapse's structured logging support. - """ - - def test_output_to_json_round_trip(self): - """ - Synapse logs can be outputted to JSON and then read back again. - """ - temp_dir = self.mktemp() - os.mkdir(temp_dir) - self.addCleanup(shutil.rmtree, temp_dir) - - json_log_file = os.path.abspath(os.path.join(temp_dir, "out.json")) - - log_config = { - "drains": {"jsonfile": {"type": "file_json", "location": json_log_file}} - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, self.hs.config, log_config, logBeginner=beginner - ) - - # Make a logger and send an event - logger = Logger( - namespace="tests.logging.test_structured", observer=beginner.observers[0] - ) - logger.info("Hello there, {name}!", name="wally") - - # Read the log file and check it has the event we sent - with open(json_log_file, "r") as f: - logged_events = list(eventsFromJSONLogFile(f)) - self.assertEqual(len(logged_events), 1) - - # The event pulled from the file should render fine - self.assertEqual( - eventAsText(logged_events[0], includeTimestamp=False), - "[tests.logging.test_structured#info] Hello there, wally!", - ) - - def test_output_to_text(self): - """ - Synapse logs can be outputted to text. - """ - temp_dir = self.mktemp() - os.mkdir(temp_dir) - self.addCleanup(shutil.rmtree, temp_dir) - - log_file = os.path.abspath(os.path.join(temp_dir, "out.log")) - - log_config = {"drains": {"file": {"type": "file", "location": log_file}}} - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, self.hs.config, log_config, logBeginner=beginner - ) - - # Make a logger and send an event - logger = Logger( - namespace="tests.logging.test_structured", observer=beginner.observers[0] - ) - logger.info("Hello there, {name}!", name="wally") - - # Read the log file and check it has the event we sent - with open(log_file, "r") as f: - logged_events = f.read().strip().split("\n") - self.assertEqual(len(logged_events), 1) - - # The event pulled from the file should render fine - self.assertTrue( - logged_events[0].endswith( - " - tests.logging.test_structured - INFO - None - Hello there, wally!" - ) - ) - - def test_collects_logcontext(self): - """ - Test that log outputs have the attached logging context. - """ - log_config = {"drains": {}} - - # Begin the logger with our config - beginner = FakeBeginner() - publisher = setup_structured_logging( - self.hs, self.hs.config, log_config, logBeginner=beginner - ) - - logs = [] - - publisher.addObserver(logs.append) - - # Make a logger and send an event - logger = Logger( - namespace="tests.logging.test_structured", observer=beginner.observers[0] - ) - - with LoggingContext("testcontext", request="somereq"): - logger.info("Hello there, {name}!", name="steve") - - self.assertEqual(len(logs), 1) - self.assertEqual(logs[0]["request"], "somereq") - - -class StructuredLoggingConfigurationFileTestCase( - StructuredLoggingTestBase, HomeserverTestCase -): - def make_homeserver(self, reactor, clock): - - tempdir = self.mktemp() - os.mkdir(tempdir) - log_config_file = os.path.abspath(os.path.join(tempdir, "log.config.yaml")) - self.homeserver_log = os.path.abspath(os.path.join(tempdir, "homeserver.log")) - - config = self.default_config() - config["log_config"] = log_config_file - - with open(log_config_file, "w") as f: - f.write( - textwrap.dedent( - """\ - structured: true - - drains: - file: - type: file_json - location: %s - """ - % (self.homeserver_log,) - ) - ) - - self.addCleanup(self._sys_cleanup) - - return self.setup_test_homeserver(config=config) - - def _sys_cleanup(self): - sys.stdout = sys.__stdout__ - sys.stderr = sys.__stderr__ - - # Do not remove! We need the logging system to be set other than WARNING. - @DEBUG - def test_log_output(self): - """ - When a structured logging config is given, Synapse will use it. - """ - beginner = FakeBeginner() - publisher = setup_logging(self.hs, self.hs.config, logBeginner=beginner) - - # Make a logger and send an event - logger = Logger(namespace="tests.logging.test_structured", observer=publisher) - - with LoggingContext("testcontext", request="somereq"): - logger.info("Hello there, {name}!", name="steve") - - with open(self.homeserver_log, "r") as f: - logged_events = [ - eventAsText(x, includeTimestamp=False) for x in eventsFromJSONLogFile(f) - ] - - logs = "\n".join(logged_events) - self.assertTrue("***** STARTING SERVER *****" in logs) - self.assertTrue("Hello there, steve!" in logs) diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index fd128b88e0..73f469b802 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -14,57 +14,33 @@ # limitations under the License. import json -from collections import Counter +import logging +from io import StringIO -from twisted.logger import Logger +from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter -from synapse.logging._structured import setup_structured_logging +from tests.logging import LoggerCleanupMixin +from tests.unittest import TestCase -from tests.server import connect_client -from tests.unittest import HomeserverTestCase -from .test_structured import FakeBeginner, StructuredLoggingTestBase - - -class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase): - def test_log_output(self): +class TerseJsonTestCase(LoggerCleanupMixin, TestCase): + def test_terse_json_output(self): """ - The Terse JSON outputter delivers simplified structured logs over TCP. + The Terse JSON formatter converts log messages to JSON. """ - log_config = { - "drains": { - "tersejson": { - "type": "network_json_terse", - "host": "127.0.0.1", - "port": 8000, - } - } - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, self.hs.config, log_config, logBeginner=beginner - ) - - logger = Logger( - namespace="tests.logging.test_terse_json", observer=beginner.observers[0] - ) - logger.info("Hello there, {name}!", name="wally") - - # Trigger the connection - self.pump() + output = StringIO() - _, server = connect_client(self.reactor, 0) + handler = logging.StreamHandler(output) + handler.setFormatter(TerseJsonFormatter()) + logger = self.get_logger(handler) - # Trigger data being sent - self.pump() + logger.info("Hello there, %s!", "wally") - # One log message, with a single trailing newline - logs = server.data.decode("utf8").splitlines() + # One log message, with a single trailing newline. + data = output.getvalue() + logs = data.splitlines() self.assertEqual(len(logs), 1) - self.assertEqual(server.data.count(b"\n"), 1) - + self.assertEqual(data.count("\n"), 1) log = json.loads(logs[0]) # The terse logger should give us these keys. @@ -72,163 +48,74 @@ class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase): "log", "time", "level", - "log_namespace", - "request", - "scope", - "server_name", - "name", + "namespace", ] self.assertCountEqual(log.keys(), expected_log_keys) + self.assertEqual(log["log"], "Hello there, wally!") - # It contains the data we expect. - self.assertEqual(log["name"], "wally") - - def test_log_backpressure_debug(self): + def test_extra_data(self): """ - When backpressure is hit, DEBUG logs will be shed. + Additional information can be included in the structured logging. """ - log_config = { - "loggers": {"synapse": {"level": "DEBUG"}}, - "drains": { - "tersejson": { - "type": "network_json_terse", - "host": "127.0.0.1", - "port": 8000, - "maximum_buffer": 10, - } - }, - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, - self.hs.config, - log_config, - logBeginner=beginner, - redirect_stdlib_logging=False, - ) - - logger = Logger( - namespace="synapse.logging.test_terse_json", observer=beginner.observers[0] - ) + output = StringIO() - # Send some debug messages - for i in range(0, 3): - logger.debug("debug %s" % (i,)) + handler = logging.StreamHandler(output) + handler.setFormatter(TerseJsonFormatter()) + logger = self.get_logger(handler) - # Send a bunch of useful messages - for i in range(0, 7): - logger.info("test message %s" % (i,)) - - # The last debug message pushes it past the maximum buffer - logger.debug("too much debug") - - # Allow the reconnection - _, server = connect_client(self.reactor, 0) - self.pump() - - # Only the 7 infos made it through, the debugs were elided - logs = server.data.splitlines() - self.assertEqual(len(logs), 7) - - def test_log_backpressure_info(self): - """ - When backpressure is hit, DEBUG and INFO logs will be shed. - """ - log_config = { - "loggers": {"synapse": {"level": "DEBUG"}}, - "drains": { - "tersejson": { - "type": "network_json_terse", - "host": "127.0.0.1", - "port": 8000, - "maximum_buffer": 10, - } - }, - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, - self.hs.config, - log_config, - logBeginner=beginner, - redirect_stdlib_logging=False, - ) - - logger = Logger( - namespace="synapse.logging.test_terse_json", observer=beginner.observers[0] + logger.info( + "Hello there, %s!", "wally", extra={"foo": "bar", "int": 3, "bool": True} ) - # Send some debug messages - for i in range(0, 3): - logger.debug("debug %s" % (i,)) - - # Send a bunch of useful messages - for i in range(0, 10): - logger.warn("test warn %s" % (i,)) - - # Send a bunch of info messages - for i in range(0, 3): - logger.info("test message %s" % (i,)) - - # The last debug message pushes it past the maximum buffer - logger.debug("too much debug") - - # Allow the reconnection - client, server = connect_client(self.reactor, 0) - self.pump() + # One log message, with a single trailing newline. + data = output.getvalue() + logs = data.splitlines() + self.assertEqual(len(logs), 1) + self.assertEqual(data.count("\n"), 1) + log = json.loads(logs[0]) - # The 10 warnings made it through, the debugs and infos were elided - logs = list(map(json.loads, server.data.decode("utf8").splitlines())) - self.assertEqual(len(logs), 10) + # The terse logger should give us these keys. + expected_log_keys = [ + "log", + "time", + "level", + "namespace", + # The additional keys given via extra. + "foo", + "int", + "bool", + ] + self.assertCountEqual(log.keys(), expected_log_keys) - self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10}) + # Check the values of the extra fields. + self.assertEqual(log["foo"], "bar") + self.assertEqual(log["int"], 3) + self.assertIs(log["bool"], True) - def test_log_backpressure_cut_middle(self): + def test_json_output(self): """ - When backpressure is hit, and no more DEBUG and INFOs cannot be culled, - it will cut the middle messages out. + The Terse JSON formatter converts log messages to JSON. """ - log_config = { - "loggers": {"synapse": {"level": "DEBUG"}}, - "drains": { - "tersejson": { - "type": "network_json_terse", - "host": "127.0.0.1", - "port": 8000, - "maximum_buffer": 10, - } - }, - } - - # Begin the logger with our config - beginner = FakeBeginner() - setup_structured_logging( - self.hs, - self.hs.config, - log_config, - logBeginner=beginner, - redirect_stdlib_logging=False, - ) + output = StringIO() - logger = Logger( - namespace="synapse.logging.test_terse_json", observer=beginner.observers[0] - ) + handler = logging.StreamHandler(output) + handler.setFormatter(JsonFormatter()) + logger = self.get_logger(handler) - # Send a bunch of useful messages - for i in range(0, 20): - logger.warn("test warn", num=i) + logger.info("Hello there, %s!", "wally") - # Allow the reconnection - client, server = connect_client(self.reactor, 0) - self.pump() + # One log message, with a single trailing newline. + data = output.getvalue() + logs = data.splitlines() + self.assertEqual(len(logs), 1) + self.assertEqual(data.count("\n"), 1) + log = json.loads(logs[0]) - # The first five and last five warnings made it through, the debugs and - # infos were elided - logs = list(map(json.loads, server.data.decode("utf8").splitlines())) - self.assertEqual(len(logs), 10) - self.assertEqual(Counter([x["level"] for x in logs]), {"WARN": 10}) - self.assertEqual([0, 1, 2, 3, 4, 15, 16, 17, 18, 19], [x["num"] for x in logs]) + # The terse logger should give us these keys. + expected_log_keys = [ + "log", + "level", + "namespace", + ] + self.assertCountEqual(log.keys(), expected_log_keys) + self.assertEqual(log["log"], "Hello there, wally!") diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 55545d9341..bcdcafa5a9 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -100,7 +100,7 @@ class EmailPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(self.access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.pusher = self.get_success( self.hs.get_pusherpool().add_pusher( @@ -131,6 +131,35 @@ class EmailPusherTests(HomeserverTestCase): # We should get emailed about that message self._check_for_mail() + def test_invite_sends_email(self): + # Create a room and invite the user to it + room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token) + self.helper.invite( + room=room, + src=self.others[0].id, + tok=self.others[0].token, + targ=self.user_id, + ) + + # We should get emailed about the invite + self._check_for_mail() + + def test_invite_to_empty_room_sends_email(self): + # Create a room and invite the user to it + room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token) + self.helper.invite( + room=room, + src=self.others[0].id, + tok=self.others[0].token, + targ=self.user_id, + ) + + # Then have the original user leave + self.helper.leave(room, self.others[0].id, tok=self.others[0].token) + + # We should get emailed about the invite + self._check_for_mail() + def test_multiple_members_email(self): # We want to test multiple notifications, so we pause processing of push # while we send messages. diff --git a/tests/push/test_http.py b/tests/push/test_http.py index b567868b02..8571924b29 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -69,7 +69,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -181,7 +181,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -297,7 +297,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -379,7 +379,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( @@ -452,7 +452,7 @@ class HTTPPusherTests(HomeserverTestCase): user_tuple = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_tuple["token_id"] + token_id = user_tuple.token_id self.get_success( self.hs.get_pusherpool().add_pusher( diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 093e2faac7..5c633ac6df 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -16,7 +16,6 @@ import logging from typing import Any, Callable, List, Optional, Tuple import attr -import hiredis from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime from twisted.internet.protocol import Protocol @@ -39,12 +38,22 @@ from synapse.util import Clock from tests import unittest from tests.server import FakeTransport, render +try: + import hiredis +except ImportError: + hiredis = None + logger = logging.getLogger(__name__) class BaseStreamTestCase(unittest.HomeserverTestCase): """Base class for tests of the replication streams""" + # hiredis is an optional dependency so we don't want to require it for running + # the tests. + if not hiredis: + skip = "Requires hiredis" + servlets = [ streams.register_servlets, ] @@ -269,7 +278,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): homeserver_to_use=GenericWorkerServer, config=config, reactor=self.reactor, - **kwargs + **kwargs, ) # If the instance is in the `instance_map` config then workers may try diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index c9998e88e6..bad0df08cf 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -449,7 +449,7 @@ class EventsStreamTestCase(BaseStreamTestCase): sender=sender, type="test_event", content={"body": body}, - **kwargs + **kwargs, ) ) diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py new file mode 100644 index 0000000000..77c261dbf7 --- /dev/null +++ b/tests/replication/test_multi_media_repo.py @@ -0,0 +1,277 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from binascii import unhexlify +from typing import Tuple + +from twisted.internet.protocol import Factory +from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.web.http import HTTPChannel +from twisted.web.server import Request + +from synapse.rest import admin +from synapse.rest.client.v1 import login +from synapse.server import HomeServer + +from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.server import FakeChannel, FakeTransport + +logger = logging.getLogger(__name__) + +test_server_connection_factory = None + + +class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): + """Checks running multiple media repos work correctly. + """ + + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.user_id = self.register_user("user", "pass") + self.access_token = self.login("user", "pass") + + self.reactor.lookups["example.com"] = "127.0.0.2" + + def default_config(self): + conf = super().default_config() + conf["federation_custom_ca_list"] = [get_test_ca_cert_file()] + return conf + + def _get_media_req( + self, hs: HomeServer, target: str, media_id: str + ) -> Tuple[FakeChannel, Request]: + """Request some remote media from the given HS by calling the download + API. + + This then triggers an outbound request from the HS to the target. + + Returns: + The channel for the *client* request and the *outbound* request for + the media which the caller should respond to. + """ + + request, channel = self.make_request( + "GET", + "/{}/{}".format(target, media_id), + shorthand=False, + access_token=self.access_token, + ) + request.render(hs.get_media_repository_resource().children[b"download"]) + self.pump() + + clients = self.reactor.tcpClients + self.assertGreaterEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop() + + # build the test server + server_tls_protocol = _build_test_server(get_connection_factory()) + + # now, tell the client protocol factory to build the client protocol (it will be a + # _WrappingProtocol, around a TLSMemoryBIOProtocol, around an + # HTTP11ClientProtocol) and wire the output of said protocol up to the server via + # a FakeTransport. + # + # Normally this would be done by the TCP socket code in Twisted, but we are + # stubbing that out here. + client_protocol = client_factory.buildProtocol(None) + client_protocol.makeConnection( + FakeTransport(server_tls_protocol, self.reactor, client_protocol) + ) + + # tell the server tls protocol to send its stuff back to the client, too + server_tls_protocol.makeConnection( + FakeTransport(client_protocol, self.reactor, server_tls_protocol) + ) + + # fish the test server back out of the server-side TLS protocol. + http_server = server_tls_protocol.wrappedProtocol + + # give the reactor a pump to get the TLS juices flowing. + self.reactor.pump((0.1,)) + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + + self.assertEqual(request.method, b"GET") + self.assertEqual( + request.path, + "/_matrix/media/r0/download/{}/{}".format(target, media_id).encode("utf-8"), + ) + self.assertEqual( + request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] + ) + + return channel, request + + def test_basic(self): + """Test basic fetching of remote media from a single worker. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + + channel, request = self._get_media_req(hs1, "example.com:443", "ABC123") + + request.setResponseCode(200) + request.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"]) + request.write(b"Hello!") + request.finish() + + self.pump(0.1) + + self.assertEqual(channel.code, 200) + self.assertEqual(channel.result["body"], b"Hello!") + + def test_download_simple_file_race(self): + """Test that fetching remote media from two different processes at the + same time works. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + hs2 = self.make_worker_hs("synapse.app.generic_worker") + + start_count = self._count_remote_media() + + # Make two requests without responding to the outbound media requests. + channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123") + channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123") + + # Respond to the first outbound media request and check that the client + # request is successful + request1.setResponseCode(200) + request1.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"]) + request1.write(b"Hello!") + request1.finish() + + self.pump(0.1) + + self.assertEqual(channel1.code, 200, channel1.result["body"]) + self.assertEqual(channel1.result["body"], b"Hello!") + + # Now respond to the second with the same content. + request2.setResponseCode(200) + request2.responseHeaders.setRawHeaders(b"Content-Type", [b"text/plain"]) + request2.write(b"Hello!") + request2.finish() + + self.pump(0.1) + + self.assertEqual(channel2.code, 200, channel2.result["body"]) + self.assertEqual(channel2.result["body"], b"Hello!") + + # We expect only one new file to have been persisted. + self.assertEqual(start_count + 1, self._count_remote_media()) + + def test_download_image_race(self): + """Test that fetching remote *images* from two different processes at + the same time works. + + This checks that races generating thumbnails are handled correctly. + """ + hs1 = self.make_worker_hs("synapse.app.generic_worker") + hs2 = self.make_worker_hs("synapse.app.generic_worker") + + start_count = self._count_remote_thumbnails() + + channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1") + channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1") + + png_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + request1.setResponseCode(200) + request1.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"]) + request1.write(png_data) + request1.finish() + + self.pump(0.1) + + self.assertEqual(channel1.code, 200, channel1.result["body"]) + self.assertEqual(channel1.result["body"], png_data) + + request2.setResponseCode(200) + request2.responseHeaders.setRawHeaders(b"Content-Type", [b"image/png"]) + request2.write(png_data) + request2.finish() + + self.pump(0.1) + + self.assertEqual(channel2.code, 200, channel2.result["body"]) + self.assertEqual(channel2.result["body"], png_data) + + # We expect only three new thumbnails to have been persisted. + self.assertEqual(start_count + 3, self._count_remote_thumbnails()) + + def _count_remote_media(self) -> int: + """Count the number of files in our remote media directory. + """ + path = os.path.join( + self.hs.get_media_repository().primary_base_path, "remote_content" + ) + return sum(len(files) for _, _, files in os.walk(path)) + + def _count_remote_thumbnails(self) -> int: + """Count the number of files in our remote thumbnails directory. + """ + path = os.path.join( + self.hs.get_media_repository().primary_base_path, "remote_thumbnail" + ) + return sum(len(files) for _, _, files in os.walk(path)) + + +def get_connection_factory(): + # this needs to happen once, but not until we are ready to run the first test + global test_server_connection_factory + if test_server_connection_factory is None: + test_server_connection_factory = TestServerTLSConnectionFactory( + sanlist=[b"DNS:example.com"] + ) + return test_server_connection_factory + + +def _build_test_server(connection_creator): + """Construct a test server + + This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol + + Args: + connection_creator (IOpenSSLServerConnectionCreator): thing to build + SSL connections + sanlist (list[bytes]): list of the SAN entries for the cert returned + by the server + + Returns: + TLSMemoryBIOProtocol + """ + server_factory = Factory.forProtocol(HTTPChannel) + # Request.finish expects the factory to have a 'log' method. + server_factory.log = _log_request + + server_tls_factory = TLSMemoryBIOFactory( + connection_creator, isClient=False, wrappedFactory=server_factory + ) + + return server_tls_factory.buildProtocol(None) + + +def _log_request(request): + """Implements Factory.log, which is expected by Request.finish""" + logger.info("Completed request %s", request) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 2bdc6edbb1..67c27a089f 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): user_dict = self.get_success( self.hs.get_datastore().get_user_by_access_token(access_token) ) - token_id = user_dict["token_id"] + token_id = user_dict.token_id self.get_success( self.hs.get_pusherpool().add_pusher( diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index 92c9058887..d89eb90cfe 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -393,6 +393,22 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) + def test_user_has_no_devices(self): + """ + Tests that a normal lookup for devices is successfully + if user has no devices + """ + + # Get devices + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["devices"])) + def test_get_devices(self): """ Tests that a normal lookup for devices is successfully @@ -409,6 +425,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase): self.render(request) self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(number_devices, channel.json_body["total"]) self.assertEqual(number_devices, len(channel.json_body["devices"])) self.assertEqual(self.other_user, channel.json_body["devices"][0]["user_id"]) # Check that all fields are available diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index bf79086f78..303622217f 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -70,6 +70,16 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/event_reports" + def test_no_auth(self): + """ + Try to get an event report without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + def test_requester_is_no_admin(self): """ If the user is not a server admin, an error 403 is returned. @@ -266,7 +276,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): def test_limit_is_negative(self): """ - Testing that a negative list parameter returns a 400 + Testing that a negative limit parameter returns a 400 """ request, channel = self.make_request( @@ -360,7 +370,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) def _check_fields(self, content): - """Checks that all attributes are present in a event report + """Checks that all attributes are present in an event report """ for c in content: self.assertIn("id", c) @@ -368,15 +378,175 @@ class EventReportsTestCase(unittest.HomeserverTestCase): self.assertIn("room_id", c) self.assertIn("event_id", c) self.assertIn("user_id", c) - self.assertIn("reason", c) - self.assertIn("content", c) self.assertIn("sender", c) - self.assertIn("room_alias", c) - self.assertIn("event_json", c) - self.assertIn("score", c["content"]) - self.assertIn("reason", c["content"]) - self.assertIn("auth_events", c["event_json"]) - self.assertIn("type", c["event_json"]) - self.assertIn("room_id", c["event_json"]) - self.assertIn("sender", c["event_json"]) - self.assertIn("content", c["event_json"]) + self.assertIn("canonical_alias", c) + self.assertIn("name", c) + self.assertIn("score", c) + self.assertIn("reason", c) + + +class EventReportDetailTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + report_event.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id1 = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok, is_public=True + ) + self.helper.join(self.room_id1, user=self.admin_user, tok=self.admin_user_tok) + + self._create_event_and_report( + room_id=self.room_id1, user_tok=self.other_user_tok, + ) + + # first created event report gets `id`=2 + self.url = "/_synapse/admin/v1/event_reports/2" + + def test_no_auth(self): + """ + Try to get event report without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error 403 is returned. + """ + + request, channel = self.make_request( + "GET", self.url, access_token=self.other_user_tok, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_default_success(self): + """ + Testing get a reported event + """ + + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self._check_fields(channel.json_body) + + def test_invalid_report_id(self): + """ + Testing that an invalid `report_id` returns a 400. + """ + + # `report_id` is negative + request, channel = self.make_request( + "GET", + "/_synapse/admin/v1/event_reports/-123", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is a non-numerical string + request, channel = self.make_request( + "GET", + "/_synapse/admin/v1/event_reports/abcdef", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + # `report_id` is undefined + request, channel = self.make_request( + "GET", + "/_synapse/admin/v1/event_reports/", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "The report_id parameter must be a string representing a positive integer.", + channel.json_body["error"], + ) + + def test_report_id_not_found(self): + """ + Testing that a not existing `report_id` returns a 404. + """ + + request, channel = self.make_request( + "GET", + "/_synapse/admin/v1/event_reports/123", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + self.assertEqual("Event report not found", channel.json_body["error"]) + + def _create_event_and_report(self, room_id, user_tok): + """Create and report events + """ + resp = self.helper.send(room_id, tok=user_tok) + event_id = resp["event_id"] + + request, channel = self.make_request( + "POST", + "rooms/%s/report/%s" % (room_id, event_id), + json.dumps({"score": -100, "reason": "this makes me sad"}), + access_token=user_tok, + ) + self.render(request) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + + def _check_fields(self, content): + """Checks that all attributes are present in a event report + """ + self.assertIn("id", content) + self.assertIn("received_ts", content) + self.assertIn("room_id", content) + self.assertIn("event_id", content) + self.assertIn("user_id", content) + self.assertIn("sender", content) + self.assertIn("canonical_alias", content) + self.assertIn("name", content) + self.assertIn("event_json", content) + self.assertIn("score", content) + self.assertIn("reason", content) + self.assertIn("auth_events", content["event_json"]) + self.assertIn("type", content["event_json"]) + self.assertIn("room_id", content["event_json"]) + self.assertIn("sender", content["event_json"]) + self.assertIn("content", content["event_json"]) diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py new file mode 100644 index 0000000000..721fa1ed51 --- /dev/null +++ b/tests/rest/admin/test_media.py @@ -0,0 +1,568 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from binascii import unhexlify + +import synapse.rest.admin +from synapse.api.errors import Codes +from synapse.rest.client.v1 import login, profile, room +from synapse.rest.media.v1.filepath import MediaFilePaths + +from tests import unittest + + +class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_media_repo, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.handler = hs.get_device_handler() + self.media_repo = hs.get_media_repository_resource() + self.server_name = hs.hostname + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.filepaths = MediaFilePaths(hs.config.media_store_path) + + def test_no_auth(self): + """ + Try to delete media without authentication. + """ + url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") + + request, channel = self.make_request("DELETE", url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") + + request, channel = self.make_request( + "DELETE", url, access_token=self.other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_media_does_not_exist(self): + """ + Tests that a lookup for a media that does not exist returns a 404 + """ + url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, "12345") + + request, channel = self.make_request( + "DELETE", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_media_is_not_local(self): + """ + Tests that a lookup for a media that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/media/%s/%s" % ("unknown_domain", "12345") + + request, channel = self.make_request( + "DELETE", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only delete local media", channel.json_body["error"]) + + def test_delete_media(self): + """ + Tests that delete a media is successfully + """ + + download_resource = self.media_repo.children[b"download"] + upload_resource = self.media_repo.children[b"upload"] + image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + # Upload some media into the room + response = self.helper.upload_media( + upload_resource, image_data, tok=self.admin_user_tok, expect_code=200 + ) + # Extract media ID from the response + server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' + server_name, media_id = server_and_media_id.split("/") + + self.assertEqual(server_name, self.server_name) + + # Attempt to access media + request, channel = self.make_request( + "GET", + server_and_media_id, + shorthand=False, + access_token=self.admin_user_tok, + ) + request.render(download_resource) + self.pump(1.0) + + # Should be successful + self.assertEqual( + 200, + channel.code, + msg=( + "Expected to receive a 200 on accessing media: %s" % server_and_media_id + ), + ) + + # Test if the file exists + local_path = self.filepaths.local_media_filepath(media_id) + self.assertTrue(os.path.exists(local_path)) + + url = "/_synapse/admin/v1/media/%s/%s" % (self.server_name, media_id) + + # Delete media + request, channel = self.make_request( + "DELETE", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + media_id, channel.json_body["deleted_media"][0], + ) + + # Attempt to access media + request, channel = self.make_request( + "GET", + server_and_media_id, + shorthand=False, + access_token=self.admin_user_tok, + ) + request.render(download_resource) + self.pump(1.0) + self.assertEqual( + 404, + channel.code, + msg=( + "Expected to receive a 404 on accessing deleted media: %s" + % server_and_media_id + ), + ) + + # Test if the file is deleted + self.assertFalse(os.path.exists(local_path)) + + +class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_media_repo, + login.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.handler = hs.get_device_handler() + self.media_repo = hs.get_media_repository_resource() + self.server_name = hs.hostname + self.clock = hs.clock + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.filepaths = MediaFilePaths(hs.config.media_store_path) + self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name + + def test_no_auth(self): + """ + Try to delete media without authentication. + """ + + request, channel = self.make_request("POST", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + request, channel = self.make_request( + "POST", self.url, access_token=self.other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_media_is_not_local(self): + """ + Tests that a lookup for media that is not local returns a 400 + """ + url = "/_synapse/admin/v1/media/%s/delete" % "unknown_domain" + + request, channel = self.make_request( + "POST", url + "?before_ts=1234", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only delete local media", channel.json_body["error"]) + + def test_missing_parameter(self): + """ + If the parameter `before_ts` is missing, an error is returned. + """ + request, channel = self.make_request( + "POST", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "Missing integer query parameter b'before_ts'", channel.json_body["error"] + ) + + def test_invalid_parameter(self): + """ + If parameters are invalid, an error is returned. + """ + request, channel = self.make_request( + "POST", self.url + "?before_ts=-1234", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "Query parameter before_ts must be a string representing a positive integer.", + channel.json_body["error"], + ) + + request, channel = self.make_request( + "POST", + self.url + "?before_ts=1234&size_gt=-1234", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + self.assertEqual( + "Query parameter size_gt must be a string representing a positive integer.", + channel.json_body["error"], + ) + + request, channel = self.make_request( + "POST", + self.url + "?before_ts=1234&keep_profiles=not_bool", + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + self.assertEqual( + "Boolean query parameter b'keep_profiles' must be one of ['true', 'false']", + channel.json_body["error"], + ) + + def test_delete_media_never_accessed(self): + """ + Tests that media deleted if it is older than `before_ts` and never accessed + `last_access_ts` is `NULL` and `created_ts` < `before_ts` + """ + + # upload and do not access + server_and_media_id = self._create_media() + self.pump(1.0) + + # test that the file exists + media_id = server_and_media_id.split("/")[1] + local_path = self.filepaths.local_media_filepath(media_id) + self.assertTrue(os.path.exists(local_path)) + + # timestamp after upload/create + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + media_id, channel.json_body["deleted_media"][0], + ) + + self._access_media(server_and_media_id, False) + + def test_keep_media_by_date(self): + """ + Tests that media is not deleted if it is newer than `before_ts` + """ + + # timestamp before upload + now_ms = self.clock.time_msec() + server_and_media_id = self._create_media() + + self._access_media(server_and_media_id) + + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + + self._access_media(server_and_media_id) + + # timestamp after upload + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + ) + + self._access_media(server_and_media_id, False) + + def test_keep_media_by_size(self): + """ + Tests that media is not deleted if its size is smaller than or equal + to `size_gt` + """ + server_and_media_id = self._create_media() + + self._access_media(server_and_media_id) + + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&size_gt=67", + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + + self._access_media(server_and_media_id) + + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&size_gt=66", + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + ) + + self._access_media(server_and_media_id, False) + + def test_keep_media_by_user_avatar(self): + """ + Tests that we do not delete media if is used as a user avatar + Tests parameter `keep_profiles` + """ + server_and_media_id = self._create_media() + + self._access_media(server_and_media_id) + + # set media as avatar + request, channel = self.make_request( + "PUT", + "/profile/%s/avatar_url" % (self.admin_user,), + content=json.dumps({"avatar_url": "mxc://%s" % (server_and_media_id,)}), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + + self._access_media(server_and_media_id) + + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + ) + + self._access_media(server_and_media_id, False) + + def test_keep_media_by_room_avatar(self): + """ + Tests that we do not delete media if it is used as a room avatar + Tests parameter `keep_profiles` + """ + server_and_media_id = self._create_media() + + self._access_media(server_and_media_id) + + # set media as room avatar + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + request, channel = self.make_request( + "PUT", + "/rooms/%s/state/m.room.avatar" % (room_id,), + content=json.dumps({"url": "mxc://%s" % (server_and_media_id,)}), + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=true", + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + + self._access_media(server_and_media_id) + + now_ms = self.clock.time_msec() + request, channel = self.make_request( + "POST", + self.url + "?before_ts=" + str(now_ms) + "&keep_profiles=false", + access_token=self.admin_user_tok, + ) + self.render(request) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual( + server_and_media_id.split("/")[1], channel.json_body["deleted_media"][0], + ) + + self._access_media(server_and_media_id, False) + + def _create_media(self): + """ + Create a media and return media_id and server_and_media_id + """ + upload_resource = self.media_repo.children[b"upload"] + # file size is 67 Byte + image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + # Upload some media into the room + response = self.helper.upload_media( + upload_resource, image_data, tok=self.admin_user_tok, expect_code=200 + ) + # Extract media ID from the response + server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' + server_name = server_and_media_id.split("/")[0] + + # Check that new media is a local and not remote + self.assertEqual(server_name, self.server_name) + + return server_and_media_id + + def _access_media(self, server_and_media_id, expect_success=True): + """ + Try to access a media and check the result + """ + download_resource = self.media_repo.children[b"download"] + + media_id = server_and_media_id.split("/")[1] + local_path = self.filepaths.local_media_filepath(media_id) + + request, channel = self.make_request( + "GET", + server_and_media_id, + shorthand=False, + access_token=self.admin_user_tok, + ) + request.render(download_resource) + self.pump(1.0) + + if expect_success: + self.assertEqual( + 200, + channel.code, + msg=( + "Expected to receive a 200 on accessing media: %s" + % server_and_media_id + ), + ) + # Test that the file exists + self.assertTrue(os.path.exists(local_path)) + else: + self.assertEqual( + 404, + channel.code, + msg=( + "Expected to receive a 404 on accessing deleted media: %s" + % (server_and_media_id) + ), + ) + # Test that the file is deleted + self.assertFalse(os.path.exists(local_path)) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 98d0623734..7df32e5093 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -17,6 +17,7 @@ import hashlib import hmac import json import urllib.parse +from binascii import unhexlify from mock import Mock @@ -1016,7 +1017,6 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, - sync.register_servlets, room.register_servlets, ] @@ -1082,6 +1082,21 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Can only lookup local users", channel.json_body["error"]) + def test_no_memberships(self): + """ + Tests that a normal lookup for rooms is successfully + if user has no memberships + """ + # Get rooms + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["joined_rooms"])) + def test_get_rooms(self): """ Tests that a normal lookup for rooms is successfully @@ -1101,3 +1116,408 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(number_rooms, channel.json_body["total"]) self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"])) + + +class PushersRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.url = "/_synapse/admin/v1/users/%s/pushers" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to list pushers of an user without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + request, channel = self.make_request( + "GET", self.url, access_token=other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + url = "/_synapse/admin/v1/users/@unknown_person:test/pushers" + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers" + + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + def test_get_pushers(self): + """ + Tests that a normal lookup for pushers is successfully + """ + + # Get pushers + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + + # Register the pusher + other_user_token = self.login("user", "pass") + user_tuple = self.get_success( + self.store.get_user_by_access_token(other_user_token) + ) + token_id = user_tuple.token_id + + self.get_success( + self.hs.get_pusherpool().add_pusher( + user_id=self.other_user, + access_token=token_id, + kind="http", + app_id="m.http", + app_display_name="HTTP Push Notifications", + device_display_name="pushy push", + pushkey="a@example.com", + lang=None, + data={"url": "example.com"}, + ) + ) + + # Get pushers + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(1, channel.json_body["total"]) + + for p in channel.json_body["pushers"]: + self.assertIn("pushkey", p) + self.assertIn("kind", p) + self.assertIn("app_id", p) + self.assertIn("app_display_name", p) + self.assertIn("device_display_name", p) + self.assertIn("profile_tag", p) + self.assertIn("lang", p) + self.assertIn("url", p["data"]) + + +class UserMediaRestTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.media_repo = hs.get_media_repository_resource() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + self.other_user = self.register_user("user", "pass") + self.url = "/_synapse/admin/v1/users/%s/media" % urllib.parse.quote( + self.other_user + ) + + def test_no_auth(self): + """ + Try to list media of an user without authentication. + """ + request, channel = self.make_request("GET", self.url, b"{}") + self.render(request) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + def test_requester_is_no_admin(self): + """ + If the user is not a server admin, an error is returned. + """ + other_user_token = self.login("user", "pass") + + request, channel = self.make_request( + "GET", self.url, access_token=other_user_token, + ) + self.render(request) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_user_does_not_exist(self): + """ + Tests that a lookup for a user that does not exist returns a 404 + """ + url = "/_synapse/admin/v1/users/@unknown_person:test/media" + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(404, channel.code, msg=channel.json_body) + self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) + + def test_user_is_not_local(self): + """ + Tests that a lookup for a user that is not a local returns a 400 + """ + url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/media" + + request, channel = self.make_request( + "GET", url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, channel.code, msg=channel.json_body) + self.assertEqual("Can only lookup local users", channel.json_body["error"]) + + def test_limit(self): + """ + Testing list of media with limit + """ + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media(other_user_tok, number_media) + + request, channel = self.make_request( + "GET", self.url + "?limit=5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), 5) + self.assertEqual(channel.json_body["next_token"], 5) + self._check_fields(channel.json_body["media"]) + + def test_from(self): + """ + Testing list of media with a defined starting point (from) + """ + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media(other_user_tok, number_media) + + request, channel = self.make_request( + "GET", self.url + "?from=5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), 15) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["media"]) + + def test_limit_and_from(self): + """ + Testing list of media with a defined starting point and limit + """ + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media(other_user_tok, number_media) + + request, channel = self.make_request( + "GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(channel.json_body["next_token"], 15) + self.assertEqual(len(channel.json_body["media"]), 10) + self._check_fields(channel.json_body["media"]) + + def test_limit_is_negative(self): + """ + Testing that a negative limit parameter returns a 400 + """ + + request, channel = self.make_request( + "GET", self.url + "?limit=-5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_from_is_negative(self): + """ + Testing that a negative from parameter returns a 400 + """ + + request, channel = self.make_request( + "GET", self.url + "?from=-5", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) + + def test_next_token(self): + """ + Testing that `next_token` appears at the right place + """ + + number_media = 20 + other_user_tok = self.login("user", "pass") + self._create_media(other_user_tok, number_media) + + # `next_token` does not appear + # Number of results is the number of entries + request, channel = self.make_request( + "GET", self.url + "?limit=20", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), number_media) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does not appear + # Number of max results is larger than the number of entries + request, channel = self.make_request( + "GET", self.url + "?limit=21", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), number_media) + self.assertNotIn("next_token", channel.json_body) + + # `next_token` does appear + # Number of max results is smaller than the number of entries + request, channel = self.make_request( + "GET", self.url + "?limit=19", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), 19) + self.assertEqual(channel.json_body["next_token"], 19) + + # Check + # Set `from` to value of `next_token` for request remaining entries + # `next_token` does not appear + request, channel = self.make_request( + "GET", self.url + "?from=19", access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(channel.json_body["total"], number_media) + self.assertEqual(len(channel.json_body["media"]), 1) + self.assertNotIn("next_token", channel.json_body) + + def test_user_has_no_media(self): + """ + Tests that a normal lookup for media is successfully + if user has no media created + """ + + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(0, channel.json_body["total"]) + self.assertEqual(0, len(channel.json_body["media"])) + + def test_get_media(self): + """ + Tests that a normal lookup for media is successfully + """ + + number_media = 5 + other_user_tok = self.login("user", "pass") + self._create_media(other_user_tok, number_media) + + request, channel = self.make_request( + "GET", self.url, access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(number_media, channel.json_body["total"]) + self.assertEqual(number_media, len(channel.json_body["media"])) + self.assertNotIn("next_token", channel.json_body) + self._check_fields(channel.json_body["media"]) + + def _create_media(self, user_token, number_media): + """ + Create a number of media for a specific user + """ + upload_resource = self.media_repo.children[b"upload"] + for i in range(number_media): + # file size is 67 Byte + image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + # Upload some media into the room + self.helper.upload_media( + upload_resource, image_data, tok=user_token, expect_code=200 + ) + + def _check_fields(self, content): + """Checks that all attributes are present in content + """ + for m in content: + self.assertIn("media_id", m) + self.assertIn("media_type", m) + self.assertIn("media_length", m) + self.assertIn("upload_name", m) + self.assertIn("created_ts", m) + self.assertIn("last_access_ts", m) + self.assertIn("quarantined_by", m) + self.assertIn("safe_from_quarantine", m) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 2fc3a60fc5..98c3887bbf 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -55,6 +55,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.hs.config.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, + sender="@as:test", ) self.hs.get_datastore().services_cache.append(appservice) diff --git a/tests/server.py b/tests/server.py index 4d33b84097..3dd2cfc072 100644 --- a/tests/server.py +++ b/tests/server.py @@ -46,7 +46,7 @@ class FakeChannel: site = attr.ib(type=Site) _reactor = attr.ib() - result = attr.ib(default=attr.Factory(dict)) + result = attr.ib(type=dict, default=attr.Factory(dict)) _producer = None @property @@ -380,7 +380,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): pool._runWithConnection, func, *args, - **kwargs + **kwargs, ) def runInteraction(interaction, *args, **kwargs): @@ -390,7 +390,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): pool._runInteraction, interaction, *args, - **kwargs + **kwargs, ) pool.runWithConnection = runWithConnection @@ -571,12 +571,10 @@ def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol reactor factory: The connecting factory to build. """ - factory = reactor.tcpClients[client_id][2] + factory = reactor.tcpClients.pop(client_id)[2] client = factory.buildProtocol(None) server = AccumulatingProtocol() server.makeConnection(FakeTransport(client, reactor)) client.makeConnection(FakeTransport(server, reactor)) - reactor.tcpClients.pop(client_id) - return client, server diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 080761d1d2..5a1e5c4e66 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -22,7 +22,7 @@ import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.rest.client.v1 import login, room from synapse.storage import prepare_database -from synapse.types import Requester, UserID +from synapse.types import UserID, create_requester from tests.unittest import HomeserverTestCase @@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID("alice", "test") - self.requester = Requester(self.user, None, False, False, None, None) + self.requester = create_requester(self.user) info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] @@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): # Create a test user and room self.user = UserID.from_string(self.register_user("user1", "password")) self.token1 = self.login("user1", "password") - self.requester = Requester(self.user, None, False, False, None, None) + self.requester = create_requester(self.user) info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 755c70db31..e96ca1c8ca 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -412,7 +412,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): "GET", "/_matrix/client/r0/admin/users/" + self.user_id, access_token=access_token, - **make_request_args + **make_request_args, ) request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza") diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py index 3957471f3f..7691f2d790 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/test_event_metrics.py @@ -14,7 +14,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, generate_latest -from synapse.types import Requester, UserID +from synapse.types import UserID, create_requester from tests.unittest import HomeserverTestCase @@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase): room_creator = self.hs.get_room_creation_handler() user = UserID("alice", "test") - requester = Requester(user, None, False, False, None, None) + requester = create_requester(user) # Real events, forward extremities events = [(3, 2), (6, 2), (4, 6)] diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 6b582771fe..c8c7a90e5d 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -69,11 +69,9 @@ class RegistrationStoreTestCase(unittest.TestCase): self.store.get_user_by_access_token(self.tokens[1]) ) - self.assertDictContainsSubset( - {"name": self.user_id, "device_id": self.device_id}, result - ) - - self.assertTrue("token_id" in result) + self.assertEqual(result.user_id, self.user_id) + self.assertEqual(result.device_id, self.device_id) + self.assertIsNotNone(result.token_id) @defer.inlineCallbacks def test_user_delete_access_tokens(self): @@ -105,7 +103,7 @@ class RegistrationStoreTestCase(unittest.TestCase): user = yield defer.ensureDeferred( self.store.get_user_by_access_token(self.tokens[0]) ) - self.assertEqual(self.user_id, user["name"]) + self.assertEqual(self.user_id, user.user_id) # now delete the rest yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id)) diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 12ccc1f53e..ff972daeaa 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -19,7 +19,7 @@ from unittest.mock import Mock from synapse.api.constants import Membership from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client.v1 import login, room -from synapse.types import Requester, UserID +from synapse.types import UserID, create_requester from tests import unittest from tests.test_utils import event_injection @@ -187,7 +187,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase): # Now let's create a room, which will insert a membership user = UserID("alice", "test") - requester = Requester(user, None, False, False, None, None) + requester = create_requester(user) self.get_success(self.room_creator.create_room(requester, {})) # Register the background update to run again. diff --git a/tests/test_federation.py b/tests/test_federation.py index d39e792580..1ce4ea3a01 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -20,7 +20,7 @@ from twisted.internet.defer import succeed from synapse.api.errors import FederationError from synapse.events import make_event_from_dict from synapse.logging.context import LoggingContext -from synapse.types import Requester, UserID +from synapse.types import UserID, create_requester from synapse.util import Clock from synapse.util.retryutils import NotRetryingDestination @@ -43,7 +43,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) user_id = UserID("us", "test") - our_user = Requester(user_id, None, False, False, None, None) + our_user = create_requester(user_id) room_creator = self.homeserver.get_room_creation_handler() self.room_id = self.get_success( room_creator.create_room( diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index a298cc0fd3..d232b72264 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -17,8 +17,10 @@ """ Utilities for running the unit tests """ +import sys +import warnings from asyncio import Future -from typing import Any, Awaitable, TypeVar +from typing import Any, Awaitable, Callable, TypeVar TV = TypeVar("TV") @@ -48,3 +50,33 @@ def make_awaitable(result: Any) -> Awaitable[Any]: future = Future() # type: ignore future.set_result(result) return future + + +def setup_awaitable_errors() -> Callable[[], None]: + """ + Convert warnings from a non-awaited coroutines into errors. + """ + warnings.simplefilter("error", RuntimeWarning) + + # unraisablehook was added in Python 3.8. + if not hasattr(sys, "unraisablehook"): + return lambda: None + + # State shared between unraisablehook and check_for_unraisable_exceptions. + unraisable_exceptions = [] + orig_unraisablehook = sys.unraisablehook # type: ignore + + def unraisablehook(unraisable): + unraisable_exceptions.append(unraisable.exc_value) + + def cleanup(): + """ + A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions. + """ + sys.unraisablehook = orig_unraisablehook # type: ignore + if unraisable_exceptions: + raise unraisable_exceptions.pop() + + sys.unraisablehook = unraisablehook # type: ignore + + return cleanup diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index e93aa84405..c3c4a93e1f 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -50,7 +50,7 @@ async def inject_member_event( sender=sender, state_key=target, content=content, - **kwargs + **kwargs, ) diff --git a/tests/unittest.py b/tests/unittest.py index 040b126a27..08cf9b10c5 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -44,7 +44,7 @@ from synapse.logging.context import ( set_current_context, ) from synapse.server import HomeServer -from synapse.types import Requester, UserID, create_requester +from synapse.types import UserID, create_requester from synapse.util.ratelimitutils import FederationRateLimiter from tests.server import ( @@ -54,7 +54,7 @@ from tests.server import ( render, setup_test_homeserver, ) -from tests.test_utils import event_injection +from tests.test_utils import event_injection, setup_awaitable_errors from tests.test_utils.logging_setup import setup_logging from tests.utils import default_config, setupdb @@ -119,6 +119,10 @@ class TestCase(unittest.TestCase): logging.getLogger().setLevel(level) + # Trial messes with the warnings configuration, thus this has to be + # done in the context of an individual TestCase. + self.addCleanup(setup_awaitable_errors()) + return orig() @around(self) @@ -627,7 +631,7 @@ class HomeserverTestCase(TestCase): """ event_creator = self.hs.get_event_creation_handler() secrets = self.hs.get_secrets() - requester = Requester(user, None, False, False, None, None) + requester = create_requester(user) event, context = self.get_success( event_creator.create_event( diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 2ad08f541b..cf1e3203a4 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -29,13 +29,46 @@ from synapse.logging.context import ( make_deferred_yieldable, ) from synapse.util.caches import descriptors -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, lru_cache from tests import unittest +from tests.test_utils import get_awaitable_result logger = logging.getLogger(__name__) +class LruCacheDecoratorTestCase(unittest.TestCase): + def test_base(self): + class Cls: + def __init__(self): + self.mock = mock.Mock() + + @lru_cache() + def fn(self, arg1, arg2): + return self.mock(arg1, arg2) + + obj = Cls() + obj.mock.return_value = "fish" + r = obj.fn(1, 2) + self.assertEqual(r, "fish") + obj.mock.assert_called_once_with(1, 2) + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = "chips" + r = obj.fn(1, 3) + self.assertEqual(r, "chips") + obj.mock.assert_called_once_with(1, 3) + obj.mock.reset_mock() + + # the two values should now be cached + r = obj.fn(1, 2) + self.assertEqual(r, "fish") + r = obj.fn(1, 3) + self.assertEqual(r, "chips") + obj.mock.assert_not_called() + + def run_on_reactor(): d = defer.Deferred() reactor.callLater(0, d.callback, 0) @@ -362,6 +395,31 @@ class DescriptorTestCase(unittest.TestCase): d = obj.fn(1) self.failureResultOf(d, SynapseError) + def test_invalidate_cascade(self): + """Invalidations should cascade up through cache contexts""" + + class Cls: + @cached(cache_context=True) + async def func1(self, key, cache_context): + return await self.func2(key, on_invalidate=cache_context.invalidate) + + @cached(cache_context=True) + async def func2(self, key, cache_context): + return self.func3(key, on_invalidate=cache_context.invalidate) + + @lru_cache(cache_context=True) + def func3(self, key, cache_context): + self.invalidate = cache_context.invalidate + return 42 + + obj = Cls() + + top_invalidate = mock.Mock() + r = get_awaitable_result(obj.func1("k1", on_invalidate=top_invalidate)) + self.assertEqual(r, 42) + obj.invalidate() + top_invalidate.assert_called_once() + class CacheDecoratorTestCase(unittest.HomeserverTestCase): """More tests for @cached diff --git a/tox.ini b/tox.ini index 6dcc439a40..c232676826 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = packaging, py35, py36, py37, py38, check_codestyle, check_isort +envlist = packaging, py35, py36, py37, py38, py39, check_codestyle, check_isort [base] extras = test @@ -24,11 +24,6 @@ deps = pip>=10 setenv = - # we have a pyproject.toml, but don't want pip to use it for building. - # (otherwise we get an error about 'editable mode is not supported for - # pyproject.toml-style projects'). - PIP_USE_PEP517 = false - PYTHONDONTWRITEBYTECODE = no_byte_code COVERAGE_PROCESS_START = {toxinidir}/.coveragerc |