From c389bfb6eac421b84d3f91b344e8f3a91e421e83 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 4 Jun 2020 20:03:40 +0100 Subject: Fix encryption algorithm typos in tests/comments (#7637) @uhoreg has confirmed these were both typos. They are only in comments and tests though, rather than anything critical. Introduced in: * https://github.com/matrix-org/synapse/pull/7157 * https://github.com/matrix-org/synapse/pull/5726 --- tests/handlers/test_e2e_keys.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'tests/handlers') diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 854eb6c024..e1e144b2e7 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -222,7 +222,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_key_1 = { "user_id": local_user, "device_id": "abc", - "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], "keys": { "ed25519:abc": "base64+ed25519+key", "curve25519:abc": "base64+curve25519+key", @@ -232,7 +232,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_key_2 = { "user_id": local_user, "device_id": "def", - "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], "keys": { "ed25519:def": "base64+ed25519+key", "curve25519:def": "base64+curve25519+key", @@ -315,7 +315,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_key = { "user_id": local_user, "device_id": device_id, - "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "algorithms": ["m.olm.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], "keys": {"curve25519:xyz": "curve25519+key", "ed25519:xyz": device_pubkey}, "signatures": {local_user: {"ed25519:xyz": "something"}}, } @@ -391,8 +391,8 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "user_id": local_user, "device_id": device_id, "algorithms": [ - "m.olm.curve25519-aes-sha256", - "m.megolm.v1.aes-sha", + "m.olm.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2", ], "keys": { "curve25519:xyz": "curve25519+key", -- cgit 1.5.1 From f4e6495b5d3267976f34088fa7459b388b801eb6 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 5 Jun 2020 10:47:20 +0100 Subject: Performance improvements and refactor of Ratelimiter (#7595) While working on https://github.com/matrix-org/synapse/issues/5665 I found myself digging into the `Ratelimiter` class and seeing that it was both: * Rather undocumented, and * causing a *lot* of config checks This PR attempts to refactor and comment the `Ratelimiter` class, as well as encourage config file accesses to only be done at instantiation. Best to be reviewed commit-by-commit. --- changelog.d/7595.misc | 1 + synapse/api/ratelimiting.py | 153 +++++++++++++++++++++------- synapse/config/ratelimiting.py | 8 +- synapse/handlers/_base.py | 60 ++++++----- synapse/handlers/auth.py | 24 ++--- synapse/handlers/message.py | 1 - synapse/handlers/register.py | 9 +- synapse/rest/client/v1/login.py | 65 ++++-------- synapse/rest/client/v2_alpha/register.py | 16 +-- synapse/server.py | 17 ++-- synapse/util/ratelimitutils.py | 2 +- tests/api/test_ratelimiting.py | 96 +++++++++++++---- tests/handlers/test_profile.py | 6 +- tests/replication/slave/storage/_base.py | 9 +- tests/rest/client/v1/test_events.py | 8 +- tests/rest/client/v1/test_login.py | 49 +++++++-- tests/rest/client/v1/test_rooms.py | 9 +- tests/rest/client/v1/test_typing.py | 10 +- tests/rest/client/v2_alpha/test_register.py | 9 +- 19 files changed, 322 insertions(+), 230 deletions(-) create mode 100644 changelog.d/7595.misc (limited to 'tests/handlers') diff --git a/changelog.d/7595.misc b/changelog.d/7595.misc new file mode 100644 index 0000000000..7a0646b1a3 --- /dev/null +++ b/changelog.d/7595.misc @@ -0,0 +1 @@ +Refactor `Ratelimiter` to limit the amount of expensive config value accesses. diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 7a049b3af7..ec6b3a69a2 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -1,4 +1,5 @@ # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,75 +17,157 @@ from collections import OrderedDict from typing import Any, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.util import Clock class Ratelimiter(object): """ - Ratelimit message sending by user. + Ratelimit actions marked by arbitrary keys. + + Args: + clock: A homeserver clock, for retrieving the current time + rate_hz: The long term number of actions that can be performed in a second. + burst_count: How many actions that can be performed before being limited. """ - def __init__(self): - self.message_counts = ( - OrderedDict() - ) # type: OrderedDict[Any, Tuple[float, int, Optional[float]]] + def __init__(self, clock: Clock, rate_hz: float, burst_count: int): + self.clock = clock + self.rate_hz = rate_hz + self.burst_count = burst_count + + # A ordered dictionary keeping track of actions, when they were last + # performed and how often. Each entry is a mapping from a key of arbitrary type + # to a tuple representing: + # * How many times an action has occurred since a point in time + # * The point in time + # * The rate_hz of this particular entry. This can vary per request + self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]] - def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True): + def can_do_action( + self, + key: Any, + rate_hz: Optional[float] = None, + burst_count: Optional[int] = None, + update: bool = True, + _time_now_s: Optional[int] = None, + ) -> Tuple[bool, float]: """Can the entity (e.g. user or IP address) perform the action? + Args: key: The key we should use when rate limiting. Can be a user ID (when sending events), an IP address, etc. - time_now_s: The time now. - rate_hz: The long term number of messages a user can send in a - second. - burst_count: How many messages the user can send before being - limited. - update (bool): Whether to update the message rates or not. This is - useful to check if a message would be allowed to be sent before - its ready to be actually sent. + rate_hz: The long term number of actions that can be performed in a second. + Overrides the value set during instantiation if set. + burst_count: How many actions that can be performed before being limited. + Overrides the value set during instantiation if set. + update: Whether to count this check as performing the action + _time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Only used by tests. + Returns: - A pair of a bool indicating if they can send a message now and a - time in seconds of when they can next send a message. + A tuple containing: + * A bool indicating if they can perform the action now + * The reactor timestamp for when the action can be performed next. + -1 if rate_hz is less than or equal to zero """ - self.prune_message_counts(time_now_s) - message_count, time_start, _ignored = self.message_counts.get( - key, (0.0, time_now_s, None) - ) + # Override default values if set + time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() + rate_hz = rate_hz if rate_hz is not None else self.rate_hz + burst_count = burst_count if burst_count is not None else self.burst_count + + # Remove any expired entries + self._prune_message_counts(time_now_s) + + # Check if there is an existing count entry for this key + action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0)) + + # Check whether performing another action is allowed time_delta = time_now_s - time_start - sent_count = message_count - time_delta * rate_hz - if sent_count < 0: + performed_count = action_count - time_delta * rate_hz + if performed_count < 0: + # Allow, reset back to count 1 allowed = True time_start = time_now_s - message_count = 1.0 - elif sent_count > burst_count - 1.0: + action_count = 1.0 + elif performed_count > burst_count - 1.0: + # Deny, we have exceeded our burst count allowed = False else: + # We haven't reached our limit yet allowed = True - message_count += 1 + action_count += 1.0 if update: - self.message_counts[key] = (message_count, time_start, rate_hz) + self.actions[key] = (action_count, time_start, rate_hz) if rate_hz > 0: - time_allowed = time_start + (message_count - burst_count + 1) / rate_hz + # Find out when the count of existing actions expires + time_allowed = time_start + (action_count - burst_count + 1) / rate_hz + + # Don't give back a time in the past if time_allowed < time_now_s: time_allowed = time_now_s + else: + # XXX: Why is this -1? This seems to only be used in + # self.ratelimit. I guess so that clients get a time in the past and don't + # feel afraid to try again immediately time_allowed = -1 return allowed, time_allowed - def prune_message_counts(self, time_now_s): - for key in list(self.message_counts.keys()): - message_count, time_start, rate_hz = self.message_counts[key] + def _prune_message_counts(self, time_now_s: int): + """Remove message count entries that have not exceeded their defined + rate_hz limit + + Args: + time_now_s: The current time + """ + # We create a copy of the key list here as the dictionary is modified during + # the loop + for key in list(self.actions.keys()): + action_count, time_start, rate_hz = self.actions[key] + + # Rate limit = "seconds since we started limiting this action" * rate_hz + # If this limit has not been exceeded, wipe our record of this action time_delta = time_now_s - time_start - if message_count - time_delta * rate_hz > 0: - break + if action_count - time_delta * rate_hz > 0: + continue else: - del self.message_counts[key] + del self.actions[key] + + def ratelimit( + self, + key: Any, + rate_hz: Optional[float] = None, + burst_count: Optional[int] = None, + update: bool = True, + _time_now_s: Optional[int] = None, + ): + """Checks if an action can be performed. If not, raises a LimitExceededError + + Args: + key: An arbitrary key used to classify an action + rate_hz: The long term number of actions that can be performed in a second. + Overrides the value set during instantiation if set. + burst_count: How many actions that can be performed before being limited. + Overrides the value set during instantiation if set. + update: Whether to count this check as performing the action + _time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Only used by tests. + + Raises: + LimitExceededError: If an action could not be performed, along with the time in + milliseconds until the action can be performed again + """ + time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() - def ratelimit(self, key, time_now_s, rate_hz, burst_count, update=True): allowed, time_allowed = self.can_do_action( - key, time_now_s, rate_hz, burst_count, update + key, + rate_hz=rate_hz, + burst_count=burst_count, + update=update, + _time_now_s=time_now_s, ) if not allowed: diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 4a3bfc4354..2dd94bae2b 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -12,11 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict + from ._base import Config class RateLimitConfig(object): - def __init__(self, config, defaults={"per_second": 0.17, "burst_count": 3.0}): + def __init__( + self, + config: Dict[str, float], + defaults={"per_second": 0.17, "burst_count": 3.0}, + ): self.per_second = config.get("per_second", defaults["per_second"]) self.burst_count = config.get("burst_count", defaults["burst_count"]) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 3b781d9836..61dc4beafe 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -19,7 +19,7 @@ from twisted.internet import defer import synapse.types from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import LimitExceededError +from synapse.api.ratelimiting import Ratelimiter from synapse.types import UserID logger = logging.getLogger(__name__) @@ -44,11 +44,26 @@ class BaseHandler(object): self.notifier = hs.get_notifier() self.state_handler = hs.get_state_handler() self.distributor = hs.get_distributor() - self.ratelimiter = hs.get_ratelimiter() - self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter() self.clock = hs.get_clock() self.hs = hs + # The rate_hz and burst_count are overridden on a per-user basis + self.request_ratelimiter = Ratelimiter( + clock=self.clock, rate_hz=0, burst_count=0 + ) + self._rc_message = self.hs.config.rc_message + + # Check whether ratelimiting room admin message redaction is enabled + # by the presence of rate limits in the config + if self.hs.config.rc_admin_redaction: + self.admin_redaction_ratelimiter = Ratelimiter( + clock=self.clock, + rate_hz=self.hs.config.rc_admin_redaction.per_second, + burst_count=self.hs.config.rc_admin_redaction.burst_count, + ) + else: + self.admin_redaction_ratelimiter = None + self.server_name = hs.hostname self.event_builder_factory = hs.get_event_builder_factory() @@ -70,7 +85,6 @@ class BaseHandler(object): Raises: LimitExceededError if the request should be ratelimited """ - time_now = self.clock.time() user_id = requester.user.to_string() # The AS user itself is never rate limited. @@ -83,48 +97,32 @@ class BaseHandler(object): if requester.app_service and not requester.app_service.is_rate_limited(): return + messages_per_second = self._rc_message.per_second + burst_count = self._rc_message.burst_count + # Check if there is a per user override in the DB. override = yield self.store.get_ratelimit_for_user(user_id) if override: - # If overriden with a null Hz then ratelimiting has been entirely + # If overridden with a null Hz then ratelimiting has been entirely # disabled for the user if not override.messages_per_second: return messages_per_second = override.messages_per_second burst_count = override.burst_count + + if is_admin_redaction and self.admin_redaction_ratelimiter: + # If we have separate config for admin redactions, use a separate + # ratelimiter as to not have user_ids clash + self.admin_redaction_ratelimiter.ratelimit(user_id, update=update) else: - # We default to different values if this is an admin redaction and - # the config is set - if is_admin_redaction and self.hs.config.rc_admin_redaction: - messages_per_second = self.hs.config.rc_admin_redaction.per_second - burst_count = self.hs.config.rc_admin_redaction.burst_count - else: - messages_per_second = self.hs.config.rc_message.per_second - burst_count = self.hs.config.rc_message.burst_count - - if is_admin_redaction and self.hs.config.rc_admin_redaction: - # If we have separate config for admin redactions we use a separate - # ratelimiter - allowed, time_allowed = self.admin_redaction_ratelimiter.can_do_action( - user_id, - time_now, - rate_hz=messages_per_second, - burst_count=burst_count, - update=update, - ) - else: - allowed, time_allowed = self.ratelimiter.can_do_action( + # Override rate and burst count per-user + self.request_ratelimiter.ratelimit( user_id, - time_now, rate_hz=messages_per_second, burst_count=burst_count, update=update, ) - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)) - ) async def maybe_kick_guest_users(self, event, context=None): # Technically this function invalidates current_state by changing it. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 75b39e878c..119678e67b 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -108,7 +108,11 @@ class AuthHandler(BaseHandler): # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. - self._failed_uia_attempts_ratelimiter = Ratelimiter() + self._failed_uia_attempts_ratelimiter = Ratelimiter( + clock=self.clock, + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + ) self._clock = self.hs.get_clock() @@ -196,13 +200,7 @@ class AuthHandler(BaseHandler): user_id = requester.user.to_string() # Check if we should be ratelimited due to too many previous failed attempts - self._failed_uia_attempts_ratelimiter.ratelimit( - user_id, - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=False, - ) + self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False) # build a list of supported flows flows = [[login_type] for login_type in self._supported_ui_auth_types] @@ -212,14 +210,8 @@ class AuthHandler(BaseHandler): flows, request, request_body, clientip, description ) except LoginError: - # Update the ratelimite to say we failed (`can_do_action` doesn't raise). - self._failed_uia_attempts_ratelimiter.can_do_action( - user_id, - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=True, - ) + # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). + self._failed_uia_attempts_ratelimiter.can_do_action(user_id) raise # find the completed login type diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 681f92cafd..649ca1f08a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -362,7 +362,6 @@ class EventCreationHandler(object): self.profile_handler = hs.get_profile_handler() self.event_builder_factory = hs.get_event_builder_factory() self.server_name = hs.hostname - self.ratelimiter = hs.get_ratelimiter() self.notifier = hs.get_notifier() self.config = hs.config self.require_membership_for_aliases = hs.config.require_membership_for_aliases diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 55a03e53ea..cd746be7c8 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -425,14 +425,7 @@ class RegistrationHandler(BaseHandler): if not address: return - time_now = self.clock.time() - - self.ratelimiter.ratelimit( - address, - time_now_s=time_now, - rate_hz=self.hs.config.rc_registration.per_second, - burst_count=self.hs.config.rc_registration.burst_count, - ) + self.ratelimiter.ratelimit(address) def register_with_store( self, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 6ac7c5142b..dceb2792fa 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -87,11 +87,22 @@ class LoginRestServlet(RestServlet): self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() - self._clock = hs.get_clock() self._well_known_builder = WellKnownBuilder(hs) - self._address_ratelimiter = Ratelimiter() - self._account_ratelimiter = Ratelimiter() - self._failed_attempts_ratelimiter = Ratelimiter() + self._address_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_address.per_second, + burst_count=self.hs.config.rc_login_address.burst_count, + ) + self._account_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_account.per_second, + burst_count=self.hs.config.rc_login_account.burst_count, + ) + self._failed_attempts_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + ) def on_GET(self, request): flows = [] @@ -124,13 +135,7 @@ class LoginRestServlet(RestServlet): return 200, {} async def on_POST(self, request): - self._address_ratelimiter.ratelimit( - request.getClientIP(), - time_now_s=self.hs.clock.time(), - rate_hz=self.hs.config.rc_login_address.per_second, - burst_count=self.hs.config.rc_login_address.burst_count, - update=True, - ) + self._address_ratelimiter.ratelimit(request.getClientIP()) login_submission = parse_json_object_from_request(request) try: @@ -198,13 +203,7 @@ class LoginRestServlet(RestServlet): # We also apply account rate limiting using the 3PID as a key, as # otherwise using 3PID bypasses the ratelimiting based on user ID. - self._failed_attempts_ratelimiter.ratelimit( - (medium, address), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=False, - ) + self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False) # Check for login providers that support 3pid login types ( @@ -238,13 +237,7 @@ class LoginRestServlet(RestServlet): # If it returned None but the 3PID was bound then we won't hit # this code path, which is fine as then the per-user ratelimit # will kick in below. - self._failed_attempts_ratelimiter.can_do_action( - (medium, address), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=True, - ) + self._failed_attempts_ratelimiter.can_do_action((medium, address)) raise LoginError(403, "", errcode=Codes.FORBIDDEN) identifier = {"type": "m.id.user", "user": user_id} @@ -263,11 +256,7 @@ class LoginRestServlet(RestServlet): # Check if we've hit the failed ratelimit (but don't update it) self._failed_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=False, + qualified_user_id.lower(), update=False ) try: @@ -279,13 +268,7 @@ class LoginRestServlet(RestServlet): # limiter. Using `can_do_action` avoids us raising a ratelimit # exception and masking the LoginError. The actual ratelimiting # should have happened above. - self._failed_attempts_ratelimiter.can_do_action( - qualified_user_id.lower(), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - update=True, - ) + self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower()) raise result = await self._complete_login( @@ -318,13 +301,7 @@ class LoginRestServlet(RestServlet): # Before we actually log them in we check if they've already logged in # too often. This happens here rather than before as we don't # necessarily know the user before now. - self._account_ratelimiter.ratelimit( - user_id.lower(), - time_now_s=self._clock.time(), - rate_hz=self.hs.config.rc_login_account.per_second, - burst_count=self.hs.config.rc_login_account.burst_count, - update=True, - ) + self._account_ratelimiter.ratelimit(user_id.lower()) if create_non_existent_users: canonical_uid = await self.auth_handler.check_user_exists(user_id) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index addd4cae19..b9ffe86b2a 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -26,7 +26,6 @@ import synapse.types from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, - LimitExceededError, SynapseError, ThreepidValidationError, UnrecognizedRequestError, @@ -396,20 +395,7 @@ class RegisterRestServlet(RestServlet): client_addr = request.getClientIP() - time_now = self.clock.time() - - allowed, time_allowed = self.ratelimiter.can_do_action( - client_addr, - time_now_s=time_now, - rate_hz=self.hs.config.rc_registration.per_second, - burst_count=self.hs.config.rc_registration.burst_count, - update=False, - ) - - if not allowed: - raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)) - ) + self.ratelimiter.ratelimit(client_addr, update=False) kind = b"user" if b"kind" in request.args: diff --git a/synapse/server.py b/synapse/server.py index ca2deb49bb..fe94836a2c 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -242,9 +242,12 @@ class HomeServer(object): self.clock = Clock(reactor) self.distributor = Distributor() - self.ratelimiter = Ratelimiter() - self.admin_redaction_ratelimiter = Ratelimiter() - self.registration_ratelimiter = Ratelimiter() + + self.registration_ratelimiter = Ratelimiter( + clock=self.clock, + rate_hz=config.rc_registration.per_second, + burst_count=config.rc_registration.burst_count, + ) self.datastores = None @@ -314,15 +317,9 @@ class HomeServer(object): def get_distributor(self): return self.distributor - def get_ratelimiter(self): - return self.ratelimiter - - def get_registration_ratelimiter(self): + def get_registration_ratelimiter(self) -> Ratelimiter: return self.registration_ratelimiter - def get_admin_redaction_ratelimiter(self): - return self.admin_redaction_ratelimiter - def build_federation_client(self): return FederationClient(self) diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 5ca4521ce3..e5efdfcd02 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -43,7 +43,7 @@ class FederationRateLimiter(object): self.ratelimiters = collections.defaultdict(new_limiter) def ratelimit(self, host): - """Used to ratelimit an incoming request from given host + """Used to ratelimit an incoming request from a given host Example usage: diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index dbdd427cac..d580e729c5 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -1,39 +1,97 @@ -from synapse.api.ratelimiting import Ratelimiter +from synapse.api.ratelimiting import LimitExceededError, Ratelimiter from tests import unittest class TestRatelimiter(unittest.TestCase): - def test_allowed(self): - limiter = Ratelimiter() - allowed, time_allowed = limiter.can_do_action( - key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1 - ) + def test_allowed_via_can_do_action(self): + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_do_action( - key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1 - ) + allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_do_action( - key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1 - ) + allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) - def test_pruning(self): - limiter = Ratelimiter() + def test_allowed_via_ratelimit(self): + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + + # Shouldn't raise + limiter.ratelimit(key="test_id", _time_now_s=0) + + # Should raise + with self.assertRaises(LimitExceededError) as context: + limiter.ratelimit(key="test_id", _time_now_s=5) + self.assertEqual(context.exception.retry_after_ms, 5000) + + # Shouldn't raise + limiter.ratelimit(key="test_id", _time_now_s=10) + + def test_allowed_via_can_do_action_and_overriding_parameters(self): + """Test that we can override options of can_do_action that would otherwise fail + an action + """ + # Create a Ratelimiter with a very low allowed rate_hz and burst_count + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + + # First attempt should be allowed + allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=0,) + self.assertTrue(allowed) + self.assertEqual(10.0, time_allowed) + + # Second attempt, 1s later, will fail + allowed, time_allowed = limiter.can_do_action(("test_id",), _time_now_s=1,) + self.assertFalse(allowed) + self.assertEqual(10.0, time_allowed) + + # But, if we allow 10 actions/sec for this request, we should be allowed + # to continue. allowed, time_allowed = limiter.can_do_action( - key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1 + ("test_id",), _time_now_s=1, rate_hz=10.0 ) + self.assertTrue(allowed) + self.assertEqual(1.1, time_allowed) - self.assertIn("test_id_1", limiter.message_counts) - + # Similarly if we allow a burst of 10 actions allowed, time_allowed = limiter.can_do_action( - key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1 + ("test_id",), _time_now_s=1, burst_count=10 ) + self.assertTrue(allowed) + self.assertEqual(1.0, time_allowed) + + def test_allowed_via_ratelimit_and_overriding_parameters(self): + """Test that we can override options of the ratelimit method that would otherwise + fail an action + """ + # Create a Ratelimiter with a very low allowed rate_hz and burst_count + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + + # First attempt should be allowed + limiter.ratelimit(key=("test_id",), _time_now_s=0) + + # Second attempt, 1s later, will fail + with self.assertRaises(LimitExceededError) as context: + limiter.ratelimit(key=("test_id",), _time_now_s=1) + self.assertEqual(context.exception.retry_after_ms, 9000) + + # But, if we allow 10 actions/sec for this request, we should be allowed + # to continue. + limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0) + + # Similarly if we allow a burst of 10 actions + limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10) + + def test_pruning(self): + limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter.can_do_action(key="test_id_1", _time_now_s=0) + + self.assertIn("test_id_1", limiter.actions) + + limiter.can_do_action(key="test_id_2", _time_now_s=10) - self.assertNotIn("test_id_1", limiter.message_counts) + self.assertNotIn("test_id_1", limiter.actions) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 8aa56f1496..29dd7d9c6e 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -14,7 +14,7 @@ # limitations under the License. -from mock import Mock, NonCallableMock +from mock import Mock from twisted.internet import defer @@ -55,12 +55,8 @@ class ProfileTestCase(unittest.TestCase): federation_client=self.mock_federation, federation_server=Mock(), federation_registry=self.mock_registry, - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) - self.store = hs.get_datastore() self.frank = UserID.from_string("@1234ABCD:test") diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 32cb04645f..56497b8476 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from mock import Mock, NonCallableMock +from mock import Mock from tests.replication._base import BaseStreamTestCase @@ -21,12 +21,7 @@ from tests.replication._base import BaseStreamTestCase class BaseSlavedStoreTestCase(BaseStreamTestCase): def make_homeserver(self, reactor, clock): - hs = self.setup_test_homeserver( - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), - ) - - hs.get_ratelimiter().can_do_action.return_value = (True, 0) + hs = self.setup_test_homeserver(federation_client=Mock()) return hs diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index b54b06482b..f75520877f 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -15,7 +15,7 @@ """ Tests REST events for /events paths.""" -from mock import Mock, NonCallableMock +from mock import Mock import synapse.rest.admin from synapse.rest.client.v1 import events, login, room @@ -40,11 +40,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): config["enable_registration"] = True config["auto_join_rooms"] = [] - hs = self.setup_test_homeserver( - config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"]) - ) - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) + hs = self.setup_test_homeserver(config=config) hs.get_handlers().federation_handler = Mock() diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 0f0f7ca72d..9033f09fd2 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -29,7 +29,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor, clock): - self.hs = self.setup_test_homeserver() self.hs.config.enable_registration = True self.hs.config.registrations_require_3pid = [] @@ -38,10 +37,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): return self.hs + @override_config( + { + "rc_login": { + "address": {"per_second": 0.17, "burst_count": 5}, + # Prevent the account login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "account": {"per_second": 10000, "burst_count": 10000}, + } + } + ) def test_POST_ratelimiting_per_address(self): - self.hs.config.rc_login_address.burst_count = 5 - self.hs.config.rc_login_address.per_second = 0.17 - # Create different users so we're sure not to be bothered by the per-user # ratelimiter. for i in range(0, 6): @@ -80,10 +89,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config( + { + "rc_login": { + "account": {"per_second": 0.17, "burst_count": 5}, + # Prevent the address login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "address": {"per_second": 10000, "burst_count": 10000}, + } + } + ) def test_POST_ratelimiting_per_account(self): - self.hs.config.rc_login_account.burst_count = 5 - self.hs.config.rc_login_account.per_second = 0.17 - self.register_user("kermit", "monkey") for i in range(0, 6): @@ -119,10 +138,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config( + { + "rc_login": { + # Prevent the address login ratelimiter from raising first + # + # This is normally covered by the default test homeserver config + # which sets these values to 10000, but as we're overriding the entire + # rc_login dict here, we need to set this manually as well + "address": {"per_second": 10000, "burst_count": 10000}, + "failed_attempts": {"per_second": 0.17, "burst_count": 5}, + } + } + ) def test_POST_ratelimiting_per_account_failed_attempts(self): - self.hs.config.rc_login_failed_attempts.burst_count = 5 - self.hs.config.rc_login_failed_attempts.per_second = 0.17 - self.register_user("kermit", "monkey") for i in range(0, 6): diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 7dd86d0c27..4886bbb401 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -20,7 +20,7 @@ import json -from mock import Mock, NonCallableMock +from mock import Mock from six.moves.urllib import parse as urlparse from twisted.internet import defer @@ -46,13 +46,8 @@ class RoomBase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), + "red", http_client=None, federation_client=Mock(), ) - self.ratelimiter = self.hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) self.hs.get_federation_handler = Mock(return_value=Mock()) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 4bc3aaf02d..18260bb90e 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -16,7 +16,7 @@ """Tests REST events for /rooms paths.""" -from mock import Mock, NonCallableMock +from mock import Mock from twisted.internet import defer @@ -39,17 +39,11 @@ class RoomTypingTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): hs = self.setup_test_homeserver( - "red", - http_client=None, - federation_client=Mock(), - ratelimiter=NonCallableMock(spec_set=["can_do_action"]), + "red", http_client=None, federation_client=Mock(), ) self.event_source = hs.get_event_sources().sources["typing"] - self.ratelimiter = hs.get_ratelimiter() - self.ratelimiter.can_do_action.return_value = (True, 0) - hs.get_handlers().federation_handler = Mock() def get_user_by_access_token(token=None, allow_guest=False): diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 5637ce2090..7deaf5b24a 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -29,6 +29,7 @@ from synapse.rest.client.v1 import login, logout from synapse.rest.client.v2_alpha import account, account_validity, register, sync from tests import unittest +from tests.unittest import override_config class RegisterRestServletTestCase(unittest.HomeserverTestCase): @@ -146,10 +147,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.json_body["error"], "Guest access is disabled") + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting_guest(self): - self.hs.config.rc_registration.burst_count = 5 - self.hs.config.rc_registration.per_second = 0.17 - for i in range(0, 6): url = self.url + b"?kind=guest" request, channel = self.make_request(b"POST", url, b"{}") @@ -168,10 +167,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) + @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting(self): - self.hs.config.rc_registration.burst_count = 5 - self.hs.config.rc_registration.per_second = 0.17 - for i in range(0, 6): params = { "username": "kermit" + str(i), -- cgit 1.5.1 From 09099313e6d527938013bb46640efc3768960d21 Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Fri, 5 Jun 2020 11:18:15 -0600 Subject: Add an option to disable autojoin for guest accounts (#6637) Fixes https://github.com/matrix-org/synapse/issues/3177 --- changelog.d/6637.feature | 1 + docs/sample_config.yaml | 7 +++++++ synapse/config/registration.py | 8 ++++++++ synapse/handlers/register.py | 8 +++++++- tests/handlers/test_register.py | 10 ++++++++++ 5 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 changelog.d/6637.feature (limited to 'tests/handlers') diff --git a/changelog.d/6637.feature b/changelog.d/6637.feature new file mode 100644 index 0000000000..5228ebc1e5 --- /dev/null +++ b/changelog.d/6637.feature @@ -0,0 +1 @@ +Add an option to disable autojoining rooms for guest accounts. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index b06394a2bd..94e1ec698f 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1223,6 +1223,13 @@ account_threepid_delegates: # #autocreate_auto_join_rooms: true +# When auto_join_rooms is specified, setting this flag to false prevents +# guest accounts from being automatically joined to the rooms. +# +# Defaults to true. +# +#auto_join_rooms_for_guests: false + ## Metrics ### diff --git a/synapse/config/registration.py b/synapse/config/registration.py index a9aa8c3737..fecced2d57 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -128,6 +128,7 @@ class RegistrationConfig(Config): if not RoomAlias.is_valid(room_alias): raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,)) self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True) + self.auto_join_rooms_for_guests = config.get("auto_join_rooms_for_guests", True) self.enable_set_displayname = config.get("enable_set_displayname", True) self.enable_set_avatar_url = config.get("enable_set_avatar_url", True) @@ -368,6 +369,13 @@ class RegistrationConfig(Config): # users cannot be auto-joined since they do not exist. # #autocreate_auto_join_rooms: true + + # When auto_join_rooms is specified, setting this flag to false prevents + # guest accounts from being automatically joined to the rooms. + # + # Defaults to true. + # + #auto_join_rooms_for_guests: false """ % locals() ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ffda09226c..5c7113a3bb 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -244,7 +244,13 @@ class RegistrationHandler(BaseHandler): fail_count += 1 if not self.hs.config.user_consent_at_registration: - yield defer.ensureDeferred(self._auto_join_rooms(user_id)) + if not self.hs.config.auto_join_rooms_for_guests and make_guest: + logger.info( + "Skipping auto-join for %s because auto-join for guests is disabled", + user_id, + ) + else: + yield defer.ensureDeferred(self._auto_join_rooms(user_id)) else: logger.info( "Skipping auto-join for %s because consent is required at registration", diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 1b7935cef2..ca32f993a3 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -135,6 +135,16 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart="local_part"), ResourceLimitError ) + def test_auto_join_rooms_for_guests(self): + room_alias_str = "#room:test" + self.hs.config.auto_join_rooms = [room_alias_str] + self.hs.config.auto_join_rooms_for_guests = False + user_id = self.get_success( + self.handler.register_user(localpart="jeff", make_guest=True), + ) + rooms = self.get_success(self.store.get_rooms_for_user(user_id)) + self.assertEqual(len(rooms), 0) + def test_auto_create_auto_join_rooms(self): room_alias_str = "#room:test" self.hs.config.auto_join_rooms = [room_alias_str] -- cgit 1.5.1 From 737b4a936e35daaa839e7296d888041941546b47 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 5 Jun 2020 14:42:55 -0400 Subject: Convert user directory handler and related classes to async/await. (#7640) --- changelog.d/7640.misc | 1 + synapse/handlers/register.py | 6 +- synapse/handlers/state_deltas.py | 9 +-- synapse/handlers/stats.py | 47 ++++++-------- synapse/handlers/user_directory.py | 118 +++++++++++++--------------------- tests/handlers/test_user_directory.py | 8 +-- 6 files changed, 78 insertions(+), 111 deletions(-) create mode 100644 changelog.d/7640.misc (limited to 'tests/handlers') diff --git a/changelog.d/7640.misc b/changelog.d/7640.misc new file mode 100644 index 0000000000..55edc1c781 --- /dev/null +++ b/changelog.d/7640.misc @@ -0,0 +1 @@ +Convert user directory, state deltas, and stats handlers to async/await. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 5c7113a3bb..af812dbda9 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -207,8 +207,10 @@ class RegistrationHandler(BaseHandler): if self.hs.config.user_directory_search_all_users: profile = yield self.store.get_profileinfo(localpart) - yield self.user_directory_handler.handle_local_profile_change( - user_id, profile + yield defer.ensureDeferred( + self.user_directory_handler.handle_local_profile_change( + user_id, profile + ) ) else: diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index f065970c40..8590c1eff4 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - logger = logging.getLogger(__name__) @@ -24,8 +22,7 @@ class StateDeltasHandler(object): def __init__(self, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks - def _get_key_change(self, prev_event_id, event_id, key_name, public_value): + async def _get_key_change(self, prev_event_id, event_id, key_name, public_value): """Given two events check if the `key_name` field in content changed from not matching `public_value` to doing so. @@ -41,10 +38,10 @@ class StateDeltasHandler(object): prev_event = None event = None if prev_event_id: - prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + prev_event = await self.store.get_event(prev_event_id, allow_none=True) if event_id: - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) if not event and not prev_event: logger.debug("Neither event exists: %r %r", prev_event_id, event_id) diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index d93a276693..149f861239 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -16,17 +16,14 @@ import logging from collections import Counter -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership -from synapse.handlers.state_deltas import StateDeltasHandler from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process logger = logging.getLogger(__name__) -class StatsHandler(StateDeltasHandler): +class StatsHandler: """Handles keeping the *_stats tables updated with a simple time-series of information about the users, rooms and media on the server, such that admins have some idea of who is consuming their resources. @@ -35,7 +32,6 @@ class StatsHandler(StateDeltasHandler): """ def __init__(self, hs): - super(StatsHandler, self).__init__(hs) self.hs = hs self.store = hs.get_datastore() self.state = hs.get_state_handler() @@ -68,20 +64,18 @@ class StatsHandler(StateDeltasHandler): self._is_processing = True - @defer.inlineCallbacks - def process(): + async def process(): try: - yield self._unsafe_process() + await self._unsafe_process() finally: self._is_processing = False run_as_background_process("stats.notify_new_event", process) - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): # If self.pos is None then means we haven't fetched it from DB if self.pos is None: - self.pos = yield self.store.get_stats_positions() + self.pos = await self.store.get_stats_positions() # Loop round handling deltas until we're up to date @@ -96,13 +90,13 @@ class StatsHandler(StateDeltasHandler): logger.debug( "Processing room stats %s->%s", self.pos, room_max_stream_ordering ) - max_pos, deltas = yield self.store.get_current_state_deltas( + max_pos, deltas = await self.store.get_current_state_deltas( self.pos, room_max_stream_ordering ) if deltas: logger.debug("Handling %d state deltas", len(deltas)) - room_deltas, user_deltas = yield self._handle_deltas(deltas) + room_deltas, user_deltas = await self._handle_deltas(deltas) else: room_deltas = {} user_deltas = {} @@ -111,7 +105,7 @@ class StatsHandler(StateDeltasHandler): ( room_count, user_count, - ) = yield self.store.get_changes_room_total_events_and_bytes( + ) = await self.store.get_changes_room_total_events_and_bytes( self.pos, max_pos ) @@ -125,7 +119,7 @@ class StatsHandler(StateDeltasHandler): logger.debug("user_deltas: %s", user_deltas) # Always call this so that we update the stats position. - yield self.store.bulk_update_stats_delta( + await self.store.bulk_update_stats_delta( self.clock.time_msec(), updates={"room": room_deltas, "user": user_deltas}, stream_id=max_pos, @@ -137,13 +131,12 @@ class StatsHandler(StateDeltasHandler): self.pos = max_pos - @defer.inlineCallbacks - def _handle_deltas(self, deltas): + async def _handle_deltas(self, deltas): """Called with the state deltas to process Returns: - Deferred[tuple[dict[str, Counter], dict[str, counter]]] - Resovles to two dicts, the room deltas and the user deltas, + tuple[dict[str, Counter], dict[str, counter]] + Two dicts: the room deltas and the user deltas, mapping from room/user ID to changes in the various fields. """ @@ -162,7 +155,7 @@ class StatsHandler(StateDeltasHandler): logger.debug("Handling: %r, %r %r, %s", room_id, typ, state_key, event_id) - token = yield self.store.get_earliest_token_for_stats("room", room_id) + token = await self.store.get_earliest_token_for_stats("room", room_id) # If the earliest token to begin from is larger than our current # stream ID, skip processing this delta. @@ -184,7 +177,7 @@ class StatsHandler(StateDeltasHandler): sender = None if event_id is not None: - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) if event: event_content = event.content or {} sender = event.sender @@ -200,16 +193,16 @@ class StatsHandler(StateDeltasHandler): room_stats_delta["current_state_events"] += 1 if typ == EventTypes.Member: - # we could use _get_key_change here but it's a bit inefficient - # given we're not testing for a specific result; might as well - # just grab the prev_membership and membership strings and - # compare them. + # we could use StateDeltasHandler._get_key_change here but it's + # a bit inefficient given we're not testing for a specific + # result; might as well just grab the prev_membership and + # membership strings and compare them. # We take None rather than leave as a previous membership # in the absence of a previous event because we do not want to # reduce the leave count when a new-to-the-room user joins. prev_membership = None if prev_event_id is not None: - prev_event = yield self.store.get_event( + prev_event = await self.store.get_event( prev_event_id, allow_none=True ) if prev_event: @@ -301,6 +294,6 @@ class StatsHandler(StateDeltasHandler): for room_id, state in room_to_state_updates.items(): logger.debug("Updating room_stats_state for %s: %s", room_id, state) - yield self.store.update_room_state(room_id, state) + await self.store.update_room_state(room_id, state) return room_to_stats_deltas, user_to_stats_deltas diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 722760c59d..12423b909a 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -17,14 +17,11 @@ import logging from six import iteritems, iterkeys -from twisted.internet import defer - import synapse.metrics from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.handlers.state_deltas import StateDeltasHandler from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.roommember import ProfileInfo -from synapse.types import get_localpart_from_id from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -103,43 +100,39 @@ class UserDirectoryHandler(StateDeltasHandler): if self._is_processing: return - @defer.inlineCallbacks - def process(): + async def process(): try: - yield self._unsafe_process() + await self._unsafe_process() finally: self._is_processing = False self._is_processing = True run_as_background_process("user_directory.notify_new_event", process) - @defer.inlineCallbacks - def handle_local_profile_change(self, user_id, profile): + async def handle_local_profile_change(self, user_id, profile): """Called to update index of our local user profiles when they change irrespective of any rooms the user may be in. """ # FIXME(#3714): We should probably do this in the same worker as all # the other changes. - is_support = yield self.store.is_support_user(user_id) + is_support = await self.store.is_support_user(user_id) # Support users are for diagnostics and should not appear in the user directory. if not is_support: - yield self.store.update_profile_in_user_dir( + await self.store.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) - @defer.inlineCallbacks - def handle_user_deactivated(self, user_id): + async def handle_user_deactivated(self, user_id): """Called when a user ID is deactivated """ # FIXME(#3714): We should probably do this in the same worker as all # the other changes. - yield self.store.remove_from_user_dir(user_id) + await self.store.remove_from_user_dir(user_id) - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): # If self.pos is None then means we haven't fetched it from DB if self.pos is None: - self.pos = yield self.store.get_user_directory_stream_pos() + self.pos = await self.store.get_user_directory_stream_pos() # If still None then the initial background update hasn't happened yet if self.pos is None: @@ -155,12 +148,12 @@ class UserDirectoryHandler(StateDeltasHandler): logger.debug( "Processing user stats %s->%s", self.pos, room_max_stream_ordering ) - max_pos, deltas = yield self.store.get_current_state_deltas( + max_pos, deltas = await self.store.get_current_state_deltas( self.pos, room_max_stream_ordering ) logger.debug("Handling %d state deltas", len(deltas)) - yield self._handle_deltas(deltas) + await self._handle_deltas(deltas) self.pos = max_pos @@ -169,10 +162,9 @@ class UserDirectoryHandler(StateDeltasHandler): max_pos ) - yield self.store.update_user_directory_stream_pos(max_pos) + await self.store.update_user_directory_stream_pos(max_pos) - @defer.inlineCallbacks - def _handle_deltas(self, deltas): + async def _handle_deltas(self, deltas): """Called with the state deltas to process """ for delta in deltas: @@ -187,11 +179,11 @@ class UserDirectoryHandler(StateDeltasHandler): # For join rule and visibility changes we need to check if the room # may have become public or not and add/remove the users in said room if typ in (EventTypes.RoomHistoryVisibility, EventTypes.JoinRules): - yield self._handle_room_publicity_change( + await self._handle_room_publicity_change( room_id, prev_event_id, event_id, typ ) elif typ == EventTypes.Member: - change = yield self._get_key_change( + change = await self._get_key_change( prev_event_id, event_id, key_name="membership", @@ -201,7 +193,7 @@ class UserDirectoryHandler(StateDeltasHandler): if change is False: # Need to check if the server left the room entirely, if so # we might need to remove all the users in that room - is_in_room = yield self.store.is_host_joined( + is_in_room = await self.store.is_host_joined( room_id, self.server_name ) if not is_in_room: @@ -209,40 +201,41 @@ class UserDirectoryHandler(StateDeltasHandler): # Fetch all the users that we marked as being in user # directory due to being in the room and then check if # need to remove those users or not - user_ids = yield self.store.get_users_in_dir_due_to_room( + user_ids = await self.store.get_users_in_dir_due_to_room( room_id ) for user_id in user_ids: - yield self._handle_remove_user(room_id, user_id) + await self._handle_remove_user(room_id, user_id) return else: logger.debug("Server is still in room: %r", room_id) - is_support = yield self.store.is_support_user(state_key) + is_support = await self.store.is_support_user(state_key) if not is_support: if change is None: # Handle any profile changes - yield self._handle_profile_change( + await self._handle_profile_change( state_key, room_id, prev_event_id, event_id ) continue if change: # The user joined - event = yield self.store.get_event(event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) profile = ProfileInfo( avatar_url=event.content.get("avatar_url"), display_name=event.content.get("displayname"), ) - yield self._handle_new_user(room_id, state_key, profile) + await self._handle_new_user(room_id, state_key, profile) else: # The user left - yield self._handle_remove_user(room_id, state_key) + await self._handle_remove_user(room_id, state_key) else: logger.debug("Ignoring irrelevant type: %r", typ) - @defer.inlineCallbacks - def _handle_room_publicity_change(self, room_id, prev_event_id, event_id, typ): + async def _handle_room_publicity_change( + self, room_id, prev_event_id, event_id, typ + ): """Handle a room having potentially changed from/to world_readable/publically joinable. @@ -255,14 +248,14 @@ class UserDirectoryHandler(StateDeltasHandler): logger.debug("Handling change for %s: %s", typ, room_id) if typ == EventTypes.RoomHistoryVisibility: - change = yield self._get_key_change( + change = await self._get_key_change( prev_event_id, event_id, key_name="history_visibility", public_value="world_readable", ) elif typ == EventTypes.JoinRules: - change = yield self._get_key_change( + change = await self._get_key_change( prev_event_id, event_id, key_name="join_rule", @@ -278,7 +271,7 @@ class UserDirectoryHandler(StateDeltasHandler): # There's been a change to or from being world readable. - is_public = yield self.store.is_room_world_readable_or_publicly_joinable( + is_public = await self.store.is_room_world_readable_or_publicly_joinable( room_id ) @@ -293,11 +286,11 @@ class UserDirectoryHandler(StateDeltasHandler): # ignore the change return - users_with_profile = yield self.state.get_current_users_in_room(room_id) + users_with_profile = await self.state.get_current_users_in_room(room_id) # Remove every user from the sharing tables for that room. for user_id in iterkeys(users_with_profile): - yield self.store.remove_user_who_share_room(user_id, room_id) + await self.store.remove_user_who_share_room(user_id, room_id) # Then, re-add them to the tables. # NOTE: this is not the most efficient method, as handle_new_user sets @@ -306,26 +299,9 @@ class UserDirectoryHandler(StateDeltasHandler): # being added multiple times. The batching upserts shouldn't make this # too bad, though. for user_id, profile in iteritems(users_with_profile): - yield self._handle_new_user(room_id, user_id, profile) - - @defer.inlineCallbacks - def _handle_local_user(self, user_id): - """Adds a new local roomless user into the user_directory_search table. - Used to populate up the user index when we have an - user_directory_search_all_users specified. - """ - logger.debug("Adding new local user to dir, %r", user_id) - - profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id)) - - row = yield self.store.get_user_in_directory(user_id) - if not row: - yield self.store.update_profile_in_user_dir( - user_id, profile.display_name, profile.avatar_url - ) + await self._handle_new_user(room_id, user_id, profile) - @defer.inlineCallbacks - def _handle_new_user(self, room_id, user_id, profile): + async def _handle_new_user(self, room_id, user_id, profile): """Called when we might need to add user to directory Args: @@ -334,18 +310,18 @@ class UserDirectoryHandler(StateDeltasHandler): """ logger.debug("Adding new user to dir, %r", user_id) - yield self.store.update_profile_in_user_dir( + await self.store.update_profile_in_user_dir( user_id, profile.display_name, profile.avatar_url ) - is_public = yield self.store.is_room_world_readable_or_publicly_joinable( + is_public = await self.store.is_room_world_readable_or_publicly_joinable( room_id ) # Now we update users who share rooms with users. - users_with_profile = yield self.state.get_current_users_in_room(room_id) + users_with_profile = await self.state.get_current_users_in_room(room_id) if is_public: - yield self.store.add_users_in_public_rooms(room_id, (user_id,)) + await self.store.add_users_in_public_rooms(room_id, (user_id,)) else: to_insert = set() @@ -376,10 +352,9 @@ class UserDirectoryHandler(StateDeltasHandler): to_insert.add((other_user_id, user_id)) if to_insert: - yield self.store.add_users_who_share_private_room(room_id, to_insert) + await self.store.add_users_who_share_private_room(room_id, to_insert) - @defer.inlineCallbacks - def _handle_remove_user(self, room_id, user_id): + async def _handle_remove_user(self, room_id, user_id): """Called when we might need to remove user from directory Args: @@ -389,24 +364,23 @@ class UserDirectoryHandler(StateDeltasHandler): logger.debug("Removing user %r", user_id) # Remove user from sharing tables - yield self.store.remove_user_who_share_room(user_id, room_id) + await self.store.remove_user_who_share_room(user_id, room_id) # Are they still in any rooms? If not, remove them entirely. - rooms_user_is_in = yield self.store.get_user_dir_rooms_user_is_in(user_id) + rooms_user_is_in = await self.store.get_user_dir_rooms_user_is_in(user_id) if len(rooms_user_is_in) == 0: - yield self.store.remove_from_user_dir(user_id) + await self.store.remove_from_user_dir(user_id) - @defer.inlineCallbacks - def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id): + async def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id): """Check member event changes for any profile changes and update the database if there are. """ if not prev_event_id or not event_id: return - prev_event = yield self.store.get_event(prev_event_id, allow_none=True) - event = yield self.store.get_event(event_id, allow_none=True) + prev_event = await self.store.get_event(prev_event_id, allow_none=True) + event = await self.store.get_event(event_id, allow_none=True) if not prev_event or not event: return @@ -421,4 +395,4 @@ class UserDirectoryHandler(StateDeltasHandler): new_avatar = event.content.get("avatar_url") if prev_name != new_name or prev_avatar != new_avatar: - yield self.store.update_profile_in_user_dir(user_id, new_name, new_avatar) + await self.store.update_profile_in_user_dir(user_id, new_name, new_avatar) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 572df8d80b..c15bce5bef 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -14,6 +14,8 @@ # limitations under the License. from mock import Mock +from twisted.internet import defer + import synapse.rest.admin from synapse.api.constants import UserTypes from synapse.rest.client.v1 import login, room @@ -75,18 +77,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) - self.store.remove_from_user_dir = Mock() - self.store.remove_from_user_in_public_room = Mock() + self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None)) self.get_success(self.handler.handle_user_deactivated(s_user_id)) self.store.remove_from_user_dir.not_called() - self.store.remove_from_user_in_public_room.not_called() def test_handle_user_deactivated_regular_user(self): r_user_id = "@regular:test" self.get_success( self.store.register_user(user_id=r_user_id, password_hash=None) ) - self.store.remove_from_user_dir = Mock() + self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None)) self.get_success(self.handler.handle_user_deactivated(r_user_id)) self.store.remove_from_user_dir.called_once_with(r_user_id) -- cgit 1.5.1