From dbc630a628e4fc6eb5eff09ce5edba062c0e9955 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Aug 2020 10:32:33 -0400 Subject: Use the JSON encoder without whitespace in more places. (#8124) --- synapse/handlers/devicemessage.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 610b08d00b..dcb4c82244 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -16,8 +16,6 @@ import logging from typing import Any, Dict -from canonicaljson import json - from synapse.api.errors import SynapseError from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( @@ -27,6 +25,7 @@ from synapse.logging.opentracing import ( start_active_span, ) from synapse.types import UserID, get_domain_from_id +from synapse.util import json_encoder from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) @@ -174,7 +173,7 @@ class DeviceMessageHandler(object): "sender": sender_user_id, "type": message_type, "message_id": message_id, - "org.matrix.opentracing_context": json.dumps(context), + "org.matrix.opentracing_context": json_encoder.encode(context), } log_kv({"local_messages": local_messages}) -- cgit 1.5.1 From 592cdf73be50af837d8c255ebcb5fbcd429c2954 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Aug 2020 10:39:41 -0400 Subject: Improve the error code when trying to register using a name reserved for guests. (#8135) --- changelog.d/8135.bugfix | 1 + synapse/handlers/register.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8135.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/8135.bugfix b/changelog.d/8135.bugfix new file mode 100644 index 0000000000..9d5c60ea00 --- /dev/null +++ b/changelog.d/8135.bugfix @@ -0,0 +1 @@ +Clarify the error code if a user tries to register with a numeric ID. This bug was introduced in v1.15.0. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 999bc6efb5..ccd96e4626 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -124,7 +124,9 @@ class RegistrationHandler(BaseHandler): try: int(localpart) raise SynapseError( - 400, "Numeric user IDs are reserved for guest users." + 400, + "Numeric user IDs are reserved for guest users.", + errcode=Codes.INVALID_USERNAME, ) except ValueError: pass -- cgit 1.5.1 From e259d63f73fd7599520d0c4a6f5082e5cd383d25 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Aug 2020 15:07:42 -0400 Subject: Stop shadow-banned users from sending invites. (#8095) --- changelog.d/8095.feature | 1 + synapse/api/errors.py | 8 +++ synapse/handlers/room.py | 16 +++++- synapse/handlers/room_member.py | 62 ++++++++++++++++++++++- synapse/rest/admin/rooms.py | 3 ++ synapse/rest/client/v1/room.py | 67 +++++++++++++++---------- tests/rest/client/v1/test_rooms.py | 100 +++++++++++++++++++++++++++++++++++++ 7 files changed, 226 insertions(+), 31 deletions(-) create mode 100644 changelog.d/8095.feature (limited to 'synapse/handlers') diff --git a/changelog.d/8095.feature b/changelog.d/8095.feature new file mode 100644 index 0000000000..813e6d0903 --- /dev/null +++ b/changelog.d/8095.feature @@ -0,0 +1 @@ +Add support for shadow-banning users (ignoring any message send requests). diff --git a/synapse/api/errors.py b/synapse/api/errors.py index a3f314118a..4888c0ec4d 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -604,3 +604,11 @@ class HttpResponseException(CodeMessageException): errmsg = j.pop("error", self.msg) return ProxiedRequestError(self.code, errmsg, errcode, j) + + +class ShadowBanError(Exception): + """ + Raised when a shadow-banned user attempts to perform an action. + + This should be caught and a proper "fake" success response sent to the user. + """ diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 442cca28e6..0fc71475c3 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -20,6 +20,7 @@ import itertools import logging import math +import random import string from collections import OrderedDict from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple @@ -626,6 +627,7 @@ class RoomCreationHandler(BaseHandler): if mapping: raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE) + invite_3pid_list = config.get("invite_3pid", []) invite_list = config.get("invite", []) for i in invite_list: try: @@ -634,6 +636,14 @@ class RoomCreationHandler(BaseHandler): except Exception: raise SynapseError(400, "Invalid user_id: %s" % (i,)) + if (invite_list or invite_3pid_list) and requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + + # Allow the request to go through, but remove any associated invites. + invite_3pid_list = [] + invite_list = [] + await self.event_creation_handler.assert_accepted_privacy_policy(requester) power_level_content_override = config.get("power_level_content_override") @@ -648,8 +658,6 @@ class RoomCreationHandler(BaseHandler): % (user_id,), ) - invite_3pid_list = config.get("invite_3pid", []) - visibility = config.get("visibility", None) is_public = visibility == "public" @@ -744,6 +752,8 @@ class RoomCreationHandler(BaseHandler): if is_direct: content["is_direct"] = is_direct + # Note that update_membership with an action of "invite" can raise a + # ShadowBanError, but this was handled above by emptying invite_list. _, last_stream_id = await self.room_member_handler.update_membership( requester, UserID.from_string(invitee), @@ -758,6 +768,8 @@ class RoomCreationHandler(BaseHandler): id_access_token = invite_3pid.get("id_access_token") # optional address = invite_3pid["address"] medium = invite_3pid["medium"] + # Note that do_3pid_invite can raise a ShadowBanError, but this was + # handled above by emptying invite_3pid_list. last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite( room_id, requester.user, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index aa1ccde211..3a6ee6378d 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -15,6 +15,7 @@ import abc import logging +import random from http import HTTPStatus from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union @@ -22,7 +23,13 @@ from unpaddedbase64 import encode_base64 from synapse import types from synapse.api.constants import MAX_DEPTH, EventTypes, Membership -from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + LimitExceededError, + ShadowBanError, + SynapseError, +) from synapse.api.ratelimiting import Ratelimiter from synapse.api.room_versions import EventFormatVersions from synapse.crypto.event_signing import compute_event_reference_hash @@ -285,6 +292,31 @@ class RoomMemberHandler(object): content: Optional[dict] = None, require_consent: bool = True, ) -> Tuple[str, int]: + """Update a user's membership in a room. + + Params: + requester: The user who is performing the update. + target: The user whose membership is being updated. + room_id: The room ID whose membership is being updated. + action: The membership change, see synapse.api.constants.Membership. + txn_id: The transaction ID, if given. + remote_room_hosts: Remote servers to send the update to. + third_party_signed: Information from a 3PID invite. + ratelimit: Whether to rate limit the request. + content: The content of the created event. + require_consent: Whether consent is required. + + Returns: + A tuple of the new event ID and stream ID. + + Raises: + ShadowBanError if a shadow-banned requester attempts to send an invite. + """ + if action == Membership.INVITE and requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() + key = (room_id,) with (await self.member_linearizer.queue(key)): @@ -773,6 +805,25 @@ class RoomMemberHandler(object): txn_id: Optional[str], id_access_token: Optional[str] = None, ) -> int: + """Invite a 3PID to a room. + + Args: + room_id: The room to invite the 3PID to. + inviter: The user sending the invite. + medium: The 3PID's medium. + address: The 3PID's address. + id_server: The identity server to use. + requester: The user making the request. + txn_id: The transaction ID this is part of, or None if this is not + part of a transaction. + id_access_token: The optional identity server access token. + + Returns: + The new stream ID. + + Raises: + ShadowBanError if the requester has been shadow-banned. + """ if self.config.block_non_admin_invites: is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: @@ -780,6 +831,11 @@ class RoomMemberHandler(object): 403, "Invites have been disabled on this server", Codes.FORBIDDEN ) + if requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() + # We need to rate limit *before* we send out any 3PID invites, so we # can't just rely on the standard ratelimiting of events. await self.base_handler.ratelimit(requester) @@ -804,6 +860,8 @@ class RoomMemberHandler(object): ) if invitee: + # Note that update_membership with an action of "invite" can raise + # a ShadowBanError, but this was done above already. _, stream_id = await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) @@ -1042,7 +1100,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): return event_id, stream_id # The room is too large. Leave. - requester = types.create_requester(user, None, False, None) + requester = types.create_requester(user, None, False, False, None) await self.update_membership( requester=requester, target=user, room_id=room_id, action="leave" ) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 7c292ef3f9..09726d52d6 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -316,6 +316,9 @@ class JoinRoomAliasServlet(RestServlet): join_rules_event = room_state.get((EventTypes.JoinRules, "")) if join_rules_event: if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC): + # update_membership with an action of "invite" can raise a + # ShadowBanError. This is not handled since it is assumed that + # an admin isn't going to call this API with a shadow-banned user. await self.room_member_handler.update_membership( requester=requester, target=fake_requester.user, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index f216382636..a9dd3a6aec 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -27,6 +27,7 @@ from synapse.api.errors import ( Codes, HttpResponseException, InvalidClientCredentialsError, + ShadowBanError, SynapseError, ) from synapse.api.filtering import Filter @@ -45,6 +46,7 @@ from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from synapse.util import json_decoder +from synapse.util.stringutils import random_string MYPY = False if MYPY: @@ -200,14 +202,17 @@ class RoomStateEventRestServlet(TransactionRestServlet): event_dict["state_key"] = state_key if event_type == EventTypes.Member: - membership = content.get("membership", None) - event_id, _ = await self.room_member_handler.update_membership( - requester, - target=UserID.from_string(state_key), - room_id=room_id, - action=membership, - content=content, - ) + try: + membership = content.get("membership", None) + event_id, _ = await self.room_member_handler.update_membership( + requester, + target=UserID.from_string(state_key), + room_id=room_id, + action=membership, + content=content, + ) + except ShadowBanError: + event_id = "$" + random_string(43) else: ( event, @@ -719,16 +724,20 @@ class RoomMembershipRestServlet(TransactionRestServlet): content = {} if membership_action == "invite" and self._has_3pid_invite_keys(content): - await self.room_member_handler.do_3pid_invite( - room_id, - requester.user, - content["medium"], - content["address"], - content["id_server"], - requester, - txn_id, - content.get("id_access_token"), - ) + try: + await self.room_member_handler.do_3pid_invite( + room_id, + requester.user, + content["medium"], + content["address"], + content["id_server"], + requester, + txn_id, + content.get("id_access_token"), + ) + except ShadowBanError: + # Pretend the request succeeded. + pass return 200, {} target = requester.user @@ -740,15 +749,19 @@ class RoomMembershipRestServlet(TransactionRestServlet): if "reason" in content: event_content = {"reason": content["reason"]} - await self.room_member_handler.update_membership( - requester=requester, - target=target, - room_id=room_id, - action=membership_action, - txn_id=txn_id, - third_party_signed=content.get("third_party_signed", None), - content=event_content, - ) + try: + await self.room_member_handler.update_membership( + requester=requester, + target=target, + room_id=room_id, + action=membership_action, + txn_id=txn_id, + third_party_signed=content.get("third_party_signed", None), + content=event_content, + ) + except ShadowBanError: + # Pretend the request succeeded. + pass return_value = {} diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index ef6b775ed2..e674eb90d7 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1974,3 +1974,103 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase): """An alias which does not point to the room raises a SynapseError.""" self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400) self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400) + + +class ShadowBannedTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + directory.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver): + self.banned_user_id = self.register_user("banned", "test") + self.banned_access_token = self.login("banned", "test") + + self.store = self.hs.get_datastore() + + self.get_success( + self.store.db_pool.simple_update( + table="users", + keyvalues={"name": self.banned_user_id}, + updatevalues={"shadow_banned": True}, + desc="shadow_ban", + ) + ) + + self.other_user_id = self.register_user("otheruser", "pass") + self.other_access_token = self.login("otheruser", "pass") + + def test_invite(self): + """Invites from shadow-banned users don't actually get sent.""" + + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # Inviting the user completes successfully. + self.helper.invite( + room=room_id, + src=self.banned_user_id, + tok=self.banned_access_token, + targ=self.other_user_id, + ) + + # But the user wasn't actually invited. + invited_rooms = self.get_success( + self.store.get_invited_rooms_for_local_user(self.other_user_id) + ) + self.assertEqual(invited_rooms, []) + + def test_invite_3pid(self): + """Ensure that a 3PID invite does not attempt to contact the identity server.""" + identity_handler = self.hs.get_handlers().identity_handler + identity_handler.lookup_3pid = Mock( + side_effect=AssertionError("This should not get called") + ) + + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + # Inviting the user completes successfully. + request, channel = self.make_request( + "POST", + "/rooms/%s/invite" % (room_id,), + {"id_server": "test", "medium": "email", "address": "test@test.test"}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + + # This should have raised an error earlier, but double check this wasn't called. + identity_handler.lookup_3pid.assert_not_called() + + def test_create_room(self): + """Invitations during a room creation should be discarded, but the room still gets created.""" + # The room creation is successful. + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/createRoom", + {"visibility": "public", "invite": [self.other_user_id]}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + room_id = channel.json_body["room_id"] + + # But the user wasn't actually invited. + invited_rooms = self.get_success( + self.store.get_invited_rooms_for_local_user(self.other_user_id) + ) + self.assertEqual(invited_rooms, []) + + # Since a real room was created, the other user should be able to join it. + self.helper.join(room_id, self.other_user_id, tok=self.other_access_token) + + # Both users should be in the room. + users = self.get_success(self.store.get_users_in_room(room_id)) + self.assertCountEqual(users, ["@banned:test", "@otheruser:test"]) -- cgit 1.5.1 From 3f91638da6ea0aeaf789ddc8ca1e624a11b7ebb2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Aug 2020 15:42:58 -0400 Subject: Allow denying or shadow banning registrations via the spam checker (#8034) --- changelog.d/8034.feature | 1 + synapse/events/spamcheck.py | 35 ++++++++++++++- synapse/handlers/auth.py | 8 ++++ synapse/handlers/cas_handler.py | 11 ++++- synapse/handlers/oidc_handler.py | 21 +++++++-- synapse/handlers/register.py | 26 ++++++++++- synapse/handlers/saml_handler.py | 18 +++++++- synapse/rest/client/v2_alpha/register.py | 5 +++ synapse/spam_checker_api/__init__.py | 11 +++++ .../main/schema/delta/58/07persist_ui_auth_ips.sql | 25 +++++++++++ synapse/storage/databases/main/ui_auth.py | 39 +++++++++++++++- tests/handlers/test_oidc.py | 18 ++++++-- tests/handlers/test_register.py | 52 +++++++++++++++++++++- tests/handlers/test_user_directory.py | 6 +-- 14 files changed, 258 insertions(+), 18 deletions(-) create mode 100644 changelog.d/8034.feature create mode 100644 synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql (limited to 'synapse/handlers') diff --git a/changelog.d/8034.feature b/changelog.d/8034.feature new file mode 100644 index 0000000000..813e6d0903 --- /dev/null +++ b/changelog.d/8034.feature @@ -0,0 +1 @@ +Add support for shadow-banning users (ignoring any message send requests). diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 1ffc9525d1..a7cddac974 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -15,9 +15,10 @@ # limitations under the License. import inspect -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Tuple -from synapse.spam_checker_api import SpamCheckerApi +from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi +from synapse.types import Collection MYPY = False if MYPY: @@ -160,3 +161,33 @@ class SpamChecker(object): return True return False + + def check_registration_for_spam( + self, + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + ) -> RegistrationBehaviour: + """Checks if we should allow the given registration request. + + Args: + email_threepid: The email threepid used for registering, if any + username: The request user name, if any + request_info: List of tuples of user agent and IP that + were used during the registration process. + + Returns: + Enum for how the request should be handled + """ + + for spam_checker in self.spam_checkers: + # For backwards compatibility, only run if the method exists on the + # spam checker + checker = getattr(spam_checker, "check_registration_for_spam", None) + if checker: + behaviour = checker(email_threepid, username, request_info) + assert isinstance(behaviour, RegistrationBehaviour) + if behaviour != RegistrationBehaviour.ALLOW: + return behaviour + + return RegistrationBehaviour.ALLOW diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 68d6870e40..654f58ddae 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -364,6 +364,14 @@ class AuthHandler(BaseHandler): # authentication flow. await self.store.set_ui_auth_clientdict(sid, clientdict) + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + + await self.store.add_user_agent_ip_to_ui_auth_session( + session.session_id, user_agent, clientip + ) + if not authdict: raise InteractiveAuthIncompleteError( session.session_id, self._auth_dict_for_flows(flows, session.session_id) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 786e608fa2..a4cc4b9a5a 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -35,6 +35,7 @@ class CasHandler: """ def __init__(self, hs): + self.hs = hs self._hostname = hs.hostname self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() @@ -210,8 +211,16 @@ class CasHandler: else: if not registered_user_id: + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders( + b"User-Agent", default=[b""] + )[0].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=user_display_name + localpart=localpart, + default_display_name=user_display_name, + user_agent_ips=(user_agent, ip_address), ) await self._auth_handler.complete_sso_login( diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index dd3703cbd2..c5bd2fea68 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -93,6 +93,7 @@ class OidcHandler: """ def __init__(self, hs: "HomeServer"): + self.hs = hs self._callback_url = hs.config.oidc_callback_url # type: str self._scopes = hs.config.oidc_scopes # type: List[str] self._client_auth = ClientAuth( @@ -689,9 +690,17 @@ class OidcHandler: self._render_error(request, "invalid_token", str(e)) return + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + # Call the mapper to register/login the user try: - user_id = await self._map_userinfo_to_user(userinfo, token) + user_id = await self._map_userinfo_to_user( + userinfo, token, user_agent, ip_address + ) except MappingException as e: logger.exception("Could not map user") self._render_error(request, "mapping_error", str(e)) @@ -828,7 +837,9 @@ class OidcHandler: now = self._clock.time_msec() return now < expiry - async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str: + async def _map_userinfo_to_user( + self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str + ) -> str: """Maps a UserInfo object to a mxid. UserInfo should have a claim that uniquely identifies users. This claim @@ -843,6 +854,8 @@ class OidcHandler: Args: userinfo: an object representing the user token: a dict with the tokens obtained from the provider + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. Raises: MappingException: if there was an error while mapping some properties @@ -899,7 +912,9 @@ class OidcHandler: # It's the first time this user is logging in and the mapped mxid was # not taken, register the user registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=attributes["display_name"], + localpart=localpart, + default_display_name=attributes["display_name"], + user_agent_ips=(user_agent, ip_address), ) await self._datastore.record_user_external_id( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ccd96e4626..cde2dbca92 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -26,6 +26,7 @@ from synapse.replication.http.register import ( ReplicationPostRegisterActionsServlet, ReplicationRegisterServlet, ) +from synapse.spam_checker_api import RegistrationBehaviour from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester @@ -52,6 +53,8 @@ class RegistrationHandler(BaseHandler): self.macaroon_gen = hs.get_macaroon_generator() self._server_notices_mxid = hs.config.server_notices_mxid + self.spam_checker = hs.get_spam_checker() + if hs.config.worker_app: self._register_client = ReplicationRegisterServlet.make_client(hs) self._register_device_client = RegisterDeviceReplicationServlet.make_client( @@ -144,7 +147,7 @@ class RegistrationHandler(BaseHandler): address=None, bind_emails=[], by_admin=False, - shadow_banned=False, + user_agent_ips=None, ): """Registers a new client on the server. @@ -162,7 +165,8 @@ class RegistrationHandler(BaseHandler): bind_emails (List[str]): list of emails to bind to this account. by_admin (bool): True if this registration is being made via the admin api, otherwise False. - shadow_banned (bool): Shadow-ban the created user. + user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used + during the registration process. Returns: str: user_id Raises: @@ -170,6 +174,24 @@ class RegistrationHandler(BaseHandler): """ self.check_registration_ratelimit(address) + result = self.spam_checker.check_registration_for_spam( + threepid, localpart, user_agent_ips or [], + ) + + if result == RegistrationBehaviour.DENY: + logger.info( + "Blocked registration of %r", localpart, + ) + # We return a 429 to make it not obvious that they've been + # denied. + raise SynapseError(429, "Rate limited") + + shadow_banned = result == RegistrationBehaviour.SHADOW_BAN + if shadow_banned: + logger.info( + "Shadow banning registration of %r", localpart, + ) + # do not check_auth_blocking if the call is coming through the Admin API if not by_admin: await self.auth.check_auth_blocking(threepid=threepid) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index c1fcb98454..b426199aa6 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -54,6 +54,7 @@ class Saml2SessionData: class SamlHandler: def __init__(self, hs: "synapse.server.HomeServer"): + self.hs = hs self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() @@ -133,8 +134,14 @@ class SamlHandler: # the dict. self.expire_sessions() + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + user_id, current_session = await self._map_saml_response_to_user( - resp_bytes, relay_state + resp_bytes, relay_state, user_agent, ip_address ) # Complete the interactive auth session or the login. @@ -147,7 +154,11 @@ class SamlHandler: await self._auth_handler.complete_sso_login(user_id, request, relay_state) async def _map_saml_response_to_user( - self, resp_bytes: str, client_redirect_url: str + self, + resp_bytes: str, + client_redirect_url: str, + user_agent: str, + ip_address: str, ) -> Tuple[str, Optional[Saml2SessionData]]: """ Given a sample response, retrieve the cached session and user for it. @@ -155,6 +166,8 @@ class SamlHandler: Args: resp_bytes: The SAML response. client_redirect_url: The redirect URL passed in by the client. + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. Returns: Tuple of the user ID and SAML session associated with this response. @@ -291,6 +304,7 @@ class SamlHandler: localpart=localpart, default_display_name=displayname, bind_emails=emails, + user_agent_ips=(user_agent, ip_address), ) await self._datastore.record_user_external_id( diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 7290fd0756..be0e680ac5 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -591,12 +591,17 @@ class RegisterRestServlet(RestServlet): Codes.THREEPID_IN_USE, ) + entries = await self.store.get_user_agents_ips_to_ui_auth_session( + session_id + ) + registered_user_id = await self.registration_handler.register_user( localpart=desired_username, password_hash=password_hash, guest_access_token=guest_access_token, threepid=threepid, address=client_addr, + user_agent_ips=entries, ) # Necessary due to auth checks prior to the threepid being # written to the db diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py index 7f63f1bfa0..9be92e2565 100644 --- a/synapse/spam_checker_api/__init__.py +++ b/synapse/spam_checker_api/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from enum import Enum from twisted.internet import defer @@ -25,6 +26,16 @@ if MYPY: logger = logging.getLogger(__name__) +class RegistrationBehaviour(Enum): + """ + Enum to define whether a registration request should allowed, denied, or shadow-banned. + """ + + ALLOW = "allow" + SHADOW_BAN = "shadow_ban" + DENY = "deny" + + class SpamCheckerApi(object): """A proxy object that gets passed to spam checkers so they can get access to rooms and other relevant information. diff --git a/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql new file mode 100644 index 0000000000..4cc96a5341 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql @@ -0,0 +1,25 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- A table of the IP address and user-agent used to complete each step of a +-- user-interactive authentication session. +CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips( + session_id TEXT NOT NULL, + ip TEXT NOT NULL, + user_agent TEXT NOT NULL, + UNIQUE (session_id, ip, user_agent), + FOREIGN KEY (session_id) + REFERENCES ui_auth_sessions (session_id) +); diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 6281a41a3d..9eef8e57c5 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import attr @@ -260,6 +260,34 @@ class UIAuthWorkerStore(SQLBaseStore): return serverdict.get(key, default) + async def add_user_agent_ip_to_ui_auth_session( + self, session_id: str, user_agent: str, ip: str, + ): + """Add the given user agent / IP to the tracking table + """ + await self.db_pool.simple_upsert( + table="ui_auth_sessions_ips", + keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip}, + values={}, + desc="add_user_agent_ip_to_ui_auth_session", + ) + + async def get_user_agents_ips_to_ui_auth_session( + self, session_id: str, + ) -> List[Tuple[str, str]]: + """Get the given user agents / IPs used during the ui auth process + + Returns: + List of user_agent/ip pairs + """ + rows = await self.db_pool.simple_select_list( + table="ui_auth_sessions_ips", + keyvalues={"session_id": session_id}, + retcols=("user_agent", "ip"), + desc="get_user_agents_ips_to_ui_auth_session", + ) + return [(row["user_agent"], row["ip"]) for row in rows] + class UIAuthStore(UIAuthWorkerStore): def delete_old_ui_auth_sessions(self, expiration_time: int): @@ -285,6 +313,15 @@ class UIAuthStore(UIAuthWorkerStore): txn.execute(sql, [expiration_time]) session_ids = [r[0] for r in txn.fetchall()] + # Delete the corresponding IP/user agents. + self.db_pool.simple_delete_many_txn( + txn, + table="ui_auth_sessions_ips", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) + # Delete the corresponding completed credentials. self.db_pool.simple_delete_many_txn( txn, diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 1bb25ab684..f92f3b8c15 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -374,12 +374,16 @@ class OidcHandlerTestCase(HomeserverTestCase): self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) self.handler._auth_handler.complete_sso_login = simple_async_mock() - request = Mock(spec=["args", "getCookie", "addCookie"]) + request = Mock( + spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"] + ) code = "code" state = "state" nonce = "nonce" client_redirect_url = "http://client/redirect" + user_agent = "Browser" + ip_address = "10.0.0.1" session = self.handler._generate_oidc_session_token( state=state, nonce=nonce, @@ -392,6 +396,10 @@ class OidcHandlerTestCase(HomeserverTestCase): request.args[b"code"] = [code.encode("utf-8")] request.args[b"state"] = [state.encode("utf-8")] + request.requestHeaders = Mock(spec=["getRawHeaders"]) + request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")] + request.getClientIP.return_value = ip_address + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) self.handler._auth_handler.complete_sso_login.assert_called_once_with( @@ -399,7 +407,9 @@ class OidcHandlerTestCase(HomeserverTestCase): ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._map_userinfo_to_user.assert_called_once_with( + userinfo, token, user_agent, ip_address + ) self.handler._fetch_userinfo.assert_not_called() self.handler._render_error.assert_not_called() @@ -431,7 +441,9 @@ class OidcHandlerTestCase(HomeserverTestCase): ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_not_called() - self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._map_userinfo_to_user.assert_called_once_with( + userinfo, token, user_agent, ip_address + ) self.handler._fetch_userinfo.assert_called_once_with(token) self.handler._render_error.assert_not_called() diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e364b1bd62..5c92d0e8c9 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,18 +17,21 @@ from mock import Mock from twisted.internet import defer +from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import Codes, ResourceLimitError, SynapseError from synapse.handlers.register import RegistrationHandler +from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import RoomAlias, UserID, create_requester from tests.test_utils import make_awaitable from tests.unittest import override_config +from tests.utils import mock_getRawHeaders from .. import unittest -class RegistrationHandlers(object): +class RegistrationHandlers: def __init__(self, hs): self.registration_handler = RegistrationHandler(hs) @@ -475,6 +478,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart=invalid_user_id), SynapseError ) + def test_spam_checker_deny(self): + """A spam checker can deny registration, which results in an error.""" + + class DenyAll: + def check_registration_for_spam( + self, email_threepid, username, request_info + ): + return RegistrationBehaviour.DENY + + # Configure a spam checker that denies all users. + spam_checker = self.hs.get_spam_checker() + spam_checker.spam_checkers = [DenyAll()] + + self.get_failure(self.handler.register_user(localpart="user"), SynapseError) + + def test_spam_checker_shadow_ban(self): + """A spam checker can choose to shadow-ban a user, which allows registration to succeed.""" + + class BanAll: + def check_registration_for_spam( + self, email_threepid, username, request_info + ): + return RegistrationBehaviour.SHADOW_BAN + + # Configure a spam checker that denies all users. + spam_checker = self.hs.get_spam_checker() + spam_checker.spam_checkers = [BanAll()] + + user_id = self.get_success(self.handler.register_user(localpart="user")) + + # Get an access token. + token = self.macaroon_generator.generate_access_token(user_id) + self.get_success( + self.store.add_access_token_to_user( + user_id=user_id, token=token, device_id=None, valid_until_ms=None + ) + ) + + # Ensure the user was marked as shadow-banned. + request = Mock(args={}) + request.args[b"access_token"] = [token.encode("ascii")] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + auth = Auth(self.hs) + requester = self.get_success(auth.get_user_by_req(request)) + + self.assertTrue(requester.shadow_banned) + async def get_or_create_user( self, requester, localpart, displayname, password_hash=None ): diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 31ed89a5cd..87be94111f 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -238,7 +238,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def test_spam_checker(self): """ - A user which fails to the spam checks will not appear in search results. + A user which fails the spam checks will not appear in search results. """ u1 = self.register_user("user1", "pass") u1_token = self.login(u1, "pass") @@ -269,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Configure a spam checker that does not filter any users. spam_checker = self.hs.get_spam_checker() - class AllowAll(object): + class AllowAll: def check_username_for_spam(self, user_profile): # Allow all users. return False @@ -282,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual(len(s["results"]), 1) # Configure a spam checker that filters all users. - class BlockAll(object): + class BlockAll: def check_username_for_spam(self, user_profile): # All users are spammy. return True -- cgit 1.5.1 From cbbf9126cbd2ace90c1c0f615b87bcec30fdcbd8 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Fri, 21 Aug 2020 15:07:56 +0100 Subject: Do not apply ratelimiting on joins to appservices (#8139) Add new method ratelimiter.can_requester_do_action and ensure that appservices are exempt from being ratelimited. Co-authored-by: Patrick Cloke Co-authored-by: Erik Johnston --- changelog.d/8139.bugfix | 1 + synapse/api/ratelimiting.py | 37 +++++++++++++++++++++ synapse/handlers/room_member.py | 14 ++++---- tests/api/test_ratelimiting.py | 73 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 119 insertions(+), 6 deletions(-) create mode 100644 changelog.d/8139.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/8139.bugfix b/changelog.d/8139.bugfix new file mode 100644 index 0000000000..21f65d87b7 --- /dev/null +++ b/changelog.d/8139.bugfix @@ -0,0 +1 @@ +Fixes a bug where appservices with ratelimiting disabled would still be ratelimited when joining rooms. This bug was introduced in v1.19.0. diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index ec6b3a69a2..e62ae50ac2 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -17,6 +17,7 @@ from collections import OrderedDict from typing import Any, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.types import Requester from synapse.util import Clock @@ -43,6 +44,42 @@ class Ratelimiter(object): # * The rate_hz of this particular entry. This can vary per request self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]] + def can_requester_do_action( + self, + requester: Requester, + rate_hz: Optional[float] = None, + burst_count: Optional[int] = None, + update: bool = True, + _time_now_s: Optional[int] = None, + ) -> Tuple[bool, float]: + """Can the requester perform the action? + + Args: + requester: The requester to key off when rate limiting. The user property + will be used. + rate_hz: The long term number of actions that can be performed in a second. + Overrides the value set during instantiation if set. + burst_count: How many actions that can be performed before being limited. + Overrides the value set during instantiation if set. + update: Whether to count this check as performing the action + _time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Only used by tests. + + Returns: + A tuple containing: + * A bool indicating if they can perform the action now + * The reactor timestamp for when the action can be performed next. + -1 if rate_hz is less than or equal to zero + """ + # Disable rate limiting of users belonging to any AS that is configured + # not to be rate limited in its registration file (rate_limited: true|false). + if requester.app_service and not requester.app_service.is_rate_limited(): + return True, -1.0 + + return self.can_do_action( + requester.user.to_string(), rate_hz, burst_count, update, _time_now_s + ) + def can_do_action( self, key: Any, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 3a6ee6378d..a03cb02792 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -491,9 +491,10 @@ class RoomMemberHandler(object): if is_host_in_room: time_now_s = self.clock.time() - allowed, time_allowed = self._join_rate_limiter_local.can_do_action( - requester.user.to_string(), - ) + ( + allowed, + time_allowed, + ) = self._join_rate_limiter_local.can_requester_do_action(requester,) if not allowed: raise LimitExceededError( @@ -502,9 +503,10 @@ class RoomMemberHandler(object): else: time_now_s = self.clock.time() - allowed, time_allowed = self._join_rate_limiter_remote.can_do_action( - requester.user.to_string(), - ) + ( + allowed, + time_allowed, + ) = self._join_rate_limiter_remote.can_requester_do_action(requester,) if not allowed: raise LimitExceededError( diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index d580e729c5..1e1f30d790 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -1,4 +1,6 @@ from synapse.api.ratelimiting import LimitExceededError, Ratelimiter +from synapse.appservice import ApplicationService +from synapse.types import create_requester from tests import unittest @@ -18,6 +20,77 @@ class TestRatelimiter(unittest.TestCase): self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) + def test_allowed_user_via_can_requester_do_action(self): + user_requester = create_requester("@user:example.com") + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=5 + ) + self.assertFalse(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + user_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(20.0, time_allowed) + + def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): + appservice = ApplicationService( + None, "example.com", id="foo", rate_limited=True, + ) + as_requester = create_requester("@user:example.com", app_service=appservice) + + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=5 + ) + self.assertFalse(allowed) + self.assertEquals(10.0, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(20.0, time_allowed) + + def test_allowed_appservice_via_can_requester_do_action(self): + appservice = ApplicationService( + None, "example.com", id="foo", rate_limited=False, + ) + as_requester = create_requester("@user:example.com", app_service=appservice) + + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=0 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=5 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + + allowed, time_allowed = limiter.can_requester_do_action( + as_requester, _time_now_s=10 + ) + self.assertTrue(allowed) + self.assertEquals(-1, time_allowed) + def test_allowed_via_ratelimit(self): limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) -- cgit 1.5.1 From 420484a334a79b31e689bdcca2e57d9a23f7e3d4 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 24 Aug 2020 18:21:04 +0100 Subject: Allow capping a room's retention policy (#8104) --- changelog.d/8104.bugfix | 1 + docs/sample_config.yaml | 22 +++++---- synapse/config/server.py | 22 +++++---- synapse/events/validator.py | 59 ++--------------------- synapse/handlers/pagination.py | 36 +++++++++++--- tests/rest/client/test_retention.py | 94 ++++++++++++++++++++++++++----------- 6 files changed, 127 insertions(+), 107 deletions(-) create mode 100644 changelog.d/8104.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/8104.bugfix b/changelog.d/8104.bugfix new file mode 100644 index 0000000000..e32e2996c4 --- /dev/null +++ b/changelog.d/8104.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.7.2 impacting message retention policies that would allow federated homeservers to dictate a retention period that's lower than the configured minimum allowed duration in the configuration file. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index f168853f67..3528d9e11f 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -378,11 +378,10 @@ retention: # min_lifetime: 1d # max_lifetime: 1y - # Retention policy limits. If set, a user won't be able to send a - # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime' - # that's not within this range. This is especially useful in closed federations, - # in which server admins can make sure every federating server applies the same - # rules. + # Retention policy limits. If set, and the state of a room contains a + # 'm.room.retention' event in its state which contains a 'min_lifetime' or a + # 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy + # to these limits when running purge jobs. # #allowed_lifetime_min: 1d #allowed_lifetime_max: 1y @@ -408,12 +407,19 @@ retention: # (e.g. every 12h), but not want that purge to be performed by a job that's # iterating over every room it knows, which could be heavy on the server. # + # If any purge job is configured, it is strongly recommended to have at least + # a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime' + # set, or one job without 'shortest_max_lifetime' and one job without + # 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if + # 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a + # room's policy to these values is done after the policies are retrieved from + # Synapse's database (which is done using the range specified in a purge job's + # configuration). + # #purge_jobs: - # - shortest_max_lifetime: 1d - # longest_max_lifetime: 3d + # - longest_max_lifetime: 3d # interval: 12h # - shortest_max_lifetime: 3d - # longest_max_lifetime: 1y # interval: 1d # Inhibits the /requestToken endpoints from returning an error that might leak diff --git a/synapse/config/server.py b/synapse/config/server.py index ed66f3eba1..526a90b26a 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -961,11 +961,10 @@ class ServerConfig(Config): # min_lifetime: 1d # max_lifetime: 1y - # Retention policy limits. If set, a user won't be able to send a - # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime' - # that's not within this range. This is especially useful in closed federations, - # in which server admins can make sure every federating server applies the same - # rules. + # Retention policy limits. If set, and the state of a room contains a + # 'm.room.retention' event in its state which contains a 'min_lifetime' or a + # 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy + # to these limits when running purge jobs. # #allowed_lifetime_min: 1d #allowed_lifetime_max: 1y @@ -991,12 +990,19 @@ class ServerConfig(Config): # (e.g. every 12h), but not want that purge to be performed by a job that's # iterating over every room it knows, which could be heavy on the server. # + # If any purge job is configured, it is strongly recommended to have at least + # a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime' + # set, or one job without 'shortest_max_lifetime' and one job without + # 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if + # 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a + # room's policy to these values is done after the policies are retrieved from + # Synapse's database (which is done using the range specified in a purge job's + # configuration). + # #purge_jobs: - # - shortest_max_lifetime: 1d - # longest_max_lifetime: 3d + # - longest_max_lifetime: 3d # interval: 12h # - shortest_max_lifetime: 3d - # longest_max_lifetime: 1y # interval: 1d # Inhibits the /requestToken endpoints from returning an error that might leak diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 588d222f36..5ce3874fba 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -74,15 +74,14 @@ class EventValidator(object): ) if event.type == EventTypes.Retention: - self._validate_retention(event, config) + self._validate_retention(event) - def _validate_retention(self, event, config): + def _validate_retention(self, event): """Checks that an event that defines the retention policy for a room respects the - boundaries imposed by the server's administrator. + format enforced by the spec. Args: event (FrozenEvent): The event to validate. - config (Config): The homeserver's configuration. """ min_lifetime = event.content.get("min_lifetime") max_lifetime = event.content.get("max_lifetime") @@ -95,32 +94,6 @@ class EventValidator(object): errcode=Codes.BAD_JSON, ) - if ( - config.retention_allowed_lifetime_min is not None - and min_lifetime < config.retention_allowed_lifetime_min - ): - raise SynapseError( - code=400, - msg=( - "'min_lifetime' can't be lower than the minimum allowed" - " value enforced by the server's administrator" - ), - errcode=Codes.BAD_JSON, - ) - - if ( - config.retention_allowed_lifetime_max is not None - and min_lifetime > config.retention_allowed_lifetime_max - ): - raise SynapseError( - code=400, - msg=( - "'min_lifetime' can't be greater than the maximum allowed" - " value enforced by the server's administrator" - ), - errcode=Codes.BAD_JSON, - ) - if max_lifetime is not None: if not isinstance(max_lifetime, int): raise SynapseError( @@ -129,32 +102,6 @@ class EventValidator(object): errcode=Codes.BAD_JSON, ) - if ( - config.retention_allowed_lifetime_min is not None - and max_lifetime < config.retention_allowed_lifetime_min - ): - raise SynapseError( - code=400, - msg=( - "'max_lifetime' can't be lower than the minimum allowed value" - " enforced by the server's administrator" - ), - errcode=Codes.BAD_JSON, - ) - - if ( - config.retention_allowed_lifetime_max is not None - and max_lifetime > config.retention_allowed_lifetime_max - ): - raise SynapseError( - code=400, - msg=( - "'max_lifetime' can't be greater than the maximum allowed" - " value enforced by the server's administrator" - ), - errcode=Codes.BAD_JSON, - ) - if ( min_lifetime is not None and max_lifetime is not None diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 487420bb5d..ac3418d69d 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -82,6 +82,9 @@ class PaginationHandler(object): self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime + self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min + self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max + if hs.config.retention_enabled: # Run the purge jobs described in the configuration file. for job in hs.config.retention_purge_jobs: @@ -111,7 +114,7 @@ class PaginationHandler(object): the range to handle (inclusive). If None, it means that the range has no upper limit. """ - # We want the storage layer to to include rooms with no retention policy in its + # We want the storage layer to include rooms with no retention policy in its # return value only if a default retention policy is defined in the server's # configuration and that policy's 'max_lifetime' is either lower (or equal) than # max_ms or higher than min_ms (or both). @@ -152,13 +155,32 @@ class PaginationHandler(object): ) continue - max_lifetime = retention_policy["max_lifetime"] + # If max_lifetime is None, it means that the room has no retention policy. + # Given we only retrieve such rooms when there's a default retention policy + # defined in the server's configuration, we can safely assume that's the + # case and use it for this room. + max_lifetime = ( + retention_policy["max_lifetime"] or self._retention_default_max_lifetime + ) - if max_lifetime is None: - # If max_lifetime is None, it means that include_null equals True, - # therefore we can safely assume that there is a default policy defined - # in the server's configuration. - max_lifetime = self._retention_default_max_lifetime + # Cap the effective max_lifetime to be within the range allowed in the + # config. + # We do this in two steps: + # 1. Make sure it's higher or equal to the minimum allowed value, and if + # it's not replace it with that value. This is because the server + # operator can be required to not delete information before a given + # time, e.g. to comply with freedom of information laws. + # 2. Make sure the resulting value is lower or equal to the maximum allowed + # value, and if it's not replace it with that value. This is because the + # server operator can be required to delete any data after a specific + # amount of time. + if self._retention_allowed_lifetime_min is not None: + max_lifetime = max(self._retention_allowed_lifetime_min, max_lifetime) + + if self._retention_allowed_lifetime_max is not None: + max_lifetime = min(max_lifetime, self._retention_allowed_lifetime_max) + + logger.debug("[purge] max_lifetime for room %s: %s", room_id, max_lifetime) # Figure out what token we should start purging at. ts = self.clock.time_msec() - max_lifetime diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 0b191d13c6..d4e7fa1293 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -45,50 +45,63 @@ class RetentionTestCase(unittest.HomeserverTestCase): } self.hs = self.setup_test_homeserver(config=config) + return self.hs def prepare(self, reactor, clock, homeserver): self.user_id = self.register_user("user", "password") self.token = self.login("user", "password") - def test_retention_state_event(self): - """Tests that the server configuration can limit the values a user can set to the - room's retention policy. + self.store = self.hs.get_datastore() + self.serializer = self.hs.get_event_client_serializer() + self.clock = self.hs.get_clock() + + def test_retention_event_purged_with_state_event(self): + """Tests that expired events are correctly purged when the room's retention policy + is defined by a state event. """ room_id = self.helper.create_room_as(self.user_id, tok=self.token) + # Set the room's retention period to 2 days. + lifetime = one_day_ms * 2 self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={"max_lifetime": one_day_ms * 4}, + body={"max_lifetime": lifetime}, tok=self.token, - expect_code=400, ) + self._test_retention_event_purged(room_id, one_day_ms * 1.5) + + def test_retention_event_purged_with_state_event_outside_allowed(self): + """Tests that the server configuration can override the policy for a room when + running the purge jobs. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) + + # Set a max_lifetime higher than the maximum allowed value. self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={"max_lifetime": one_hour_ms}, + body={"max_lifetime": one_day_ms * 4}, tok=self.token, - expect_code=400, ) - def test_retention_event_purged_with_state_event(self): - """Tests that expired events are correctly purged when the room's retention policy - is defined by a state event. - """ - room_id = self.helper.create_room_as(self.user_id, tok=self.token) + # Check that the event is purged after waiting for the maximum allowed duration + # instead of the one specified in the room's policy. + self._test_retention_event_purged(room_id, one_day_ms * 1.5) - # Set the room's retention period to 2 days. - lifetime = one_day_ms * 2 + # Set a max_lifetime lower than the minimum allowed value. self.helper.send_state( room_id=room_id, event_type=EventTypes.Retention, - body={"max_lifetime": lifetime}, + body={"max_lifetime": one_hour_ms}, tok=self.token, ) - self._test_retention_event_purged(room_id, one_day_ms * 1.5) + # Check that the event is purged after waiting for the minimum allowed duration + # instead of the one specified in the room's policy. + self._test_retention_event_purged(room_id, one_day_ms * 0.5) def test_retention_event_purged_without_state_event(self): """Tests that expired events are correctly purged when the room's retention policy @@ -140,7 +153,27 @@ class RetentionTestCase(unittest.HomeserverTestCase): # That event should be the second, not outdated event. self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events) - def _test_retention_event_purged(self, room_id, increment): + def _test_retention_event_purged(self, room_id: str, increment: float): + """Run the following test scenario to test the message retention policy support: + + 1. Send event 1 + 2. Increment time by `increment` + 3. Send event 2 + 4. Increment time by `increment` + 5. Check that event 1 has been purged + 6. Check that event 2 has not been purged + 7. Check that state events that were sent before event 1 aren't purged. + The main reason for sending a second event is because currently Synapse won't + purge the latest message in a room because it would otherwise result in a lack of + forward extremities for this room. It's also a good thing to ensure the purge jobs + aren't too greedy and purge messages they shouldn't. + + Args: + room_id: The ID of the room to test retention in. + increment: The number of milliseconds to advance the clock each time. Must be + defined so that events in the room aren't purged if they are `increment` + old but are purged if they are `increment * 2` old. + """ # Get the create event to, later, check that we can still access it. message_handler = self.hs.get_message_handler() create_event = self.get_success( @@ -156,7 +189,7 @@ class RetentionTestCase(unittest.HomeserverTestCase): expired_event_id = resp.get("event_id") # Check that we can retrieve the event. - expired_event = self.get_event(room_id, expired_event_id) + expired_event = self.get_event(expired_event_id) self.assertEqual( expired_event.get("content", {}).get("body"), "1", expired_event ) @@ -174,26 +207,31 @@ class RetentionTestCase(unittest.HomeserverTestCase): # one should still be kept. self.reactor.advance(increment / 1000) - # Check that the event has been purged from the database. - self.get_event(room_id, expired_event_id, expected_code=404) + # Check that the first event has been purged from the database, i.e. that we + # can't retrieve it anymore, because it has expired. + self.get_event(expired_event_id, expect_none=True) - # Check that the event that hasn't been purged can still be retrieved. - valid_event = self.get_event(room_id, valid_event_id) + # Check that the event that hasn't expired can still be retrieved. + valid_event = self.get_event(valid_event_id) self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event) # Check that we can still access state events that were sent before the event that # has been purged. self.get_event(room_id, create_event.event_id) - def get_event(self, room_id, event_id, expected_code=200): - url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + def get_event(self, event_id, expect_none=False): + event = self.get_success(self.store.get_event(event_id, allow_none=True)) - request, channel = self.make_request("GET", url, access_token=self.token) - self.render(request) + if expect_none: + self.assertIsNone(event) + return {} - self.assertEqual(channel.code, expected_code, channel.result) + self.assertIsNotNone(event) - return channel.json_body + time_now = self.clock.time_msec() + serialized = self.get_success(self.serializer.serialize_event(event, time_now)) + + return serialized class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1 From cbd8d83da7d24d7434c749c4c6cfece0c507b0b9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 24 Aug 2020 13:58:56 -0400 Subject: Stop shadow-banned users from sending non-member events. (#8142) --- changelog.d/8142.feature | 1 + synapse/handlers/directory.py | 6 ++ synapse/handlers/message.py | 10 +++ synapse/handlers/room.py | 19 +++++- synapse/rest/client/v1/room.py | 74 +++++++++++++--------- synapse/rest/client/v2_alpha/relations.py | 18 ++++-- .../client/v2_alpha/room_upgrade_rest_servlet.py | 14 ++-- tests/rest/client/v1/test_rooms.py | 55 +++++++++++++++- 8 files changed, 155 insertions(+), 42 deletions(-) create mode 100644 changelog.d/8142.feature (limited to 'synapse/handlers') diff --git a/changelog.d/8142.feature b/changelog.d/8142.feature new file mode 100644 index 0000000000..813e6d0903 --- /dev/null +++ b/changelog.d/8142.feature @@ -0,0 +1 @@ +Add support for shadow-banning users (ignoring any message send requests). diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 79a2df6201..46826eb784 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -23,6 +23,7 @@ from synapse.api.errors import ( CodeMessageException, Codes, NotFoundError, + ShadowBanError, StoreError, SynapseError, ) @@ -199,6 +200,8 @@ class DirectoryHandler(BaseHandler): try: await self._update_canonical_alias(requester, user_id, room_id, room_alias) + except ShadowBanError as e: + logger.info("Failed to update alias events due to shadow-ban: %s", e) except AuthError as e: logger.info("Failed to update alias events: %s", e) @@ -292,6 +295,9 @@ class DirectoryHandler(BaseHandler): """ Send an updated canonical alias event if the removed alias was set as the canonical alias or listed in the alt_aliases field. + + Raises: + ShadowBanError if the requester has been shadow-banned. """ alias_event = await self.state.get_current_state( room_id, EventTypes.CanonicalAlias, "" diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index c955a86be0..593c0cc6f1 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import random from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from canonicaljson import encode_canonical_json @@ -34,6 +35,7 @@ from synapse.api.errors import ( Codes, ConsentNotGivenError, NotFoundError, + ShadowBanError, SynapseError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions @@ -716,12 +718,20 @@ class EventCreationHandler(object): event_dict: dict, ratelimit: bool = True, txn_id: Optional[str] = None, + ignore_shadow_ban: bool = False, ) -> Tuple[EventBase, int]: """ Creates an event, then sends it. See self.create_event and self.send_nonmember_event. + + Raises: + ShadowBanError if the requester has been shadow-banned. """ + if not ignore_shadow_ban and requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() # We limit the number of concurrent event sends in a room so that we # don't fork the DAG too much. If we don't limit then we can end up in diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 0fc71475c3..e4788ef86b 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -136,6 +136,9 @@ class RoomCreationHandler(BaseHandler): Returns: the new room id + + Raises: + ShadowBanError if the requester is shadow-banned. """ await self.ratelimit(requester) @@ -171,6 +174,15 @@ class RoomCreationHandler(BaseHandler): async def _upgrade_room( self, requester: Requester, old_room_id: str, new_version: RoomVersion ): + """ + Args: + requester: the user requesting the upgrade + old_room_id: the id of the room to be replaced + new_versions: the version to upgrade the room to + + Raises: + ShadowBanError if the requester is shadow-banned. + """ user_id = requester.user.to_string() # start by allocating a new room id @@ -257,6 +269,9 @@ class RoomCreationHandler(BaseHandler): old_room_id: the id of the room to be replaced new_room_id: the id of the replacement room old_room_state: the state map for the old room + + Raises: + ShadowBanError if the requester is shadow-banned. """ old_room_pl_event_id = old_room_state.get((EventTypes.PowerLevels, "")) @@ -829,11 +844,13 @@ class RoomCreationHandler(BaseHandler): async def send(etype: str, content: JsonDict, **kwargs) -> int: event = create(etype, content, **kwargs) logger.debug("Sending %s in new room", etype) + # Allow these events to be sent even if the user is shadow-banned to + # allow the room creation to complete. ( _, last_stream_id, ) = await self.event_creation_handler.create_and_send_nonmember_event( - creator, event, ratelimit=False + creator, event, ratelimit=False, ignore_shadow_ban=True, ) return last_stream_id diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index a9dd3a6aec..11da8bc037 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -201,8 +201,8 @@ class RoomStateEventRestServlet(TransactionRestServlet): if state_key is not None: event_dict["state_key"] = state_key - if event_type == EventTypes.Member: - try: + try: + if event_type == EventTypes.Member: membership = content.get("membership", None) event_id, _ = await self.room_member_handler.update_membership( requester, @@ -211,16 +211,16 @@ class RoomStateEventRestServlet(TransactionRestServlet): action=membership, content=content, ) - except ShadowBanError: - event_id = "$" + random_string(43) - else: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict, txn_id=txn_id - ) - event_id = event.event_id + else: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict, txn_id=txn_id + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) set_tag("event_id", event_id) ret = {"event_id": event_id} @@ -253,12 +253,19 @@ class RoomSendEventRestServlet(TransactionRestServlet): if b"ts" in request.args and requester.app_service: event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) - event, _ = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict, txn_id=txn_id - ) + try: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict, txn_id=txn_id + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) - set_tag("event_id", event.event_id) - return 200, {"event_id": event.event_id} + set_tag("event_id", event_id) + return 200, {"event_id": event_id} def on_GET(self, request, room_id, event_type, txn_id): return 200, "Not implemented" @@ -799,20 +806,27 @@ class RoomRedactEventRestServlet(TransactionRestServlet): requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) - event, _ = await self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.Redaction, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "redacts": event_id, - }, - txn_id=txn_id, - ) + try: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "redacts": event_id, + }, + txn_id=txn_id, + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) - set_tag("event_id", event.event_id) - return 200, {"event_id": event.event_id} + set_tag("event_id", event_id) + return 200, {"event_id": event_id} def on_PUT(self, request, room_id, event_id, txn_id): set_tag("txn_id", txn_id) diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index 89002ffbff..e29f49f7f5 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -22,7 +22,7 @@ any time to reflect changes in the MSC. import logging from synapse.api.constants import EventTypes, RelationTypes -from synapse.api.errors import SynapseError +from synapse.api.errors import ShadowBanError, SynapseError from synapse.http.servlet import ( RestServlet, parse_integer, @@ -35,6 +35,7 @@ from synapse.storage.relations import ( PaginationChunk, RelationPaginationToken, ) +from synapse.util.stringutils import random_string from ._base import client_patterns @@ -111,11 +112,18 @@ class RelationSendServlet(RestServlet): "sender": requester.user.to_string(), } - event, _ = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict=event_dict, txn_id=txn_id - ) + try: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict=event_dict, txn_id=txn_id + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) - return 200, {"event_id": event.event_id} + return 200, {"event_id": event_id} class RelationPaginationServlet(RestServlet): diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index f357015a70..39a5518614 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -15,13 +15,14 @@ import logging -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import Codes, ShadowBanError, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, ) +from synapse.util import stringutils from ._base import client_patterns @@ -62,7 +63,6 @@ class RoomUpgradeRestServlet(RestServlet): content = parse_json_object_from_request(request) assert_params_in_dict(content, ("new_version",)) - new_version = content["new_version"] new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"]) if new_version is None: @@ -72,9 +72,13 @@ class RoomUpgradeRestServlet(RestServlet): Codes.UNSUPPORTED_ROOM_VERSION, ) - new_room_id = await self._room_creation_handler.upgrade_room( - requester, room_id, new_version - ) + try: + new_room_id = await self._room_creation_handler.upgrade_room( + requester, room_id, new_version + ) + except ShadowBanError: + # Generate a random room ID. + new_room_id = stringutils.random_string(18) ret = {"replacement_room": new_room_id} diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 286e0ccdcc..60fef13e9f 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -27,7 +27,7 @@ import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.handlers.pagination import PurgeStatus from synapse.rest.client.v1 import directory, login, profile, room -from synapse.rest.client.v2_alpha import account +from synapse.rest.client.v2_alpha import account, room_upgrade_rest_servlet from synapse.types import JsonDict, RoomAlias from synapse.util.stringutils import random_string @@ -1984,6 +1984,7 @@ class ShadowBannedTestCase(unittest.HomeserverTestCase): directory.register_servlets, login.register_servlets, room.register_servlets, + room_upgrade_rest_servlet.register_servlets, ] def prepare(self, reactor, clock, homeserver): @@ -2076,3 +2077,55 @@ class ShadowBannedTestCase(unittest.HomeserverTestCase): # Both users should be in the room. users = self.get_success(self.store.get_users_in_room(room_id)) self.assertCountEqual(users, ["@banned:test", "@otheruser:test"]) + + def test_message(self): + """Messages from shadow-banned users don't actually get sent.""" + + room_id = self.helper.create_room_as( + self.other_user_id, tok=self.other_access_token + ) + + # The user should be in the room. + self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token) + + # Sending a message should complete successfully. + result = self.helper.send_event( + room_id=room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "with right label"}, + tok=self.banned_access_token, + ) + self.assertIn("event_id", result) + event_id = result["event_id"] + + latest_events = self.get_success( + self.store.get_latest_event_ids_in_room(room_id) + ) + self.assertNotIn(event_id, latest_events) + + def test_upgrade(self): + """A room upgrade should fail, but look like it succeeded.""" + + # The create works fine. + room_id = self.helper.create_room_as( + self.banned_user_id, tok=self.banned_access_token + ) + + request, channel = self.make_request( + "POST", + "/_matrix/client/r0/rooms/%s/upgrade" % (room_id,), + {"new_version": "6"}, + access_token=self.banned_access_token, + ) + self.render(request) + self.assertEquals(200, channel.code, channel.result) + # A new room_id should be returned. + self.assertIn("replacement_room", channel.json_body) + + new_room_id = channel.json_body["replacement_room"] + + # It doesn't really matter what API we use here, we just want to assert + # that the room doesn't exist. + summary = self.get_success(self.store.get_room_summary(new_room_id)) + # The summary should be empty since the room doesn't exist. + self.assertEqual(summary, {}) -- cgit 1.5.1 From 5758dcf30c245efa1032385cd1af7853d39642a9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 24 Aug 2020 14:25:27 -0400 Subject: Add type hints for state. (#8140) --- changelog.d/8140.misc | 1 + stubs/frozendict.pyi | 47 +++++++ synapse/federation/sender/__init__.py | 4 +- synapse/handlers/federation.py | 10 +- synapse/handlers/presence.py | 6 +- synapse/handlers/room_member.py | 20 +-- synapse/state/__init__.py | 192 +++++++++++++++---------- synapse/state/v1.py | 87 ++++++++---- synapse/state/v2.py | 255 ++++++++++++++++++++++------------ tox.ini | 1 + 10 files changed, 420 insertions(+), 203 deletions(-) create mode 100644 changelog.d/8140.misc create mode 100644 stubs/frozendict.pyi (limited to 'synapse/handlers') diff --git a/changelog.d/8140.misc b/changelog.d/8140.misc new file mode 100644 index 0000000000..78d8834328 --- /dev/null +++ b/changelog.d/8140.misc @@ -0,0 +1 @@ +Add type hints to `synapse.state`. diff --git a/stubs/frozendict.pyi b/stubs/frozendict.pyi new file mode 100644 index 0000000000..3f3af59f26 --- /dev/null +++ b/stubs/frozendict.pyi @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Stub for frozendict. + +from typing import ( + Any, + Hashable, + Iterable, + Iterator, + Mapping, + overload, + Tuple, + TypeVar, +) + +_KT = TypeVar("_KT", bound=Hashable) # Key type. +_VT = TypeVar("_VT") # Value type. + +class frozendict(Mapping[_KT, _VT]): + @overload + def __init__(self, **kwargs: _VT) -> None: ... + @overload + def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ... + @overload + def __init__( + self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT + ) -> None: ... + def __getitem__(self, key: _KT) -> _VT: ... + def __contains__(self, key: Any) -> bool: ... + def copy(self, **add_or_replace: Any) -> frozendict: ... + def __iter__(self) -> Iterator[_KT]: ... + def __len__(self) -> int: ... + def __repr__(self) -> str: ... + def __hash__(self) -> int: ... diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index e53b6ac456..4662008bfd 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -329,10 +329,10 @@ class FederationSender(object): room_id = receipt.room_id # Work out which remote servers should be poked and poke them. - domains = await self.state.get_current_hosts_in_room(room_id) + domains_set = await self.state.get_current_hosts_in_room(room_id) domains = [ d - for d in domains + for d in domains_set if d != self.server_name and self._federation_shard_config.should_handle(self._instance_name, d) ] diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 5b270228e7..f8b234cee2 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2134,10 +2134,10 @@ class FederationHandler(BaseHandler): ) state_sets = list(state_sets.values()) state_sets.append(state) - current_state_ids = await self.state_handler.resolve_events( + current_states = await self.state_handler.resolve_events( room_version, state_sets, event ) - current_state_ids = {k: e.event_id for k, e in current_state_ids.items()} + current_state_ids = {k: e.event_id for k, e in current_states.items()} else: current_state_ids = await self.state_handler.get_current_state_ids( event.room_id, latest_event_ids=extrem_ids @@ -2149,9 +2149,11 @@ class FederationHandler(BaseHandler): # Now check if event pass auth against said current state auth_types = auth_types_for_event(event) - current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types] + current_state_ids_list = [ + e for k, e in current_state_ids.items() if k in auth_types + ] - auth_events_map = await self.store.get_events(current_state_ids) + auth_events_map = await self.store.get_events(current_state_ids_list) current_auth_events = { (e.type, e.state_key): e for e in auth_events_map.values() } diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 24e1940ee5..1846068150 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -40,7 +40,7 @@ from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.state import StateHandler from synapse.storage.databases.main import DataStore -from synapse.types import JsonDict, UserID, get_domain_from_id +from synapse.types import Collection, JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches.descriptors import cached from synapse.util.metrics import Measure @@ -1318,7 +1318,7 @@ async def get_interested_parties( async def get_interested_remotes( store: DataStore, states: List[UserPresenceState], state_handler: StateHandler -) -> List[Tuple[List[str], List[UserPresenceState]]]: +) -> List[Tuple[Collection[str], List[UserPresenceState]]]: """Given a list of presence states figure out which remote servers should be sent which. @@ -1334,7 +1334,7 @@ async def get_interested_remotes( each tuple the list of UserPresenceState should be sent to each destination """ - hosts_and_states = [] + hosts_and_states = [] # type: List[Tuple[Collection[str], List[UserPresenceState]]] # First we look up the rooms each user is in (as well as any explicit # subscriptions), then for each distinct room we look up the remote diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a03cb02792..52548087a9 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -17,7 +17,7 @@ import abc import logging import random from http import HTTPStatus -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union from unpaddedbase64 import encode_base64 @@ -38,7 +38,15 @@ from synapse.events.builder import create_local_event_from_event_dict from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.storage.roommember import RoomsForUser -from synapse.types import Collection, JsonDict, Requester, RoomAlias, RoomID, UserID +from synapse.types import ( + Collection, + JsonDict, + Requester, + RoomAlias, + RoomID, + StateMap, + UserID, +) from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room @@ -738,9 +746,7 @@ class RoomMemberHandler(object): if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target_user, room_id) - async def _can_guest_join( - self, current_state_ids: Dict[Tuple[str, str], str] - ) -> bool: + async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool: """ Returns whether a guest can join a room based on its current state. """ @@ -969,9 +975,7 @@ class RoomMemberHandler(object): ) return stream_id - async def _is_host_in_room( - self, current_state_ids: Dict[Tuple[str, str], str] - ) -> bool: + async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool: # Have we just created the room, and is this about to be the very # first member event? create_event_id = current_state_ids.get(("m.room.create", "")) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index dba8d91eef..a601303fa3 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,11 +16,22 @@ import logging from collections import namedtuple -from typing import Awaitable, Dict, Iterable, List, Optional, Set +from typing import ( + Awaitable, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Union, + overload, +) import attr from frozendict import frozendict from prometheus_client import Histogram +from typing_extensions import Literal from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions @@ -30,7 +41,7 @@ from synapse.logging.utils import log_function from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.roommember import ProfileInfo -from synapse.types import StateMap +from synapse.types import Collection, StateMap from synapse.util import Clock from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache @@ -68,8 +79,14 @@ def _gen_state_id(): class _StateCacheEntry(object): __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] - def __init__(self, state, state_group, prev_group=None, delta_ids=None): - # dict[(str, str), str] map from (type, state_key) to event_id + def __init__( + self, + state: StateMap[str], + state_group: Optional[int], + prev_group: Optional[int] = None, + delta_ids: Optional[StateMap[str]] = None, + ): + # A map from (type, state_key) to event_id. self.state = frozendict(state) # the ID of a state group if one and only one is involved. @@ -107,24 +124,49 @@ class StateHandler(object): self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() + @overload async def get_current_state( - self, room_id, event_type=None, state_key="", latest_event_ids=None - ): - """ Retrieves the current state for the room. This is done by + self, + room_id: str, + event_type: Literal[None] = None, + state_key: str = "", + latest_event_ids: Optional[List[str]] = None, + ) -> StateMap[EventBase]: + ... + + @overload + async def get_current_state( + self, + room_id: str, + event_type: str, + state_key: str = "", + latest_event_ids: Optional[List[str]] = None, + ) -> Optional[EventBase]: + ... + + async def get_current_state( + self, + room_id: str, + event_type: Optional[str] = None, + state_key: str = "", + latest_event_ids: Optional[List[str]] = None, + ) -> Union[Optional[EventBase], StateMap[EventBase]]: + """Retrieves the current state for the room. This is done by calling `get_latest_events_in_room` to get the leading edges of the event graph and then resolving any of the state conflicts. This is equivalent to getting the state of an event that were to send next before receiving any new events. - If `event_type` is specified, then the method returns only the one - event (or None) with that `event_type` and `state_key`. - Returns: - map from (type, state_key) to event + If `event_type` is specified, then the method returns only the one + event (or None) with that `event_type` and `state_key`. + + Otherwise, a map from (type, state_key) to event. """ if not latest_event_ids: latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) + assert latest_event_ids is not None logger.debug("calling resolve_state_groups from get_current_state") ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) @@ -140,34 +182,30 @@ class StateHandler(object): state_map = await self.store.get_events( list(state.values()), get_prev_content=False ) - state = { + return { key: state_map[e_id] for key, e_id in state.items() if e_id in state_map } - return state - - async def get_current_state_ids(self, room_id, latest_event_ids=None): + async def get_current_state_ids( + self, room_id: str, latest_event_ids: Optional[Iterable[str]] = None + ) -> StateMap[str]: """Get the current state, or the state at a set of events, for a room Args: - room_id (str): - - latest_event_ids (iterable[str]|None): if given, the forward - extremities to resolve. If None, we look them up from the - database (via a cache) + room_id: + latest_event_ids: if given, the forward extremities to resolve. If + None, we look them up from the database (via a cache). Returns: - Deferred[dict[(str, str), str)]]: the state dict, mapping from - (event_type, state_key) -> event_id + the state dict, mapping from (event_type, state_key) -> event_id """ if not latest_event_ids: latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) + assert latest_event_ids is not None logger.debug("calling resolve_state_groups from get_current_state_ids") ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - state = ret.state - - return state + return dict(ret.state) async def get_current_users_in_room( self, room_id: str, latest_event_ids: Optional[List[str]] = None @@ -183,32 +221,34 @@ class StateHandler(object): """ if not latest_event_ids: latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) + assert latest_event_ids is not None + logger.debug("calling resolve_state_groups from get_current_users_in_room") entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - joined_users = await self.store.get_joined_users_from_state(room_id, entry) - return joined_users + return await self.store.get_joined_users_from_state(room_id, entry) - async def get_current_hosts_in_room(self, room_id): + async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: event_ids = await self.store.get_latest_event_ids_in_room(room_id) return await self.get_hosts_in_room_at_events(room_id, event_ids) - async def get_hosts_in_room_at_events(self, room_id, event_ids): + async def get_hosts_in_room_at_events( + self, room_id: str, event_ids: List[str] + ) -> Set[str]: """Get the hosts that were in a room at the given event ids Args: - room_id (str): - event_ids (list[str]): + room_id: + event_ids: Returns: - Deferred[list[str]]: the hosts in the room at the given events + The hosts in the room at the given events """ entry = await self.resolve_state_groups_for_events(room_id, event_ids) - joined_hosts = await self.store.get_joined_hosts(room_id, entry) - return joined_hosts + return await self.store.get_joined_hosts(room_id, entry) async def compute_event_context( self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None - ): + ) -> EventContext: """Build an EventContext structure for the event. This works out what the current state should be for the event, and @@ -221,7 +261,7 @@ class StateHandler(object): when receiving an event from federation where we don't have the prev events for, e.g. when backfilling. Returns: - synapse.events.snapshot.EventContext: + The event context. """ if event.internal_metadata.is_outlier(): @@ -275,7 +315,7 @@ class StateHandler(object): event.room_id, event.prev_event_ids() ) - state_ids_before_event = entry.state + state_ids_before_event = dict(entry.state) state_group_before_event = entry.state_group state_group_before_event_prev_group = entry.prev_group deltas_to_state_group_before_event = entry.delta_ids @@ -346,19 +386,18 @@ class StateHandler(object): ) @measure_func() - async def resolve_state_groups_for_events(self, room_id, event_ids): + async def resolve_state_groups_for_events( + self, room_id: str, event_ids: Iterable[str] + ) -> _StateCacheEntry: """ Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. Args: - room_id (str) - event_ids (list[str]) - explicit_room_version (str|None): If set uses the the given room - version to choose the resolution algorithm. If None, then - checks the database for room version. + room_id + event_ids Returns: - Deferred[_StateCacheEntry]: resolved state + The resolved state """ logger.debug("resolve_state_groups event_ids %s", event_ids) @@ -394,7 +433,12 @@ class StateHandler(object): ) return result - async def resolve_events(self, room_version, state_sets, event): + async def resolve_events( + self, + room_version: str, + state_sets: Collection[Iterable[EventBase]], + event: EventBase, + ) -> StateMap[EventBase]: logger.info( "Resolving state for %s with %d groups", event.room_id, len(state_sets) ) @@ -414,9 +458,7 @@ class StateHandler(object): state_res_store=StateResolutionStore(self.store), ) - new_state = {key: state_map[ev_id] for key, ev_id in new_state.items()} - - return new_state + return {key: state_map[ev_id] for key, ev_id in new_state.items()} class StateResolutionHandler(object): @@ -444,7 +486,12 @@ class StateResolutionHandler(object): @log_function async def resolve_state_groups( - self, room_id, room_version, state_groups_ids, event_map, state_res_store + self, + room_id: str, + room_version: str, + state_groups_ids: Dict[int, StateMap[str]], + event_map: Optional[Dict[str, EventBase]], + state_res_store: "StateResolutionStore", ): """Resolves conflicts between a set of state groups @@ -452,13 +499,13 @@ class StateResolutionHandler(object): not be called for a single state group Args: - room_id (str): room we are resolving for (used for logging and sanity checks) - room_version (str): version of the room - state_groups_ids (dict[int, dict[(str, str), str]]): - map from state group id to the state in that state group + room_id: room we are resolving for (used for logging and sanity checks) + room_version: version of the room + state_groups_ids: + A map from state group id to the state in that state group (where 'state' is a map from state key to event id) - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing @@ -466,10 +513,10 @@ class StateResolutionHandler(object): If None, all events will be fetched via state_res_store. - state_res_store (StateResolutionStore) + state_res_store Returns: - _StateCacheEntry: resolved state + The resolved state """ logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys()) @@ -530,21 +577,22 @@ class StateResolutionHandler(object): return cache -def _make_state_cache_entry(new_state, state_groups_ids): +def _make_state_cache_entry( + new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]] +) -> _StateCacheEntry: """Given a resolved state, and a set of input state groups, pick one to base a new state group on (if any), and return an appropriately-constructed _StateCacheEntry. Args: - new_state (dict[(str, str), str]): resolved state map (mapping from - (type, state_key) to event_id) + new_state: resolved state map (mapping from (type, state_key) to event_id) - state_groups_ids (dict[int, dict[(str, str), str]]): - map from state group id to the state in that state group - (where 'state' is a map from state key to event id) + state_groups_ids: + map from state group id to the state in that state group (where + 'state' is a map from state key to event id) Returns: - _StateCacheEntry + The cache entry. """ # if the new state matches any of the input state groups, we can # use that state group again. Otherwise we will generate a state_id @@ -585,7 +633,7 @@ def resolve_events_with_store( clock: Clock, room_id: str, room_version: str, - state_sets: List[StateMap[str]], + state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "StateResolutionStore", ) -> Awaitable[StateMap[str]]: @@ -633,15 +681,17 @@ class StateResolutionStore(object): store = attr.ib() - def get_events(self, event_ids, allow_rejected=False): + def get_events( + self, event_ids: Iterable[str], allow_rejected: bool = False + ) -> Awaitable[Dict[str, EventBase]]: """Get events from the database Args: - event_ids (list): The event_ids of the events to fetch - allow_rejected (bool): If True return rejected events. + event_ids: The event_ids of the events to fetch + allow_rejected: If True return rejected events. Returns: - Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event. + An awaitable which resolves to a dict from event_id to event. """ return self.store.get_events( @@ -651,7 +701,9 @@ class StateResolutionStore(object): allow_rejected=allow_rejected, ) - def get_auth_chain_difference(self, state_sets: List[Set[str]]): + def get_auth_chain_difference( + self, state_sets: List[Set[str]] + ) -> Awaitable[Set[str]]: """Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). @@ -660,7 +712,7 @@ class StateResolutionStore(object): chain. Returns: - Deferred[Set[str]]: Set of event IDs. + An awaitable that resolves to a set of event IDs. """ return self.store.get_auth_chain_difference(state_sets) diff --git a/synapse/state/v1.py b/synapse/state/v1.py index ab5e24841d..0eb7fdd9e5 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -15,7 +15,17 @@ import hashlib import logging -from typing import Awaitable, Callable, Dict, List, Optional +from typing import ( + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, +) from synapse import event_auth from synapse.api.constants import EventTypes @@ -32,10 +42,10 @@ POWER_KEY = (EventTypes.PowerLevels, "") async def resolve_events_with_store( room_id: str, - state_sets: List[StateMap[str]], + state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], - state_map_factory: Callable[[List[str]], Awaitable], -): + state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]], +) -> StateMap[str]: """ Args: room_id: the room we are working in @@ -56,8 +66,7 @@ async def resolve_events_with_store( an Awaitable that resolves to a dict of event_id to event. Returns: - Deferred[dict[(str, str), str]]: - a map from (type, state_key) to event_id. + A map from (type, state_key) to event_id. """ if len(state_sets) == 1: return state_sets[0] @@ -75,8 +84,8 @@ async def resolve_events_with_store( "Asking for %d/%d conflicted events", len(needed_events), needed_event_count ) - # dict[str, FrozenEvent]: a map from state event id to event. Only includes - # the state events which are in conflict (and those in event_map) + # A map from state event id to event. Only includes the state events which + # are in conflict (and those in event_map). state_map = await state_map_factory(needed_events) if event_map is not None: state_map.update(event_map) @@ -91,8 +100,6 @@ async def resolve_events_with_store( # get the ids of the auth events which allow us to authenticate the # conflicted state, picking only from the unconflicting state. - # - # dict[(str, str), str]: a map from state key to event id auth_events = _create_auth_events_from_maps( unconflicted_state, conflicted_state, state_map ) @@ -122,29 +129,30 @@ async def resolve_events_with_store( ) -def _seperate(state_sets): +def _seperate( + state_sets: Iterable[StateMap[str]], +) -> Tuple[StateMap[str], StateMap[Set[str]]]: """Takes the state_sets and figures out which keys are conflicted and which aren't. i.e., which have multiple different event_ids associated with them in different state sets. Args: - state_sets(iterable[dict[(str, str), str]]): + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. Returns: - (dict[(str, str), str], dict[(str, str), set[str]]): - A tuple of (unconflicted_state, conflicted_state), where: + A tuple of (unconflicted_state, conflicted_state), where: - unconflicted_state is a dict mapping (type, state_key)->event_id - for unconflicted state keys. + unconflicted_state is a dict mapping (type, state_key)->event_id + for unconflicted state keys. - conflicted_state is a dict mapping (type, state_key) to a set of - event ids for conflicted state keys. + conflicted_state is a dict mapping (type, state_key) to a set of + event ids for conflicted state keys. """ state_set_iterator = iter(state_sets) unconflicted_state = dict(next(state_set_iterator)) - conflicted_state = {} + conflicted_state = {} # type: StateMap[Set[str]] for state_set in state_set_iterator: for key, value in state_set.items(): @@ -171,7 +179,21 @@ def _seperate(state_sets): return unconflicted_state, conflicted_state -def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): +def _create_auth_events_from_maps( + unconflicted_state: StateMap[str], + conflicted_state: StateMap[Set[str]], + state_map: Dict[str, EventBase], +) -> StateMap[str]: + """ + + Args: + unconflicted_state: The unconflicted state map. + conflicted_state: The conflicted state map. + state_map: + + Returns: + A map from state key to event id. + """ auth_events = {} for event_ids in conflicted_state.values(): for event_id in event_ids: @@ -179,14 +201,17 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma keys = event_auth.auth_types_for_event(state_map[event_id]) for key in keys: if key not in auth_events: - event_id = unconflicted_state.get(key, None) - if event_id: - auth_events[key] = event_id + auth_event_id = unconflicted_state.get(key, None) + if auth_event_id: + auth_events[key] = auth_event_id return auth_events def _resolve_with_state( - unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map + unconflicted_state_ids: StateMap[str], + conflicted_state_ids: StateMap[Set[str]], + auth_event_ids: StateMap[str], + state_map: Dict[str, EventBase], ): conflicted_state = {} for key, event_ids in conflicted_state_ids.items(): @@ -215,7 +240,9 @@ def _resolve_with_state( return new_state -def _resolve_state_events(conflicted_state, auth_events): +def _resolve_state_events( + conflicted_state: StateMap[List[EventBase]], auth_events: StateMap[EventBase] +) -> StateMap[EventBase]: """ This is where we actually decide which of the conflicted state to use. @@ -255,7 +282,9 @@ def _resolve_state_events(conflicted_state, auth_events): return resolved_state -def _resolve_auth_events(events, auth_events): +def _resolve_auth_events( + events: List[EventBase], auth_events: StateMap[EventBase] +) -> EventBase: reverse = list(reversed(_ordered_events(events))) auth_keys = { @@ -289,7 +318,9 @@ def _resolve_auth_events(events, auth_events): return event -def _resolve_normal_events(events, auth_events): +def _resolve_normal_events( + events: List[EventBase], auth_events: StateMap[EventBase] +) -> EventBase: for event in _ordered_events(events): try: # The signatures have already been checked at this point @@ -309,7 +340,7 @@ def _resolve_normal_events(events, auth_events): return event -def _ordered_events(events): +def _ordered_events(events: Iterable[EventBase]) -> List[EventBase]: def key_func(e): # we have to use utf-8 rather than ascii here because it turns out we allow # people to send us events with non-ascii event IDs :/ diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 6634955cdc..0e9ffbd6e6 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -16,7 +16,21 @@ import heapq import itertools import logging -from typing import Dict, List, Optional +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + overload, +) + +from typing_extensions import Literal import synapse.state from synapse import event_auth @@ -40,10 +54,10 @@ async def resolve_events_with_store( clock: Clock, room_id: str, room_version: str, - state_sets: List[StateMap[str]], + state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "synapse.state.StateResolutionStore", -): +) -> StateMap[str]: """Resolves the state using the v2 state resolution algorithm Args: @@ -63,8 +77,7 @@ async def resolve_events_with_store( state_res_store: Returns: - Deferred[dict[(str, str), str]]: - a map from (type, state_key) to event_id. + A map from (type, state_key) to event_id. """ logger.debug("Computing conflicted state") @@ -171,18 +184,23 @@ async def resolve_events_with_store( return resolved_state -async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): +async def _get_power_level_for_sender( + room_id: str, + event_id: str, + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", +) -> int: """Return the power level of the sender of the given event according to their auth events. Args: - room_id (str) - event_id (str) - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) + room_id + event_id + event_map + state_res_store Returns: - Deferred[int] + The power level. """ event = await _get_event(room_id, event_id, event_map, state_res_store) @@ -217,17 +235,21 @@ async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_st return int(level) -async def _get_auth_chain_difference(state_sets, event_map, state_res_store): +async def _get_auth_chain_difference( + state_sets: Sequence[StateMap[str]], + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", +) -> Set[str]: """Compare the auth chains of each state set and return the set of events that only appear in some but not all of the auth chains. Args: - state_sets (list) - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) + state_sets + event_map + state_res_store Returns: - Deferred[set[str]]: Set of event IDs + Set of event IDs """ difference = await state_res_store.get_auth_chain_difference( @@ -237,17 +259,19 @@ async def _get_auth_chain_difference(state_sets, event_map, state_res_store): return difference -def _seperate(state_sets): +def _seperate( + state_sets: Iterable[StateMap[str]], +) -> Tuple[StateMap[str], StateMap[Set[str]]]: """Return the unconflicted and conflicted state. This is different than in the original algorithm, as this defines a key to be conflicted if one of the state sets doesn't have that key. Args: - state_sets (list) + state_sets Returns: - tuple[dict, dict]: A tuple of unconflicted and conflicted state. The - conflicted state dict is a map from type/state_key to set of event IDs + A tuple of unconflicted and conflicted state. The conflicted state dict + is a map from type/state_key to set of event IDs """ unconflicted_state = {} conflicted_state = {} @@ -260,18 +284,20 @@ def _seperate(state_sets): event_ids.discard(None) conflicted_state[key] = event_ids - return unconflicted_state, conflicted_state + # mypy doesn't understand that discarding None above means that conflicted + # state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]]. + return unconflicted_state, conflicted_state # type: ignore -def _is_power_event(event): +def _is_power_event(event: EventBase) -> bool: """Return whether or not the event is a "power event", as defined by the v2 state resolution algorithm Args: - event (FrozenEvent) + event Returns: - boolean + True if the event is a power event. """ if (event.type, event.state_key) in ( (EventTypes.PowerLevels, ""), @@ -288,19 +314,23 @@ def _is_power_event(event): async def _add_event_and_auth_chain_to_graph( - graph, room_id, event_id, event_map, state_res_store, auth_diff -): + graph: Dict[str, Set[str]], + room_id: str, + event_id: str, + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", + auth_diff: Set[str], +) -> None: """Helper function for _reverse_topological_power_sort that add the event and its auth chain (that is in the auth diff) to the graph Args: - graph (dict[str, set[str]]): A map from event ID to the events auth - event IDs - room_id (str): the room we are working in - event_id (str): Event to add to the graph - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) - auth_diff (set[str]): Set of event IDs that are in the auth difference. + graph: A map from event ID to the events auth event IDs + room_id: the room we are working in + event_id: Event to add to the graph + event_map + state_res_store + auth_diff: Set of event IDs that are in the auth difference. """ state = [event_id] @@ -318,24 +348,29 @@ async def _add_event_and_auth_chain_to_graph( async def _reverse_topological_power_sort( - clock, room_id, event_ids, event_map, state_res_store, auth_diff -): + clock: Clock, + room_id: str, + event_ids: Iterable[str], + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", + auth_diff: Set[str], +) -> List[str]: """Returns a list of the event_ids sorted by reverse topological ordering, and then by power level and origin_server_ts Args: - clock (Clock) - room_id (str): the room we are working in - event_ids (list[str]): The events to sort - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) - auth_diff (set[str]): Set of event IDs that are in the auth difference. + clock + room_id: the room we are working in + event_ids: The events to sort + event_map + state_res_store + auth_diff: Set of event IDs that are in the auth difference. Returns: - Deferred[list[str]]: The sorted list + The sorted list """ - graph = {} + graph = {} # type: Dict[str, Set[str]] for idx, event_id in enumerate(event_ids, start=1): await _add_event_and_auth_chain_to_graph( graph, room_id, event_id, event_map, state_res_store, auth_diff @@ -372,22 +407,28 @@ async def _reverse_topological_power_sort( async def _iterative_auth_checks( - clock, room_id, room_version, event_ids, base_state, event_map, state_res_store -): + clock: Clock, + room_id: str, + room_version: str, + event_ids: List[str], + base_state: StateMap[str], + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", +) -> StateMap[str]: """Sequentially apply auth checks to each event in given list, updating the state as it goes along. Args: - clock (Clock) - room_id (str) - room_version (str) - event_ids (list[str]): Ordered list of events to apply auth checks to - base_state (StateMap[str]): The set of state to start with - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) + clock + room_id + room_version + event_ids: Ordered list of events to apply auth checks to + base_state: The set of state to start with + event_map + state_res_store Returns: - Deferred[StateMap[str]]: Returns the final updated state + Returns the final updated state """ resolved_state = base_state.copy() room_version_obj = KNOWN_ROOM_VERSIONS[room_version] @@ -439,21 +480,26 @@ async def _iterative_auth_checks( async def _mainline_sort( - clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store -): + clock: Clock, + room_id: str, + event_ids: List[str], + resolved_power_event_id: Optional[str], + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", +) -> List[str]: """Returns a sorted list of event_ids sorted by mainline ordering based on the given event resolved_power_event_id Args: - clock (Clock) - room_id (str): room we're working in - event_ids (list[str]): Events to sort - resolved_power_event_id (str): The final resolved power level event ID - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) + clock + room_id: room we're working in + event_ids: Events to sort + resolved_power_event_id: The final resolved power level event ID + event_map + state_res_store Returns: - Deferred[list[str]]: The sorted list + The sorted list """ if not event_ids: # It's possible for there to be no event IDs here to sort, so we can @@ -505,59 +551,90 @@ async def _mainline_sort( async def _get_mainline_depth_for_event( - event, mainline_map, event_map, state_res_store -): + event: EventBase, + mainline_map: Dict[str, int], + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", +) -> int: """Get the mainline depths for the given event based on the mainline map Args: - event (FrozenEvent) - mainline_map (dict[str, int]): Map from event_id to mainline depth for - events in the mainline. - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) + event + mainline_map: Map from event_id to mainline depth for events in the mainline. + event_map + state_res_store Returns: - Deferred[int] + The mainline depth """ room_id = event.room_id + tmp_event = event # type: Optional[EventBase] # We do an iterative search, replacing `event with the power level in its # auth events (if any) - while event: + while tmp_event: depth = mainline_map.get(event.event_id) if depth is not None: return depth - auth_events = event.auth_event_ids() - event = None + auth_events = tmp_event.auth_event_ids() + tmp_event = None for aid in auth_events: aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): - event = aev + tmp_event = aev break # Didn't find a power level auth event, so we just return 0 return 0 -async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): +@overload +async def _get_event( + room_id: str, + event_id: str, + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", + allow_none: Literal[False] = False, +) -> EventBase: + ... + + +@overload +async def _get_event( + room_id: str, + event_id: str, + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", + allow_none: Literal[True], +) -> Optional[EventBase]: + ... + + +async def _get_event( + room_id: str, + event_id: str, + event_map: Dict[str, EventBase], + state_res_store: "synapse.state.StateResolutionStore", + allow_none: bool = False, +) -> Optional[EventBase]: """Helper function to look up event in event_map, falling back to looking it up in the store Args: - room_id (str) - event_id (str) - event_map (dict[str,FrozenEvent]) - state_res_store (StateResolutionStore) - allow_none (bool): if the event is not found, return None rather than raising + room_id + event_id + event_map + state_res_store + allow_none: if the event is not found, return None rather than raising an exception Returns: - Deferred[Optional[FrozenEvent]] + The event, or none if the event does not exist (and allow_none is True). """ if event_id not in event_map: events = await state_res_store.get_events([event_id], allow_rejected=True) @@ -577,7 +654,9 @@ async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=F return event -def lexicographical_topological_sort(graph, key): +def lexicographical_topological_sort( + graph: Dict[str, Set[str]], key: Callable[[str], Any] +) -> Generator[str, None, None]: """Performs a lexicographic reverse topological sort on the graph. This returns a reverse topological sort (i.e. if node A references B then B @@ -587,20 +666,20 @@ def lexicographical_topological_sort(graph, key): NOTE: `graph` is modified during the sort. Args: - graph (dict[str, set[str]]): A representation of the graph where each - node is a key in the dict and its value are the nodes edges. - key (func): A function that takes a node and returns a value that is - comparable and used to order nodes + graph: A representation of the graph where each node is a key in the + dict and its value are the nodes edges. + key: A function that takes a node and returns a value that is comparable + and used to order nodes Yields: - str: The next node in the topological sort + The next node in the topological sort """ # Note, this is basically Kahn's algorithm except we look at nodes with no # outgoing edges, c.f. # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm outdegree_map = graph - reverse_graph = {} + reverse_graph = {} # type: Dict[str, Set[str]] # Lists of nodes with zero out degree. Is actually a tuple of # `(key(node), node)` so that sorting does the right thing diff --git a/tox.ini b/tox.ini index ea804108b5..edeb757f7b 100644 --- a/tox.ini +++ b/tox.ini @@ -209,6 +209,7 @@ commands = mypy \ synapse/server.py \ synapse/server_notices \ synapse/spam_checker_api \ + synapse/state \ synapse/storage/databases/main/ui_auth.py \ synapse/storage/database.py \ synapse/storage/engines \ -- cgit 1.5.1 From 5099bd68da4cf27364671a46c5754ec06d7a7a34 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 25 Aug 2020 10:52:15 -0400 Subject: Do not allow send_nonmember_event to be called with shadow-banned users. (#8158) --- changelog.d/8158.feature | 1 + synapse/handlers/message.py | 39 ++++++++++++++++++++++++++++++++++----- 2 files changed, 35 insertions(+), 5 deletions(-) create mode 100644 changelog.d/8158.feature (limited to 'synapse/handlers') diff --git a/changelog.d/8158.feature b/changelog.d/8158.feature new file mode 100644 index 0000000000..47c4c39167 --- /dev/null +++ b/changelog.d/8158.feature @@ -0,0 +1 @@ + Add support for shadow-banning users (ignoring any message send requests). diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 593c0cc6f1..02d624268b 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -647,24 +647,35 @@ class EventCreationHandler(object): event: EventBase, context: EventContext, ratelimit: bool = True, + ignore_shadow_ban: bool = False, ) -> int: """ Persists and notifies local clients and federation of an event. Args: - requester - event the event to send. - context: the context of the event. + requester: The requester sending the event. + event: The event to send. + context: The context of the event. ratelimit: Whether to rate limit this send. + ignore_shadow_ban: True if shadow-banned users should be allowed to + send this event. Return: The stream_id of the persisted event. + + Raises: + ShadowBanError if the requester has been shadow-banned. """ if event.type == EventTypes.Member: raise SynapseError( 500, "Tried to send member event through non-member codepath" ) + if not ignore_shadow_ban and requester.shadow_banned: + # We randomly sleep a bit just to annoy the requester. + await self.clock.sleep(random.randint(1, 10)) + raise ShadowBanError() + user = UserID.from_string(event.sender) assert self.hs.is_mine(user), "User must be our own: %s" % (user,) @@ -725,6 +736,14 @@ class EventCreationHandler(object): See self.create_event and self.send_nonmember_event. + Args: + requester: The requester sending the event. + event_dict: An entire event. + ratelimit: Whether to rate limit this send. + txn_id: The transaction ID. + ignore_shadow_ban: True if shadow-banned users should be allowed to + send this event. + Raises: ShadowBanError if the requester has been shadow-banned. """ @@ -750,7 +769,11 @@ class EventCreationHandler(object): raise SynapseError(403, spam_error, Codes.FORBIDDEN) stream_id = await self.send_nonmember_event( - requester, event, context, ratelimit=ratelimit + requester, + event, + context, + ratelimit=ratelimit, + ignore_shadow_ban=ignore_shadow_ban, ) return event, stream_id @@ -1190,8 +1213,14 @@ class EventCreationHandler(object): event.internal_metadata.proactively_send = False + # Since this is a dummy-event it is OK if it is sent by a + # shadow-banned user. await self.send_nonmember_event( - requester, event, context, ratelimit=False + requester, + event, + context, + ratelimit=False, + ignore_shadow_ban=True, ) dummy_event_sent = True break -- cgit 1.5.1