From dd71eb0f8ab5a6e0d8eda3be8c2d5ff01271d147 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 18 Mar 2021 15:52:26 +0000 Subject: Make federation catchup send last event from any server. (#9640) Currently federation catchup will send the last *local* event that we failed to send to the remote. This can cause issues for large rooms where lots of servers have sent events while the remote server was down, as when it comes back up again it'll be flooded with events from various points in the DAG. Instead, let's make it so that all the servers send the most recent events, even if its not theirs. The remote should deduplicate the events, so there shouldn't be much overhead in doing this. Alternatively, the servers could only send local events if they were also extremities and hope that the other server will send the event over, but that is a bit risky. --- synapse/federation/federation_server.py | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) (limited to 'synapse/federation/federation_server.py') diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 9839d3d016..d84e362070 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -35,7 +35,7 @@ from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure -from synapse.api.constants import EduTypes, EventTypes, Membership +from synapse.api.constants import EduTypes, EventTypes from synapse.api.errors import ( AuthError, Codes, @@ -63,7 +63,7 @@ from synapse.replication.http.federation import ( ReplicationFederationSendEduRestServlet, ReplicationGetQueryRestServlet, ) -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import JsonDict from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache @@ -727,27 +727,6 @@ class FederationServer(FederationBase): if the event was unacceptable for any other reason (eg, too large, too many prev_events, couldn't find the prev_events) """ - # check that it's actually being sent from a valid destination to - # workaround bug #1753 in 0.18.5 and 0.18.6 - if origin != get_domain_from_id(pdu.sender): - # We continue to accept join events from any server; this is - # necessary for the federation join dance to work correctly. - # (When we join over federation, the "helper" server is - # responsible for sending out the join event, rather than the - # origin. See bug #1893. This is also true for some third party - # invites). - if not ( - pdu.type == "m.room.member" - and pdu.content - and pdu.content.get("membership", None) - in (Membership.JOIN, Membership.INVITE) - ): - logger.info( - "Discarding PDU %s from invalid origin %s", pdu.event_id, origin - ) - return - else: - logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) # We've already checked that we know the room version by this point room_version = await self.store.get_room_version(pdu.room_id) -- cgit 1.5.1 From 963f4309fe29206f3ba92b493e922280feea30ed Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 30 Mar 2021 12:06:09 +0100 Subject: Make RateLimiter class check for ratelimit overrides (#9711) This should fix a class of bug where we forget to check if e.g. the appservice shouldn't be ratelimited. We also check the `ratelimit_override` table to check if the user has ratelimiting disabled. That table is really only meant to override the event sender ratelimiting, so we don't use any values from it (as they might not make sense for different rate limits), but we do infer that if ratelimiting is disabled for the user we should disabled all ratelimits. Fixes #9663 --- changelog.d/9711.bugfix | 1 + synapse/api/ratelimiting.py | 100 +++++++++--------- synapse/federation/federation_server.py | 5 +- synapse/handlers/_base.py | 14 +-- synapse/handlers/auth.py | 24 +++-- synapse/handlers/devicemessage.py | 5 +- synapse/handlers/federation.py | 2 +- synapse/handlers/identity.py | 12 ++- synapse/handlers/register.py | 6 +- synapse/handlers/room_member.py | 23 +++-- synapse/replication/http/register.py | 2 +- synapse/rest/client/v1/login.py | 14 ++- synapse/rest/client/v2_alpha/account.py | 10 +- synapse/rest/client/v2_alpha/register.py | 8 +- synapse/server.py | 1 + tests/api/test_ratelimiting.py | 168 ++++++++++++++++++++----------- 16 files changed, 241 insertions(+), 154 deletions(-) create mode 100644 changelog.d/9711.bugfix (limited to 'synapse/federation/federation_server.py') diff --git a/changelog.d/9711.bugfix b/changelog.d/9711.bugfix new file mode 100644 index 0000000000..4ca3438d46 --- /dev/null +++ b/changelog.d/9711.bugfix @@ -0,0 +1 @@ +Fix recently added ratelimits to correctly honour the application service `rate_limited` flag. diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index c3f07bc1a3..2244b8a340 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -17,6 +17,7 @@ from collections import OrderedDict from typing import Hashable, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.storage.databases.main import DataStore from synapse.types import Requester from synapse.util import Clock @@ -31,10 +32,13 @@ class Ratelimiter: burst_count: How many actions that can be performed before being limited. """ - def __init__(self, clock: Clock, rate_hz: float, burst_count: int): + def __init__( + self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int + ): self.clock = clock self.rate_hz = rate_hz self.burst_count = burst_count + self.store = store # A ordered dictionary keeping track of actions, when they were last # performed and how often. Each entry is a mapping from a key of arbitrary type @@ -46,45 +50,10 @@ class Ratelimiter: OrderedDict() ) # type: OrderedDict[Hashable, Tuple[float, int, float]] - def can_requester_do_action( - self, - requester: Requester, - rate_hz: Optional[float] = None, - burst_count: Optional[int] = None, - update: bool = True, - _time_now_s: Optional[int] = None, - ) -> Tuple[bool, float]: - """Can the requester perform the action? - - Args: - requester: The requester to key off when rate limiting. The user property - will be used. - rate_hz: The long term number of actions that can be performed in a second. - Overrides the value set during instantiation if set. - burst_count: How many actions that can be performed before being limited. - Overrides the value set during instantiation if set. - update: Whether to count this check as performing the action - _time_now_s: The current time. Optional, defaults to the current time according - to self.clock. Only used by tests. - - Returns: - A tuple containing: - * A bool indicating if they can perform the action now - * The reactor timestamp for when the action can be performed next. - -1 if rate_hz is less than or equal to zero - """ - # Disable rate limiting of users belonging to any AS that is configured - # not to be rate limited in its registration file (rate_limited: true|false). - if requester.app_service and not requester.app_service.is_rate_limited(): - return True, -1.0 - - return self.can_do_action( - requester.user.to_string(), rate_hz, burst_count, update, _time_now_s - ) - - def can_do_action( + async def can_do_action( self, - key: Hashable, + requester: Optional[Requester], + key: Optional[Hashable] = None, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, update: bool = True, @@ -92,9 +61,16 @@ class Ratelimiter: ) -> Tuple[bool, float]: """Can the entity (e.g. user or IP address) perform the action? + Checks if the user has ratelimiting disabled in the database by looking + for null/zero values in the `ratelimit_override` table. (Non-zero + values aren't honoured, as they're specific to the event sending + ratelimiter, rather than all ratelimiters) + Args: - key: The key we should use when rate limiting. Can be a user ID - (when sending events), an IP address, etc. + requester: The requester that is doing the action, if any. Used to check + if the user has ratelimits disabled in the database. + key: An arbitrary key used to classify an action. Defaults to the + requester's user ID. rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. burst_count: How many actions that can be performed before being limited. @@ -109,6 +85,30 @@ class Ratelimiter: * The reactor timestamp for when the action can be performed next. -1 if rate_hz is less than or equal to zero """ + if key is None: + if not requester: + raise ValueError("Must supply at least one of `requester` or `key`") + + key = requester.user.to_string() + + if requester: + # Disable rate limiting of users belonging to any AS that is configured + # not to be rate limited in its registration file (rate_limited: true|false). + if requester.app_service and not requester.app_service.is_rate_limited(): + return True, -1.0 + + # Check if ratelimiting has been disabled for the user. + # + # Note that we don't use the returned rate/burst count, as the table + # is specifically for the event sending ratelimiter. Instead, we + # only use it to (somewhat cheekily) infer whether the user should + # be subject to any rate limiting or not. + override = await self.store.get_ratelimit_for_user( + requester.authenticated_entity + ) + if override and not override.messages_per_second: + return True, -1.0 + # Override default values if set time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() rate_hz = rate_hz if rate_hz is not None else self.rate_hz @@ -175,9 +175,10 @@ class Ratelimiter: else: del self.actions[key] - def ratelimit( + async def ratelimit( self, - key: Hashable, + requester: Optional[Requester], + key: Optional[Hashable] = None, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, update: bool = True, @@ -185,8 +186,16 @@ class Ratelimiter: ): """Checks if an action can be performed. If not, raises a LimitExceededError + Checks if the user has ratelimiting disabled in the database by looking + for null/zero values in the `ratelimit_override` table. (Non-zero + values aren't honoured, as they're specific to the event sending + ratelimiter, rather than all ratelimiters) + Args: - key: An arbitrary key used to classify an action + requester: The requester that is doing the action, if any. Used to check for + if the user has ratelimits disabled. + key: An arbitrary key used to classify an action. Defaults to the + requester's user ID. rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. burst_count: How many actions that can be performed before being limited. @@ -201,7 +210,8 @@ class Ratelimiter: """ time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() - allowed, time_allowed = self.can_do_action( + allowed, time_allowed = await self.can_do_action( + requester, key, rate_hz=rate_hz, burst_count=burst_count, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index d84e362070..71cb120ef7 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -870,6 +870,7 @@ class FederationHandlerRegistry: # A rate limiter for incoming room key requests per origin. self._room_key_request_rate_limiter = Ratelimiter( + store=hs.get_datastore(), clock=self.clock, rate_hz=self.config.rc_key_requests.per_second, burst_count=self.config.rc_key_requests.burst_count, @@ -930,7 +931,9 @@ class FederationHandlerRegistry: # the limit, drop them. if ( edu_type == EduTypes.RoomKeyRequest - and not self._room_key_request_rate_limiter.can_do_action(origin) + and not await self._room_key_request_rate_limiter.can_do_action( + None, origin + ) ): return diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index aade2c4a3a..fb899aa90d 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -49,7 +49,7 @@ class BaseHandler: # The rate_hz and burst_count are overridden on a per-user basis self.request_ratelimiter = Ratelimiter( - clock=self.clock, rate_hz=0, burst_count=0 + store=self.store, clock=self.clock, rate_hz=0, burst_count=0 ) self._rc_message = self.hs.config.rc_message @@ -57,6 +57,7 @@ class BaseHandler: # by the presence of rate limits in the config if self.hs.config.rc_admin_redaction: self.admin_redaction_ratelimiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=self.hs.config.rc_admin_redaction.per_second, burst_count=self.hs.config.rc_admin_redaction.burst_count, @@ -91,11 +92,6 @@ class BaseHandler: if app_service is not None: return # do not ratelimit app service senders - # Disable rate limiting of users belonging to any AS that is configured - # not to be rate limited in its registration file (rate_limited: true|false). - if requester.app_service and not requester.app_service.is_rate_limited(): - return - messages_per_second = self._rc_message.per_second burst_count = self._rc_message.burst_count @@ -113,11 +109,11 @@ class BaseHandler: if is_admin_redaction and self.admin_redaction_ratelimiter: # If we have separate config for admin redactions, use a separate # ratelimiter as to not have user_ids clash - self.admin_redaction_ratelimiter.ratelimit(user_id, update=update) + await self.admin_redaction_ratelimiter.ratelimit(requester, update=update) else: # Override rate and burst count per-user - self.request_ratelimiter.ratelimit( - user_id, + await self.request_ratelimiter.ratelimit( + requester, rate_hz=messages_per_second, burst_count=burst_count, update=update, diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d537ea8137..08e413bc98 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -238,6 +238,7 @@ class AuthHandler(BaseHandler): # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. self._failed_uia_attempts_ratelimiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, @@ -248,6 +249,7 @@ class AuthHandler(BaseHandler): # Ratelimitier for failed /login attempts self._failed_login_attempts_ratelimiter = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, @@ -352,7 +354,7 @@ class AuthHandler(BaseHandler): requester_user_id = requester.user.to_string() # Check if we should be ratelimited due to too many previous failed attempts - self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False) + await self._failed_uia_attempts_ratelimiter.ratelimit(requester, update=False) # build a list of supported flows supported_ui_auth_types = await self._get_available_ui_auth_types( @@ -373,7 +375,9 @@ class AuthHandler(BaseHandler): ) except LoginError: # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). - self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id) + await self._failed_uia_attempts_ratelimiter.can_do_action( + requester, + ) raise # find the completed login type @@ -982,8 +986,8 @@ class AuthHandler(BaseHandler): # We also apply account rate limiting using the 3PID as a key, as # otherwise using 3PID bypasses the ratelimiting based on user ID. if ratelimit: - self._failed_login_attempts_ratelimiter.ratelimit( - (medium, address), update=False + await self._failed_login_attempts_ratelimiter.ratelimit( + None, (medium, address), update=False ) # Check for login providers that support 3pid login types @@ -1016,8 +1020,8 @@ class AuthHandler(BaseHandler): # this code path, which is fine as then the per-user ratelimit # will kick in below. if ratelimit: - self._failed_login_attempts_ratelimiter.can_do_action( - (medium, address) + await self._failed_login_attempts_ratelimiter.can_do_action( + None, (medium, address) ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) @@ -1039,8 +1043,8 @@ class AuthHandler(BaseHandler): # Check if we've hit the failed ratelimit (but don't update it) if ratelimit: - self._failed_login_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), update=False + await self._failed_login_attempts_ratelimiter.ratelimit( + None, qualified_user_id.lower(), update=False ) try: @@ -1051,8 +1055,8 @@ class AuthHandler(BaseHandler): # exception and masking the LoginError. The actual ratelimiting # should have happened above. if ratelimit: - self._failed_login_attempts_ratelimiter.can_do_action( - qualified_user_id.lower() + await self._failed_login_attempts_ratelimiter.can_do_action( + None, qualified_user_id.lower() ) raise diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index eb547743be..5ee48be6ff 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -81,6 +81,7 @@ class DeviceMessageHandler: ) self._ratelimiter = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=hs.config.rc_key_requests.per_second, burst_count=hs.config.rc_key_requests.burst_count, @@ -191,8 +192,8 @@ class DeviceMessageHandler: if ( message_type == EduTypes.RoomKeyRequest and user_id != sender_user_id - and self._ratelimiter.can_do_action( - (sender_user_id, requester.device_id) + and await self._ratelimiter.can_do_action( + requester, (sender_user_id, requester.device_id) ) ): continue diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 598a66f74c..3ebee38ebe 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1711,7 +1711,7 @@ class FederationHandler(BaseHandler): member_handler = self.hs.get_room_member_handler() # We don't rate limit based on room ID, as that should be done by # sending server. - member_handler.ratelimit_invite(None, event.state_key) + await member_handler.ratelimit_invite(None, None, event.state_key) # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 5f346f6d6d..d89fa5fb30 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -61,17 +61,19 @@ class IdentityHandler(BaseHandler): # Ratelimiters for `/requestToken` endpoints. self._3pid_validation_ratelimiter_ip = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, ) self._3pid_validation_ratelimiter_address = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, ) - def ratelimit_request_token_requests( + async def ratelimit_request_token_requests( self, request: SynapseRequest, medium: str, @@ -85,8 +87,12 @@ class IdentityHandler(BaseHandler): address: The actual threepid ID, e.g. the phone number or email address """ - self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP())) - self._3pid_validation_ratelimiter_address.ratelimit((medium, address)) + await self._3pid_validation_ratelimiter_ip.ratelimit( + None, (medium, request.getClientIP()) + ) + await self._3pid_validation_ratelimiter_address.ratelimit( + None, (medium, address) + ) async def threepid_from_creds( self, id_server: str, creds: Dict[str, str] diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 0fc2bf15d5..9701b76d0f 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -204,7 +204,7 @@ class RegistrationHandler(BaseHandler): Raises: SynapseError if there was a problem registering. """ - self.check_registration_ratelimit(address) + await self.check_registration_ratelimit(address) result = await self.spam_checker.check_registration_for_spam( threepid, @@ -583,7 +583,7 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - def check_registration_ratelimit(self, address: Optional[str]) -> None: + async def check_registration_ratelimit(self, address: Optional[str]) -> None: """A simple helper method to check whether the registration rate limit has been hit for a given IP address @@ -597,7 +597,7 @@ class RegistrationHandler(BaseHandler): if not address: return - self.ratelimiter.ratelimit(address) + await self.ratelimiter.ratelimit(None, address) async def register_with_store( self, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 4d20ed8357..1cf12f3255 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -75,22 +75,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.allow_per_room_profiles = self.config.allow_per_room_profiles self._join_rate_limiter_local = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, burst_count=hs.config.ratelimiting.rc_joins_local.burst_count, ) self._join_rate_limiter_remote = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, ) self._invites_per_room_limiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, ) self._invites_per_user_limiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, @@ -159,15 +163,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): async def forget(self, user: UserID, room_id: str) -> None: raise NotImplementedError() - def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str): + async def ratelimit_invite( + self, + requester: Optional[Requester], + room_id: Optional[str], + invitee_user_id: str, + ): """Ratelimit invites by room and by target user. If room ID is missing then we just rate limit by target user. """ if room_id: - self._invites_per_room_limiter.ratelimit(room_id) + await self._invites_per_room_limiter.ratelimit(requester, room_id) - self._invites_per_user_limiter.ratelimit(invitee_user_id) + await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id) async def _local_membership_update( self, @@ -237,7 +246,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ( allowed, time_allowed, - ) = self._join_rate_limiter_local.can_requester_do_action(requester) + ) = await self._join_rate_limiter_local.can_do_action(requester) if not allowed: raise LimitExceededError( @@ -421,9 +430,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if effective_membership_state == Membership.INVITE: target_id = target.to_string() if ratelimit: - # Don't ratelimit application services. - if not requester.app_service or requester.app_service.is_rate_limited(): - self.ratelimit_invite(room_id, target_id) + await self.ratelimit_invite(requester, room_id, target_id) # block any attempts to invite the server notices mxid if target_id == self._server_notices_mxid: @@ -534,7 +541,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ( allowed, time_allowed, - ) = self._join_rate_limiter_remote.can_requester_do_action( + ) = await self._join_rate_limiter_remote.can_do_action( requester, ) diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index d005f38767..73d7477854 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -77,7 +77,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) - self.registration_handler.check_registration_ratelimit(content["address"]) + await self.registration_handler.check_registration_ratelimit(content["address"]) await self.registration_handler.register_with_store( user_id=user_id, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index e4c352f572..3151e72d4f 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -74,11 +74,13 @@ class LoginRestServlet(RestServlet): self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( + store=hs.get_datastore(), clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_address.per_second, burst_count=self.hs.config.rc_login_address.burst_count, ) self._account_ratelimiter = Ratelimiter( + store=hs.get_datastore(), clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_account.per_second, burst_count=self.hs.config.rc_login_account.burst_count, @@ -141,20 +143,22 @@ class LoginRestServlet(RestServlet): appservice = self.auth.get_appservice_by_req(request) if appservice.is_rate_limited(): - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit( + None, request.getClientIP() + ) result = await self._do_appservice_login(login_submission, appservice) elif self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) result = await self._do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) result = await self._do_token_login(login_submission) else: - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) result = await self._do_other_login(login_submission) except KeyError: raise SynapseError(400, "Missing JSON keys.") @@ -258,7 +262,7 @@ class LoginRestServlet(RestServlet): # too often. This happens here rather than before as we don't # necessarily know the user before now. if ratelimit: - self._account_ratelimiter.ratelimit(user_id.lower()) + await self._account_ratelimiter.ratelimit(None, user_id.lower()) if create_non_existent_users: canonical_uid = await self.auth_handler.check_user_exists(user_id) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index c2ba790bab..411fb57c47 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -103,7 +103,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) - self.identity_handler.ratelimit_request_token_requests(request, "email", email) + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) # The email will be sent to the stored address. # This avoids a potential account hijack by requesting a password reset to @@ -387,7 +389,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests(request, "email", email) + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) if next_link: # Raise if the provided next_link value isn't valid @@ -468,7 +472,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests( + await self.identity_handler.ratelimit_request_token_requests( request, "msisdn", msisdn ) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 8f68d8dfc8..c212da0cb2 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -126,7 +126,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests(request, "email", email) + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email @@ -208,7 +210,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests( + await self.identity_handler.ratelimit_request_token_requests( request, "msisdn", msisdn ) @@ -406,7 +408,7 @@ class RegisterRestServlet(RestServlet): client_addr = request.getClientIP() - self.ratelimiter.ratelimit(client_addr, update=False) + await self.ratelimiter.ratelimit(None, client_addr, update=False) kind = b"user" if b"kind" in request.args: diff --git a/synapse/server.py b/synapse/server.py index e85b9391fa..e42f7b1a18 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -329,6 +329,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_registration_ratelimiter(self) -> Ratelimiter: return Ratelimiter( + store=self.get_datastore(), clock=self.get_clock(), rate_hz=self.config.rc_registration.per_second, burst_count=self.config.rc_registration.burst_count, diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 483418192c..fa96ba07a5 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -5,38 +5,25 @@ from synapse.types import create_requester from tests import unittest -class TestRatelimiter(unittest.TestCase): +class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_via_can_do_action(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0) - self.assertTrue(allowed) - self.assertEquals(10.0, time_allowed) - - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5) - self.assertFalse(allowed) - self.assertEquals(10.0, time_allowed) - - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10) - self.assertTrue(allowed) - self.assertEquals(20.0, time_allowed) - - def test_allowed_user_via_can_requester_do_action(self): - user_requester = create_requester("@user:example.com") - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=0 + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=5) ) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) @@ -51,21 +38,23 @@ class TestRatelimiter(unittest.TestCase): ) as_requester = create_requester("@user:example.com", app_service=appservice) - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=0 + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) @@ -80,73 +69,89 @@ class TestRatelimiter(unittest.TestCase): ) as_requester = create_requester("@user:example.com", app_service=appservice) - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=0 + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) def test_allowed_via_ratelimit(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # Shouldn't raise - limiter.ratelimit(key="test_id", _time_now_s=0) + self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0)) # Should raise with self.assertRaises(LimitExceededError) as context: - limiter.ratelimit(key="test_id", _time_now_s=5) + self.get_success_or_raise( + limiter.ratelimit(None, key="test_id", _time_now_s=5) + ) self.assertEqual(context.exception.retry_after_ms, 5000) # Shouldn't raise - limiter.ratelimit(key="test_id", _time_now_s=10) + self.get_success_or_raise( + limiter.ratelimit(None, key="test_id", _time_now_s=10) + ) def test_allowed_via_can_do_action_and_overriding_parameters(self): """Test that we can override options of can_do_action that would otherwise fail an action """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # First attempt should be allowed - allowed, time_allowed = limiter.can_do_action( - ("test_id",), - _time_now_s=0, + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action( + None, + ("test_id",), + _time_now_s=0, + ) ) self.assertTrue(allowed) self.assertEqual(10.0, time_allowed) # Second attempt, 1s later, will fail - allowed, time_allowed = limiter.can_do_action( - ("test_id",), - _time_now_s=1, + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action( + None, + ("test_id",), + _time_now_s=1, + ) ) self.assertFalse(allowed) self.assertEqual(10.0, time_allowed) # But, if we allow 10 actions/sec for this request, we should be allowed # to continue. - allowed, time_allowed = limiter.can_do_action( - ("test_id",), _time_now_s=1, rate_hz=10.0 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, ("test_id",), _time_now_s=1, rate_hz=10.0) ) self.assertTrue(allowed) self.assertEqual(1.1, time_allowed) # Similarly if we allow a burst of 10 actions - allowed, time_allowed = limiter.can_do_action( - ("test_id",), _time_now_s=1, burst_count=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, ("test_id",), _time_now_s=1, burst_count=10) ) self.assertTrue(allowed) self.assertEqual(1.0, time_allowed) @@ -156,29 +161,72 @@ class TestRatelimiter(unittest.TestCase): fail an action """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # First attempt should be allowed - limiter.ratelimit(key=("test_id",), _time_now_s=0) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=0) + ) # Second attempt, 1s later, will fail with self.assertRaises(LimitExceededError) as context: - limiter.ratelimit(key=("test_id",), _time_now_s=1) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1) + ) self.assertEqual(context.exception.retry_after_ms, 9000) # But, if we allow 10 actions/sec for this request, we should be allowed # to continue. - limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1, rate_hz=10.0) + ) # Similarly if we allow a burst of 10 actions - limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10) + ) def test_pruning(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - limiter.can_do_action(key="test_id_1", _time_now_s=0) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + self.get_success_or_raise( + limiter.can_do_action(None, key="test_id_1", _time_now_s=0) + ) self.assertIn("test_id_1", limiter.actions) - limiter.can_do_action(key="test_id_2", _time_now_s=10) + self.get_success_or_raise( + limiter.can_do_action(None, key="test_id_2", _time_now_s=10) + ) self.assertNotIn("test_id_1", limiter.actions) + + def test_db_user_override(self): + """Test that users that have ratelimiting disabled in the DB aren't + ratelimited. + """ + store = self.hs.get_datastore() + + user_id = "@user:test" + requester = create_requester(user_id) + + self.get_success( + store.db_pool.simple_insert( + table="ratelimit_override", + values={ + "user_id": user_id, + "messages_per_second": None, + "burst_count": None, + }, + desc="test_db_user_override", + ) + ) + + limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1) + + # Shouldn't raise + for _ in range(20): + self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0)) -- cgit 1.5.1 From d959d28730ec6a0765ab72b10bcc96b1507233ac Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 6 Apr 2021 07:21:57 -0400 Subject: Add type hints to the federation handler and server. (#9743) --- changelog.d/9743.misc | 1 + synapse/federation/federation_server.py | 26 +++--- synapse/federation/transport/server.py | 4 +- synapse/handlers/federation.py | 161 ++++++++++++++++---------------- 4 files changed, 97 insertions(+), 95 deletions(-) create mode 100644 changelog.d/9743.misc (limited to 'synapse/federation/federation_server.py') diff --git a/changelog.d/9743.misc b/changelog.d/9743.misc new file mode 100644 index 0000000000..c2f75c1df9 --- /dev/null +++ b/changelog.d/9743.misc @@ -0,0 +1 @@ +Add missing type hints to federation handler and server. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 71cb120ef7..b9f8d966a6 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -739,22 +739,20 @@ class FederationServer(FederationBase): await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) - def __str__(self): + def __str__(self) -> str: return "" % self.server_name async def exchange_third_party_invite( self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict - ): - ret = await self.handler.exchange_third_party_invite( + ) -> None: + await self.handler.exchange_third_party_invite( sender_user_id, target_user_id, room_id, signed ) - return ret - async def on_exchange_third_party_invite_request(self, event_dict: Dict): - ret = await self.handler.on_exchange_third_party_invite_request(event_dict) - return ret + async def on_exchange_third_party_invite_request(self, event_dict: Dict) -> None: + await self.handler.on_exchange_third_party_invite_request(event_dict) - async def check_server_matches_acl(self, server_name: str, room_id: str): + async def check_server_matches_acl(self, server_name: str, room_id: str) -> None: """Check if the given server is allowed by the server ACLs in the room Args: @@ -878,7 +876,7 @@ class FederationHandlerRegistry: def register_edu_handler( self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]] - ): + ) -> None: """Sets the handler callable that will be used to handle an incoming federation EDU of the given type. @@ -897,7 +895,7 @@ class FederationHandlerRegistry: def register_query_handler( self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]] - ): + ) -> None: """Sets the handler callable that will be used to handle an incoming federation query of the given type. @@ -915,15 +913,17 @@ class FederationHandlerRegistry: self.query_handlers[query_type] = handler - def register_instance_for_edu(self, edu_type: str, instance_name: str): + def register_instance_for_edu(self, edu_type: str, instance_name: str) -> None: """Register that the EDU handler is on a different instance than master.""" self._edu_type_to_instance[edu_type] = [instance_name] - def register_instances_for_edu(self, edu_type: str, instance_names: List[str]): + def register_instances_for_edu( + self, edu_type: str, instance_names: List[str] + ) -> None: """Register that the EDU handler is on multiple instances.""" self._edu_type_to_instance[edu_type] = instance_names - async def on_edu(self, edu_type: str, origin: str, content: dict): + async def on_edu(self, edu_type: str, origin: str, content: dict) -> None: if not self.config.use_presence and edu_type == EduTypes.Presence: return diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 84e39c5a46..5ef0556ef7 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -620,8 +620,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet): PATH = "/exchange_third_party_invite/(?P[^/]*)" async def on_PUT(self, origin, content, query, room_id): - content = await self.handler.on_exchange_third_party_invite_request(content) - return 200, content + await self.handler.on_exchange_third_party_invite_request(content) + return 200, {} class FederationClientKeysQueryServlet(BaseFederationServlet): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3ebee38ebe..5ea8a7b603 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -21,7 +21,17 @@ import itertools import logging from collections.abc import Container from http import HTTPStatus -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import ( + TYPE_CHECKING, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) import attr from signedjson.key import decode_verify_key_bytes @@ -171,15 +181,17 @@ class FederationHandler(BaseHandler): self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages - async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: + async def on_receive_pdu( + self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False + ) -> None: """Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events Args: - origin (str): server which initiated the /send/ transaction. Will + origin: server which initiated the /send/ transaction. Will be used to fetch missing events or state. - pdu (FrozenEvent): received PDU - sent_to_us_directly (bool): True if this event was pushed to us; False if + pdu: received PDU + sent_to_us_directly: True if this event was pushed to us; False if we pulled it as the result of a missing prev_event. """ @@ -411,13 +423,15 @@ class FederationHandler(BaseHandler): await self._process_received_pdu(origin, pdu, state=state) - async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): + async def _get_missing_events_for_pdu( + self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int + ) -> None: """ Args: - origin (str): Origin of the pdu. Will be called to get the missing events + origin: Origin of the pdu. Will be called to get the missing events pdu: received pdu - prevs (set(str)): List of event ids which we are missing - min_depth (int): Minimum depth of events to return. + prevs: List of event ids which we are missing + min_depth: Minimum depth of events to return. """ room_id = pdu.room_id @@ -778,7 +792,7 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, state: Optional[Iterable[EventBase]], - ): + ) -> None: """Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. @@ -887,7 +901,9 @@ class FederationHandler(BaseHandler): logger.exception("Failed to resync device for %s", sender) @log_function - async def backfill(self, dest, room_id, limit, extremities): + async def backfill( + self, dest: str, room_id: str, limit: int, extremities: List[str] + ) -> List[EventBase]: """Trigger a backfill request to `dest` for the given `room_id` This will attempt to get more events from the remote. If the other side @@ -1142,16 +1158,15 @@ class FederationHandler(BaseHandler): curr_state = await self.state_handler.get_current_state(room_id) - def get_domains_from_state(state): + def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]: """Get joined domains from state Args: - state (dict[tuple, FrozenEvent]): State map from type/state - key to event. + state: State map from type/state key to event. Returns: - list[tuple[str, int]]: Returns a list of servers with the - lowest depth of their joins. Sorted by lowest depth first. + Returns a list of servers with the lowest depth of their joins. + Sorted by lowest depth first. """ joined_users = [ (state_key, int(event.depth)) @@ -1179,7 +1194,7 @@ class FederationHandler(BaseHandler): domain for domain, depth in curr_domains if domain != self.server_name ] - async def try_backfill(domains): + async def try_backfill(domains: List[str]) -> bool: # TODO: Should we try multiple of these at a time? for dom in domains: try: @@ -1258,21 +1273,25 @@ class FederationHandler(BaseHandler): } for e_id, _ in sorted_extremeties_tuple: - likely_domains = get_domains_from_state(states[e_id]) + likely_extremeties_domains = get_domains_from_state(states[e_id]) success = await try_backfill( - [dom for dom, _ in likely_domains if dom not in tried_domains] + [ + dom + for dom, _ in likely_extremeties_domains + if dom not in tried_domains + ] ) if success: return True - tried_domains.update(dom for dom, _ in likely_domains) + tried_domains.update(dom for dom, _ in likely_extremeties_domains) return False async def _get_events_and_persist( self, destination: str, room_id: str, events: Iterable[str] - ): + ) -> None: """Fetch the given events from a server, and persist them as outliers. This function *does not* recursively get missing auth events of the @@ -1348,7 +1367,7 @@ class FederationHandler(BaseHandler): event_infos, ) - def _sanity_check_event(self, ev): + def _sanity_check_event(self, ev: EventBase) -> None: """ Do some early sanity checks of a received event @@ -1357,9 +1376,7 @@ class FederationHandler(BaseHandler): or cascade of event fetches. Args: - ev (synapse.events.EventBase): event to be checked - - Returns: None + ev: event to be checked Raises: SynapseError if the event does not pass muster @@ -1380,7 +1397,7 @@ class FederationHandler(BaseHandler): ) raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") - async def send_invite(self, target_host, event): + async def send_invite(self, target_host: str, event: EventBase) -> EventBase: """Sends the invite to the remote server for signing. Invites must be signed by the invitee's server before distribution. @@ -1528,12 +1545,13 @@ class FederationHandler(BaseHandler): run_in_background(self._handle_queued_pdus, room_queue) - async def _handle_queued_pdus(self, room_queue): + async def _handle_queued_pdus( + self, room_queue: List[Tuple[EventBase, str]] + ) -> None: """Process PDUs which got queued up while we were busy send_joining. Args: - room_queue (list[FrozenEvent, str]): list of PDUs to be processed - and the servers that sent them + room_queue: list of PDUs to be processed and the servers that sent them """ for p, origin in room_queue: try: @@ -1612,7 +1630,7 @@ class FederationHandler(BaseHandler): return event - async def on_send_join_request(self, origin, pdu): + async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict: """We have received a join event for a room. Fully process it and respond with the current state and auth chains. """ @@ -1668,7 +1686,7 @@ class FederationHandler(BaseHandler): async def on_invite_request( self, origin: str, event: EventBase, room_version: RoomVersion - ): + ) -> EventBase: """We've got an invite event. Process and persist it. Sign it. Respond with the now signed event. @@ -1841,7 +1859,7 @@ class FederationHandler(BaseHandler): return event - async def on_send_leave_request(self, origin, pdu): + async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None: """ We have received a leave event for a room. Fully process it.""" event = pdu @@ -1969,12 +1987,17 @@ class FederationHandler(BaseHandler): else: return None - async def get_min_depth_for_context(self, context): + async def get_min_depth_for_context(self, context: str) -> int: return await self.store.get_min_depth(context) async def _handle_new_event( - self, origin, event, state=None, auth_events=None, backfilled=False - ): + self, + origin: str, + event: EventBase, + state: Optional[Iterable[EventBase]] = None, + auth_events: Optional[MutableStateMap[EventBase]] = None, + backfilled: bool = False, + ) -> EventContext: context = await self._prep_event( origin, event, state=state, auth_events=auth_events, backfilled=backfilled ) @@ -2280,40 +2303,14 @@ class FederationHandler(BaseHandler): logger.warning("Soft-failing %r because %s", event, e) event.internal_metadata.soft_failed = True - async def on_query_auth( - self, origin, event_id, room_id, remote_auth_chain, rejects, missing - ): - in_room = await self.auth.check_host_in_room(room_id, origin) - if not in_room: - raise AuthError(403, "Host not in room.") - - event = await self.store.get_event(event_id, check_room_id=room_id) - - # Just go through and process each event in `remote_auth_chain`. We - # don't want to fall into the trap of `missing` being wrong. - for e in remote_auth_chain: - try: - await self._handle_new_event(origin, e) - except AuthError: - pass - - # Now get the current auth_chain for the event. - local_auth_chain = await self.store.get_auth_chain( - room_id, list(event.auth_event_ids()), include_given=True - ) - - # TODO: Check if we would now reject event_id. If so we need to tell - # everyone. - - ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain) - - logger.debug("on_query_auth returning: %s", ret) - - return ret - async def on_get_missing_events( - self, origin, room_id, earliest_events, latest_events, limit - ): + self, + origin: str, + room_id: str, + earliest_events: List[str], + latest_events: List[str], + limit: int, + ) -> List[EventBase]: in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -2617,8 +2614,8 @@ class FederationHandler(BaseHandler): assumes that we have already processed all events in remote_auth Params: - local_auth (list) - remote_auth (list) + local_auth + remote_auth Returns: dict @@ -2742,8 +2739,8 @@ class FederationHandler(BaseHandler): @log_function async def exchange_third_party_invite( - self, sender_user_id, target_user_id, room_id, signed - ): + self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict + ) -> None: third_party_invite = {"signed": signed} event_dict = { @@ -2835,8 +2832,12 @@ class FederationHandler(BaseHandler): await member_handler.send_membership_event(None, event, context) async def add_display_name_to_third_party_invite( - self, room_version, event_dict, event, context - ): + self, + room_version: str, + event_dict: JsonDict, + event: EventBase, + context: EventContext, + ) -> Tuple[EventBase, EventContext]: key = ( EventTypes.ThirdPartyInvite, event.content["third_party_invite"]["signed"]["token"], @@ -2872,13 +2873,13 @@ class FederationHandler(BaseHandler): EventValidator().validate_new(event, self.config) return (event, context) - async def _check_signature(self, event, context): + async def _check_signature(self, event: EventBase, context: EventContext) -> None: """ Checks that the signature in the event is consistent with its invite. Args: - event (Event): The m.room.member event to check - context (EventContext): + event: The m.room.member event to check + context: Raises: AuthError: if signature didn't match any keys, or key has been @@ -2964,13 +2965,13 @@ class FederationHandler(BaseHandler): raise last_exception - async def _check_key_revocation(self, public_key, url): + async def _check_key_revocation(self, public_key: str, url: str) -> None: """ Checks whether public_key has been revoked. Args: - public_key (str): base-64 encoded public key. - url (str): Key revocation URL. + public_key: base-64 encoded public key. + url: Key revocation URL. Raises: AuthError: if they key has been revoked. -- cgit 1.5.1