From 8000cf131592b6edcded65ef4be20b8ac0f1bfd3 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 16 Mar 2021 16:44:25 +0100 Subject: Return m.change_password.enabled=false if local database is disabled (#9588) Instead of if the user does not have a password hash. This allows a SSO user to add a password to their account, but only if the local password database is configured. --- synapse/handlers/auth.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) (limited to 'synapse/handlers/auth.py') diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index fb5f8118f0..badac8c26c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -886,6 +886,19 @@ class AuthHandler(BaseHandler): ) return result + def can_change_password(self) -> bool: + """Get whether users on this server are allowed to change or set a password. + + Both `config.password_enabled` and `config.password_localdb_enabled` must be true. + + Note that any account (even SSO accounts) are allowed to add passwords if the above + is true. + + Returns: + Whether users on this server are allowed to change or set a password + """ + return self._password_enabled and self._password_localdb_enabled + def get_supported_login_types(self) -> Iterable[str]: """Get a the login types supported for the /login API -- cgit 1.5.1 From b7748d3c00e87df8c49346e67d643916487254e4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 23 Mar 2021 07:12:48 -0400 Subject: Import HomeServer from the proper module. (#9665) --- changelog.d/9665.misc | 1 + synapse/crypto/keyring.py | 2 +- synapse/federation/federation_client.py | 2 +- synapse/groups/attestations.py | 2 +- synapse/groups/groups_server.py | 2 +- synapse/handlers/_base.py | 2 +- synapse/handlers/account_data.py | 2 +- synapse/handlers/account_validity.py | 2 +- synapse/handlers/acme.py | 2 +- synapse/handlers/admin.py | 2 +- synapse/handlers/appservice.py | 2 +- synapse/handlers/auth.py | 2 +- synapse/handlers/cas_handler.py | 2 +- synapse/handlers/deactivate_account.py | 2 +- synapse/handlers/device.py | 2 +- synapse/handlers/devicemessage.py | 2 +- synapse/handlers/e2e_keys.py | 2 +- synapse/handlers/e2e_room_keys.py | 2 +- synapse/handlers/groups_local.py | 2 +- synapse/handlers/password_policy.py | 2 +- synapse/handlers/profile.py | 2 +- synapse/handlers/read_marker.py | 2 +- synapse/handlers/receipts.py | 2 +- synapse/handlers/register.py | 2 +- synapse/handlers/room_list.py | 2 +- synapse/handlers/room_member_worker.py | 2 +- synapse/handlers/search.py | 2 +- synapse/handlers/set_password.py | 2 +- synapse/handlers/state_deltas.py | 2 +- synapse/handlers/stats.py | 2 +- synapse/handlers/user_directory.py | 2 +- synapse/http/client.py | 2 +- synapse/push/__init__.py | 2 +- synapse/push/action_generator.py | 2 +- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/push/emailpusher.py | 2 +- synapse/push/httppusher.py | 2 +- synapse/push/mailer.py | 2 +- synapse/push/pusher.py | 2 +- synapse/replication/slave/storage/pushers.py | 2 +- synapse/replication/tcp/streams/_base.py | 2 +- synapse/rest/admin/media.py | 2 +- synapse/rest/client/v1/room.py | 2 +- synapse/rest/client/v2_alpha/account.py | 2 +- synapse/rest/client/v2_alpha/groups.py | 2 +- synapse/rest/media/v1/config_resource.py | 2 +- synapse/rest/media/v1/download_resource.py | 2 +- synapse/rest/media/v1/media_repository.py | 2 +- synapse/rest/media/v1/preview_url_resource.py | 2 +- synapse/rest/media/v1/storage_provider.py | 2 +- synapse/rest/media/v1/thumbnail_resource.py | 2 +- synapse/rest/media/v1/upload_resource.py | 2 +- synapse/storage/__init__.py | 2 +- synapse/storage/_base.py | 2 +- synapse/storage/background_updates.py | 2 +- synapse/storage/databases/main/appservice.py | 2 +- synapse/storage/databases/main/pusher.py | 2 +- synapse/storage/purge_events.py | 2 +- synapse/storage/state.py | 2 +- 59 files changed, 59 insertions(+), 58 deletions(-) create mode 100644 changelog.d/9665.misc (limited to 'synapse/handlers/auth.py') diff --git a/changelog.d/9665.misc b/changelog.d/9665.misc new file mode 100644 index 0000000000..b8bf76c639 --- /dev/null +++ b/changelog.d/9665.misc @@ -0,0 +1 @@ +Import `HomeServer` from the proper module. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 902128a23c..d5fb51513b 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -57,7 +57,7 @@ from synapse.util.metrics import Measure from synapse.util.retryutils import NotRetryingDestination if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index bee81fc019..3b2f51baab 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -62,7 +62,7 @@ from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index a3f8d92d08..368c44708d 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -46,7 +46,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import JsonDict, get_domain_from_id if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index f9a0f40221..4b16a4ac29 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -25,7 +25,7 @@ from synapse.types import GroupID, JsonDict, RoomID, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index d29b066a56..aade2c4a3a 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -24,7 +24,7 @@ from synapse.api.ratelimiting import Ratelimiter from synapse.types import UserID if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index b1a5df9638..1ce6d697ed 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -25,7 +25,7 @@ from synapse.replication.http.account_data import ( from synapse.types import JsonDict, UserID if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer class AccountDataHandler: diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 664d09da1c..d781bb251d 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -27,7 +27,7 @@ from synapse.types import UserID from synapse.util import stringutils if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py index 132be238dd..2a25af6288 100644 --- a/synapse/handlers/acme.py +++ b/synapse/handlers/acme.py @@ -24,7 +24,7 @@ from twisted.web.resource import Resource from synapse.app import check_bind_error if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index db68c94c50..c494de49a3 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -25,7 +25,7 @@ from synapse.visibility import filter_events_for_client from ._base import BaseHandler if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index deab8ff2d0..996f9e5deb 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -38,7 +38,7 @@ from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, User from synapse.util.metrics import Measure if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index badac8c26c..d537ea8137 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -70,7 +70,7 @@ from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.threepids import canonicalise_email if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index cb67589f7d..5060936f94 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -27,7 +27,7 @@ from synapse.http.site import SynapseRequest from synapse.types import UserID, map_username_to_mxid_localpart if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 3886d3124d..2bcd8f5435 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -23,7 +23,7 @@ from synapse.types import Requester, UserID, create_requester from ._base import BaseHandler if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 2fc4951df4..54293d0b9c 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -45,7 +45,7 @@ from synapse.util.retryutils import NotRetryingDestination from ._base import BaseHandler if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 7db4f48965..eb547743be 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -32,7 +32,7 @@ from synapse.util import json_encoder from synapse.util.stringutils import random_string if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 9a946a3cfe..2ad9b6d930 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -42,7 +42,7 @@ from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 622cae23be..a910d246d6 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -29,7 +29,7 @@ from synapse.types import JsonDict from synapse.util.async_helpers import Linearizer if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index bfb95e3eee..a41ca5df9c 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -21,7 +21,7 @@ from synapse.api.errors import HttpResponseException, RequestSendFailed, Synapse from synapse.types import GroupID, JsonDict, get_domain_from_id if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py index 6c635cc31b..92cefa11aa 100644 --- a/synapse/handlers/password_policy.py +++ b/synapse/handlers/password_policy.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING from synapse.api.errors import Codes, PasswordRefusedError if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index dd59392bda..a755363c3f 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -36,7 +36,7 @@ from synapse.types import ( from ._base import BaseHandler if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index 6bb2fd936b..a54fe1968e 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -21,7 +21,7 @@ from synapse.util.async_helpers import Linearizer from ._base import BaseHandler if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 6a6c528849..dbfe9bfaca 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -20,7 +20,7 @@ from synapse.handlers._base import BaseHandler from synapse.types import JsonDict, ReadReceipt, get_domain_from_id if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index d7f226d589..0fc2bf15d5 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -38,7 +38,7 @@ from synapse.types import RoomAlias, UserID, create_requester from ._base import BaseHandler if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 8bfc46c654..924b81db7c 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -29,7 +29,7 @@ from synapse.util.caches.response_cache import ResponseCache from ._base import BaseHandler if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index d75506c75e..3a90fc0c16 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -26,7 +26,7 @@ from synapse.replication.http.membership import ( from synapse.types import Requester, UserID if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 94062e79cb..d742dfbd53 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -30,7 +30,7 @@ from synapse.visibility import filter_events_for_client from ._base import BaseHandler if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 04e7c64c94..f98a338ec5 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -21,7 +21,7 @@ from synapse.types import Requester from ._base import BaseHandler if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index b3f9875358..ee8f87e59a 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -17,7 +17,7 @@ import logging from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 924281144c..8730f99d03 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -24,7 +24,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import JsonDict if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 1a8340000a..b121286d95 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -25,7 +25,7 @@ from synapse.types import JsonDict from synapse.util.metrics import Measure if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/http/client.py b/synapse/http/client.py index 1e01e0a9f2..a0caba84e4 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -77,7 +77,7 @@ from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index f4f7ec96f8..9fc3da49a2 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -21,7 +21,7 @@ import attr from synapse.types import JsonDict, RoomStreamToken if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer @attr.s(slots=True) diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index aaed28650d..38a47a600f 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -22,7 +22,7 @@ from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator from synapse.util.metrics import Measure if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index c016a83909..1897f59153 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -33,7 +33,7 @@ from synapse.util.caches.lrucache import LruCache from .push_rule_evaluator import PushRuleEvaluatorForEvent if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 3dc06a79e8..c0968dc7a1 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -24,7 +24,7 @@ from synapse.push import Pusher, PusherConfig, ThrottleParams from synapse.push.mailer import Mailer if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 026134ae26..26af5309c1 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -31,7 +31,7 @@ from synapse.push import Pusher, PusherConfig, PusherConfigException from . import push_rule_evaluator, push_tools if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index d10201b6b3..2e5161de2c 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -40,7 +40,7 @@ from synapse.util.async_helpers import concurrently_execute from synapse.visibility import filter_events_for_client if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index 2aa7918fb4..cb94127850 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -22,7 +22,7 @@ from synapse.push.httppusher import HttpPusher from synapse.push.mailer import Mailer if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 045bd014da..93161c3dfb 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -24,7 +24,7 @@ from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 7e8e64d61c..3dfee76743 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -33,7 +33,7 @@ import attr from synapse.replication.http.streams import ReplicationGetStreamUpdates if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 7fcc48a9d7..40646ef241 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -28,7 +28,7 @@ from synapse.rest.admin._base import ( from synapse.types import JsonDict if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 6c722d634d..525efdf221 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -58,7 +58,7 @@ from synapse.util import json_decoder from synapse.util.stringutils import parse_and_validate_server_name, random_string if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index adf1d39728..c2ba790bab 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -45,7 +45,7 @@ from synapse.util.threepids import canonicalise_email, check_3pid_allowed from ._base import client_patterns, interactive_auth_handler if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 5901432fad..08fb6b2b06 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -38,7 +38,7 @@ from synapse.types import GroupID, JsonDict from ._base import client_patterns if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py index 1eff98ef14..c41a7ab412 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/v1/config_resource.py @@ -23,7 +23,7 @@ from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.site import SynapseRequest if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer class MediaConfigResource(DirectServeJsonResource): diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index 8a43581f1f..5dadaeaf57 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -24,8 +24,8 @@ from synapse.http.servlet import parse_boolean from ._base import parse_media_id, respond_404 if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 8b4841ed5d..0c041b542d 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -58,7 +58,7 @@ from .thumbnailer import Thumbnailer, ThumbnailError from .upload_resource import UploadResource if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index b8895aeaa9..e590a0deab 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -54,8 +54,8 @@ from ._base import FileInfo if TYPE_CHECKING: from lxml import etree - from synapse.app.homeserver import HomeServer from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index e92006faa9..031947557d 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -29,7 +29,7 @@ from .media_storage import FileResponder logger = logging.getLogger(__name__) if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer class StorageProvider(metaclass=abc.ABCMeta): diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index fbcd50f1e2..af802bc0b1 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -34,8 +34,8 @@ from ._base import ( ) if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index ae5aef2f7f..0138b2e2d1 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -26,8 +26,8 @@ from synapse.http.site import SynapseRequest from synapse.rest.media.v1.media_storage import SpamMediaException if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer from synapse.rest.media.v1.media_repository import MediaRepository + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index a3c52695e9..0b9007e51f 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -36,7 +36,7 @@ from synapse.storage.purge_events import PurgeEventsStorage from synapse.storage.state import StateGroupStorage if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer __all__ = ["Databases", "DataStore"] diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index a25c4093bc..240905329f 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -27,7 +27,7 @@ from synapse.types import Collection, StreamToken, get_domain_from_id from synapse.util import json_decoder if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 329660cf0f..ccb06aab39 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -23,7 +23,7 @@ from synapse.util import json_encoder from . import engines if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer from synapse.storage.database import DatabasePool, LoggingTransaction logger = logging.getLogger(__name__) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 03a38422a1..85bb853d33 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -32,7 +32,7 @@ from synapse.types import JsonDict from synapse.util import json_encoder if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 85f1ebac98..c65558c280 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -27,7 +27,7 @@ from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py index 4dcd848c59..ad954990a7 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Set from synapse.storage.databases import Databases if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer logger = logging.getLogger(__name__) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index d179a41884..aa25bd8350 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -32,7 +32,7 @@ from synapse.events import EventBase from synapse.types import MutableStateMap, StateMap if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer + from synapse.server import HomeServer from synapse.storage.databases import Databases logger = logging.getLogger(__name__) -- cgit 1.5.1 From 963f4309fe29206f3ba92b493e922280feea30ed Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 30 Mar 2021 12:06:09 +0100 Subject: Make RateLimiter class check for ratelimit overrides (#9711) This should fix a class of bug where we forget to check if e.g. the appservice shouldn't be ratelimited. We also check the `ratelimit_override` table to check if the user has ratelimiting disabled. That table is really only meant to override the event sender ratelimiting, so we don't use any values from it (as they might not make sense for different rate limits), but we do infer that if ratelimiting is disabled for the user we should disabled all ratelimits. Fixes #9663 --- changelog.d/9711.bugfix | 1 + synapse/api/ratelimiting.py | 100 +++++++++--------- synapse/federation/federation_server.py | 5 +- synapse/handlers/_base.py | 14 +-- synapse/handlers/auth.py | 24 +++-- synapse/handlers/devicemessage.py | 5 +- synapse/handlers/federation.py | 2 +- synapse/handlers/identity.py | 12 ++- synapse/handlers/register.py | 6 +- synapse/handlers/room_member.py | 23 +++-- synapse/replication/http/register.py | 2 +- synapse/rest/client/v1/login.py | 14 ++- synapse/rest/client/v2_alpha/account.py | 10 +- synapse/rest/client/v2_alpha/register.py | 8 +- synapse/server.py | 1 + tests/api/test_ratelimiting.py | 168 ++++++++++++++++++++----------- 16 files changed, 241 insertions(+), 154 deletions(-) create mode 100644 changelog.d/9711.bugfix (limited to 'synapse/handlers/auth.py') diff --git a/changelog.d/9711.bugfix b/changelog.d/9711.bugfix new file mode 100644 index 0000000000..4ca3438d46 --- /dev/null +++ b/changelog.d/9711.bugfix @@ -0,0 +1 @@ +Fix recently added ratelimits to correctly honour the application service `rate_limited` flag. diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index c3f07bc1a3..2244b8a340 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -17,6 +17,7 @@ from collections import OrderedDict from typing import Hashable, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.storage.databases.main import DataStore from synapse.types import Requester from synapse.util import Clock @@ -31,10 +32,13 @@ class Ratelimiter: burst_count: How many actions that can be performed before being limited. """ - def __init__(self, clock: Clock, rate_hz: float, burst_count: int): + def __init__( + self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int + ): self.clock = clock self.rate_hz = rate_hz self.burst_count = burst_count + self.store = store # A ordered dictionary keeping track of actions, when they were last # performed and how often. Each entry is a mapping from a key of arbitrary type @@ -46,45 +50,10 @@ class Ratelimiter: OrderedDict() ) # type: OrderedDict[Hashable, Tuple[float, int, float]] - def can_requester_do_action( - self, - requester: Requester, - rate_hz: Optional[float] = None, - burst_count: Optional[int] = None, - update: bool = True, - _time_now_s: Optional[int] = None, - ) -> Tuple[bool, float]: - """Can the requester perform the action? - - Args: - requester: The requester to key off when rate limiting. The user property - will be used. - rate_hz: The long term number of actions that can be performed in a second. - Overrides the value set during instantiation if set. - burst_count: How many actions that can be performed before being limited. - Overrides the value set during instantiation if set. - update: Whether to count this check as performing the action - _time_now_s: The current time. Optional, defaults to the current time according - to self.clock. Only used by tests. - - Returns: - A tuple containing: - * A bool indicating if they can perform the action now - * The reactor timestamp for when the action can be performed next. - -1 if rate_hz is less than or equal to zero - """ - # Disable rate limiting of users belonging to any AS that is configured - # not to be rate limited in its registration file (rate_limited: true|false). - if requester.app_service and not requester.app_service.is_rate_limited(): - return True, -1.0 - - return self.can_do_action( - requester.user.to_string(), rate_hz, burst_count, update, _time_now_s - ) - - def can_do_action( + async def can_do_action( self, - key: Hashable, + requester: Optional[Requester], + key: Optional[Hashable] = None, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, update: bool = True, @@ -92,9 +61,16 @@ class Ratelimiter: ) -> Tuple[bool, float]: """Can the entity (e.g. user or IP address) perform the action? + Checks if the user has ratelimiting disabled in the database by looking + for null/zero values in the `ratelimit_override` table. (Non-zero + values aren't honoured, as they're specific to the event sending + ratelimiter, rather than all ratelimiters) + Args: - key: The key we should use when rate limiting. Can be a user ID - (when sending events), an IP address, etc. + requester: The requester that is doing the action, if any. Used to check + if the user has ratelimits disabled in the database. + key: An arbitrary key used to classify an action. Defaults to the + requester's user ID. rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. burst_count: How many actions that can be performed before being limited. @@ -109,6 +85,30 @@ class Ratelimiter: * The reactor timestamp for when the action can be performed next. -1 if rate_hz is less than or equal to zero """ + if key is None: + if not requester: + raise ValueError("Must supply at least one of `requester` or `key`") + + key = requester.user.to_string() + + if requester: + # Disable rate limiting of users belonging to any AS that is configured + # not to be rate limited in its registration file (rate_limited: true|false). + if requester.app_service and not requester.app_service.is_rate_limited(): + return True, -1.0 + + # Check if ratelimiting has been disabled for the user. + # + # Note that we don't use the returned rate/burst count, as the table + # is specifically for the event sending ratelimiter. Instead, we + # only use it to (somewhat cheekily) infer whether the user should + # be subject to any rate limiting or not. + override = await self.store.get_ratelimit_for_user( + requester.authenticated_entity + ) + if override and not override.messages_per_second: + return True, -1.0 + # Override default values if set time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() rate_hz = rate_hz if rate_hz is not None else self.rate_hz @@ -175,9 +175,10 @@ class Ratelimiter: else: del self.actions[key] - def ratelimit( + async def ratelimit( self, - key: Hashable, + requester: Optional[Requester], + key: Optional[Hashable] = None, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, update: bool = True, @@ -185,8 +186,16 @@ class Ratelimiter: ): """Checks if an action can be performed. If not, raises a LimitExceededError + Checks if the user has ratelimiting disabled in the database by looking + for null/zero values in the `ratelimit_override` table. (Non-zero + values aren't honoured, as they're specific to the event sending + ratelimiter, rather than all ratelimiters) + Args: - key: An arbitrary key used to classify an action + requester: The requester that is doing the action, if any. Used to check for + if the user has ratelimits disabled. + key: An arbitrary key used to classify an action. Defaults to the + requester's user ID. rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. burst_count: How many actions that can be performed before being limited. @@ -201,7 +210,8 @@ class Ratelimiter: """ time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() - allowed, time_allowed = self.can_do_action( + allowed, time_allowed = await self.can_do_action( + requester, key, rate_hz=rate_hz, burst_count=burst_count, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index d84e362070..71cb120ef7 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -870,6 +870,7 @@ class FederationHandlerRegistry: # A rate limiter for incoming room key requests per origin. self._room_key_request_rate_limiter = Ratelimiter( + store=hs.get_datastore(), clock=self.clock, rate_hz=self.config.rc_key_requests.per_second, burst_count=self.config.rc_key_requests.burst_count, @@ -930,7 +931,9 @@ class FederationHandlerRegistry: # the limit, drop them. if ( edu_type == EduTypes.RoomKeyRequest - and not self._room_key_request_rate_limiter.can_do_action(origin) + and not await self._room_key_request_rate_limiter.can_do_action( + None, origin + ) ): return diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index aade2c4a3a..fb899aa90d 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -49,7 +49,7 @@ class BaseHandler: # The rate_hz and burst_count are overridden on a per-user basis self.request_ratelimiter = Ratelimiter( - clock=self.clock, rate_hz=0, burst_count=0 + store=self.store, clock=self.clock, rate_hz=0, burst_count=0 ) self._rc_message = self.hs.config.rc_message @@ -57,6 +57,7 @@ class BaseHandler: # by the presence of rate limits in the config if self.hs.config.rc_admin_redaction: self.admin_redaction_ratelimiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=self.hs.config.rc_admin_redaction.per_second, burst_count=self.hs.config.rc_admin_redaction.burst_count, @@ -91,11 +92,6 @@ class BaseHandler: if app_service is not None: return # do not ratelimit app service senders - # Disable rate limiting of users belonging to any AS that is configured - # not to be rate limited in its registration file (rate_limited: true|false). - if requester.app_service and not requester.app_service.is_rate_limited(): - return - messages_per_second = self._rc_message.per_second burst_count = self._rc_message.burst_count @@ -113,11 +109,11 @@ class BaseHandler: if is_admin_redaction and self.admin_redaction_ratelimiter: # If we have separate config for admin redactions, use a separate # ratelimiter as to not have user_ids clash - self.admin_redaction_ratelimiter.ratelimit(user_id, update=update) + await self.admin_redaction_ratelimiter.ratelimit(requester, update=update) else: # Override rate and burst count per-user - self.request_ratelimiter.ratelimit( - user_id, + await self.request_ratelimiter.ratelimit( + requester, rate_hz=messages_per_second, burst_count=burst_count, update=update, diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d537ea8137..08e413bc98 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -238,6 +238,7 @@ class AuthHandler(BaseHandler): # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. self._failed_uia_attempts_ratelimiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, @@ -248,6 +249,7 @@ class AuthHandler(BaseHandler): # Ratelimitier for failed /login attempts self._failed_login_attempts_ratelimiter = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, @@ -352,7 +354,7 @@ class AuthHandler(BaseHandler): requester_user_id = requester.user.to_string() # Check if we should be ratelimited due to too many previous failed attempts - self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False) + await self._failed_uia_attempts_ratelimiter.ratelimit(requester, update=False) # build a list of supported flows supported_ui_auth_types = await self._get_available_ui_auth_types( @@ -373,7 +375,9 @@ class AuthHandler(BaseHandler): ) except LoginError: # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). - self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id) + await self._failed_uia_attempts_ratelimiter.can_do_action( + requester, + ) raise # find the completed login type @@ -982,8 +986,8 @@ class AuthHandler(BaseHandler): # We also apply account rate limiting using the 3PID as a key, as # otherwise using 3PID bypasses the ratelimiting based on user ID. if ratelimit: - self._failed_login_attempts_ratelimiter.ratelimit( - (medium, address), update=False + await self._failed_login_attempts_ratelimiter.ratelimit( + None, (medium, address), update=False ) # Check for login providers that support 3pid login types @@ -1016,8 +1020,8 @@ class AuthHandler(BaseHandler): # this code path, which is fine as then the per-user ratelimit # will kick in below. if ratelimit: - self._failed_login_attempts_ratelimiter.can_do_action( - (medium, address) + await self._failed_login_attempts_ratelimiter.can_do_action( + None, (medium, address) ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) @@ -1039,8 +1043,8 @@ class AuthHandler(BaseHandler): # Check if we've hit the failed ratelimit (but don't update it) if ratelimit: - self._failed_login_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), update=False + await self._failed_login_attempts_ratelimiter.ratelimit( + None, qualified_user_id.lower(), update=False ) try: @@ -1051,8 +1055,8 @@ class AuthHandler(BaseHandler): # exception and masking the LoginError. The actual ratelimiting # should have happened above. if ratelimit: - self._failed_login_attempts_ratelimiter.can_do_action( - qualified_user_id.lower() + await self._failed_login_attempts_ratelimiter.can_do_action( + None, qualified_user_id.lower() ) raise diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index eb547743be..5ee48be6ff 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -81,6 +81,7 @@ class DeviceMessageHandler: ) self._ratelimiter = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=hs.config.rc_key_requests.per_second, burst_count=hs.config.rc_key_requests.burst_count, @@ -191,8 +192,8 @@ class DeviceMessageHandler: if ( message_type == EduTypes.RoomKeyRequest and user_id != sender_user_id - and self._ratelimiter.can_do_action( - (sender_user_id, requester.device_id) + and await self._ratelimiter.can_do_action( + requester, (sender_user_id, requester.device_id) ) ): continue diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 598a66f74c..3ebee38ebe 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1711,7 +1711,7 @@ class FederationHandler(BaseHandler): member_handler = self.hs.get_room_member_handler() # We don't rate limit based on room ID, as that should be done by # sending server. - member_handler.ratelimit_invite(None, event.state_key) + await member_handler.ratelimit_invite(None, None, event.state_key) # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 5f346f6d6d..d89fa5fb30 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -61,17 +61,19 @@ class IdentityHandler(BaseHandler): # Ratelimiters for `/requestToken` endpoints. self._3pid_validation_ratelimiter_ip = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, ) self._3pid_validation_ratelimiter_address = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, ) - def ratelimit_request_token_requests( + async def ratelimit_request_token_requests( self, request: SynapseRequest, medium: str, @@ -85,8 +87,12 @@ class IdentityHandler(BaseHandler): address: The actual threepid ID, e.g. the phone number or email address """ - self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP())) - self._3pid_validation_ratelimiter_address.ratelimit((medium, address)) + await self._3pid_validation_ratelimiter_ip.ratelimit( + None, (medium, request.getClientIP()) + ) + await self._3pid_validation_ratelimiter_address.ratelimit( + None, (medium, address) + ) async def threepid_from_creds( self, id_server: str, creds: Dict[str, str] diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 0fc2bf15d5..9701b76d0f 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -204,7 +204,7 @@ class RegistrationHandler(BaseHandler): Raises: SynapseError if there was a problem registering. """ - self.check_registration_ratelimit(address) + await self.check_registration_ratelimit(address) result = await self.spam_checker.check_registration_for_spam( threepid, @@ -583,7 +583,7 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - def check_registration_ratelimit(self, address: Optional[str]) -> None: + async def check_registration_ratelimit(self, address: Optional[str]) -> None: """A simple helper method to check whether the registration rate limit has been hit for a given IP address @@ -597,7 +597,7 @@ class RegistrationHandler(BaseHandler): if not address: return - self.ratelimiter.ratelimit(address) + await self.ratelimiter.ratelimit(None, address) async def register_with_store( self, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 4d20ed8357..1cf12f3255 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -75,22 +75,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.allow_per_room_profiles = self.config.allow_per_room_profiles self._join_rate_limiter_local = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, burst_count=hs.config.ratelimiting.rc_joins_local.burst_count, ) self._join_rate_limiter_remote = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, ) self._invites_per_room_limiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, ) self._invites_per_user_limiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, @@ -159,15 +163,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): async def forget(self, user: UserID, room_id: str) -> None: raise NotImplementedError() - def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str): + async def ratelimit_invite( + self, + requester: Optional[Requester], + room_id: Optional[str], + invitee_user_id: str, + ): """Ratelimit invites by room and by target user. If room ID is missing then we just rate limit by target user. """ if room_id: - self._invites_per_room_limiter.ratelimit(room_id) + await self._invites_per_room_limiter.ratelimit(requester, room_id) - self._invites_per_user_limiter.ratelimit(invitee_user_id) + await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id) async def _local_membership_update( self, @@ -237,7 +246,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ( allowed, time_allowed, - ) = self._join_rate_limiter_local.can_requester_do_action(requester) + ) = await self._join_rate_limiter_local.can_do_action(requester) if not allowed: raise LimitExceededError( @@ -421,9 +430,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if effective_membership_state == Membership.INVITE: target_id = target.to_string() if ratelimit: - # Don't ratelimit application services. - if not requester.app_service or requester.app_service.is_rate_limited(): - self.ratelimit_invite(room_id, target_id) + await self.ratelimit_invite(requester, room_id, target_id) # block any attempts to invite the server notices mxid if target_id == self._server_notices_mxid: @@ -534,7 +541,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ( allowed, time_allowed, - ) = self._join_rate_limiter_remote.can_requester_do_action( + ) = await self._join_rate_limiter_remote.can_do_action( requester, ) diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index d005f38767..73d7477854 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -77,7 +77,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) - self.registration_handler.check_registration_ratelimit(content["address"]) + await self.registration_handler.check_registration_ratelimit(content["address"]) await self.registration_handler.register_with_store( user_id=user_id, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index e4c352f572..3151e72d4f 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -74,11 +74,13 @@ class LoginRestServlet(RestServlet): self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( + store=hs.get_datastore(), clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_address.per_second, burst_count=self.hs.config.rc_login_address.burst_count, ) self._account_ratelimiter = Ratelimiter( + store=hs.get_datastore(), clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_account.per_second, burst_count=self.hs.config.rc_login_account.burst_count, @@ -141,20 +143,22 @@ class LoginRestServlet(RestServlet): appservice = self.auth.get_appservice_by_req(request) if appservice.is_rate_limited(): - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit( + None, request.getClientIP() + ) result = await self._do_appservice_login(login_submission, appservice) elif self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) result = await self._do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) result = await self._do_token_login(login_submission) else: - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) result = await self._do_other_login(login_submission) except KeyError: raise SynapseError(400, "Missing JSON keys.") @@ -258,7 +262,7 @@ class LoginRestServlet(RestServlet): # too often. This happens here rather than before as we don't # necessarily know the user before now. if ratelimit: - self._account_ratelimiter.ratelimit(user_id.lower()) + await self._account_ratelimiter.ratelimit(None, user_id.lower()) if create_non_existent_users: canonical_uid = await self.auth_handler.check_user_exists(user_id) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index c2ba790bab..411fb57c47 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -103,7 +103,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) - self.identity_handler.ratelimit_request_token_requests(request, "email", email) + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) # The email will be sent to the stored address. # This avoids a potential account hijack by requesting a password reset to @@ -387,7 +389,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests(request, "email", email) + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) if next_link: # Raise if the provided next_link value isn't valid @@ -468,7 +472,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests( + await self.identity_handler.ratelimit_request_token_requests( request, "msisdn", msisdn ) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 8f68d8dfc8..c212da0cb2 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -126,7 +126,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests(request, "email", email) + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email @@ -208,7 +210,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests( + await self.identity_handler.ratelimit_request_token_requests( request, "msisdn", msisdn ) @@ -406,7 +408,7 @@ class RegisterRestServlet(RestServlet): client_addr = request.getClientIP() - self.ratelimiter.ratelimit(client_addr, update=False) + await self.ratelimiter.ratelimit(None, client_addr, update=False) kind = b"user" if b"kind" in request.args: diff --git a/synapse/server.py b/synapse/server.py index e85b9391fa..e42f7b1a18 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -329,6 +329,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_registration_ratelimiter(self) -> Ratelimiter: return Ratelimiter( + store=self.get_datastore(), clock=self.get_clock(), rate_hz=self.config.rc_registration.per_second, burst_count=self.config.rc_registration.burst_count, diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 483418192c..fa96ba07a5 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -5,38 +5,25 @@ from synapse.types import create_requester from tests import unittest -class TestRatelimiter(unittest.TestCase): +class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_via_can_do_action(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0) - self.assertTrue(allowed) - self.assertEquals(10.0, time_allowed) - - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5) - self.assertFalse(allowed) - self.assertEquals(10.0, time_allowed) - - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10) - self.assertTrue(allowed) - self.assertEquals(20.0, time_allowed) - - def test_allowed_user_via_can_requester_do_action(self): - user_requester = create_requester("@user:example.com") - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=0 + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=5) ) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) @@ -51,21 +38,23 @@ class TestRatelimiter(unittest.TestCase): ) as_requester = create_requester("@user:example.com", app_service=appservice) - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=0 + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) @@ -80,73 +69,89 @@ class TestRatelimiter(unittest.TestCase): ) as_requester = create_requester("@user:example.com", app_service=appservice) - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=0 + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) def test_allowed_via_ratelimit(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # Shouldn't raise - limiter.ratelimit(key="test_id", _time_now_s=0) + self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0)) # Should raise with self.assertRaises(LimitExceededError) as context: - limiter.ratelimit(key="test_id", _time_now_s=5) + self.get_success_or_raise( + limiter.ratelimit(None, key="test_id", _time_now_s=5) + ) self.assertEqual(context.exception.retry_after_ms, 5000) # Shouldn't raise - limiter.ratelimit(key="test_id", _time_now_s=10) + self.get_success_or_raise( + limiter.ratelimit(None, key="test_id", _time_now_s=10) + ) def test_allowed_via_can_do_action_and_overriding_parameters(self): """Test that we can override options of can_do_action that would otherwise fail an action """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # First attempt should be allowed - allowed, time_allowed = limiter.can_do_action( - ("test_id",), - _time_now_s=0, + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action( + None, + ("test_id",), + _time_now_s=0, + ) ) self.assertTrue(allowed) self.assertEqual(10.0, time_allowed) # Second attempt, 1s later, will fail - allowed, time_allowed = limiter.can_do_action( - ("test_id",), - _time_now_s=1, + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action( + None, + ("test_id",), + _time_now_s=1, + ) ) self.assertFalse(allowed) self.assertEqual(10.0, time_allowed) # But, if we allow 10 actions/sec for this request, we should be allowed # to continue. - allowed, time_allowed = limiter.can_do_action( - ("test_id",), _time_now_s=1, rate_hz=10.0 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, ("test_id",), _time_now_s=1, rate_hz=10.0) ) self.assertTrue(allowed) self.assertEqual(1.1, time_allowed) # Similarly if we allow a burst of 10 actions - allowed, time_allowed = limiter.can_do_action( - ("test_id",), _time_now_s=1, burst_count=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, ("test_id",), _time_now_s=1, burst_count=10) ) self.assertTrue(allowed) self.assertEqual(1.0, time_allowed) @@ -156,29 +161,72 @@ class TestRatelimiter(unittest.TestCase): fail an action """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # First attempt should be allowed - limiter.ratelimit(key=("test_id",), _time_now_s=0) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=0) + ) # Second attempt, 1s later, will fail with self.assertRaises(LimitExceededError) as context: - limiter.ratelimit(key=("test_id",), _time_now_s=1) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1) + ) self.assertEqual(context.exception.retry_after_ms, 9000) # But, if we allow 10 actions/sec for this request, we should be allowed # to continue. - limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1, rate_hz=10.0) + ) # Similarly if we allow a burst of 10 actions - limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10) + ) def test_pruning(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - limiter.can_do_action(key="test_id_1", _time_now_s=0) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + self.get_success_or_raise( + limiter.can_do_action(None, key="test_id_1", _time_now_s=0) + ) self.assertIn("test_id_1", limiter.actions) - limiter.can_do_action(key="test_id_2", _time_now_s=10) + self.get_success_or_raise( + limiter.can_do_action(None, key="test_id_2", _time_now_s=10) + ) self.assertNotIn("test_id_1", limiter.actions) + + def test_db_user_override(self): + """Test that users that have ratelimiting disabled in the DB aren't + ratelimited. + """ + store = self.hs.get_datastore() + + user_id = "@user:test" + requester = create_requester(user_id) + + self.get_success( + store.db_pool.simple_insert( + table="ratelimit_override", + values={ + "user_id": user_id, + "messages_per_second": None, + "burst_count": None, + }, + desc="test_db_user_override", + ) + ) + + limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1) + + # Shouldn't raise + for _ in range(20): + self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0)) -- cgit 1.5.1