diff options
92 files changed, 1421 insertions, 386 deletions
diff --git a/changelog.d/8957.feature b/changelog.d/8957.feature new file mode 100644 index 0000000000..fa8961f840 --- /dev/null +++ b/changelog.d/8957.feature @@ -0,0 +1 @@ +Add rate limiters to cross-user key sharing requests. diff --git a/changelog.d/8978.feature b/changelog.d/8978.feature new file mode 100644 index 0000000000..042e257bf0 --- /dev/null +++ b/changelog.d/8978.feature @@ -0,0 +1 @@ +Add `order_by` to the admin API `GET /_synapse/admin/v1/users/<user_id>/media`. Contributed by @dklimpel. \ No newline at end of file diff --git a/changelog.d/9203.feature b/changelog.d/9203.feature new file mode 100644 index 0000000000..36b66a47a8 --- /dev/null +++ b/changelog.d/9203.feature @@ -0,0 +1 @@ +Add some configuration settings to make users' profile data more private. diff --git a/changelog.d/9358.misc b/changelog.d/9358.misc new file mode 100644 index 0000000000..cc7614afc0 --- /dev/null +++ b/changelog.d/9358.misc @@ -0,0 +1 @@ +Added a fix that invalidates cache for empty timed-out sync responses. \ No newline at end of file diff --git a/changelog.d/9383.feature b/changelog.d/9383.feature new file mode 100644 index 0000000000..8957c9cc5e --- /dev/null +++ b/changelog.d/9383.feature @@ -0,0 +1 @@ +Add a configuration option, `user_directory.prefer_local_users`, which when enabled will make it more likely for users on the same server as you to appear above other users. \ No newline at end of file diff --git a/changelog.d/9385.feature b/changelog.d/9385.feature new file mode 100644 index 0000000000..cbe3922de8 --- /dev/null +++ b/changelog.d/9385.feature @@ -0,0 +1 @@ + Add a configuration option, `user_directory.prefer_local_users`, which when enabled will make it more likely for users on the same server as you to appear above other users. \ No newline at end of file diff --git a/changelog.d/9402.bugfix b/changelog.d/9402.bugfix new file mode 100644 index 0000000000..7729225ba2 --- /dev/null +++ b/changelog.d/9402.bugfix @@ -0,0 +1 @@ +Fix a bug where a lot of unnecessary presence updates were sent when joining a room. diff --git a/changelog.d/9416.bugfix b/changelog.d/9416.bugfix new file mode 100644 index 0000000000..4d79cb2228 --- /dev/null +++ b/changelog.d/9416.bugfix @@ -0,0 +1 @@ +Fix a bug that caused multiple calls to the experimental `shared_rooms` endpoint to return stale results. \ No newline at end of file diff --git a/changelog.d/9432.misc b/changelog.d/9432.misc new file mode 100644 index 0000000000..1e07da2033 --- /dev/null +++ b/changelog.d/9432.misc @@ -0,0 +1 @@ +Add documentation and type hints to `parse_duration`. diff --git a/changelog.d/9438.feature b/changelog.d/9438.feature new file mode 100644 index 0000000000..a1f6be5563 --- /dev/null +++ b/changelog.d/9438.feature @@ -0,0 +1 @@ +Add support for regenerating thumbnails if they have been deleted but the original image is still stored. diff --git a/changelog.d/9440.bugfix b/changelog.d/9440.bugfix new file mode 100644 index 0000000000..47b9842b37 --- /dev/null +++ b/changelog.d/9440.bugfix @@ -0,0 +1 @@ +Fix bug introduced in v1.27.0 where allowing a user to choose their own username when logging in via single sign-on did not work unless an `idp_icon` was defined. diff --git a/changelog.d/9449.bugfix b/changelog.d/9449.bugfix new file mode 100644 index 0000000000..54214a7e4a --- /dev/null +++ b/changelog.d/9449.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.26.0 where some sequences were not properly configured when running `synapse_port_db`. diff --git a/changelog.d/9462.misc b/changelog.d/9462.misc new file mode 100644 index 0000000000..1b245bf85d --- /dev/null +++ b/changelog.d/9462.misc @@ -0,0 +1 @@ +Remove vestiges of `uploads_path` configuration setting. diff --git a/changelog.d/9463.doc b/changelog.d/9463.doc new file mode 100644 index 0000000000..c9cedd147d --- /dev/null +++ b/changelog.d/9463.doc @@ -0,0 +1 @@ +Update the example systemd config to propagate reloads to individual units. diff --git a/changelog.d/9464.misc b/changelog.d/9464.misc new file mode 100644 index 0000000000..39fcf85d40 --- /dev/null +++ b/changelog.d/9464.misc @@ -0,0 +1 @@ +Add a comment about systemd-python. diff --git a/changelog.d/9465.bugfix b/changelog.d/9465.bugfix new file mode 100644 index 0000000000..2ab4f315c1 --- /dev/null +++ b/changelog.d/9465.bugfix @@ -0,0 +1 @@ +Fix deleting pushers when using sharded pushers. diff --git a/changelog.d/9466.bugfix b/changelog.d/9466.bugfix new file mode 100644 index 0000000000..2ab4f315c1 --- /dev/null +++ b/changelog.d/9466.bugfix @@ -0,0 +1 @@ +Fix deleting pushers when using sharded pushers. diff --git a/changelog.d/9470.bugfix b/changelog.d/9470.bugfix new file mode 100644 index 0000000000..c1b7dbb17d --- /dev/null +++ b/changelog.d/9470.bugfix @@ -0,0 +1 @@ +Fix missing startup checks for the consistency of certain PostgreSQL sequences. diff --git a/changelog.d/9472.feature b/changelog.d/9472.feature new file mode 100644 index 0000000000..2ea14e2d62 --- /dev/null +++ b/changelog.d/9472.feature @@ -0,0 +1 @@ +Add support for `X-Forwarded-Proto` header when using a reverse proxy. Administrators using a reverse proxy should ensure this header is set to avoid warnings. See [docs/workers.md](docs/workers.md) for example configurations. diff --git a/changelog.d/9479.bugfix b/changelog.d/9479.bugfix new file mode 100644 index 0000000000..2ab4f315c1 --- /dev/null +++ b/changelog.d/9479.bugfix @@ -0,0 +1 @@ +Fix deleting pushers when using sharded pushers. diff --git a/docker/README.md b/docker/README.md index c8f27b8566..7b138df4d3 100644 --- a/docker/README.md +++ b/docker/README.md @@ -11,7 +11,6 @@ The image also does *not* provide a TURN server. By default, the image expects a single volume, located at ``/data``, that will hold: * configuration files; -* temporary files during uploads; * uploaded media and thumbnails; * the SQLite database if you do not configure postgres; * the appservices configuration. diff --git a/docker/conf/homeserver.yaml b/docker/conf/homeserver.yaml index 2ed570a5d1..0dea62a87d 100644 --- a/docker/conf/homeserver.yaml +++ b/docker/conf/homeserver.yaml @@ -89,7 +89,6 @@ federation_rc_concurrent: 3 ## Files ## media_store_path: "/data/media" -uploads_path: "/data/uploads" max_upload_size: "{{ SYNAPSE_MAX_UPLOAD_SIZE or "50M" }}" max_image_pixels: "32M" dynamic_thumbnails: false diff --git a/docs/admin_api/user_admin_api.rst b/docs/admin_api/user_admin_api.rst index 33dfbcfb49..8d4ec5a6f9 100644 --- a/docs/admin_api/user_admin_api.rst +++ b/docs/admin_api/user_admin_api.rst @@ -379,11 +379,12 @@ The following fields are returned in the JSON response body: - ``total`` - Number of rooms. -List media of an user -================================ +List media of a user +==================== Gets a list of all local media that a specific ``user_id`` has created. -The response is ordered by creation date descending and media ID descending. -The newest media is on top. +By default, the response is ordered by descending creation date and ascending media ID. +The newest media is on top. You can change the order with parameters +``order_by`` and ``dir``. The API is:: @@ -440,6 +441,35 @@ The following parameters should be set in the URL: denoting the offset in the returned results. This should be treated as an opaque value and not explicitly set to anything other than the return value of ``next_token`` from a previous call. Defaults to ``0``. +- ``order_by`` - The method by which to sort the returned list of media. + If the ordered field has duplicates, the second order is always by ascending ``media_id``, + which guarantees a stable ordering. Valid values are: + + - ``media_id`` - Media are ordered alphabetically by ``media_id``. + - ``upload_name`` - Media are ordered alphabetically by name the media was uploaded with. + - ``created_ts`` - Media are ordered by when the content was uploaded in ms. + Smallest to largest. This is the default. + - ``last_access_ts`` - Media are ordered by when the content was last accessed in ms. + Smallest to largest. + - ``media_length`` - Media are ordered by length of the media in bytes. + Smallest to largest. + - ``media_type`` - Media are ordered alphabetically by MIME-type. + - ``quarantined_by`` - Media are ordered alphabetically by the user ID that + initiated the quarantine request for this media. + - ``safe_from_quarantine`` - Media are ordered by the status if this media is safe + from quarantining. + +- ``dir`` - Direction of media order. Either ``f`` for forwards or ``b`` for backwards. + Setting this value to ``b`` will reverse the above sort order. Defaults to ``f``. + +If neither ``order_by`` nor ``dir`` is set, the default order is newest media on top +(corresponds to ``order_by`` = ``created_ts`` and ``dir`` = ``b``). + +Caution. The database only has indexes on the columns ``media_id``, +``user_id`` and ``created_ts``. This means that if a different sort order is used +(``upload_name``, ``last_access_ts``, ``media_length``, ``media_type``, +``quarantined_by`` or ``safe_from_quarantine``), this can cause a large load on the +database, especially for large environments. **Response** diff --git a/docs/reverse_proxy.md b/docs/reverse_proxy.md index 04b6e24124..bb7caa8bb9 100644 --- a/docs/reverse_proxy.md +++ b/docs/reverse_proxy.md @@ -9,23 +9,23 @@ of doing so is that it means that you can expose the default https port (443) to Matrix clients without needing to run Synapse with root privileges. -**NOTE**: Your reverse proxy must not `canonicalise` or `normalise` -the requested URI in any way (for example, by decoding `%xx` escapes). -Beware that Apache *will* canonicalise URIs unless you specify -`nocanon`. - -When setting up a reverse proxy, remember that Matrix clients and other -Matrix servers do not necessarily need to connect to your server via the -same server name or port. Indeed, clients will use port 443 by default, -whereas servers default to port 8448. Where these are different, we -refer to the 'client port' and the 'federation port'. See [the Matrix +You should configure your reverse proxy to forward requests to `/_matrix` or +`/_synapse/client` to Synapse, and have it set the `X-Forwarded-For` and +`X-Forwarded-Proto` request headers. + +You should remember that Matrix clients and other Matrix servers do not +necessarily need to connect to your server via the same server name or +port. Indeed, clients will use port 443 by default, whereas servers default to +port 8448. Where these are different, we refer to the 'client port' and the +'federation port'. See [the Matrix specification](https://matrix.org/docs/spec/server_server/latest#resolving-server-names) for more details of the algorithm used for federation connections, and [delegate.md](<delegate.md>) for instructions on setting up delegation. -Endpoints that are part of the standardised Matrix specification are -located under `/_matrix`, whereas endpoints specific to Synapse are -located under `/_synapse/client`. +**NOTE**: Your reverse proxy must not `canonicalise` or `normalise` +the requested URI in any way (for example, by decoding `%xx` escapes). +Beware that Apache *will* canonicalise URIs unless you specify +`nocanon`. Let's assume that we expect clients to connect to our server at `https://matrix.example.com`, and other servers to connect at @@ -52,6 +52,7 @@ server { location ~* ^(\/_matrix|\/_synapse\/client) { proxy_pass http://localhost:8008; proxy_set_header X-Forwarded-For $remote_addr; + proxy_set_header X-Forwarded-Proto $scheme; # Nginx by default only allows file uploads up to 1M in size # Increase client_max_body_size to match max_upload_size defined in homeserver.yaml client_max_body_size 50M; @@ -102,6 +103,7 @@ example.com:8448 { SSLEngine on ServerName matrix.example.com; + RequestHeader set "X-Forwarded-Proto" expr=%{REQUEST_SCHEME} AllowEncodedSlashes NoDecode ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix @@ -113,6 +115,7 @@ example.com:8448 { SSLEngine on ServerName example.com; + RequestHeader set "X-Forwarded-Proto" expr=%{REQUEST_SCHEME} AllowEncodedSlashes NoDecode ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix @@ -134,6 +137,9 @@ example.com:8448 { ``` frontend https bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1 + http-request set-header X-Forwarded-Proto https if { ssl_fc } + http-request set-header X-Forwarded-Proto http if !{ ssl_fc } + http-request set-header X-Forwarded-For %[src] # Matrix client traffic acl matrix-host hdr(host) -i matrix.example.com @@ -144,6 +150,10 @@ frontend https frontend matrix-federation bind :::8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1 + http-request set-header X-Forwarded-Proto https if { ssl_fc } + http-request set-header X-Forwarded-Proto http if !{ ssl_fc } + http-request set-header X-Forwarded-For %[src] + default_backend matrix backend matrix diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 52380dfb04..4dbef41b7e 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -101,6 +101,14 @@ pid_file: DATADIR/homeserver.pid # #limit_profile_requests_to_users_who_share_rooms: true +# Uncomment to prevent a user's profile data from being retrieved and +# displayed in a room until they have joined it. By default, a user's +# profile data is included in an invite event, regardless of the values +# of the above two settings, and whether or not the users share a server. +# Defaults to 'true'. +# +#include_profile_data_on_invite: false + # If set to 'true', removes the need for authentication to access the server's # public rooms directory through the client API, meaning that anyone can # query the room directory. Defaults to 'false'. @@ -699,6 +707,12 @@ acme: # - matrix.org # - example.com +# Uncomment to disable profile lookup over federation. By default, the +# Federation API allows other homeservers to obtain profile data of any user +# on this homeserver. Defaults to 'true'. +# +#allow_profile_lookup_over_federation: false + ## Caching ## @@ -2530,19 +2544,35 @@ spam_checker: # User Directory configuration # -# 'enabled' defines whether users can search the user directory. If -# false then empty responses are returned to all queries. Defaults to -# true. -# -# 'search_all_users' defines whether to search all users visible to your HS -# when searching the user directory, rather than limiting to users visible -# in public rooms. Defaults to false. If you set it True, you'll have to -# rebuild the user_directory search indexes, see -# https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md -# -#user_directory: -# enabled: true -# search_all_users: false +user_directory: + # Defines whether users can search the user directory. If false then + # empty responses are returned to all queries. Defaults to true. + # + # Uncomment to disable the user directory. + # + #enabled: false + + # Defines whether to search all users visible to your HS when searching + # the user directory, rather than limiting to users visible in public + # rooms. Defaults to false. + # + # If you set it true, you'll have to rebuild the user_directory search + # indexes, see: + # https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md + # + # Uncomment to return search results containing all known users, even if that + # user does not share a room with the requester. + # + #search_all_users: true + + # Defines whether to prefer local users in search query results. + # If True, local users are more likely to appear above remote users + # when searching the user directory. Defaults to false. + # + # Uncomment to prefer local over remote users in user directory search + # results. + # + #prefer_local_users: true # User Consent configuration diff --git a/docs/spam_checker.md b/docs/spam_checker.md index 47a27bf85c..e615ac9910 100644 --- a/docs/spam_checker.md +++ b/docs/spam_checker.md @@ -25,7 +25,7 @@ well as some specific methods: * `check_username_for_spam` * `check_registration_for_spam` -The details of the each of these methods (as well as their inputs and outputs) +The details of each of these methods (as well as their inputs and outputs) are documented in the `synapse.events.spamcheck.SpamChecker` class. The `ModuleApi` class provides a way for the custom spam checker class to diff --git a/docs/systemd-with-workers/system/matrix-synapse-worker@.service b/docs/systemd-with-workers/system/matrix-synapse-worker@.service index cb5ac0ac87..d164e8ce1f 100644 --- a/docs/systemd-with-workers/system/matrix-synapse-worker@.service +++ b/docs/systemd-with-workers/system/matrix-synapse-worker@.service @@ -4,6 +4,7 @@ AssertPathExists=/etc/matrix-synapse/workers/%i.yaml # This service should be restarted when the synapse target is restarted. PartOf=matrix-synapse.target +ReloadPropagatedFrom=matrix-synapse.target # if this is started at the same time as the main, let the main process start # first, to initialise the database schema. diff --git a/docs/systemd-with-workers/system/matrix-synapse.service b/docs/systemd-with-workers/system/matrix-synapse.service index c7b5ddfa49..f6b6dfd3ce 100644 --- a/docs/systemd-with-workers/system/matrix-synapse.service +++ b/docs/systemd-with-workers/system/matrix-synapse.service @@ -3,6 +3,7 @@ Description=Synapse master # This service should be restarted when the synapse target is restarted. PartOf=matrix-synapse.target +ReloadPropagatedFrom=matrix-synapse.target [Service] Type=notify diff --git a/docs/tcp_replication.md b/docs/tcp_replication.md index ad145439b4..15df949deb 100644 --- a/docs/tcp_replication.md +++ b/docs/tcp_replication.md @@ -220,10 +220,6 @@ Asks the server for the current position of all streams. Acknowledge receipt of some federation data -#### REMOVE_PUSHER (C) - - Inform the server a pusher should be removed - ### REMOTE_SERVER_UP (S, C) Inform other processes that a remote server may have come back online. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 69bf9110a6..d2aaea08f5 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -22,7 +22,7 @@ import logging import sys import time import traceback -from typing import Dict, Optional, Set +from typing import Dict, Iterable, Optional, Set import yaml @@ -629,7 +629,13 @@ class Porter(object): await self._setup_state_group_id_seq() await self._setup_user_id_seq() await self._setup_events_stream_seqs() - await self._setup_device_inbox_seq() + await self._setup_sequence( + "device_inbox_sequence", ("device_inbox", "device_federation_outbox") + ) + await self._setup_sequence( + "account_data_sequence", ("room_account_data", "room_tags_revisions", "account_data")) + await self._setup_sequence("receipts_sequence", ("receipts_linearized", )) + await self._setup_auth_chain_sequence() # Step 3. Get tables. self.progress.set_state("Fetching tables") @@ -854,7 +860,7 @@ class Porter(object): return done, remaining + done - async def _setup_state_group_id_seq(self): + async def _setup_state_group_id_seq(self) -> None: curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol( table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True ) @@ -868,7 +874,7 @@ class Porter(object): await self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r) - async def _setup_user_id_seq(self): + async def _setup_user_id_seq(self) -> None: curr_id = await self.sqlite_store.db_pool.runInteraction( "setup_user_id_seq", find_max_generated_user_id_localpart ) @@ -877,9 +883,9 @@ class Porter(object): next_id = curr_id + 1 txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,)) - return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r) + await self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r) - async def _setup_events_stream_seqs(self): + async def _setup_events_stream_seqs(self) -> None: """Set the event stream sequences to the correct values. """ @@ -908,35 +914,46 @@ class Porter(object): (curr_backward_id + 1,), ) - return await self.postgres_store.db_pool.runInteraction( + await self.postgres_store.db_pool.runInteraction( "_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos, ) - async def _setup_device_inbox_seq(self): - """Set the device inbox sequence to the correct value. + async def _setup_sequence(self, sequence_name: str, stream_id_tables: Iterable[str]) -> None: + """Set a sequence to the correct value. """ - curr_local_id = await self.sqlite_store.db_pool.simple_select_one_onecol( - table="device_inbox", - keyvalues={}, - retcol="COALESCE(MAX(stream_id), 1)", - allow_none=True, - ) + current_stream_ids = [] + for stream_id_table in stream_id_tables: + max_stream_id = await self.sqlite_store.db_pool.simple_select_one_onecol( + table=stream_id_table, + keyvalues={}, + retcol="COALESCE(MAX(stream_id), 1)", + allow_none=True, + ) + current_stream_ids.append(max_stream_id) - curr_federation_id = await self.sqlite_store.db_pool.simple_select_one_onecol( - table="device_federation_outbox", - keyvalues={}, - retcol="COALESCE(MAX(stream_id), 1)", - allow_none=True, - ) + next_id = max(current_stream_ids) + 1 + + def r(txn): + sql = "ALTER SEQUENCE %s RESTART WITH" % (sequence_name, ) + txn.execute(sql + " %s", (next_id, )) - next_id = max(curr_local_id, curr_federation_id) + 1 + await self.postgres_store.db_pool.runInteraction("_setup_%s" % (sequence_name,), r) + + async def _setup_auth_chain_sequence(self) -> None: + curr_chain_id = await self.sqlite_store.db_pool.simple_select_one_onecol( + table="event_auth_chains", keyvalues={}, retcol="MAX(chain_id)", allow_none=True + ) def r(txn): txn.execute( - "ALTER SEQUENCE device_inbox_sequence RESTART WITH %s", (next_id,) + "ALTER SEQUENCE event_auth_chain_id RESTART WITH %s", + (curr_chain_id,), ) - return self.postgres_store.db_pool.runInteraction("_setup_device_inbox_seq", r) + await self.postgres_store.db_pool.runInteraction( + "_setup_event_auth_chain_id", r, + ) + ############################################## diff --git a/synapse/api/constants.py b/synapse/api/constants.py index af8d59cf87..691f8f9adf 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -98,11 +98,14 @@ class EventTypes: Retention = "m.room.retention" - Presence = "m.presence" - Dummy = "org.matrix.dummy_event" +class EduTypes: + Presence = "m.presence" + RoomKeyRequest = "m.room_key_request" + + class RejectedReason: AUTH_ERROR = "auth_error" diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 5d9d5a228f..c3f07bc1a3 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -14,7 +14,7 @@ # limitations under the License. from collections import OrderedDict -from typing import Any, Optional, Tuple +from typing import Hashable, Optional, Tuple from synapse.api.errors import LimitExceededError from synapse.types import Requester @@ -42,7 +42,9 @@ class Ratelimiter: # * How many times an action has occurred since a point in time # * The point in time # * The rate_hz of this particular entry. This can vary per request - self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]] + self.actions = ( + OrderedDict() + ) # type: OrderedDict[Hashable, Tuple[float, int, float]] def can_requester_do_action( self, @@ -82,7 +84,7 @@ class Ratelimiter: def can_do_action( self, - key: Any, + key: Hashable, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, update: bool = True, @@ -175,7 +177,7 @@ class Ratelimiter: def ratelimit( self, - key: Any, + key: Hashable, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, update: bool = True, diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index b4bd4d8e7a..9f99651aa2 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -210,7 +210,9 @@ def start(config_options): config.update_user_directory = False config.run_background_tasks = False config.start_pushers = False + config.pusher_shard_config.instances = [] config.send_federation = False + config.federation_shard_config.instances = [] synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 6526acb2f2..dc0d3eb725 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -645,9 +645,6 @@ class GenericWorkerServer(HomeServer): self.get_tcp_replication().start_replication(self) - async def remove_pusher(self, app_id, push_key, user_id): - self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) - @cache_in_self def get_replication_data_handler(self): return GenericWorkerReplicationHandler(self) @@ -922,22 +919,6 @@ def start(config_options): # For other worker types we force this to off. config.appservice.notify_appservices = False - if config.worker_app == "synapse.app.pusher": - if config.server.start_pushers: - sys.stderr.write( - "\nThe pushers must be disabled in the main synapse process" - "\nbefore they can be run in a separate worker." - "\nPlease add ``start_pushers: false`` to the main config" - "\n" - ) - sys.exit(1) - - # Force the pushers to start since they will be disabled in the main config - config.server.start_pushers = True - else: - # For other worker types we force this to off. - config.server.start_pushers = False - if config.worker_app == "synapse.app.user_dir": if config.server.update_user_directory: sys.stderr.write( @@ -954,22 +935,6 @@ def start(config_options): # For other worker types we force this to off. config.server.update_user_directory = False - if config.worker_app == "synapse.app.federation_sender": - if config.worker.send_federation: - sys.stderr.write( - "\nThe send_federation must be disabled in the main synapse process" - "\nbefore they can be run in a separate worker." - "\nPlease add ``send_federation: false`` to the main config" - "\n" - ) - sys.exit(1) - - # Force the pushers to start since they will be disabled in the main config - config.worker.send_federation = True - else: - # For other worker types we force this to off. - config.worker.send_federation = False - synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts hs = GenericWorkerServer( diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 97399eb9ba..4026966711 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -21,7 +21,7 @@ import os from collections import OrderedDict from hashlib import sha256 from textwrap import dedent -from typing import Any, Iterable, List, MutableMapping, Optional +from typing import Any, Iterable, List, MutableMapping, Optional, Union import attr import jinja2 @@ -147,7 +147,20 @@ class Config: return int(value) * size @staticmethod - def parse_duration(value): + def parse_duration(value: Union[str, int]) -> int: + """Convert a duration as a string or integer to a number of milliseconds. + + If an integer is provided it is treated as milliseconds and is unchanged. + + String durations can have a suffix of 's', 'm', 'h', 'd', 'w', or 'y'. + No suffix is treated as milliseconds. + + Args: + value: The duration to parse. + + Returns: + The number of milliseconds in the duration. + """ if isinstance(value, int): return value second = 1000 @@ -831,22 +844,23 @@ class ShardedWorkerHandlingConfig: def should_handle(self, instance_name: str, key: str) -> bool: """Whether this instance is responsible for handling the given key.""" - # If multiple instances are not defined we always return true - if not self.instances or len(self.instances) == 1: - return True + # If no instances are defined we assume some other worker is handling + # this. + if not self.instances: + return False - return self.get_instance(key) == instance_name + return self._get_instance(key) == instance_name - def get_instance(self, key: str) -> str: + def _get_instance(self, key: str) -> str: """Get the instance responsible for handling the given key. - Note: For things like federation sending the config for which instance - is sending is known only to the sender instance if there is only one. - Therefore `should_handle` should be used where possible. + Note: For federation sending and pushers the config for which instance + is sending is known only to the sender instance, so we don't expose this + method by default. """ if not self.instances: - return "master" + raise Exception("Unknown worker") if len(self.instances) == 1: return self.instances[0] @@ -863,4 +877,21 @@ class ShardedWorkerHandlingConfig: return self.instances[remainder] +@attr.s +class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig): + """A version of `ShardedWorkerHandlingConfig` that is used for config + options where all instances know which instances are responsible for the + sharded work. + """ + + def __attrs_post_init__(self): + # We require that `self.instances` is non-empty. + if not self.instances: + raise Exception("Got empty list of instances for shard config") + + def get_instance(self, key: str) -> str: + """Get the instance responsible for handling the given key.""" + return self._get_instance(key) + + __all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 70025b5d60..db16c86f50 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -149,4 +149,6 @@ class ShardedWorkerHandlingConfig: instances: List[str] def __init__(self, instances: List[str]) -> None: ... def should_handle(self, instance_name: str, key: str) -> bool: ... + +class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig): def get_instance(self, key: str) -> str: ... diff --git a/synapse/config/federation.py b/synapse/config/federation.py index 9f3c57e6a1..55e4db5442 100644 --- a/synapse/config/federation.py +++ b/synapse/config/federation.py @@ -41,6 +41,10 @@ class FederationConfig(Config): ) self.federation_metrics_domains = set(federation_metrics_domains) + self.allow_profile_lookup_over_federation = config.get( + "allow_profile_lookup_over_federation", True + ) + def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ ## Federation ## @@ -66,6 +70,12 @@ class FederationConfig(Config): #federation_metrics_domains: # - matrix.org # - example.com + + # Uncomment to disable profile lookup over federation. By default, the + # Federation API allows other homeservers to obtain profile data of any user + # on this homeserver. Defaults to 'true'. + # + #allow_profile_lookup_over_federation: false """ diff --git a/synapse/config/push.py b/synapse/config/push.py index 3adbfb73e6..7831a2ef79 100644 --- a/synapse/config/push.py +++ b/synapse/config/push.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import Config, ShardedWorkerHandlingConfig +from ._base import Config class PushConfig(Config): @@ -27,9 +27,6 @@ class PushConfig(Config): "group_unread_count_by_room", True ) - pusher_instances = config.get("pusher_instances") or [] - self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances) - # There was a a 'redact_content' setting but mistakenly read from the # 'email'section'. Check for the flag in the 'push' section, and log, # but do not honour it to avoid nasty surprises when people upgrade. diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index def33a60ad..847d25122c 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -102,6 +102,16 @@ class RatelimitConfig(Config): defaults={"per_second": 0.01, "burst_count": 3}, ) + # Ratelimit cross-user key requests: + # * For local requests this is keyed by the sending device. + # * For requests received over federation this is keyed by the origin. + # + # Note that this isn't exposed in the configuration as it is obscure. + self.rc_key_requests = RateLimitConfig( + config.get("rc_key_requests", {}), + defaults={"per_second": 20, "burst_count": 100}, + ) + self.rc_3pid_validation = RateLimitConfig( config.get("rc_3pid_validation") or {}, defaults={"per_second": 0.003, "burst_count": 5}, diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 52849c3256..69d9de5a43 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -206,7 +206,6 @@ class ContentRepositoryConfig(Config): def generate_config_section(self, data_dir_path, **kwargs): media_store = os.path.join(data_dir_path, "media_store") - uploads_path = os.path.join(data_dir_path, "uploads") formatted_thumbnail_sizes = "".join( THUMBNAIL_SIZE_YAML % s for s in DEFAULT_THUMBNAIL_SIZES diff --git a/synapse/config/server.py b/synapse/config/server.py index 6f3325ff81..2afca36e7d 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -263,6 +263,12 @@ class ServerConfig(Config): False, ) + # Whether to retrieve and display profile data for a user when they + # are invited to a room + self.include_profile_data_on_invite = config.get( + "include_profile_data_on_invite", True + ) + if "restrict_public_rooms_to_local_users" in config and ( "allow_public_rooms_without_auth" in config or "allow_public_rooms_over_federation" in config @@ -391,7 +397,6 @@ class ServerConfig(Config): if self.public_baseurl is not None: if self.public_baseurl[-1] != "/": self.public_baseurl += "/" - self.start_pushers = config.get("start_pushers", True) # (undocumented) option for torturing the worker-mode replication a bit, # for testing. The value defines the number of milliseconds to pause before @@ -848,6 +853,14 @@ class ServerConfig(Config): # #limit_profile_requests_to_users_who_share_rooms: true + # Uncomment to prevent a user's profile data from being retrieved and + # displayed in a room until they have joined it. By default, a user's + # profile data is included in an invite event, regardless of the values + # of the above two settings, and whether or not the users share a server. + # Defaults to 'true'. + # + #include_profile_data_on_invite: false + # If set to 'true', removes the need for authentication to access the server's # public rooms directory through the client API, meaning that anyone can # query the room directory. Defaults to 'false'. diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py index c8d19c5d6b..8d05ef173c 100644 --- a/synapse/config/user_directory.py +++ b/synapse/config/user_directory.py @@ -24,32 +24,46 @@ class UserDirectoryConfig(Config): section = "userdirectory" def read_config(self, config, **kwargs): - self.user_directory_search_enabled = True - self.user_directory_search_all_users = False - user_directory_config = config.get("user_directory", None) - if user_directory_config: - self.user_directory_search_enabled = user_directory_config.get( - "enabled", True - ) - self.user_directory_search_all_users = user_directory_config.get( - "search_all_users", False - ) + user_directory_config = config.get("user_directory") or {} + self.user_directory_search_enabled = user_directory_config.get("enabled", True) + self.user_directory_search_all_users = user_directory_config.get( + "search_all_users", False + ) + self.user_directory_search_prefer_local_users = user_directory_config.get( + "prefer_local_users", False + ) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """ # User Directory configuration # - # 'enabled' defines whether users can search the user directory. If - # false then empty responses are returned to all queries. Defaults to - # true. - # - # 'search_all_users' defines whether to search all users visible to your HS - # when searching the user directory, rather than limiting to users visible - # in public rooms. Defaults to false. If you set it True, you'll have to - # rebuild the user_directory search indexes, see - # https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md - # - #user_directory: - # enabled: true - # search_all_users: false + user_directory: + # Defines whether users can search the user directory. If false then + # empty responses are returned to all queries. Defaults to true. + # + # Uncomment to disable the user directory. + # + #enabled: false + + # Defines whether to search all users visible to your HS when searching + # the user directory, rather than limiting to users visible in public + # rooms. Defaults to false. + # + # If you set it true, you'll have to rebuild the user_directory search + # indexes, see: + # https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md + # + # Uncomment to return search results containing all known users, even if that + # user does not share a room with the requester. + # + #search_all_users: true + + # Defines whether to prefer local users in search query results. + # If True, local users are more likely to appear above remote users + # when searching the user directory. Defaults to false. + # + # Uncomment to prefer local over remote users in user directory search + # results. + # + #prefer_local_users: true """ diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 7a0ca16da8..ac92375a85 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -17,9 +17,28 @@ from typing import List, Union import attr -from ._base import Config, ConfigError, ShardedWorkerHandlingConfig +from ._base import ( + Config, + ConfigError, + RoutableShardedWorkerHandlingConfig, + ShardedWorkerHandlingConfig, +) from .server import ListenerConfig, parse_listener_def +_FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR = """ +The send_federation config option must be disabled in the main +synapse process before they can be run in a separate worker. + +Please add ``send_federation: false`` to the main config +""" + +_PUSHER_WITH_START_PUSHERS_ENABLED_ERROR = """ +The start_pushers config option must be disabled in the main +synapse process before they can be run in a separate worker. + +Please add ``start_pushers: false`` to the main config +""" + def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]: """Helper for allowing parsing a string or list of strings to a config @@ -103,6 +122,7 @@ class WorkerConfig(Config): self.worker_replication_secret = config.get("worker_replication_secret", None) self.worker_name = config.get("worker_name", self.worker_app) + self.instance_name = self.worker_name or "master" self.worker_main_http_uri = config.get("worker_main_http_uri", None) @@ -118,12 +138,41 @@ class WorkerConfig(Config): ) ) - # Whether to send federation traffic out in this process. This only - # applies to some federation traffic, and so shouldn't be used to - # "disable" federation - self.send_federation = config.get("send_federation", True) + # Handle federation sender configuration. + # + # There are two ways of configuring which instances handle federation + # sending: + # 1. The old way where "send_federation" is set to false and running a + # `synapse.app.federation_sender` worker app. + # 2. Specifying the workers sending federation in + # `federation_sender_instances`. + # + + send_federation = config.get("send_federation", True) + + federation_sender_instances = config.get("federation_sender_instances") + if federation_sender_instances is None: + # Default to an empty list, which means "another, unknown, worker is + # responsible for it". + federation_sender_instances = [] - federation_sender_instances = config.get("federation_sender_instances") or [] + # If no federation sender instances are set we check if + # `send_federation` is set, which means use master + if send_federation: + federation_sender_instances = ["master"] + + if self.worker_app == "synapse.app.federation_sender": + if send_federation: + # If we're running federation senders, and not using + # `federation_sender_instances`, then we should have + # explicitly set `send_federation` to false. + raise ConfigError( + _FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR + ) + + federation_sender_instances = [self.worker_name] + + self.send_federation = self.instance_name in federation_sender_instances self.federation_shard_config = ShardedWorkerHandlingConfig( federation_sender_instances ) @@ -164,7 +213,37 @@ class WorkerConfig(Config): "Must only specify one instance to handle `receipts` messages." ) - self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) + if len(self.writers.events) == 0: + raise ConfigError("Must specify at least one instance to handle `events`.") + + self.events_shard_config = RoutableShardedWorkerHandlingConfig( + self.writers.events + ) + + # Handle sharded push + start_pushers = config.get("start_pushers", True) + pusher_instances = config.get("pusher_instances") + if pusher_instances is None: + # Default to an empty list, which means "another, unknown, worker is + # responsible for it". + pusher_instances = [] + + # If no pushers instances are set we check if `start_pushers` is + # set, which means use master + if start_pushers: + pusher_instances = ["master"] + + if self.worker_app == "synapse.app.pusher": + if start_pushers: + # If we're running pushers, and not using + # `pusher_instances`, then we should have explicitly set + # `start_pushers` to false. + raise ConfigError(_PUSHER_WITH_START_PUSHERS_ENABLED_ERROR) + + pusher_instances = [self.instance_name] + + self.start_pushers = self.instance_name in pusher_instances + self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances) # Whether this worker should run background tasks or not. # diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 8d4bb621e7..2f832b47f6 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -34,7 +34,7 @@ from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -44,6 +44,7 @@ from synapse.api.errors import ( SynapseError, UnsupportedRoomVersionError, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.federation.federation_base import FederationBase, event_from_pdu_json @@ -869,6 +870,13 @@ class FederationHandlerRegistry: # EDU received. self._edu_type_to_instance = {} # type: Dict[str, List[str]] + # A rate limiter for incoming room key requests per origin. + self._room_key_request_rate_limiter = Ratelimiter( + clock=self.clock, + rate_hz=self.config.rc_key_requests.per_second, + burst_count=self.config.rc_key_requests.burst_count, + ) + def register_edu_handler( self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]] ): @@ -917,7 +925,15 @@ class FederationHandlerRegistry: self._edu_type_to_instance[edu_type] = instance_names async def on_edu(self, edu_type: str, origin: str, content: dict): - if not self.config.use_presence and edu_type == "m.presence": + if not self.config.use_presence and edu_type == EduTypes.Presence: + return + + # If the incoming room key requests from a particular origin are over + # the limit, drop them. + if ( + edu_type == EduTypes.RoomKeyRequest + and not self._room_key_request_rate_limiter.can_do_action(origin) + ): return # Check if we have a handler on this instance diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 97fc4d0a82..24ebc4b803 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -474,7 +474,7 @@ class FederationSender: self._processing_pending_presence = False def send_presence_to_destinations( - self, states: List[UserPresenceState], destinations: List[str] + self, states: Iterable[UserPresenceState], destinations: Iterable[str] ) -> None: """Send the given presence states to the given destinations. destinations (list[str]) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index cce83704d4..2cf935f38d 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -484,10 +484,9 @@ class FederationQueryServlet(BaseFederationServlet): # This is when we receive a server-server Query async def on_GET(self, origin, content, query, query_type): - return await self.handler.on_query_request( - query_type, - {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()}, - ) + args = {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()} + args["origin"] = origin + return await self.handler.on_query_request(query_type, args) class FederationMakeJoinServlet(BaseFederationServlet): diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 1aa7d803b5..7db4f48965 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -16,7 +16,9 @@ import logging from typing import TYPE_CHECKING, Any, Dict +from synapse.api.constants import EduTypes from synapse.api.errors import SynapseError +from synapse.api.ratelimiting import Ratelimiter from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( get_active_span_text_map, @@ -25,7 +27,7 @@ from synapse.logging.opentracing import ( start_active_span, ) from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet -from synapse.types import JsonDict, UserID, get_domain_from_id +from synapse.types import JsonDict, Requester, UserID, get_domain_from_id from synapse.util import json_encoder from synapse.util.stringutils import random_string @@ -78,6 +80,12 @@ class DeviceMessageHandler: ReplicationUserDevicesResyncRestServlet.make_client(hs) ) + self._ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=hs.config.rc_key_requests.per_second, + burst_count=hs.config.rc_key_requests.burst_count, + ) + async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: local_messages = {} sender_user_id = content["sender"] @@ -168,15 +176,27 @@ class DeviceMessageHandler: async def send_device_message( self, - sender_user_id: str, + requester: Requester, message_type: str, messages: Dict[str, Dict[str, JsonDict]], ) -> None: + sender_user_id = requester.user.to_string() + set_tag("number_of_messages", len(messages)) set_tag("sender", sender_user_id) local_messages = {} remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] for user_id, by_device in messages.items(): + # Ratelimit local cross-user key requests by the sending device. + if ( + message_type == EduTypes.RoomKeyRequest + and user_id != sender_user_id + and self._ratelimiter.can_do_action( + (sender_user_id, requester.device_id) + ) + ): + continue + # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): messages_by_device = { diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 3e23f82cf7..f46cab7325 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -17,7 +17,7 @@ import logging import random from typing import TYPE_CHECKING, Iterable, List, Optional -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase from synapse.handlers.presence import format_user_presence_state @@ -113,7 +113,7 @@ class EventStreamHandler(BaseHandler): states = await presence_handler.get_states(users) to_add.extend( { - "type": EventTypes.Presence, + "type": EduTypes.Presence, "content": format_user_presence_state(state, time_now), } for state in states diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 78c3e5a10b..71a5076672 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional, Tuple from twisted.internet import defer -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import SynapseError from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state @@ -412,7 +412,7 @@ class InitialSyncHandler(BaseHandler): return [ { - "type": EventTypes.Presence, + "type": EduTypes.Presence, "content": format_user_presence_state(s, time_now), } for s in states diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c03f6c997b..1b7c065b34 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -387,6 +387,12 @@ class EventCreationHandler: self.room_invite_state_types = self.hs.config.room_invite_state_types + self.membership_types_to_include_profile_data_in = ( + {Membership.JOIN, Membership.INVITE} + if self.hs.config.include_profile_data_on_invite + else {Membership.JOIN} + ) + self.send_event = ReplicationSendEventRestServlet.make_client(hs) # This is only used to get at ratelimit function, and maybe_kick_guest_users @@ -500,7 +506,7 @@ class EventCreationHandler: membership = builder.content.get("membership", None) target = UserID.from_string(builder.state_key) - if membership in {Membership.JOIN, Membership.INVITE}: + if membership in self.membership_types_to_include_profile_data_in: # If event doesn't include a display name, add one. profile = self.profile_handler content = builder.content diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index fb85b19770..b6a9ce4f38 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -849,6 +849,9 @@ class PresenceHandler(BasePresenceHandler): """Process current state deltas to find new joins that need to be handled. """ + # A map of destination to a set of user state that they should receive + presence_destinations = {} # type: Dict[str, Set[UserPresenceState]] + for delta in deltas: typ = delta["type"] state_key = delta["state_key"] @@ -858,6 +861,7 @@ class PresenceHandler(BasePresenceHandler): logger.debug("Handling: %r %r, %s", typ, state_key, event_id) + # Drop any event that isn't a membership join if typ != EventTypes.Member: continue @@ -880,13 +884,38 @@ class PresenceHandler(BasePresenceHandler): # Ignore changes to join events. continue - await self._on_user_joined_room(room_id, state_key) + # Retrieve any user presence state updates that need to be sent as a result, + # and the destinations that need to receive it + destinations, user_presence_states = await self._on_user_joined_room( + room_id, state_key + ) + + # Insert the destinations and respective updates into our destinations dict + for destination in destinations: + presence_destinations.setdefault(destination, set()).update( + user_presence_states + ) + + # Send out user presence updates for each destination + for destination, user_state_set in presence_destinations.items(): + self.federation.send_presence_to_destinations( + destinations=[destination], states=user_state_set + ) - async def _on_user_joined_room(self, room_id: str, user_id: str) -> None: + async def _on_user_joined_room( + self, room_id: str, user_id: str + ) -> Tuple[List[str], List[UserPresenceState]]: """Called when we detect a user joining the room via the current state - delta stream. - """ + delta stream. Returns the destinations that need to be updated and the + presence updates to send to them. + + Args: + room_id: The ID of the room that the user has joined. + user_id: The ID of the user that has joined the room. + Returns: + A tuple of destinations and presence updates to send to them. + """ if self.is_mine_id(user_id): # If this is a local user then we need to send their presence # out to hosts in the room (who don't already have it) @@ -894,15 +923,15 @@ class PresenceHandler(BasePresenceHandler): # TODO: We should be able to filter the hosts down to those that # haven't previously seen the user - state = await self.current_state_for_user(user_id) - hosts = await self.state.get_current_hosts_in_room(room_id) + remote_hosts = await self.state.get_current_hosts_in_room(room_id) # Filter out ourselves. - hosts = {host for host in hosts if host != self.server_name} + filtered_remote_hosts = [ + host for host in remote_hosts if host != self.server_name + ] - self.federation.send_presence_to_destinations( - states=[state], destinations=hosts - ) + state = await self.current_state_for_user(user_id) + return filtered_remote_hosts, [state] else: # A remote user has joined the room, so we need to: # 1. Check if this is a new server in the room @@ -915,6 +944,8 @@ class PresenceHandler(BasePresenceHandler): # TODO: Check that this is actually a new server joining the # room. + remote_host = get_domain_from_id(user_id) + users = await self.state.get_current_users_in_room(room_id) user_ids = list(filter(self.is_mine_id, users)) @@ -934,10 +965,7 @@ class PresenceHandler(BasePresenceHandler): or state.status_msg is not None ] - if states: - self.federation.send_presence_to_destinations( - states=states, destinations=[get_domain_from_id(user_id)] - ) + return [remote_host], states def should_notify(old_state, new_state): diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 2f62d84fb5..dd59392bda 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -310,6 +310,15 @@ class ProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) async def on_profile_query(self, args: JsonDict) -> JsonDict: + """Handles federation profile query requests.""" + + if not self.hs.config.allow_profile_lookup_over_federation: + raise SynapseError( + 403, + "Profile lookup over federation is disabled on this homeserver", + Codes.FORBIDDEN, + ) + user = UserID.from_string(args["user_id"]) if not self.hs.is_mine(user): raise SynapseError(400, "User is not hosted on this homeserver") diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 4e8ed7b33f..ce644e01ad 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -277,8 +277,9 @@ class SyncHandler: user_id = sync_config.user.to_string() await self.auth.check_auth_blocking(requester=requester) - res = await self.response_cache.wrap( + res = await self.response_cache.wrap_conditional( sync_config.request_key, + lambda result: since_token != result.next_batch, self._wait_for_sync_for_user, sync_config, since_token, diff --git a/synapse/http/site.py b/synapse/http/site.py index 4a4fb5ef26..30153237e3 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -16,6 +16,10 @@ import logging import time from typing import Optional, Union +import attr +from zope.interface import implementer + +from twisted.internet.interfaces import IAddress from twisted.python.failure import Failure from twisted.web.server import Request, Site @@ -333,26 +337,77 @@ class SynapseRequest(Request): class XForwardedForRequest(SynapseRequest): - def __init__(self, *args, **kw): - SynapseRequest.__init__(self, *args, **kw) + """Request object which honours proxy headers + Extends SynapseRequest to replace getClientIP, getClientAddress, and isSecure with + information from request headers. """ - Add a layer on top of another request that only uses the value of an - X-Forwarded-For header as the result of C{getClientIP}. - """ - def getClientIP(self): + # the client IP and ssl flag, as extracted from the headers. + _forwarded_for = None # type: Optional[_XForwardedForAddress] + _forwarded_https = False # type: bool + + def requestReceived(self, command, path, version): + # this method is called by the Channel once the full request has been + # received, to dispatch the request to a resource. + # We can use it to set the IP address and protocol according to the + # headers. + self._process_forwarded_headers() + return super().requestReceived(command, path, version) + + def _process_forwarded_headers(self): + headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for") + if not headers: + return + + # for now, we just use the first x-forwarded-for header. Really, we ought + # to start from the client IP address, and check whether it is trusted; if it + # is, work backwards through the headers until we find an untrusted address. + # see https://github.com/matrix-org/synapse/issues/9471 + self._forwarded_for = _XForwardedForAddress( + headers[0].split(b",")[0].strip().decode("ascii") + ) + + # if we got an x-forwarded-for header, also look for an x-forwarded-proto header + header = self.getHeader(b"x-forwarded-proto") + if header is not None: + self._forwarded_https = header.lower() == b"https" + else: + # this is done largely for backwards-compatibility so that people that + # haven't set an x-forwarded-proto header don't get a redirect loop. + logger.warning( + "forwarded request lacks an x-forwarded-proto header: assuming https" + ) + self._forwarded_https = True + + def isSecure(self): + if self._forwarded_https: + return True + return super().isSecure() + + def getClientIP(self) -> str: """ - @return: The client address (the first address) in the value of the - I{X-Forwarded-For header}. If the header is not present, return - C{b"-"}. + Return the IP address of the client who submitted this request. + + This method is deprecated. Use getClientAddress() instead. """ - return ( - self.requestHeaders.getRawHeaders(b"x-forwarded-for", [b"-"])[0] - .split(b",")[0] - .strip() - .decode("ascii") - ) + if self._forwarded_for is not None: + return self._forwarded_for.host + return super().getClientIP() + + def getClientAddress(self) -> IAddress: + """ + Return the address of the client who submitted this request. + """ + if self._forwarded_for is not None: + return self._forwarded_for + return super().getClientAddress() + + +@implementer(IAddress) +@attr.s(frozen=True, slots=True) +class _XForwardedForAddress: + host = attr.ib(type=str) class SynapseSite(Site): diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index b9d3da2e0a..f4d7e199e9 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -74,6 +74,7 @@ class HttpPusher(Pusher): self.timed_call = None self._is_processing = False self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room + self._pusherpool = hs.get_pusherpool() self.data = pusher_config.data if self.data is None: @@ -299,7 +300,7 @@ class HttpPusher(Pusher): ) else: logger.info("Pushkey %s was rejected: removing", pk) - await self.hs.remove_pusher(self.app_id, pk, self.user_id) + await self._pusherpool.remove_pusher(self.app_id, pk, self.user_id) return True async def _build_notification_dict( diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index ae1145be0e..21f14f05f0 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -25,6 +25,7 @@ from synapse.metrics.background_process_metrics import ( ) from synapse.push import Pusher, PusherConfig, PusherConfigException from synapse.push.pusher import PusherFactory +from synapse.replication.http.push import ReplicationRemovePusherRestServlet from synapse.types import JsonDict, RoomStreamToken from synapse.util.async_helpers import concurrently_execute @@ -58,7 +59,6 @@ class PusherPool: def __init__(self, hs: "HomeServer"): self.hs = hs self.pusher_factory = PusherFactory(hs) - self._should_start_pushers = hs.config.start_pushers self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() @@ -67,6 +67,16 @@ class PusherPool: # We shard the handling of push notifications by user ID. self._pusher_shard_config = hs.config.push.pusher_shard_config self._instance_name = hs.get_instance_name() + self._should_start_pushers = ( + self._instance_name in self._pusher_shard_config.instances + ) + + # We can only delete pushers on master. + self._remove_pusher_client = None + if hs.config.worker.worker_app: + self._remove_pusher_client = ReplicationRemovePusherRestServlet.make_client( + hs + ) # Record the last stream ID that we were poked about so we can get # changes since then. We set this to the current max stream ID on @@ -175,9 +185,6 @@ class PusherPool: user_id: user to remove pushers for access_tokens: access token *ids* to remove pushers for """ - if not self._pusher_shard_config.should_handle(self._instance_name, user_id): - return - tokens = set(access_tokens) for p in await self.store.get_pushers_by_user_id(user_id): if p.access_token in tokens: @@ -380,6 +387,12 @@ class PusherPool: synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() - await self.store.delete_pusher_by_app_id_pushkey_user_id( - app_id, pushkey, user_id - ) + # We can only delete pushers on master. + if self._remove_pusher_client: + await self._remove_pusher_client( + app_id=app_id, pushkey=pushkey, user_id=user_id + ) + else: + await self.store.delete_pusher_by_app_id_pushkey_user_id( + app_id, pushkey, user_id + ) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 8a2b73b75e..321a333820 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -106,6 +106,9 @@ CONDITIONAL_REQUIREMENTS = { "pysaml2>=4.5.0;python_version>='3.6'", ], "oidc": ["authlib>=0.14.0"], + # systemd-python is necessary for logging to the systemd journal via + # `systemd.journal.JournalHandler`, as is documented in + # `contrib/systemd/log_config.yaml`. "systemd": ["systemd-python>=231"], "url_preview": ["lxml>=3.5.0"], "sentry": ["sentry-sdk>=0.7.2"], diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index dd527e807f..cb4a52dbe9 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -21,6 +21,7 @@ from synapse.replication.http import ( login, membership, presence, + push, register, send_event, streams, @@ -42,6 +43,7 @@ class ReplicationRestResource(JsonResource): membership.register_servlets(hs, self) streams.register_servlets(hs, self) account_data.register_servlets(hs, self) + push.register_servlets(hs, self) # The following can't currently be instantiated on workers. if hs.config.worker.worker_app is None: diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 7a0dbb5b1a..8af53b4f28 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -213,8 +213,9 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): content = parse_json_object_from_request(request) args = content["args"] + args["origin"] = content["origin"] - logger.info("Got %r query", query_type) + logger.info("Got %r query from %s", query_type, args["origin"]) result = await self.registry.on_query(query_type, args) diff --git a/synapse/replication/http/push.py b/synapse/replication/http/push.py new file mode 100644 index 0000000000..054ed64d34 --- /dev/null +++ b/synapse/replication/http/push.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from synapse.http.servlet import parse_json_object_from_request +from synapse.replication.http._base import ReplicationEndpoint + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class ReplicationRemovePusherRestServlet(ReplicationEndpoint): + """Deletes the given pusher. + + Request format: + + POST /_synapse/replication/remove_pusher/:user_id + + { + "app_id": "<some_id>", + "pushkey": "<some_key>" + } + + """ + + NAME = "add_user_account_data" + PATH_ARGS = ("user_id",) + CACHE = False + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.pusher_pool = hs.get_pusherpool() + + @staticmethod + async def _serialize_payload(app_id, pushkey, user_id): + payload = { + "app_id": app_id, + "pushkey": pushkey, + } + + return payload + + async def _handle_request(self, request, user_id): + content = parse_json_object_from_request(request) + + app_id = content["app_id"] + pushkey = content["pushkey"] + + await self.pusher_pool.remove_pusher(app_id, pushkey, user_id) + + return 200, {} + + +def register_servlets(hs, http_server): + ReplicationRemovePusherRestServlet(hs).register(http_server) diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 0a9da79c32..bb447f75b4 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -325,31 +325,6 @@ class FederationAckCommand(Command): return "%s %s" % (self.instance_name, self.token) -class RemovePusherCommand(Command): - """Sent by the client to request the master remove the given pusher. - - Format:: - - REMOVE_PUSHER <app_id> <push_key> <user_id> - """ - - NAME = "REMOVE_PUSHER" - - def __init__(self, app_id, push_key, user_id): - self.user_id = user_id - self.app_id = app_id - self.push_key = push_key - - @classmethod - def from_line(cls, line): - app_id, push_key, user_id = line.split(" ", 2) - - return cls(app_id, push_key, user_id) - - def to_line(self): - return " ".join((self.app_id, self.push_key, self.user_id)) - - class UserIpCommand(Command): """Sent periodically when a worker sees activity from a client. @@ -416,7 +391,6 @@ _COMMANDS = ( ReplicateCommand, UserSyncCommand, FederationAckCommand, - RemovePusherCommand, UserIpCommand, RemoteServerUpCommand, ClearUserSyncsCommand, @@ -443,7 +417,6 @@ VALID_CLIENT_COMMANDS = ( UserSyncCommand.NAME, ClearUserSyncsCommand.NAME, FederationAckCommand.NAME, - RemovePusherCommand.NAME, UserIpCommand.NAME, ErrorCommand.NAME, RemoteServerUpCommand.NAME, diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index d1d00c3717..a7245da152 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -44,7 +44,6 @@ from synapse.replication.tcp.commands import ( PositionCommand, RdataCommand, RemoteServerUpCommand, - RemovePusherCommand, ReplicateCommand, UserIpCommand, UserSyncCommand, @@ -373,23 +372,6 @@ class ReplicationCommandHandler: if self._federation_sender: self._federation_sender.federation_ack(cmd.instance_name, cmd.token) - def on_REMOVE_PUSHER( - self, conn: AbstractConnection, cmd: RemovePusherCommand - ) -> Optional[Awaitable[None]]: - remove_pusher_counter.inc() - - if self._is_master: - return self._handle_remove_pusher(cmd) - else: - return None - - async def _handle_remove_pusher(self, cmd: RemovePusherCommand): - await self._store.delete_pusher_by_app_id_pushkey_user_id( - app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id - ) - - self._notifier.on_new_replication_data() - def on_USER_IP( self, conn: AbstractConnection, cmd: UserIpCommand ) -> Optional[Awaitable[None]]: @@ -684,11 +666,6 @@ class ReplicationCommandHandler: UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) ) - def send_remove_pusher(self, app_id: str, push_key: str, user_id: str): - """Poke the master to remove a pusher for a user""" - cmd = RemovePusherCommand(app_id, push_key, user_id) - self.send_command(cmd) - def send_user_ip( self, user_id: str, diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html index f4fdc40b22..00e1dcdbb8 100644 --- a/synapse/res/templates/sso_auth_account_details.html +++ b/synapse/res/templates/sso_auth_account_details.html @@ -145,7 +145,7 @@ <input type="submit" value="Continue" class="primary-button"> {% if user_attributes.avatar_url or user_attributes.display_name or user_attributes.emails %} <section class="idp-pick-details"> - <h2><img src="{{ idp.idp_icon | mxc_to_http(24, 24) }}"/>Information from {{ idp.idp_name }}</h2> + <h2>{% if idp.idp_icon %}<img src="{{ idp.idp_icon | mxc_to_http(24, 24) }}"/>{% endif %}Information from {{ idp.idp_name }}</h2> {% if user_attributes.avatar_url %} <label class="idp-detail idp-avatar" for="idp-avatar"> <div class="check-row"> diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 998a0ef671..9c701c7348 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -35,6 +35,7 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.rest.client.v2_alpha._base import client_patterns +from synapse.storage.databases.main.media_repository import MediaSortOrder from synapse.types import JsonDict, UserID if TYPE_CHECKING: @@ -832,8 +833,33 @@ class UserMediaRestServlet(RestServlet): errcode=Codes.INVALID_PARAM, ) + # If neither `order_by` nor `dir` is set, set the default order + # to newest media is on top for backward compatibility. + if b"order_by" not in request.args and b"dir" not in request.args: + order_by = MediaSortOrder.CREATED_TS.value + direction = "b" + else: + order_by = parse_string( + request, + "order_by", + default=MediaSortOrder.CREATED_TS.value, + allowed_values=( + MediaSortOrder.MEDIA_ID.value, + MediaSortOrder.UPLOAD_NAME.value, + MediaSortOrder.CREATED_TS.value, + MediaSortOrder.LAST_ACCESS_TS.value, + MediaSortOrder.MEDIA_LENGTH.value, + MediaSortOrder.MEDIA_TYPE.value, + MediaSortOrder.QUARANTINED_BY.value, + MediaSortOrder.SAFE_FROM_QUARANTINE.value, + ), + ) + direction = parse_string( + request, "dir", default="f", allowed_values=("f", "b") + ) + media, total = await self.store.get_local_media_by_user_paginate( - start, limit, user_id + start, limit, user_id, order_by, direction ) ret = {"media": media, "total": total} diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index a3dee14ed4..79c1b526ee 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -56,10 +56,8 @@ class SendToDeviceRestServlet(servlet.RestServlet): content = parse_json_object_from_request(request) assert_params_in_dict(content, ("messages",)) - sender_user_id = requester.user.to_string() - await self.device_message_handler.send_device_message( - sender_user_id, message_type, content["messages"] + requester, message_type, content["messages"] ) response = (200, {}) # type: Tuple[int, dict] diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index a0162d4255..3375455c43 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -509,7 +509,7 @@ class MediaRepository: t_height: int, t_method: str, t_type: str, - url_cache: str, + url_cache: Optional[str], ) -> Optional[str]: input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(None, media_id, url_cache=url_cache) diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 1057e638be..b1b1c9e6ec 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -244,7 +244,7 @@ class MediaStorage: await consumer.wait() return local_path - raise Exception("file could not be found") + raise NotFoundError() def _file_info_to_path(self, file_info: FileInfo) -> str: """Converts file_info into a relative path. diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index d653a58be9..3ab90e9f9b 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -114,6 +114,7 @@ class ThumbnailResource(DirectServeJsonResource): m_type, thumbnail_infos, media_id, + media_id, url_cache=media_info["url_cache"], server_name=None, ) @@ -269,6 +270,7 @@ class ThumbnailResource(DirectServeJsonResource): method, m_type, thumbnail_infos, + media_id, media_info["filesystem_id"], url_cache=None, server_name=server_name, @@ -282,6 +284,7 @@ class ThumbnailResource(DirectServeJsonResource): desired_method: str, desired_type: str, thumbnail_infos: List[Dict[str, Any]], + media_id: str, file_id: str, url_cache: Optional[str] = None, server_name: Optional[str] = None, @@ -317,8 +320,59 @@ class ThumbnailResource(DirectServeJsonResource): return responder = await self.media_storage.fetch_media(file_info) + if responder: + await respond_with_responder( + request, + responder, + file_info.thumbnail_type, + file_info.thumbnail_length, + ) + return + + # If we can't find the thumbnail we regenerate it. This can happen + # if e.g. we've deleted the thumbnails but still have the original + # image somewhere. + # + # Since we have an entry for the thumbnail in the DB we a) know we + # have have successfully generated the thumbnail in the past (so we + # don't need to worry about repeatedly failing to generate + # thumbnails), and b) have already calculated that appropriate + # width/height/method so we can just call the "generate exact" + # methods. + + # First let's check that we do actually have the original image + # still. This will throw a 404 if we don't. + # TODO: We should refetch the thumbnails for remote media. + await self.media_storage.ensure_media_is_in_local_cache( + FileInfo(server_name, file_id, url_cache=url_cache) + ) + + if server_name: + await self.media_repo.generate_remote_exact_thumbnail( + server_name, + file_id=file_id, + media_id=media_id, + t_width=file_info.thumbnail_width, + t_height=file_info.thumbnail_height, + t_method=file_info.thumbnail_method, + t_type=file_info.thumbnail_type, + ) + else: + await self.media_repo.generate_local_exact_thumbnail( + media_id=media_id, + t_width=file_info.thumbnail_width, + t_height=file_info.thumbnail_height, + t_method=file_info.thumbnail_method, + t_type=file_info.thumbnail_type, + url_cache=url_cache, + ) + + responder = await self.media_storage.fetch_media(file_info) await respond_with_responder( - request, responder, file_info.thumbnail_type, file_info.thumbnail_length + request, + responder, + file_info.thumbnail_type, + file_info.thumbnail_length, ) else: logger.info("Failed to find any generated thumbnails") diff --git a/synapse/server.py b/synapse/server.py index 6b3892e3cd..4b9ec7f0ae 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -248,7 +248,7 @@ class HomeServer(metaclass=abc.ABCMeta): self.start_time = None # type: Optional[int] self._instance_id = random_string(5) - self._instance_name = config.worker_name or "master" + self._instance_name = config.worker.instance_name self.version_string = version_string @@ -758,12 +758,6 @@ class HomeServer(metaclass=abc.ABCMeta): reconnect=True, ) - async def remove_pusher(self, app_id: str, push_key: str, user_id: str): - return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) - def should_send_federation(self) -> bool: "Should this server be sending federation traffic directly?" - return self.config.send_federation and ( - not self.config.worker_app - or self.config.worker_app == "synapse.app.federation_sender" - ) + return self.config.send_federation diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 4646926449..f1ba529a2d 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -49,7 +49,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor -from synapse.storage.util.sequence import build_sequence_generator from synapse.types import Collection # python 3 does not have a maximum int value @@ -381,7 +380,10 @@ class DatabasePool: _TXN_ID = 0 def __init__( - self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine + self, + hs, + database_config: DatabaseConnectionConfig, + engine: BaseDatabaseEngine, ): self.hs = hs self._clock = hs.get_clock() @@ -420,16 +422,6 @@ class DatabasePool: self._check_safe_to_upsert, ) - # We define this sequence here so that it can be referenced from both - # the DataStore and PersistEventStore. - def get_chain_id_txn(txn): - txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") - return txn.fetchone()[0] - - self.event_chain_id_gen = build_sequence_generator( - engine, get_chain_id_txn, "event_auth_chain_id" - ) - def is_running(self) -> bool: """Is the database pool currently running""" return self._db_pool.running diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index e84f8b42f7..379c78bb83 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -79,7 +79,7 @@ class Databases: # If we're on a process that can persist events also # instantiate a `PersistEventsStore` if hs.get_instance_name() in hs.config.worker.writers.events: - persist_events = PersistEventsStore(hs, database, main) + persist_events = PersistEventsStore(hs, database, main, db_conn) if "state" in database_config.databases: logger.info( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 287606cb4f..cd1ceac50e 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -42,7 +42,9 @@ from synapse.logging.utils import log_function from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchEntry +from synapse.storage.types import Connection from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.storage.util.sequence import SequenceGenerator from synapse.types import StateMap, get_domain_from_id from synapse.util import json_encoder from synapse.util.iterutils import batch_iter, sorted_topologically @@ -90,7 +92,11 @@ class PersistEventsStore: """ def __init__( - self, hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore" + self, + hs: "HomeServer", + db: DatabasePool, + main_data_store: "DataStore", + db_conn: Connection, ): self.hs = hs self.db_pool = db @@ -474,6 +480,7 @@ class PersistEventsStore: self._add_chain_cover_index( txn, self.db_pool, + self.store.event_chain_id_gen, event_to_room_id, event_to_types, event_to_auth_chain, @@ -484,6 +491,7 @@ class PersistEventsStore: cls, txn, db_pool: DatabasePool, + event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], event_to_types: Dict[str, Tuple[str, str]], event_to_auth_chain: Dict[str, List[str]], @@ -630,6 +638,7 @@ class PersistEventsStore: new_chain_tuples = cls._allocate_chain_ids( txn, db_pool, + event_chain_id_gen, event_to_room_id, event_to_types, event_to_auth_chain, @@ -768,6 +777,7 @@ class PersistEventsStore: def _allocate_chain_ids( txn, db_pool: DatabasePool, + event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], event_to_types: Dict[str, Tuple[str, str]], event_to_auth_chain: Dict[str, List[str]], @@ -880,7 +890,7 @@ class PersistEventsStore: chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1] # Generate new chain IDs for all unallocated chain IDs. - newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn( + newly_allocated_chain_ids = event_chain_id_gen.get_next_mult_txn( txn, len(unallocated_chain_ids) ) diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 89274e75f7..c1626ccf28 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -917,6 +917,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): PersistEventsStore._add_chain_cover_index( txn, self.db_pool, + self.event_chain_id_gen, event_to_room_id, event_to_types, event_to_auth_chain, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index c8850a4707..edbe42f2bf 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -45,6 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.sequence import build_sequence_generator from synapse.types import Collection, JsonDict, get_domain_from_id from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache @@ -156,6 +157,21 @@ class EventsWorkerStore(SQLBaseStore): self._event_fetch_list = [] self._event_fetch_ongoing = 0 + # We define this sequence here so that it can be referenced from both + # the DataStore and PersistEventStore. + def get_chain_id_txn(txn): + txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") + return txn.fetchone()[0] + + self.event_chain_id_gen = build_sequence_generator( + db_conn, + database.engine, + get_chain_id_txn, + "event_auth_chain_id", + table="event_auth_chains", + id_column="chain_id", + ) + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == EventsStream.NAME: self._stream_id_gen.advance(instance_name, token) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index a0313c3ccf..274f8de595 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum from typing import Any, Dict, Iterable, List, Optional, Tuple from synapse.storage._base import SQLBaseStore @@ -23,6 +24,22 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = ( ) +class MediaSortOrder(Enum): + """ + Enum to define the sorting method used when returning media with + get_local_media_by_user_paginate + """ + + MEDIA_ID = "media_id" + UPLOAD_NAME = "upload_name" + CREATED_TS = "created_ts" + LAST_ACCESS_TS = "last_access_ts" + MEDIA_LENGTH = "media_length" + MEDIA_TYPE = "media_type" + QUARANTINED_BY = "quarantined_by" + SAFE_FROM_QUARANTINE = "safe_from_quarantine" + + class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -118,7 +135,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) async def get_local_media_by_user_paginate( - self, start: int, limit: int, user_id: str + self, + start: int, + limit: int, + user_id: str, + order_by: MediaSortOrder = MediaSortOrder.CREATED_TS.value, + direction: str = "f", ) -> Tuple[List[Dict[str, Any]], int]: """Get a paginated list of metadata for a local piece of media which an user_id has uploaded @@ -127,6 +149,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): start: offset in the list limit: maximum amount of media_ids to retrieve user_id: fully-qualified user id + order_by: the sort order of the returned list + direction: sort ascending or descending Returns: A paginated list of all metadata of user's media, plus the total count of all the user's media @@ -134,6 +158,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def get_local_media_by_user_paginate_txn(txn): + # Set ordering + order_by_column = MediaSortOrder(order_by).value + + if direction == "b": + order = "DESC" + else: + order = "ASC" + args = [user_id] sql = """ SELECT COUNT(*) as total_media @@ -155,9 +187,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "safe_from_quarantine" FROM local_media_repository WHERE user_id = ? - ORDER BY created_ts DESC, media_id DESC + ORDER BY {order_by_column} {order}, media_id ASC LIMIT ? OFFSET ? - """ + """.format( + order_by_column=order_by_column, + order=order, + ) args += [limit, start] txn.execute(sql, args) @@ -344,16 +379,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - await self.db_pool.simple_insert( - "local_media_repository_thumbnails", - { + await self.db_pool.simple_upsert( + table="local_media_repository_thumbnails", + keyvalues={ "media_id": media_id, "thumbnail_width": thumbnail_width, "thumbnail_height": thumbnail_height, "thumbnail_method": thumbnail_method, "thumbnail_type": thumbnail_type, - "thumbnail_length": thumbnail_length, }, + values={"thumbnail_length": thumbnail_length}, desc="store_local_thumbnail", ) @@ -498,18 +533,18 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): thumbnail_method, thumbnail_length, ): - await self.db_pool.simple_insert( - "remote_media_cache_thumbnails", - { + await self.db_pool.simple_upsert( + table="remote_media_cache_thumbnails", + keyvalues={ "media_origin": origin, "media_id": media_id, "thumbnail_width": thumbnail_width, "thumbnail_height": thumbnail_height, "thumbnail_method": thumbnail_method, "thumbnail_type": thumbnail_type, - "thumbnail_length": thumbnail_length, - "filesystem_id": filesystem_id, }, + values={"thumbnail_length": thumbnail_length}, + insertion_values={"filesystem_id": filesystem_id}, desc="store_remote_media_thumbnail", ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index d5b5507815..61a7556e56 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -23,7 +23,7 @@ import attr from synapse.api.constants import UserTypes from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.stats import StatsStore from synapse.storage.types import Connection, Cursor @@ -70,7 +70,12 @@ class TokenLookupResult: class RegistrationWorkerStore(CacheInvalidationWorkerStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self.config = hs.config @@ -79,9 +84,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # call `find_max_generated_user_id_localpart` each time, which is # expensive if there are many entries. self._user_id_seq = build_sequence_generator( + db_conn, database.engine, find_max_generated_user_id_localpart, "user_id_seq", + table=None, + id_column=None, ) self._account_validity = hs.config.account_validity @@ -1036,7 +1044,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._clock = hs.get_clock() diff --git a/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql b/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql new file mode 100644 index 0000000000..2442eea6bc --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql @@ -0,0 +1,19 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Delete all pushers associated with deleted devices. This is to clear up after +-- a bug where they weren't correctly deleted when using workers. +DELETE FROM pushers WHERE access_token NOT IN (SELECT id FROM access_tokens); diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 63f88eac51..1026f321e5 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -497,8 +497,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): async def add_users_in_public_rooms( self, room_id: str, user_ids: Iterable[str] ) -> None: - """Insert entries into the users_who_share_private_rooms table. The first - user should be a local user. + """Insert entries into the users_in_public_rooms table. Args: room_id @@ -556,6 +555,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) + self._prefer_local_users_in_search = ( + hs.config.user_directory_search_prefer_local_users + ) + self._server_name = hs.config.server_name + async def remove_from_user_dir(self, user_id: str) -> None: def _remove_from_user_dir_txn(txn): self.db_pool.simple_delete_txn( @@ -665,7 +669,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - @cached() async def get_shared_rooms_for_users( self, user_id: str, other_user_id: str ) -> Set[str]: @@ -754,9 +757,24 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): ) """ + # We allow manipulating the ranking algorithm by injecting statements + # based on config options. + additional_ordering_statements = [] + ordering_arguments = () + if isinstance(self.database_engine, PostgresEngine): full_query, exact_query, prefix_query = _parse_query_postgres(search_term) + # If enabled, this config option will rank local users higher than those on + # remote instances. + if self._prefer_local_users_in_search: + # This statement checks whether a given user's user ID contains a server name + # that matches the local server + statement = "* (CASE WHEN user_id LIKE ? THEN 2.0 ELSE 1.0 END)" + additional_ordering_statements.append(statement) + + ordering_arguments += ("%:" + self._server_name,) + # We order by rank and then if they have profile info # The ranking algorithm is hand tweaked for "best" results. Broadly # the idea is we give a higher weight to exact matches. @@ -767,7 +785,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): FROM user_directory_search as t INNER JOIN user_directory AS d USING (user_id) WHERE - %s + %(where_clause)s AND vector @@ to_tsquery('simple', ?) ORDER BY (CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) @@ -787,33 +805,54 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): 8 ) ) + %(order_case_statements)s DESC, display_name IS NULL, avatar_url IS NULL LIMIT ? - """ % ( - where_clause, + """ % { + "where_clause": where_clause, + "order_case_statements": " ".join(additional_ordering_statements), + } + args = ( + join_args + + (full_query, exact_query, prefix_query) + + ordering_arguments + + (limit + 1,) ) - args = join_args + (full_query, exact_query, prefix_query, limit + 1) elif isinstance(self.database_engine, Sqlite3Engine): search_query = _parse_query_sqlite(search_term) + # If enabled, this config option will rank local users higher than those on + # remote instances. + if self._prefer_local_users_in_search: + # This statement checks whether a given user's user ID contains a server name + # that matches the local server + # + # Note that we need to include a comma at the end for valid SQL + statement = "user_id LIKE ? DESC," + additional_ordering_statements.append(statement) + + ordering_arguments += ("%:" + self._server_name,) + sql = """ SELECT d.user_id AS user_id, display_name, avatar_url FROM user_directory_search as t INNER JOIN user_directory AS d USING (user_id) WHERE - %s + %(where_clause)s AND value MATCH ? ORDER BY rank(matchinfo(user_directory_search)) DESC, + %(order_statements)s display_name IS NULL, avatar_url IS NULL LIMIT ? - """ % ( - where_clause, - ) - args = join_args + (search_query, limit + 1) + """ % { + "where_clause": where_clause, + "order_statements": " ".join(additional_ordering_statements), + } + args = join_args + (search_query,) + ordering_arguments + (limit + 1,) else: # This should be unreachable. raise Exception("Unrecognized database engine") diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index b16b9905d8..e2240703a7 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -97,10 +97,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return txn.fetchone()[0] self._state_group_seq_gen = build_sequence_generator( - self.database_engine, get_max_state_group_txn, "state_group_id_seq" - ) - self._state_group_seq_gen.check_consistency( - db_conn, table="state_groups", id_column="id" + db_conn, + self.database_engine, + get_max_state_group_txn, + "state_group_id_seq", + table="state_groups", + id_column="id", ) @cached(max_entries=10000, iterable=True) diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index 3ea637b281..36a67e7019 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -251,9 +251,14 @@ class LocalSequenceGenerator(SequenceGenerator): def build_sequence_generator( + db_conn: "LoggingDatabaseConnection", database_engine: BaseDatabaseEngine, get_first_callback: GetFirstCallbackType, sequence_name: str, + table: Optional[str], + id_column: Optional[str], + stream_name: Optional[str] = None, + positive: bool = True, ) -> SequenceGenerator: """Get the best impl of SequenceGenerator available @@ -265,8 +270,23 @@ def build_sequence_generator( get_first_callback: a callback which gets the next sequence ID. Used if we're on sqlite. sequence_name: the name of a postgres sequence to use. + table, id_column, stream_name, positive: If set then `check_consistency` + is called on the created sequence. See docstring for + `check_consistency` details. """ if isinstance(database_engine, PostgresEngine): - return PostgresSequenceGenerator(sequence_name) + seq = PostgresSequenceGenerator(sequence_name) # type: SequenceGenerator else: - return LocalSequenceGenerator(get_first_callback) + seq = LocalSequenceGenerator(get_first_callback) + + if table: + assert id_column + seq.check_consistency( + db_conn=db_conn, + table=table, + id_column=id_column, + stream_name=stream_name, + positive=positive, + ) + + return seq diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 32228f42ee..53f85195a7 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Set, TypeVar from twisted.internet import defer @@ -40,6 +40,7 @@ class ResponseCache(Generic[T]): def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0): # Requests that haven't finished yet. self.pending_result_cache = {} # type: Dict[T, ObservableDeferred] + self.pending_conditionals = {} # type: Dict[T, Set[Callable[[Any], bool]]] self.clock = hs.get_clock() self.timeout_sec = timeout_ms / 1000.0 @@ -101,7 +102,11 @@ class ResponseCache(Generic[T]): self.pending_result_cache[key] = result def remove(r): - if self.timeout_sec: + should_cache = all( + func(r) for func in self.pending_conditionals.pop(key, []) + ) + + if self.timeout_sec and should_cache: self.clock.call_later( self.timeout_sec, self.pending_result_cache.pop, key, None ) @@ -112,6 +117,31 @@ class ResponseCache(Generic[T]): result.addBoth(remove) return result.observe() + def add_conditional(self, key: T, conditional: Callable[[Any], bool]): + self.pending_conditionals.setdefault(key, set()).add(conditional) + + def wrap_conditional( + self, + key: T, + should_cache: Callable[[Any], bool], + callback: "Callable[..., Any]", + *args: Any, + **kwargs: Any + ) -> defer.Deferred: + """The same as wrap(), but adds a conditional to the final execution. + + When the final execution completes, *all* conditionals need to return True for it to properly cache, + else it'll not be cached in a timed fashion. + """ + + # See if there's already a result on this key that hasn't yet completed. Due to the single-threaded nature of + # python, adding a key immediately in the same execution thread will not cause a race condition. + result = self.get(key) + if not result or isinstance(result, defer.Deferred) and not result.called: + self.add_conditional(key, should_cache) + + return self.wrap(key, callback, *args, **kwargs) + def wrap( self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any ) -> defer.Deferred: diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index be2ee26f07..996c614198 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -521,7 +521,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): ) self.assertEqual(expected_state.state, PresenceState.ONLINE) self.federation_sender.send_presence_to_destinations.assert_called_once_with( - destinations=["server2"], states=[expected_state] + destinations=["server2"], states={expected_state} ) # @@ -533,7 +533,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): self.federation_sender.send_presence.assert_not_called() self.federation_sender.send_presence_to_destinations.assert_called_once_with( - destinations=["server3"], states=[expected_state] + destinations=["server3"], states={expected_state} ) def test_remote_gets_presence_when_local_user_joins(self): @@ -584,8 +584,14 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): self.presence_handler.current_state_for_user("@test2:server") ) self.assertEqual(expected_state.state, PresenceState.ONLINE) - self.federation_sender.send_presence_to_destinations.assert_called_once_with( - destinations={"server2", "server3"}, states=[expected_state] + self.assertEqual( + self.federation_sender.send_presence_to_destinations.call_count, 2 + ) + self.federation_sender.send_presence_to_destinations.assert_any_call( + destinations=["server3"], states={expected_state} + ) + self.federation_sender.send_presence_to_destinations.assert_any_call( + destinations=["server2"], states={expected_state} ) def _add_new_user(self, room_id, user_id): diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 18ca8b84f5..75c6a4e21c 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -161,7 +161,11 @@ class ProfileTestCase(unittest.HomeserverTestCase): response = self.get_success( self.query_handlers["profile"]( - {"user_id": "@caroline:test", "field": "displayname"} + { + "user_id": "@caroline:test", + "field": "displayname", + "origin": "servername.tld", + } ) ) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 3572e54c5d..98b2f5b383 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -18,6 +18,7 @@ from twisted.internet import defer import synapse.rest.admin from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes +from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import user_directory from synapse.storage.roommember import ProfileInfo @@ -46,6 +47,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.handler = hs.get_user_directory_handler() + self.event_builder_factory = self.hs.get_event_builder_factory() + self.event_creation_handler = self.hs.get_event_creation_handler() def test_handle_local_profile_change_with_support_user(self): support_user_id = "@support:test" @@ -547,6 +550,100 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, u4, 10)) self.assertEqual(len(s["results"]), 1) + @override_config( + { + "user_directory": { + "enabled": True, + "search_all_users": True, + "prefer_local_users": True, + } + } + ) + def test_prefer_local_users(self): + """Tests that local users are shown higher in search results when + user_directory.prefer_local_users is True. + """ + # Create a room and few users to test the directory with + searching_user = self.register_user("searcher", "password") + searching_user_tok = self.login("searcher", "password") + + room_id = self.helper.create_room_as( + searching_user, + room_version=RoomVersions.V1.identifier, + tok=searching_user_tok, + ) + + # Create a few local users and join them to the room + local_user_1 = self.register_user("user_xxxxx", "password") + local_user_2 = self.register_user("user_bbbbb", "password") + local_user_3 = self.register_user("user_zzzzz", "password") + + self._add_user_to_room(room_id, RoomVersions.V1, local_user_1) + self._add_user_to_room(room_id, RoomVersions.V1, local_user_2) + self._add_user_to_room(room_id, RoomVersions.V1, local_user_3) + + # Create a few "remote" users and join them to the room + remote_user_1 = "@user_aaaaa:remote_server" + remote_user_2 = "@user_yyyyy:remote_server" + remote_user_3 = "@user_ccccc:remote_server" + self._add_user_to_room(room_id, RoomVersions.V1, remote_user_1) + self._add_user_to_room(room_id, RoomVersions.V1, remote_user_2) + self._add_user_to_room(room_id, RoomVersions.V1, remote_user_3) + + local_users = [local_user_1, local_user_2, local_user_3] + remote_users = [remote_user_1, remote_user_2, remote_user_3] + + # Populate the user directory via background update + self._add_background_updates() + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # The local searching user searches for the term "user", which other users have + # in their user id + results = self.get_success( + self.handler.search_users(searching_user, "user", 20) + )["results"] + received_user_id_ordering = [result["user_id"] for result in results] + + # Typically we'd expect Synapse to return users in lexicographical order, + # assuming they have similar User IDs/display names, and profile information. + + # Check that the order of returned results using our module is as we expect, + # i.e our local users show up first, despite all users having lexographically mixed + # user IDs. + [self.assertIn(user, local_users) for user in received_user_id_ordering[:3]] + [self.assertIn(user, remote_users) for user in received_user_id_ordering[3:]] + + def _add_user_to_room( + self, + room_id: str, + room_version: RoomVersion, + user_id: str, + ): + # Add a user to the room. + builder = self.event_builder_factory.for_room_version( + room_version, + { + "type": "m.room.member", + "sender": user_id, + "state_key": user_id, + "room_id": room_id, + "content": {"membership": "join"}, + }, + ) + + event, context = self.get_success( + self.event_creation_handler.create_new_client_event(builder) + ) + + self.get_success( + self.hs.get_storage().persistence.persist_event(event, context) + ) + class TestUserDirSearchDisabled(unittest.HomeserverTestCase): user_id = "@test:test" diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py index 2babea4e3e..aa4bf1c7e3 100644 --- a/tests/replication/tcp/streams/test_federation.py +++ b/tests/replication/tcp/streams/test_federation.py @@ -24,7 +24,7 @@ class FederationStreamTestCase(BaseStreamTestCase): # enable federation sending on the worker config = super()._get_worker_hs_config() # TODO: make it so we don't need both of these - config["send_federation"] = True + config["send_federation"] = False config["worker_app"] = "synapse.app.federation_sender" return config diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 1853667558..f235f1bd83 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -27,7 +27,7 @@ class FederationAckTestCase(HomeserverTestCase): def default_config(self) -> dict: config = super().default_config() config["worker_app"] = "synapse.app.federation_sender" - config["send_federation"] = True + config["send_federation"] = False return config def make_homeserver(self, reactor, clock): diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index fffdb742c8..2f2d117858 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -49,7 +49,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): self.make_worker_hs( "synapse.app.federation_sender", - {"send_federation": True}, + {"send_federation": False}, federation_http_client=mock_client, ) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index f118fe32af..ab2988a6ba 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -95,7 +95,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): self.make_worker_hs( "synapse.app.pusher", - {"start_pushers": True}, + {"start_pushers": False}, proxied_blacklisted_http_client=http_client_mock, ) diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index ba26895391..e58d5cf0db 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -18,7 +18,7 @@ import hmac import json import urllib.parse from binascii import unhexlify -from typing import Optional +from typing import List, Optional from mock import Mock @@ -31,6 +31,7 @@ from synapse.rest.client.v2_alpha import devices, sync from synapse.types import JsonDict from tests import unittest +from tests.server import FakeSite, make_request from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -1954,6 +1955,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() self.media_repo = hs.get_media_repository_resource() self.admin_user = self.register_user("admin", "pass", admin=True) @@ -2024,7 +2026,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): number_media = 20 other_user_tok = self.login("user", "pass") - self._create_media(other_user_tok, number_media) + self._create_media_for_user(other_user_tok, number_media) channel = self.make_request( "GET", @@ -2045,7 +2047,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): number_media = 20 other_user_tok = self.login("user", "pass") - self._create_media(other_user_tok, number_media) + self._create_media_for_user(other_user_tok, number_media) channel = self.make_request( "GET", @@ -2066,7 +2068,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): number_media = 20 other_user_tok = self.login("user", "pass") - self._create_media(other_user_tok, number_media) + self._create_media_for_user(other_user_tok, number_media) channel = self.make_request( "GET", @@ -2080,11 +2082,31 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(len(channel.json_body["media"]), 10) self._check_fields(channel.json_body["media"]) - def test_limit_is_negative(self): + def test_invalid_parameter(self): """ - Testing that a negative limit parameter returns a 400 + If parameters are invalid, an error is returned. """ + # unkown order_by + channel = self.make_request( + "GET", + self.url + "?order_by=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + # invalid search order + channel = self.make_request( + "GET", + self.url + "?dir=bar", + access_token=self.admin_user_tok, + ) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) + + # negative limit channel = self.make_request( "GET", self.url + "?limit=-5", @@ -2094,11 +2116,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) - def test_from_is_negative(self): - """ - Testing that a negative from parameter returns a 400 - """ - + # negative from channel = self.make_request( "GET", self.url + "?from=-5", @@ -2115,7 +2133,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): number_media = 20 other_user_tok = self.login("user", "pass") - self._create_media(other_user_tok, number_media) + self._create_media_for_user(other_user_tok, number_media) # `next_token` does not appear # Number of results is the number of entries @@ -2193,7 +2211,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): number_media = 5 other_user_tok = self.login("user", "pass") - self._create_media(other_user_tok, number_media) + self._create_media_for_user(other_user_tok, number_media) channel = self.make_request( "GET", @@ -2207,11 +2225,118 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertNotIn("next_token", channel.json_body) self._check_fields(channel.json_body["media"]) - def _create_media(self, user_token, number_media): + def test_order_by(self): + """ + Testing order list with parameter `order_by` + """ + + other_user_tok = self.login("user", "pass") + + # Resolution: 1×1, MIME type: image/png, Extension: png, Size: 67 B + image_data1 = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + # Resolution: 1×1, MIME type: image/gif, Extension: gif, Size: 35 B + image_data2 = unhexlify( + b"47494638376101000100800100000000" + b"ffffff2c00000000010001000002024c" + b"01003b" + ) + # Resolution: 1×1, MIME type: image/bmp, Extension: bmp, Size: 54 B + image_data3 = unhexlify( + b"424d3a0000000000000036000000280000000100000001000000" + b"0100180000000000040000000000000000000000000000000000" + b"0000" + ) + + # create media and make sure they do not have the same timestamp + media1 = self._create_media_and_access(other_user_tok, image_data1, "image.png") + self.pump(1.0) + media2 = self._create_media_and_access(other_user_tok, image_data2, "image.gif") + self.pump(1.0) + media3 = self._create_media_and_access(other_user_tok, image_data3, "image.bmp") + self.pump(1.0) + + # Mark one media as safe from quarantine. + self.get_success(self.store.mark_local_media_as_safe(media2)) + # Quarantine one media + self.get_success( + self.store.quarantine_media_by_id("test", media3, self.admin_user) + ) + + # order by default ("created_ts") + # default is backwards + self._order_test([media3, media2, media1], None) + self._order_test([media1, media2, media3], None, "f") + self._order_test([media3, media2, media1], None, "b") + + # sort by media_id + sorted_media = sorted([media1, media2, media3], reverse=False) + sorted_media_reverse = sorted(sorted_media, reverse=True) + + # order by media_id + self._order_test(sorted_media, "media_id") + self._order_test(sorted_media, "media_id", "f") + self._order_test(sorted_media_reverse, "media_id", "b") + + # order by upload_name + self._order_test([media3, media2, media1], "upload_name") + self._order_test([media3, media2, media1], "upload_name", "f") + self._order_test([media1, media2, media3], "upload_name", "b") + + # order by media_type + # result is ordered by media_id + # because of uploaded media_type is always 'application/json' + self._order_test(sorted_media, "media_type") + self._order_test(sorted_media, "media_type", "f") + self._order_test(sorted_media, "media_type", "b") + + # order by media_length + self._order_test([media2, media3, media1], "media_length") + self._order_test([media2, media3, media1], "media_length", "f") + self._order_test([media1, media3, media2], "media_length", "b") + + # order by created_ts + self._order_test([media1, media2, media3], "created_ts") + self._order_test([media1, media2, media3], "created_ts", "f") + self._order_test([media3, media2, media1], "created_ts", "b") + + # order by last_access_ts + self._order_test([media1, media2, media3], "last_access_ts") + self._order_test([media1, media2, media3], "last_access_ts", "f") + self._order_test([media3, media2, media1], "last_access_ts", "b") + + # order by quarantined_by + # one media is in quarantine, others are ordered by media_ids + + # Different sort order of SQlite and PostreSQL + # If a media is not in quarantine `quarantined_by` is NULL + # SQLite considers NULL to be smaller than any other value. + # PostreSQL considers NULL to be larger than any other value. + + # self._order_test(sorted([media1, media2]) + [media3], "quarantined_by") + # self._order_test(sorted([media1, media2]) + [media3], "quarantined_by", "f") + # self._order_test([media3] + sorted([media1, media2]), "quarantined_by", "b") + + # order by safe_from_quarantine + # one media is safe from quarantine, others are ordered by media_ids + self._order_test(sorted([media1, media3]) + [media2], "safe_from_quarantine") + self._order_test( + sorted([media1, media3]) + [media2], "safe_from_quarantine", "f" + ) + self._order_test( + [media2] + sorted([media1, media3]), "safe_from_quarantine", "b" + ) + + def _create_media_for_user(self, user_token: str, number_media: int): """ Create a number of media for a specific user + Args: + user_token: Access token of the user + number_media: Number of media to be created for the user """ - upload_resource = self.media_repo.children[b"upload"] for i in range(number_media): # file size is 67 Byte image_data = unhexlify( @@ -2220,13 +2345,60 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): b"0a2db40000000049454e44ae426082" ) - # Upload some media into the room - self.helper.upload_media( - upload_resource, image_data, tok=user_token, expect_code=200 - ) + self._create_media_and_access(user_token, image_data) + + def _create_media_and_access( + self, + user_token: str, + image_data: bytes, + filename: str = "image1.png", + ) -> str: + """ + Create one media for a specific user, access and returns `media_id` + Args: + user_token: Access token of the user + image_data: binary data of image + filename: The filename of the media to be uploaded + Returns: + The ID of the newly created media. + """ + upload_resource = self.media_repo.children[b"upload"] + download_resource = self.media_repo.children[b"download"] + + # Upload some media into the room + response = self.helper.upload_media( + upload_resource, image_data, user_token, filename, expect_code=200 + ) + + # Extract media ID from the response + server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' + media_id = server_and_media_id.split("/")[1] + + # Try to access a media and to create `last_access_ts` + channel = make_request( + self.reactor, + FakeSite(download_resource), + "GET", + server_and_media_id, + shorthand=False, + access_token=user_token, + ) + + self.assertEqual( + 200, + channel.code, + msg=( + "Expected to receive a 200 on accessing media: %s" % server_and_media_id + ), + ) - def _check_fields(self, content): - """Checks that all attributes are present in content""" + return media_id + + def _check_fields(self, content: JsonDict): + """Checks that the expected user attributes are present in content + Args: + content: List that is checked for content + """ for m in content: self.assertIn("media_id", m) self.assertIn("media_type", m) @@ -2237,6 +2409,38 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.assertIn("quarantined_by", m) self.assertIn("safe_from_quarantine", m) + def _order_test( + self, + expected_media_list: List[str], + order_by: Optional[str], + dir: Optional[str] = None, + ): + """Request the list of media in a certain order. Assert that order is what + we expect + Args: + expected_media_list: The list of media_ids in the order we expect to get + back from the server + order_by: The type of ordering to give the server + dir: The direction of ordering to give the server + """ + + url = self.url + "?" + if order_by is not None: + url += "order_by=%s&" % (order_by,) + if dir is not None and dir in ("b", "f"): + url += "dir=%s" % (dir,) + channel = self.make_request( + "GET", + url.encode("ascii"), + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(channel.json_body["total"], len(expected_media_list)) + + returned_order = [row["media_id"] for row in channel.json_body["media"]] + self.assertEqual(expected_media_list, returned_order) + self._check_fields(channel.json_body["media"]) + class UserTokenRestTestCase(unittest.HomeserverTestCase): """Test for /_synapse/admin/v1/users/<user>/login""" diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/v2_alpha/test_shared_rooms.py index 116ace1812..dd83a1f8ff 100644 --- a/tests/rest/client/v2_alpha/test_shared_rooms.py +++ b/tests/rest/client/v2_alpha/test_shared_rooms.py @@ -54,61 +54,62 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): A room should show up in the shared list of rooms between two users if it is public. """ - u1 = self.register_user("user1", "pass") - u1_token = self.login(u1, "pass") - u2 = self.register_user("user2", "pass") - u2_token = self.login(u2, "pass") - - room = self.helper.create_room_as(u1, is_public=True, tok=u1_token) - self.helper.invite(room, src=u1, targ=u2, tok=u1_token) - self.helper.join(room, user=u2, tok=u2_token) - - channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 1) - self.assertEquals(channel.json_body["joined"][0], room) + self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=True) def test_shared_room_list_private(self): """ A room should show up in the shared list of rooms between two users if it is private. """ - u1 = self.register_user("user1", "pass") - u1_token = self.login(u1, "pass") - u2 = self.register_user("user2", "pass") - u2_token = self.login(u2, "pass") - - room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) - self.helper.invite(room, src=u1, targ=u2, tok=u1_token) - self.helper.join(room, user=u2, tok=u2_token) - - channel = self._get_shared_rooms(u1_token, u2) - self.assertEquals(200, channel.code, channel.result) - self.assertEquals(len(channel.json_body["joined"]), 1) - self.assertEquals(channel.json_body["joined"][0], room) + self._check_shared_rooms_with( + room_one_is_public=False, room_two_is_public=False + ) def test_shared_room_list_mixed(self): """ The shared room list between two users should contain both public and private rooms. """ + self._check_shared_rooms_with(room_one_is_public=True, room_two_is_public=False) + + def _check_shared_rooms_with( + self, room_one_is_public: bool, room_two_is_public: bool + ): + """Checks that shared public or private rooms between two users appear in + their shared room lists + """ u1 = self.register_user("user1", "pass") u1_token = self.login(u1, "pass") u2 = self.register_user("user2", "pass") u2_token = self.login(u2, "pass") - room_public = self.helper.create_room_as(u1, is_public=True, tok=u1_token) - room_private = self.helper.create_room_as(u2, is_public=False, tok=u2_token) - self.helper.invite(room_public, src=u1, targ=u2, tok=u1_token) - self.helper.invite(room_private, src=u2, targ=u1, tok=u2_token) - self.helper.join(room_public, user=u2, tok=u2_token) - self.helper.join(room_private, user=u1, tok=u1_token) + # Create a room. user1 invites user2, who joins + room_id_one = self.helper.create_room_as( + u1, is_public=room_one_is_public, tok=u1_token + ) + self.helper.invite(room_id_one, src=u1, targ=u2, tok=u1_token) + self.helper.join(room_id_one, user=u2, tok=u2_token) + # Check shared rooms from user1's perspective. + # We should see the one room in common + channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 1) + self.assertEquals(channel.json_body["joined"][0], room_id_one) + + # Create another room and invite user2 to it + room_id_two = self.helper.create_room_as( + u1, is_public=room_two_is_public, tok=u1_token + ) + self.helper.invite(room_id_two, src=u1, targ=u2, tok=u1_token) + self.helper.join(room_id_two, user=u2, tok=u2_token) + + # Check shared rooms again. We should now see both rooms. channel = self._get_shared_rooms(u1_token, u2) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 2) - self.assertTrue(room_public in channel.json_body["joined"]) - self.assertTrue(room_private in channel.json_body["joined"]) + for room_id_id in channel.json_body["joined"]: + self.assertIn(room_id_id, [room_id_one, room_id_two]) def test_shared_room_list_after_leave(self): """ @@ -132,6 +133,12 @@ class UserSharedRoomsTest(unittest.HomeserverTestCase): self.helper.leave(room, user=u1, tok=u1_token) + # Check user1's view of shared rooms with user2 + channel = self._get_shared_rooms(u1_token, u2) + self.assertEquals(200, channel.code, channel.result) + self.assertEquals(len(channel.json_body["joined"]), 0) + + # Check user2's view of shared rooms with user1 channel = self._get_shared_rooms(u2_token, u1) self.assertEquals(200, channel.code, channel.result) self.assertEquals(len(channel.json_body["joined"]), 0) diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 0789b12392..36d1e6bc4a 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -231,9 +231,11 @@ class MediaRepoTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): - self.media_repo = hs.get_media_repository_resource() - self.download_resource = self.media_repo.children[b"download"] - self.thumbnail_resource = self.media_repo.children[b"thumbnail"] + media_resource = hs.get_media_repository_resource() + self.download_resource = media_resource.children[b"download"] + self.thumbnail_resource = media_resource.children[b"thumbnail"] + self.store = hs.get_datastore() + self.media_repo = hs.get_media_repository() self.media_id = "example.com/12345" @@ -357,6 +359,67 @@ class MediaRepoTests(unittest.HomeserverTestCase): """ self._test_thumbnail("scale", None, False) + def test_thumbnail_repeated_thumbnail(self): + """Test that fetching the same thumbnail works, and deleting the on disk + thumbnail regenerates it. + """ + self._test_thumbnail( + "scale", self.test_image.expected_scaled, self.test_image.expected_found + ) + + if not self.test_image.expected_found: + return + + # Fetching again should work, without re-requesting the image from the + # remote. + params = "?width=32&height=32&method=scale" + channel = make_request( + self.reactor, + FakeSite(self.thumbnail_resource), + "GET", + self.media_id + params, + shorthand=False, + await_result=False, + ) + self.pump() + + self.assertEqual(channel.code, 200) + if self.test_image.expected_scaled: + self.assertEqual( + channel.result["body"], + self.test_image.expected_scaled, + channel.result["body"], + ) + + # Deleting the thumbnail on disk then re-requesting it should work as + # Synapse should regenerate missing thumbnails. + origin, media_id = self.media_id.split("/") + info = self.get_success(self.store.get_cached_remote_media(origin, media_id)) + file_id = info["filesystem_id"] + + thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir( + origin, file_id + ) + shutil.rmtree(thumbnail_dir, ignore_errors=True) + + channel = make_request( + self.reactor, + FakeSite(self.thumbnail_resource), + "GET", + self.media_id + params, + shorthand=False, + await_result=False, + ) + self.pump() + + self.assertEqual(channel.code, 200) + if self.test_image.expected_scaled: + self.assertEqual( + channel.result["body"], + self.test_image.expected_scaled, + channel.result["body"], + ) + def _test_thumbnail(self, method, expected_body, expected_found): params = "?width=32&height=32&method=" + method channel = make_request( diff --git a/tests/utils.py b/tests/utils.py index 4fb5098550..be80b13760 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -114,7 +114,6 @@ def default_config(name, parse=False): "server_name": name, "send_federation": False, "media_store_path": "media", - "uploads_path": "uploads", # the test signing key is just an arbitrary ed25519 key to keep the config # parser happy "signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg", |