From 27c06a6e0699f92bcd02b9e930dc8191ab87305e Mon Sep 17 00:00:00 2001 From: "Michael[tm] Smith" Date: Wed, 23 Jun 2021 19:25:03 +0900 Subject: Drop Origin & Accept from Access-Control-Allow-Headers value (#10114) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Drop Origin & Accept from Access-Control-Allow-Headers value This change drops the Origin and Accept header names from the value of the Access-Control-Allow-Headers response header sent by Synapse. Per the CORS protocol, it’s not necessary or useful to include those header names. Details: Per-spec at https://fetch.spec.whatwg.org/#forbidden-header-name, Origin is a “forbidden header name” set by the browser and that frontend JavaScript code is never allowed to set. So the value of Access-Control-Allow-Headers isn’t relevant to Origin or in general to other headers set by the browser itself — the browser never ever consults the Access-Control-Allow-Headers value to confirm that it’s OK for the request to include an Origin header. And per-spec at https://fetch.spec.whatwg.org/#cors-safelisted-request-header, Accept is a “CORS-safelisted request-header”, which means that browsers allow requests to contain the Accept header regardless of whether the Access-Control-Allow-Headers value contains "Accept". So it’s unnecessary for the Access-Control-Allow-Headers to explicitly include Accept. Browsers will not perform a CORS preflight for requests containing an Accept request header. Related: https://github.com/matrix-org/matrix-doc/pull/3225 Signed-off-by: Michael[tm] Smith --- synapse/http/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/http/server.py b/synapse/http/server.py index 845651e606..efbc6d5b25 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -728,7 +728,7 @@ def set_cors_headers(request: Request): ) request.setHeader( b"Access-Control-Allow-Headers", - b"Origin, X-Requested-With, Content-Type, Accept, Authorization, Date", + b"X-Requested-With, Content-Type, Authorization, Date", ) -- cgit 1.5.1 From 8beead66ae48aa11f1e25da42256eb92b8bce099 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 23 Jun 2021 12:54:50 +0100 Subject: Send out invite rejections and knocks over federation (#10223) ensure that events sent via `send_leave` and `send_knock` are sent on to the rest of the federation. --- changelog.d/10223.bugfix | 1 + scripts-dev/complement.sh | 2 +- synapse/handlers/federation.py | 14 ++++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10223.bugfix (limited to 'synapse') diff --git a/changelog.d/10223.bugfix b/changelog.d/10223.bugfix new file mode 100644 index 0000000000..4e42f6b608 --- /dev/null +++ b/changelog.d/10223.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug which meant that invite rejections and knocks were not sent out over federation in a timely manner. diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index ba060104c3..aca32edc17 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -65,4 +65,4 @@ if [[ -n "$1" ]]; then fi # Run the tests! -go test -v -tags synapse_blacklist,msc2946,msc3083,msc2716 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests +go test -v -tags synapse_blacklist,msc2946,msc3083,msc2716,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 1b566dbf2d..74d169a2ac 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1980,6 +1980,13 @@ class FederationHandler(BaseHandler): event.internal_metadata.outlier = False + # Send this event on behalf of the other server. + # + # The remote server isn't a full participant in the room at this point, so + # may not have an up-to-date list of the other homeservers participating in + # the room, so we send it on their behalf. + event.internal_metadata.send_on_behalf_of = origin + context = await self.state_handler.compute_event_context(event) await self._auth_and_persist_event(origin, event, context) @@ -2084,6 +2091,13 @@ class FederationHandler(BaseHandler): event.internal_metadata.outlier = False + # Send this event on behalf of the other server. + # + # The remote server isn't a full participant in the room at this point, so + # may not have an up-to-date list of the other homeservers participating in + # the room, so we send it on their behalf. + event.internal_metadata.send_on_behalf_of = origin + context = await self.state_handler.compute_event_context(event) event_allowed = await self.third_party_event_rules.check_event_allowed( -- cgit 1.5.1 From e19e3d452d7553cad974556c723b7a17e6f11a9d Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 23 Jun 2021 16:14:52 +0200 Subject: Improve the reliability of auto-joining remote rooms (#10237) If a room is remote and we don't have a user in it, always try to join it. It might fail if the room is invite-only, but we don't have a user to invite with, so at this point it's the best we can do. Fixes #10233 (at least to some extent) --- changelog.d/10237.misc | 1 + synapse/handlers/register.py | 63 ++++++++++++++++++++++++++++++----------- tests/handlers/test_register.py | 49 +++++++++++++++++++++++++++++++- 3 files changed, 96 insertions(+), 17 deletions(-) create mode 100644 changelog.d/10237.misc (limited to 'synapse') diff --git a/changelog.d/10237.misc b/changelog.d/10237.misc new file mode 100644 index 0000000000..d76c119a41 --- /dev/null +++ b/changelog.d/10237.misc @@ -0,0 +1 @@ +Improve the reliability of auto-joining remote rooms. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ca1ed6a5c0..4b4b579741 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -386,11 +386,32 @@ class RegistrationHandler(BaseHandler): room_alias = RoomAlias.from_string(r) if self.hs.hostname != room_alias.domain: - logger.warning( - "Cannot create room alias %s, " - "it does not match server domain", + # If the alias is remote, try to join the room. This might fail + # because the room might be invite only, but we don't have any local + # user in the room to invite this one with, so at this point that's + # the best we can do. + logger.info( + "Cannot automatically create room with alias %s as it isn't" + " local, trying to join the room instead", r, ) + + ( + room, + remote_room_hosts, + ) = await room_member_handler.lookup_room_alias(room_alias) + room_id = room.to_string() + + await room_member_handler.update_membership( + requester=create_requester( + user_id, authenticated_entity=self._server_name + ), + target=UserID.from_string(user_id), + room_id=room_id, + remote_room_hosts=remote_room_hosts, + action="join", + ratelimit=False, + ) else: # A shallow copy is OK here since the only key that is # modified is room_alias_name. @@ -448,22 +469,32 @@ class RegistrationHandler(BaseHandler): ) # Calculate whether the room requires an invite or can be - # joined directly. Note that unless a join rule of public exists, - # it is treated as requiring an invite. - requires_invite = True - - state = await self.store.get_filtered_current_state_ids( - room_id, StateFilter.from_types([(EventTypes.JoinRules, "")]) + # joined directly. By default, we consider the room as requiring an + # invite if the homeserver is in the room (unless told otherwise by the + # join rules). Otherwise we consider it as being joinable, at the risk of + # failing to join, but in this case there's little more we can do since + # we don't have a local user in the room to craft up an invite with. + requires_invite = await self.store.is_host_joined( + room_id, + self.server_name, ) - event_id = state.get((EventTypes.JoinRules, "")) - if event_id: - join_rules_event = await self.store.get_event( - event_id, allow_none=True + if requires_invite: + # If the server is in the room, check if the room is public. + state = await self.store.get_filtered_current_state_ids( + room_id, StateFilter.from_types([(EventTypes.JoinRules, "")]) ) - if join_rules_event: - join_rule = join_rules_event.content.get("join_rule", None) - requires_invite = join_rule and join_rule != JoinRules.PUBLIC + + event_id = state.get((EventTypes.JoinRules, "")) + if event_id: + join_rules_event = await self.store.get_event( + event_id, allow_none=True + ) + if join_rules_event: + join_rule = join_rules_event.content.get("join_rule", None) + requires_invite = ( + join_rule and join_rule != JoinRules.PUBLIC + ) # Send the invite, if necessary. if requires_invite: diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index a9fd3036dc..c901003225 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -18,7 +18,7 @@ from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import Codes, ResourceLimitError, SynapseError from synapse.spam_checker_api import RegistrationBehaviour -from synapse.types import RoomAlias, UserID, create_requester +from synapse.types import RoomAlias, RoomID, UserID, create_requester from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -643,3 +643,50 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) return user_id, token + + +class RemoteAutoJoinTestCase(unittest.HomeserverTestCase): + """Tests auto-join on remote rooms.""" + + def make_homeserver(self, reactor, clock): + self.room_id = "!roomid:remotetest" + + async def update_membership(*args, **kwargs): + pass + + async def lookup_room_alias(*args, **kwargs): + return RoomID.from_string(self.room_id), ["remotetest"] + + self.room_member_handler = Mock(spec=["update_membership", "lookup_room_alias"]) + self.room_member_handler.update_membership.side_effect = update_membership + self.room_member_handler.lookup_room_alias.side_effect = lookup_room_alias + + hs = self.setup_test_homeserver(room_member_handler=self.room_member_handler) + return hs + + def prepare(self, reactor, clock, hs): + self.handler = self.hs.get_registration_handler() + self.store = self.hs.get_datastore() + + @override_config({"auto_join_rooms": ["#room:remotetest"]}) + def test_auto_create_auto_join_remote_room(self): + """Tests that we don't attempt to create remote rooms, and that we don't attempt + to invite ourselves to rooms we're not in.""" + + # Register a first user; this should call _create_and_join_rooms + self.get_success(self.handler.register_user(localpart="jeff")) + + _, kwargs = self.room_member_handler.update_membership.call_args + + self.assertEqual(kwargs["room_id"], self.room_id) + self.assertEqual(kwargs["action"], "join") + self.assertEqual(kwargs["remote_room_hosts"], ["remotetest"]) + + # Register a second user; this should call _join_rooms + self.get_success(self.handler.register_user(localpart="jeff2")) + + _, kwargs = self.room_member_handler.update_membership.call_args + + self.assertEqual(kwargs["room_id"], self.room_id) + self.assertEqual(kwargs["action"], "join") + self.assertEqual(kwargs["remote_room_hosts"], ["remotetest"]) -- cgit 1.5.1 From 394673055db4df49bfd58c2f6118834a6d928563 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 23 Jun 2021 15:57:41 +0100 Subject: Re-introduce "Leave out optional keys from /sync" change (#10214) Required some fixes due to merge conflicts with #6739, but nothing too hairy. The first commit is the same as the original (after merge conflict resolution) then two more for compatibility with the latest sync code. --- changelog.d/10214.feature | 1 + synapse/rest/client/v2_alpha/sync.py | 69 ++++++++++++++-------- tests/rest/client/v2_alpha/test_sync.py | 30 +--------- .../test_resource_limits_server_notices.py | 8 ++- 4 files changed, 53 insertions(+), 55 deletions(-) create mode 100644 changelog.d/10214.feature (limited to 'synapse') diff --git a/changelog.d/10214.feature b/changelog.d/10214.feature new file mode 100644 index 0000000000..a3818c9d25 --- /dev/null +++ b/changelog.d/10214.feature @@ -0,0 +1 @@ +Omit empty fields from the `/sync` response. Contributed by @deepbluev7. \ No newline at end of file diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 042e1788b6..ecbbcf3851 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -13,6 +13,7 @@ # limitations under the License. import itertools import logging +from collections import defaultdict from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple from synapse.api.constants import Membership, PresenceState @@ -232,29 +233,51 @@ class SyncRestServlet(RestServlet): ) logger.debug("building sync response dict") - return { - "account_data": {"events": sync_result.account_data}, - "to_device": {"events": sync_result.to_device}, - "device_lists": { - "changed": list(sync_result.device_lists.changed), - "left": list(sync_result.device_lists.left), - }, - "presence": SyncRestServlet.encode_presence(sync_result.presence, time_now), - "rooms": { - Membership.JOIN: joined, - Membership.INVITE: invited, - Membership.KNOCK: knocked, - Membership.LEAVE: archived, - }, - "groups": { - Membership.JOIN: sync_result.groups.join, - Membership.INVITE: sync_result.groups.invite, - Membership.LEAVE: sync_result.groups.leave, - }, - "device_one_time_keys_count": sync_result.device_one_time_keys_count, - "org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types, - "next_batch": await sync_result.next_batch.to_string(self.store), - } + + response: dict = defaultdict(dict) + response["next_batch"] = await sync_result.next_batch.to_string(self.store) + + if sync_result.account_data: + response["account_data"] = {"events": sync_result.account_data} + if sync_result.presence: + response["presence"] = SyncRestServlet.encode_presence( + sync_result.presence, time_now + ) + + if sync_result.to_device: + response["to_device"] = {"events": sync_result.to_device} + + if sync_result.device_lists.changed: + response["device_lists"]["changed"] = list(sync_result.device_lists.changed) + if sync_result.device_lists.left: + response["device_lists"]["left"] = list(sync_result.device_lists.left) + + if sync_result.device_one_time_keys_count: + response[ + "device_one_time_keys_count" + ] = sync_result.device_one_time_keys_count + if sync_result.device_unused_fallback_key_types: + response[ + "org.matrix.msc2732.device_unused_fallback_key_types" + ] = sync_result.device_unused_fallback_key_types + + if joined: + response["rooms"][Membership.JOIN] = joined + if invited: + response["rooms"][Membership.INVITE] = invited + if knocked: + response["rooms"][Membership.KNOCK] = knocked + if archived: + response["rooms"][Membership.LEAVE] = archived + + if sync_result.groups.join: + response["groups"][Membership.JOIN] = sync_result.groups.join + if sync_result.groups.invite: + response["groups"][Membership.INVITE] = sync_result.groups.invite + if sync_result.groups.leave: + response["groups"][Membership.LEAVE] = sync_result.groups.leave + + return response @staticmethod def encode_presence(events, time_now): diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index 012910f136..cdca3a3e23 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -41,35 +41,7 @@ class FilterTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", "/sync") self.assertEqual(channel.code, 200) - self.assertTrue( - { - "next_batch", - "rooms", - "presence", - "account_data", - "to_device", - "device_lists", - }.issubset(set(channel.json_body.keys())) - ) - - def test_sync_presence_disabled(self): - """ - When presence is disabled, the key does not appear in /sync. - """ - self.hs.config.use_presence = False - - channel = self.make_request("GET", "/sync") - - self.assertEqual(channel.code, 200) - self.assertTrue( - { - "next_batch", - "rooms", - "account_data", - "to_device", - "device_lists", - }.issubset(set(channel.json_body.keys())) - ) + self.assertIn("next_batch", channel.json_body) class SyncFilterTestCase(unittest.HomeserverTestCase): diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index d46521ccdc..3245aa91ca 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -306,8 +306,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): channel = self.make_request("GET", "/sync?timeout=0", access_token=tok) - invites = channel.json_body["rooms"]["invite"] - self.assertEqual(len(invites), 0, invites) + self.assertNotIn( + "rooms", channel.json_body, "Got invites without server notice" + ) def test_invite_with_notice(self): """Tests that, if the MAU limit is hit, the server notices user invites each user @@ -364,7 +365,8 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): # We could also pick another user and sync with it, which would return an # invite to a system notices room, but it doesn't matter which user we're # using so we use the last one because it saves us an extra sync. - invites = channel.json_body["rooms"]["invite"] + if "rooms" in channel.json_body: + invites = channel.json_body["rooms"]["invite"] # Make sure we have an invite to process. self.assertEqual(len(invites), 1, invites) -- cgit 1.5.1 From bd4919fb72b2a75f1c0a7f0c78bd619fd2ae30e8 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 24 Jun 2021 15:33:20 +0200 Subject: MSC2918 Refresh tokens implementation (#9450) This implements refresh tokens, as defined by MSC2918 This MSC has been implemented client side in Hydrogen Web: vector-im/hydrogen-web#235 The basics of the MSC works: requesting refresh tokens on login, having the access tokens expire, and using the refresh token to get a new one. Signed-off-by: Quentin Gliech --- changelog.d/9450.feature | 1 + scripts/synapse_port_db | 4 +- synapse/api/auth.py | 5 + synapse/config/registration.py | 21 ++ synapse/handlers/auth.py | 132 ++++++++++++- synapse/handlers/register.py | 52 ++++- synapse/module_api/__init__.py | 2 +- synapse/replication/http/login.py | 13 +- synapse/rest/client/v1/login.py | 171 +++++++++++++--- synapse/rest/client/v2_alpha/register.py | 88 +++++++-- synapse/storage/databases/main/registration.py | 207 ++++++++++++++++++- .../schema/main/delta/59/14refresh_tokens.sql | 34 ++++ tests/api/test_auth.py | 1 + tests/handlers/test_device.py | 2 +- tests/rest/client/v2_alpha/test_auth.py | 220 ++++++++++++++++++++- 15 files changed, 892 insertions(+), 61 deletions(-) create mode 100644 changelog.d/9450.feature create mode 100644 synapse/storage/schema/main/delta/59/14refresh_tokens.sql (limited to 'synapse') diff --git a/changelog.d/9450.feature b/changelog.d/9450.feature new file mode 100644 index 0000000000..455936a41d --- /dev/null +++ b/changelog.d/9450.feature @@ -0,0 +1 @@ +Implement refresh tokens as specified by [MSC2918](https://github.com/matrix-org/matrix-doc/pull/2918). diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 86eb76cbca..2bbaf5557d 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -93,6 +93,7 @@ BOOLEAN_COLUMNS = { "local_media_repository": ["safe_from_quarantine"], "users": ["shadow_banned"], "e2e_fallback_keys_json": ["used"], + "access_tokens": ["used"], } @@ -307,7 +308,8 @@ class Porter(object): information_schema.table_constraints AS tc INNER JOIN information_schema.constraint_column_usage AS ccu USING (table_schema, constraint_name) - WHERE tc.constraint_type = 'FOREIGN KEY'; + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_name != ccu.table_name; """ txn.execute(sql) diff --git a/synapse/api/auth.py b/synapse/api/auth.py index edf1b918eb..29cf257633 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -245,6 +245,11 @@ class Auth: errcode=Codes.GUEST_ACCESS_FORBIDDEN, ) + # Mark the token as used. This is used to invalidate old refresh + # tokens after some time. + if not user_info.token_used and token_id is not None: + await self.store.mark_access_token_as_used(token_id) + requester = create_requester( user_info.user_id, token_id, diff --git a/synapse/config/registration.py b/synapse/config/registration.py index d9dc55a0c3..0ad919b139 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -119,6 +119,27 @@ class RegistrationConfig(Config): session_lifetime = self.parse_duration(session_lifetime) self.session_lifetime = session_lifetime + # The `access_token_lifetime` applies for tokens that can be renewed + # using a refresh token, as per MSC2918. If it is `None`, the refresh + # token mechanism is disabled. + # + # Since it is incompatible with the `session_lifetime` mechanism, it is set to + # `None` by default if a `session_lifetime` is set. + access_token_lifetime = config.get( + "access_token_lifetime", "5m" if session_lifetime is None else None + ) + if access_token_lifetime is not None: + access_token_lifetime = self.parse_duration(access_token_lifetime) + self.access_token_lifetime = access_token_lifetime + + if session_lifetime is not None and access_token_lifetime is not None: + raise ConfigError( + "The refresh token mechanism is incompatible with the " + "`session_lifetime` option. Consider disabling the " + "`session_lifetime` option or disabling the refresh token " + "mechanism by removing the `access_token_lifetime` option." + ) + # The success template used during fallback auth. self.fallback_success_template = self.read_template("auth_success.html") diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 1971e373ed..e2ac595a62 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -30,6 +30,7 @@ from typing import ( Optional, Tuple, Union, + cast, ) import attr @@ -72,6 +73,7 @@ from synapse.util.stringutils import base62_encode from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: + from synapse.rest.client.v1.login import LoginResponse from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -777,6 +779,108 @@ class AuthHandler(BaseHandler): "params": params, } + async def refresh_token( + self, + refresh_token: str, + valid_until_ms: Optional[int], + ) -> Tuple[str, str]: + """ + Consumes a refresh token and generate both a new access token and a new refresh token from it. + + The consumed refresh token is considered invalid after the first use of the new access token or the new refresh token. + + Args: + refresh_token: The token to consume. + valid_until_ms: The expiration timestamp of the new access token. + + Returns: + A tuple containing the new access token and refresh token + """ + + # Verify the token signature first before looking up the token + if not self._verify_refresh_token(refresh_token): + raise SynapseError(401, "invalid refresh token", Codes.UNKNOWN_TOKEN) + + existing_token = await self.store.lookup_refresh_token(refresh_token) + if existing_token is None: + raise SynapseError(401, "refresh token does not exist", Codes.UNKNOWN_TOKEN) + + if ( + existing_token.has_next_access_token_been_used + or existing_token.has_next_refresh_token_been_refreshed + ): + raise SynapseError( + 403, "refresh token isn't valid anymore", Codes.FORBIDDEN + ) + + ( + new_refresh_token, + new_refresh_token_id, + ) = await self.get_refresh_token_for_user_id( + user_id=existing_token.user_id, device_id=existing_token.device_id + ) + access_token = await self.get_access_token_for_user_id( + user_id=existing_token.user_id, + device_id=existing_token.device_id, + valid_until_ms=valid_until_ms, + refresh_token_id=new_refresh_token_id, + ) + await self.store.replace_refresh_token( + existing_token.token_id, new_refresh_token_id + ) + return access_token, new_refresh_token + + def _verify_refresh_token(self, token: str) -> bool: + """ + Verifies the shape of a refresh token. + + Args: + token: The refresh token to verify + + Returns: + Whether the token has the right shape + """ + parts = token.split("_", maxsplit=4) + if len(parts) != 4: + return False + + type, localpart, rand, crc = parts + + # Refresh tokens are prefixed by "syr_", let's check that + if type != "syr": + return False + + # Check the CRC + base = f"{type}_{localpart}_{rand}" + expected_crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) + if crc != expected_crc: + return False + + return True + + async def get_refresh_token_for_user_id( + self, + user_id: str, + device_id: str, + ) -> Tuple[str, int]: + """ + Creates a new refresh token for the user with the given user ID. + + Args: + user_id: canonical user ID + device_id: the device ID to associate with the token. + + Returns: + The newly created refresh token and its ID in the database + """ + refresh_token = self.generate_refresh_token(UserID.from_string(user_id)) + refresh_token_id = await self.store.add_refresh_token_to_user( + user_id=user_id, + token=refresh_token, + device_id=device_id, + ) + return refresh_token, refresh_token_id + async def get_access_token_for_user_id( self, user_id: str, @@ -784,6 +888,7 @@ class AuthHandler(BaseHandler): valid_until_ms: Optional[int], puppets_user_id: Optional[str] = None, is_appservice_ghost: bool = False, + refresh_token_id: Optional[int] = None, ) -> str: """ Creates a new access token for the user with the given user ID. @@ -801,6 +906,8 @@ class AuthHandler(BaseHandler): valid_until_ms: when the token is valid until. None for no expiry. is_appservice_ghost: Whether the user is an application ghost user + refresh_token_id: the refresh token ID that will be associated with + this access token. Returns: The access token for the user's session. Raises: @@ -836,6 +943,7 @@ class AuthHandler(BaseHandler): device_id=device_id, valid_until_ms=valid_until_ms, puppets_user_id=puppets_user_id, + refresh_token_id=refresh_token_id, ) # the device *should* have been registered before we got here; however, @@ -928,7 +1036,7 @@ class AuthHandler(BaseHandler): self, login_submission: Dict[str, Any], ratelimit: bool = False, - ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: + ) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]: """Authenticates the user for the /login API Also used by the user-interactive auth flow to validate auth types which don't @@ -1073,7 +1181,7 @@ class AuthHandler(BaseHandler): self, username: str, login_submission: Dict[str, Any], - ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: + ) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]: """Helper for validate_login Handles login, once we've mapped 3pids onto userids @@ -1151,7 +1259,7 @@ class AuthHandler(BaseHandler): async def check_password_provider_3pid( self, medium: str, address: str, password: str - ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: + ) -> Tuple[Optional[str], Optional[Callable[["LoginResponse"], Awaitable[None]]]]: """Check if a password provider is able to validate a thirdparty login Args: @@ -1215,6 +1323,19 @@ class AuthHandler(BaseHandler): crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) return f"{base}_{crc}" + def generate_refresh_token(self, for_user: UserID) -> str: + """Generates an opaque string, for use as a refresh token""" + + # we use the following format for refresh tokens: + # syr___ + + b64local = unpaddedbase64.encode_base64(for_user.localpart.encode("utf-8")) + random_string = stringutils.random_string(20) + base = f"syr_{b64local}_{random_string}" + + crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) + return f"{base}_{crc}" + async def validate_short_term_login_token( self, login_token: str ) -> LoginTokenAttributes: @@ -1563,7 +1684,7 @@ class AuthHandler(BaseHandler): ) respond_with_html(request, 200, html) - async def _sso_login_callback(self, login_result: JsonDict) -> None: + async def _sso_login_callback(self, login_result: "LoginResponse") -> None: """ A login callback which might add additional attributes to the login response. @@ -1577,7 +1698,8 @@ class AuthHandler(BaseHandler): extra_attributes = self._extra_attributes.get(login_result["user_id"]) if extra_attributes: - login_result.update(extra_attributes.extra_attributes) + login_result_dict = cast(Dict[str, Any], login_result) + login_result_dict.update(extra_attributes.extra_attributes) def _expire_sso_extra_attributes(self) -> None: """ diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 4b4b579741..26ef016179 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -15,9 +15,10 @@ """Contains functions for registering clients.""" import logging -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from prometheus_client import Counter +from typing_extensions import TypedDict from synapse import types from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType @@ -54,6 +55,16 @@ login_counter = Counter( ["guest", "auth_provider"], ) +LoginDict = TypedDict( + "LoginDict", + { + "device_id": str, + "access_token": str, + "valid_until_ms": Optional[int], + "refresh_token": Optional[str], + }, +) + class RegistrationHandler(BaseHandler): def __init__(self, hs: "HomeServer"): @@ -85,6 +96,7 @@ class RegistrationHandler(BaseHandler): self.pusher_pool = hs.get_pusherpool() self.session_lifetime = hs.config.session_lifetime + self.access_token_lifetime = hs.config.access_token_lifetime async def check_username( self, @@ -696,7 +708,8 @@ class RegistrationHandler(BaseHandler): is_guest: bool = False, is_appservice_ghost: bool = False, auth_provider_id: Optional[str] = None, - ) -> Tuple[str, str]: + should_issue_refresh_token: bool = False, + ) -> Tuple[str, str, Optional[int], Optional[str]]: """Register a device for a user and generate an access token. The access token will be limited by the homeserver's session_lifetime config. @@ -708,8 +721,9 @@ class RegistrationHandler(BaseHandler): is_guest: Whether this is a guest account auth_provider_id: The SSO IdP the user used, if any (just used for the prometheus metrics). + should_issue_refresh_token: Whether it should also issue a refresh token Returns: - Tuple of device ID and access token + Tuple of device ID, access token, access token expiration time and refresh token """ res = await self._register_device_client( user_id=user_id, @@ -717,6 +731,7 @@ class RegistrationHandler(BaseHandler): initial_display_name=initial_display_name, is_guest=is_guest, is_appservice_ghost=is_appservice_ghost, + should_issue_refresh_token=should_issue_refresh_token, ) login_counter.labels( @@ -724,7 +739,12 @@ class RegistrationHandler(BaseHandler): auth_provider=(auth_provider_id or ""), ).inc() - return res["device_id"], res["access_token"] + return ( + res["device_id"], + res["access_token"], + res["valid_until_ms"], + res["refresh_token"], + ) async def register_device_inner( self, @@ -733,7 +753,8 @@ class RegistrationHandler(BaseHandler): initial_display_name: Optional[str], is_guest: bool = False, is_appservice_ghost: bool = False, - ) -> Dict[str, str]: + should_issue_refresh_token: bool = False, + ) -> LoginDict: """Helper for register_device Does the bits that need doing on the main process. Not for use outside this @@ -748,6 +769,9 @@ class RegistrationHandler(BaseHandler): ) valid_until_ms = self.clock.time_msec() + self.session_lifetime + refresh_token = None + refresh_token_id = None + registered_device_id = await self.device_handler.check_device_registered( user_id, device_id, initial_display_name ) @@ -755,14 +779,30 @@ class RegistrationHandler(BaseHandler): assert valid_until_ms is None access_token = self.macaroon_gen.generate_guest_access_token(user_id) else: + if should_issue_refresh_token: + ( + refresh_token, + refresh_token_id, + ) = await self._auth_handler.get_refresh_token_for_user_id( + user_id, + device_id=registered_device_id, + ) + valid_until_ms = self.clock.time_msec() + self.access_token_lifetime + access_token = await self._auth_handler.get_access_token_for_user_id( user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms, is_appservice_ghost=is_appservice_ghost, + refresh_token_id=refresh_token_id, ) - return {"device_id": registered_device_id, "access_token": access_token} + return { + "device_id": registered_device_id, + "access_token": access_token, + "valid_until_ms": valid_until_ms, + "refresh_token": refresh_token, + } async def post_registration_actions( self, user_id: str, auth_result: dict, access_token: Optional[str] diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 58b255eb1b..721c45abac 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -168,7 +168,7 @@ class ModuleApi: "Using deprecated ModuleApi.register which creates a dummy user device." ) user_id = yield self.register_user(localpart, displayname, emails or []) - _, access_token = yield self.register_device(user_id) + _, access_token, _, _ = yield self.register_device(user_id) return user_id, access_token def register_user( diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index c2e8c00293..550bd5c95f 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -36,20 +36,29 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): @staticmethod async def _serialize_payload( - user_id, device_id, initial_display_name, is_guest, is_appservice_ghost + user_id, + device_id, + initial_display_name, + is_guest, + is_appservice_ghost, + should_issue_refresh_token, ): """ Args: + user_id (int) device_id (str|None): Device ID to use, if None a new one is generated. initial_display_name (str|None) is_guest (bool) + is_appservice_ghost (bool) + should_issue_refresh_token (bool) """ return { "device_id": device_id, "initial_display_name": initial_display_name, "is_guest": is_guest, "is_appservice_ghost": is_appservice_ghost, + "should_issue_refresh_token": should_issue_refresh_token, } async def _handle_request(self, request, user_id): @@ -59,6 +68,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): initial_display_name = content["initial_display_name"] is_guest = content["is_guest"] is_appservice_ghost = content["is_appservice_ghost"] + should_issue_refresh_token = content["should_issue_refresh_token"] res = await self.registration_handler.register_device_inner( user_id, @@ -66,6 +76,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): initial_display_name, is_guest, is_appservice_ghost=is_appservice_ghost, + should_issue_refresh_token=should_issue_refresh_token, ) return 200, res diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index f6be5f1020..cbcb60fe31 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -14,7 +14,9 @@ import logging import re -from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional + +from typing_extensions import TypedDict from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter @@ -25,6 +27,8 @@ from synapse.http import get_request_uri from synapse.http.server import HttpServer, finish_request from synapse.http.servlet import ( RestServlet, + assert_params_in_dict, + parse_boolean, parse_bytes_from_args, parse_json_object_from_request, parse_string, @@ -40,6 +44,21 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +LoginResponse = TypedDict( + "LoginResponse", + { + "user_id": str, + "access_token": str, + "home_server": str, + "expires_in_ms": Optional[int], + "refresh_token": Optional[str], + "device_id": str, + "well_known": Optional[Dict[str, Any]], + }, + total=False, +) + + class LoginRestServlet(RestServlet): PATTERNS = client_patterns("/login$", v1=True) CAS_TYPE = "m.login.cas" @@ -48,6 +67,7 @@ class LoginRestServlet(RestServlet): JWT_TYPE = "org.matrix.login.jwt" JWT_TYPE_DEPRECATED = "m.login.jwt" APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" + REFRESH_TOKEN_PARAM = "org.matrix.msc2918.refresh_token" def __init__(self, hs: "HomeServer"): super().__init__() @@ -65,9 +85,12 @@ class LoginRestServlet(RestServlet): self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled self._msc2858_enabled = hs.config.experimental.msc2858_enabled + self._msc2918_enabled = hs.config.access_token_lifetime is not None self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self._sso_handler = hs.get_sso_handler() @@ -138,6 +161,15 @@ class LoginRestServlet(RestServlet): async def on_POST(self, request: SynapseRequest): login_submission = parse_json_object_from_request(request) + if self._msc2918_enabled: + # Check if this login should also issue a refresh token, as per + # MSC2918 + should_issue_refresh_token = parse_boolean( + request, name=LoginRestServlet.REFRESH_TOKEN_PARAM, default=False + ) + else: + should_issue_refresh_token = False + try: if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: appservice = self.auth.get_appservice_by_req(request) @@ -147,19 +179,32 @@ class LoginRestServlet(RestServlet): None, request.getClientIP() ) - result = await self._do_appservice_login(login_submission, appservice) + result = await self._do_appservice_login( + login_submission, + appservice, + should_issue_refresh_token=should_issue_refresh_token, + ) elif self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): await self._address_ratelimiter.ratelimit(None, request.getClientIP()) - result = await self._do_jwt_login(login_submission) + result = await self._do_jwt_login( + login_submission, + should_issue_refresh_token=should_issue_refresh_token, + ) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: await self._address_ratelimiter.ratelimit(None, request.getClientIP()) - result = await self._do_token_login(login_submission) + result = await self._do_token_login( + login_submission, + should_issue_refresh_token=should_issue_refresh_token, + ) else: await self._address_ratelimiter.ratelimit(None, request.getClientIP()) - result = await self._do_other_login(login_submission) + result = await self._do_other_login( + login_submission, + should_issue_refresh_token=should_issue_refresh_token, + ) except KeyError: raise SynapseError(400, "Missing JSON keys.") @@ -169,7 +214,10 @@ class LoginRestServlet(RestServlet): return 200, result async def _do_appservice_login( - self, login_submission: JsonDict, appservice: ApplicationService + self, + login_submission: JsonDict, + appservice: ApplicationService, + should_issue_refresh_token: bool = False, ): identifier = login_submission.get("identifier") logger.info("Got appservice login request with identifier: %r", identifier) @@ -198,14 +246,21 @@ class LoginRestServlet(RestServlet): raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN) return await self._complete_login( - qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited() + qualified_user_id, + login_submission, + ratelimit=appservice.is_rate_limited(), + should_issue_refresh_token=should_issue_refresh_token, ) - async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]: + async def _do_other_login( + self, login_submission: JsonDict, should_issue_refresh_token: bool = False + ) -> LoginResponse: """Handle non-token/saml/jwt logins Args: login_submission: + should_issue_refresh_token: True if this login should issue + a refresh token alongside the access token. Returns: HTTP response @@ -224,7 +279,10 @@ class LoginRestServlet(RestServlet): login_submission, ratelimit=True ) result = await self._complete_login( - canonical_user_id, login_submission, callback + canonical_user_id, + login_submission, + callback, + should_issue_refresh_token=should_issue_refresh_token, ) return result @@ -232,11 +290,12 @@ class LoginRestServlet(RestServlet): self, user_id: str, login_submission: JsonDict, - callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None, + callback: Optional[Callable[[LoginResponse], Awaitable[None]]] = None, create_non_existent_users: bool = False, ratelimit: bool = True, auth_provider_id: Optional[str] = None, - ) -> Dict[str, str]: + should_issue_refresh_token: bool = False, + ) -> LoginResponse: """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on all successful logins. @@ -253,6 +312,8 @@ class LoginRestServlet(RestServlet): ratelimit: Whether to ratelimit the login request. auth_provider_id: The SSO IdP the user used, if any (just used for the prometheus metrics). + should_issue_refresh_token: True if this login should issue + a refresh token alongside the access token. Returns: result: Dictionary of account information after successful login. @@ -274,28 +335,48 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") - device_id, access_token = await self.registration_handler.register_device( - user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id + ( + device_id, + access_token, + valid_until_ms, + refresh_token, + ) = await self.registration_handler.register_device( + user_id, + device_id, + initial_display_name, + auth_provider_id=auth_provider_id, + should_issue_refresh_token=should_issue_refresh_token, ) - result = { - "user_id": user_id, - "access_token": access_token, - "home_server": self.hs.hostname, - "device_id": device_id, - } + result = LoginResponse( + user_id=user_id, + access_token=access_token, + home_server=self.hs.hostname, + device_id=device_id, + ) + + if valid_until_ms is not None: + expires_in_ms = valid_until_ms - self.clock.time_msec() + result["expires_in_ms"] = expires_in_ms + + if refresh_token is not None: + result["refresh_token"] = refresh_token if callback is not None: await callback(result) return result - async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]: + async def _do_token_login( + self, login_submission: JsonDict, should_issue_refresh_token: bool = False + ) -> LoginResponse: """ Handle the final stage of SSO login. Args: - login_submission: The JSON request body. + login_submission: The JSON request body. + should_issue_refresh_token: True if this login should issue + a refresh token alongside the access token. Returns: The body of the JSON response. @@ -309,9 +390,12 @@ class LoginRestServlet(RestServlet): login_submission, self.auth_handler._sso_login_callback, auth_provider_id=res.auth_provider_id, + should_issue_refresh_token=should_issue_refresh_token, ) - async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]: + async def _do_jwt_login( + self, login_submission: JsonDict, should_issue_refresh_token: bool = False + ) -> LoginResponse: token = login_submission.get("token", None) if token is None: raise LoginError( @@ -342,7 +426,10 @@ class LoginRestServlet(RestServlet): user_id = UserID(user, self.hs.hostname).to_string() result = await self._complete_login( - user_id, login_submission, create_non_existent_users=True + user_id, + login_submission, + create_non_existent_users=True, + should_issue_refresh_token=should_issue_refresh_token, ) return result @@ -371,6 +458,42 @@ def _get_auth_flow_dict_for_idp( return e +class RefreshTokenServlet(RestServlet): + PATTERNS = client_patterns( + "/org.matrix.msc2918.refresh_token/refresh$", releases=(), unstable=True + ) + + def __init__(self, hs: "HomeServer"): + self._auth_handler = hs.get_auth_handler() + self._clock = hs.get_clock() + self.access_token_lifetime = hs.config.access_token_lifetime + + async def on_POST( + self, + request: SynapseRequest, + ): + refresh_submission = parse_json_object_from_request(request) + + assert_params_in_dict(refresh_submission, ["refresh_token"]) + token = refresh_submission["refresh_token"] + if not isinstance(token, str): + raise SynapseError(400, "Invalid param: refresh_token", Codes.INVALID_PARAM) + + valid_until_ms = self._clock.time_msec() + self.access_token_lifetime + access_token, refresh_token = await self._auth_handler.refresh_token( + token, valid_until_ms + ) + expires_in_ms = valid_until_ms - self._clock.time_msec() + return ( + 200, + { + "access_token": access_token, + "refresh_token": refresh_token, + "expires_in_ms": expires_in_ms, + }, + ) + + class SsoRedirectServlet(RestServlet): PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [ re.compile( @@ -477,6 +600,8 @@ class CasTicketServlet(RestServlet): def register_servlets(hs, http_server): LoginRestServlet(hs).register(http_server) + if hs.config.access_token_lifetime is not None: + RefreshTokenServlet(hs).register(http_server) SsoRedirectServlet(hs).register(http_server) if hs.config.cas_enabled: CasTicketServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index a30a5df1b1..4d31584acd 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -41,11 +41,13 @@ from synapse.http.server import finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, assert_params_in_dict, + parse_boolean, parse_json_object_from_request, parse_string, ) from synapse.metrics import threepid_send_requests from synapse.push.mailer import Mailer +from synapse.types import JsonDict from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import assert_valid_client_secret, random_string @@ -399,6 +401,7 @@ class RegisterRestServlet(RestServlet): self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() self._registration_enabled = self.hs.config.enable_registration + self._msc2918_enabled = hs.config.access_token_lifetime is not None self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler @@ -424,6 +427,15 @@ class RegisterRestServlet(RestServlet): "Do not understand membership kind: %s" % (kind.decode("utf8"),) ) + if self._msc2918_enabled: + # Check if this registration should also issue a refresh token, as + # per MSC2918 + should_issue_refresh_token = parse_boolean( + request, name="org.matrix.msc2918.refresh_token", default=False + ) + else: + should_issue_refresh_token = False + # Pull out the provided username and do basic sanity checks early since # the auth layer will store these in sessions. desired_username = None @@ -462,7 +474,10 @@ class RegisterRestServlet(RestServlet): raise SynapseError(400, "Desired Username is missing or not a string") result = await self._do_appservice_registration( - desired_username, access_token, body + desired_username, + access_token, + body, + should_issue_refresh_token=should_issue_refresh_token, ) return 200, result @@ -665,7 +680,9 @@ class RegisterRestServlet(RestServlet): registered = True return_dict = await self._create_registration_details( - registered_user_id, params + registered_user_id, + params, + should_issue_refresh_token=should_issue_refresh_token, ) if registered: @@ -677,7 +694,9 @@ class RegisterRestServlet(RestServlet): return 200, return_dict - async def _do_appservice_registration(self, username, as_token, body): + async def _do_appservice_registration( + self, username, as_token, body, should_issue_refresh_token: bool = False + ): user_id = await self.registration_handler.appservice_register( username, as_token ) @@ -685,19 +704,27 @@ class RegisterRestServlet(RestServlet): user_id, body, is_appservice_ghost=True, + should_issue_refresh_token=should_issue_refresh_token, ) async def _create_registration_details( - self, user_id, params, is_appservice_ghost=False + self, + user_id: str, + params: JsonDict, + is_appservice_ghost: bool = False, + should_issue_refresh_token: bool = False, ): """Complete registration of newly-registered user Allocates device_id if one was not given; also creates access_token. Args: - (str) user_id: full canonical @user:id - (object) params: registration parameters, from which we pull - device_id, initial_device_name and inhibit_login + user_id: full canonical @user:id + params: registration parameters, from which we pull device_id, + initial_device_name and inhibit_login + is_appservice_ghost + should_issue_refresh_token: True if this registration should issue + a refresh token alongside the access token. Returns: dictionary for response from /register """ @@ -705,15 +732,29 @@ class RegisterRestServlet(RestServlet): if not params.get("inhibit_login", False): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") - device_id, access_token = await self.registration_handler.register_device( + ( + device_id, + access_token, + valid_until_ms, + refresh_token, + ) = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=False, is_appservice_ghost=is_appservice_ghost, + should_issue_refresh_token=should_issue_refresh_token, ) result.update({"access_token": access_token, "device_id": device_id}) + + if valid_until_ms is not None: + expires_in_ms = valid_until_ms - self.clock.time_msec() + result["expires_in_ms"] = expires_in_ms + + if refresh_token is not None: + result["refresh_token"] = refresh_token + return result async def _do_guest_registration(self, params, address=None): @@ -727,19 +768,30 @@ class RegisterRestServlet(RestServlet): # we have nowhere to store it. device_id = synapse.api.auth.GUEST_DEVICE_ID initial_display_name = params.get("initial_device_display_name") - device_id, access_token = await self.registration_handler.register_device( + ( + device_id, + access_token, + valid_until_ms, + refresh_token, + ) = await self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=True ) - return ( - 200, - { - "user_id": user_id, - "device_id": device_id, - "access_token": access_token, - "home_server": self.hs.hostname, - }, - ) + result = { + "user_id": user_id, + "device_id": device_id, + "access_token": access_token, + "home_server": self.hs.hostname, + } + + if valid_until_ms is not None: + expires_in_ms = valid_until_ms - self.clock.time_msec() + result["expires_in_ms"] = expires_in_ms + + if refresh_token is not None: + result["refresh_token"] = refresh_token + + return 200, result def _calculate_registration_flows( diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index e5c5cf8ff0..e31c5864ac 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -53,6 +53,9 @@ class TokenLookupResult: valid_until_ms: The timestamp the token expires, if any. token_owner: The "owner" of the token. This is either the same as the user, or a server admin who is logged in as the user. + token_used: True if this token was used at least once in a request. + This field can be out of date since `get_user_by_access_token` is + cached. """ user_id = attr.ib(type=str) @@ -62,6 +65,7 @@ class TokenLookupResult: device_id = attr.ib(type=Optional[str], default=None) valid_until_ms = attr.ib(type=Optional[int], default=None) token_owner = attr.ib(type=str) + token_used = attr.ib(type=bool, default=False) # Make the token owner default to the user ID, which is the common case. @token_owner.default @@ -69,6 +73,29 @@ class TokenLookupResult: return self.user_id +@attr.s(frozen=True, slots=True) +class RefreshTokenLookupResult: + """Result of looking up a refresh token.""" + + user_id = attr.ib(type=str) + """The user this token belongs to.""" + + device_id = attr.ib(type=str) + """The device associated with this refresh token.""" + + token_id = attr.ib(type=int) + """The ID of this refresh token.""" + + next_token_id = attr.ib(type=Optional[int]) + """The ID of the refresh token which replaced this one.""" + + has_next_refresh_token_been_refreshed = attr.ib(type=bool) + """True if the next refresh token was used for another refresh.""" + + has_next_access_token_been_used = attr.ib(type=bool) + """True if the next access token was already used at least once.""" + + class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -441,7 +468,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): access_tokens.id as token_id, access_tokens.device_id, access_tokens.valid_until_ms, - access_tokens.user_id as token_owner + access_tokens.user_id as token_owner, + access_tokens.used as token_used FROM users INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id) WHERE token = ? @@ -449,8 +477,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): txn.execute(sql, (token,)) rows = self.db_pool.cursor_to_dict(txn) + if rows: - return TokenLookupResult(**rows[0]) + row = rows[0] + + # This field is nullable, ensure it comes out as a boolean + if row["token_used"] is None: + row["token_used"] = False + + return TokenLookupResult(**row) return None @@ -1072,6 +1107,111 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="update_access_token_last_validated", ) + @cached() + async def mark_access_token_as_used(self, token_id: int) -> None: + """ + Mark the access token as used, which invalidates the refresh token used + to obtain it. + + Because get_user_by_access_token is cached, this function might be + called multiple times for the same token, effectively doing unnecessary + SQL updates. Because updating the `used` field only goes one way (from + False to True) it is safe to cache this function as well to avoid this + issue. + + Args: + token_id: The ID of the access token to update. + Raises: + StoreError if there was a problem updating this. + """ + await self.db_pool.simple_update_one( + "access_tokens", + {"id": token_id}, + {"used": True}, + desc="mark_access_token_as_used", + ) + + async def lookup_refresh_token( + self, token: str + ) -> Optional[RefreshTokenLookupResult]: + """Lookup a refresh token with hints about its validity.""" + + def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]: + txn.execute( + """ + SELECT + rt.id token_id, + rt.user_id, + rt.device_id, + rt.next_token_id, + (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed, + at.used has_next_access_token_been_used + FROM refresh_tokens rt + LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id + LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id + WHERE rt.token = ? + """, + (token,), + ) + row = txn.fetchone() + + if row is None: + return None + + return RefreshTokenLookupResult( + token_id=row[0], + user_id=row[1], + device_id=row[2], + next_token_id=row[3], + has_next_refresh_token_been_refreshed=row[4], + # This column is nullable, ensure it's a boolean + has_next_access_token_been_used=(row[5] or False), + ) + + return await self.db_pool.runInteraction( + "lookup_refresh_token", _lookup_refresh_token_txn + ) + + async def replace_refresh_token(self, token_id: int, next_token_id: int) -> None: + """ + Set the successor of a refresh token, removing the existing successor + if any. + + Args: + token_id: ID of the refresh token to update. + next_token_id: ID of its successor. + """ + + def _replace_refresh_token_txn(txn) -> None: + # First check if there was an existing refresh token + old_next_token_id = self.db_pool.simple_select_one_onecol_txn( + txn, + "refresh_tokens", + {"id": token_id}, + "next_token_id", + allow_none=True, + ) + + self.db_pool.simple_update_one_txn( + txn, + "refresh_tokens", + {"id": token_id}, + {"next_token_id": next_token_id}, + ) + + # Delete the old "next" token if it exists. This should cascade and + # delete the associated access_token + if old_next_token_id is not None: + self.db_pool.simple_delete_one_txn( + txn, + "refresh_tokens", + {"id": old_next_token_id}, + ) + + await self.db_pool.runInteraction( + "replace_refresh_token", _replace_refresh_token_txn + ) + class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__( @@ -1263,6 +1403,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") + self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") async def add_access_token_to_user( self, @@ -1271,14 +1412,18 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): device_id: Optional[str], valid_until_ms: Optional[int], puppets_user_id: Optional[str] = None, + refresh_token_id: Optional[int] = None, ) -> int: """Adds an access token for the given user. Args: user_id: The user ID. token: The new access token to add. - device_id: ID of the device to associate with the access token + device_id: ID of the device to associate with the access token. valid_until_ms: when the token is valid until. None for no expiry. + puppets_user_id + refresh_token_id: ID of the refresh token generated alongside this + access token. Raises: StoreError if there was a problem adding this. Returns: @@ -1297,12 +1442,47 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): "valid_until_ms": valid_until_ms, "puppets_user_id": puppets_user_id, "last_validated": now, + "refresh_token_id": refresh_token_id, + "used": False, }, desc="add_access_token_to_user", ) return next_id + async def add_refresh_token_to_user( + self, + user_id: str, + token: str, + device_id: Optional[str], + ) -> int: + """Adds a refresh token for the given user. + + Args: + user_id: The user ID. + token: The new access token to add. + device_id: ID of the device to associate with the refresh token. + Raises: + StoreError if there was a problem adding this. + Returns: + The token ID + """ + next_id = self._refresh_tokens_id_gen.get_next() + + await self.db_pool.simple_insert( + "refresh_tokens", + { + "id": next_id, + "user_id": user_id, + "device_id": device_id, + "token": token, + "next_token_id": None, + }, + desc="add_refresh_token_to_user", + ) + + return next_id + def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str: old_device_id = self.db_pool.simple_select_one_onecol_txn( txn, "access_tokens", {"token": token}, "device_id" @@ -1545,7 +1725,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): device_id: Optional[str] = None, ) -> List[Tuple[str, int, Optional[str]]]: """ - Invalidate access tokens belonging to a user + Invalidate access and refresh tokens belonging to a user Args: user_id: ID of user the tokens belong to @@ -1565,7 +1745,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): items = keyvalues.items() where_clause = " AND ".join(k + " = ?" for k, _ in items) values = [v for _, v in items] # type: List[Union[str, int]] + # Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat + # is the `except_token_id` param that is tricky to get right, so for now we're just using the same where + # clause and values before we handle that. This seems to be only used in the "set password" handler. + refresh_where_clause = where_clause + refresh_values = values.copy() if except_token_id: + # TODO: support that for refresh tokens where_clause += " AND id != ?" values.append(except_token_id) @@ -1583,6 +1769,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values) + txn.execute( + "DELETE FROM refresh_tokens WHERE %s" % refresh_where_clause, + refresh_values, + ) + return tokens_and_devices return await self.db_pool.runInteraction("user_delete_access_tokens", f) @@ -1599,6 +1790,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): await self.db_pool.runInteraction("delete_access_token", f) + async def delete_refresh_token(self, refresh_token: str) -> None: + def f(txn): + self.db_pool.simple_delete_one_txn( + txn, table="refresh_tokens", keyvalues={"token": refresh_token} + ) + + await self.db_pool.runInteraction("delete_refresh_token", f) + async def add_user_pending_deactivation(self, user_id: str) -> None: """ Adds a user to the table of users who need to be parted from all the rooms they're diff --git a/synapse/storage/schema/main/delta/59/14refresh_tokens.sql b/synapse/storage/schema/main/delta/59/14refresh_tokens.sql new file mode 100644 index 0000000000..9a6bce1e3e --- /dev/null +++ b/synapse/storage/schema/main/delta/59/14refresh_tokens.sql @@ -0,0 +1,34 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Holds MSC2918 refresh tokens +CREATE TABLE refresh_tokens ( + id BIGINT PRIMARY KEY, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + token TEXT NOT NULL, + -- When consumed, a new refresh token is generated, which is tracked by + -- this foreign key + next_token_id BIGINT REFERENCES refresh_tokens (id) ON DELETE CASCADE, + UNIQUE(token) +); + +-- Add a reference to the refresh token generated alongside each access token +ALTER TABLE "access_tokens" + ADD COLUMN refresh_token_id BIGINT REFERENCES refresh_tokens (id) ON DELETE CASCADE; + +-- Add a flag whether the token was already used or not +ALTER TABLE "access_tokens" + ADD COLUMN used BOOLEAN; diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 1b0a815757..f76fea4f66 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -58,6 +58,7 @@ class AuthTestCase(unittest.HomeserverTestCase): user_id=self.test_user, token_id=5, device_id="device" ) self.store.get_user_by_access_token = simple_async_mock(user_info) + self.store.mark_access_token_as_used = simple_async_mock(None) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 84c38b295d..3ac48e5e95 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -257,7 +257,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase): self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) # Create a new login for the user and dehydrated the device - device_id, access_token = self.get_success( + device_id, access_token, _expiration_time, _refresh_token = self.get_success( self.registration.register_device( user_id=user_id, device_id=None, diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 485e3650c3..6b90f838b6 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -20,7 +20,7 @@ import synapse.rest.admin from synapse.api.constants import LoginType from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.rest.client.v1 import login -from synapse.rest.client.v2_alpha import auth, devices, register +from synapse.rest.client.v2_alpha import account, auth, devices, register from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.types import JsonDict, UserID @@ -498,3 +498,221 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_device( self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}} ) + + +class RefreshAuthTests(unittest.HomeserverTestCase): + servlets = [ + auth.register_servlets, + account.register_servlets, + login.register_servlets, + synapse.rest.admin.register_servlets_for_client_rest_resource, + register.register_servlets, + ] + hijack_auth = False + + def prepare(self, reactor, clock, hs): + self.user_pass = "pass" + self.user = self.register_user("test", self.user_pass) + + def test_login_issue_refresh_token(self): + """ + A login response should include a refresh_token only if asked. + """ + # Test login + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + + login_without_refresh = self.make_request( + "POST", "/_matrix/client/r0/login", body + ) + self.assertEqual(login_without_refresh.code, 200, login_without_refresh.result) + self.assertNotIn("refresh_token", login_without_refresh.json_body) + + login_with_refresh = self.make_request( + "POST", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, + ) + self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result) + self.assertIn("refresh_token", login_with_refresh.json_body) + self.assertIn("expires_in_ms", login_with_refresh.json_body) + + def test_register_issue_refresh_token(self): + """ + A register response should include a refresh_token only if asked. + """ + register_without_refresh = self.make_request( + "POST", + "/_matrix/client/r0/register", + { + "username": "test2", + "password": self.user_pass, + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual( + register_without_refresh.code, 200, register_without_refresh.result + ) + self.assertNotIn("refresh_token", register_without_refresh.json_body) + + register_with_refresh = self.make_request( + "POST", + "/_matrix/client/r0/register?org.matrix.msc2918.refresh_token=true", + { + "username": "test3", + "password": self.user_pass, + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result) + self.assertIn("refresh_token", register_with_refresh.json_body) + self.assertIn("expires_in_ms", register_with_refresh.json_body) + + def test_token_refresh(self): + """ + A refresh token can be used to issue a new access token. + """ + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + login_response = self.make_request( + "POST", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, + ) + self.assertEqual(login_response.code, 200, login_response.result) + + refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertIn("access_token", refresh_response.json_body) + self.assertIn("refresh_token", refresh_response.json_body) + self.assertIn("expires_in_ms", refresh_response.json_body) + + # The access and refresh tokens should be different from the original ones after refresh + self.assertNotEqual( + login_response.json_body["access_token"], + refresh_response.json_body["access_token"], + ) + self.assertNotEqual( + login_response.json_body["refresh_token"], + refresh_response.json_body["refresh_token"], + ) + + @override_config({"access_token_lifetime": "1m"}) + def test_refresh_token_expiration(self): + """ + The access token should have some time as specified in the config. + """ + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + login_response = self.make_request( + "POST", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, + ) + self.assertEqual(login_response.code, 200, login_response.result) + self.assertApproximates( + login_response.json_body["expires_in_ms"], 60 * 1000, 100 + ) + + refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual(refresh_response.code, 200, refresh_response.result) + self.assertApproximates( + refresh_response.json_body["expires_in_ms"], 60 * 1000, 100 + ) + + def test_refresh_token_invalidation(self): + """Refresh tokens are invalidated after first use of the next token. + + A refresh token is considered invalid if: + - it was already used at least once + - and either + - the next access token was used + - the next refresh token was used + + The chain of tokens goes like this: + + login -|-> first_refresh -> third_refresh (fails) + |-> second_refresh -> fifth_refresh + |-> fourth_refresh (fails) + """ + + body = {"type": "m.login.password", "user": "test", "password": self.user_pass} + login_response = self.make_request( + "POST", + "/_matrix/client/r0/login?org.matrix.msc2918.refresh_token=true", + body, + ) + self.assertEqual(login_response.code, 200, login_response.result) + + # This first refresh should work properly + first_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual( + first_refresh_response.code, 200, first_refresh_response.result + ) + + # This one as well, since the token in the first one was never used + second_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual( + second_refresh_response.code, 200, second_refresh_response.result + ) + + # This one should not, since the token from the first refresh is not valid anymore + third_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": first_refresh_response.json_body["refresh_token"]}, + ) + self.assertEqual( + third_refresh_response.code, 401, third_refresh_response.result + ) + + # The associated access token should also be invalid + whoami_response = self.make_request( + "GET", + "/_matrix/client/r0/account/whoami", + access_token=first_refresh_response.json_body["access_token"], + ) + self.assertEqual(whoami_response.code, 401, whoami_response.result) + + # But all other tokens should work (they will expire after some time) + for access_token in [ + second_refresh_response.json_body["access_token"], + login_response.json_body["access_token"], + ]: + whoami_response = self.make_request( + "GET", "/_matrix/client/r0/account/whoami", access_token=access_token + ) + self.assertEqual(whoami_response.code, 200, whoami_response.result) + + # Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail + fourth_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": login_response.json_body["refresh_token"]}, + ) + self.assertEqual( + fourth_refresh_response.code, 403, fourth_refresh_response.result + ) + + # But refreshing from the last valid refresh token still works + fifth_refresh_response = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc2918.refresh_token/refresh", + {"refresh_token": second_refresh_response.json_body["refresh_token"]}, + ) + self.assertEqual( + fifth_refresh_response.code, 200, fifth_refresh_response.result + ) -- cgit 1.5.1 From 6e8fb42be7657f9d4958c02d87cff865225714d2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 24 Jun 2021 15:30:49 +0100 Subject: Improve validation for `send_{join,leave,knock}` (#10225) The idea here is to stop people sending things that aren't joins/leaves/knocks through these endpoints: previously you could send anything you liked through them. I wasn't able to find any security holes from doing so, but it doesn't sound like a good thing. --- changelog.d/10225.feature | 1 + synapse/federation/federation_server.py | 121 +++++++++------ synapse/federation/transport/server.py | 12 +- synapse/handlers/federation.py | 177 +++++++--------------- tests/handlers/test_federation.py | 2 +- tests/replication/test_federation_sender_shard.py | 2 +- 6 files changed, 132 insertions(+), 183 deletions(-) create mode 100644 changelog.d/10225.feature (limited to 'synapse') diff --git a/changelog.d/10225.feature b/changelog.d/10225.feature new file mode 100644 index 0000000000..d16f66ffe9 --- /dev/null +++ b/changelog.d/10225.feature @@ -0,0 +1 @@ +Improve validation on federation `send_{join,leave,knock}` endpoints. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2b07f18529..341965047a 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -34,7 +34,7 @@ from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure -from synapse.api.constants import EduTypes, EventTypes +from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -46,6 +46,7 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction @@ -537,26 +538,21 @@ class FederationServer(FederationBase): return {"event": ret_pdu.get_pdu_json(time_now)} async def on_send_join_request( - self, origin: str, content: JsonDict + self, origin: str, content: JsonDict, room_id: str ) -> Dict[str, Any]: - logger.debug("on_send_join_request: content: %s", content) - - assert_params_in_dict(content, ["room_id"]) - room_version = await self.store.get_room_version(content["room_id"]) - pdu = event_from_pdu_json(content, room_version) - - origin_host, _ = parse_server_name(origin) - await self.check_server_matches_acl(origin_host, pdu.room_id) - - logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) + context = await self._on_send_membership_event( + origin, content, Membership.JOIN, room_id + ) - pdu = await self._check_sigs_and_hash(room_version, pdu) + prev_state_ids = await context.get_prev_state_ids() + state_ids = list(prev_state_ids.values()) + auth_chain = await self.store.get_auth_chain(room_id, state_ids) + state = await self.store.get_events(state_ids) - res_pdus = await self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() return { - "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], - "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]], + "state": [p.get_pdu_json(time_now) for p in state.values()], + "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain], } async def on_make_leave_request( @@ -571,21 +567,11 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - async def on_send_leave_request(self, origin: str, content: JsonDict) -> dict: + async def on_send_leave_request( + self, origin: str, content: JsonDict, room_id: str + ) -> dict: logger.debug("on_send_leave_request: content: %s", content) - - assert_params_in_dict(content, ["room_id"]) - room_version = await self.store.get_room_version(content["room_id"]) - pdu = event_from_pdu_json(content, room_version) - - origin_host, _ = parse_server_name(origin) - await self.check_server_matches_acl(origin_host, pdu.room_id) - - logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) - - pdu = await self._check_sigs_and_hash(room_version, pdu) - - await self.handler.on_send_leave_request(origin, pdu) + await self._on_send_membership_event(origin, content, Membership.LEAVE, room_id) return {} async def on_make_knock_request( @@ -651,39 +637,76 @@ class FederationServer(FederationBase): Returns: The stripped room state. """ - logger.debug("on_send_knock_request: content: %s", content) + event_context = await self._on_send_membership_event( + origin, content, Membership.KNOCK, room_id + ) + + # Retrieve stripped state events from the room and send them back to the remote + # server. This will allow the remote server's clients to display information + # related to the room while the knock request is pending. + stripped_room_state = ( + await self.store.get_stripped_room_state_from_event_context( + event_context, self._room_prejoin_state_types + ) + ) + return {"knock_state_events": stripped_room_state} + + async def _on_send_membership_event( + self, origin: str, content: JsonDict, membership_type: str, room_id: str + ) -> EventContext: + """Handle an on_send_{join,leave,knock} request + + Does some preliminary validation before passing the request on to the + federation handler. + + Args: + origin: The (authenticated) requesting server + content: The body of the send_* request - a complete membership event + membership_type: The expected membership type (join or leave, depending + on the endpoint) + room_id: The room_id from the request, to be validated against the room_id + in the event + + Returns: + The context of the event after inserting it into the room graph. + + Raises: + SynapseError if there is a problem with the request, including things like + the room_id not matching or the event not being authorized. + """ + assert_params_in_dict(content, ["room_id"]) + if content["room_id"] != room_id: + raise SynapseError( + 400, + "Room ID in body does not match that in request path", + Codes.BAD_JSON, + ) room_version = await self.store.get_room_version(room_id) - # Check that this room supports knocking as defined by its room version - if not room_version.msc2403_knocking: + if membership_type == Membership.KNOCK and not room_version.msc2403_knocking: raise SynapseError( 403, "This room version does not support knocking", errcode=Codes.FORBIDDEN, ) - pdu = event_from_pdu_json(content, room_version) + event = event_from_pdu_json(content, room_version) - origin_host, _ = parse_server_name(origin) - await self.check_server_matches_acl(origin_host, pdu.room_id) + if event.type != EventTypes.Member or not event.is_state(): + raise SynapseError(400, "Not an m.room.member event", Codes.BAD_JSON) - logger.debug("on_send_knock_request: pdu sigs: %s", pdu.signatures) + if event.content.get("membership") != membership_type: + raise SynapseError(400, "Not a %s event" % membership_type, Codes.BAD_JSON) - pdu = await self._check_sigs_and_hash(room_version, pdu) + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, event.room_id) - # Handle the event, and retrieve the EventContext - event_context = await self.handler.on_send_knock_request(origin, pdu) + logger.debug("_on_send_membership_event: pdu sigs: %s", event.signatures) - # Retrieve stripped state events from the room and send them back to the remote - # server. This will allow the remote server's clients to display information - # related to the room while the knock request is pending. - stripped_room_state = ( - await self.store.get_stripped_room_state_from_event_context( - event_context, self._room_prejoin_state_types - ) - ) - return {"knock_state_events": stripped_room_state} + event = await self._check_sigs_and_hash(room_version, event) + + return await self.handler.on_send_membership_event(origin, event) async def on_event_auth( self, origin: str, room_id: str, event_id: str diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index bed47f8abd..676fbd3750 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -553,7 +553,7 @@ class FederationV1SendLeaveServlet(BaseFederationServerServlet): PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): - content = await self.handler.on_send_leave_request(origin, content) + content = await self.handler.on_send_leave_request(origin, content, room_id) return 200, (200, content) @@ -563,7 +563,7 @@ class FederationV2SendLeaveServlet(BaseFederationServerServlet): PREFIX = FEDERATION_V2_PREFIX async def on_PUT(self, origin, content, query, room_id, event_id): - content = await self.handler.on_send_leave_request(origin, content) + content = await self.handler.on_send_leave_request(origin, content, room_id) return 200, content @@ -602,9 +602,9 @@ class FederationV1SendJoinServlet(BaseFederationServerServlet): PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): - # TODO(paul): assert that room_id/event_id parsed from path actually + # TODO(paul): assert that event_id parsed from path actually # match those given in content - content = await self.handler.on_send_join_request(origin, content) + content = await self.handler.on_send_join_request(origin, content, room_id) return 200, (200, content) @@ -614,9 +614,9 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet): PREFIX = FEDERATION_V2_PREFIX async def on_PUT(self, origin, content, query, room_id, event_id): - # TODO(paul): assert that room_id/event_id parsed from path actually + # TODO(paul): assert that event_id parsed from path actually # match those given in content - content = await self.handler.on_send_join_request(origin, content) + content = await self.handler.on_send_join_request(origin, content, room_id) return 200, content diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 74d169a2ac..12f3d85342 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1711,80 +1711,6 @@ class FederationHandler(BaseHandler): return event - async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict: - """We have received a join event for a room. Fully process it and - respond with the current state and auth chains. - """ - event = pdu - - logger.debug( - "on_send_join_request from %s: Got event: %s, signatures: %s", - origin, - event.event_id, - event.signatures, - ) - - if get_domain_from_id(event.sender) != origin: - logger.info( - "Got /send_join request for user %r from different origin %s", - event.sender, - origin, - ) - raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) - - event.internal_metadata.outlier = False - # Send this event on behalf of the origin server. - # - # The reasons we have the destination server rather than the origin - # server send it are slightly mysterious: the origin server should have - # all the necessary state once it gets the response to the send_join, - # so it could send the event itself if it wanted to. It may be that - # doing it this way reduces failure modes, or avoids certain attacks - # where a new server selectively tells a subset of the federation that - # it has joined. - # - # The fact is that, as of the current writing, Synapse doesn't send out - # the join event over federation after joining, and changing it now - # would introduce the danger of backwards-compatibility problems. - event.internal_metadata.send_on_behalf_of = origin - - # Calculate the event context. - context = await self.state_handler.compute_event_context(event) - - # Get the state before the new event. - prev_state_ids = await context.get_prev_state_ids() - - # Check if the user is already in the room or invited to the room. - user_id = event.state_key - prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) - prev_member_event = None - if prev_member_event_id: - prev_member_event = await self.store.get_event(prev_member_event_id) - - # Check if the member should be allowed access via membership in a space. - await self._event_auth_handler.check_restricted_join_rules( - prev_state_ids, - event.room_version, - user_id, - prev_member_event, - ) - - # Persist the event. - await self._auth_and_persist_event(origin, event, context) - - logger.debug( - "on_send_join_request: After _auth_and_persist_event: %s, sigs: %s", - event.event_id, - event.signatures, - ) - - state_ids = list(prev_state_ids.values()) - auth_chain = await self.store.get_auth_chain(event.room_id, state_ids) - - state = await self.store.get_events(list(prev_state_ids.values())) - - return {"state": list(state.values()), "auth_chain": auth_chain} - async def on_invite_request( self, origin: str, event: EventBase, room_version: RoomVersion ) -> EventBase: @@ -1960,44 +1886,6 @@ class FederationHandler(BaseHandler): return event - async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None: - """We have received a leave event for a room. Fully process it.""" - event = pdu - - logger.debug( - "on_send_leave_request: Got event: %s, signatures: %s", - event.event_id, - event.signatures, - ) - - if get_domain_from_id(event.sender) != origin: - logger.info( - "Got /send_leave request for user %r from different origin %s", - event.sender, - origin, - ) - raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) - - event.internal_metadata.outlier = False - - # Send this event on behalf of the other server. - # - # The remote server isn't a full participant in the room at this point, so - # may not have an up-to-date list of the other homeservers participating in - # the room, so we send it on their behalf. - event.internal_metadata.send_on_behalf_of = origin - - context = await self.state_handler.compute_event_context(event) - await self._auth_and_persist_event(origin, event, context) - - logger.debug( - "on_send_leave_request: After _auth_and_persist_event: %s, sigs: %s", - event.event_id, - event.signatures, - ) - - return None - @log_function async def on_make_knock_request( self, origin: str, room_id: str, user_id: str @@ -2061,34 +1949,38 @@ class FederationHandler(BaseHandler): return event @log_function - async def on_send_knock_request( + async def on_send_membership_event( self, origin: str, event: EventBase ) -> EventContext: """ - We have received a knock event for a room. Verify that event and send it into the room - on the knocking homeserver's behalf. + We have received a join/leave/knock event for a room. + + Verify that event and send it into the room on the remote homeserver's behalf. Args: - origin: The remote homeserver of the knocking user. - event: The knocking member event that has been signed by the remote homeserver. + origin: The homeserver of the remote (joining/invited/knocking) user. + event: The member event that has been signed by the remote homeserver. Returns: The context of the event after inserting it into the room graph. """ logger.debug( - "on_send_knock_request: Got event: %s, signatures: %s", + "on_send_membership_event: Got event: %s, signatures: %s", event.event_id, event.signatures, ) if get_domain_from_id(event.sender) != origin: logger.info( - "Got /send_knock request for user %r from different origin %s", + "Got send_membership request for user %r from different origin %s", event.sender, origin, ) raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) + if event.sender != event.state_key: + raise SynapseError(400, "state_key and sender must match", Codes.BAD_JSON) + event.internal_metadata.outlier = False # Send this event on behalf of the other server. @@ -2100,19 +1992,52 @@ class FederationHandler(BaseHandler): context = await self.state_handler.compute_event_context(event) - event_allowed = await self.third_party_event_rules.check_event_allowed( - event, context - ) - if not event_allowed: - logger.info("Sending of knock %s forbidden by third-party rules", event) - raise SynapseError( - 403, "This event is not allowed in this context", Codes.FORBIDDEN + # for joins, we need to check the restrictions of restricted rooms + if event.membership == Membership.JOIN: + await self._check_join_restrictions(context, event) + + # for knock events, we run the third-party event rules. It's not entirely clear + # why we don't do this for other sorts of membership events. + if event.membership == Membership.KNOCK: + event_allowed = await self.third_party_event_rules.check_event_allowed( + event, context ) + if not event_allowed: + logger.info("Sending of knock %s forbidden by third-party rules", event) + raise SynapseError( + 403, "This event is not allowed in this context", Codes.FORBIDDEN + ) await self._auth_and_persist_event(origin, event, context) return context + async def _check_join_restrictions( + self, context: EventContext, event: EventBase + ) -> None: + """Check that restrictions in restricted join rules are matched + + Called when we receive a join event via send_join. + + Raises an auth error if the restrictions are not matched. + """ + prev_state_ids = await context.get_prev_state_ids() + + # Check if the user is already in the room or invited to the room. + user_id = event.state_key + prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) + prev_member_event = None + if prev_member_event_id: + prev_member_event = await self.store.get_event(prev_member_event_id) + + # Check if the member should be allowed access via membership in a space. + await self._event_auth_handler.check_restricted_join_rules( + prev_state_ids, + event.room_version, + user_id, + prev_member_event, + ) + async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]: """Returns the state at the event. i.e. not including said event.""" diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 8796af45ed..ba8cf44f46 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -251,7 +251,7 @@ class FederationTestCase(unittest.HomeserverTestCase): join_event.signatures[other_server] = {"x": "y"} with LoggingContext("send_join"): d = run_in_background( - self.handler.on_send_join_request, other_server, join_event + self.handler.on_send_membership_event, other_server, join_event ) self.get_success(d) diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 584da58371..a0c710f855 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -228,7 +228,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None) ) - self.get_success(federation.on_send_join_request(remote_server, join_event)) + self.get_success(federation.on_send_membership_event(remote_server, join_event)) self.replicate() return room -- cgit 1.5.1 From 8165ba48b1d7d6a265683b06e32d08935f41fa69 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 24 Jun 2021 16:00:08 +0100 Subject: Return errors from `send_join` etc if the event is rejected (#10243) Rather than persisting rejected events via `send_join` and friends, raise a 403 if someone tries to pull a fast one. --- changelog.d/10243.feature | 1 + synapse/handlers/federation.py | 46 ++++++++++++++++++++++++----- tests/federation/transport/test_knocking.py | 4 +-- 3 files changed, 41 insertions(+), 10 deletions(-) create mode 100644 changelog.d/10243.feature (limited to 'synapse') diff --git a/changelog.d/10243.feature b/changelog.d/10243.feature new file mode 100644 index 0000000000..d16f66ffe9 --- /dev/null +++ b/changelog.d/10243.feature @@ -0,0 +1 @@ +Improve validation on federation `send_{join,leave,knock}` endpoints. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 12f3d85342..d929c65131 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1953,16 +1953,31 @@ class FederationHandler(BaseHandler): self, origin: str, event: EventBase ) -> EventContext: """ - We have received a join/leave/knock event for a room. + We have received a join/leave/knock event for a room via send_join/leave/knock. Verify that event and send it into the room on the remote homeserver's behalf. + This is quite similar to on_receive_pdu, with the following principal + differences: + * only membership events are permitted (and only events with + sender==state_key -- ie, no kicks or bans) + * *We* send out the event on behalf of the remote server. + * We enforce the membership restrictions of restricted rooms. + * Rejected events result in an exception rather than being stored. + + There are also other differences, however it is not clear if these are by + design or omission. In particular, we do not attempt to backfill any missing + prev_events. + Args: origin: The homeserver of the remote (joining/invited/knocking) user. event: The member event that has been signed by the remote homeserver. Returns: The context of the event after inserting it into the room graph. + + Raises: + SynapseError if the event is not accepted into the room """ logger.debug( "on_send_membership_event: Got event: %s, signatures: %s", @@ -1981,7 +1996,7 @@ class FederationHandler(BaseHandler): if event.sender != event.state_key: raise SynapseError(400, "state_key and sender must match", Codes.BAD_JSON) - event.internal_metadata.outlier = False + assert not event.internal_metadata.outlier # Send this event on behalf of the other server. # @@ -1991,6 +2006,11 @@ class FederationHandler(BaseHandler): event.internal_metadata.send_on_behalf_of = origin context = await self.state_handler.compute_event_context(event) + context = await self._check_event_auth(origin, event, context) + if context.rejected: + raise SynapseError( + 403, f"{event.membership} event was rejected", Codes.FORBIDDEN + ) # for joins, we need to check the restrictions of restricted rooms if event.membership == Membership.JOIN: @@ -2008,8 +2028,8 @@ class FederationHandler(BaseHandler): 403, "This event is not allowed in this context", Codes.FORBIDDEN ) - await self._auth_and_persist_event(origin, event, context) - + # all looks good, we can persist the event. + await self._run_push_actions_and_persist_event(event, context) return context async def _check_join_restrictions( @@ -2179,6 +2199,18 @@ class FederationHandler(BaseHandler): backfilled=backfilled, ) + await self._run_push_actions_and_persist_event(event, context, backfilled) + + async def _run_push_actions_and_persist_event( + self, event: EventBase, context: EventContext, backfilled: bool = False + ): + """Run the push actions for a received event, and persist it. + + Args: + event: The event itself. + context: The event context. + backfilled: True if the event was backfilled. + """ try: if ( not event.internal_metadata.is_outlier() @@ -2492,9 +2524,9 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, context: EventContext, - state: Optional[Iterable[EventBase]], - auth_events: Optional[MutableStateMap[EventBase]], - backfilled: bool, + state: Optional[Iterable[EventBase]] = None, + auth_events: Optional[MutableStateMap[EventBase]] = None, + backfilled: bool = False, ) -> EventContext: """ Checks whether an event should be rejected (for failing auth checks). diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index 8c215d50f2..aab44bce4a 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -205,9 +205,7 @@ class FederationKnockingTestCase( # Have this homeserver skip event auth checks. This is necessary due to # event auth checks ensuring that events were signed by the sender's homeserver. - async def _check_event_auth( - origin, event, context, state, auth_events, backfilled - ): + async def _check_event_auth(origin, event, context, *args, **kwargs): return context homeserver.get_federation_handler()._check_event_auth = _check_event_auth -- cgit 1.5.1 From 0555d7b0dc18fff489a31afccb47b79afa082113 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 28 Jun 2021 07:36:41 -0400 Subject: Add additional types to the federation transport server. (#10213) --- changelog.d/10213.misc | 1 + synapse/federation/transport/server.py | 588 ++++++++++++++++++++++++++------- synapse/http/servlet.py | 50 ++- 3 files changed, 521 insertions(+), 118 deletions(-) create mode 100644 changelog.d/10213.misc (limited to 'synapse') diff --git a/changelog.d/10213.misc b/changelog.d/10213.misc new file mode 100644 index 0000000000..9adb0fbd02 --- /dev/null +++ b/changelog.d/10213.misc @@ -0,0 +1 @@ +Add type hints to the federation servlets. diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 676fbd3750..d37d9565fc 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -15,7 +15,19 @@ import functools import logging import re -from typing import Container, Mapping, Optional, Sequence, Tuple, Type +from typing import ( + Container, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from typing_extensions import Literal import synapse from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH @@ -56,15 +68,15 @@ logger = logging.getLogger(__name__) class TransportLayerServer(JsonResource): """Handles incoming federation HTTP requests""" - def __init__(self, hs, servlet_groups=None): + def __init__(self, hs: HomeServer, servlet_groups: Optional[List[str]] = None): """Initialize the TransportLayerServer Will by default register all servlets. For custom behaviour, pass in a list of servlet_groups to register. Args: - hs (synapse.server.HomeServer): homeserver - servlet_groups (list[str], optional): List of servlet groups to register. + hs: homeserver + servlet_groups: List of servlet groups to register. Defaults to ``DEFAULT_SERVLET_GROUPS``. """ self.hs = hs @@ -78,7 +90,7 @@ class TransportLayerServer(JsonResource): self.register_servlets() - def register_servlets(self): + def register_servlets(self) -> None: register_servlets( self.hs, resource=self, @@ -91,14 +103,10 @@ class TransportLayerServer(JsonResource): class AuthenticationError(SynapseError): """There was a problem authenticating the request""" - pass - class NoAuthenticationError(AuthenticationError): """The request had no authentication information""" - pass - class Authenticator: def __init__(self, hs: HomeServer): @@ -410,13 +418,18 @@ class FederationSendServlet(BaseFederationServerServlet): RATELIMIT = False # This is when someone is trying to send us a bunch of data. - async def on_PUT(self, origin, content, query, transaction_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + transaction_id: str, + ) -> Tuple[int, JsonDict]: """Called on PUT /send// Args: - request (twisted.web.http.Request): The HTTP request. - transaction_id (str): The transaction_id associated with this - request. This is *not* None. + transaction_id: The transaction_id associated with this request. This + is *not* None. Returns: Tuple of `(code, response)`, where @@ -461,7 +474,13 @@ class FederationEventServlet(BaseFederationServerServlet): PATH = "/event/(?P[^/]*)/?" # This is when someone asks for a data item for a given server data_id pair. - async def on_GET(self, origin, content, query, event_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + event_id: str, + ) -> Tuple[int, Union[JsonDict, str]]: return await self.handler.on_pdu_request(origin, event_id) @@ -469,7 +488,13 @@ class FederationStateV1Servlet(BaseFederationServerServlet): PATH = "/state/(?P[^/]*)/?" # This is when someone asks for all data for a given room. - async def on_GET(self, origin, content, query, room_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: return await self.handler.on_room_state_request( origin, room_id, @@ -480,7 +505,13 @@ class FederationStateV1Servlet(BaseFederationServerServlet): class FederationStateIdsServlet(BaseFederationServerServlet): PATH = "/state_ids/(?P[^/]*)/?" - async def on_GET(self, origin, content, query, room_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: return await self.handler.on_state_ids_request( origin, room_id, @@ -491,7 +522,13 @@ class FederationStateIdsServlet(BaseFederationServerServlet): class FederationBackfillServlet(BaseFederationServerServlet): PATH = "/backfill/(?P[^/]*)/?" - async def on_GET(self, origin, content, query, room_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: versions = [x.decode("ascii") for x in query[b"v"]] limit = parse_integer_from_args(query, "limit", None) @@ -505,7 +542,13 @@ class FederationQueryServlet(BaseFederationServerServlet): PATH = "/query/(?P[^/]*)" # This is when we receive a server-server Query - async def on_GET(self, origin, content, query, query_type): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + query_type: str, + ) -> Tuple[int, JsonDict]: args = {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()} args["origin"] = origin return await self.handler.on_query_request(query_type, args) @@ -514,47 +557,66 @@ class FederationQueryServlet(BaseFederationServerServlet): class FederationMakeJoinServlet(BaseFederationServerServlet): PATH = "/make_join/(?P[^/]*)/(?P[^/]*)" - async def on_GET(self, origin, _content, query, room_id, user_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: """ Args: - origin (unicode): The authenticated server_name of the calling server + origin: The authenticated server_name of the calling server - _content (None): (GETs don't have bodies) + content: (GETs don't have bodies) - query (dict[bytes, list[bytes]]): Query params from the request. + query: Query params from the request. - **kwargs (dict[unicode, unicode]): the dict mapping keys to path - components as specified in the path match regexp. + **kwargs: the dict mapping keys to path components as specified in + the path match regexp. Returns: - Tuple[int, object]: (response code, response object) + Tuple of (response code, response object) """ - versions = query.get(b"ver") - if versions is not None: - supported_versions = [v.decode("utf-8") for v in versions] - else: + supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8") + if supported_versions is None: supported_versions = ["1"] - content = await self.handler.on_make_join_request( + result = await self.handler.on_make_join_request( origin, room_id, user_id, supported_versions=supported_versions ) - return 200, content + return 200, result class FederationMakeLeaveServlet(BaseFederationServerServlet): PATH = "/make_leave/(?P[^/]*)/(?P[^/]*)" - async def on_GET(self, origin, content, query, room_id, user_id): - content = await self.handler.on_make_leave_request(origin, room_id, user_id) - return 200, content + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + result = await self.handler.on_make_leave_request(origin, room_id, user_id) + return 200, result class FederationV1SendLeaveServlet(BaseFederationServerServlet): PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" - async def on_PUT(self, origin, content, query, room_id, event_id): - content = await self.handler.on_send_leave_request(origin, content, room_id) - return 200, (200, content) + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: + result = await self.handler.on_send_leave_request(origin, content, room_id) + return 200, (200, result) class FederationV2SendLeaveServlet(BaseFederationServerServlet): @@ -562,50 +624,84 @@ class FederationV2SendLeaveServlet(BaseFederationServerServlet): PREFIX = FEDERATION_V2_PREFIX - async def on_PUT(self, origin, content, query, room_id, event_id): - content = await self.handler.on_send_leave_request(origin, content, room_id) - return 200, content + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + result = await self.handler.on_send_leave_request(origin, content, room_id) + return 200, result class FederationMakeKnockServlet(BaseFederationServerServlet): PATH = "/make_knock/(?P[^/]*)/(?P[^/]*)" - async def on_GET(self, origin, content, query, room_id, user_id): - try: - # Retrieve the room versions the remote homeserver claims to support - supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8") - except KeyError: - raise SynapseError(400, "Missing required query parameter 'ver'") + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + # Retrieve the room versions the remote homeserver claims to support + supported_versions = parse_strings_from_args( + query, "ver", required=True, encoding="utf-8" + ) - content = await self.handler.on_make_knock_request( + result = await self.handler.on_make_knock_request( origin, room_id, user_id, supported_versions=supported_versions ) - return 200, content + return 200, result class FederationV1SendKnockServlet(BaseFederationServerServlet): PATH = "/send_knock/(?P[^/]*)/(?P[^/]*)" - async def on_PUT(self, origin, content, query, room_id, event_id): - content = await self.handler.on_send_knock_request(origin, content, room_id) - return 200, content + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + result = await self.handler.on_send_knock_request(origin, content, room_id) + return 200, result class FederationEventAuthServlet(BaseFederationServerServlet): PATH = "/event_auth/(?P[^/]*)/(?P[^/]*)" - async def on_GET(self, origin, content, query, room_id, event_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: return await self.handler.on_event_auth(origin, room_id, event_id) class FederationV1SendJoinServlet(BaseFederationServerServlet): PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" - async def on_PUT(self, origin, content, query, room_id, event_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: # TODO(paul): assert that event_id parsed from path actually # match those given in content - content = await self.handler.on_send_join_request(origin, content, room_id) - return 200, (200, content) + result = await self.handler.on_send_join_request(origin, content, room_id) + return 200, (200, result) class FederationV2SendJoinServlet(BaseFederationServerServlet): @@ -613,28 +709,42 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet): PREFIX = FEDERATION_V2_PREFIX - async def on_PUT(self, origin, content, query, room_id, event_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: # TODO(paul): assert that event_id parsed from path actually # match those given in content - content = await self.handler.on_send_join_request(origin, content, room_id) - return 200, content + result = await self.handler.on_send_join_request(origin, content, room_id) + return 200, result class FederationV1InviteServlet(BaseFederationServerServlet): PATH = "/invite/(?P[^/]*)/(?P[^/]*)" - async def on_PUT(self, origin, content, query, room_id, event_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: # We don't get a room version, so we have to assume its EITHER v1 or # v2. This is "fine" as the only difference between V1 and V2 is the # state resolution algorithm, and we don't use that for processing # invites - content = await self.handler.on_invite_request( + result = await self.handler.on_invite_request( origin, content, room_version_id=RoomVersions.V1.identifier ) # V1 federation API is defined to return a content of `[200, {...}]` # due to a historical bug. - return 200, (200, content) + return 200, (200, result) class FederationV2InviteServlet(BaseFederationServerServlet): @@ -642,7 +752,14 @@ class FederationV2InviteServlet(BaseFederationServerServlet): PREFIX = FEDERATION_V2_PREFIX - async def on_PUT(self, origin, content, query, room_id, event_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: # TODO(paul): assert that room_id/event_id parsed from path actually # match those given in content @@ -655,16 +772,22 @@ class FederationV2InviteServlet(BaseFederationServerServlet): event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state - content = await self.handler.on_invite_request( + result = await self.handler.on_invite_request( origin, event, room_version_id=room_version ) - return 200, content + return 200, result class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet): PATH = "/exchange_third_party_invite/(?P[^/]*)" - async def on_PUT(self, origin, content, query, room_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: await self.handler.on_exchange_third_party_invite_request(content) return 200, {} @@ -672,21 +795,31 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet): class FederationClientKeysQueryServlet(BaseFederationServerServlet): PATH = "/user/keys/query" - async def on_POST(self, origin, content, query): + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: return await self.handler.on_query_client_keys(origin, content) class FederationUserDevicesQueryServlet(BaseFederationServerServlet): PATH = "/user/devices/(?P[^/]*)" - async def on_GET(self, origin, content, query, user_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + user_id: str, + ) -> Tuple[int, JsonDict]: return await self.handler.on_query_user_devices(origin, user_id) class FederationClientKeysClaimServlet(BaseFederationServerServlet): PATH = "/user/keys/claim" - async def on_POST(self, origin, content, query): + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: response = await self.handler.on_claim_client_keys(origin, content) return 200, response @@ -695,12 +828,18 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet): # TODO(paul): Why does this path alone end with "/?" optional? PATH = "/get_missing_events/(?P[^/]*)/?" - async def on_POST(self, origin, content, query, room_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: limit = int(content.get("limit", 10)) earliest_events = content.get("earliest_events", []) latest_events = content.get("latest_events", []) - content = await self.handler.on_get_missing_events( + result = await self.handler.on_get_missing_events( origin, room_id=room_id, earliest_events=earliest_events, @@ -708,7 +847,7 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet): limit=limit, ) - return 200, content + return 200, result class On3pidBindServlet(BaseFederationServerServlet): @@ -716,7 +855,9 @@ class On3pidBindServlet(BaseFederationServerServlet): REQUIRE_AUTH = False - async def on_POST(self, origin, content, query): + async def on_POST( + self, origin: Optional[str], content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: if "invites" in content: last_exception = None for invite in content["invites"]: @@ -762,15 +903,20 @@ class OpenIdUserInfo(BaseFederationServerServlet): REQUIRE_AUTH = False - async def on_GET(self, origin, content, query): - token = query.get(b"access_token", [None])[0] + async def on_GET( + self, + origin: Optional[str], + content: Literal[None], + query: Dict[bytes, List[bytes]], + ) -> Tuple[int, JsonDict]: + token = parse_string_from_args(query, "access_token") if token is None: return ( 401, {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"}, ) - user_id = await self.handler.on_openid_userinfo(token.decode("ascii")) + user_id = await self.handler.on_openid_userinfo(token) if user_id is None: return ( @@ -829,7 +975,9 @@ class PublicRoomList(BaseFederationServlet): self.handler = hs.get_room_list_handler() self.allow_access = allow_access - async def on_GET(self, origin, content, query): + async def on_GET( + self, origin: str, content: Literal[None], query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: if not self.allow_access: raise FederationDeniedError(origin) @@ -858,7 +1006,9 @@ class PublicRoomList(BaseFederationServlet): ) return 200, data - async def on_POST(self, origin, content, query): + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: # This implements MSC2197 (Search Filtering over Federation) if not self.allow_access: raise FederationDeniedError(origin) @@ -904,7 +1054,12 @@ class FederationVersionServlet(BaseFederationServlet): REQUIRE_AUTH = False - async def on_GET(self, origin, content, query): + async def on_GET( + self, + origin: Optional[str], + content: Literal[None], + query: Dict[bytes, List[bytes]], + ) -> Tuple[int, JsonDict]: return ( 200, {"server": {"name": "Synapse", "version": get_version_string(synapse)}}, @@ -933,7 +1088,13 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/profile" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -942,7 +1103,13 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet): return 200, new_content - async def on_POST(self, origin, content, query, group_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -957,7 +1124,13 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet): class FederationGroupsSummaryServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/summary" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -972,7 +1145,13 @@ class FederationGroupsRoomsServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/rooms" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -987,7 +1166,14 @@ class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/room/(?P[^/]*)" - async def on_POST(self, origin, content, query, group_id, room_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -998,7 +1184,14 @@ class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet): return 200, new_content - async def on_DELETE(self, origin, content, query, group_id, room_id): + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1018,7 +1211,15 @@ class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet): "/config/(?P[^/]*)" ) - async def on_POST(self, origin, content, query, group_id, room_id, config_key): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + config_key: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1035,7 +1236,13 @@ class FederationGroupsUsersServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/users" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1050,7 +1257,13 @@ class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/invited_users" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1067,7 +1280,14 @@ class FederationGroupsInviteServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/invite" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1084,7 +1304,14 @@ class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/accept_invite" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: if get_domain_from_id(user_id) != origin: raise SynapseError(403, "user_id doesn't match origin") @@ -1098,7 +1325,14 @@ class FederationGroupsJoinServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/join" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: if get_domain_from_id(user_id) != origin: raise SynapseError(403, "user_id doesn't match origin") @@ -1112,7 +1346,14 @@ class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/remove" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1146,7 +1387,14 @@ class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet): PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/invite" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: if get_domain_from_id(group_id) != origin: raise SynapseError(403, "group_id doesn't match origin") @@ -1164,7 +1412,14 @@ class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet): PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/remove" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, None]: if get_domain_from_id(group_id) != origin: raise SynapseError(403, "user_id doesn't match origin") @@ -1172,11 +1427,9 @@ class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet): self.handler, GroupsLocalHandler ), "Workers cannot handle group removals." - new_content = await self.handler.user_removed_from_group( - group_id, user_id, content - ) + await self.handler.user_removed_from_group(group_id, user_id, content) - return 200, new_content + return 200, None class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): @@ -1194,7 +1447,14 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): super().__init__(hs, authenticator, ratelimiter, server_name) self.handler = hs.get_groups_attestation_renewer() - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: # We don't need to check auth here as we check the attestation signatures new_content = await self.handler.on_renew_attestation( @@ -1218,7 +1478,15 @@ class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet): "/rooms/(?P[^/]*)" ) - async def on_POST(self, origin, content, query, group_id, category_id, room_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1246,7 +1514,15 @@ class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet): return 200, resp - async def on_DELETE(self, origin, content, query, group_id, category_id, room_id): + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1266,7 +1542,13 @@ class FederationGroupsCategoriesServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/categories/?" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1281,7 +1563,14 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/categories/(?P[^/]+)" - async def on_GET(self, origin, content, query, group_id, category_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1292,7 +1581,14 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet): return 200, resp - async def on_POST(self, origin, content, query, group_id, category_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1314,7 +1610,14 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet): return 200, resp - async def on_DELETE(self, origin, content, query, group_id, category_id): + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1334,7 +1637,13 @@ class FederationGroupsRolesServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/roles/?" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1349,7 +1658,14 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/roles/(?P[^/]+)" - async def on_GET(self, origin, content, query, group_id, role_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1358,7 +1674,14 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet): return 200, resp - async def on_POST(self, origin, content, query, group_id, role_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1382,7 +1705,14 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet): return 200, resp - async def on_DELETE(self, origin, content, query, group_id, role_id): + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1411,7 +1741,15 @@ class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet): "/users/(?P[^/]*)" ) - async def on_POST(self, origin, content, query, group_id, role_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1437,7 +1775,15 @@ class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet): return 200, resp - async def on_DELETE(self, origin, content, query, group_id, role_id, user_id): + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1457,7 +1803,9 @@ class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet): PATH = "/get_groups_publicised" - async def on_POST(self, origin, content, query): + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: resp = await self.handler.bulk_get_publicised_groups( content["user_ids"], proxy=False ) @@ -1470,7 +1818,13 @@ class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/settings/m.join_policy" - async def on_PUT(self, origin, content, query, group_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1499,7 +1853,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): async def on_GET( self, origin: str, - content: JsonDict, + content: Literal[None], query: Mapping[bytes, Sequence[bytes]], room_id: str, ) -> Tuple[int, JsonDict]: @@ -1571,7 +1925,13 @@ class RoomComplexityServlet(BaseFederationServlet): super().__init__(hs, authenticator, ratelimiter, server_name) self._store = self.hs.get_datastore() - async def on_GET(self, origin, content, query, room_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: is_public = await self._store.is_room_world_readable_or_publicly_joinable( room_id ) diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index fda8da21b7..6ba2ce1e53 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -109,12 +109,22 @@ def parse_boolean_from_args(args, name, default=None, required=False): return default +@overload +def parse_bytes_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[bytes] = None, +) -> Optional[bytes]: + ... + + @overload def parse_bytes_from_args( args: Dict[bytes, List[bytes]], name: str, default: Literal[None] = None, - required: Literal[True] = True, + *, + required: Literal[True], ) -> bytes: ... @@ -197,7 +207,12 @@ def parse_string( """ args = request.args # type: Dict[bytes, List[bytes]] # type: ignore return parse_string_from_args( - args, name, default, required, allowed_values, encoding + args, + name, + default, + required=required, + allowed_values=allowed_values, + encoding=encoding, ) @@ -227,7 +242,20 @@ def parse_strings_from_args( args: Dict[bytes, List[bytes]], name: str, default: Optional[List[str]] = None, - required: Literal[True] = True, + *, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> Optional[List[str]]: + ... + + +@overload +def parse_strings_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[List[str]] = None, + *, + required: Literal[True], allowed_values: Optional[Iterable[str]] = None, encoding: str = "ascii", ) -> List[str]: @@ -239,6 +267,7 @@ def parse_strings_from_args( args: Dict[bytes, List[bytes]], name: str, default: Optional[List[str]] = None, + *, required: bool = False, allowed_values: Optional[Iterable[str]] = None, encoding: str = "ascii", @@ -299,7 +328,20 @@ def parse_string_from_args( args: Dict[bytes, List[bytes]], name: str, default: Optional[str] = None, - required: Literal[True] = True, + *, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> Optional[str]: + ... + + +@overload +def parse_string_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[str] = None, + *, + required: Literal[True], allowed_values: Optional[Iterable[str]] = None, encoding: str = "ascii", ) -> str: -- cgit 1.5.1 From cdf569e46811cb498e17eccf81d8f7d645aa60e9 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Tue, 29 Jun 2021 10:15:34 +0100 Subject: 1.37.0 --- CHANGES.md | 6 ++++++ debian/changelog | 6 ++++++ synapse/__init__.py | 2 +- 3 files changed, 13 insertions(+), 1 deletion(-) (limited to 'synapse') diff --git a/CHANGES.md b/CHANGES.md index 2c7f24487c..5b924e2471 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,9 @@ +Synapse 1.37.0 (2021-06-29) +=========================== + +No significant changes. + + Synapse 1.37.0rc1 (2021-06-24) ============================== diff --git a/debian/changelog b/debian/changelog index e640dadde9..cf190b7dba 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.37.0) stable; urgency=medium + + * New synapse release 1.37.0. + + -- Synapse Packaging team Tue, 29 Jun 2021 10:15:25 +0100 + matrix-synapse-py3 (1.36.0) stable; urgency=medium * New synapse release 1.36.0. diff --git a/synapse/__init__.py b/synapse/__init__.py index 6d1c6d6f72..c865d2e100 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.37.0rc1" +__version__ = "1.37.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.5.1 From a0ed0f363eb84f273b2cc706fcc5542d77a94463 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 29 Jun 2021 11:08:06 +0100 Subject: Soft-fail spammy events received over federation (#10263) --- changelog.d/10263.feature | 1 + synapse/federation/federation_base.py | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) create mode 100644 changelog.d/10263.feature (limited to 'synapse') diff --git a/changelog.d/10263.feature b/changelog.d/10263.feature new file mode 100644 index 0000000000..7b1d2fe60f --- /dev/null +++ b/changelog.d/10263.feature @@ -0,0 +1 @@ +Mark events received over federation which fail a spam check as "soft-failed". diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index c066617b92..2bfe6a3d37 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -89,12 +89,12 @@ class FederationBase: result = await self.spam_checker.check_event_for_spam(pdu) if result: - logger.warning( - "Event contains spam, redacting %s: %s", - pdu.event_id, - pdu.get_pdu_json(), - ) - return prune_event(pdu) + logger.warning("Event contains spam, soft-failing %s", pdu.event_id) + # we redact (to save disk space) as well as soft-failing (to stop + # using the event in prev_events). + redacted_event = prune_event(pdu) + redacted_event.internal_metadata.soft_failed = True + return redacted_event return pdu -- cgit 1.5.1 From 60efc51a2bbc31f18a71ad1338afc430bfa65597 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 29 Jun 2021 11:25:34 +0100 Subject: Migrate stream_ordering to a bigint (#10264) * Move background update names out to a separate class `EventsBackgroundUpdatesStore` gets inherited and we don't really want to further pollute the namespace. * Migrate stream_ordering to a bigint * changelog --- changelog.d/10264.bugfix | 1 + .../storage/databases/main/events_bg_updates.py | 136 ++++++++++++++++++--- synapse/storage/schema/__init__.py | 2 +- .../60/01recreate_stream_ordering.sql.postgres | 40 ++++++ 4 files changed, 163 insertions(+), 16 deletions(-) create mode 100644 changelog.d/10264.bugfix create mode 100644 synapse/storage/schema/main/delta/60/01recreate_stream_ordering.sql.postgres (limited to 'synapse') diff --git a/changelog.d/10264.bugfix b/changelog.d/10264.bugfix new file mode 100644 index 0000000000..7ebda7cdc2 --- /dev/null +++ b/changelog.d/10264.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse would return errors after 231 events were handled by the server. diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index cbe4be1437..39aaee743c 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -29,6 +29,25 @@ from synapse.types import JsonDict logger = logging.getLogger(__name__) +_REPLACE_STREAM_ORDRING_SQL_COMMANDS = ( + # there should be no leftover rows without a stream_ordering2, but just in case... + "UPDATE events SET stream_ordering2 = stream_ordering WHERE stream_ordering2 IS NULL", + # finally, we can drop the rule and switch the columns + "DROP RULE populate_stream_ordering2 ON events", + "ALTER TABLE events DROP COLUMN stream_ordering", + "ALTER TABLE events RENAME COLUMN stream_ordering2 TO stream_ordering", +) + + +class _BackgroundUpdates: + EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" + EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" + DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" + POPULATE_STREAM_ORDERING2 = "populate_stream_ordering2" + INDEX_STREAM_ORDERING2 = "index_stream_ordering2" + REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column" + + @attr.s(slots=True, frozen=True) class _CalculateChainCover: """Return value for _calculate_chain_cover_txn.""" @@ -48,19 +67,15 @@ class _CalculateChainCover: class EventsBackgroundUpdatesStore(SQLBaseStore): - - EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" - EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" - DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities" - def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) self.db_pool.updates.register_background_update_handler( - self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts + _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, + self._background_reindex_origin_server_ts, ) self.db_pool.updates.register_background_update_handler( - self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, + _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, self._background_reindex_fields_sender, ) @@ -85,7 +100,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) self.db_pool.updates.register_background_update_handler( - self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update + _BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES, + self._cleanup_extremities_bg_update, ) self.db_pool.updates.register_background_update_handler( @@ -139,6 +155,24 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): self._purged_chain_cover_index, ) + # bg updates for replacing stream_ordering with a BIGINT + # (these only run on postgres.) + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.POPULATE_STREAM_ORDERING2, + self._background_populate_stream_ordering2, + ) + self.db_pool.updates.register_background_index_update( + _BackgroundUpdates.INDEX_STREAM_ORDERING2, + index_name="events_stream_ordering", + table="events", + columns=["stream_ordering2"], + unique=True, + ) + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.REPLACE_STREAM_ORDERING_COLUMN, + self._background_replace_stream_ordering_column, + ) + async def _background_reindex_fields_sender(self, progress, batch_size): target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] @@ -190,18 +224,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): } self.db_pool.updates._background_update_progress_txn( - txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress + txn, _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress ) return len(rows) result = await self.db_pool.runInteraction( - self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn + _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn ) if not result: await self.db_pool.updates._end_background_update( - self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME + _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME ) return result @@ -264,18 +298,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): } self.db_pool.updates._background_update_progress_txn( - txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress + txn, _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, progress ) return len(rows_to_update) result = await self.db_pool.runInteraction( - self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn + _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn ) if not result: await self.db_pool.updates._end_background_update( - self.EVENT_ORIGIN_SERVER_TS_NAME + _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME ) return result @@ -454,7 +488,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): if not num_handled: await self.db_pool.updates._end_background_update( - self.DELETE_SOFT_FAILED_EXTREMITIES + _BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES ) def _drop_table_txn(txn): @@ -1009,3 +1043,75 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): await self.db_pool.updates._end_background_update("purged_chain_cover") return result + + async def _background_populate_stream_ordering2( + self, progress: JsonDict, batch_size: int + ) -> int: + """Populate events.stream_ordering2, then replace stream_ordering + + This is to deal with the fact that stream_ordering was initially created as a + 32-bit integer field. + """ + batch_size = max(batch_size, 1) + + def process(txn: Cursor) -> int: + # if this is the first pass, find the minimum stream ordering + last_stream = progress.get("last_stream") + if last_stream is None: + txn.execute( + """ + SELECT stream_ordering FROM events ORDER BY stream_ordering LIMIT 1 + """ + ) + rows = txn.fetchall() + if not rows: + return 0 + last_stream = rows[0][0] - 1 + + txn.execute( + """ + UPDATE events SET stream_ordering2=stream_ordering + WHERE stream_ordering > ? AND stream_ordering <= ? + """, + (last_stream, last_stream + batch_size), + ) + row_count = txn.rowcount + + self.db_pool.updates._background_update_progress_txn( + txn, + _BackgroundUpdates.POPULATE_STREAM_ORDERING2, + {"last_stream": last_stream + batch_size}, + ) + return row_count + + result = await self.db_pool.runInteraction( + "_background_populate_stream_ordering2", process + ) + + if result != 0: + return result + + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.POPULATE_STREAM_ORDERING2 + ) + return 0 + + async def _background_replace_stream_ordering_column( + self, progress: JsonDict, batch_size: int + ) -> int: + """Drop the old 'stream_ordering' column and rename 'stream_ordering2' into its place.""" + + def process(txn: Cursor) -> None: + for sql in _REPLACE_STREAM_ORDRING_SQL_COMMANDS: + logger.info("completing stream_ordering migration: %s", sql) + txn.execute(sql) + + await self.db_pool.runInteraction( + "_background_replace_stream_ordering_column", process + ) + + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.REPLACE_STREAM_ORDERING_COLUMN + ) + + return 0 diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index d36ba1d773..0a53b73ccc 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 59 +SCHEMA_VERSION = 60 """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the diff --git a/synapse/storage/schema/main/delta/60/01recreate_stream_ordering.sql.postgres b/synapse/storage/schema/main/delta/60/01recreate_stream_ordering.sql.postgres new file mode 100644 index 0000000000..88c9f8bd0d --- /dev/null +++ b/synapse/storage/schema/main/delta/60/01recreate_stream_ordering.sql.postgres @@ -0,0 +1,40 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This migration handles the process of changing the type of `stream_ordering` to +-- a BIGINT. +-- +-- Note that this is only a problem on postgres as sqlite only has one "integer" type +-- which can cope with values up to 2^63. + +-- First add a new column to contain the bigger stream_ordering +ALTER TABLE events ADD COLUMN stream_ordering2 BIGINT; + +-- Create a rule which will populate it for new rows. +CREATE OR REPLACE RULE "populate_stream_ordering2" AS + ON INSERT TO events + DO UPDATE events SET stream_ordering2=NEW.stream_ordering WHERE stream_ordering=NEW.stream_ordering; + +-- Start a bg process to populate it for old events +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6001, 'populate_stream_ordering2', '{}'); + +-- ... and another to build an index on it +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (6001, 'index_stream_ordering2', '{}', 'populate_stream_ordering2'); + +-- ... and another to do the switcheroo +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (6001, 'replace_stream_ordering_column', '{}', 'index_stream_ordering2'); -- cgit 1.5.1 From 7647b0337fb5d936c88c5949fa92c07bf2137ad0 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 29 Jun 2021 12:43:36 +0100 Subject: Fix `populate_stream_ordering2` background job (#10267) It was possible for us not to find any rows in a batch, and hence conclude that we had finished. Let's not do that. --- changelog.d/10267.bugfix | 1 + .../storage/databases/main/events_bg_updates.py | 28 ++++++++++------------ 2 files changed, 13 insertions(+), 16 deletions(-) create mode 100644 changelog.d/10267.bugfix (limited to 'synapse') diff --git a/changelog.d/10267.bugfix b/changelog.d/10267.bugfix new file mode 100644 index 0000000000..7ebda7cdc2 --- /dev/null +++ b/changelog.d/10267.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where Synapse would return errors after 231 events were handled by the server. diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 39aaee743c..da3a7df27b 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -1055,32 +1055,28 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): batch_size = max(batch_size, 1) def process(txn: Cursor) -> int: - # if this is the first pass, find the minimum stream ordering - last_stream = progress.get("last_stream") - if last_stream is None: - txn.execute( - """ - SELECT stream_ordering FROM events ORDER BY stream_ordering LIMIT 1 - """ - ) - rows = txn.fetchall() - if not rows: - return 0 - last_stream = rows[0][0] - 1 - + last_stream = progress.get("last_stream", -(1 << 31)) txn.execute( """ UPDATE events SET stream_ordering2=stream_ordering - WHERE stream_ordering > ? AND stream_ordering <= ? + WHERE stream_ordering IN ( + SELECT stream_ordering FROM events WHERE stream_ordering > ? + ORDER BY stream_ordering LIMIT ? + ) + RETURNING stream_ordering; """, - (last_stream, last_stream + batch_size), + (last_stream, batch_size), ) row_count = txn.rowcount + if row_count == 0: + return 0 + last_stream = max(row[0] for row in txn) + logger.info("populated stream_ordering2 up to %i", last_stream) self.db_pool.updates._background_update_progress_txn( txn, _BackgroundUpdates.POPULATE_STREAM_ORDERING2, - {"last_stream": last_stream + batch_size}, + {"last_stream": last_stream}, ) return row_count -- cgit 1.5.1 From f55836929d3c64f3f8d883d8f3643a88b6c9cbca Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 29 Jun 2021 12:00:04 -0400 Subject: Do not recurse into non-spaces in the spaces summary. (#10256) Previously m.child.room events in non-space rooms would be treated as part of the room graph, but this is no longer supported. --- changelog.d/10256.misc | 1 + synapse/api/constants.py | 6 +++++ synapse/handlers/space_summary.py | 11 +++++++-- tests/handlers/test_space_summary.py | 48 +++++++++++++++++++----------------- tests/rest/client/v1/utils.py | 3 ++- 5 files changed, 43 insertions(+), 26 deletions(-) create mode 100644 changelog.d/10256.misc (limited to 'synapse') diff --git a/changelog.d/10256.misc b/changelog.d/10256.misc new file mode 100644 index 0000000000..adef12fcb9 --- /dev/null +++ b/changelog.d/10256.misc @@ -0,0 +1 @@ +Improve the performance of the spaces summary endpoint by only recursing into spaces (and not rooms in general). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 414e4c019a..8363c2bb0f 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -201,6 +201,12 @@ class EventContentFields: ) +class RoomTypes: + """Understood values of the room_type field of m.room.create events.""" + + SPACE = "m.space" + + class RoomEncryptionAlgorithms: MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2" DEFAULT = MEGOLM_V1_AES_SHA2 diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py index 17fc47ce16..266f369883 100644 --- a/synapse/handlers/space_summary.py +++ b/synapse/handlers/space_summary.py @@ -25,6 +25,7 @@ from synapse.api.constants import ( EventTypes, HistoryVisibility, Membership, + RoomTypes, ) from synapse.events import EventBase from synapse.events.utils import format_event_for_client_v2 @@ -318,7 +319,8 @@ class SpaceSummaryHandler: Returns: A tuple of: - An iterable of a single value of the room. + The room information, if the room should be returned to the + user. None, otherwise. An iterable of the sorted children events. This may be limited to a maximum size or may include all children. @@ -328,7 +330,11 @@ class SpaceSummaryHandler: room_entry = await self._build_room_entry(room_id) - # look for child rooms/spaces. + # If the room is not a space, return just the room information. + if room_entry.get("room_type") != RoomTypes.SPACE: + return room_entry, () + + # Otherwise, look for child rooms/spaces. child_events = await self._get_child_events(room_id) if suggested_only: @@ -348,6 +354,7 @@ class SpaceSummaryHandler: event_format=format_event_for_client_v2, ) ) + return room_entry, events_result async def _summarize_remote_room( diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py index 131d362ccc..9771d3fb3b 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py @@ -14,6 +14,7 @@ from typing import Any, Iterable, Optional, Tuple from unittest import mock +from synapse.api.constants import EventContentFields, RoomTypes from synapse.api.errors import AuthError from synapse.handlers.space_summary import _child_events_comparison_key from synapse.rest import admin @@ -97,9 +98,21 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self.hs = hs self.handler = self.hs.get_space_summary_handler() + # Create a user. self.user = self.register_user("user", "pass") self.token = self.login("user", "pass") + # Create a space and a child room. + self.space = self.helper.create_room_as( + self.user, + tok=self.token, + extra_content={ + "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE} + }, + ) + self.room = self.helper.create_room_as(self.user, tok=self.token) + self._add_child(self.space, self.room, self.token) + def _add_child(self, space_id: str, room_id: str, token: str) -> None: """Add a child room to a space.""" self.helper.send_state( @@ -128,43 +141,32 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): def test_simple_space(self): """Test a simple space with a single room.""" - space = self.helper.create_room_as(self.user, tok=self.token) - room = self.helper.create_room_as(self.user, tok=self.token) - self._add_child(space, room, self.token) - - result = self.get_success(self.handler.get_space_summary(self.user, space)) + result = self.get_success(self.handler.get_space_summary(self.user, self.space)) # The result should have the space and the room in it, along with a link # from space -> room. - self._assert_rooms(result, [space, room]) - self._assert_events(result, [(space, room)]) + self._assert_rooms(result, [self.space, self.room]) + self._assert_events(result, [(self.space, self.room)]) def test_visibility(self): """A user not in a space cannot inspect it.""" - space = self.helper.create_room_as(self.user, tok=self.token) - room = self.helper.create_room_as(self.user, tok=self.token) - self._add_child(space, room, self.token) - user2 = self.register_user("user2", "pass") token2 = self.login("user2", "pass") # The user cannot see the space. - self.get_failure(self.handler.get_space_summary(user2, space), AuthError) + self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError) # Joining the room causes it to be visible. - self.helper.join(space, user2, tok=token2) - result = self.get_success(self.handler.get_space_summary(user2, space)) + self.helper.join(self.space, user2, tok=token2) + result = self.get_success(self.handler.get_space_summary(user2, self.space)) # The result should only have the space, but includes the link to the room. - self._assert_rooms(result, [space]) - self._assert_events(result, [(space, room)]) + self._assert_rooms(result, [self.space]) + self._assert_events(result, [(self.space, self.room)]) def test_world_readable(self): """A world-readable room is visible to everyone.""" - space = self.helper.create_room_as(self.user, tok=self.token) - room = self.helper.create_room_as(self.user, tok=self.token) - self._add_child(space, room, self.token) self.helper.send_state( - space, + self.space, event_type="m.room.history_visibility", body={"history_visibility": "world_readable"}, tok=self.token, @@ -173,6 +175,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): user2 = self.register_user("user2", "pass") # The space should be visible, as well as the link to the room. - result = self.get_success(self.handler.get_space_summary(user2, space)) - self._assert_rooms(result, [space]) - self._assert_events(result, [(space, room)]) + result = self.get_success(self.handler.get_space_summary(user2, self.space)) + self._assert_rooms(result, [self.space]) + self._assert_events(result, [(self.space, self.room)]) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index ed55a640af..69798e95c3 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -52,6 +52,7 @@ class RestHelper: room_version: str = None, tok: str = None, expect_code: int = 200, + extra_content: Optional[Dict] = None, ) -> str: """ Create a room. @@ -72,7 +73,7 @@ class RestHelper: temp_id = self.auth_user_id self.auth_user_id = room_creator path = "/_matrix/client/r0/createRoom" - content = {} + content = extra_content or {} if not is_public: content["visibility"] = "private" if room_version: -- cgit 1.5.1 From 85d237eba789a667109ced140026d2494b210310 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Jun 2021 19:15:47 +0100 Subject: Add a distributed lock (#10269) This adds a simple best effort locking mechanism that works cross workers. --- changelog.d/10269.misc | 1 + synapse/app/generic_worker.py | 2 + synapse/storage/databases/main/__init__.py | 2 + synapse/storage/databases/main/lock.py | 334 +++++++++++++++++++++++ synapse/storage/schema/main/delta/59/15locks.sql | 37 +++ tests/storage/databases/main/test_lock.py | 100 +++++++ 6 files changed, 476 insertions(+) create mode 100644 changelog.d/10269.misc create mode 100644 synapse/storage/databases/main/lock.py create mode 100644 synapse/storage/schema/main/delta/59/15locks.sql create mode 100644 tests/storage/databases/main/test_lock.py (limited to 'synapse') diff --git a/changelog.d/10269.misc b/changelog.d/10269.misc new file mode 100644 index 0000000000..23e590490c --- /dev/null +++ b/changelog.d/10269.misc @@ -0,0 +1 @@ +Add a distributed lock implementation. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index af8a1833f3..5b041fcaad 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -108,6 +108,7 @@ from synapse.server import HomeServer from synapse.storage.databases.main.censor_events import CensorEventsStore from synapse.storage.databases.main.client_ips import ClientIpWorkerStore from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyStore +from synapse.storage.databases.main.lock import LockStore from synapse.storage.databases.main.media_repository import MediaRepositoryStore from synapse.storage.databases.main.metrics import ServerMetricsStore from synapse.storage.databases.main.monthly_active_users import ( @@ -249,6 +250,7 @@ class GenericWorkerSlavedStore( ServerMetricsStore, SearchStore, TransactionWorkerStore, + LockStore, BaseSlavedStore, ): pass diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 9cce62ae6c..a3fddea042 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -46,6 +46,7 @@ from .events_forward_extremities import EventForwardExtremitiesStore from .filtering import FilteringStore from .group_server import GroupServerStore from .keys import KeyStore +from .lock import LockStore from .media_repository import MediaRepositoryStore from .metrics import ServerMetricsStore from .monthly_active_users import MonthlyActiveUsersStore @@ -119,6 +120,7 @@ class DataStore( CacheInvalidationWorkerStore, ServerMetricsStore, EventForwardExtremitiesStore, + LockStore, ): def __init__(self, database: DatabasePool, db_conn, hs): self.hs = hs diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py new file mode 100644 index 0000000000..e76188328c --- /dev/null +++ b/synapse/storage/databases/main/lock.py @@ -0,0 +1,334 @@ +# Copyright 2021 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from types import TracebackType +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type + +from twisted.internet.interfaces import IReactorCore + +from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.types import Connection +from synapse.util import Clock +from synapse.util.stringutils import random_string + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +logger = logging.getLogger(__name__) + + +# How often to renew an acquired lock by updating the `last_renewed_ts` time in +# the lock table. +_RENEWAL_INTERVAL_MS = 30 * 1000 + +# How long before an acquired lock times out. +_LOCK_TIMEOUT_MS = 2 * 60 * 1000 + + +class LockStore(SQLBaseStore): + """Provides a best effort distributed lock between worker instances. + + Locks are identified by a name and key. A lock is acquired by inserting into + the `worker_locks` table if a) there is no existing row for the name/key or + b) the existing row has a `last_renewed_ts` older than `_LOCK_TIMEOUT_MS`. + + When a lock is taken out the instance inserts a random `token`, the instance + that holds that token holds the lock until it drops (or times out). + + The instance that holds the lock should regularly update the + `last_renewed_ts` column with the current time. + """ + + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + super().__init__(database, db_conn, hs) + + self._reactor = hs.get_reactor() + self._instance_name = hs.get_instance_id() + + # A map from `(lock_name, lock_key)` to the token of any locks that we + # think we currently hold. + self._live_tokens: Dict[Tuple[str, str], str] = {} + + # When we shut down we want to remove the locks. Technically this can + # lead to a race, as we may drop the lock while we are still processing. + # However, a) it should be a small window, b) the lock is best effort + # anyway and c) we want to really avoid leaking locks when we restart. + hs.get_reactor().addSystemEventTrigger( + "before", + "shutdown", + self._on_shutdown, + ) + + @wrap_as_background_process("LockStore._on_shutdown") + async def _on_shutdown(self) -> None: + """Called when the server is shutting down""" + logger.info("Dropping held locks due to shutdown") + + for (lock_name, lock_key), token in self._live_tokens.items(): + await self._drop_lock(lock_name, lock_key, token) + + logger.info("Dropped locks due to shutdown") + + async def try_acquire_lock(self, lock_name: str, lock_key: str) -> Optional["Lock"]: + """Try to acquire a lock for the given name/key. Will return an async + context manager if the lock is successfully acquired, which *must* be + used (otherwise the lock will leak). + """ + + now = self._clock.time_msec() + token = random_string(6) + + if self.db_pool.engine.can_native_upsert: + + def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool: + # We take out the lock if either a) there is no row for the lock + # already or b) the existing row has timed out. + sql = """ + INSERT INTO worker_locks (lock_name, lock_key, instance_name, token, last_renewed_ts) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (lock_name, lock_key) + DO UPDATE + SET + token = EXCLUDED.token, + instance_name = EXCLUDED.instance_name, + last_renewed_ts = EXCLUDED.last_renewed_ts + WHERE + worker_locks.last_renewed_ts < ? + """ + txn.execute( + sql, + ( + lock_name, + lock_key, + self._instance_name, + token, + now, + now - _LOCK_TIMEOUT_MS, + ), + ) + + # We only acquired the lock if we inserted or updated the table. + return bool(txn.rowcount) + + did_lock = await self.db_pool.runInteraction( + "try_acquire_lock", + _try_acquire_lock_txn, + # We can autocommit here as we're executing a single query, this + # will avoid serialization errors. + db_autocommit=True, + ) + if not did_lock: + return None + + else: + # If we're on an old SQLite we emulate the above logic by first + # clearing out any existing stale locks and then upserting. + + def _try_acquire_lock_emulated_txn(txn: LoggingTransaction) -> bool: + sql = """ + DELETE FROM worker_locks + WHERE + lock_name = ? + AND lock_key = ? + AND last_renewed_ts < ? + """ + txn.execute( + sql, + (lock_name, lock_key, now - _LOCK_TIMEOUT_MS), + ) + + inserted = self.db_pool.simple_upsert_txn_emulated( + txn, + table="worker_locks", + keyvalues={ + "lock_name": lock_name, + "lock_key": lock_key, + }, + values={}, + insertion_values={ + "token": token, + "last_renewed_ts": self._clock.time_msec(), + "instance_name": self._instance_name, + }, + ) + + return inserted + + did_lock = await self.db_pool.runInteraction( + "try_acquire_lock_emulated", _try_acquire_lock_emulated_txn + ) + + if not did_lock: + return None + + self._live_tokens[(lock_name, lock_key)] = token + + return Lock( + self._reactor, + self._clock, + self, + lock_name=lock_name, + lock_key=lock_key, + token=token, + ) + + async def _is_lock_still_valid( + self, lock_name: str, lock_key: str, token: str + ) -> bool: + """Checks whether this instance still holds the lock.""" + last_renewed_ts = await self.db_pool.simple_select_one_onecol( + table="worker_locks", + keyvalues={ + "lock_name": lock_name, + "lock_key": lock_key, + "token": token, + }, + retcol="last_renewed_ts", + allow_none=True, + desc="is_lock_still_valid", + ) + return ( + last_renewed_ts is not None + and self._clock.time_msec() - _LOCK_TIMEOUT_MS < last_renewed_ts + ) + + async def _renew_lock(self, lock_name: str, lock_key: str, token: str) -> None: + """Attempt to renew the lock if we still hold it.""" + await self.db_pool.simple_update( + table="worker_locks", + keyvalues={ + "lock_name": lock_name, + "lock_key": lock_key, + "token": token, + }, + updatevalues={"last_renewed_ts": self._clock.time_msec()}, + desc="renew_lock", + ) + + async def _drop_lock(self, lock_name: str, lock_key: str, token: str) -> None: + """Attempt to drop the lock, if we still hold it""" + await self.db_pool.simple_delete( + table="worker_locks", + keyvalues={ + "lock_name": lock_name, + "lock_key": lock_key, + "token": token, + }, + desc="drop_lock", + ) + + self._live_tokens.pop((lock_name, lock_key), None) + + +class Lock: + """An async context manager that manages an acquired lock, ensuring it is + regularly renewed and dropping it when the context manager exits. + + The lock object has an `is_still_valid` method which can be used to + double-check the lock is still valid, if e.g. processing work in a loop. + + For example: + + lock = await self.store.try_acquire_lock(...) + if not lock: + return + + async with lock: + for item in work: + await process(item) + + if not await lock.is_still_valid(): + break + """ + + def __init__( + self, + reactor: IReactorCore, + clock: Clock, + store: LockStore, + lock_name: str, + lock_key: str, + token: str, + ) -> None: + self._reactor = reactor + self._clock = clock + self._store = store + self._lock_name = lock_name + self._lock_key = lock_key + + self._token = token + + self._looping_call = clock.looping_call( + self._renew, _RENEWAL_INTERVAL_MS, store, lock_name, lock_key, token + ) + + self._dropped = False + + @staticmethod + @wrap_as_background_process("Lock._renew") + async def _renew( + store: LockStore, + lock_name: str, + lock_key: str, + token: str, + ) -> None: + """Renew the lock. + + Note: this is a static method, rather than using self.*, so that we + don't end up with a reference to `self` in the reactor, which would stop + this from being cleaned up if we dropped the context manager. + """ + await store._renew_lock(lock_name, lock_key, token) + + async def is_still_valid(self) -> bool: + """Check if the lock is still held by us""" + return await self._store._is_lock_still_valid( + self._lock_name, self._lock_key, self._token + ) + + async def __aenter__(self) -> None: + if self._dropped: + raise Exception("Cannot reuse a Lock object") + + async def __aexit__( + self, + _exctype: Optional[Type[BaseException]], + _excinst: Optional[BaseException], + _exctb: Optional[TracebackType], + ) -> bool: + if self._looping_call.running: + self._looping_call.stop() + + await self._store._drop_lock(self._lock_name, self._lock_key, self._token) + self._dropped = True + + return False + + def __del__(self) -> None: + if not self._dropped: + # We should not be dropped without the lock being released (unless + # we're shutting down), but if we are then let's at least stop + # renewing the lock. + if self._looping_call.running: + self._looping_call.stop() + + if self._reactor.running: + logger.error( + "Lock for (%s, %s) dropped without being released", + self._lock_name, + self._lock_key, + ) diff --git a/synapse/storage/schema/main/delta/59/15locks.sql b/synapse/storage/schema/main/delta/59/15locks.sql new file mode 100644 index 0000000000..8b2999ff3e --- /dev/null +++ b/synapse/storage/schema/main/delta/59/15locks.sql @@ -0,0 +1,37 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- A noddy implementation of a distributed lock across workers. While a worker +-- has taken a lock out they should regularly update the `last_renewed_ts` +-- column, a lock will be considered dropped if `last_renewed_ts` is from ages +-- ago. +CREATE TABLE worker_locks ( + lock_name TEXT NOT NULL, + lock_key TEXT NOT NULL, + -- We write the instance name to ease manual debugging, we don't ever read + -- from it. + -- Note: instance names aren't guarenteed to be unique. + instance_name TEXT NOT NULL, + -- A random string generated each time an instance takes out a lock. Used by + -- the instance to tell whether the lock is still held by it (e.g. in the + -- case where the process stalls for a long time the lock may time out and + -- be taken out by another instance, at which point the original instance + -- can tell it no longer holds the lock as the tokens no longer match). + token TEXT NOT NULL, + last_renewed_ts BIGINT NOT NULL +); + +CREATE UNIQUE INDEX worker_locks_key ON worker_locks (lock_name, lock_key); diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py new file mode 100644 index 0000000000..9ca70e7367 --- /dev/null +++ b/tests/storage/databases/main/test_lock.py @@ -0,0 +1,100 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.server import HomeServer +from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS + +from tests import unittest + + +class LockTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs: HomeServer): + self.store = hs.get_datastore() + + def test_simple_lock(self): + """Test that we can take out a lock and that while we hold it nobody + else can take it out. + """ + # First to acquire this lock, so it should complete + lock = self.get_success(self.store.try_acquire_lock("name", "key")) + self.assertIsNotNone(lock) + + # Enter the context manager + self.get_success(lock.__aenter__()) + + # Attempting to acquire the lock again fails. + lock2 = self.get_success(self.store.try_acquire_lock("name", "key")) + self.assertIsNone(lock2) + + # Calling `is_still_valid` reports true. + self.assertTrue(self.get_success(lock.is_still_valid())) + + # Drop the lock + self.get_success(lock.__aexit__(None, None, None)) + + # We can now acquire the lock again. + lock3 = self.get_success(self.store.try_acquire_lock("name", "key")) + self.assertIsNotNone(lock3) + self.get_success(lock3.__aenter__()) + self.get_success(lock3.__aexit__(None, None, None)) + + def test_maintain_lock(self): + """Test that we don't time out locks while they're still active""" + + lock = self.get_success(self.store.try_acquire_lock("name", "key")) + self.assertIsNotNone(lock) + + self.get_success(lock.__aenter__()) + + # Wait for ages with the lock, we should not be able to get the lock. + self.reactor.advance(5 * _LOCK_TIMEOUT_MS / 1000) + + lock2 = self.get_success(self.store.try_acquire_lock("name", "key")) + self.assertIsNone(lock2) + + self.get_success(lock.__aexit__(None, None, None)) + + def test_timeout_lock(self): + """Test that we time out locks if they're not updated for ages""" + + lock = self.get_success(self.store.try_acquire_lock("name", "key")) + self.assertIsNotNone(lock) + + self.get_success(lock.__aenter__()) + + # We simulate the process getting stuck by cancelling the looping call + # that keeps the lock active. + lock._looping_call.stop() + + # Wait for the lock to timeout. + self.reactor.advance(2 * _LOCK_TIMEOUT_MS / 1000) + + lock2 = self.get_success(self.store.try_acquire_lock("name", "key")) + self.assertIsNotNone(lock2) + + self.assertFalse(self.get_success(lock.is_still_valid())) + + def test_drop(self): + """Test that dropping the context manager means we stop renewing the lock""" + + lock = self.get_success(self.store.try_acquire_lock("name", "key")) + self.assertIsNotNone(lock) + + del lock + + # Wait for the lock to timeout. + self.reactor.advance(2 * _LOCK_TIMEOUT_MS / 1000) + + lock2 = self.get_success(self.store.try_acquire_lock("name", "key")) + self.assertIsNotNone(lock2) -- cgit 1.5.1 From c54db67d0ea5b5967b7ea918c66a222a75b8ced1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Jun 2021 19:55:22 +0100 Subject: Handle inbound events from federation asynchronously (#10272) Fixes #9490 This will break a couple of SyTest that are expecting failures to be added to the response of a federation /send, which obviously doesn't happen now that things are asynchronous. Two drawbacks: Currently there is no logic to handle any events left in the staging area after restart, and so they'll only be handled on the next incoming event in that room. That can be fixed separately. We now only process one event per room at a time. This can be fixed up further down the line. --- changelog.d/10272.bugfix | 1 + synapse/federation/federation_server.py | 98 +++++++++++++++++- synapse/storage/databases/main/event_federation.py | 109 ++++++++++++++++++++- .../main/delta/59/16federation_inbound_staging.sql | 32 ++++++ sytest-blacklist | 6 ++ 5 files changed, 241 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10272.bugfix create mode 100644 synapse/storage/schema/main/delta/59/16federation_inbound_staging.sql (limited to 'synapse') diff --git a/changelog.d/10272.bugfix b/changelog.d/10272.bugfix new file mode 100644 index 0000000000..3cefa05788 --- /dev/null +++ b/changelog.d/10272.bugfix @@ -0,0 +1 @@ +Handle inbound events from federation asynchronously. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2b07f18529..1d050e54e2 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -44,7 +44,7 @@ from synapse.api.errors import ( SynapseError, UnsupportedRoomVersionError, ) -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions @@ -57,10 +57,12 @@ from synapse.logging.context import ( ) from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace from synapse.logging.utils import log_function +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.replication.http.federation import ( ReplicationFederationSendEduRestServlet, ReplicationGetQueryRestServlet, ) +from synapse.storage.databases.main.lock import Lock from synapse.types import JsonDict from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute @@ -96,6 +98,11 @@ last_pdu_ts_metric = Gauge( ) +# The name of the lock to use when process events in a room received over +# federation. +_INBOUND_EVENT_HANDLING_LOCK_NAME = "federation_inbound_pdu" + + class FederationServer(FederationBase): def __init__(self, hs: "HomeServer"): super().__init__(hs) @@ -834,7 +841,94 @@ class FederationServer(FederationBase): except SynapseError as e: raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id) - await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) + # Add the event to our staging area + await self.store.insert_received_event_to_staging(origin, pdu) + + # Try and acquire the processing lock for the room, if we get it start a + # background process for handling the events in the room. + lock = await self.store.try_acquire_lock( + _INBOUND_EVENT_HANDLING_LOCK_NAME, pdu.room_id + ) + if lock: + self._process_incoming_pdus_in_room_inner( + pdu.room_id, room_version, lock, origin, pdu + ) + + @wrap_as_background_process("_process_incoming_pdus_in_room_inner") + async def _process_incoming_pdus_in_room_inner( + self, + room_id: str, + room_version: RoomVersion, + lock: Lock, + latest_origin: str, + latest_event: EventBase, + ) -> None: + """Process events in the staging area for the given room. + + The latest_origin and latest_event args are the latest origin and event + received. + """ + + # The common path is for the event we just received be the only event in + # the room, so instead of pulling the event out of the DB and parsing + # the event we just pull out the next event ID and check if that matches. + next_origin, next_event_id = await self.store.get_next_staged_event_id_for_room( + room_id + ) + if next_origin == latest_origin and next_event_id == latest_event.event_id: + origin = latest_origin + event = latest_event + else: + next = await self.store.get_next_staged_event_for_room( + room_id, room_version + ) + if not next: + return + + origin, event = next + + # We loop round until there are no more events in the room in the + # staging area, or we fail to get the lock (which means another process + # has started processing). + while True: + async with lock: + try: + await self.handler.on_receive_pdu( + origin, event, sent_to_us_directly=True + ) + except FederationError as e: + # XXX: Ideally we'd inform the remote we failed to process + # the event, but we can't return an error in the transaction + # response (as we've already responded). + logger.warning("Error handling PDU %s: %s", event.event_id, e) + except Exception: + f = failure.Failure() + logger.error( + "Failed to handle PDU %s", + event.event_id, + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore + ) + + await self.store.remove_received_event_from_staging( + origin, event.event_id + ) + + # We need to do this check outside the lock to avoid a race between + # a new event being inserted by another instance and it attempting + # to acquire the lock. + next = await self.store.get_next_staged_event_for_room( + room_id, room_version + ) + if not next: + break + + origin, event = next + + lock = await self.store.try_acquire_lock( + _INBOUND_EVENT_HANDLING_LOCK_NAME, room_id + ) + if not lock: + return def __str__(self) -> str: return "" % self.server_name diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index c0ea445550..f23f8c6ecf 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -14,18 +14,20 @@ import itertools import logging from queue import Empty, PriorityQueue -from typing import Collection, Dict, Iterable, List, Set, Tuple +from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple from synapse.api.constants import MAX_DEPTH from synapse.api.errors import StoreError -from synapse.events import EventBase +from synapse.api.room_versions import RoomVersion +from synapse.events import EventBase, make_event_from_dict from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter @@ -1044,6 +1046,107 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas _delete_old_forward_extrem_cache_txn, ) + async def insert_received_event_to_staging( + self, origin: str, event: EventBase + ) -> None: + """Insert a newly received event from federation into the staging area.""" + + # We use an upsert here to handle the case where we see the same event + # from the same server multiple times. + await self.db_pool.simple_upsert( + table="federation_inbound_events_staging", + keyvalues={ + "origin": origin, + "event_id": event.event_id, + }, + values={}, + insertion_values={ + "room_id": event.room_id, + "received_ts": self._clock.time_msec(), + "event_json": json_encoder.encode(event.get_dict()), + "internal_metadata": json_encoder.encode( + event.internal_metadata.get_dict() + ), + }, + desc="insert_received_event_to_staging", + ) + + async def remove_received_event_from_staging( + self, + origin: str, + event_id: str, + ) -> None: + """Remove the given event from the staging area""" + await self.db_pool.simple_delete( + table="federation_inbound_events_staging", + keyvalues={ + "origin": origin, + "event_id": event_id, + }, + desc="remove_received_event_from_staging", + ) + + async def get_next_staged_event_id_for_room( + self, + room_id: str, + ) -> Optional[Tuple[str, str]]: + """Get the next event ID in the staging area for the given room.""" + + def _get_next_staged_event_id_for_room_txn(txn): + sql = """ + SELECT origin, event_id + FROM federation_inbound_events_staging + WHERE room_id = ? + ORDER BY received_ts ASC + LIMIT 1 + """ + + txn.execute(sql, (room_id,)) + + return txn.fetchone() + + return await self.db_pool.runInteraction( + "get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn + ) + + async def get_next_staged_event_for_room( + self, + room_id: str, + room_version: RoomVersion, + ) -> Optional[Tuple[str, EventBase]]: + """Get the next event in the staging area for the given room.""" + + def _get_next_staged_event_for_room_txn(txn): + sql = """ + SELECT event_json, internal_metadata, origin + FROM federation_inbound_events_staging + WHERE room_id = ? + ORDER BY received_ts ASC + LIMIT 1 + """ + txn.execute(sql, (room_id,)) + + return txn.fetchone() + + row = await self.db_pool.runInteraction( + "get_next_staged_event_for_room", _get_next_staged_event_for_room_txn + ) + + if not row: + return None + + event_d = db_to_json(row[0]) + internal_metadata_d = db_to_json(row[1]) + origin = row[2] + + event = make_event_from_dict( + event_dict=event_d, + room_version=room_version, + internal_metadata_dict=internal_metadata_d, + ) + + return origin, event + class EventFederationStore(EventFederationWorkerStore): """Responsible for storing and serving up the various graphs associated diff --git a/synapse/storage/schema/main/delta/59/16federation_inbound_staging.sql b/synapse/storage/schema/main/delta/59/16federation_inbound_staging.sql new file mode 100644 index 0000000000..43bc5c025f --- /dev/null +++ b/synapse/storage/schema/main/delta/59/16federation_inbound_staging.sql @@ -0,0 +1,32 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- A staging area for newly received events over federation. +-- +-- Note we may store the same event multiple times if it comes from different +-- servers; this is to handle the case if we get a redacted and non-redacted +-- versions of the event. +CREATE TABLE federation_inbound_events_staging ( + origin TEXT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + received_ts BIGINT NOT NULL, + event_json TEXT NOT NULL, + internal_metadata TEXT NOT NULL +); + +CREATE INDEX federation_inbound_events_staging_room ON federation_inbound_events_staging(room_id, received_ts); +CREATE UNIQUE INDEX federation_inbound_events_staging_instance_event ON federation_inbound_events_staging(origin, event_id); diff --git a/sytest-blacklist b/sytest-blacklist index de9986357b..89c4e828fd 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -41,3 +41,9 @@ We can't peek into rooms with invited history_visibility We can't peek into rooms with joined history_visibility Local users can peek by room alias Peeked rooms only turn up in the sync for the device who peeked them + + +# Blacklisted due to changes made in #10272 +Outbound federation will ignore a missing event with bad JSON for room version 6 +Backfilled events whose prev_events are in a different room do not allow cross-room back-pagination +Federation rejects inbound events where the prev_events cannot be found -- cgit 1.5.1 From f99e9cc2da6afe49ed7a1fbe18ab08e68befa614 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 29 Jun 2021 19:58:25 +0100 Subject: v1.37.1a1 --- synapse/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse') diff --git a/synapse/__init__.py b/synapse/__init__.py index c865d2e100..0900492619 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.37.0" +__version__ = "1.37.1a1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.5.1 From d561367c18db3300804dee182e74b4a8fb7998e6 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 29 Jun 2021 21:39:30 +0100 Subject: 1.37.1rc1 --- CHANGES.md | 9 +++++++++ changelog.d/10269.bugfix | 1 - changelog.d/10272.bugfix | 1 - synapse/__init__.py | 2 +- 4 files changed, 10 insertions(+), 3 deletions(-) delete mode 100644 changelog.d/10269.bugfix delete mode 100644 changelog.d/10272.bugfix (limited to 'synapse') diff --git a/CHANGES.md b/CHANGES.md index eac91ffe02..8de3bad906 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,12 @@ +Synapse 1.37.1rc1 (2021-06-29) +============================== + +Features +-------- + +- Handle inbound events from federation asynchronously. ([\#10269](https://github.com/matrix-org/synapse/issues/10269), [\#10272](https://github.com/matrix-org/synapse/issues/10272)) + + Synapse 1.37.0 (2021-06-29) =========================== diff --git a/changelog.d/10269.bugfix b/changelog.d/10269.bugfix deleted file mode 100644 index 3cefa05788..0000000000 --- a/changelog.d/10269.bugfix +++ /dev/null @@ -1 +0,0 @@ -Handle inbound events from federation asynchronously. diff --git a/changelog.d/10272.bugfix b/changelog.d/10272.bugfix deleted file mode 100644 index 3cefa05788..0000000000 --- a/changelog.d/10272.bugfix +++ /dev/null @@ -1 +0,0 @@ -Handle inbound events from federation asynchronously. diff --git a/synapse/__init__.py b/synapse/__init__.py index 0900492619..2070724c34 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -47,7 +47,7 @@ try: except ImportError: pass -__version__ = "1.37.1a1" +__version__ = "1.37.1rc1" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when -- cgit 1.5.1