summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py5
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py9
-rw-r--r--synapse/api/errors.py22
-rw-r--r--synapse/api/ratelimiting.py20
-rw-r--r--synapse/config/_base.py8
-rw-r--r--synapse/config/appservice.py2
-rw-r--r--synapse/config/cas.py13
-rw-r--r--synapse/config/experimental.py14
-rw-r--r--synapse/config/ratelimiting.py132
-rw-r--r--synapse/event_auth.py12
-rw-r--r--synapse/events/utils.py4
-rw-r--r--synapse/events/validator.py4
-rw-r--r--synapse/federation/federation_base.py2
-rw-r--r--synapse/federation/federation_client.py2
-rw-r--r--synapse/federation/send_queue.py12
-rw-r--r--synapse/federation/sender/__init__.py86
-rw-r--r--synapse/federation/sender/per_destination_queue.py6
-rw-r--r--synapse/federation/transport/client.py20
-rw-r--r--synapse/handlers/admin.py1
-rw-r--r--synapse/handlers/auth.py8
-rw-r--r--synapse/handlers/cas.py6
-rw-r--r--synapse/handlers/device.py26
-rw-r--r--synapse/handlers/devicemessage.py10
-rw-r--r--synapse/handlers/events.py1
-rw-r--r--synapse/handlers/identity.py6
-rw-r--r--synapse/handlers/message.py56
-rw-r--r--synapse/handlers/presence.py360
-rw-r--r--synapse/handlers/room_member.py21
-rw-r--r--synapse/handlers/room_summary.py5
-rw-r--r--synapse/handlers/send_email.py28
-rw-r--r--synapse/handlers/typing.py14
-rw-r--r--synapse/http/matrixfederationclient.py10
-rw-r--r--synapse/http/server.py8
-rw-r--r--synapse/logging/_terse_json.py1
-rw-r--r--synapse/logging/context.py19
-rw-r--r--synapse/logging/opentracing.py14
-rw-r--r--synapse/media/media_repository.py5
-rw-r--r--synapse/media/oembed.py2
-rw-r--r--synapse/media/thumbnailer.py2
-rw-r--r--synapse/module_api/__init__.py2
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py8
-rw-r--r--synapse/replication/http/devices.py2
-rw-r--r--synapse/replication/http/presence.py21
-rw-r--r--synapse/replication/tcp/client.py8
-rw-r--r--synapse/replication/tcp/commands.py29
-rw-r--r--synapse/replication/tcp/handler.py37
-rw-r--r--synapse/rest/admin/__init__.py2
-rw-r--r--synapse/rest/admin/registration_tokens.py21
-rw-r--r--synapse/rest/admin/users.py8
-rw-r--r--synapse/rest/client/login.py6
-rw-r--r--synapse/rest/client/login_token_request.py10
-rw-r--r--synapse/rest/client/presence.py2
-rw-r--r--synapse/rest/client/read_marker.py4
-rw-r--r--synapse/rest/client/receipts.py4
-rw-r--r--synapse/rest/client/register.py3
-rw-r--r--synapse/rest/client/report_event.py2
-rw-r--r--synapse/rest/client/room.py4
-rw-r--r--synapse/rest/client/sync.py1
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py39
-rw-r--r--synapse/server.py3
-rw-r--r--synapse/storage/background_updates.py4
-rw-r--r--synapse/storage/database.py17
-rw-r--r--synapse/storage/databases/main/__init__.py6
-rw-r--r--synapse/storage/databases/main/events.py41
-rw-r--r--synapse/storage/databases/main/events_worker.py41
-rw-r--r--synapse/storage/databases/main/lock.py36
-rw-r--r--synapse/storage/databases/main/push_rule.py1
-rw-r--r--synapse/storage/databases/main/registration.py7
-rw-r--r--synapse/storage/databases/main/stats.py1
-rw-r--r--synapse/storage/databases/main/transactions.py26
-rw-r--r--synapse/storage/schema/__init__.py16
-rw-r--r--synapse/types/__init__.py2
-rw-r--r--synapse/util/caches/deferred_cache.py2
-rw-r--r--synapse/util/check_dependencies.py6
-rw-r--r--synapse/util/ratelimitutils.py3
-rw-r--r--synapse/util/retryutils.py43
-rw-r--r--synapse/util/task_scheduler.py92
77 files changed, 929 insertions, 607 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py

