diff --git a/changelog.d/9685.misc b/changelog.d/9685.misc
new file mode 100644
index 0000000000..0506d8af0c
--- /dev/null
+++ b/changelog.d/9685.misc
@@ -0,0 +1 @@
+Update `scripts-dev/complement.sh` to use a local checkout of Complement, allow running a subset of tests and have it use Synapse's Complement test blacklist.
\ No newline at end of file
diff --git a/changelog.d/9700.feature b/changelog.d/9700.feature
new file mode 100644
index 0000000000..037de8367f
--- /dev/null
+++ b/changelog.d/9700.feature
@@ -0,0 +1 @@
+Replace the `room_invite_state_types` configuration setting with `room_prejoin_state`.
diff --git a/changelog.d/9710.feature b/changelog.d/9710.feature
new file mode 100644
index 0000000000..fce308cc41
--- /dev/null
+++ b/changelog.d/9710.feature
@@ -0,0 +1 @@
+Experimental Spaces support: include `m.room.create` in the room state sent with room-invites.
diff --git a/changelog.d/9711.bugfix b/changelog.d/9711.bugfix
new file mode 100644
index 0000000000..4ca3438d46
--- /dev/null
+++ b/changelog.d/9711.bugfix
@@ -0,0 +1 @@
+Fix recently added ratelimits to correctly honour the application service `rate_limited` flag.
diff --git a/changelog.d/9718.removal b/changelog.d/9718.removal
new file mode 100644
index 0000000000..6de7814217
--- /dev/null
+++ b/changelog.d/9718.removal
@@ -0,0 +1 @@
+Replace deprecated `imp` module with successor `importlib`. Contributed by Cristina Muñoz.
diff --git a/changelog.d/9719.doc b/changelog.d/9719.doc
new file mode 100644
index 0000000000..f018606dd6
--- /dev/null
+++ b/changelog.d/9719.doc
@@ -0,0 +1 @@
+Make the allowed_local_3pids regex example in the sample config stricter.
diff --git a/docker/conf/homeserver.yaml b/docker/conf/homeserver.yaml
index 0dea62a87d..a792899540 100644
--- a/docker/conf/homeserver.yaml
+++ b/docker/conf/homeserver.yaml
@@ -173,18 +173,10 @@ report_stats: False
## API Configuration ##
-room_invite_state_types:
- - "m.room.join_rules"
- - "m.room.canonical_alias"
- - "m.room.avatar"
- - "m.room.name"
-
{% if SYNAPSE_APPSERVICES %}
app_service_config_files:
{% for appservice in SYNAPSE_APPSERVICES %} - "{{ appservice }}"
{% endfor %}
-{% else %}
-app_service_config_files: []
{% endif %}
macaroon_secret_key: "{{ SYNAPSE_MACAROON_SECRET_KEY }}"
diff --git a/docs/code_style.md b/docs/code_style.md
index 190f8ab2de..28fb7277c4 100644
--- a/docs/code_style.md
+++ b/docs/code_style.md
@@ -128,6 +128,9 @@ Some guidelines follow:
will be if no sub-options are enabled).
- Lines should be wrapped at 80 characters.
- Use two-space indents.
+- `true` and `false` are spelt thus (as opposed to `True`, etc.)
+- Use single quotes (`'`) rather than double-quotes (`"`) or backticks
+ (`` ` ``) to refer to configuration options.
Example:
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 17cda71adc..b0bf987740 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1246,9 +1246,9 @@ account_validity:
#
#allowed_local_3pids:
# - medium: email
-# pattern: '.*@matrix\.org'
+# pattern: '^[^@]+@matrix\.org$'
# - medium: email
-# pattern: '.*@vector\.im'
+# pattern: '^[^@]+@vector\.im$'
# - medium: msisdn
# pattern: '\+44'
@@ -1451,14 +1451,31 @@ metrics_flags:
## API Configuration ##
-# A list of event types that will be included in the room_invite_state
+# Controls for the state that is shared with users who receive an invite
+# to a room
#
-#room_invite_state_types:
-# - "m.room.join_rules"
-# - "m.room.canonical_alias"
-# - "m.room.avatar"
-# - "m.room.encryption"
-# - "m.room.name"
+room_prejoin_state:
+ # By default, the following state event types are shared with users who
+ # receive invites to the room:
+ #
+ # - m.room.join_rules
+ # - m.room.canonical_alias
+ # - m.room.avatar
+ # - m.room.encryption
+ # - m.room.name
+ #
+ # Uncomment the following to disable these defaults (so that only the event
+ # types listed in 'additional_event_types' are shared). Defaults to 'false'.
+ #
+ #disable_default_event_types: true
+
+ # Additional state event types to share with users when they are invited
+ # to a room.
+ #
+ # By default, this list is empty (so only the default event types are shared).
+ #
+ #additional_event_types:
+ # - org.example.custom.event.type
# A list of application service config files to use
diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh
index 3cde53f5c0..31cc20a826 100755
--- a/scripts-dev/complement.sh
+++ b/scripts-dev/complement.sh
@@ -1,22 +1,49 @@
-#! /bin/bash -eu
+#!/usr/bin/env bash
# This script is designed for developers who want to test their code
# against Complement.
#
# It makes a Synapse image which represents the current checkout,
-# then downloads Complement and runs it with that image.
+# builds a synapse-complement image on top, then runs tests with it.
+#
+# By default the script will fetch the latest Complement master branch and
+# run tests with that. This can be overridden to use a custom Complement
+# checkout by setting the COMPLEMENT_DIR environment variable to the
+# filepath of a local Complement checkout.
+#
+# A regular expression of test method names can be supplied as the first
+# argument to the script. Complement will then only run those tests. If
+# no regex is supplied, all tests are run. For example;
+#
+# ./complement.sh "TestOutboundFederation(Profile|Send)"
+#
+
+# Exit if a line returns a non-zero exit code
+set -e
+# Change to the repository root
cd "$(dirname $0)/.."
+# Check for a user-specified Complement checkout
+if [[ -z "$COMPLEMENT_DIR" ]]; then
+ echo "COMPLEMENT_DIR not set. Fetching the latest Complement checkout..."
+ wget -Nq https://github.com/matrix-org/complement/archive/master.tar.gz
+ tar -xzf master.tar.gz
+ COMPLEMENT_DIR=complement-master
+ echo "Checkout available at 'complement-master'"
+fi
+
# Build the base Synapse image from the local checkout
-docker build -t matrixdotorg/synapse:latest -f docker/Dockerfile .
+docker build -t matrixdotorg/synapse -f docker/Dockerfile .
+# Build the Synapse monolith image from Complement, based on the above image we just built
+docker build -t complement-synapse -f "$COMPLEMENT_DIR/dockerfiles/Synapse.Dockerfile" "$COMPLEMENT_DIR/dockerfiles"
-# Download Complement
-wget -N https://github.com/matrix-org/complement/archive/master.tar.gz
-tar -xzf master.tar.gz
-cd complement-master
+cd "$COMPLEMENT_DIR"
-# Build the Synapse image from Complement, based on the above image we just built
-docker build -t complement-synapse -f dockerfiles/Synapse.Dockerfile ./dockerfiles
+EXTRA_COMPLEMENT_ARGS=""
+if [[ -n "$1" ]]; then
+ # A test name regex has been set, supply it to Complement
+ EXTRA_COMPLEMENT_ARGS+="-run $1 "
+fi
-# Run the tests on the resulting image!
-COMPLEMENT_BASE_IMAGE=complement-synapse go test -v -count=1 ./tests
+# Run the tests!
+COMPLEMENT_BASE_IMAGE=complement-synapse go test -v -tags synapse_blacklist -count=1 $EXTRA_COMPLEMENT_ARGS ./tests
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index c3f07bc1a3..2244b8a340 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -17,6 +17,7 @@ from collections import OrderedDict
from typing import Hashable, Optional, Tuple
from synapse.api.errors import LimitExceededError
+from synapse.storage.databases.main import DataStore
from synapse.types import Requester
from synapse.util import Clock
@@ -31,10 +32,13 @@ class Ratelimiter:
burst_count: How many actions that can be performed before being limited.
"""
- def __init__(self, clock: Clock, rate_hz: float, burst_count: int):
+ def __init__(
+ self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int
+ ):
self.clock = clock
self.rate_hz = rate_hz
self.burst_count = burst_count
+ self.store = store
# A ordered dictionary keeping track of actions, when they were last
# performed and how often. Each entry is a mapping from a key of arbitrary type
@@ -46,45 +50,10 @@ class Ratelimiter:
OrderedDict()
) # type: OrderedDict[Hashable, Tuple[float, int, float]]
- def can_requester_do_action(
- self,
- requester: Requester,
- rate_hz: Optional[float] = None,
- burst_count: Optional[int] = None,
- update: bool = True,
- _time_now_s: Optional[int] = None,
- ) -> Tuple[bool, float]:
- """Can the requester perform the action?
-
- Args:
- requester: The requester to key off when rate limiting. The user property
- will be used.
- rate_hz: The long term number of actions that can be performed in a second.
- Overrides the value set during instantiation if set.
- burst_count: How many actions that can be performed before being limited.
- Overrides the value set during instantiation if set.
- update: Whether to count this check as performing the action
- _time_now_s: The current time. Optional, defaults to the current time according
- to self.clock. Only used by tests.
-
- Returns:
- A tuple containing:
- * A bool indicating if they can perform the action now
- * The reactor timestamp for when the action can be performed next.
- -1 if rate_hz is less than or equal to zero
- """
- # Disable rate limiting of users belonging to any AS that is configured
- # not to be rate limited in its registration file (rate_limited: true|false).
- if requester.app_service and not requester.app_service.is_rate_limited():
- return True, -1.0
-
- return self.can_do_action(
- requester.user.to_string(), rate_hz, burst_count, update, _time_now_s
- )
-
- def can_do_action(
+ async def can_do_action(
self,
- key: Hashable,
+ requester: Optional[Requester],
+ key: Optional[Hashable] = None,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
update: bool = True,
@@ -92,9 +61,16 @@ class Ratelimiter:
) -> Tuple[bool, float]:
"""Can the entity (e.g. user or IP address) perform the action?
+ Checks if the user has ratelimiting disabled in the database by looking
+ for null/zero values in the `ratelimit_override` table. (Non-zero
+ values aren't honoured, as they're specific to the event sending
+ ratelimiter, rather than all ratelimiters)
+
Args:
- key: The key we should use when rate limiting. Can be a user ID
- (when sending events), an IP address, etc.
+ requester: The requester that is doing the action, if any. Used to check
+ if the user has ratelimits disabled in the database.
+ key: An arbitrary key used to classify an action. Defaults to the
+ requester's user ID.
rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited.
@@ -109,6 +85,30 @@ class Ratelimiter:
* The reactor timestamp for when the action can be performed next.
-1 if rate_hz is less than or equal to zero
"""
+ if key is None:
+ if not requester:
+ raise ValueError("Must supply at least one of `requester` or `key`")
+
+ key = requester.user.to_string()
+
+ if requester:
+ # Disable rate limiting of users belonging to any AS that is configured
+ # not to be rate limited in its registration file (rate_limited: true|false).
+ if requester.app_service and not requester.app_service.is_rate_limited():
+ return True, -1.0
+
+ # Check if ratelimiting has been disabled for the user.
+ #
+ # Note that we don't use the returned rate/burst count, as the table
+ # is specifically for the event sending ratelimiter. Instead, we
+ # only use it to (somewhat cheekily) infer whether the user should
+ # be subject to any rate limiting or not.
+ override = await self.store.get_ratelimit_for_user(
+ requester.authenticated_entity
+ )
+ if override and not override.messages_per_second:
+ return True, -1.0
+
# Override default values if set
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
rate_hz = rate_hz if rate_hz is not None else self.rate_hz
@@ -175,9 +175,10 @@ class Ratelimiter:
else:
del self.actions[key]
- def ratelimit(
+ async def ratelimit(
self,
- key: Hashable,
+ requester: Optional[Requester],
+ key: Optional[Hashable] = None,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
update: bool = True,
@@ -185,8 +186,16 @@ class Ratelimiter:
):
"""Checks if an action can be performed. If not, raises a LimitExceededError
+ Checks if the user has ratelimiting disabled in the database by looking
+ for null/zero values in the `ratelimit_override` table. (Non-zero
+ values aren't honoured, as they're specific to the event sending
+ ratelimiter, rather than all ratelimiters)
+
Args:
- key: An arbitrary key used to classify an action
+ requester: The requester that is doing the action, if any. Used to check for
+ if the user has ratelimits disabled.
+ key: An arbitrary key used to classify an action. Defaults to the
+ requester's user ID.
rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited.
@@ -201,7 +210,8 @@ class Ratelimiter:
"""
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
- allowed, time_allowed = self.can_do_action(
+ allowed, time_allowed = await self.can_do_action(
+ requester,
key,
rate_hz=rate_hz,
burst_count=burst_count,
diff --git a/synapse/config/api.py b/synapse/config/api.py
index 74cd53a8ed..55c038c0c4 100644
--- a/synapse/config/api.py
+++ b/synapse/config/api.py
@@ -1,4 +1,4 @@
-# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,38 +12,131 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+from typing import Iterable
+
from synapse.api.constants import EventTypes
+from synapse.config._base import Config, ConfigError
+from synapse.config._util import validate_config
+from synapse.types import JsonDict
-from ._base import Config
+logger = logging.getLogger(__name__)
class ApiConfig(Config):
section = "api"
- def read_config(self, config, **kwargs):
- self.room_invite_state_types = config.get(
- "room_invite_state_types",
- [
- EventTypes.JoinRules,
- EventTypes.CanonicalAlias,
- EventTypes.RoomAvatar,
- EventTypes.RoomEncryption,
- EventTypes.Name,
- ],
+ def read_config(self, config: JsonDict, **kwargs):
+ validate_config(_MAIN_SCHEMA, config, ())
+ self.room_prejoin_state = list(self._get_prejoin_state_types(config))
+
+ def generate_config_section(cls, **kwargs) -> str:
+ formatted_default_state_types = "\n".join(
+ " # - %s" % (t,) for t in _DEFAULT_PREJOIN_STATE_TYPES
)
- def generate_config_section(cls, **kwargs):
return """\
## API Configuration ##
- # A list of event types that will be included in the room_invite_state
+ # Controls for the state that is shared with users who receive an invite
+ # to a room
#
- #room_invite_state_types:
- # - "{JoinRules}"
- # - "{CanonicalAlias}"
- # - "{RoomAvatar}"
- # - "{RoomEncryption}"
- # - "{Name}"
- """.format(
- **vars(EventTypes)
- )
+ room_prejoin_state:
+ # By default, the following state event types are shared with users who
+ # receive invites to the room:
+ #
+%(formatted_default_state_types)s
+ #
+ # Uncomment the following to disable these defaults (so that only the event
+ # types listed in 'additional_event_types' are shared). Defaults to 'false'.
+ #
+ #disable_default_event_types: true
+
+ # Additional state event types to share with users when they are invited
+ # to a room.
+ #
+ # By default, this list is empty (so only the default event types are shared).
+ #
+ #additional_event_types:
+ # - org.example.custom.event.type
+ """ % {
+ "formatted_default_state_types": formatted_default_state_types
+ }
+
+ def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]:
+ """Get the event types to include in the prejoin state
+
+ Parses the config and returns an iterable of the event types to be included.
+ """
+ room_prejoin_state_config = config.get("room_prejoin_state") or {}
+
+ # backwards-compatibility support for room_invite_state_types
+ if "room_invite_state_types" in config:
+ # if both "room_invite_state_types" and "room_prejoin_state" are set, then
+ # we don't really know what to do.
+ if room_prejoin_state_config:
+ raise ConfigError(
+ "Can't specify both 'room_invite_state_types' and 'room_prejoin_state' "
+ "in config"
+ )
+
+ logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING)
+
+ yield from config["room_invite_state_types"]
+ return
+
+ if not room_prejoin_state_config.get("disable_default_event_types"):
+ yield from _DEFAULT_PREJOIN_STATE_TYPES
+
+ if self.spaces_enabled:
+ # MSC1772 suggests adding m.room.create to the prejoin state
+ yield EventTypes.Create
+
+ yield from room_prejoin_state_config.get("additional_event_types", [])
+
+
+_ROOM_INVITE_STATE_TYPES_WARNING = """\
+WARNING: The 'room_invite_state_types' configuration setting is now deprecated,
+and replaced with 'room_prejoin_state'. New features may not work correctly
+unless 'room_invite_state_types' is removed. See the sample configuration file for
+details of 'room_prejoin_state'.
+--------------------------------------------------------------------------------
+"""
+
+_DEFAULT_PREJOIN_STATE_TYPES = [
+ EventTypes.JoinRules,
+ EventTypes.CanonicalAlias,
+ EventTypes.RoomAvatar,
+ EventTypes.RoomEncryption,
+ EventTypes.Name,
+]
+
+
+# room_prejoin_state can either be None (as it is in the default config), or
+# an object containing other config settings
+_ROOM_PREJOIN_STATE_CONFIG_SCHEMA = {
+ "oneOf": [
+ {
+ "type": "object",
+ "properties": {
+ "disable_default_event_types": {"type": "boolean"},
+ "additional_event_types": {
+ "type": "array",
+ "items": {"type": "string"},
+ },
+ },
+ },
+ {"type": "null"},
+ ]
+}
+
+# the legacy room_invite_state_types setting
+_ROOM_INVITE_STATE_TYPES_SCHEMA = {"type": "array", "items": {"type": "string"}}
+
+_MAIN_SCHEMA = {
+ "type": "object",
+ "properties": {
+ "room_prejoin_state": _ROOM_PREJOIN_STATE_CONFIG_SCHEMA,
+ "room_invite_state_types": _ROOM_INVITE_STATE_TYPES_SCHEMA,
+ },
+}
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index ead007ba5a..f27d1e14ac 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -298,9 +298,9 @@ class RegistrationConfig(Config):
#
#allowed_local_3pids:
# - medium: email
- # pattern: '.*@matrix\\.org'
+ # pattern: '^[^@]+@matrix\\.org$'
# - medium: email
- # pattern: '.*@vector\\.im'
+ # pattern: '^[^@]+@vector\\.im$'
# - medium: msisdn
# pattern: '\\+44'
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index d84e362070..71cb120ef7 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -870,6 +870,7 @@ class FederationHandlerRegistry:
# A rate limiter for incoming room key requests per origin.
self._room_key_request_rate_limiter = Ratelimiter(
+ store=hs.get_datastore(),
clock=self.clock,
rate_hz=self.config.rc_key_requests.per_second,
burst_count=self.config.rc_key_requests.burst_count,
@@ -930,7 +931,9 @@ class FederationHandlerRegistry:
# the limit, drop them.
if (
edu_type == EduTypes.RoomKeyRequest
- and not self._room_key_request_rate_limiter.can_do_action(origin)
+ and not await self._room_key_request_rate_limiter.can_do_action(
+ None, origin
+ )
):
return
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index aade2c4a3a..fb899aa90d 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -49,7 +49,7 @@ class BaseHandler:
# The rate_hz and burst_count are overridden on a per-user basis
self.request_ratelimiter = Ratelimiter(
- clock=self.clock, rate_hz=0, burst_count=0
+ store=self.store, clock=self.clock, rate_hz=0, burst_count=0
)
self._rc_message = self.hs.config.rc_message
@@ -57,6 +57,7 @@ class BaseHandler:
# by the presence of rate limits in the config
if self.hs.config.rc_admin_redaction:
self.admin_redaction_ratelimiter = Ratelimiter(
+ store=self.store,
clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count,
@@ -91,11 +92,6 @@ class BaseHandler:
if app_service is not None:
return # do not ratelimit app service senders
- # Disable rate limiting of users belonging to any AS that is configured
- # not to be rate limited in its registration file (rate_limited: true|false).
- if requester.app_service and not requester.app_service.is_rate_limited():
- return
-
messages_per_second = self._rc_message.per_second
burst_count = self._rc_message.burst_count
@@ -113,11 +109,11 @@ class BaseHandler:
if is_admin_redaction and self.admin_redaction_ratelimiter:
# If we have separate config for admin redactions, use a separate
# ratelimiter as to not have user_ids clash
- self.admin_redaction_ratelimiter.ratelimit(user_id, update=update)
+ await self.admin_redaction_ratelimiter.ratelimit(requester, update=update)
else:
# Override rate and burst count per-user
- self.request_ratelimiter.ratelimit(
- user_id,
+ await self.request_ratelimiter.ratelimit(
+ requester,
rate_hz=messages_per_second,
burst_count=burst_count,
update=update,
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index d537ea8137..08e413bc98 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -238,6 +238,7 @@ class AuthHandler(BaseHandler):
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter(
+ store=self.store,
clock=self.clock,
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
@@ -248,6 +249,7 @@ class AuthHandler(BaseHandler):
# Ratelimitier for failed /login attempts
self._failed_login_attempts_ratelimiter = Ratelimiter(
+ store=self.store,
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
@@ -352,7 +354,7 @@ class AuthHandler(BaseHandler):
requester_user_id = requester.user.to_string()
# Check if we should be ratelimited due to too many previous failed attempts
- self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
+ await self._failed_uia_attempts_ratelimiter.ratelimit(requester, update=False)
# build a list of supported flows
supported_ui_auth_types = await self._get_available_ui_auth_types(
@@ -373,7 +375,9 @@ class AuthHandler(BaseHandler):
)
except LoginError:
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
- self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
+ await self._failed_uia_attempts_ratelimiter.can_do_action(
+ requester,
+ )
raise
# find the completed login type
@@ -982,8 +986,8 @@ class AuthHandler(BaseHandler):
# We also apply account rate limiting using the 3PID as a key, as
# otherwise using 3PID bypasses the ratelimiting based on user ID.
if ratelimit:
- self._failed_login_attempts_ratelimiter.ratelimit(
- (medium, address), update=False
+ await self._failed_login_attempts_ratelimiter.ratelimit(
+ None, (medium, address), update=False
)
# Check for login providers that support 3pid login types
@@ -1016,8 +1020,8 @@ class AuthHandler(BaseHandler):
# this code path, which is fine as then the per-user ratelimit
# will kick in below.
if ratelimit:
- self._failed_login_attempts_ratelimiter.can_do_action(
- (medium, address)
+ await self._failed_login_attempts_ratelimiter.can_do_action(
+ None, (medium, address)
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@@ -1039,8 +1043,8 @@ class AuthHandler(BaseHandler):
# Check if we've hit the failed ratelimit (but don't update it)
if ratelimit:
- self._failed_login_attempts_ratelimiter.ratelimit(
- qualified_user_id.lower(), update=False
+ await self._failed_login_attempts_ratelimiter.ratelimit(
+ None, qualified_user_id.lower(), update=False
)
try:
@@ -1051,8 +1055,8 @@ class AuthHandler(BaseHandler):
# exception and masking the LoginError. The actual ratelimiting
# should have happened above.
if ratelimit:
- self._failed_login_attempts_ratelimiter.can_do_action(
- qualified_user_id.lower()
+ await self._failed_login_attempts_ratelimiter.can_do_action(
+ None, qualified_user_id.lower()
)
raise
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index eb547743be..5ee48be6ff 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -81,6 +81,7 @@ class DeviceMessageHandler:
)
self._ratelimiter = Ratelimiter(
+ store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.rc_key_requests.per_second,
burst_count=hs.config.rc_key_requests.burst_count,
@@ -191,8 +192,8 @@ class DeviceMessageHandler:
if (
message_type == EduTypes.RoomKeyRequest
and user_id != sender_user_id
- and self._ratelimiter.can_do_action(
- (sender_user_id, requester.device_id)
+ and await self._ratelimiter.can_do_action(
+ requester, (sender_user_id, requester.device_id)
)
):
continue
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 598a66f74c..3ebee38ebe 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1711,7 +1711,7 @@ class FederationHandler(BaseHandler):
member_handler = self.hs.get_room_member_handler()
# We don't rate limit based on room ID, as that should be done by
# sending server.
- member_handler.ratelimit_invite(None, event.state_key)
+ await member_handler.ratelimit_invite(None, None, event.state_key)
# keep a record of the room version, if we don't yet know it.
# (this may get overwritten if we later get a different room version in a
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 5f346f6d6d..d89fa5fb30 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -61,17 +61,19 @@ class IdentityHandler(BaseHandler):
# Ratelimiters for `/requestToken` endpoints.
self._3pid_validation_ratelimiter_ip = Ratelimiter(
+ store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
)
self._3pid_validation_ratelimiter_address = Ratelimiter(
+ store=self.store,
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
)
- def ratelimit_request_token_requests(
+ async def ratelimit_request_token_requests(
self,
request: SynapseRequest,
medium: str,
@@ -85,8 +87,12 @@ class IdentityHandler(BaseHandler):
address: The actual threepid ID, e.g. the phone number or email address
"""
- self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
- self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
+ await self._3pid_validation_ratelimiter_ip.ratelimit(
+ None, (medium, request.getClientIP())
+ )
+ await self._3pid_validation_ratelimiter_address.ratelimit(
+ None, (medium, address)
+ )
async def threepid_from_creds(
self, id_server: str, creds: Dict[str, str]
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 1b7c065b34..6069968f7f 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -385,7 +385,7 @@ class EventCreationHandler:
self._events_shard_config = self.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
- self.room_invite_state_types = self.hs.config.room_invite_state_types
+ self.room_invite_state_types = self.hs.config.api.room_prejoin_state
self.membership_types_to_include_profile_data_in = (
{Membership.JOIN, Membership.INVITE}
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 0fc2bf15d5..9701b76d0f 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -204,7 +204,7 @@ class RegistrationHandler(BaseHandler):
Raises:
SynapseError if there was a problem registering.
"""
- self.check_registration_ratelimit(address)
+ await self.check_registration_ratelimit(address)
result = await self.spam_checker.check_registration_for_spam(
threepid,
@@ -583,7 +583,7 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
- def check_registration_ratelimit(self, address: Optional[str]) -> None:
+ async def check_registration_ratelimit(self, address: Optional[str]) -> None:
"""A simple helper method to check whether the registration rate limit has been hit
for a given IP address
@@ -597,7 +597,7 @@ class RegistrationHandler(BaseHandler):
if not address:
return
- self.ratelimiter.ratelimit(address)
+ await self.ratelimiter.ratelimit(None, address)
async def register_with_store(
self,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 4d20ed8357..1cf12f3255 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -75,22 +75,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.allow_per_room_profiles = self.config.allow_per_room_profiles
self._join_rate_limiter_local = Ratelimiter(
+ store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
)
self._join_rate_limiter_remote = Ratelimiter(
+ store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
)
self._invites_per_room_limiter = Ratelimiter(
+ store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
)
self._invites_per_user_limiter = Ratelimiter(
+ store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
@@ -159,15 +163,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
async def forget(self, user: UserID, room_id: str) -> None:
raise NotImplementedError()
- def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str):
+ async def ratelimit_invite(
+ self,
+ requester: Optional[Requester],
+ room_id: Optional[str],
+ invitee_user_id: str,
+ ):
"""Ratelimit invites by room and by target user.
If room ID is missing then we just rate limit by target user.
"""
if room_id:
- self._invites_per_room_limiter.ratelimit(room_id)
+ await self._invites_per_room_limiter.ratelimit(requester, room_id)
- self._invites_per_user_limiter.ratelimit(invitee_user_id)
+ await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id)
async def _local_membership_update(
self,
@@ -237,7 +246,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
(
allowed,
time_allowed,
- ) = self._join_rate_limiter_local.can_requester_do_action(requester)
+ ) = await self._join_rate_limiter_local.can_do_action(requester)
if not allowed:
raise LimitExceededError(
@@ -421,9 +430,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if effective_membership_state == Membership.INVITE:
target_id = target.to_string()
if ratelimit:
- # Don't ratelimit application services.
- if not requester.app_service or requester.app_service.is_rate_limited():
- self.ratelimit_invite(room_id, target_id)
+ await self.ratelimit_invite(requester, room_id, target_id)
# block any attempts to invite the server notices mxid
if target_id == self._server_notices_mxid:
@@ -534,7 +541,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
(
allowed,
time_allowed,
- ) = self._join_rate_limiter_remote.can_requester_do_action(
+ ) = await self._join_rate_limiter_remote.can_do_action(
requester,
)
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index d005f38767..73d7477854 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -77,7 +77,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
- self.registration_handler.check_registration_ratelimit(content["address"])
+ await self.registration_handler.check_registration_ratelimit(content["address"])
await self.registration_handler.register_with_store(
user_id=user_id,
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index e4c352f572..3151e72d4f 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -74,11 +74,13 @@ class LoginRestServlet(RestServlet):
self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter(
+ store=hs.get_datastore(),
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_address.per_second,
burst_count=self.hs.config.rc_login_address.burst_count,
)
self._account_ratelimiter = Ratelimiter(
+ store=hs.get_datastore(),
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count,
@@ -141,20 +143,22 @@ class LoginRestServlet(RestServlet):
appservice = self.auth.get_appservice_by_req(request)
if appservice.is_rate_limited():
- self._address_ratelimiter.ratelimit(request.getClientIP())
+ await self._address_ratelimiter.ratelimit(
+ None, request.getClientIP()
+ )
result = await self._do_appservice_login(login_submission, appservice)
elif self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
- self._address_ratelimiter.ratelimit(request.getClientIP())
+ await self._address_ratelimiter.ratelimit(None, request.getClientIP())
result = await self._do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
- self._address_ratelimiter.ratelimit(request.getClientIP())
+ await self._address_ratelimiter.ratelimit(None, request.getClientIP())
result = await self._do_token_login(login_submission)
else:
- self._address_ratelimiter.ratelimit(request.getClientIP())
+ await self._address_ratelimiter.ratelimit(None, request.getClientIP())
result = await self._do_other_login(login_submission)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@@ -258,7 +262,7 @@ class LoginRestServlet(RestServlet):
# too often. This happens here rather than before as we don't
# necessarily know the user before now.
if ratelimit:
- self._account_ratelimiter.ratelimit(user_id.lower())
+ await self._account_ratelimiter.ratelimit(None, user_id.lower())
if create_non_existent_users:
canonical_uid = await self.auth_handler.check_user_exists(user_id)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index c2ba790bab..411fb57c47 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -103,7 +103,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
- self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+ await self.identity_handler.ratelimit_request_token_requests(
+ request, "email", email
+ )
# The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to
@@ -387,7 +389,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+ await self.identity_handler.ratelimit_request_token_requests(
+ request, "email", email
+ )
if next_link:
# Raise if the provided next_link value isn't valid
@@ -468,7 +472,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- self.identity_handler.ratelimit_request_token_requests(
+ await self.identity_handler.ratelimit_request_token_requests(
request, "msisdn", msisdn
)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 8f68d8dfc8..c212da0cb2 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -126,7 +126,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+ await self.identity_handler.ratelimit_request_token_requests(
+ request, "email", email
+ )
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email
@@ -208,7 +210,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- self.identity_handler.ratelimit_request_token_requests(
+ await self.identity_handler.ratelimit_request_token_requests(
request, "msisdn", msisdn
)
@@ -406,7 +408,7 @@ class RegisterRestServlet(RestServlet):
client_addr = request.getClientIP()
- self.ratelimiter.ratelimit(client_addr, update=False)
+ await self.ratelimiter.ratelimit(None, client_addr, update=False)
kind = b"user"
if b"kind" in request.args:
diff --git a/synapse/server.py b/synapse/server.py
index e85b9391fa..e42f7b1a18 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -329,6 +329,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_registration_ratelimiter(self) -> Ratelimiter:
return Ratelimiter(
+ store=self.get_datastore(),
clock=self.get_clock(),
rate_hz=self.config.rc_registration.per_second,
burst_count=self.config.rc_registration.burst_count,
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 952d4969b2..c00780969f 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -16,7 +16,7 @@
import logging
import threading
from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Tuple, overload
+from typing import Container, Dict, Iterable, List, Optional, Tuple, overload
from constantly import NamedConstant, Names
from typing_extensions import Literal
@@ -544,7 +544,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_stripped_room_state_from_event_context(
self,
context: EventContext,
- state_types_to_include: List[EventTypes],
+ state_types_to_include: Container[str],
membership_user_id: Optional[str] = None,
) -> List[JsonDict]:
"""
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 6c3c2da520..c7f0b8ccb5 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import imp
+import importlib.util
import logging
import os
import re
@@ -454,8 +454,13 @@ def _upgrade_existing_database(
)
module_name = "synapse.storage.v%d_%s" % (v, root_name)
- with open(absolute_path) as python_file:
- module = imp.load_source(module_name, absolute_path, python_file) # type: ignore
+
+ spec = importlib.util.spec_from_file_location(
+ module_name, absolute_path
+ )
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module) # type: ignore
+
logger.info("Running script %s", relative_path)
module.run_create(cur, database_engine) # type: ignore
if not is_empty:
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 483418192c..fa96ba07a5 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -5,38 +5,25 @@ from synapse.types import create_requester
from tests import unittest
-class TestRatelimiter(unittest.TestCase):
+class TestRatelimiter(unittest.HomeserverTestCase):
def test_allowed_via_can_do_action(self):
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
- allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0)
- self.assertTrue(allowed)
- self.assertEquals(10.0, time_allowed)
-
- allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5)
- self.assertFalse(allowed)
- self.assertEquals(10.0, time_allowed)
-
- allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10)
- self.assertTrue(allowed)
- self.assertEquals(20.0, time_allowed)
-
- def test_allowed_user_via_can_requester_do_action(self):
- user_requester = create_requester("@user:example.com")
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
- allowed, time_allowed = limiter.can_requester_do_action(
- user_requester, _time_now_s=0
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(None, key="test_id", _time_now_s=0)
)
self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- user_requester, _time_now_s=5
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(None, key="test_id", _time_now_s=5)
)
self.assertFalse(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- user_requester, _time_now_s=10
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(None, key="test_id", _time_now_s=10)
)
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
@@ -51,21 +38,23 @@ class TestRatelimiter(unittest.TestCase):
)
as_requester = create_requester("@user:example.com", app_service=appservice)
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=0
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=0)
)
self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=5
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=5)
)
self.assertFalse(allowed)
self.assertEquals(10.0, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=10
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=10)
)
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
@@ -80,73 +69,89 @@ class TestRatelimiter(unittest.TestCase):
)
as_requester = create_requester("@user:example.com", app_service=appservice)
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=0
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=0)
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=5
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=5)
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
- allowed, time_allowed = limiter.can_requester_do_action(
- as_requester, _time_now_s=10
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(as_requester, _time_now_s=10)
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
def test_allowed_via_ratelimit(self):
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
# Shouldn't raise
- limiter.ratelimit(key="test_id", _time_now_s=0)
+ self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0))
# Should raise
with self.assertRaises(LimitExceededError) as context:
- limiter.ratelimit(key="test_id", _time_now_s=5)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key="test_id", _time_now_s=5)
+ )
self.assertEqual(context.exception.retry_after_ms, 5000)
# Shouldn't raise
- limiter.ratelimit(key="test_id", _time_now_s=10)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key="test_id", _time_now_s=10)
+ )
def test_allowed_via_can_do_action_and_overriding_parameters(self):
"""Test that we can override options of can_do_action that would otherwise fail
an action
"""
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
# First attempt should be allowed
- allowed, time_allowed = limiter.can_do_action(
- ("test_id",),
- _time_now_s=0,
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(
+ None,
+ ("test_id",),
+ _time_now_s=0,
+ )
)
self.assertTrue(allowed)
self.assertEqual(10.0, time_allowed)
# Second attempt, 1s later, will fail
- allowed, time_allowed = limiter.can_do_action(
- ("test_id",),
- _time_now_s=1,
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(
+ None,
+ ("test_id",),
+ _time_now_s=1,
+ )
)
self.assertFalse(allowed)
self.assertEqual(10.0, time_allowed)
# But, if we allow 10 actions/sec for this request, we should be allowed
# to continue.
- allowed, time_allowed = limiter.can_do_action(
- ("test_id",), _time_now_s=1, rate_hz=10.0
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(None, ("test_id",), _time_now_s=1, rate_hz=10.0)
)
self.assertTrue(allowed)
self.assertEqual(1.1, time_allowed)
# Similarly if we allow a burst of 10 actions
- allowed, time_allowed = limiter.can_do_action(
- ("test_id",), _time_now_s=1, burst_count=10
+ allowed, time_allowed = self.get_success_or_raise(
+ limiter.can_do_action(None, ("test_id",), _time_now_s=1, burst_count=10)
)
self.assertTrue(allowed)
self.assertEqual(1.0, time_allowed)
@@ -156,29 +161,72 @@ class TestRatelimiter(unittest.TestCase):
fail an action
"""
# Create a Ratelimiter with a very low allowed rate_hz and burst_count
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
# First attempt should be allowed
- limiter.ratelimit(key=("test_id",), _time_now_s=0)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key=("test_id",), _time_now_s=0)
+ )
# Second attempt, 1s later, will fail
with self.assertRaises(LimitExceededError) as context:
- limiter.ratelimit(key=("test_id",), _time_now_s=1)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key=("test_id",), _time_now_s=1)
+ )
self.assertEqual(context.exception.retry_after_ms, 9000)
# But, if we allow 10 actions/sec for this request, we should be allowed
# to continue.
- limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key=("test_id",), _time_now_s=1, rate_hz=10.0)
+ )
# Similarly if we allow a burst of 10 actions
- limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10)
+ self.get_success_or_raise(
+ limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10)
+ )
def test_pruning(self):
- limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
- limiter.can_do_action(key="test_id_1", _time_now_s=0)
+ limiter = Ratelimiter(
+ store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+ )
+ self.get_success_or_raise(
+ limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
+ )
self.assertIn("test_id_1", limiter.actions)
- limiter.can_do_action(key="test_id_2", _time_now_s=10)
+ self.get_success_or_raise(
+ limiter.can_do_action(None, key="test_id_2", _time_now_s=10)
+ )
self.assertNotIn("test_id_1", limiter.actions)
+
+ def test_db_user_override(self):
+ """Test that users that have ratelimiting disabled in the DB aren't
+ ratelimited.
+ """
+ store = self.hs.get_datastore()
+
+ user_id = "@user:test"
+ requester = create_requester(user_id)
+
+ self.get_success(
+ store.db_pool.simple_insert(
+ table="ratelimit_override",
+ values={
+ "user_id": user_id,
+ "messages_per_second": None,
+ "burst_count": None,
+ },
+ desc="test_db_user_override",
+ )
+ )
+
+ limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1)
+
+ # Shouldn't raise
+ for _ in range(20):
+ self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))
diff --git a/tests/utils.py b/tests/utils.py
index be80b13760..a141ee6496 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -122,7 +122,6 @@ def default_config(name, parse=False):
"enable_registration_captcha": False,
"macaroon_secret_key": "not even a little secret",
"trusted_third_party_id_servers": [],
- "room_invite_state_types": [],
"password_providers": [],
"worker_replication_url": "",
"worker_app": None,
|