index 2f9c22a833..4a9bbc4d57 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py
@@ -21,9 +21,14 @@ import os import sys from typing import Any, Dict +from PIL import ImageFile + from synapse.util.rust import check_rust_lib_up_to_date from synapse.util.stringutils import strtobool +# Allow truncated JPEG images to be thumbnailed. +ImageFile.LOAD_TRUNCATED_IMAGES = True + # Check that we're not running on an unsupported Python version. # # Note that we use an (unneeded) variable here so that pyupgrade doesn't nuke the diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 49242800b8..ab2b29cf1b 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py
@@ -482,7 +482,10 @@ class Porter: do_backward[0] = False if forward_rows or backward_rows: - headers = [column[0] for column in txn.description] + assert txn.description is not None + headers: Optional[List[str]] = [ + column[0] for column in txn.description + ] else: headers = None @@ -544,6 +547,7 @@ class Porter: def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]: txn.execute(select, (forward_chunk, self.batch_size)) rows = txn.fetchall() + assert txn.description is not None headers = [column[0] for column in txn.description] return headers, rows @@ -919,7 +923,8 @@ class Porter: def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]: txn.execute(select) rows = txn.fetchall() - headers: List[str] = [column[0] for column in txn.description] + assert txn.description is not None + headers = [column[0] for column in txn.description] ts_ind = headers.index("ts") diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 7ffd72c42c..fdb2955be8 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py
@@ -16,6 +16,7 @@ """Contains exceptions and error codes.""" import logging +import math import typing from enum import Enum from http import HTTPStatus @@ -210,6 +211,11 @@ class SynapseError(CodeMessageException): def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": return cs_error(self.msg, self.errcode, **self._additional_fields) + @property + def debug_context(self) -> Optional[str]: + """Override this to add debugging context that shouldn't be sent to clients.""" + return None + class InvalidAPICallError(SynapseError): """You called an existing API endpoint, but fed that endpoint @@ -503,19 +509,31 @@ class InvalidCaptchaError(SynapseError): class LimitExceededError(SynapseError): """A client has sent too many requests and is being throttled.""" + include_retry_after_header = False + def __init__( self, + limiter_name: str, code: int = 429, - msg: str = "Too Many Requests", retry_after_ms: Optional[int] = None, errcode: str = Codes.LIMIT_EXCEEDED, ): - super().__init__(code, msg, errcode) + headers = ( + {"Retry-After": str(math.ceil(retry_after_ms / 1000))} + if self.include_retry_after_header and retry_after_ms is not None + else None + ) + super().__init__(code, "Too Many Requests", errcode, headers=headers) self.retry_after_ms = retry_after_ms + self.limiter_name = limiter_name def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms) + @property + def debug_context(self) -> Optional[str]: + return self.limiter_name + class RoomKeysVersionError(SynapseError): """A client has tried to upload to a non-current version of the room_keys store""" diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 511790c7c5..887b214d64 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py
@@ -61,12 +61,16 @@ class Ratelimiter: """ def __init__( - self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int + self, + store: DataStore, + clock: Clock, + cfg: RatelimitSettings, ): self.clock = clock - self.rate_hz = rate_hz - self.burst_count = burst_count + self.rate_hz = cfg.per_second + self.burst_count = cfg.burst_count self.store = store + self._limiter_name = cfg.key # An ordered dictionary representing the token buckets tracked by this rate # limiter. Each entry maps a key of arbitrary type to a tuple representing: @@ -305,7 +309,8 @@ class Ratelimiter: if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now_s)) + limiter_name=self._limiter_name, + retry_after_ms=int(1000 * (time_allowed - time_now_s)), ) @@ -322,7 +327,9 @@ class RequestRatelimiter: # The rate_hz and burst_count are overridden on a per-user basis self.request_ratelimiter = Ratelimiter( - store=self.store, clock=self.clock, rate_hz=0, burst_count=0 + store=self.store, + clock=self.clock, + cfg=RatelimitSettings(key=rc_message.key, per_second=0, burst_count=0), ) self._rc_message = rc_message @@ -332,8 +339,7 @@ class RequestRatelimiter: self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=rc_admin_redaction.per_second, - burst_count=rc_admin_redaction.burst_count, + cfg=rc_admin_redaction, ) else: self.admin_redaction_ratelimiter = None diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 1d268a1817..69a8318127 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py
@@ -186,9 +186,9 @@ class Config: TypeError, if given something other than an integer or a string ValueError: if given a string not of the form described above. """ - if type(value) is int: + if type(value) is int: # noqa: E721 return value - elif type(value) is str: + elif isinstance(value, str): sizes = {"K": 1024, "M": 1024 * 1024} size = 1 suffix = value[-1] @@ -218,9 +218,9 @@ class Config: TypeError, if given something other than an integer or a string ValueError: if given a string not of the form described above. """ - if type(value) is int: + if type(value) is int: # noqa: E721 return value - elif type(value) is str: + elif isinstance(value, str): second = 1000 minute = 60 * second hour = 60 * minute diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 919f81a9b7..a70dfbf41f 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py
@@ -34,7 +34,7 @@ class AppServiceConfig(Config): def read_config(self, config: JsonDict, **kwargs: Any) -> None: self.app_service_config_files = config.get("app_service_config_files", []) if not isinstance(self.app_service_config_files, list) or not all( - type(x) is str for x in self.app_service_config_files + isinstance(x, str) for x in self.app_service_config_files ): raise ConfigError( "Expected '%s' to be a list of AS config files:" diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index c4e63e7411..6e2d9addbf 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py
@@ -18,7 +18,7 @@ from typing import Any, List from synapse.config.sso import SsoAttributeRequirement from synapse.types import JsonDict -from ._base import Config +from ._base import Config, ConfigError from ._util import validate_config @@ -41,6 +41,16 @@ class CasConfig(Config): public_baseurl = self.root.server.public_baseurl self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket" + self.cas_protocol_version = cas_config.get("protocol_version") + if ( + self.cas_protocol_version is not None + and self.cas_protocol_version not in [1, 2, 3] + ): + raise ConfigError( + "Unsupported CAS protocol version %s (only versions 1, 2, 3 are supported)" + % (self.cas_protocol_version,), + ("cas_config", "protocol_version"), + ) self.cas_displayname_attribute = cas_config.get("displayname_attribute") required_attributes = cas_config.get("required_attributes") or {} self.cas_required_attributes = _parsed_required_attributes_def( @@ -54,6 +64,7 @@ class CasConfig(Config): else: self.cas_server_url = None self.cas_service_url = None + self.cas_protocol_version = None self.cas_displayname_attribute = None self.cas_required_attributes = [] diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 277ea4675b..cabe0d4397 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py
@@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Any, Optional import attr import attr.validators +from synapse.api.errors import LimitExceededError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.config import ConfigError from synapse.config._base import Config, RootConfig @@ -383,11 +384,6 @@ class ExperimentalConfig(Config): # MSC3391: Removing account data. self.msc3391_enabled = experimental.get("msc3391_enabled", False) - # MSC3959: Do not generate notifications for edits. - self.msc3958_supress_edit_notifs = experimental.get( - "msc3958_supress_edit_notifs", False - ) - # MSC3967: Do not require UIA when first uploading cross signing keys self.msc3967_enabled = experimental.get("msc3967_enabled", False) @@ -411,3 +407,11 @@ class ExperimentalConfig(Config): self.msc4010_push_rules_account_data = experimental.get( "msc4010_push_rules_account_data", False ) + + # MSC4041: Use HTTP header Retry-After to enable library-assisted retry handling + # + # This is a bit hacky, but the most reasonable way to *alway* include the + # headers. + LimitExceededError.include_retry_after_header = experimental.get( + "msc4041_enabled", False + ) diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index a5514e70a2..4efbaeac0d 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py
@@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, cast import attr @@ -21,16 +21,47 @@ from synapse.types import JsonDict from ._base import Config +@attr.s(slots=True, frozen=True, auto_attribs=True) class RatelimitSettings: - def __init__( - self, - config: Dict[str, float], + key: str + per_second: float + burst_count: int + + @classmethod + def parse( + cls, + config: Dict[str, Any], + key: str, defaults: Optional[Dict[str, float]] = None, - ): + ) -> "RatelimitSettings": + """Parse config[key] as a new-style rate limiter config. + + The key may refer to a nested dictionary using a full stop (.) to separate + each nested key. For example, use the key "a.b.c" to parse the following: + + a: + b: + c: + per_second: 10 + burst_count: 200 + + If this lookup fails, we'll fallback to the defaults. + """ defaults = defaults or {"per_second": 0.17, "burst_count": 3.0} - self.per_second = config.get("per_second", defaults["per_second"]) - self.burst_count = int(config.get("burst_count", defaults["burst_count"])) + rl_config = config + for part in key.split("."): + rl_config = rl_config.get(part, {}) + + # By this point we should have hit the rate limiter parameters. + # We don't actually check this though! + rl_config = cast(Dict[str, float], rl_config) + + return cls( + key=key, + per_second=rl_config.get("per_second", defaults["per_second"]), + burst_count=int(rl_config.get("burst_count", defaults["burst_count"])), + ) @attr.s(auto_attribs=True) @@ -49,15 +80,14 @@ class RatelimitConfig(Config): # Load the new-style messages config if it exists. Otherwise fall back # to the old method. if "rc_message" in config: - self.rc_message = RatelimitSettings( - config["rc_message"], defaults={"per_second": 0.2, "burst_count": 10.0} + self.rc_message = RatelimitSettings.parse( + config, "rc_message", defaults={"per_second": 0.2, "burst_count": 10.0} ) else: self.rc_message = RatelimitSettings( - { - "per_second": config.get("rc_messages_per_second", 0.2), - "burst_count": config.get("rc_message_burst_count", 10.0), - } + key="rc_messages", + per_second=config.get("rc_messages_per_second", 0.2), + burst_count=config.get("rc_message_burst_count", 10.0), ) # Load the new-style federation config, if it exists. Otherwise, fall @@ -79,51 +109,59 @@ class RatelimitConfig(Config): } ) - self.rc_registration = RatelimitSettings(config.get("rc_registration", {})) + self.rc_registration = RatelimitSettings.parse(config, "rc_registration", {}) - self.rc_registration_token_validity = RatelimitSettings( - config.get("rc_registration_token_validity", {}), + self.rc_registration_token_validity = RatelimitSettings.parse( + config, + "rc_registration_token_validity", defaults={"per_second": 0.1, "burst_count": 5}, ) # It is reasonable to login with a bunch of devices at once (i.e. when # setting up an account), but it is *not* valid to continually be # logging into new devices. - rc_login_config = config.get("rc_login", {}) - self.rc_login_address = RatelimitSettings( - rc_login_config.get("address", {}), + self.rc_login_address = RatelimitSettings.parse( + config, + "rc_login.address", defaults={"per_second": 0.003, "burst_count": 5}, ) - self.rc_login_account = RatelimitSettings( - rc_login_config.get("account", {}), + self.rc_login_account = RatelimitSettings.parse( + config, + "rc_login.account", defaults={"per_second": 0.003, "burst_count": 5}, ) - self.rc_login_failed_attempts = RatelimitSettings( - rc_login_config.get("failed_attempts", {}) + self.rc_login_failed_attempts = RatelimitSettings.parse( + config, + "rc_login.failed_attempts", + {}, ) self.federation_rr_transactions_per_room_per_second = config.get( "federation_rr_transactions_per_room_per_second", 50 ) - rc_admin_redaction = config.get("rc_admin_redaction") self.rc_admin_redaction = None - if rc_admin_redaction: - self.rc_admin_redaction = RatelimitSettings(rc_admin_redaction) + if "rc_admin_redaction" in config: + self.rc_admin_redaction = RatelimitSettings.parse( + config, "rc_admin_redaction", {} + ) - self.rc_joins_local = RatelimitSettings( - config.get("rc_joins", {}).get("local", {}), + self.rc_joins_local = RatelimitSettings.parse( + config, + "rc_joins.local", defaults={"per_second": 0.1, "burst_count": 10}, ) - self.rc_joins_remote = RatelimitSettings( - config.get("rc_joins", {}).get("remote", {}), + self.rc_joins_remote = RatelimitSettings.parse( + config, + "rc_joins.remote", defaults={"per_second": 0.01, "burst_count": 10}, ) # Track the rate of joins to a given room. If there are too many, temporarily # prevent local joins and remote joins via this server. - self.rc_joins_per_room = RatelimitSettings( - config.get("rc_joins_per_room", {}), + self.rc_joins_per_room = RatelimitSettings.parse( + config, + "rc_joins_per_room", defaults={"per_second": 1, "burst_count": 10}, ) @@ -132,31 +170,37 @@ class RatelimitConfig(Config): # * For requests received over federation this is keyed by the origin. # # Note that this isn't exposed in the configuration as it is obscure. - self.rc_key_requests = RatelimitSettings( - config.get("rc_key_requests", {}), + self.rc_key_requests = RatelimitSettings.parse( + config, + "rc_key_requests", defaults={"per_second": 20, "burst_count": 100}, ) - self.rc_3pid_validation = RatelimitSettings( - config.get("rc_3pid_validation") or {}, + self.rc_3pid_validation = RatelimitSettings.parse( + config, + "rc_3pid_validation", defaults={"per_second": 0.003, "burst_count": 5}, ) - self.rc_invites_per_room = RatelimitSettings( - config.get("rc_invites", {}).get("per_room", {}), + self.rc_invites_per_room = RatelimitSettings.parse( + config, + "rc_invites.per_room", defaults={"per_second": 0.3, "burst_count": 10}, ) - self.rc_invites_per_user = RatelimitSettings( - config.get("rc_invites", {}).get("per_user", {}), + self.rc_invites_per_user = RatelimitSettings.parse( + config, + "rc_invites.per_user", defaults={"per_second": 0.003, "burst_count": 5}, ) - self.rc_invites_per_issuer = RatelimitSettings( - config.get("rc_invites", {}).get("per_issuer", {}), + self.rc_invites_per_issuer = RatelimitSettings.parse( + config, + "rc_invites.per_issuer", defaults={"per_second": 0.3, "burst_count": 10}, ) - self.rc_third_party_invite = RatelimitSettings( - config.get("rc_third_party_invite", {}), + self.rc_third_party_invite = RatelimitSettings.parse( + config, + "rc_third_party_invite", defaults={"per_second": 0.0025, "burst_count": 5}, ) diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 3a260a492b..2ac9f8b309 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py
@@ -669,12 +669,18 @@ def _is_membership_change_allowed( errcode=Codes.INSUFFICIENT_POWER, ) elif Membership.BAN == membership: - if user_level < ban_level or user_level <= target_level: + if user_level < ban_level: raise UnstableSpecAuthError( 403, "You don't have permission to ban", errcode=Codes.INSUFFICIENT_POWER, ) + elif user_level <= target_level: + raise UnstableSpecAuthError( + 403, + "You don't have permission to ban this user", + errcode=Codes.INSUFFICIENT_POWER, + ) elif room_version.knock_join_rule and Membership.KNOCK == membership: if join_rule != JoinRules.KNOCK and ( not room_version.knock_restricted_join_rule @@ -846,11 +852,11 @@ def _check_power_levels( "kick", "invite", }: - if type(v) is not int: + if type(v) is not int: # noqa: E721 raise SynapseError(400, f"{v!r} must be an integer.") if k in {"events", "notifications", "users"}: if not isinstance(v, collections.abc.Mapping) or not all( - type(v) is int for v in v.values() + type(v) is int for v in v.values() # noqa: E721 ): raise SynapseError( 400, diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 52acb21955..53af423a5a 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py
@@ -702,7 +702,7 @@ def _copy_power_level_value_as_integer( :raises TypeError: if `old_value` is neither an integer nor a base-10 string representation of an integer. """ - if type(old_value) is int: + if type(old_value) is int: # noqa: E721 power_levels[key] = old_value return @@ -730,7 +730,7 @@ def validate_canonicaljson(value: Any) -> None: * Floats * NaN, Infinity, -Infinity """ - if type(value) is int: + if type(value) is int: # noqa: E721 if value < CANONICALJSON_MIN_INT or CANONICALJSON_MAX_INT < value: raise SynapseError(400, "JSON integer out of range", Codes.BAD_JSON) diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 9278f1a1aa..34625dd7a1 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py
@@ -151,7 +151,7 @@ class EventValidator: max_lifetime = event.content.get("max_lifetime") if min_lifetime is not None: - if type(min_lifetime) is not int: + if type(min_lifetime) is not int: # noqa: E721 raise SynapseError( code=400, msg="'min_lifetime' must be an integer", @@ -159,7 +159,7 @@ class EventValidator: ) if max_lifetime is not None: - if type(max_lifetime) is not int: + if type(max_lifetime) is not int: # noqa: E721 raise SynapseError( code=400, msg="'max_lifetime' must be an integer", diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 31e0260b83..d4e7dd45a9 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py
@@ -280,7 +280,7 @@ def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventB _strip_unsigned_values(pdu_json) depth = pdu_json["depth"] - if type(depth) is not int: + if type(depth) is not int: # noqa: E721 raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON) if depth < 0: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 89bd597409..607013f121 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py
@@ -1891,7 +1891,7 @@ class TimestampToEventResponse: ) origin_server_ts = d.get("origin_server_ts") - if type(origin_server_ts) is not int: + if type(origin_server_ts) is not int: # noqa: E721 raise ValueError( "Invalid response: 'origin_server_ts' must be a int but received %r" % origin_server_ts diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index fb448f2155..6520795635 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py
@@ -49,7 +49,7 @@ from synapse.api.presence import UserPresenceState from synapse.federation.sender import AbstractFederationSender, FederationSender from synapse.metrics import LaterGauge from synapse.replication.tcp.streams.federation import FederationStream -from synapse.types import JsonDict, ReadReceipt, RoomStreamToken +from synapse.types import JsonDict, ReadReceipt, RoomStreamToken, StrCollection from synapse.util.metrics import Measure from .units import Edu @@ -229,7 +229,7 @@ class FederationRemoteSendQueue(AbstractFederationSender): """ # nothing to do here: the replication listener will handle it. - def send_presence_to_destinations( + async def send_presence_to_destinations( self, states: Iterable[UserPresenceState], destinations: Iterable[str] ) -> None: """As per FederationSender @@ -245,7 +245,9 @@ class FederationRemoteSendQueue(AbstractFederationSender): self.notifier.on_new_replication_data() - def send_device_messages(self, destination: str, immediate: bool = True) -> None: + async def send_device_messages( + self, destinations: StrCollection, immediate: bool = True + ) -> None: """As per FederationSender""" # We don't need to replicate this as it gets sent down a different # stream. @@ -463,7 +465,7 @@ class ParsedFederationStreamData: edus: Dict[str, List[Edu]] -def process_rows_for_federation( +async def process_rows_for_federation( transaction_queue: FederationSender, rows: List[FederationStream.FederationStreamRow], ) -> None: @@ -496,7 +498,7 @@ def process_rows_for_federation( parsed_row.add_to_buffer(buff) for state, destinations in buff.presence_destinations: - transaction_queue.send_presence_to_destinations( + await transaction_queue.send_presence_to_destinations( states=[state], destinations=destinations ) diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 97abbdee18..fb20fd8a10 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py
@@ -147,7 +147,10 @@ from twisted.internet import defer import synapse.metrics from synapse.api.presence import UserPresenceState from synapse.events import EventBase -from synapse.federation.sender.per_destination_queue import PerDestinationQueue +from synapse.federation.sender.per_destination_queue import ( + CATCHUP_RETRY_INTERVAL, + PerDestinationQueue, +) from synapse.federation.sender.transaction_manager import TransactionManager from synapse.federation.units import Edu from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -161,9 +164,10 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) -from synapse.types import JsonDict, ReadReceipt, RoomStreamToken +from synapse.types import JsonDict, ReadReceipt, RoomStreamToken, StrCollection from synapse.util import Clock from synapse.util.metrics import Measure +from synapse.util.retryutils import filter_destinations_by_retry_limiter if TYPE_CHECKING: from synapse.events.presence_router import PresenceRouter @@ -213,7 +217,7 @@ class AbstractFederationSender(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod - def send_presence_to_destinations( + async def send_presence_to_destinations( self, states: Iterable[UserPresenceState], destinations: Iterable[str] ) -> None: """Send the given presence states to the given destinations. @@ -242,9 +246,11 @@ class AbstractFederationSender(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod - def send_device_messages(self, destination: str, immediate: bool = True) -> None: + async def send_device_messages( + self, destinations: StrCollection, immediate: bool = True + ) -> None: """Tells the sender that a new device message is ready to be sent to the - destination. The `immediate` flag specifies whether the messages should + destinations. The `immediate` flag specifies whether the messages should be tried to be sent immediately, or whether it can be delayed for a short while (to aid performance). """ @@ -716,6 +722,13 @@ class FederationSender(AbstractFederationSender): pdu.internal_metadata.stream_ordering, ) + destinations = await filter_destinations_by_retry_limiter( + destinations, + clock=self.clock, + store=self.store, + retry_due_within_ms=CATCHUP_RETRY_INTERVAL, + ) + for destination in destinations: self._get_per_destination_queue(destination).send_pdu(pdu) @@ -763,12 +776,20 @@ class FederationSender(AbstractFederationSender): domains_set = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation( room_id ) - domains = [ + domains: StrCollection = [ d for d in domains_set if not self.is_mine_server_name(d) and self._federation_shard_config.should_handle(self._instance_name, d) ] + + domains = await filter_destinations_by_retry_limiter( + domains, + clock=self.clock, + store=self.store, + retry_due_within_ms=CATCHUP_RETRY_INTERVAL, + ) + if not domains: return @@ -816,7 +837,7 @@ class FederationSender(AbstractFederationSender): for queue in queues: queue.flush_read_receipts_for_room(room_id) - def send_presence_to_destinations( + async def send_presence_to_destinations( self, states: Iterable[UserPresenceState], destinations: Iterable[str] ) -> None: """Send the given presence states to the given destinations. @@ -831,13 +852,20 @@ class FederationSender(AbstractFederationSender): for state in states: assert self.is_mine_id(state.user_id) + destinations = await filter_destinations_by_retry_limiter( + [ + d + for d in destinations + if self._federation_shard_config.should_handle(self._instance_name, d) + ], + clock=self.clock, + store=self.store, + retry_due_within_ms=CATCHUP_RETRY_INTERVAL, + ) + for destination in destinations: if self.is_mine_server_name(destination): continue - if not self._federation_shard_config.should_handle( - self._instance_name, destination - ): - continue self._get_per_destination_queue(destination).send_presence( states, start_loop=False @@ -896,21 +924,29 @@ class FederationSender(AbstractFederationSender): else: queue.send_edu(edu) - def send_device_messages(self, destination: str, immediate: bool = True) -> None: - if self.is_mine_server_name(destination): - logger.warning("Not sending device update to ourselves") - return - - if not self._federation_shard_config.should_handle( - self._instance_name, destination - ): - return + async def send_device_messages( + self, destinations: StrCollection, immediate: bool = True + ) -> None: + destinations = await filter_destinations_by_retry_limiter( + [ + destination + for destination in destinations + if self._federation_shard_config.should_handle( + self._instance_name, destination + ) + and not self.is_mine_server_name(destination) + ], + clock=self.clock, + store=self.store, + retry_due_within_ms=CATCHUP_RETRY_INTERVAL, + ) - if immediate: - self._get_per_destination_queue(destination).attempt_new_transaction() - else: - self._get_per_destination_queue(destination).mark_new_data() - self._destination_wakeup_queue.add_to_queue(destination) + for destination in destinations: + if immediate: + self._get_per_destination_queue(destination).attempt_new_transaction() + else: + self._get_per_destination_queue(destination).mark_new_data() + self._destination_wakeup_queue.add_to_queue(destination) def wake_destination(self, destination: str) -> None: """Called when we want to retry sending transactions to a remote. diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 31c5c2b7de..9105ba664c 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py
@@ -59,6 +59,10 @@ sent_edus_by_type = Counter( ) +# If the retry interval is larger than this then we enter "catchup" mode +CATCHUP_RETRY_INTERVAL = 60 * 60 * 1000 + + class PerDestinationQueue: """ Manages the per-destination transmission queues. @@ -370,7 +374,7 @@ class PerDestinationQueue: ), ) - if e.retry_interval > 60 * 60 * 1000: + if e.retry_interval > CATCHUP_RETRY_INTERVAL: # we won't retry for another hour! # (this suggests a significant outage) # We drop pending EDUs because otherwise they will diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 0afa1a3514..37903a79ec 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py
@@ -249,8 +249,10 @@ class TransportLayerClient: data=json_data, json_data_callback=json_data_callback, long_retries=True, - backoff_on_404=True, # If we get a 404 the other side has gone try_trailing_slash_on_400=True, + # Sending a transaction should always succeed, if it doesn't + # then something is wrong and we should backoff. + backoff_on_all_error_codes=True, ) async def make_query( @@ -475,13 +477,11 @@ class TransportLayerClient: See synapse.federation.federation_client.FederationClient.get_public_rooms for more information. """ + path = _create_v1_path("/publicRooms") + if search_filter: # this uses MSC2197 (Search Filtering over Federation) - path = _create_v1_path("/publicRooms") - - data: Dict[str, Any] = { - "include_all_networks": "true" if include_all_networks else "false" - } + data: Dict[str, Any] = {"include_all_networks": include_all_networks} if third_party_instance_id: data["third_party_instance_id"] = third_party_instance_id if limit: @@ -505,17 +505,15 @@ class TransportLayerClient: ) raise else: - path = _create_v1_path("/publicRooms") - args: Dict[str, Union[str, Iterable[str]]] = { "include_all_networks": "true" if include_all_networks else "false" } if third_party_instance_id: - args["third_party_instance_id"] = (third_party_instance_id,) + args["third_party_instance_id"] = third_party_instance_id if limit: - args["limit"] = [str(limit)] + args["limit"] = str(limit) if since_token: - args["since"] = [since_token] + args["since"] = since_token try: response = await self.client.get_json( diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 0e812a6d8b..2f0e5f3b0a 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py
@@ -76,6 +76,7 @@ class AdminHandler: "consent_ts", "user_type", "is_guest", + "last_seen_ts", } if self._msc3866_enabled: diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 59ecafa6a0..2b0c505130 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -218,19 +218,17 @@ class AuthHandler: self._failed_uia_attempts_ratelimiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count, + cfg=self.hs.config.ratelimiting.rc_login_failed_attempts, ) # The number of seconds to keep a UI auth session active. self._ui_auth_session_timeout = hs.config.auth.ui_auth_session_timeout - # Ratelimitier for failed /login attempts + # Ratelimiter for failed /login attempts self._failed_login_attempts_ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count, + cfg=self.hs.config.ratelimiting.rc_login_failed_attempts, ) self._clock = self.hs.get_clock() diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py
index 5c71637038..a850545453 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py
@@ -67,6 +67,7 @@ class CasHandler: self._cas_server_url = hs.config.cas.cas_server_url self._cas_service_url = hs.config.cas.cas_service_url + self._cas_protocol_version = hs.config.cas.cas_protocol_version self._cas_displayname_attribute = hs.config.cas.cas_displayname_attribute self._cas_required_attributes = hs.config.cas.cas_required_attributes @@ -121,7 +122,10 @@ class CasHandler: Returns: The parsed CAS response. """ - uri = self._cas_server_url + "/proxyValidate" + if self._cas_protocol_version == 3: + uri = self._cas_server_url + "/p3/proxyValidate" + else: + uri = self._cas_server_url + "/proxyValidate" args = { "ticket": ticket, "service": self._build_service_param(service_args), diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 5ae427d52c..763f56dfc1 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py
@@ -836,17 +836,16 @@ class DeviceHandler(DeviceWorkerHandler): user_id, hosts, ) - for host in hosts: - self.federation_sender.send_device_messages( - host, immediate=False - ) - # TODO: when called, this isn't in a logging context. - # This leads to log spam, sentry event spam, and massive - # memory usage. - # See https://github.com/matrix-org/synapse/issues/12552. - # log_kv( - # {"message": "sent device update to host", "host": host} - # ) + await self.federation_sender.send_device_messages( + hosts, immediate=False + ) + # TODO: when called, this isn't in a logging context. + # This leads to log spam, sentry event spam, and massive + # memory usage. + # See https://github.com/matrix-org/synapse/issues/12552. + # log_kv( + # {"message": "sent device update to host", "host": host} + # ) if current_stream_id != stream_id: # Clear the set of hosts we've already sent to as we're @@ -951,8 +950,9 @@ class DeviceHandler(DeviceWorkerHandler): # Notify things that device lists need to be sent out. self.notifier.notify_replication() - for host in potentially_changed_hosts: - self.federation_sender.send_device_messages(host, immediate=False) + await self.federation_sender.send_device_messages( + potentially_changed_hosts, immediate=False + ) def _update_device_from_client_ips( diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 17ff8821d9..1c79f7a61e 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py
@@ -90,8 +90,7 @@ class DeviceMessageHandler: self._ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=hs.config.ratelimiting.rc_key_requests.per_second, - burst_count=hs.config.ratelimiting.rc_key_requests.burst_count, + cfg=hs.config.ratelimiting.rc_key_requests, ) async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: @@ -303,10 +302,9 @@ class DeviceMessageHandler: ) if self.federation_sender: - for destination in remote_messages.keys(): - # Enqueue a new federation transaction to send the new - # device messages to each remote destination. - self.federation_sender.send_device_messages(destination) + # Enqueue a new federation transaction to send the new + # device messages to each remote destination. + await self.federation_sender.send_device_messages(remote_messages.keys()) async def get_events_for_dehydrated_device( self, diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 33359f6ed7..d12803bf0f 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py
@@ -67,6 +67,7 @@ class EventStreamHandler: context = await presence_handler.user_syncing( requester.user.to_string(), + requester.device_id, affect_presence=affect_presence, presence_state=PresenceState.ONLINE, ) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 3031384d25..472879c964 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py
@@ -66,14 +66,12 @@ class IdentityHandler: self._3pid_validation_ratelimiter_ip = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, - burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + cfg=hs.config.ratelimiting.rc_3pid_validation, ) self._3pid_validation_ratelimiter_address = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, - burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + cfg=hs.config.ratelimiting.rc_3pid_validation, ) async def ratelimit_request_token_requests( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index a74db1dccf..d6be18cdef 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py
@@ -379,7 +379,7 @@ class MessageHandler: """ expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) - if type(expiry_ts) is not int or event.is_state(): + if type(expiry_ts) is not int or event.is_state(): # noqa: E721 return # _schedule_expiry_for_event won't actually schedule anything if there's already @@ -908,19 +908,6 @@ class EventCreationHandler: if existing_event_id: return existing_event_id - # Some requsters don't have device IDs (appservice, guests, and access - # tokens minted with the admin API), fallback to checking the access token - # ID, which should be close enough. - if requester.access_token_id: - existing_event_id = ( - await self.store.get_event_id_from_transaction_id_and_token_id( - room_id, - requester.user.to_string(), - requester.access_token_id, - txn_id, - ) - ) - return existing_event_id async def get_event_from_transaction( @@ -1474,23 +1461,23 @@ class EventCreationHandler: # We now persist the event (and update the cache in parallel, since we # don't want to block on it). - event, context = events_and_context[0] + # + # Note: mypy gets confused if we inline dl and check with twisted#11770. + # Some kind of bug in mypy's deduction? + deferreds = ( + run_in_background( + self._persist_events, + requester=requester, + events_and_context=events_and_context, + ratelimit=ratelimit, + extra_users=extra_users, + ), + run_in_background( + self.cache_joined_hosts_for_events, events_and_context + ).addErrback(log_failure, "cache_joined_hosts_for_event failed"), + ) result, _ = await make_deferred_yieldable( - gather_results( - ( - run_in_background( - self._persist_events, - requester=requester, - events_and_context=events_and_context, - ratelimit=ratelimit, - extra_users=extra_users, - ), - run_in_background( - self.cache_joined_hosts_for_events, events_and_context - ).addErrback(log_failure, "cache_joined_hosts_for_event failed"), - ), - consumeErrors=True, - ) + gather_results(deferreds, consumeErrors=True) ).addErrback(unwrapFirstError) return result @@ -1921,7 +1908,10 @@ class EventCreationHandler: # We don't want to block sending messages on any presence code. This # matters as sometimes presence code can take a while. run_as_background_process( - "bump_presence_active_time", self._bump_active_time, requester.user + "bump_presence_active_time", + self._bump_active_time, + requester.user, + requester.device_id, ) async def _notify() -> None: @@ -1958,10 +1948,10 @@ class EventCreationHandler: logger.info("maybe_kick_guest_users %r", current_state) await self.hs.get_room_member_handler().kick_guest_users(current_state) - async def _bump_active_time(self, user: UserID) -> None: + async def _bump_active_time(self, user: UserID, device_id: Optional[str]) -> None: try: presence = self.hs.get_presence_handler() - await presence.bump_presence_active_time(user) + await presence.bump_presence_active_time(user, device_id) except Exception: logger.exception("Error bumping presence active time") diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index e8e9db4b91..f31e18328b 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py
@@ -23,6 +23,7 @@ The methods that define policy are: """ import abc import contextlib +import itertools import logging from bisect import bisect from contextlib import contextmanager @@ -151,15 +152,13 @@ class BasePresenceHandler(abc.ABC): self._federation_queue = PresenceFederationQueue(hs, self) - self._busy_presence_enabled = hs.config.experimental.msc3026_enabled - self.VALID_PRESENCE: Tuple[str, ...] = ( PresenceState.ONLINE, PresenceState.UNAVAILABLE, PresenceState.OFFLINE, ) - if self._busy_presence_enabled: + if hs.config.experimental.msc3026_enabled: self.VALID_PRESENCE += (PresenceState.BUSY,) active_presence = self.store.take_presence_startup_info() @@ -167,7 +166,11 @@ class BasePresenceHandler(abc.ABC): @abc.abstractmethod async def user_syncing( - self, user_id: str, affect_presence: bool, presence_state: str + self, + user_id: str, + device_id: Optional[str], + affect_presence: bool, + presence_state: str, ) -> ContextManager[None]: """Returns a context manager that should surround any stream requests from the user. @@ -178,6 +181,7 @@ class BasePresenceHandler(abc.ABC): Args: user_id: the user that is starting a sync + device_id: the user's device that is starting a sync affect_presence: If false this function will be a no-op. Useful for streams that are not associated with an actual client that is being used by a user. @@ -185,15 +189,17 @@ class BasePresenceHandler(abc.ABC): """ @abc.abstractmethod - def get_currently_syncing_users_for_replication(self) -> Iterable[str]: - """Get an iterable of syncing users on this worker, to send to the presence handler + def get_currently_syncing_users_for_replication( + self, + ) -> Iterable[Tuple[str, Optional[str]]]: + """Get an iterable of syncing users and devices on this worker, to send to the presence handler This is called when a replication connection is established. It should return - a list of user ids, which are then sent as USER_SYNC commands to inform the - process handling presence about those users. + a list of tuples of user ID & device ID, which are then sent as USER_SYNC commands + to inform the process handling presence about those users/devices. Returns: - An iterable of user_id strings. + An iterable of tuples of user ID and device ID. """ async def get_state(self, target_user: UserID) -> UserPresenceState: @@ -254,28 +260,39 @@ class BasePresenceHandler(abc.ABC): async def set_state( self, target_user: UserID, + device_id: Optional[str], state: JsonDict, - ignore_status_msg: bool = False, force_notify: bool = False, + is_sync: bool = False, ) -> None: """Set the presence state of the user. Args: target_user: The ID of the user to set the presence state of. + device_id: the device that the user is setting the presence state of. state: The presence state as a JSON dictionary. - ignore_status_msg: True to ignore the "status_msg" field of the `state` dict. - If False, the user's current status will be updated. force_notify: Whether to force notification of the update to clients. + is_sync: True if this update was from a sync, which results in + *not* overriding a previously set BUSY status, updating the + user's last_user_sync_ts, and ignoring the "status_msg" field of + the `state` dict. """ @abc.abstractmethod - async def bump_presence_active_time(self, user: UserID) -> None: + async def bump_presence_active_time( + self, user: UserID, device_id: Optional[str] + ) -> None: """We've seen the user do something that indicates they're interacting with the app. """ async def update_external_syncs_row( # noqa: B027 (no-op by design) - self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int + self, + process_id: str, + user_id: str, + device_id: Optional[str], + is_syncing: bool, + sync_time_msec: int, ) -> None: """Update the syncing users for an external process as a delta. @@ -286,6 +303,7 @@ class BasePresenceHandler(abc.ABC): syncing against. This allows synapse to process updates as user start and stop syncing against a given process. user_id: The user who has started or stopped syncing + device_id: The user's device that has started or stopped syncing is_syncing: Whether or not the user is now syncing sync_time_msec: Time in ms when the user was last syncing """ @@ -336,7 +354,9 @@ class BasePresenceHandler(abc.ABC): ) for destination, host_states in hosts_to_states.items(): - self._federation.send_presence_to_destinations(host_states, [destination]) + await self._federation.send_presence_to_destinations( + host_states, [destination] + ) async def send_full_presence_to_users(self, user_ids: StrCollection) -> None: """ @@ -381,7 +401,9 @@ class BasePresenceHandler(abc.ABC): # We set force_notify=True here so that this presence update is guaranteed to # increment the presence stream ID (which resending the current user's presence # otherwise would not do). - await self.set_state(UserID.from_string(user_id), state, force_notify=True) + await self.set_state( + UserID.from_string(user_id), None, state, force_notify=True + ) async def is_visible(self, observed_user: UserID, observer_user: UserID) -> bool: raise NotImplementedError( @@ -414,16 +436,18 @@ class WorkerPresenceHandler(BasePresenceHandler): hs.config.worker.writers.presence, ) - # The number of ongoing syncs on this process, by user id. + # The number of ongoing syncs on this process, by (user ID, device ID). # Empty if _presence_enabled is false. - self._user_to_num_current_syncs: Dict[str, int] = {} + self._user_device_to_num_current_syncs: Dict[ + Tuple[str, Optional[str]], int + ] = {} self.notifier = hs.get_notifier() self.instance_id = hs.get_instance_id() - # user_id -> last_sync_ms. Lists the users that have stopped syncing but - # we haven't notified the presence writer of that yet - self.users_going_offline: Dict[str, int] = {} + # (user_id, device_id) -> last_sync_ms. Lists the devices that have stopped + # syncing but we haven't notified the presence writer of that yet + self._user_devices_going_offline: Dict[Tuple[str, Optional[str]], int] = {} self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs) self._set_state_client = ReplicationPresenceSetState.make_client(hs) @@ -446,42 +470,54 @@ class WorkerPresenceHandler(BasePresenceHandler): ClearUserSyncsCommand(self.instance_id) ) - def send_user_sync(self, user_id: str, is_syncing: bool, last_sync_ms: int) -> None: + def send_user_sync( + self, + user_id: str, + device_id: Optional[str], + is_syncing: bool, + last_sync_ms: int, + ) -> None: if self._presence_enabled: self.hs.get_replication_command_handler().send_user_sync( - self.instance_id, user_id, is_syncing, last_sync_ms + self.instance_id, user_id, device_id, is_syncing, last_sync_ms ) - def mark_as_coming_online(self, user_id: str) -> None: + def mark_as_coming_online(self, user_id: str, device_id: Optional[str]) -> None: """A user has started syncing. Send a UserSync to the presence writer, unless they had recently stopped syncing. """ - going_offline = self.users_going_offline.pop(user_id, None) + going_offline = self._user_devices_going_offline.pop((user_id, device_id), None) if not going_offline: # Safe to skip because we haven't yet told the presence writer they # were offline - self.send_user_sync(user_id, True, self.clock.time_msec()) + self.send_user_sync(user_id, device_id, True, self.clock.time_msec()) - def mark_as_going_offline(self, user_id: str) -> None: + def mark_as_going_offline(self, user_id: str, device_id: Optional[str]) -> None: """A user has stopped syncing. We wait before notifying the presence writer as its likely they'll come back soon. This allows us to avoid sending a stopped syncing immediately followed by a started syncing notification to the presence writer """ - self.users_going_offline[user_id] = self.clock.time_msec() + self._user_devices_going_offline[(user_id, device_id)] = self.clock.time_msec() def send_stop_syncing(self) -> None: """Check if there are any users who have stopped syncing a while ago and haven't come back yet. If there are poke the presence writer about them. """ now = self.clock.time_msec() - for user_id, last_sync_ms in list(self.users_going_offline.items()): + for (user_id, device_id), last_sync_ms in list( + self._user_devices_going_offline.items() + ): if now - last_sync_ms > UPDATE_SYNCING_USERS_MS: - self.users_going_offline.pop(user_id, None) - self.send_user_sync(user_id, False, last_sync_ms) + self._user_devices_going_offline.pop((user_id, device_id), None) + self.send_user_sync(user_id, device_id, False, last_sync_ms) async def user_syncing( - self, user_id: str, affect_presence: bool, presence_state: str + self, + user_id: str, + device_id: Optional[str], + affect_presence: bool, + presence_state: str, ) -> ContextManager[None]: """Record that a user is syncing. @@ -491,36 +527,32 @@ class WorkerPresenceHandler(BasePresenceHandler): if not affect_presence or not self._presence_enabled: return _NullContextManager() - prev_state = await self.current_state_for_user(user_id) - if prev_state.state != PresenceState.BUSY: - # We set state here but pass ignore_status_msg = True as we don't want to - # cause the status message to be cleared. - # Note that this causes last_active_ts to be incremented which is not - # what the spec wants: see comment in the BasePresenceHandler version - # of this function. - await self.set_state( - UserID.from_string(user_id), - {"presence": presence_state}, - ignore_status_msg=True, - ) + # Note that this causes last_active_ts to be incremented which is not + # what the spec wants. + await self.set_state( + UserID.from_string(user_id), + device_id, + state={"presence": presence_state}, + is_sync=True, + ) - curr_sync = self._user_to_num_current_syncs.get(user_id, 0) - self._user_to_num_current_syncs[user_id] = curr_sync + 1 + curr_sync = self._user_device_to_num_current_syncs.get((user_id, device_id), 0) + self._user_device_to_num_current_syncs[(user_id, device_id)] = curr_sync + 1 - # If we went from no in flight sync to some, notify replication - if self._user_to_num_current_syncs[user_id] == 1: - self.mark_as_coming_online(user_id) + # If this is the first in-flight sync, notify replication + if self._user_device_to_num_current_syncs[(user_id, device_id)] == 1: + self.mark_as_coming_online(user_id, device_id) def _end() -> None: # We check that the user_id is in user_to_num_current_syncs because # user_to_num_current_syncs may have been cleared if we are # shutting down. - if user_id in self._user_to_num_current_syncs: - self._user_to_num_current_syncs[user_id] -= 1 + if (user_id, device_id) in self._user_device_to_num_current_syncs: + self._user_device_to_num_current_syncs[(user_id, device_id)] -= 1 - # If we went from one in flight sync to non, notify replication - if self._user_to_num_current_syncs[user_id] == 0: - self.mark_as_going_offline(user_id) + # If there are no more in-flight syncs, notify replication + if self._user_device_to_num_current_syncs[(user_id, device_id)] == 0: + self.mark_as_going_offline(user_id, device_id) @contextlib.contextmanager def _user_syncing() -> Generator[None, None, None]: @@ -587,28 +619,34 @@ class WorkerPresenceHandler(BasePresenceHandler): # If this is a federation sender, notify about presence updates. await self.maybe_send_presence_to_interested_destinations(state_to_notify) - def get_currently_syncing_users_for_replication(self) -> Iterable[str]: + def get_currently_syncing_users_for_replication( + self, + ) -> Iterable[Tuple[str, Optional[str]]]: return [ - user_id - for user_id, count in self._user_to_num_current_syncs.items() + user_id_device_id + for user_id_device_id, count in self._user_device_to_num_current_syncs.items() if count > 0 ] async def set_state( self, target_user: UserID, + device_id: Optional[str], state: JsonDict, - ignore_status_msg: bool = False, force_notify: bool = False, + is_sync: bool = False, ) -> None: """Set the presence state of the user. Args: target_user: The ID of the user to set the presence state of. + device_id: the device that the user is setting the presence state of. state: The presence state as a JSON dictionary. - ignore_status_msg: True to ignore the "status_msg" field of the `state` dict. - If False, the user's current status will be updated. force_notify: Whether to force notification of the update to clients. + is_sync: True if this update was from a sync, which results in + *not* overriding a previously set BUSY status, updating the + user's last_user_sync_ts, and ignoring the "status_msg" field of + the `state` dict. """ presence = state["presence"] @@ -625,12 +663,15 @@ class WorkerPresenceHandler(BasePresenceHandler): await self._set_state_client( instance_name=self._presence_writer_instance, user_id=user_id, + device_id=device_id, state=state, - ignore_status_msg=ignore_status_msg, force_notify=force_notify, + is_sync=is_sync, ) - async def bump_presence_active_time(self, user: UserID) -> None: + async def bump_presence_active_time( + self, user: UserID, device_id: Optional[str] + ) -> None: """We've seen the user do something that indicates they're interacting with the app. """ @@ -641,7 +682,9 @@ class WorkerPresenceHandler(BasePresenceHandler): # Proxy request to instance that writes presence user_id = user.to_string() await self._bump_active_client( - instance_name=self._presence_writer_instance, user_id=user_id + instance_name=self._presence_writer_instance, + user_id=user_id, + device_id=device_id, ) @@ -703,17 +746,23 @@ class PresenceHandler(BasePresenceHandler): # Keeps track of the number of *ongoing* syncs on this process. While # this is non zero a user will never go offline. - self.user_to_num_current_syncs: Dict[str, int] = {} + self._user_device_to_num_current_syncs: Dict[ + Tuple[str, Optional[str]], int + ] = {} # Keeps track of the number of *ongoing* syncs on other processes. + # # While any sync is ongoing on another process the user will never # go offline. + # # Each process has a unique identifier and an update frequency. If # no update is received from that process within the update period then # we assume that all the sync requests on that process have stopped. - # Stored as a dict from process_id to set of user_id, and a dict of - # process_id to millisecond timestamp last updated. - self.external_process_to_current_syncs: Dict[str, Set[str]] = {} + # Stored as a dict from process_id to set of (user_id, device_id), and + # a dict of process_id to millisecond timestamp last updated. + self.external_process_to_current_syncs: Dict[ + str, Set[Tuple[str, Optional[str]]] + ] = {} self.external_process_last_updated_ms: Dict[str, int] = {} self.external_sync_linearizer = Linearizer(name="external_sync_linearizer") @@ -889,7 +938,7 @@ class PresenceHandler(BasePresenceHandler): ) for destination, states in hosts_to_states.items(): - self._federation_queue.send_presence_to_destinations( + await self._federation_queue.send_presence_to_destinations( states, [destination] ) @@ -918,7 +967,10 @@ class PresenceHandler(BasePresenceHandler): # that were syncing on that process to see if they need to be timed # out. users_to_check.update( - self.external_process_to_current_syncs.pop(process_id, ()) + user_id + for user_id, device_id in self.external_process_to_current_syncs.pop( + process_id, () + ) ) self.external_process_last_updated_ms.pop(process_id) @@ -931,11 +983,15 @@ class PresenceHandler(BasePresenceHandler): syncing_user_ids = { user_id - for user_id, count in self.user_to_num_current_syncs.items() + for (user_id, _), count in self._user_device_to_num_current_syncs.items() if count } - for user_ids in self.external_process_to_current_syncs.values(): - syncing_user_ids.update(user_ids) + syncing_user_ids.update( + user_id + for user_id, _ in itertools.chain( + *self.external_process_to_current_syncs.values() + ) + ) changes = handle_timeouts( states, @@ -946,7 +1002,9 @@ class PresenceHandler(BasePresenceHandler): return await self._update_states(changes) - async def bump_presence_active_time(self, user: UserID) -> None: + async def bump_presence_active_time( + self, user: UserID, device_id: Optional[str] + ) -> None: """We've seen the user do something that indicates they're interacting with the app. """ @@ -969,6 +1027,7 @@ class PresenceHandler(BasePresenceHandler): async def user_syncing( self, user_id: str, + device_id: Optional[str], affect_presence: bool = True, presence_state: str = PresenceState.ONLINE, ) -> ContextManager[None]: @@ -980,7 +1039,8 @@ class PresenceHandler(BasePresenceHandler): when users disconnect/reconnect. Args: - user_id + user_id: the user that is starting a sync + device_id: the user's device that is starting a sync affect_presence: If false this function will be a no-op. Useful for streams that are not associated with an actual client that is being used by a user. @@ -989,52 +1049,21 @@ class PresenceHandler(BasePresenceHandler): if not affect_presence or not self._presence_enabled: return _NullContextManager() - curr_sync = self.user_to_num_current_syncs.get(user_id, 0) - self.user_to_num_current_syncs[user_id] = curr_sync + 1 + curr_sync = self._user_device_to_num_current_syncs.get((user_id, device_id), 0) + self._user_device_to_num_current_syncs[(user_id, device_id)] = curr_sync + 1 - prev_state = await self.current_state_for_user(user_id) - - # If they're busy then they don't stop being busy just by syncing, - # so just update the last sync time. - if prev_state.state != PresenceState.BUSY: - # XXX: We set_state separately here and just update the last_active_ts above - # This keeps the logic as similar as possible between the worker and single - # process modes. Using set_state will actually cause last_active_ts to be - # updated always, which is not what the spec calls for, but synapse has done - # this for... forever, I think. - await self.set_state( - UserID.from_string(user_id), - {"presence": presence_state}, - ignore_status_msg=True, - ) - # Retrieve the new state for the logic below. This should come from the - # in-memory cache. - prev_state = await self.current_state_for_user(user_id) - - # To keep the single process behaviour consistent with worker mode, run the - # same logic as `update_external_syncs_row`, even though it looks weird. - if prev_state.state == PresenceState.OFFLINE: - await self._update_states( - [ - prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=self.clock.time_msec(), - last_user_sync_ts=self.clock.time_msec(), - ) - ] - ) - # otherwise, set the new presence state & update the last sync time, - # but don't update last_active_ts as this isn't an indication that - # they've been active (even though it's probably been updated by - # set_state above) - else: - await self._update_states( - [prev_state.copy_and_replace(last_user_sync_ts=self.clock.time_msec())] - ) + # Note that this causes last_active_ts to be incremented which is not + # what the spec wants. + await self.set_state( + UserID.from_string(user_id), + device_id, + state={"presence": presence_state}, + is_sync=True, + ) async def _end() -> None: try: - self.user_to_num_current_syncs[user_id] -= 1 + self._user_device_to_num_current_syncs[(user_id, device_id)] -= 1 prev_state = await self.current_state_for_user(user_id) await self._update_states( @@ -1056,12 +1085,19 @@ class PresenceHandler(BasePresenceHandler): return _user_syncing() - def get_currently_syncing_users_for_replication(self) -> Iterable[str]: + def get_currently_syncing_users_for_replication( + self, + ) -> Iterable[Tuple[str, Optional[str]]]: # since we are the process handling presence, there is nothing to do here. return [] async def update_external_syncs_row( - self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int + self, + process_id: str, + user_id: str, + device_id: Optional[str], + is_syncing: bool, + sync_time_msec: int, ) -> None: """Update the syncing users for an external process as a delta. @@ -1070,6 +1106,7 @@ class PresenceHandler(BasePresenceHandler): syncing against. This allows synapse to process updates as user start and stop syncing against a given process. user_id: The user who has started or stopped syncing + device_id: The user's device that has started or stopped syncing is_syncing: Whether or not the user is now syncing sync_time_msec: Time in ms when the user was last syncing """ @@ -1080,31 +1117,27 @@ class PresenceHandler(BasePresenceHandler): process_id, set() ) - updates = [] - if is_syncing and user_id not in process_presence: - if prev_state.state == PresenceState.OFFLINE: - updates.append( - prev_state.copy_and_replace( - state=PresenceState.ONLINE, - last_active_ts=sync_time_msec, - last_user_sync_ts=sync_time_msec, - ) - ) - else: - updates.append( - prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec) - ) - process_presence.add(user_id) - elif user_id in process_presence: - updates.append( - prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec) + # USER_SYNC is sent when a user's device starts or stops syncing on + # a remote # process. (But only for the initial and last sync for that + # device.) + # + # When a device *starts* syncing it also calls set_state(...) which + # will update the state, last_active_ts, and last_user_sync_ts. + # Simply ensure the user & device is tracked as syncing in this case. + # + # When a device *stops* syncing, update the last_user_sync_ts and mark + # them as no longer syncing. Note this doesn't quite match the + # monolith behaviour, which updates last_user_sync_ts at the end of + # every sync, not just the last in-flight sync. + if is_syncing and (user_id, device_id) not in process_presence: + process_presence.add((user_id, device_id)) + elif not is_syncing and (user_id, device_id) in process_presence: + new_state = prev_state.copy_and_replace( + last_user_sync_ts=sync_time_msec ) + await self._update_states([new_state]) - if not is_syncing: - process_presence.discard(user_id) - - if updates: - await self._update_states(updates) + process_presence.discard((user_id, device_id)) self.external_process_last_updated_ms[process_id] = self.clock.time_msec() @@ -1118,7 +1151,9 @@ class PresenceHandler(BasePresenceHandler): process_presence = self.external_process_to_current_syncs.pop( process_id, set() ) - prev_states = await self.current_state_for_users(process_presence) + prev_states = await self.current_state_for_users( + {user_id for user_id, device_id in process_presence} + ) time_now_ms = self.clock.time_msec() await self._update_states( @@ -1203,18 +1238,22 @@ class PresenceHandler(BasePresenceHandler): async def set_state( self, target_user: UserID, + device_id: Optional[str], state: JsonDict, - ignore_status_msg: bool = False, force_notify: bool = False, + is_sync: bool = False, ) -> None: """Set the presence state of the user. Args: target_user: The ID of the user to set the presence state of. + device_id: the device that the user is setting the presence state of. state: The presence state as a JSON dictionary. - ignore_status_msg: True to ignore the "status_msg" field of the `state` dict. - If False, the user's current status will be updated. force_notify: Whether to force notification of the update to clients. + is_sync: True if this update was from a sync, which results in + *not* overriding a previously set BUSY status, updating the + user's last_user_sync_ts, and ignoring the "status_msg" field of + the `state` dict. """ status_msg = state.get("status_msg", None) presence = state["presence"] @@ -1227,18 +1266,27 @@ class PresenceHandler(BasePresenceHandler): return user_id = target_user.to_string() + now = self.clock.time_msec() prev_state = await self.current_state_for_user(user_id) + # Syncs do not override a previous presence of busy. + # + # TODO: This is a hack for lack of multi-device support. Unfortunately + # removing this requires coordination with clients. + if prev_state.state == PresenceState.BUSY and is_sync: + presence = PresenceState.BUSY + new_fields = {"state": presence} - if not ignore_status_msg: - new_fields["status_msg"] = status_msg + if presence == PresenceState.ONLINE or presence == PresenceState.BUSY: + new_fields["last_active_ts"] = now - if presence == PresenceState.ONLINE or ( - presence == PresenceState.BUSY and self._busy_presence_enabled - ): - new_fields["last_active_ts"] = self.clock.time_msec() + if is_sync: + new_fields["last_user_sync_ts"] = now + else: + # Syncs do not override the status message. + new_fields["status_msg"] = status_msg await self._update_states( [prev_state.copy_and_replace(**new_fields)], force_notify=force_notify @@ -1462,7 +1510,7 @@ class PresenceHandler(BasePresenceHandler): or state.status_msg is not None ] - self._federation_queue.send_presence_to_destinations( + await self._federation_queue.send_presence_to_destinations( destinations=newly_joined_remote_hosts, states=states, ) @@ -1473,7 +1521,7 @@ class PresenceHandler(BasePresenceHandler): prev_remote_hosts or newly_joined_remote_hosts ): local_states = await self.current_state_for_users(newly_joined_local_users) - self._federation_queue.send_presence_to_destinations( + await self._federation_queue.send_presence_to_destinations( destinations=prev_remote_hosts | newly_joined_remote_hosts, states=list(local_states.values()), ) @@ -2136,7 +2184,7 @@ class PresenceFederationQueue: index = bisect(self._queue, (clear_before,)) self._queue = self._queue[index:] - def send_presence_to_destinations( + async def send_presence_to_destinations( self, states: Collection[UserPresenceState], destinations: StrCollection ) -> None: """Send the presence states to the given destinations. @@ -2156,7 +2204,7 @@ class PresenceFederationQueue: return if self._federation: - self._federation.send_presence_to_destinations( + await self._federation.send_presence_to_destinations( states=states, destinations=destinations, ) @@ -2279,7 +2327,7 @@ class PresenceFederationQueue: for host, user_ids in hosts_to_users.items(): states = await self._presence_handler.current_state_for_users(user_ids) - self._federation.send_presence_to_destinations( + await self._federation.send_presence_to_destinations( states=states.values(), destinations=[host], ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index bb409f97b7..da146b0d45 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py
@@ -112,8 +112,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self._join_rate_limiter_local = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, - burst_count=hs.config.ratelimiting.rc_joins_local.burst_count, + cfg=hs.config.ratelimiting.rc_joins_local, ) # Tracks joins from local users to rooms this server isn't a member of. # I.e. joins this server makes by requesting /make_join /send_join from @@ -121,8 +120,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self._join_rate_limiter_remote = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, - burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, + cfg=hs.config.ratelimiting.rc_joins_remote, ) # TODO: find a better place to keep this Ratelimiter. # It needs to be @@ -135,8 +133,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self._join_rate_per_room_limiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second, - burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count, + cfg=hs.config.ratelimiting.rc_joins_per_room, ) # Ratelimiter for invites, keyed by room (across all issuers, all @@ -144,8 +141,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self._invites_per_room_limiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, - burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, + cfg=hs.config.ratelimiting.rc_invites_per_room, ) # Ratelimiter for invites, keyed by recipient (across all rooms, all @@ -153,8 +149,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self._invites_per_recipient_limiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, - burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, + cfg=hs.config.ratelimiting.rc_invites_per_user, ) # Ratelimiter for invites, keyed by issuer (across all rooms, all @@ -162,15 +157,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self._invites_per_issuer_limiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_invites_per_issuer.per_second, - burst_count=hs.config.ratelimiting.rc_invites_per_issuer.burst_count, + cfg=hs.config.ratelimiting.rc_invites_per_issuer, ) self._third_party_invite_limiter = Ratelimiter( store=self.store, clock=self.clock, - rate_hz=hs.config.ratelimiting.rc_third_party_invite.per_second, - burst_count=hs.config.ratelimiting.rc_third_party_invite.burst_count, + cfg=hs.config.ratelimiting.rc_third_party_invite, ) self.request_ratelimiter = hs.get_request_ratelimiter() diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index dad3e23470..dd559b4c45 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py
@@ -35,6 +35,7 @@ from synapse.api.errors import ( UnsupportedRoomVersionError, ) from synapse.api.ratelimiting import Ratelimiter +from synapse.config.ratelimiting import RatelimitSettings from synapse.events import EventBase from synapse.types import JsonDict, Requester, StrCollection from synapse.util.caches.response_cache import ResponseCache @@ -94,7 +95,9 @@ class RoomSummaryHandler: self._server_name = hs.hostname self._federation_client = hs.get_federation_client() self._ratelimiter = Ratelimiter( - store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10 + store=self._store, + clock=hs.get_clock(), + cfg=RatelimitSettings("<room summary>", per_second=5, burst_count=10), ) # If a user tries to fetch the same page multiple times in quick succession, diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py
index 804cc6e81e..05e21509de 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py
@@ -23,9 +23,11 @@ from pkg_resources import parse_version import twisted from twisted.internet.defer import Deferred -from twisted.internet.interfaces import IOpenSSLContextFactory +from twisted.internet.endpoints import HostnameEndpoint +from twisted.internet.interfaces import IOpenSSLContextFactory, IProtocolFactory from twisted.internet.ssl import optionsForClientTLS from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory +from twisted.protocols.tls import TLSMemoryBIOFactory from synapse.logging.context import make_deferred_yieldable from synapse.types import ISynapseReactor @@ -97,6 +99,7 @@ async def _sendmail( **kwargs, ) + factory: IProtocolFactory if _is_old_twisted: # before twisted 21.2, we have to override the ESMTPSender protocol to disable # TLS @@ -110,22 +113,13 @@ async def _sendmail( factory = build_sender_factory(hostname=smtphost if enable_tls else None) if force_tls: - reactor.connectSSL( - smtphost, - smtpport, - factory, - optionsForClientTLS(smtphost), - timeout=30, - bindAddress=None, - ) - else: - reactor.connectTCP( - smtphost, - smtpport, - factory, - timeout=30, - bindAddress=None, - ) + factory = TLSMemoryBIOFactory(optionsForClientTLS(smtphost), True, factory) + + endpoint = HostnameEndpoint( + reactor, smtphost, smtpport, timeout=30, bindAddress=None + ) + + await make_deferred_yieldable(endpoint.connect(factory)) await make_deferred_yieldable(d) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 7aeae5319c..4b4227003d 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py
@@ -26,9 +26,10 @@ from synapse.metrics.background_process_metrics import ( ) from synapse.replication.tcp.streams import TypingStream from synapse.streams import EventSource -from synapse.types import JsonDict, Requester, StreamKeyType, UserID +from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType, UserID from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure +from synapse.util.retryutils import filter_destinations_by_retry_limiter from synapse.util.wheel_timer import WheelTimer if TYPE_CHECKING: @@ -150,8 +151,15 @@ class FollowerTypingHandler: now=now, obj=member, then=now + FEDERATION_PING_INTERVAL ) - hosts = await self._storage_controllers.state.get_current_hosts_in_room( - member.room_id + hosts: StrCollection = ( + await self._storage_controllers.state.get_current_hosts_in_room( + member.room_id + ) + ) + hosts = await filter_destinations_by_retry_limiter( + hosts, + clock=self.clock, + store=self.store, ) for domain in hosts: if not self.is_mine_server_name(domain): diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 583c03447c..08c7fc1631 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py
@@ -243,7 +243,7 @@ class LegacyJsonSendParser(_BaseJsonParser[Tuple[int, JsonDict]]): return ( isinstance(v, list) and len(v) == 2 - and type(v[0]) == int + and type(v[0]) == int # noqa: E721 and isinstance(v[1], dict) ) @@ -512,6 +512,7 @@ class MatrixFederationHttpClient: long_retries: bool = False, ignore_backoff: bool = False, backoff_on_404: bool = False, + backoff_on_all_error_codes: bool = False, ) -> IResponse: """ Sends a request to the given server. @@ -552,6 +553,7 @@ class MatrixFederationHttpClient: and try the request anyway. backoff_on_404: Back off if we get a 404 + backoff_on_all_error_codes: Back off if we get any error response Returns: Resolves with the HTTP response object on success. @@ -594,6 +596,7 @@ class MatrixFederationHttpClient: ignore_backoff=ignore_backoff, notifier=self.hs.get_notifier(), replication_client=self.hs.get_replication_command_handler(), + backoff_on_all_error_codes=backoff_on_all_error_codes, ) method_bytes = request.method.encode("ascii") @@ -889,6 +892,7 @@ class MatrixFederationHttpClient: backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, parser: Literal[None] = None, + backoff_on_all_error_codes: bool = False, ) -> JsonDict: ... @@ -906,6 +910,7 @@ class MatrixFederationHttpClient: backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, parser: Optional[ByteParser[T]] = None, + backoff_on_all_error_codes: bool = False, ) -> T: ... @@ -922,6 +927,7 @@ class MatrixFederationHttpClient: backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, parser: Optional[ByteParser[T]] = None, + backoff_on_all_error_codes: bool = False, ) -> Union[JsonDict, T]: """Sends the specified json data using PUT @@ -957,6 +963,7 @@ class MatrixFederationHttpClient: enabled. parser: The parser to use to decode the response. Defaults to parsing as JSON. + backoff_on_all_error_codes: Back off if we get any error response Returns: Succeeds when we get a 2xx HTTP response. The @@ -990,6 +997,7 @@ class MatrixFederationHttpClient: ignore_backoff=ignore_backoff, long_retries=long_retries, timeout=timeout, + backoff_on_all_error_codes=backoff_on_all_error_codes, ) if timeout is not None: diff --git a/synapse/http/server.py b/synapse/http/server.py
index 5109cec983..3bbf91298e 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py
@@ -115,7 +115,13 @@ def return_json_error( if exc.headers is not None: for header, value in exc.headers.items(): request.setHeader(header, value) - logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg) + error_ctx = exc.debug_context + if error_ctx: + logger.info( + "%s SynapseError: %s - %s (%s)", request, error_code, exc.msg, error_ctx + ) + else: + logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg) elif f.check(CancelledError): error_code = HTTP_STATUS_REQUEST_CANCELLED error_dict = {"error": "Request cancelled", "errcode": Codes.UNKNOWN} diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index b78d6e17c9..98c6038ff2 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py
@@ -44,6 +44,7 @@ _IGNORED_LOG_RECORD_ATTRIBUTES = { "processName", "relativeCreated", "stack_info", + "taskName", "thread", "threadName", } diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index f62bea968f..64c6ae4512 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py
@@ -809,23 +809,24 @@ def run_in_background( # type: ignore[misc] # `res` may be a coroutine, `Deferred`, some other kind of awaitable, or a plain # value. Convert it to a `Deferred`. + d: "defer.Deferred[R]" if isinstance(res, typing.Coroutine): # Wrap the coroutine in a `Deferred`. - res = defer.ensureDeferred(res) + d = defer.ensureDeferred(res) elif isinstance(res, defer.Deferred): - pass + d = res elif isinstance(res, Awaitable): # `res` is probably some kind of completed awaitable, such as a `DoneAwaitable` # or `Future` from `make_awaitable`. - res = defer.ensureDeferred(_unwrap_awaitable(res)) + d = defer.ensureDeferred(_unwrap_awaitable(res)) else: # `res` is a plain value. Wrap it in a `Deferred`. - res = defer.succeed(res) + d = defer.succeed(res) - if res.called and not res.paused: + if d.called and not d.paused: # The function should have maintained the logcontext, so we can # optimise out the messing about - return res + return d # The function may have reset the context before returning, so # we need to restore it now. @@ -843,8 +844,8 @@ def run_in_background( # type: ignore[misc] # which is supposed to have a single entry and exit point. But # by spawning off another deferred, we are effectively # adding a new exit point.) - res.addBoth(_set_context_cb, ctx) - return res + d.addBoth(_set_context_cb, ctx) + return d T = TypeVar("T") @@ -877,7 +878,7 @@ def make_deferred_yieldable(deferred: "defer.Deferred[T]") -> "defer.Deferred[T] ResultT = TypeVar("ResultT") -def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT: +def _set_context_cb(result: ResultT, context: LoggingContextOrSentinel) -> ResultT: """A callback function which just sets the logging context""" set_current_context(context) return result diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index be910128aa..5c3045e197 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py
@@ -910,10 +910,10 @@ def _custom_sync_async_decorator( async def _wrapper( *args: P.args, **kwargs: P.kwargs ) -> Any: # Return type is RInner - with wrapping_logic(func, *args, **kwargs): - # type-ignore: func() returns R, but mypy doesn't know that R is - # Awaitable here. - return await func(*args, **kwargs) # type: ignore[misc] + # type-ignore: func() returns R, but mypy doesn't know that R is + # Awaitable here. + with wrapping_logic(func, *args, **kwargs): # type: ignore[arg-type] + return await func(*args, **kwargs) else: # The other case here handles sync functions including those decorated with @@ -980,8 +980,7 @@ def trace_with_opname( See the module's doc string for usage examples. """ - # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909 - @contextlib.contextmanager # type: ignore[arg-type] + @contextlib.contextmanager def _wrapping_logic( func: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> Generator[None, None, None]: @@ -1024,8 +1023,7 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]: if not opentracing: return func - # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909 - @contextlib.contextmanager # type: ignore[arg-type] + @contextlib.contextmanager def _wrapping_logic( func: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> Generator[None, None, None]: diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 4b750c700b..1b7b014f9a 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py
@@ -214,7 +214,10 @@ class MediaRepository: user_id=auth_user, ) - await self._generate_thumbnails(None, media_id, media_id, media_type) + try: + await self._generate_thumbnails(None, media_id, media_id, media_type) + except Exception as e: + logger.info("Failed to generate thumbnails: %s", e) return MXCUri(self.server_name, media_id) diff --git a/synapse/media/oembed.py b/synapse/media/oembed.py
index 5ad9eec80b..2ce842c98d 100644 --- a/synapse/media/oembed.py +++ b/synapse/media/oembed.py
@@ -204,7 +204,7 @@ class OEmbedProvider: calc_description_and_urls(open_graph_response, oembed["html"]) for size in ("width", "height"): val = oembed.get(size) - if type(val) is int: + if type(val) is int: # noqa: E721 open_graph_response[f"og:video:{size}"] = val elif oembed_type == "link": diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py
index 2bfa58ceee..d8979813b3 100644 --- a/synapse/media/thumbnailer.py +++ b/synapse/media/thumbnailer.py
@@ -78,7 +78,7 @@ class Thumbnailer: image_exif = self.image._getexif() # type: ignore if image_exif is not None: image_orientation = image_exif.get(EXIF_ORIENTATION_TAG) - assert type(image_orientation) is int + assert type(image_orientation) is int # noqa: E721 self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation) except Exception as e: # A lot of parsing errors can happen when parsing EXIF diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 9ad8e038ae..2f00a7ba20 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py
@@ -1180,7 +1180,7 @@ class ModuleApi: # Send to remote destinations. destination = UserID.from_string(user).domain - presence_handler.get_federation_queue().send_presence_to_destinations( + await presence_handler.get_federation_queue().send_presence_to_destinations( presence_events, [destination] ) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 990c079c81..554634579e 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -379,7 +379,7 @@ class BulkPushRuleEvaluator: keys = list(notification_levels.keys()) for key in keys: level = notification_levels.get(key, SENTINEL) - if level is not SENTINEL and type(level) is not int: + if level is not SENTINEL and type(level) is not int: # noqa: E721 try: notification_levels[key] = int(level) except (TypeError, ValueError): @@ -472,7 +472,11 @@ StateGroup = Union[object, int] def _is_simple_value(value: Any) -> bool: - return isinstance(value, (bool, str)) or type(value) is int or value is None + return ( + isinstance(value, (bool, str)) + or type(value) is int # noqa: E721 + or value is None + ) def _flatten_dict( diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index 73f3de3642..209833d287 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py
@@ -62,7 +62,7 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint): NAME = "multi_user_device_resync" PATH_ARGS = () - CACHE = False + CACHE = True def __init__(self, hs: "HomeServer"): super().__init__(hs) diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py
index db16aac9c2..6c9e79fb07 100644 --- a/synapse/replication/http/presence.py +++ b/synapse/replication/http/presence.py
@@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Optional, Tuple from twisted.web.server import Request @@ -51,14 +51,14 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint): self._presence_handler = hs.get_presence_handler() @staticmethod - async def _serialize_payload(user_id: str) -> JsonDict: # type: ignore[override] - return {} + async def _serialize_payload(user_id: str, device_id: Optional[str]) -> JsonDict: # type: ignore[override] + return {"device_id": device_id} async def _handle_request( # type: ignore[override] self, request: Request, content: JsonDict, user_id: str ) -> Tuple[int, JsonDict]: await self._presence_handler.bump_presence_active_time( - UserID.from_string(user_id) + UserID.from_string(user_id), content.get("device_id") ) return (200, {}) @@ -73,8 +73,8 @@ class ReplicationPresenceSetState(ReplicationEndpoint): { "state": { ... }, - "ignore_status_msg": false, - "force_notify": false + "force_notify": false, + "is_sync": false } 200 OK @@ -95,14 +95,16 @@ class ReplicationPresenceSetState(ReplicationEndpoint): @staticmethod async def _serialize_payload( # type: ignore[override] user_id: str, + device_id: Optional[str], state: JsonDict, - ignore_status_msg: bool = False, force_notify: bool = False, + is_sync: bool = False, ) -> JsonDict: return { + "device_id": device_id, "state": state, - "ignore_status_msg": ignore_status_msg, "force_notify": force_notify, + "is_sync": is_sync, } async def _handle_request( # type: ignore[override] @@ -110,9 +112,10 @@ class ReplicationPresenceSetState(ReplicationEndpoint): ) -> Tuple[int, JsonDict]: await self._presence_handler.set_state( UserID.from_string(user_id), + content.get("device_id"), content["state"], - content["ignore_status_msg"], content["force_notify"], + content.get("is_sync", False), ) return (200, {}) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 3b88dc68ea..51285e6d33 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py
@@ -422,7 +422,7 @@ class FederationSenderHandler: # The federation stream contains things that we want to send out, e.g. # presence, typing, etc. if stream_name == "federation": - send_queue.process_rows_for_federation(self.federation_sender, rows) + await send_queue.process_rows_for_federation(self.federation_sender, rows) await self.update_token(token) # ... and when new receipts happen @@ -439,16 +439,14 @@ class FederationSenderHandler: for row in rows if not row.entity.startswith("@") and not row.is_signature } - for host in hosts: - self.federation_sender.send_device_messages(host, immediate=False) + await self.federation_sender.send_device_messages(hosts, immediate=False) elif stream_name == ToDeviceStream.NAME: # The to_device stream includes stuff to be pushed to both local # clients and remote servers, so we ignore entities that start with # '@' (since they'll be local users rather than destinations). hosts = {row.entity for row in rows if not row.entity.startswith("@")} - for host in hosts: - self.federation_sender.send_device_messages(host) + await self.federation_sender.send_device_messages(hosts) async def _on_new_receipts( self, rows: Iterable[ReceiptsStream.ReceiptsStreamRow] diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 10f5c98ff8..e616b5e1c8 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py
@@ -267,27 +267,38 @@ class UserSyncCommand(Command): NAME = "USER_SYNC" def __init__( - self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int + self, + instance_id: str, + user_id: str, + device_id: Optional[str], + is_syncing: bool, + last_sync_ms: int, ): self.instance_id = instance_id self.user_id = user_id + self.device_id = device_id self.is_syncing = is_syncing self.last_sync_ms = last_sync_ms @classmethod def from_line(cls: Type["UserSyncCommand"], line: str) -> "UserSyncCommand": - instance_id, user_id, state, last_sync_ms = line.split(" ", 3) + device_id: Optional[str] + instance_id, user_id, device_id, state, last_sync_ms = line.split(" ", 4) + + if device_id == "None": + device_id = None if state not in ("start", "end"): raise Exception("Invalid USER_SYNC state %r" % (state,)) - return cls(instance_id, user_id, state == "start", int(last_sync_ms)) + return cls(instance_id, user_id, device_id, state == "start", int(last_sync_ms)) def to_line(self) -> str: return " ".join( ( self.instance_id, self.user_id, + str(self.device_id), "start" if self.is_syncing else "end", str(self.last_sync_ms), ) @@ -452,6 +463,17 @@ class LockReleasedCommand(Command): return json_encoder.encode([self.instance_name, self.lock_name, self.lock_key]) +class NewActiveTaskCommand(_SimpleCommand): + """Sent to inform instance handling background tasks that a new active task is available to run. + + Format:: + + NEW_ACTIVE_TASK "<task_id>" + """ + + NAME = "NEW_ACTIVE_TASK" + + _COMMANDS: Tuple[Type[Command], ...] = ( ServerCommand, RdataCommand, @@ -466,6 +488,7 @@ _COMMANDS: Tuple[Type[Command], ...] = ( RemoteServerUpCommand, ClearUserSyncsCommand, LockReleasedCommand, + NewActiveTaskCommand, ) # Map of command name to command type. diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 38adcbe1d0..d9045d7b73 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py
@@ -40,6 +40,7 @@ from synapse.replication.tcp.commands import ( Command, FederationAckCommand, LockReleasedCommand, + NewActiveTaskCommand, PositionCommand, RdataCommand, RemoteServerUpCommand, @@ -238,6 +239,10 @@ class ReplicationCommandHandler: if self._is_master: self._server_notices_sender = hs.get_server_notices_sender() + self._task_scheduler = None + if hs.config.worker.run_background_tasks: + self._task_scheduler = hs.get_task_scheduler() + if hs.config.redis.redis_enabled: # If we're using Redis, it's the background worker that should # receive USER_IP commands and store the relevant client IPs. @@ -423,7 +428,11 @@ class ReplicationCommandHandler: if self._is_presence_writer: return self._presence_handler.update_external_syncs_row( - cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms + cmd.instance_id, + cmd.user_id, + cmd.device_id, + cmd.is_syncing, + cmd.last_sync_ms, ) else: return None @@ -663,6 +672,15 @@ class ReplicationCommandHandler: cmd.instance_name, cmd.lock_name, cmd.lock_key ) + async def on_NEW_ACTIVE_TASK( + self, conn: IReplicationConnection, cmd: NewActiveTaskCommand + ) -> None: + """Called when get a new NEW_ACTIVE_TASK command.""" + if self._task_scheduler: + task = await self._task_scheduler.get_task(cmd.data) + if task: + await self._task_scheduler._launch_task(task) + def new_connection(self, connection: IReplicationConnection) -> None: """Called when we have a new connection.""" self._connections.append(connection) @@ -685,9 +703,9 @@ class ReplicationCommandHandler: ) now = self._clock.time_msec() - for user_id in currently_syncing: + for user_id, device_id in currently_syncing: connection.send_command( - UserSyncCommand(self._instance_id, user_id, True, now) + UserSyncCommand(self._instance_id, user_id, device_id, True, now) ) def lost_connection(self, connection: IReplicationConnection) -> None: @@ -739,11 +757,16 @@ class ReplicationCommandHandler: self.send_command(FederationAckCommand(self._instance_name, token)) def send_user_sync( - self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int + self, + instance_id: str, + user_id: str, + device_id: Optional[str], + is_syncing: bool, + last_sync_ms: int, ) -> None: """Poke the master that a user has started/stopped syncing.""" self.send_command( - UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) + UserSyncCommand(instance_id, user_id, device_id, is_syncing, last_sync_ms) ) def send_user_ip( @@ -776,6 +799,10 @@ class ReplicationCommandHandler: if instance_name == self._instance_name: self.send_command(LockReleasedCommand(instance_name, lock_name, lock_key)) + def send_new_active_task(self, task_id: str) -> None: + """Called when a new task has been scheduled for immediate launch and is ACTIVE.""" + self.send_command(NewActiveTaskCommand(task_id)) + UpdateToken = TypeVar("UpdateToken") UpdateRow = TypeVar("UpdateRow") diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 55e752fda8..94170715fb 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py
@@ -157,7 +157,7 @@ class PurgeHistoryRestServlet(RestServlet): logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) elif "purge_up_to_ts" in body: ts = body["purge_up_to_ts"] - if type(ts) is not int: + if type(ts) is not int: # noqa: E721 raise SynapseError( HTTPStatus.BAD_REQUEST, "purge_up_to_ts must be an int", diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
index 95e751288b..ffce92d45e 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py
@@ -143,7 +143,7 @@ class NewRegistrationTokenRestServlet(RestServlet): else: # Get length of token to generate (default is 16) length = body.get("length", 16) - if type(length) is not int: + if type(length) is not int: # noqa: E721 raise SynapseError( HTTPStatus.BAD_REQUEST, "length must be an integer", @@ -163,7 +163,8 @@ class NewRegistrationTokenRestServlet(RestServlet): uses_allowed = body.get("uses_allowed", None) if not ( - uses_allowed is None or (type(uses_allowed) is int and uses_allowed >= 0) + uses_allowed is None + or (type(uses_allowed) is int and uses_allowed >= 0) # noqa: E721 ): raise SynapseError( HTTPStatus.BAD_REQUEST, @@ -172,13 +173,16 @@ class NewRegistrationTokenRestServlet(RestServlet): ) expiry_time = body.get("expiry_time", None) - if type(expiry_time) not in (int, type(None)): + if expiry_time is not None and type(expiry_time) is not int: # noqa: E721 raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must be an integer or null", Codes.INVALID_PARAM, ) - if type(expiry_time) is int and expiry_time < self.clock.time_msec(): + if ( + type(expiry_time) is int # noqa: E721 + and expiry_time < self.clock.time_msec() + ): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must not be in the past", @@ -283,7 +287,7 @@ class RegistrationTokenRestServlet(RestServlet): uses_allowed = body["uses_allowed"] if not ( uses_allowed is None - or (type(uses_allowed) is int and uses_allowed >= 0) + or (type(uses_allowed) is int and uses_allowed >= 0) # noqa: E721 ): raise SynapseError( HTTPStatus.BAD_REQUEST, @@ -294,13 +298,16 @@ class RegistrationTokenRestServlet(RestServlet): if "expiry_time" in body: expiry_time = body["expiry_time"] - if type(expiry_time) not in (int, type(None)): + if expiry_time is not None and type(expiry_time) is not int: # noqa: E721 raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must be an integer or null", Codes.INVALID_PARAM, ) - if type(expiry_time) is int and expiry_time < self.clock.time_msec(): + if ( + type(expiry_time) is int # noqa: E721 + and expiry_time < self.clock.time_msec() + ): raise SynapseError( HTTPStatus.BAD_REQUEST, "expiry_time must not be in the past", diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 240e6254b0..91898a5c13 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py
@@ -132,6 +132,7 @@ class UsersRestServletV2(RestServlet): UserSortOrder.AVATAR_URL.value, UserSortOrder.SHADOW_BANNED.value, UserSortOrder.CREATION_TS.value, + UserSortOrder.LAST_SEEN_TS.value, ), ) @@ -1172,14 +1173,17 @@ class RateLimitRestServlet(RestServlet): messages_per_second = body.get("messages_per_second", 0) burst_count = body.get("burst_count", 0) - if type(messages_per_second) is not int or messages_per_second < 0: + if ( + type(messages_per_second) is not int # noqa: E721 + or messages_per_second < 0 + ): raise SynapseError( HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (messages_per_second,), errcode=Codes.INVALID_PARAM, ) - if type(burst_count) is not int or burst_count < 0: + if type(burst_count) is not int or burst_count < 0: # noqa: E721 raise SynapseError( HTTPStatus.BAD_REQUEST, "%r parameter must be a positive int" % (burst_count,), diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index d724c68920..7be327e26f 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py
@@ -120,14 +120,12 @@ class LoginRestServlet(RestServlet): self._address_ratelimiter = Ratelimiter( store=self._main_store, clock=hs.get_clock(), - rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second, - burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count, + cfg=self.hs.config.ratelimiting.rc_login_address, ) self._account_ratelimiter = Ratelimiter( store=self._main_store, clock=hs.get_clock(), - rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second, - burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count, + cfg=self.hs.config.ratelimiting.rc_login_account, ) # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance. diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py
index b1629f94a5..d189a923b5 100644 --- a/synapse/rest/client/login_token_request.py +++ b/synapse/rest/client/login_token_request.py
@@ -16,6 +16,7 @@ import logging from typing import TYPE_CHECKING, Tuple from synapse.api.ratelimiting import Ratelimiter +from synapse.config.ratelimiting import RatelimitSettings from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest @@ -66,15 +67,18 @@ class LoginTokenRequestServlet(RestServlet): self.token_timeout = hs.config.auth.login_via_existing_token_timeout self._require_ui_auth = hs.config.auth.login_via_existing_require_ui_auth - # Ratelimit aggressively to a maxmimum of 1 request per minute. + # Ratelimit aggressively to a maximum of 1 request per minute. # # This endpoint can be used to spawn additional sessions and could be # abused by a malicious client to create many sessions. self._ratelimiter = Ratelimiter( store=self._main_store, clock=hs.get_clock(), - rate_hz=1 / 60, - burst_count=1, + cfg=RatelimitSettings( + key="<login token request>", + per_second=1 / 60, + burst_count=1, + ), ) @interactive_auth_handler diff --git a/synapse/rest/client/presence.py b/synapse/rest/client/presence.py
index 8e193330f8..d578faa969 100644 --- a/synapse/rest/client/presence.py +++ b/synapse/rest/client/presence.py
@@ -97,7 +97,7 @@ class PresenceStatusRestServlet(RestServlet): raise SynapseError(400, "Unable to parse state") if self._use_presence: - await self.presence_handler.set_state(user, state) + await self.presence_handler.set_state(user, requester.device_id, state) return 200, {} diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 4f96e51eeb..1707e51972 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py
@@ -52,7 +52,9 @@ class ReadMarkerRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - await self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time( + requester.user, requester.device_id + ) body = parse_json_object_from_request(request) diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 316e7b9982..869a374459 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py
@@ -94,7 +94,9 @@ class ReceiptRestServlet(RestServlet): Codes.INVALID_PARAM, ) - await self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time( + requester.user, requester.device_id + ) if receipt_type == ReceiptTypes.FULLY_READ: await self.read_marker_handler.received_client_read_marker( diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 77e3b91b79..132623462a 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py
@@ -376,8 +376,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): self.ratelimiter = Ratelimiter( store=self.store, clock=hs.get_clock(), - rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second, - burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count, + cfg=hs.config.ratelimiting.rc_registration_token_validity, ) async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py
index ac1a63ca27..ee93e459f6 100644 --- a/synapse/rest/client/report_event.py +++ b/synapse/rest/client/report_event.py
@@ -55,7 +55,7 @@ class ReportEventRestServlet(RestServlet): "Param 'reason' must be a string", Codes.BAD_JSON, ) - if type(body.get("score", 0)) is not int: + if type(body.get("score", 0)) is not int: # noqa: E721 raise SynapseError( HTTPStatus.BAD_REQUEST, "Param 'score' must be an integer", diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index dc498001e4..553938ce9d 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py
@@ -1229,7 +1229,9 @@ class RoomTypingRestServlet(RestServlet): content = parse_json_object_from_request(request) - await self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time( + requester.user, requester.device_id + ) # Limit timeout to stop people from setting silly typing timeouts. timeout = min(content.get("timeout", 30000), 120000) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index d7854ed4fd..42bdd3bb10 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py
@@ -205,6 +205,7 @@ class SyncRestServlet(RestServlet): context = await self.presence_handler.user_syncing( user.to_string(), + requester.device_id, affect_presence=affect_presence, presence_state=set_presence, ) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 981fd1f58a..0aaa838d04 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -16,6 +16,7 @@ import logging import re from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple +from pydantic import Extra, StrictInt, StrictStr from signedjson.sign import sign_json from twisted.web.server import Request @@ -24,9 +25,10 @@ from synapse.crypto.keyring import ServerKeyFetcher from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, + parse_and_validate_json_object_from_request, parse_integer, - parse_json_object_from_request, ) +from synapse.rest.models import RequestBodyModel from synapse.storage.keys import FetchKeyResultForRemote from synapse.types import JsonDict from synapse.util import json_decoder @@ -38,6 +40,13 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class _KeyQueryCriteriaDataModel(RequestBodyModel): + class Config: + extra = Extra.allow + + minimum_valid_until_ts: Optional[StrictInt] + + class RemoteKey(RestServlet): """HTTP resource for retrieving the TLS certificate and NACL signature verification keys for a collection of servers. Checks that the reported @@ -96,6 +105,9 @@ class RemoteKey(RestServlet): CATEGORY = "Federation requests" + class PostBody(RequestBodyModel): + server_keys: Dict[StrictStr, Dict[StrictStr, _KeyQueryCriteriaDataModel]] + def __init__(self, hs: "HomeServer"): self.fetcher = ServerKeyFetcher(hs) self.store = hs.get_datastores().main @@ -137,24 +149,29 @@ class RemoteKey(RestServlet): ) minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") - arguments = {} - if minimum_valid_until_ts is not None: - arguments["minimum_valid_until_ts"] = minimum_valid_until_ts - query = {server: {key_id: arguments}} + query = { + server: { + key_id: _KeyQueryCriteriaDataModel( + minimum_valid_until_ts=minimum_valid_until_ts + ) + } + } else: query = {server: {}} return 200, await self.query_keys(query, query_remote_on_cache_miss=True) async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) + content = parse_and_validate_json_object_from_request(request, self.PostBody) - query = content["server_keys"] + query = content.server_keys return 200, await self.query_keys(query, query_remote_on_cache_miss=True) async def query_keys( - self, query: JsonDict, query_remote_on_cache_miss: bool = False + self, + query: Dict[str, Dict[str, _KeyQueryCriteriaDataModel]], + query_remote_on_cache_miss: bool = False, ) -> JsonDict: logger.info("Handling query for keys %r", query) @@ -196,8 +213,10 @@ class RemoteKey(RestServlet): else: ts_added_ms = key_result.added_ts ts_valid_until_ms = key_result.valid_until_ts - req_key = query.get(server_name, {}).get(key_id, {}) - req_valid_until = req_key.get("minimum_valid_until_ts") + req_key = query.get(server_name, {}).get( + key_id, _KeyQueryCriteriaDataModel(minimum_valid_until_ts=None) + ) + req_valid_until = req_key.minimum_valid_until_ts if req_valid_until is not None: if ts_valid_until_ms < req_valid_until: logger.debug( diff --git a/synapse/server.py b/synapse/server.py
index 8f5e4fc140..71ead524d6 100644 --- a/synapse/server.py +++ b/synapse/server.py
@@ -408,8 +408,7 @@ class HomeServer(metaclass=abc.ABCMeta): return Ratelimiter( store=self.get_datastores().main, clock=self.get_clock(), - rate_hz=self.config.ratelimiting.rc_registration.per_second, - burst_count=self.config.ratelimiting.rc_registration.burst_count, + cfg=self.config.ratelimiting.rc_registration, ) @cache_in_self diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index ddca0af1da..7619f405fa 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py
@@ -405,14 +405,14 @@ class BackgroundUpdater: try: result = await self.do_next_background_update(sleep) back_to_back_failures = 0 - except Exception: + except Exception as e: + logger.exception("Error doing update: %s", e) back_to_back_failures += 1 if back_to_back_failures >= 5: self._aborted = True raise RuntimeError( "5 back-to-back background update failures; aborting." ) - logger.exception("Error doing update") else: if result: logger.info( diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a1c8fb0f46..55ac313f33 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -31,6 +31,7 @@ from typing import ( Iterator, List, Optional, + Sequence, Tuple, Type, TypeVar, @@ -358,7 +359,21 @@ class LoggingTransaction: return self.txn.rowcount @property - def description(self) -> Any: + def description( + self, + ) -> Optional[ + Sequence[ + Tuple[ + str, + Optional[Any], + Optional[int], + Optional[int], + Optional[int], + Optional[int], + Optional[int], + ] + ] + ]: return self.txn.description def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None: diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index a85633efcd..0836e247ef 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -277,6 +277,10 @@ class DataStore( FROM users as u LEFT JOIN profiles AS p ON u.name = p.full_user_id LEFT JOIN erased_users AS eu ON u.name = eu.user_id + LEFT JOIN ( + SELECT user_id, MAX(last_seen) AS last_seen_ts + FROM user_ips GROUP BY user_id + ) ls ON u.name = ls.user_id {where_clause} """ sql = "SELECT COUNT(*) as total_users " + sql_base @@ -286,7 +290,7 @@ class DataStore( sql = f""" SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url, creation_ts * 1000 as creation_ts, approved, - eu.user_id is not null as erased + eu.user_id is not null as erased, last_seen_ts {sql_base} ORDER BY {order_by_column} {order}, u.name ASC LIMIT ? OFFSET ? diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index c1353b18c1..0c1ed75240 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -978,26 +978,12 @@ class PersistEventsStore: """Persist the mapping from transaction IDs to event IDs (if defined).""" inserted_ts = self._clock.time_msec() - to_insert_token_id: List[Tuple[str, str, str, int, str, int]] = [] to_insert_device_id: List[Tuple[str, str, str, str, str, int]] = [] for event, _ in events_and_contexts: txn_id = getattr(event.internal_metadata, "txn_id", None) - token_id = getattr(event.internal_metadata, "token_id", None) device_id = getattr(event.internal_metadata, "device_id", None) if txn_id is not None: - if token_id is not None: - to_insert_token_id.append( - ( - event.event_id, - event.room_id, - event.sender, - token_id, - txn_id, - inserted_ts, - ) - ) - if device_id is not None: to_insert_device_id.append( ( @@ -1010,26 +996,7 @@ class PersistEventsStore: ) ) - # Synapse usually relies on the device_id to scope transactions for events, - # except for users without device IDs (appservice, guests, and access - # tokens minted with the admin API) which use the access token ID instead. - # - # TODO https://github.com/matrix-org/synapse/issues/16042 - if to_insert_token_id: - self.db_pool.simple_insert_many_txn( - txn, - table="event_txn_id", - keys=( - "event_id", - "room_id", - "user_id", - "token_id", - "txn_id", - "inserted_ts", - ), - values=to_insert_token_id, - ) - + # Synapse relies on the device_id to scope transactions for events.. if to_insert_device_id: self.db_pool.simple_insert_many_txn( txn, @@ -1671,7 +1638,7 @@ class PersistEventsStore: if self._ephemeral_messages_enabled: # If there's an expiry timestamp on the event, store it. expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) - if type(expiry_ts) is int and not event.is_state(): + if type(expiry_ts) is int and not event.is_state(): # noqa: E721 self._insert_event_expiry_txn(txn, event.event_id, expiry_ts) # Insert into the room_memberships table. @@ -2039,10 +2006,10 @@ class PersistEventsStore: ): if ( "min_lifetime" in event.content - and type(event.content["min_lifetime"]) is not int + and type(event.content["min_lifetime"]) is not int # noqa: E721 ) or ( "max_lifetime" in event.content - and type(event.content["max_lifetime"]) is not int + and type(event.content["max_lifetime"]) is not int # noqa: E721 ): # Ignore the event if one of the value isn't an integer. return diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index d0dd455aec..943666ed4f 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -2022,25 +2022,6 @@ class EventsWorkerStore(SQLBaseStore): desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) - async def get_event_id_from_transaction_id_and_token_id( - self, room_id: str, user_id: str, token_id: int, txn_id: str - ) -> Optional[str]: - """Look up if we have already persisted an event for the transaction ID, - returning the event ID if so. - """ - return await self.db_pool.simple_select_one_onecol( - table="event_txn_id", - keyvalues={ - "room_id": room_id, - "user_id": user_id, - "token_id": token_id, - "txn_id": txn_id, - }, - retcol="event_id", - allow_none=True, - desc="get_event_id_from_transaction_id_and_token_id", - ) - async def get_event_id_from_transaction_id_and_device_id( self, room_id: str, user_id: str, device_id: str, txn_id: str ) -> Optional[str]: @@ -2072,29 +2053,35 @@ class EventsWorkerStore(SQLBaseStore): """ mapping = {} - txn_id_to_event: Dict[Tuple[str, int, str], str] = {} + txn_id_to_event: Dict[Tuple[str, str, str, str], str] = {} for event in events: - token_id = getattr(event.internal_metadata, "token_id", None) + device_id = getattr(event.internal_metadata, "device_id", None) txn_id = getattr(event.internal_metadata, "txn_id", None) - if token_id and txn_id: + if device_id and txn_id: # Check if this is a duplicate of an event in the given events. - existing = txn_id_to_event.get((event.room_id, token_id, txn_id)) + existing = txn_id_to_event.get( + (event.room_id, event.sender, device_id, txn_id) + ) if existing: mapping[event.event_id] = existing continue # Check if this is a duplicate of an event we've already # persisted. - existing = await self.get_event_id_from_transaction_id_and_token_id( - event.room_id, event.sender, token_id, txn_id + existing = await self.get_event_id_from_transaction_id_and_device_id( + event.room_id, event.sender, device_id, txn_id ) if existing: mapping[event.event_id] = existing - txn_id_to_event[(event.room_id, token_id, txn_id)] = existing + txn_id_to_event[ + (event.room_id, event.sender, device_id, txn_id) + ] = existing else: - txn_id_to_event[(event.room_id, token_id, txn_id)] = event.event_id + txn_id_to_event[ + (event.room_id, event.sender, device_id, txn_id) + ] = event.event_id return mapping diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
index 54d40e7a3a..5a01ec2137 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py
@@ -17,7 +17,7 @@ from types import TracebackType from typing import TYPE_CHECKING, Collection, Optional, Set, Tuple, Type from weakref import WeakValueDictionary -from twisted.internet.interfaces import IReactorCore +from twisted.internet.task import LoopingCall from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore @@ -26,6 +26,7 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) +from synapse.types import ISynapseReactor from synapse.util import Clock from synapse.util.stringutils import random_string @@ -358,7 +359,7 @@ class Lock: def __init__( self, - reactor: IReactorCore, + reactor: ISynapseReactor, clock: Clock, store: LockStore, read_write: bool, @@ -377,19 +378,25 @@ class Lock: self._table = "worker_read_write_locks" if read_write else "worker_locks" - self._looping_call = clock.looping_call( + # We might be called from a non-main thread, so we defer setting up the + # looping call. + self._looping_call: Optional[LoopingCall] = None + reactor.callFromThread(self._setup_looping_call) + + self._dropped = False + + def _setup_looping_call(self) -> None: + self._looping_call = self._clock.looping_call( self._renew, _RENEWAL_INTERVAL_MS, - store, - clock, - read_write, - lock_name, - lock_key, - token, + self._store, + self._clock, + self._read_write, + self._lock_name, + self._lock_key, + self._token, ) - self._dropped = False - @staticmethod @wrap_as_background_process("Lock._renew") async def _renew( @@ -459,7 +466,7 @@ class Lock: if self._dropped: return - if self._looping_call.running: + if self._looping_call and self._looping_call.running: self._looping_call.stop() await self._store.db_pool.simple_delete( @@ -486,8 +493,9 @@ class Lock: # We should not be dropped without the lock being released (unless # we're shutting down), but if we are then let's at least stop # renewing the lock. - if self._looping_call.running: - self._looping_call.stop() + if self._looping_call and self._looping_call.running: + # We might be called from a non-main thread. + self._reactor.callFromThread(self._looping_call.stop) if self._reactor.running: logger.error( diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index c13c0bc7d7..bec0dc2afe 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -88,7 +88,6 @@ def _load_rules( msc1767_enabled=experimental_config.msc1767_enabled, msc3664_enabled=experimental_config.msc3664_enabled, msc3381_polls_enabled=experimental_config.msc3381_polls_enabled, - msc3958_suppress_edits_enabled=experimental_config.msc3958_supress_edit_notifs, ) return filtered_rules diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index d3a01d526f..7e85b73e8e 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -206,8 +206,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): consent_server_notice_sent, appservice_id, creation_ts, user_type, deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned, COALESCE(approved, TRUE) AS approved, - COALESCE(locked, FALSE) AS locked + COALESCE(locked, FALSE) AS locked, last_seen_ts FROM users + LEFT JOIN ( + SELECT user_id, MAX(last_seen) AS last_seen_ts + FROM user_ips GROUP BY user_id + ) ls ON users.name = ls.user_id WHERE name = ? """, (user_id,), @@ -268,6 +272,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): is_shadow_banned=bool(user_data["shadow_banned"]), user_id=UserID.from_string(user_data["name"]), user_type=user_data["user_type"], + last_seen_ts=user_data["last_seen_ts"], ) async def is_trial_user(self, user_id: str) -> bool: diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 6298f0984d..3a2966b9e4 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py
@@ -107,6 +107,7 @@ class UserSortOrder(Enum): AVATAR_URL = "avatar_url" SHADOW_BANNED = "shadow_banned" CREATION_TS = "creation_ts" + LAST_SEEN_TS = "last_seen_ts" class StatsStore(StateDeltasStore): diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 860bbf7c0f..efd21b5bfc 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py
@@ -14,7 +14,7 @@ import logging from enum import Enum -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, cast import attr from canonicaljson import encode_canonical_json @@ -28,8 +28,8 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore -from synapse.types import JsonDict -from synapse.util.caches.descriptors import cached +from synapse.types import JsonDict, StrCollection +from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: from synapse.server import HomeServer @@ -205,6 +205,26 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): else: return None + @cachedList( + cached_method_name="get_destination_retry_timings", list_name="destinations" + ) + async def get_destination_retry_timings_batch( + self, destinations: StrCollection + ) -> Dict[str, Optional[DestinationRetryTimings]]: + rows = await self.db_pool.simple_select_many_batch( + table="destinations", + iterable=destinations, + column="destination", + retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"), + desc="get_destination_retry_timings_batch", + ) + + return { + row.pop("destination"): DestinationRetryTimings(**row) + for row in rows + if row["retry_last_ts"] and row["failure_ts"] and row["retry_interval"] + } + async def set_destination_retry_timings( self, destination: str, diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 649d3c8e9f..422f11f59e 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 80 # remember to update the list below when updating +SCHEMA_VERSION = 81 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -114,19 +114,15 @@ Changes in SCHEMA_VERSION = 79 Changes in SCHEMA_VERSION = 80 - The event_txn_id_device_id is always written to for new events. - Add tables for the task scheduler. + +Changes in SCHEMA_VERSION = 81 + - The event_txn_id is no longer written to for new events. """ SCHEMA_COMPAT_VERSION = ( - # Queries against `event_stream_ordering` columns in membership tables must - # be disambiguated. - # - # The threads_id column must written to with non-null values for the - # event_push_actions, event_push_actions_staging, and event_push_summary tables. - # - # insertions to the column `full_user_id` of tables profiles and user_filters can no - # longer be null - 76 + # The `event_txn_id_device_id` must be written to for new events. + 80 ) """Limit on how far the synapse codebase can be rolled back without breaking db compat diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index e750417189..488714f60c 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py
@@ -946,6 +946,7 @@ class UserInfo: is_guest: True if the user is a guest user. is_shadow_banned: True if the user has been shadow-banned. user_type: User type (None for normal user, 'support' and 'bot' other options). + last_seen_ts: Last activity timestamp of the user. """ user_id: UserID @@ -958,6 +959,7 @@ class UserInfo: is_deactivated: bool is_guest: bool is_shadow_banned: bool + last_seen_ts: Optional[int] class UserProfile(TypedDict): diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py
index bf7bd351e0..029eedcc6f 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py
@@ -470,7 +470,7 @@ class CacheMultipleEntries(CacheEntry[KT, VT]): def deferred(self, key: KT) -> "defer.Deferred[VT]": if not self._deferred: self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) - return self._deferred.observe().addCallback(lambda res: res.get(key)) + return self._deferred.observe().addCallback(lambda res: res[key]) def add_invalidation_callback( self, key: KT, callback: Optional[Callable[[], None]] diff --git a/synapse/util/check_dependencies.py b/synapse/util/check_dependencies.py
index 114130a08f..f7cead9e12 100644 --- a/synapse/util/check_dependencies.py +++ b/synapse/util/check_dependencies.py
@@ -51,9 +51,9 @@ class DependencyException(Exception): DEV_EXTRAS = {"lint", "mypy", "test", "dev"} -RUNTIME_EXTRAS = ( - set(metadata.metadata(DISTRIBUTION_NAME).get_all("Provides-Extra")) - DEV_EXTRAS -) +ALL_EXTRAS = metadata.metadata(DISTRIBUTION_NAME).get_all("Provides-Extra") +assert ALL_EXTRAS is not None +RUNTIME_EXTRAS = set(ALL_EXTRAS) - DEV_EXTRAS VERSION = metadata.version(DISTRIBUTION_NAME) diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index cde4a0780f..f693ba2a8c 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py
@@ -291,7 +291,8 @@ class _PerHostRatelimiter: if self.metrics_name: rate_limit_reject_counter.labels(self.metrics_name).inc() raise LimitExceededError( - retry_after_ms=int(self.window_size / self.sleep_limit) + limiter_name="rc_federation", + retry_after_ms=int(self.window_size / self.sleep_limit), ) self.request_times.append(time_now) diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 27e9fc976c..0e1f907667 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Any, Optional, Type from synapse.api.errors import CodeMessageException from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage import DataStore +from synapse.types import StrCollection from synapse.util import Clock if TYPE_CHECKING: @@ -116,6 +117,30 @@ async def get_retry_limiter( ) +async def filter_destinations_by_retry_limiter( + destinations: StrCollection, + clock: Clock, + store: DataStore, + retry_due_within_ms: int = 0, +) -> StrCollection: + """Filter down the list of destinations to only those that will are either + alive or due for a retry (within `retry_due_within_ms`) + """ + if not destinations: + return destinations + + retry_timings = await store.get_destination_retry_timings_batch(destinations) + + now = int(clock.time_msec()) + + return [ + destination + for destination, timings in retry_timings.items() + if timings is None + or timings.retry_last_ts + timings.retry_interval <= now + retry_due_within_ms + ] + + class RetryDestinationLimiter: def __init__( self, @@ -128,6 +153,7 @@ class RetryDestinationLimiter: backoff_on_failure: bool = True, notifier: Optional["Notifier"] = None, replication_client: Optional["ReplicationCommandHandler"] = None, + backoff_on_all_error_codes: bool = False, ): """Marks the destination as "down" if an exception is thrown in the context, except for CodeMessageException with code < 500. @@ -147,6 +173,9 @@ class RetryDestinationLimiter: backoff_on_failure: set to False if we should not increase the retry interval on a failure. + + backoff_on_all_error_codes: Whether we should back off on any + error code. """ self.clock = clock self.store = store @@ -156,6 +185,7 @@ class RetryDestinationLimiter: self.retry_interval = retry_interval self.backoff_on_404 = backoff_on_404 self.backoff_on_failure = backoff_on_failure + self.backoff_on_all_error_codes = backoff_on_all_error_codes self.notifier = notifier self.replication_client = replication_client @@ -179,6 +209,7 @@ class RetryDestinationLimiter: exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: + success = exc_type is None valid_err_code = False if exc_type is None: valid_err_code = True @@ -195,7 +226,9 @@ class RetryDestinationLimiter: # won't accept our requests for at least a while. # 429 is us being aggressively rate limited, so lets rate limit # ourselves. - if exc_val.code == 404 and self.backoff_on_404: + if self.backoff_on_all_error_codes: + valid_err_code = False + elif exc_val.code == 404 and self.backoff_on_404: valid_err_code = False elif exc_val.code in (401, 429): valid_err_code = False @@ -204,7 +237,7 @@ class RetryDestinationLimiter: else: valid_err_code = False - if valid_err_code: + if success: # We connected successfully. if not self.retry_interval: return @@ -215,6 +248,12 @@ class RetryDestinationLimiter: self.failure_ts = None retry_last_ts = 0 self.retry_interval = 0 + elif valid_err_code: + # We got a potentially valid error code back. We don't reset the + # timers though, as the other side might actually be down anyway + # (e.g. some deprovisioned servers will always return a 404 or 403, + # and we don't want to keep resetting the retry timers for them). + return elif not self.backoff_on_failure: return else: diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py
index 4aea64b338..9e89aeb748 100644 --- a/synapse/util/task_scheduler.py +++ b/synapse/util/task_scheduler.py
@@ -57,14 +57,13 @@ class TaskScheduler: the code launching the task. You can also specify the `result` (and/or an `error`) when returning from the function. - The reconciliation loop runs every 5 mns, so this is not a precise scheduler. When wanting - to launch now, the launch will still not happen before the next loop run. - - Tasks will be run on the worker specified with `run_background_tasks_on` config, - or the main one by default. + The reconciliation loop runs every minute, so this is not a precise scheduler. There is a limit of 10 concurrent tasks, so tasks may be delayed if the pool is already full. In this regard, please take great care that scheduled tasks can actually finished. For now there is no mechanism to stop a running task if it is stuck. + + Tasks will be run on the worker specified with `run_background_tasks_on` config, + or the main one by default. """ # Precision of the scheduler, evaluation of tasks to run will only happen @@ -85,7 +84,7 @@ class TaskScheduler: self._actions: Dict[ str, Callable[ - [ScheduledTask, bool], + [ScheduledTask], Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]], ], ] = {} @@ -98,11 +97,13 @@ class TaskScheduler: "handle_scheduled_tasks", self._handle_scheduled_tasks, ) + else: + self.replication_client = hs.get_replication_command_handler() def register_action( self, function: Callable[ - [ScheduledTask, bool], + [ScheduledTask], Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]], ], action_name: str, @@ -115,10 +116,9 @@ class TaskScheduler: calling `schedule_task` but rather in an `__init__` method. Args: - function: The function to be executed for this action. The parameters - passed to the function when launched are the `ScheduledTask` being run, - and a `first_launch` boolean to signal if it's a resumed task or the first - launch of it. The function should return a tuple of new `status`, `result` + function: The function to be executed for this action. The parameter + passed to the function when launched is the `ScheduledTask` being run. + The function should return a tuple of new `status`, `result` and `error` as specified in `ScheduledTask`. action_name: The name of the action to be associated with the function """ @@ -171,6 +171,12 @@ class TaskScheduler: ) await self._store.insert_scheduled_task(task) + if status == TaskStatus.ACTIVE: + if self._run_background_tasks: + await self._launch_task(task) + else: + self.replication_client.send_new_active_task(task.id) + return task.id async def update_task( @@ -265,21 +271,13 @@ class TaskScheduler: Args: id: id of the task to delete """ - if self.task_is_running(id): - raise Exception(f"Task {id} is currently running and can't be deleted") + task = await self.get_task(id) + if task is None: + raise Exception(f"Task {id} does not exist") + if task.status == TaskStatus.ACTIVE: + raise Exception(f"Task {id} is currently ACTIVE and can't be deleted") await self._store.delete_scheduled_task(id) - def task_is_running(self, id: str) -> bool: - """Check if a task is currently running. - - Can only be called from the worker handling the task scheduling. - - Args: - id: id of the task to check - """ - assert self._run_background_tasks - return id in self._running_tasks - async def _handle_scheduled_tasks(self) -> None: """Main loop taking care of launching tasks and cleaning up old ones.""" await self._launch_scheduled_tasks() @@ -288,29 +286,11 @@ class TaskScheduler: async def _launch_scheduled_tasks(self) -> None: """Retrieve and launch scheduled tasks that should be running at that time.""" for task in await self.get_tasks(statuses=[TaskStatus.ACTIVE]): - if not self.task_is_running(task.id): - if ( - len(self._running_tasks) - < TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS - ): - await self._launch_task(task, first_launch=False) - else: - if ( - self._clock.time_msec() - > task.timestamp + TaskScheduler.LAST_UPDATE_BEFORE_WARNING_MS - ): - logger.warn( - f"Task {task.id} (action {task.action}) has seen no update for more than 24h and may be stuck" - ) + await self._launch_task(task) for task in await self.get_tasks( statuses=[TaskStatus.SCHEDULED], max_timestamp=self._clock.time_msec() ): - if ( - not self.task_is_running(task.id) - and len(self._running_tasks) - < TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS - ): - await self._launch_task(task, first_launch=True) + await self._launch_task(task) running_tasks_gauge.set(len(self._running_tasks)) @@ -320,27 +300,27 @@ class TaskScheduler: statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE] ): # FAILED and COMPLETE tasks should never be running - assert not self.task_is_running(task.id) + assert task.id not in self._running_tasks if ( self._clock.time_msec() > task.timestamp + TaskScheduler.KEEP_TASKS_FOR_MS ): await self._store.delete_scheduled_task(task.id) - async def _launch_task(self, task: ScheduledTask, first_launch: bool) -> None: + async def _launch_task(self, task: ScheduledTask) -> None: """Launch a scheduled task now. Args: task: the task to launch - first_launch: `True` if it's the first time is launched, `False` otherwise """ - assert task.action in self._actions + assert self._run_background_tasks + assert task.action in self._actions function = self._actions[task.action] async def wrapper() -> None: try: - (status, result, error) = await function(task, first_launch) + (status, result, error) = await function(task) except Exception: f = Failure() logger.error( @@ -360,6 +340,20 @@ class TaskScheduler: ) self._running_tasks.remove(task.id) + if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: + return + + if ( + self._clock.time_msec() + > task.timestamp + TaskScheduler.LAST_UPDATE_BEFORE_WARNING_MS + ): + logger.warn( + f"Task {task.id} (action {task.action}) has seen no update for more than 24h and may be stuck" + ) + + if task.id in self._running_tasks: + return + self._running_tasks.add(task.id) await self.update_task(task.id, status=TaskStatus.ACTIVE) description = f"{task.id}-{task.action}"