diff options
Diffstat (limited to 'synapse')
52 files changed, 1634 insertions, 741 deletions
diff --git a/synapse/api/errors.py b/synapse/api/errors.py index af894243f8..3546aaf7c3 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -217,6 +217,13 @@ class InvalidAPICallError(SynapseError): super().__init__(HTTPStatus.BAD_REQUEST, msg, Codes.BAD_JSON) +class InvalidProxyCredentialsError(SynapseError): + """Error raised when the proxy credentials are invalid.""" + + def __init__(self, msg: str, errcode: str = Codes.UNKNOWN): + super().__init__(401, msg, errcode) + + class ProxiedRequestError(SynapseError): """An error from a general matrix endpoint, eg. from a proxied Matrix API call. diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index d9fdda9745..9b8a19ece5 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -80,36 +80,29 @@ class RoomVersion: # MSC2209: Check 'notifications' key while verifying # m.room.power_levels auth rules. limit_notifications_power_levels: bool - # MSC2175: No longer include the creator in m.room.create events. - msc2175_implicit_room_creator: bool - # MSC2174/MSC2176: Apply updated redaction rules algorithm, move redacts to - # content property. - msc2176_redaction_rules: bool - # MSC3083: Support the 'restricted' join_rule. - msc3083_join_rules: bool - # MSC3375: Support for the proper redaction rules for MSC3083. This mustn't - # be enabled if MSC3083 is not. - msc3375_redaction_rules: bool - # MSC2403: Allows join_rules to be set to 'knock', changes auth rules to allow sending - # m.room.membership event with membership 'knock'. - msc2403_knocking: bool + # No longer include the creator in m.room.create events. + implicit_room_creator: bool + # Apply updated redaction rules algorithm from room version 11. + updated_redaction_rules: bool + # Support the 'restricted' join rule. + restricted_join_rule: bool + # Support for the proper redaction rules for the restricted join rule. This requires + # restricted_join_rule to be enabled. + restricted_join_rule_fix: bool + # Support the 'knock' join rule. + knock_join_rule: bool # MSC3389: Protect relation information from redaction. msc3389_relation_redactions: bool - # MSC3787: Adds support for a `knock_restricted` join rule, mixing concepts of - # knocks and restricted join rules into the same join condition. - msc3787_knock_restricted_join_rule: bool - # MSC3667: Enforce integer power levels - msc3667_int_only_power_levels: bool - # MSC3821: Do not redact the third_party_invite content field for membership events. - msc3821_redaction_rules: bool + # Support the 'knock_restricted' join rule. + knock_restricted_join_rule: bool + # Enforce integer power levels + enforce_int_power_levels: bool # MSC3931: Adds a push rule condition for "room version feature flags", making # some push rules room version dependent. Note that adding a flag to this list # is not enough to mark it "supported": the push rule evaluator also needs to # support the flag. Unknown flags are ignored by the evaluator, making conditions # fail if used. msc3931_push_features: Tuple[str, ...] # values from PushRuleRoomFlag - # MSC3989: Redact the origin field. - msc3989_redaction_rules: bool linearized_matrix: bool @@ -123,17 +116,15 @@ class RoomVersions: special_case_aliases_auth=True, strict_canonicaljson=False, limit_notifications_power_levels=False, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=False, - msc3375_redaction_rules=False, - msc2403_knocking=False, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=False, + restricted_join_rule_fix=False, + knock_join_rule=False, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, + knock_restricted_join_rule=False, + enforce_int_power_levels=False, msc3931_push_features=(), - msc3989_redaction_rules=False, linearized_matrix=False, ) V2 = RoomVersion( @@ -145,17 +136,15 @@ class RoomVersions: special_case_aliases_auth=True, strict_canonicaljson=False, limit_notifications_power_levels=False, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=False, - msc3375_redaction_rules=False, - msc2403_knocking=False, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=False, + restricted_join_rule_fix=False, + knock_join_rule=False, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, + knock_restricted_join_rule=False, + enforce_int_power_levels=False, msc3931_push_features=(), - msc3989_redaction_rules=False, linearized_matrix=False, ) V3 = RoomVersion( @@ -167,17 +156,15 @@ class RoomVersions: special_case_aliases_auth=True, strict_canonicaljson=False, limit_notifications_power_levels=False, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=False, - msc3375_redaction_rules=False, - msc2403_knocking=False, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=False, + restricted_join_rule_fix=False, + knock_join_rule=False, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, + knock_restricted_join_rule=False, + enforce_int_power_levels=False, msc3931_push_features=(), - msc3989_redaction_rules=False, linearized_matrix=False, ) V4 = RoomVersion( @@ -189,17 +176,15 @@ class RoomVersions: special_case_aliases_auth=True, strict_canonicaljson=False, limit_notifications_power_levels=False, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=False, - msc3375_redaction_rules=False, - msc2403_knocking=False, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=False, + restricted_join_rule_fix=False, + knock_join_rule=False, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, + knock_restricted_join_rule=False, + enforce_int_power_levels=False, msc3931_push_features=(), - msc3989_redaction_rules=False, linearized_matrix=False, ) V5 = RoomVersion( @@ -211,17 +196,15 @@ class RoomVersions: special_case_aliases_auth=True, strict_canonicaljson=False, limit_notifications_power_levels=False, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=False, - msc3375_redaction_rules=False, - msc2403_knocking=False, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=False, + restricted_join_rule_fix=False, + knock_join_rule=False, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, + knock_restricted_join_rule=False, + enforce_int_power_levels=False, msc3931_push_features=(), - msc3989_redaction_rules=False, linearized_matrix=False, ) V6 = RoomVersion( @@ -233,39 +216,15 @@ class RoomVersions: special_case_aliases_auth=False, strict_canonicaljson=True, limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=False, - msc3375_redaction_rules=False, - msc2403_knocking=False, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=False, + restricted_join_rule_fix=False, + knock_join_rule=False, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, + knock_restricted_join_rule=False, + enforce_int_power_levels=False, msc3931_push_features=(), - msc3989_redaction_rules=False, - linearized_matrix=False, - ) - MSC2176 = RoomVersion( - "org.matrix.msc2176", - RoomDisposition.UNSTABLE, - EventFormatVersions.ROOM_V4_PLUS, - StateResolutionVersions.V2, - enforce_key_validity=True, - special_case_aliases_auth=False, - strict_canonicaljson=True, - limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=True, - msc3083_join_rules=False, - msc3375_redaction_rules=False, - msc2403_knocking=False, - msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, - msc3931_push_features=(), - msc3989_redaction_rules=False, linearized_matrix=False, ) V7 = RoomVersion( @@ -277,17 +236,15 @@ class RoomVersions: special_case_aliases_auth=False, strict_canonicaljson=True, limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=False, - msc3375_redaction_rules=False, - msc2403_knocking=True, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=False, + restricted_join_rule_fix=False, + knock_join_rule=True, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, + knock_restricted_join_rule=False, + enforce_int_power_levels=False, msc3931_push_features=(), - msc3989_redaction_rules=False, linearized_matrix=False, ) V8 = RoomVersion( @@ -299,17 +256,15 @@ class RoomVersions: special_case_aliases_auth=False, strict_canonicaljson=True, limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=True, - msc3375_redaction_rules=False, - msc2403_knocking=True, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=True, + restricted_join_rule_fix=False, + knock_join_rule=True, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, + knock_restricted_join_rule=False, + enforce_int_power_levels=False, msc3931_push_features=(), - msc3989_redaction_rules=False, linearized_matrix=False, ) V9 = RoomVersion( @@ -321,61 +276,15 @@ class RoomVersions: special_case_aliases_auth=False, strict_canonicaljson=True, limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=True, - msc3375_redaction_rules=True, - msc2403_knocking=True, - msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, - msc3931_push_features=(), - msc3989_redaction_rules=False, - linearized_matrix=False, - ) - MSC3787 = RoomVersion( - "org.matrix.msc3787", - RoomDisposition.UNSTABLE, - EventFormatVersions.ROOM_V4_PLUS, - StateResolutionVersions.V2, - enforce_key_validity=True, - special_case_aliases_auth=False, - strict_canonicaljson=True, - limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=True, - msc3375_redaction_rules=True, - msc2403_knocking=True, - msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=True, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=False, - msc3931_push_features=(), - msc3989_redaction_rules=False, - linearized_matrix=False, - ) - MSC3821 = RoomVersion( - "org.matrix.msc3821.opt1", - RoomDisposition.UNSTABLE, - EventFormatVersions.ROOM_V4_PLUS, - StateResolutionVersions.V2, - enforce_key_validity=True, - special_case_aliases_auth=False, - strict_canonicaljson=True, - limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=True, - msc3375_redaction_rules=True, - msc2403_knocking=True, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=True, + restricted_join_rule_fix=True, + knock_join_rule=True, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=False, - msc3821_redaction_rules=True, + knock_restricted_join_rule=False, + enforce_int_power_levels=False, msc3931_push_features=(), - msc3989_redaction_rules=False, linearized_matrix=False, ) V10 = RoomVersion( @@ -387,17 +296,15 @@ class RoomVersions: special_case_aliases_auth=False, strict_canonicaljson=True, limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=True, - msc3375_redaction_rules=True, - msc2403_knocking=True, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=True, + restricted_join_rule_fix=True, + knock_join_rule=True, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=True, - msc3667_int_only_power_levels=True, - msc3821_redaction_rules=False, + knock_restricted_join_rule=True, + enforce_int_power_levels=True, msc3931_push_features=(), - msc3989_redaction_rules=False, linearized_matrix=False, ) MSC1767v10 = RoomVersion( @@ -410,62 +317,35 @@ class RoomVersions: special_case_aliases_auth=False, strict_canonicaljson=True, limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=True, - msc3375_redaction_rules=True, - msc2403_knocking=True, + implicit_room_creator=False, + updated_redaction_rules=False, + restricted_join_rule=True, + restricted_join_rule_fix=True, + knock_join_rule=True, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=True, - msc3667_int_only_power_levels=True, - msc3821_redaction_rules=False, + knock_restricted_join_rule=True, + enforce_int_power_levels=True, msc3931_push_features=(PushRuleRoomFlag.EXTENSIBLE_EVENTS,), - msc3989_redaction_rules=False, linearized_matrix=False, ) - MSC3989 = RoomVersion( - "org.matrix.msc3989", - RoomDisposition.UNSTABLE, - EventFormatVersions.ROOM_V4_PLUS, - StateResolutionVersions.V2, - enforce_key_validity=True, - special_case_aliases_auth=False, - strict_canonicaljson=True, - limit_notifications_power_levels=True, - msc2175_implicit_room_creator=False, - msc2176_redaction_rules=False, - msc3083_join_rules=True, - msc3375_redaction_rules=True, - msc2403_knocking=True, - msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=True, - msc3667_int_only_power_levels=True, - msc3821_redaction_rules=False, - msc3931_push_features=(), - msc3989_redaction_rules=True, - linearized_matrix=False, - ) - MSC3820opt2 = RoomVersion( - # Based upon v10 - "org.matrix.msc3820.opt2", - RoomDisposition.UNSTABLE, + V11 = RoomVersion( + "11", + RoomDisposition.STABLE, EventFormatVersions.ROOM_V4_PLUS, StateResolutionVersions.V2, enforce_key_validity=True, special_case_aliases_auth=False, strict_canonicaljson=True, limit_notifications_power_levels=True, - msc2175_implicit_room_creator=True, # Used by MSC3820 - msc2176_redaction_rules=True, # Used by MSC3820 - msc3083_join_rules=True, - msc3375_redaction_rules=True, - msc2403_knocking=True, + implicit_room_creator=True, # Used by MSC3820 + updated_redaction_rules=True, # Used by MSC3820 + restricted_join_rule=True, + restricted_join_rule_fix=True, + knock_join_rule=True, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=True, - msc3667_int_only_power_levels=True, - msc3821_redaction_rules=True, # Used by MSC3820 + knock_restricted_join_rule=True, + enforce_int_power_levels=True, msc3931_push_features=(), - msc3989_redaction_rules=True, # Used by MSC3820 linearized_matrix=False, ) # Based on room version 11: @@ -482,17 +362,15 @@ class RoomVersions: special_case_aliases_auth=False, strict_canonicaljson=True, limit_notifications_power_levels=True, - msc2175_implicit_room_creator=True, - msc2176_redaction_rules=True, - msc3083_join_rules=False, - msc3375_redaction_rules=False, - msc2403_knocking=True, + implicit_room_creator=True, + updated_redaction_rules=True, + restricted_join_rule=True, + restricted_join_rule_fix=True, + knock_join_rule=True, msc3389_relation_redactions=False, - msc3787_knock_restricted_join_rule=False, - msc3667_int_only_power_levels=True, - msc3821_redaction_rules=True, + knock_restricted_join_rule=True, + enforce_int_power_levels=True, msc3931_push_features=(), - msc3989_redaction_rules=True, linearized_matrix=True, ) @@ -506,14 +384,11 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = { RoomVersions.V4, RoomVersions.V5, RoomVersions.V6, - RoomVersions.MSC2176, RoomVersions.V7, RoomVersions.V8, RoomVersions.V9, - RoomVersions.MSC3787, RoomVersions.V10, - RoomVersions.MSC3989, - RoomVersions.MSC3820opt2, + RoomVersions.V11, RoomVersions.LINEARIZED, ) } @@ -543,12 +418,12 @@ MSC3244_CAPABILITIES = { RoomVersionCapability( "knock", RoomVersions.V7, - lambda room_version: room_version.msc2403_knocking, + lambda room_version: room_version.knock_join_rule, ), RoomVersionCapability( "restricted", RoomVersions.V9, - lambda room_version: room_version.msc3083_join_rules, + lambda room_version: room_version.restricted_join_rule, ), ) } diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 936b1b0430..a94b57a671 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -386,6 +386,7 @@ def listen_unix( def listen_http( + hs: "HomeServer", listener_config: ListenerConfig, root_resource: Resource, version_string: str, @@ -406,6 +407,7 @@ def listen_http( version_string, max_request_body_size=max_request_body_size, reactor=reactor, + hs=hs, ) if isinstance(listener_config, TCPListenerConfig): diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 7406c3948c..dc79efcc14 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -221,6 +221,7 @@ class GenericWorkerServer(HomeServer): root_resource = create_resource_tree(resources, OptionsResource()) _base.listen_http( + self, listener_config, root_resource, self.version_string, diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 84236ac299..f188c7265a 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -139,6 +139,7 @@ class SynapseHomeServer(HomeServer): root_resource = OptionsResource() ports = listen_http( + self, listener_config, create_resource_tree(resources, root_resource), self.version_string, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 0970f22a75..1695ed8ca3 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -247,6 +247,27 @@ class ExperimentalConfig(Config): # MSC3026 (busy presence state) self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False) + # MSC2697 (device dehydration) + # Enabled by default since this option was added after adding the feature. + # It is not recommended that both MSC2697 and MSC3814 both be enabled at + # once. + self.msc2697_enabled: bool = experimental.get("msc2697_enabled", True) + + # MSC3814 (dehydrated devices with SSSS) + # This is an alternative method to achieve the same goals as MSC2697. + # It is not recommended that both MSC2697 and MSC3814 both be enabled at + # once. + self.msc3814_enabled: bool = experimental.get("msc3814_enabled", False) + + if self.msc2697_enabled and self.msc3814_enabled: + raise ConfigError( + "MSC2697 and MSC3814 should not both be enabled.", + ( + "experimental_features", + "msc3814_enabled", + ), + ) + # MSC3244 (room version capabilities) self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index e55ca12a36..6567fb6bb0 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -15,7 +15,7 @@ import argparse import logging -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import attr from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr @@ -171,6 +171,27 @@ class WriterLocations: ) +@attr.s(auto_attribs=True) +class OutboundFederationRestrictedTo: + """Whether we limit outbound federation to a certain set of instances. + + Attributes: + instances: optional list of instances that can make outbound federation + requests. If None then all instances can make federation requests. + locations: list of instance locations to connect to proxy via. + """ + + instances: Optional[List[str]] + locations: List[InstanceLocationConfig] = attr.Factory(list) + + def __contains__(self, instance: str) -> bool: + # It feels a bit dirty to return `True` if `instances` is `None`, but it makes + # sense in downstream usage in the sense that if + # `outbound_federation_restricted_to` is not configured, then any instance can + # talk to federation (no restrictions so always return `True`). + return self.instances is None or instance in self.instances + + class WorkerConfig(Config): """The workers are processes run separately to the main synapse process. They have their own pid_file and listener configuration. They use the @@ -385,6 +406,28 @@ class WorkerConfig(Config): new_option_name="update_user_directory_from_worker", ) + outbound_federation_restricted_to = config.get( + "outbound_federation_restricted_to", None + ) + self.outbound_federation_restricted_to = OutboundFederationRestrictedTo( + outbound_federation_restricted_to + ) + if outbound_federation_restricted_to: + if not self.worker_replication_secret: + raise ConfigError( + "`worker_replication_secret` must be configured when using `outbound_federation_restricted_to`." + ) + + for instance in outbound_federation_restricted_to: + if instance not in self.instance_map: + raise ConfigError( + "Instance %r is configured in 'outbound_federation_restricted_to' but does not appear in `instance_map` config." + % (instance,) + ) + self.outbound_federation_restricted_to.locations.append( + self.instance_map[instance] + ) + def _should_this_worker_perform_duty( self, config: Dict[str, Any], diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 1f9c572a48..97e69c56c9 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -135,7 +135,7 @@ def validate_event_for_room_version(event: "EventBase") -> None: raise AuthError(403, "Event not signed by hub server") is_invite_via_allow_rule = ( - event.room_version.msc3083_join_rules + event.room_version.restricted_join_rule and event.type == EventTypes.Member and event.membership == Membership.JOIN and EventContentFields.AUTHORISING_USER in event.content @@ -368,11 +368,9 @@ LENIENT_EVENT_BYTE_LIMITS_ROOM_VERSIONS = { RoomVersions.V4, RoomVersions.V5, RoomVersions.V6, - RoomVersions.MSC2176, RoomVersions.V7, RoomVersions.V8, RoomVersions.V9, - RoomVersions.MSC3787, RoomVersions.V10, RoomVersions.MSC1767v10, } @@ -465,7 +463,7 @@ def _check_create(event: "EventBase") -> None: # 1.4 If content has no creator field, reject if the room version requires it. if ( - not event.room_version.msc2175_implicit_room_creator + not event.room_version.implicit_room_creator and EventContentFields.ROOM_CREATOR not in event.content ): raise AuthError(403, "Create event lacks a 'creator' property") @@ -502,7 +500,7 @@ def _is_membership_change_allowed( key = (EventTypes.Create, "") create = auth_events.get(key) if create and event.prev_event_ids()[0] == create.event_id: - if room_version.msc2175_implicit_room_creator: + if room_version.implicit_room_creator: creator = create.sender else: creator = create.content[EventContentFields.ROOM_CREATOR] @@ -525,7 +523,7 @@ def _is_membership_change_allowed( caller_invited = caller and caller.membership == Membership.INVITE caller_knocked = ( caller - and room_version.msc2403_knocking + and room_version.knock_join_rule and caller.membership == Membership.KNOCK ) @@ -625,9 +623,9 @@ def _is_membership_change_allowed( elif join_rule == JoinRules.PUBLIC: pass elif ( - room_version.msc3083_join_rules and join_rule == JoinRules.RESTRICTED + room_version.restricted_join_rule and join_rule == JoinRules.RESTRICTED ) or ( - room_version.msc3787_knock_restricted_join_rule + room_version.knock_restricted_join_rule and join_rule == JoinRules.KNOCK_RESTRICTED ): # This is the same as public, but the event must contain a reference @@ -657,9 +655,9 @@ def _is_membership_change_allowed( elif ( join_rule == JoinRules.INVITE - or (room_version.msc2403_knocking and join_rule == JoinRules.KNOCK) + or (room_version.knock_join_rule and join_rule == JoinRules.KNOCK) or ( - room_version.msc3787_knock_restricted_join_rule + room_version.knock_restricted_join_rule and join_rule == JoinRules.KNOCK_RESTRICTED ) ): @@ -693,9 +691,9 @@ def _is_membership_change_allowed( "You don't have permission to ban", errcode=Codes.INSUFFICIENT_POWER, ) - elif room_version.msc2403_knocking and Membership.KNOCK == membership: + elif room_version.knock_join_rule and Membership.KNOCK == membership: if join_rule != JoinRules.KNOCK and ( - not room_version.msc3787_knock_restricted_join_rule + not room_version.knock_restricted_join_rule or join_rule != JoinRules.KNOCK_RESTRICTED ): raise AuthError(403, "You don't have permission to knock") @@ -852,7 +850,7 @@ def _check_power_levels( # Reject events with stringy power levels if required by room version if ( event.type == EventTypes.PowerLevels - and room_version_obj.msc3667_int_only_power_levels + and room_version_obj.enforce_int_power_levels ): for k, v in event.content.items(): if k in { @@ -988,7 +986,7 @@ def get_user_power_level(user_id: str, auth_events: StateMap["EventBase"]) -> in key = (EventTypes.Create, "") create_event = auth_events.get(key) if create_event is not None: - if create_event.room_version.msc2175_implicit_room_creator: + if create_event.room_version.implicit_room_creator: creator = create_event.sender else: creator = create_event.content[EventContentFields.ROOM_CREATOR] @@ -1132,7 +1130,7 @@ def auth_types_for_event( ) auth_types.add(key) - if room_version.msc3083_join_rules and membership == Membership.JOIN: + if room_version.restricted_join_rule and membership == Membership.JOIN: if EventContentFields.AUTHORISING_USER in event.content: key = ( EventTypes.Member, diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index f95ec7dad7..eaf6fcf2f9 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -353,7 +353,7 @@ class EventBase(metaclass=abc.ABCMeta): @property def redacts(self) -> Optional[str]: """MSC2176 moved the redacts field into the content.""" - if self.room_version.msc2176_redaction_rules: + if self.room_version.updated_redaction_rules: return self.content.get("redacts") return self.get("redacts") diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 9efaff77c9..e4ed46c756 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -183,7 +183,7 @@ class EventBuilder: # MSC2174 moves the redacts property to the content, it is invalid to # provide it as a top-level property. - if self._redacts is not None and not self.room_version.msc2176_redaction_rules: + if self._redacts is not None and not self.room_version.updated_redaction_rules: event_dict["redacts"] = self._redacts if self._origin_server_ts is not None: diff --git a/synapse/events/utils.py b/synapse/events/utils.py index c692b4008d..f834d5dad8 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -108,19 +108,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic "origin_server_ts", ] - # Room versions from before MSC2176 had additional allowed keys. - if not room_version.msc2176_redaction_rules: - allowed_keys.extend(["prev_state", "membership"]) + # Earlier room versions from had additional allowed keys. + if not room_version.updated_redaction_rules: + allowed_keys.extend(["prev_state", "membership", "origin"]) + # The hub server should not be redacted for linear matrix. if room_version.linearized_matrix: allowed_keys.append("hub_server") else: allowed_keys.append("depth") - # Room versions before MSC3989 kept the origin field. - if not room_version.msc3989_redaction_rules: - allowed_keys.append("origin") - event_type = event_dict["type"] new_content = {} @@ -132,9 +129,9 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic if event_type == EventTypes.Member: add_fields("membership") - if room_version.msc3375_redaction_rules: + if room_version.restricted_join_rule_fix: add_fields(EventContentFields.AUTHORISING_USER) - if room_version.msc3821_redaction_rules: + if room_version.updated_redaction_rules: # Preserve the signed field under third_party_invite. third_party_invite = event_dict["content"].get("third_party_invite") if isinstance(third_party_invite, collections.abc.Mapping): @@ -145,15 +142,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic ] elif event_type == EventTypes.Create: - # MSC2176 rules state that create events cannot have their `content` redacted. - if room_version.msc2176_redaction_rules: + if room_version.updated_redaction_rules: + # MSC2176 rules state that create events cannot have their `content` redacted. new_content = event_dict["content"] - - if not room_version.msc2175_implicit_room_creator: + elif not room_version.implicit_room_creator: + # Some room versions give meaning to `creator` add_fields("creator") + elif event_type == EventTypes.JoinRules: add_fields("join_rule") - if room_version.msc3083_join_rules: + if room_version.restricted_join_rule: add_fields("allow") elif event_type == EventTypes.PowerLevels: add_fields( @@ -167,14 +165,14 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic "redact", ) - if room_version.msc2176_redaction_rules: + if room_version.updated_redaction_rules: add_fields("invite") elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth: add_fields("aliases") elif event_type == EventTypes.RoomHistoryVisibility: add_fields("history_visibility") - elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules: + elif event_type == EventTypes.Redaction and room_version.updated_redaction_rules: add_fields("redacts") # Protect the rel_type and event_id fields under the m.relates_to field. @@ -483,6 +481,15 @@ def serialize_event( if config.as_client_event: d = config.event_format(d) + # If the event is a redaction, copy the redacts field from the content to + # top-level for backwards compatibility. + if ( + e.type == EventTypes.Redaction + and e.room_version.updated_redaction_rules + and e.redacts is not None + ): + d["redacts"] = e.redacts + only_event_fields = config.only_event_fields if only_event_fields: if not isinstance(only_event_fields, list) or not all( diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 21cb3e4675..400b74006c 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -232,7 +232,7 @@ async def _check_sigs_on_pdu( # If this is a join event for a restricted room it may have been authorised # via a different server from the sending server. Check those signatures. if ( - room_version.msc3083_join_rules + room_version.restricted_join_rule and pdu.type == EventTypes.Member and pdu.membership == Membership.JOIN and EventContentFields.AUTHORISING_USER in pdu.content diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index b473a5c31b..5c79ddee0c 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1008,7 +1008,7 @@ class FederationClient(FederationBase): if not room_version: raise UnsupportedRoomVersionError() - if not room_version.msc2403_knocking and membership == Membership.KNOCK: + if not room_version.knock_join_rule and membership == Membership.KNOCK: raise SynapseError( 400, "This room version does not support knocking", @@ -1094,7 +1094,7 @@ class FederationClient(FederationBase): # * Ensure the signatures are good. # # Otherwise, fallback to the provided event. - if room_version.msc3083_join_rules and response.event: + if room_version.restricted_join_rule and response.event: event = response.event valid_pdu = await self._check_sigs_and_hash_and_fetch_one( @@ -1220,7 +1220,7 @@ class FederationClient(FederationBase): # MSC3083 defines additional error codes for room joins. failover_errcodes = None - if room_version.msc3083_join_rules: + if room_version.restricted_join_rule: failover_errcodes = ( Codes.UNABLE_AUTHORISE_JOIN, Codes.UNABLE_TO_GRANT_JOIN, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2a8b177162..24f726a4dd 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -857,7 +857,7 @@ class FederationServer(FederationBase): raise IncompatibleRoomVersionError(room_version=room_version.identifier) # Check that this room supports knocking as defined by its room version - if not room_version.msc2403_knocking: + if not room_version.knock_join_rule: raise SynapseError( 403, "This room version does not support knocking", @@ -965,7 +965,7 @@ class FederationServer(FederationBase): errcode=Codes.NOT_FOUND, ) - if membership_type == Membership.KNOCK and not room_version.msc2403_knocking: + if membership_type == Membership.KNOCK and not room_version.knock_join_rule: raise SynapseError( 403, "This room version does not support knocking", @@ -994,7 +994,7 @@ class FederationServer(FederationBase): # the event is valid to be sent into the room. Currently this is only done # if the user is being joined via restricted join rules. if ( - room_version.msc3083_join_rules + room_version.restricted_join_rule and event.membership == Membership.JOIN and EventContentFields.AUTHORISING_USER in event.content ): diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 5d12a39e26..f3a713f5fa 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -653,6 +653,7 @@ class DeviceHandler(DeviceWorkerHandler): async def store_dehydrated_device( self, user_id: str, + device_id: Optional[str], device_data: JsonDict, initial_device_display_name: Optional[str] = None, ) -> str: @@ -661,6 +662,7 @@ class DeviceHandler(DeviceWorkerHandler): Args: user_id: the user that we are storing the device for + device_id: device id supplied by client device_data: the dehydrated device information initial_device_display_name: The display name to use for the device Returns: @@ -668,7 +670,7 @@ class DeviceHandler(DeviceWorkerHandler): """ device_id = await self.check_device_registered( user_id, - None, + device_id, initial_device_display_name, ) old_device_id = await self.store.store_dehydrated_device( @@ -1124,7 +1126,14 @@ class DeviceListUpdater(DeviceListWorkerUpdater): ) if resync: - await self.multi_user_device_resync([user_id]) + # We mark as stale up front in case we get restarted. + await self.store.mark_remote_users_device_caches_as_stale([user_id]) + run_as_background_process( + "_maybe_retry_device_resync", + self.multi_user_device_resync, + [user_id], + False, + ) else: # Simply update the single device, since we know that is the only # change (because of the single prev_id matching the current cache) diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 3caf9b31cc..15e94a03cb 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -13,10 +13,11 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict +from http import HTTPStatus +from typing import TYPE_CHECKING, Any, Dict, Optional from synapse.api.constants import EduTypes, EventContentFields, ToDeviceEventTypes -from synapse.api.errors import SynapseError +from synapse.api.errors import Codes, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( @@ -48,6 +49,9 @@ class DeviceMessageHandler: self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.is_mine = hs.is_mine + if hs.config.experimental.msc3814_enabled: + self.event_sources = hs.get_event_sources() + self.device_handler = hs.get_device_handler() # We only need to poke the federation sender explicitly if its on the # same instance. Other federation sender instances will get notified by @@ -303,3 +307,103 @@ class DeviceMessageHandler: # Enqueue a new federation transaction to send the new # device messages to each remote destination. self.federation_sender.send_device_messages(destination) + + async def get_events_for_dehydrated_device( + self, + requester: Requester, + device_id: str, + since_token: Optional[str], + limit: int, + ) -> JsonDict: + """Fetches up to `limit` events sent to `device_id` starting from `since_token` + and returns the new since token. If there are no more messages, returns an empty + array. + + Args: + requester: the user requesting the messages + device_id: ID of the dehydrated device + since_token: stream id to start from when fetching messages + limit: the number of messages to fetch + Returns: + A dict containing the to-device messages, as well as a token that the client + can provide in the next call to fetch the next batch of messages + """ + + user_id = requester.user.to_string() + + # only allow fetching messages for the dehydrated device id currently associated + # with the user + dehydrated_device = await self.device_handler.get_dehydrated_device(user_id) + if dehydrated_device is None: + raise SynapseError( + HTTPStatus.FORBIDDEN, + "No dehydrated device exists", + Codes.FORBIDDEN, + ) + + dehydrated_device_id, _ = dehydrated_device + if device_id != dehydrated_device_id: + raise SynapseError( + HTTPStatus.FORBIDDEN, + "You may only fetch messages for your dehydrated device", + Codes.FORBIDDEN, + ) + + since_stream_id = 0 + if since_token: + if not since_token.startswith("d"): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "from parameter %r has an invalid format" % (since_token,), + errcode=Codes.INVALID_PARAM, + ) + + try: + since_stream_id = int(since_token[1:]) + except Exception: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "from parameter %r has an invalid format" % (since_token,), + errcode=Codes.INVALID_PARAM, + ) + + # if we have a since token, delete any to-device messages before that token + # (since we now know that the device has received them) + deleted = await self.store.delete_messages_for_device( + user_id, device_id, since_stream_id + ) + logger.debug( + "Deleted %d to-device messages up to %d for user_id %s device_id %s", + deleted, + since_stream_id, + user_id, + device_id, + ) + + to_token = self.event_sources.get_current_token().to_device_key + + messages, stream_id = await self.store.get_messages_for_device( + user_id, device_id, since_stream_id, to_token, limit + ) + + for message in messages: + # Remove the message id before sending to client + message_id = message.pop("message_id", None) + if message_id: + set_tag(SynapseTags.TO_DEVICE_EDU_ID, message_id) + + logger.debug( + "Returning %d to-device messages between %d and %d (current token: %d) for " + "dehydrated device %s, user_id %s", + len(messages), + since_stream_id, + stream_id, + to_token, + device_id, + user_id, + ) + + return { + "events": messages, + "next_batch": f"d{stream_id}", + } diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 3e37c0cbe2..82a7617a08 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -277,7 +277,7 @@ class EventAuthHandler: True if the proper room version and join rules are set for restricted access. """ # This only applies to room versions which support the new join rule. - if not room_version.msc3083_join_rules: + if not room_version.restricted_join_rule: return False # If there's no join rule, then it defaults to invite (so this doesn't apply). @@ -292,7 +292,7 @@ class EventAuthHandler: return True # also check for MSC3787 behaviour - if room_version.msc3787_knock_restricted_join_rule: + if room_version.knock_restricted_join_rule: return content_join_rule == JoinRules.KNOCK_RESTRICTED return False diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f0f61dbdb0..58cbf9be13 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -959,7 +959,7 @@ class FederationHandler: # Note that this requires the /send_join request to come back to the # same server. prev_event_ids = None - if room_version.msc3083_join_rules: + if room_version.restricted_join_rule: # Note that the room's state can change out from under us and render our # nice join rules-conformant event non-conformant by the time we build the # event. When this happens, our validation at the end fails and we respond @@ -1582,9 +1582,7 @@ class FederationHandler: event.content["third_party_invite"]["signed"]["token"], ) original_invite = None - prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)]) - ) + prev_state_ids = await context.get_prev_state_ids(StateFilter.from_types([key])) original_invite_id = prev_state_ids.get(key) if original_invite_id: original_invite = await self.store.get_event( @@ -1637,7 +1635,7 @@ class FederationHandler: token = signed["token"] prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)]) + StateFilter.from_types([(EventTypes.ThirdPartyInvite, token)]) ) invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token)) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4292b47037..fff0b5fa12 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -738,7 +738,7 @@ class EventCreationHandler: prev_event_id = state_map.get((EventTypes.Member, event.sender)) else: prev_state_ids = await unpersisted_context.get_prev_state_ids( - StateFilter.from_types([(EventTypes.Member, None)]) + StateFilter.from_types([(EventTypes.Member, event.sender)]) ) prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) prev_event = ( @@ -860,7 +860,7 @@ class EventCreationHandler: return None prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types([(event.type, None)]) + StateFilter.from_types([(event.type, event.state_key)]) ) prev_event_id = prev_state_ids.get((event.type, event.state_key)) if not prev_event_id: @@ -1565,12 +1565,11 @@ class EventCreationHandler: if state_entry.state_group in self._external_cache_joined_hosts_updates: return - state = await state_entry.get_state( - self._storage_controllers.state, StateFilter.all() - ) with opentracing.start_active_span("get_joined_hosts"): - joined_hosts = await self.store.get_joined_hosts( - event.room_id, state, state_entry + joined_hosts = ( + await self._storage_controllers.state.get_joined_hosts( + event.room_id, state_entry + ) ) # Note that the expiry times must be larger than the expiry time in diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 0a219b7962..cd7df0525f 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -95,13 +95,12 @@ bump_active_time_counter = Counter("synapse_handler_presence_bump_active_time", get_updates_counter = Counter("synapse_handler_presence_get_updates", "", ["type"]) notify_reason_counter = Counter( - "synapse_handler_presence_notify_reason", "", ["reason"] + "synapse_handler_presence_notify_reason", "", ["locality", "reason"] ) state_transition_counter = Counter( - "synapse_handler_presence_state_transition", "", ["from", "to"] + "synapse_handler_presence_state_transition", "", ["locality", "from", "to"] ) - # If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them # "currently_active" LAST_ACTIVE_GRANULARITY = 60 * 1000 @@ -567,8 +566,8 @@ class WorkerPresenceHandler(BasePresenceHandler): for new_state in states: old_state = self.user_to_current_state.get(new_state.user_id) self.user_to_current_state[new_state.user_id] = new_state - - if not old_state or should_notify(old_state, new_state): + is_mine = self.is_mine_id(new_state.user_id) + if not old_state or should_notify(old_state, new_state, is_mine): state_to_notify.append(new_state) stream_id = token @@ -1499,23 +1498,31 @@ class PresenceHandler(BasePresenceHandler): ) -def should_notify(old_state: UserPresenceState, new_state: UserPresenceState) -> bool: +def should_notify( + old_state: UserPresenceState, new_state: UserPresenceState, is_mine: bool +) -> bool: """Decides if a presence state change should be sent to interested parties.""" + user_location = "remote" + if is_mine: + user_location = "local" + if old_state == new_state: return False if old_state.status_msg != new_state.status_msg: - notify_reason_counter.labels("status_msg_change").inc() + notify_reason_counter.labels(user_location, "status_msg_change").inc() return True if old_state.state != new_state.state: - notify_reason_counter.labels("state_change").inc() - state_transition_counter.labels(old_state.state, new_state.state).inc() + notify_reason_counter.labels(user_location, "state_change").inc() + state_transition_counter.labels( + user_location, old_state.state, new_state.state + ).inc() return True if old_state.state == PresenceState.ONLINE: if new_state.currently_active != old_state.currently_active: - notify_reason_counter.labels("current_active_change").inc() + notify_reason_counter.labels(user_location, "current_active_change").inc() return True if ( @@ -1524,12 +1531,16 @@ def should_notify(old_state: UserPresenceState, new_state: UserPresenceState) -> ): # Only notify about last active bumps if we're not currently active if not new_state.currently_active: - notify_reason_counter.labels("last_active_change_online").inc() + notify_reason_counter.labels( + user_location, "last_active_change_online" + ).inc() return True elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: # Always notify for a transition where last active gets bumped. - notify_reason_counter.labels("last_active_change_not_online").inc() + notify_reason_counter.labels( + user_location, "last_active_change_not_online" + ).inc() return True return False @@ -1989,7 +2000,7 @@ def handle_update( ) # Check whether the change was something worth notifying about - if should_notify(prev_state, new_state): + if should_notify(prev_state, new_state, is_mine): new_state = new_state.copy_and_replace(last_federation_update_ts=now) persist_and_notify = True diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index bf907b7881..0513e28aab 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1116,7 +1116,7 @@ class RoomCreationHandler: preset_config, config = self._room_preset_config(room_config) # MSC2175 removes the creator field from the create event. - if not room_version.msc2175_implicit_room_creator: + if not room_version.implicit_room_creator: creation_content["creator"] = creator_id creation_event, unpersisted_creation_context = await create_event( EventTypes.Create, creation_content, False diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 82e4fa7363..496e701f13 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -473,7 +473,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) context = await unpersisted_context.persist(event) prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types([(EventTypes.Member, None)]) + StateFilter.from_types([(EventTypes.Member, user_id)]) ) prev_member_event_id = prev_state_ids.get( @@ -1340,7 +1340,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): requester = types.create_requester(target_user) prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types([(EventTypes.GuestAccess, None)]) + StateFilter.from_types([(EventTypes.GuestAccess, "")]) ) if event.membership == Membership.JOIN: if requester.is_guest: @@ -1362,11 +1362,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ratelimit=ratelimit, ) - prev_member_event_id = prev_state_ids.get( - (EventTypes.Member, event.state_key), None - ) - if event.membership == Membership.LEAVE: + prev_state_ids = await context.get_prev_state_ids( + StateFilter.from_types([(EventTypes.Member, event.state_key)]) + ) + prev_member_event_id = prev_state_ids.get( + (EventTypes.Member, event.state_key), None + ) + if prev_member_event_id: prev_member_event = await self.store.get_event(prev_member_event_id) if prev_member_event.membership == Membership.JOIN: diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 807245160d..dad3e23470 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -564,9 +564,9 @@ class RoomSummaryHandler: join_rule = join_rules_event.content.get("join_rule") if ( join_rule == JoinRules.PUBLIC - or (room_version.msc2403_knocking and join_rule == JoinRules.KNOCK) + or (room_version.knock_join_rule and join_rule == JoinRules.KNOCK) or ( - room_version.msc3787_knock_restricted_join_rule + room_version.knock_restricted_join_rule and join_rule == JoinRules.KNOCK_RESTRICTED ) ): diff --git a/synapse/http/client.py b/synapse/http/client.py index 09ea93e10d..ca2cdbc6e2 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -1037,7 +1037,12 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): if reason.check(ResponseDone): self.deferred.callback(self.length) elif reason.check(PotentialDataLoss): - # stolen from https://github.com/twisted/treq/pull/49/files + # This applies to requests which don't set `Content-Length` or a + # `Transfer-Encoding` in the response because in this case the end of the + # response is indicated by the connection being closed, an event which may + # also be due to a transient network problem or other error. But since this + # behavior is expected of some servers (like YouTube), let's ignore it. + # Stolen from https://github.com/twisted/treq/pull/49/files # http://twistedmatrix.com/trac/ticket/4840 self.deferred.callback(self.length) else: diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py index 23a60af171..636efc33e8 100644 --- a/synapse/http/connectproxyclient.py +++ b/synapse/http/connectproxyclient.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import base64 import logging from typing import Optional, Union @@ -39,8 +40,14 @@ class ProxyConnectError(ConnectError): pass -@attr.s(auto_attribs=True) class ProxyCredentials: + @abc.abstractmethod + def as_proxy_authorization_value(self) -> bytes: + raise NotImplementedError() + + +@attr.s(auto_attribs=True) +class BasicProxyCredentials(ProxyCredentials): username_password: bytes def as_proxy_authorization_value(self) -> bytes: @@ -55,6 +62,17 @@ class ProxyCredentials: return b"Basic " + base64.encodebytes(self.username_password) +@attr.s(auto_attribs=True) +class BearerProxyCredentials(ProxyCredentials): + access_token: bytes + + def as_proxy_authorization_value(self) -> bytes: + """ + Return the value for a Proxy-Authorization header (i.e. 'Bearer xxx'). + """ + return b"Bearer " + self.access_token + + @implementer(IStreamClientEndpoint) class HTTPConnectProxyEndpoint: """An Endpoint implementation which will send a CONNECT request to an http proxy diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index b973246adc..9c18512813 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -50,7 +50,7 @@ from twisted.internet.interfaces import IReactorTime from twisted.internet.task import Cooperator from twisted.web.client import ResponseFailed from twisted.web.http_headers import Headers -from twisted.web.iweb import IBodyProducer, IResponse +from twisted.web.iweb import IAgent, IBodyProducer, IResponse import synapse.metrics import synapse.util.retryutils @@ -71,7 +71,9 @@ from synapse.http.client import ( encode_query_args, read_body_with_max_size, ) +from synapse.http.connectproxyclient import BearerProxyCredentials from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent +from synapse.http.proxyagent import ProxyAgent from synapse.http.types import QueryParams from synapse.logging import opentracing from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -393,17 +395,41 @@ class MatrixFederationHttpClient: if hs.config.server.user_agent_suffix: user_agent = "%s %s" % (user_agent, hs.config.server.user_agent_suffix) - federation_agent = MatrixFederationAgent( - self.reactor, - tls_client_options_factory, - user_agent.encode("ascii"), - hs.config.server.federation_ip_range_allowlist, - hs.config.server.federation_ip_range_blocklist, + outbound_federation_restricted_to = ( + hs.config.worker.outbound_federation_restricted_to ) + if hs.get_instance_name() in outbound_federation_restricted_to: + # Talk to federation directly + federation_agent: IAgent = MatrixFederationAgent( + self.reactor, + tls_client_options_factory, + user_agent.encode("ascii"), + hs.config.server.federation_ip_range_allowlist, + hs.config.server.federation_ip_range_blocklist, + ) + else: + proxy_authorization_secret = hs.config.worker.worker_replication_secret + assert ( + proxy_authorization_secret is not None + ), "`worker_replication_secret` must be set when using `outbound_federation_restricted_to` (used to authenticate requests across workers)" + federation_proxy_credentials = BearerProxyCredentials( + proxy_authorization_secret.encode("ascii") + ) + + # We need to talk to federation via the proxy via one of the configured + # locations + federation_proxy_locations = outbound_federation_restricted_to.locations + federation_agent = ProxyAgent( + self.reactor, + self.reactor, + tls_client_options_factory, + federation_proxy_locations=federation_proxy_locations, + federation_proxy_credentials=federation_proxy_credentials, + ) # Use a BlocklistingAgentWrapper to prevent circumventing the IP # blocking via IP literals in server names - self.agent = BlocklistingAgentWrapper( + self.agent: IAgent = BlocklistingAgentWrapper( federation_agent, ip_blocklist=hs.config.server.federation_ip_range_blocklist, ) @@ -412,7 +438,6 @@ class MatrixFederationHttpClient: self._store = hs.get_datastores().main self.version_string_bytes = hs.version_string.encode("ascii") self.default_timeout_seconds = hs.config.federation.client_timeout_ms / 1000 - self.max_long_retry_delay_seconds = ( hs.config.federation.max_long_retry_delay_ms / 1000 ) @@ -1176,6 +1201,101 @@ class MatrixFederationHttpClient: RequestSendFailed: If there were problems connecting to the remote, due to e.g. DNS failures, connection timeouts etc. """ + json_dict, _ = await self.get_json_with_headers( + destination=destination, + path=path, + args=args, + retry_on_dns_fail=retry_on_dns_fail, + timeout=timeout, + ignore_backoff=ignore_backoff, + try_trailing_slash_on_400=try_trailing_slash_on_400, + parser=parser, + ) + return json_dict + + @overload + async def get_json_with_headers( + self, + destination: str, + path: str, + args: Optional[QueryParams] = None, + retry_on_dns_fail: bool = True, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + try_trailing_slash_on_400: bool = False, + parser: Literal[None] = None, + ) -> Tuple[JsonDict, Dict[bytes, List[bytes]]]: + ... + + @overload + async def get_json_with_headers( + self, + destination: str, + path: str, + args: Optional[QueryParams] = ..., + retry_on_dns_fail: bool = ..., + timeout: Optional[int] = ..., + ignore_backoff: bool = ..., + try_trailing_slash_on_400: bool = ..., + parser: ByteParser[T] = ..., + ) -> Tuple[T, Dict[bytes, List[bytes]]]: + ... + + async def get_json_with_headers( + self, + destination: str, + path: str, + args: Optional[QueryParams] = None, + retry_on_dns_fail: bool = True, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + try_trailing_slash_on_400: bool = False, + parser: Optional[ByteParser[T]] = None, + ) -> Tuple[Union[JsonDict, T], Dict[bytes, List[bytes]]]: + """GETs some json from the given host homeserver and path + + Args: + destination: The remote server to send the HTTP request to. + + path: The HTTP path. + + args: A dictionary used to create query strings, defaults to + None. + + retry_on_dns_fail: true if the request should be retried on DNS failures + + timeout: number of milliseconds to wait for the response. + self._default_timeout (60s) by default. + + Note that we may make several attempts to send the request; this + timeout applies to the time spent waiting for response headers for + *each* attempt (including connection time) as well as the time spent + reading the response body after a 200 response. + + ignore_backoff: true to ignore the historical backoff data + and try the request anyway. + + try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED + response we should try appending a trailing slash to the end of + the request. Workaround for #3622 in Synapse <= v0.99.3. + + parser: The parser to use to decode the response. Defaults to + parsing as JSON. + + Returns: + Succeeds when we get a 2xx HTTP response. The result will be a tuple of the + decoded JSON body and a dict of the response headers. + + Raises: + HttpResponseException: If we get an HTTP response code >= 300 + (except 429). + NotRetryingDestination: If we are not yet ready to retry this + server. + FederationDeniedError: If this destination is not on our + federation whitelist + RequestSendFailed: If there were problems connecting to the + remote, due to e.g. DNS failures, connection timeouts etc. + """ request = MatrixFederationRequest( method="GET", destination=destination, path=path, query=args ) @@ -1191,6 +1311,8 @@ class MatrixFederationHttpClient: timeout=timeout, ) + headers = dict(response.headers.getAllRawHeaders()) + if timeout is not None: _sec_timeout = timeout / 1000 else: @@ -1208,7 +1330,7 @@ class MatrixFederationHttpClient: parser=parser, ) - return body + return body, headers async def delete_json( self, diff --git a/synapse/http/proxy.py b/synapse/http/proxy.py new file mode 100644 index 0000000000..c9f51e51bc --- /dev/null +++ b/synapse/http/proxy.py @@ -0,0 +1,283 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import logging +import urllib.parse +from typing import TYPE_CHECKING, Any, Optional, Set, Tuple, cast + +from twisted.internet import protocol +from twisted.internet.interfaces import ITCPTransport +from twisted.internet.protocol import connectionDone +from twisted.python import failure +from twisted.python.failure import Failure +from twisted.web.client import ResponseDone +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse +from twisted.web.resource import IResource +from twisted.web.server import Request, Site + +from synapse.api.errors import Codes, InvalidProxyCredentialsError +from synapse.http import QuieterFileBodyProducer +from synapse.http.server import _AsyncResource +from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.types import ISynapseReactor +from synapse.util.async_helpers import timeout_deferred + +if TYPE_CHECKING: + from synapse.http.site import SynapseRequest + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + +# "Hop-by-hop" headers (as opposed to "end-to-end" headers) as defined by RFC2616 +# section 13.5.1 and referenced in RFC9110 section 7.6.1. These are meant to only be +# consumed by the immediate recipient and not be forwarded on. +HOP_BY_HOP_HEADERS = { + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "TE", + "Trailers", + "Transfer-Encoding", + "Upgrade", +} + + +def parse_connection_header_value( + connection_header_value: Optional[bytes], +) -> Set[str]: + """ + Parse the `Connection` header to determine which headers we should not be copied + over from the remote response. + + As defined by RFC2616 section 14.10 and RFC9110 section 7.6.1 + + Example: `Connection: close, X-Foo, X-Bar` will return `{"Close", "X-Foo", "X-Bar"}` + + Even though "close" is a special directive, let's just treat it as just another + header for simplicity. If people want to check for this directive, they can simply + check for `"Close" in headers`. + + Args: + connection_header_value: The value of the `Connection` header. + + Returns: + The set of header names that should not be copied over from the remote response. + The keys are capitalized in canonical capitalization. + """ + headers = Headers() + extra_headers_to_remove: Set[str] = set() + if connection_header_value: + extra_headers_to_remove = { + headers._canonicalNameCaps(connection_option.strip()).decode("ascii") + for connection_option in connection_header_value.split(b",") + } + + return extra_headers_to_remove + + +class ProxyResource(_AsyncResource): + """ + A stub resource that proxies any requests with a `matrix-federation://` scheme + through the given `federation_agent` to the remote homeserver and ferries back the + info. + """ + + isLeaf = True + + def __init__(self, reactor: ISynapseReactor, hs: "HomeServer"): + super().__init__(True) + + self.reactor = reactor + self.agent = hs.get_federation_http_client().agent + + self._proxy_authorization_secret = hs.config.worker.worker_replication_secret + + def _check_auth(self, request: Request) -> None: + # The `matrix-federation://` proxy functionality can only be used with auth. + # Protect homserver admins forgetting to configure a secret. + assert self._proxy_authorization_secret is not None + + # Get the authorization header. + auth_headers = request.requestHeaders.getRawHeaders(b"Proxy-Authorization") + + if not auth_headers: + raise InvalidProxyCredentialsError( + "Missing Proxy-Authorization header.", Codes.MISSING_TOKEN + ) + if len(auth_headers) > 1: + raise InvalidProxyCredentialsError( + "Too many Proxy-Authorization headers.", Codes.UNAUTHORIZED + ) + parts = auth_headers[0].split(b" ") + if parts[0] == b"Bearer" and len(parts) == 2: + received_secret = parts[1].decode("ascii") + if self._proxy_authorization_secret == received_secret: + # Success! + return + + raise InvalidProxyCredentialsError( + "Invalid Proxy-Authorization header.", Codes.UNAUTHORIZED + ) + + async def _async_render(self, request: "SynapseRequest") -> Tuple[int, Any]: + uri = urllib.parse.urlparse(request.uri) + assert uri.scheme == b"matrix-federation" + + # Check the authorization headers before handling the request. + self._check_auth(request) + + headers = Headers() + for header_name in (b"User-Agent", b"Authorization", b"Content-Type"): + header_value = request.getHeader(header_name) + if header_value: + headers.addRawHeader(header_name, header_value) + + request_deferred = run_in_background( + self.agent.request, + request.method, + request.uri, + headers=headers, + bodyProducer=QuieterFileBodyProducer(request.content), + ) + request_deferred = timeout_deferred( + request_deferred, + # This should be set longer than the timeout in `MatrixFederationHttpClient` + # so that it has enough time to complete and pass us the data before we give + # up. + timeout=90, + reactor=self.reactor, + ) + + response = await make_deferred_yieldable(request_deferred) + + return response.code, response + + def _send_response( + self, + request: "SynapseRequest", + code: int, + response_object: Any, + ) -> None: + response = cast(IResponse, response_object) + response_headers = cast(Headers, response.headers) + + request.setResponseCode(code) + + # The `Connection` header also defines which headers should not be copied over. + connection_header = response_headers.getRawHeaders(b"connection") + extra_headers_to_remove = parse_connection_header_value( + connection_header[0] if connection_header else None + ) + + # Copy headers. + for k, v in response_headers.getAllRawHeaders(): + # Do not copy over any hop-by-hop headers. These are meant to only be + # consumed by the immediate recipient and not be forwarded on. + header_key = k.decode("ascii") + if ( + header_key in HOP_BY_HOP_HEADERS + or header_key in extra_headers_to_remove + ): + continue + + request.responseHeaders.setRawHeaders(k, v) + + response.deliverBody(_ProxyResponseBody(request)) + + def _send_error_response( + self, + f: failure.Failure, + request: "SynapseRequest", + ) -> None: + if isinstance(f.value, InvalidProxyCredentialsError): + error_response_code = f.value.code + error_response_json = {"errcode": f.value.errcode, "err": f.value.msg} + else: + error_response_code = 502 + error_response_json = { + "errcode": Codes.UNKNOWN, + "err": "ProxyResource: Error when proxying request: %s %s -> %s" + % ( + request.method.decode("ascii"), + request.uri.decode("ascii"), + f, + ), + } + + request.setResponseCode(error_response_code) + request.setHeader(b"Content-Type", b"application/json") + request.write((json.dumps(error_response_json)).encode()) + request.finish() + + +class _ProxyResponseBody(protocol.Protocol): + """ + A protocol that proxies the given remote response data back out to the given local + request. + """ + + transport: Optional[ITCPTransport] = None + + def __init__(self, request: "SynapseRequest") -> None: + self._request = request + + def dataReceived(self, data: bytes) -> None: + # Avoid sending response data to the local request that already disconnected + if self._request._disconnected and self.transport is not None: + # Close the connection (forcefully) since all the data will get + # discarded anyway. + self.transport.abortConnection() + return + + self._request.write(data) + + def connectionLost(self, reason: Failure = connectionDone) -> None: + # If the local request is already finished (successfully or failed), don't + # worry about sending anything back. + if self._request.finished: + return + + if reason.check(ResponseDone): + self._request.finish() + else: + # Abort the underlying request since our remote request also failed. + self._request.transport.abortConnection() + + +class ProxySite(Site): + """ + Proxies any requests with a `matrix-federation://` scheme through the given + `federation_agent`. Otherwise, behaves like a normal `Site`. + """ + + def __init__( + self, + resource: IResource, + reactor: ISynapseReactor, + hs: "HomeServer", + ): + super().__init__(resource, reactor=reactor) + + self._proxy_resource = ProxyResource(reactor, hs=hs) + + def getResourceFor(self, request: "SynapseRequest") -> IResource: + uri = urllib.parse.urlparse(request.uri) + if uri.scheme == b"matrix-federation": + return self._proxy_resource + + return super().getResourceFor(request) diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 7bdc4acae7..59ab8fad35 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import random import re -from typing import Any, Dict, Optional, Tuple +from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple from urllib.parse import urlparse from urllib.request import ( # type: ignore[attr-defined] getproxies_environment, @@ -23,8 +24,17 @@ from urllib.request import ( # type: ignore[attr-defined] from zope.interface import implementer from twisted.internet import defer -from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS -from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint +from twisted.internet.endpoints import ( + HostnameEndpoint, + UNIXClientEndpoint, + wrapClientTLS, +) +from twisted.internet.interfaces import ( + IProtocol, + IProtocolFactory, + IReactorCore, + IStreamClientEndpoint, +) from twisted.python.failure import Failure from twisted.web.client import ( URI, @@ -36,8 +46,18 @@ from twisted.web.error import SchemeNotSupported from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse +from synapse.config.workers import ( + InstanceLocationConfig, + InstanceTcpLocationConfig, + InstanceUnixLocationConfig, +) from synapse.http import redact_uri -from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials +from synapse.http.connectproxyclient import ( + BasicProxyCredentials, + HTTPConnectProxyEndpoint, + ProxyCredentials, +) +from synapse.logging.context import run_in_background logger = logging.getLogger(__name__) @@ -74,6 +94,14 @@ class ProxyAgent(_AgentBase): use_proxy: Whether proxy settings should be discovered and used from conventional environment variables. + federation_proxy_locations: An optional list of locations to proxy outbound federation + traffic through (only requests that use the `matrix-federation://` scheme + will be proxied). + + federation_proxy_credentials: Required if `federation_proxy_locations` is set. The + credentials to use when proxying outbound federation traffic through another + worker. + Raises: ValueError if use_proxy is set and the environment variables contain an invalid proxy specification. @@ -89,6 +117,8 @@ class ProxyAgent(_AgentBase): bindAddress: Optional[bytes] = None, pool: Optional[HTTPConnectionPool] = None, use_proxy: bool = False, + federation_proxy_locations: Collection[InstanceLocationConfig] = (), + federation_proxy_credentials: Optional[ProxyCredentials] = None, ): contextFactory = contextFactory or BrowserLikePolicyForHTTPS() @@ -127,6 +157,47 @@ class ProxyAgent(_AgentBase): self._policy_for_https = contextFactory self._reactor = reactor + self._federation_proxy_endpoint: Optional[IStreamClientEndpoint] = None + self._federation_proxy_credentials: Optional[ProxyCredentials] = None + if federation_proxy_locations: + assert ( + federation_proxy_credentials is not None + ), "`federation_proxy_credentials` are required when using `federation_proxy_locations`" + + endpoints: List[IStreamClientEndpoint] = [] + for federation_proxy_location in federation_proxy_locations: + endpoint: IStreamClientEndpoint + if isinstance(federation_proxy_location, InstanceTcpLocationConfig): + endpoint = HostnameEndpoint( + self.proxy_reactor, + federation_proxy_location.host, + federation_proxy_location.port, + ) + if federation_proxy_location.tls: + tls_connection_creator = ( + self._policy_for_https.creatorForNetloc( + federation_proxy_location.host.encode("utf-8"), + federation_proxy_location.port, + ) + ) + endpoint = wrapClientTLS(tls_connection_creator, endpoint) + + elif isinstance(federation_proxy_location, InstanceUnixLocationConfig): + endpoint = UNIXClientEndpoint( + self.proxy_reactor, federation_proxy_location.path + ) + + else: + # It is supremely unlikely we ever hit this + raise SchemeNotSupported( + f"Unknown type of Endpoint requested, check {federation_proxy_location}" + ) + + endpoints.append(endpoint) + + self._federation_proxy_endpoint = _RandomSampleEndpoints(endpoints) + self._federation_proxy_credentials = federation_proxy_credentials + def request( self, method: bytes, @@ -214,6 +285,25 @@ class ProxyAgent(_AgentBase): parsed_uri.port, self.https_proxy_creds, ) + elif ( + parsed_uri.scheme == b"matrix-federation" + and self._federation_proxy_endpoint + ): + assert ( + self._federation_proxy_credentials is not None + ), "`federation_proxy_credentials` are required when using `federation_proxy_locations`" + + # Set a Proxy-Authorization header + if headers is None: + headers = Headers() + # We always need authentication for the outbound federation proxy + headers.addRawHeader( + b"Proxy-Authorization", + self._federation_proxy_credentials.as_proxy_authorization_value(), + ) + + endpoint = self._federation_proxy_endpoint + request_path = uri else: # not using a proxy endpoint = HostnameEndpoint( @@ -233,6 +323,11 @@ class ProxyAgent(_AgentBase): endpoint = wrapClientTLS(tls_connection_creator, endpoint) elif parsed_uri.scheme == b"http": pass + elif ( + parsed_uri.scheme == b"matrix-federation" + and self._federation_proxy_endpoint + ): + pass else: return defer.fail( Failure( @@ -334,6 +429,42 @@ def parse_proxy( credentials = None if url.username and url.password: - credentials = ProxyCredentials(b"".join([url.username, b":", url.password])) + credentials = BasicProxyCredentials( + b"".join([url.username, b":", url.password]) + ) return url.scheme, url.hostname, url.port or default_port, credentials + + +@implementer(IStreamClientEndpoint) +class _RandomSampleEndpoints: + """An endpoint that randomly iterates through a given list of endpoints at + each connection attempt. + """ + + def __init__( + self, + endpoints: Sequence[IStreamClientEndpoint], + ) -> None: + assert endpoints + self._endpoints = endpoints + + def __repr__(self) -> str: + return f"<_RandomSampleEndpoints endpoints={self._endpoints}>" + + def connect( + self, protocol_factory: IProtocolFactory + ) -> "defer.Deferred[IProtocol]": + """Implements IStreamClientEndpoint interface""" + + return run_in_background(self._do_connect, protocol_factory) + + async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol: + failures: List[Failure] = [] + for endpoint in random.sample(self._endpoints, k=len(self._endpoints)): + try: + return await endpoint.connect(protocol_factory) + except Exception: + failures.append(Failure()) + + failures.pop().raiseException() diff --git a/synapse/http/server.py b/synapse/http/server.py index e411ac7e62..5109cec983 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -18,6 +18,7 @@ import html import logging import types import urllib +import urllib.parse from http import HTTPStatus from http.client import FOUND from inspect import isawaitable @@ -65,7 +66,6 @@ from synapse.api.errors import ( UnrecognizedRequestError, ) from synapse.config.homeserver import HomeServerConfig -from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background from synapse.logging.opentracing import active_span, start_active_span, trace_servlet from synapse.util import json_encoder @@ -76,6 +76,7 @@ from synapse.util.iterutils import chunk_seq if TYPE_CHECKING: import opentracing + from synapse.http.site import SynapseRequest from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -102,7 +103,7 @@ HTTP_STATUS_REQUEST_CANCELLED = 499 def return_json_error( - f: failure.Failure, request: SynapseRequest, config: Optional[HomeServerConfig] + f: failure.Failure, request: "SynapseRequest", config: Optional[HomeServerConfig] ) -> None: """Sends a JSON error response to clients.""" @@ -220,8 +221,8 @@ def return_html_error( def wrap_async_request_handler( - h: Callable[["_AsyncResource", SynapseRequest], Awaitable[None]] -) -> Callable[["_AsyncResource", SynapseRequest], "defer.Deferred[None]"]: + h: Callable[["_AsyncResource", "SynapseRequest"], Awaitable[None]] +) -> Callable[["_AsyncResource", "SynapseRequest"], "defer.Deferred[None]"]: """Wraps an async request handler so that it calls request.processing. This helps ensure that work done by the request handler after the request is completed @@ -235,7 +236,7 @@ def wrap_async_request_handler( """ async def wrapped_async_request_handler( - self: "_AsyncResource", request: SynapseRequest + self: "_AsyncResource", request: "SynapseRequest" ) -> None: with request.processing(): await h(self, request) @@ -300,7 +301,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): self._extract_context = extract_context - def render(self, request: SynapseRequest) -> int: + def render(self, request: "SynapseRequest") -> int: """This gets called by twisted every time someone sends us a request.""" request.render_deferred = defer.ensureDeferred( self._async_render_wrapper(request) @@ -308,7 +309,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): return NOT_DONE_YET @wrap_async_request_handler - async def _async_render_wrapper(self, request: SynapseRequest) -> None: + async def _async_render_wrapper(self, request: "SynapseRequest") -> None: """This is a wrapper that delegates to `_async_render` and handles exceptions, return values, metrics, etc. """ @@ -328,7 +329,9 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): f = failure.Failure() self._send_error_response(f, request) - async def _async_render(self, request: SynapseRequest) -> Optional[Tuple[int, Any]]: + async def _async_render( + self, request: "SynapseRequest" + ) -> Optional[Tuple[int, Any]]: """Delegates to `_async_render_<METHOD>` methods, or returns a 400 if no appropriate method exists. Can be overridden in sub classes for different routing. @@ -358,7 +361,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): @abc.abstractmethod def _send_response( self, - request: SynapseRequest, + request: "SynapseRequest", code: int, response_object: Any, ) -> None: @@ -368,7 +371,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): def _send_error_response( self, f: failure.Failure, - request: SynapseRequest, + request: "SynapseRequest", ) -> None: raise NotImplementedError() @@ -384,7 +387,7 @@ class DirectServeJsonResource(_AsyncResource): def _send_response( self, - request: SynapseRequest, + request: "SynapseRequest", code: int, response_object: Any, ) -> None: @@ -401,7 +404,7 @@ class DirectServeJsonResource(_AsyncResource): def _send_error_response( self, f: failure.Failure, - request: SynapseRequest, + request: "SynapseRequest", ) -> None: """Implements _AsyncResource._send_error_response""" return_json_error(f, request, None) @@ -473,7 +476,7 @@ class JsonResource(DirectServeJsonResource): ) def _get_handler_for_request( - self, request: SynapseRequest + self, request: "SynapseRequest" ) -> Tuple[ServletCallback, str, Dict[str, str]]: """Finds a callback method to handle the given request. @@ -503,7 +506,7 @@ class JsonResource(DirectServeJsonResource): # Huh. No one wanted to handle that? Fiiiiiine. raise UnrecognizedRequestError(code=404) - async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]: + async def _async_render(self, request: "SynapseRequest") -> Tuple[int, Any]: callback, servlet_classname, group_dict = self._get_handler_for_request(request) request.is_render_cancellable = is_function_cancellable(callback) @@ -535,7 +538,7 @@ class JsonResource(DirectServeJsonResource): def _send_error_response( self, f: failure.Failure, - request: SynapseRequest, + request: "SynapseRequest", ) -> None: """Implements _AsyncResource._send_error_response""" return_json_error(f, request, self.hs.config) @@ -551,7 +554,7 @@ class DirectServeHtmlResource(_AsyncResource): def _send_response( self, - request: SynapseRequest, + request: "SynapseRequest", code: int, response_object: Any, ) -> None: @@ -565,7 +568,7 @@ class DirectServeHtmlResource(_AsyncResource): def _send_error_response( self, f: failure.Failure, - request: SynapseRequest, + request: "SynapseRequest", ) -> None: """Implements _AsyncResource._send_error_response""" return_html_error(f, request, self.ERROR_TEMPLATE) @@ -592,7 +595,7 @@ class UnrecognizedRequestResource(resource.Resource): errcode of M_UNRECOGNIZED. """ - def render(self, request: SynapseRequest) -> int: + def render(self, request: "SynapseRequest") -> int: f = failure.Failure(UnrecognizedRequestError(code=404)) return_json_error(f, request, None) # A response has already been sent but Twisted requires either NOT_DONE_YET @@ -622,7 +625,7 @@ class RootRedirect(resource.Resource): class OptionsResource(resource.Resource): """Responds to OPTION requests for itself and all children.""" - def render_OPTIONS(self, request: SynapseRequest) -> bytes: + def render_OPTIONS(self, request: "SynapseRequest") -> bytes: request.setResponseCode(204) request.setHeader(b"Content-Length", b"0") @@ -737,7 +740,7 @@ def _encode_json_bytes(json_object: object) -> bytes: def respond_with_json( - request: SynapseRequest, + request: "SynapseRequest", code: int, json_object: Any, send_cors: bool = False, @@ -787,7 +790,7 @@ def respond_with_json( def respond_with_json_bytes( - request: SynapseRequest, + request: "SynapseRequest", code: int, json_bytes: bytes, send_cors: bool = False, @@ -825,7 +828,7 @@ def respond_with_json_bytes( async def _async_write_json_to_request_in_thread( - request: SynapseRequest, + request: "SynapseRequest", json_encoder: Callable[[Any], bytes], json_object: Any, ) -> None: @@ -883,7 +886,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None: _ByteProducer(request, bytes_generator) -def set_cors_headers(request: SynapseRequest) -> None: +def set_cors_headers(request: "SynapseRequest") -> None: """Set the CORS headers so that javascript running in a web browsers can use this API @@ -981,7 +984,7 @@ def set_clickjacking_protection_headers(request: Request) -> None: def respond_with_redirect( - request: SynapseRequest, url: bytes, statusCode: int = FOUND, cors: bool = False + request: "SynapseRequest", url: bytes, statusCode: int = FOUND, cors: bool = False ) -> None: """ Write a 302 (or other specified status code) response to the request, if it is still alive. diff --git a/synapse/http/site.py b/synapse/http/site.py index 5b5a7c1e59..a388d6cf7f 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -21,25 +21,29 @@ from zope.interface import implementer from twisted.internet.address import UNIXAddress from twisted.internet.defer import Deferred -from twisted.internet.interfaces import IAddress, IReactorTime +from twisted.internet.interfaces import IAddress from twisted.python.failure import Failure from twisted.web.http import HTTPChannel from twisted.web.resource import IResource, Resource -from twisted.web.server import Request, Site +from twisted.web.server import Request from synapse.config.server import ListenerConfig from synapse.http import get_request_user_agent, redact_uri +from synapse.http.proxy import ProxySite from synapse.http.request_metrics import RequestMetrics, requests_counter from synapse.logging.context import ( ContextRequest, LoggingContext, PreserveLoggingContext, ) -from synapse.types import Requester +from synapse.types import ISynapseReactor, Requester if TYPE_CHECKING: import opentracing + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) _next_request_seq = 0 @@ -102,7 +106,7 @@ class SynapseRequest(Request): # A boolean indicating whether `render_deferred` should be cancelled if the # client disconnects early. Expected to be set by the coroutine started by # `Resource.render`, if rendering is asynchronous. - self.is_render_cancellable = False + self.is_render_cancellable: bool = False global _next_request_seq self.request_seq = _next_request_seq @@ -601,7 +605,7 @@ class _XForwardedForAddress: host: str -class SynapseSite(Site): +class SynapseSite(ProxySite): """ Synapse-specific twisted http Site @@ -623,7 +627,8 @@ class SynapseSite(Site): resource: IResource, server_version_string: str, max_request_body_size: int, - reactor: IReactorTime, + reactor: ISynapseReactor, + hs: "HomeServer", ): """ @@ -638,7 +643,11 @@ class SynapseSite(Site): dropping the connection reactor: reactor to be used to manage connection timeouts """ - Site.__init__(self, resource, reactor=reactor) + super().__init__( + resource=resource, + reactor=reactor, + hs=hs, + ) self.site_tag = site_tag self.reactor = reactor @@ -649,7 +658,9 @@ class SynapseSite(Site): request_id_header = config.http_options.request_id_header - self.experimental_cors_msc3886 = config.http_options.experimental_cors_msc3886 + self.experimental_cors_msc3886: bool = ( + config.http_options.experimental_cors_msc3886 + ) def request_factory(channel: HTTPChannel, queued: bool) -> Request: return request_class( diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 67377c647b..990c079c81 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -375,7 +375,7 @@ class BulkPushRuleEvaluator: # _get_power_levels_and_sender_level in its call to get_user_power_level # (even for room V10.) notification_levels = power_levels.get("notifications", {}) - if not event.room_version.msc3667_int_only_power_levels: + if not event.room_version.enforce_int_power_levels: keys = list(notification_levels.keys()) for key in keys: level = notification_levels.get(key, SENTINEL) diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 38dff9703f..690d2ec406 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -14,19 +14,22 @@ # limitations under the License. import logging +from http import HTTPStatus from typing import TYPE_CHECKING, List, Optional, Tuple from pydantic import Extra, StrictStr from synapse.api import errors -from synapse.api.errors import NotFoundError, UnrecognizedRequestError +from synapse.api.errors import NotFoundError, SynapseError, UnrecognizedRequestError from synapse.handlers.device import DeviceHandler from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_and_validate_json_object_from_request, + parse_integer, ) from synapse.http.site import SynapseRequest +from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet from synapse.rest.client._base import client_patterns, interactive_auth_handler from synapse.rest.client.models import AuthenticationData from synapse.rest.models import RequestBodyModel @@ -229,6 +232,8 @@ class DehydratedDeviceDataModel(RequestBodyModel): class DehydratedDeviceServlet(RestServlet): """Retrieve or store a dehydrated device. + Implements either MSC2697 or MSC3814. + GET /org.matrix.msc2697.v2/dehydrated_device HTTP/1.1 200 OK @@ -261,9 +266,7 @@ class DehydratedDeviceServlet(RestServlet): """ - PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device$", releases=()) - - def __init__(self, hs: "HomeServer"): + def __init__(self, hs: "HomeServer", msc2697: bool = True): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -271,6 +274,13 @@ class DehydratedDeviceServlet(RestServlet): assert isinstance(handler, DeviceHandler) self.device_handler = handler + self.PATTERNS = client_patterns( + "/org.matrix.msc2697.v2/dehydrated_device$" + if msc2697 + else "/org.matrix.msc3814.v1/dehydrated_device$", + releases=(), + ) + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) dehydrated_device = await self.device_handler.get_dehydrated_device( @@ -293,6 +303,7 @@ class DehydratedDeviceServlet(RestServlet): device_id = await self.device_handler.store_dehydrated_device( requester.user.to_string(), + None, submission.device_data.dict(), submission.initial_device_display_name, ) @@ -347,6 +358,210 @@ class ClaimDehydratedDeviceServlet(RestServlet): return 200, result +class DehydratedDeviceEventsServlet(RestServlet): + PATTERNS = client_patterns( + "/org.matrix.msc3814.v1/dehydrated_device/(?P<device_id>[^/]*)/events$", + releases=(), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.message_handler = hs.get_device_message_handler() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + + class PostBody(RequestBodyModel): + next_batch: Optional[StrictStr] + + async def on_POST( + self, request: SynapseRequest, device_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + next_batch = parse_and_validate_json_object_from_request( + request, self.PostBody + ).next_batch + limit = parse_integer(request, "limit", 100) + + msgs = await self.message_handler.get_events_for_dehydrated_device( + requester=requester, + device_id=device_id, + since_token=next_batch, + limit=limit, + ) + + return 200, msgs + + +class DehydratedDeviceV2Servlet(RestServlet): + """Upload, retrieve, or delete a dehydrated device. + + GET /org.matrix.msc3814.v1/dehydrated_device + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "device_id": "dehydrated_device_id", + "device_data": { + "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", + "account": "dehydrated_device" + } + } + + PUT /org.matrix.msc3814.v1/dehydrated_device + Content-Type: application/json + + { + "device_id": "dehydrated_device_id", + "device_data": { + "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", + "account": "dehydrated_device" + }, + "device_keys": { + "user_id": "<user_id>", + "device_id": "<device_id>", + "valid_until_ts": <millisecond_timestamp>, + "algorithms": [ + "m.olm.curve25519-aes-sha2", + ] + "keys": { + "<algorithm>:<device_id>": "<key_base64>", + }, + "signatures:" { + "<user_id>" { + "<algorithm>:<device_id>": "<signature_base64>" + } + } + }, + "fallback_keys": { + "<algorithm>:<device_id>": "<key_base64>", + "signed_<algorithm>:<device_id>": { + "fallback": true, + "key": "<key_base64>", + "signatures": { + "<user_id>": { + "<algorithm>:<device_id>": "<key_base64>" + } + } + } + } + "one_time_keys": { + "<algorithm>:<key_id>": "<key_base64>" + }, + + } + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "device_id": "dehydrated_device_id" + } + + DELETE /org.matrix.msc3814.v1/dehydrated_device + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "device_id": "dehydrated_device_id", + } + """ + + PATTERNS = [ + *client_patterns("/org.matrix.msc3814.v1/dehydrated_device$", releases=()), + ] + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.hs = hs + self.auth = hs.get_auth() + handler = hs.get_device_handler() + assert isinstance(handler, DeviceHandler) + self.e2e_keys_handler = hs.get_e2e_keys_handler() + self.device_handler = handler + + if hs.config.worker.worker_app is None: + # if main process + self.key_uploader = self.e2e_keys_handler.upload_keys_for_user + else: + # then a worker + self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs) + + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + dehydrated_device = await self.device_handler.get_dehydrated_device( + requester.user.to_string() + ) + + if dehydrated_device is not None: + (device_id, device_data) = dehydrated_device + result = {"device_id": device_id, "device_data": device_data} + return 200, result + else: + raise errors.NotFoundError("No dehydrated device available") + + async def on_DELETE(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + dehydrated_device = await self.device_handler.get_dehydrated_device( + requester.user.to_string() + ) + + if dehydrated_device is not None: + (device_id, device_data) = dehydrated_device + + result = await self.device_handler.rehydrate_device( + requester.user.to_string(), + self.auth.get_access_token_from_request(request), + device_id, + ) + + result = {"device_id": device_id} + + return 200, result + else: + raise errors.NotFoundError("No dehydrated device available") + + class PutBody(RequestBodyModel): + device_data: DehydratedDeviceDataModel + device_id: StrictStr + initial_device_display_name: Optional[StrictStr] + + class Config: + extra = Extra.allow + + async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + submission = parse_and_validate_json_object_from_request(request, self.PutBody) + requester = await self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + device_info = submission.dict() + if "device_keys" not in device_info.keys(): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Device key(s) not found, these must be provided.", + ) + + # TODO: Those two operations, creating a device and storing the + # device's keys should be atomic. + device_id = await self.device_handler.store_dehydrated_device( + requester.user.to_string(), + submission.device_id, + submission.device_data.dict(), + submission.initial_device_display_name, + ) + + # TODO: Do we need to do something with the result here? + await self.key_uploader( + user_id=user_id, device_id=submission.device_id, keys=submission.dict() + ) + + return 200, {"device_id": device_id} + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: if ( hs.config.worker.worker_app is None @@ -354,7 +569,12 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ): DeleteDevicesRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) + if hs.config.worker.worker_app is None: DeviceRestServlet(hs).register(http_server) - DehydratedDeviceServlet(hs).register(http_server) - ClaimDehydratedDeviceServlet(hs).register(http_server) + if hs.config.experimental.msc2697_enabled: + DehydratedDeviceServlet(hs, msc2697=True).register(http_server) + ClaimDehydratedDeviceServlet(hs).register(http_server) + if hs.config.experimental.msc3814_enabled: + DehydratedDeviceV2Servlet(hs).register(http_server) + DehydratedDeviceEventsServlet(hs).register(http_server) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 951bd033f5..dc498001e4 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -1117,7 +1117,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet): # Ensure the redacts property in the content matches the one provided in # the URL. room_version = await self._store.get_room_version(room_id) - if room_version.msc2176_redaction_rules: + if room_version.updated_redaction_rules: if "redacts" in content and content["redacts"] != event_id: raise SynapseError( 400, @@ -1151,7 +1151,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet): "sender": requester.user.to_string(), } # Earlier room versions had a top-level redacts property. - if not room_version.msc2176_redaction_rules: + if not room_version.updated_redaction_rules: event_dict["redacts"] = event_id ( diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 9bc0c3b7b9..1b91cf5eaa 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -268,8 +268,7 @@ class StateHandler: The hosts in the room at the given events """ entry = await self.resolve_state_groups_for_events(room_id, event_ids) - state = await entry.get_state(self._state_storage_controller, StateFilter.all()) - return await self.store.get_joined_hosts(room_id, state, entry) + return await self._state_storage_controller.get_joined_hosts(room_id, entry) @trace @tag_args diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 1b9d7d8457..44c49274a9 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -667,7 +667,7 @@ async def _mainline_sort( order_map = {} for idx, ev_id in enumerate(event_ids, start=1): depth = await _get_mainline_depth_for_event( - event_map[ev_id], mainline_map, event_map, state_res_store + clock, event_map[ev_id], mainline_map, event_map, state_res_store ) order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id) @@ -682,6 +682,7 @@ async def _mainline_sort( async def _get_mainline_depth_for_event( + clock: Clock, event: EventBase, mainline_map: Dict[str, int], event_map: Dict[str, EventBase], @@ -704,6 +705,7 @@ async def _get_mainline_depth_for_event( # We do an iterative search, replacing `event with the power level in its # auth events (if any) + idx = 0 while tmp_event: depth = mainline_map.get(tmp_event.event_id) if depth is not None: @@ -720,6 +722,11 @@ async def _get_mainline_depth_for_event( tmp_event = aev break + idx += 1 + + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) + # Didn't find a power level auth event, so we just return 0 return 0 diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 233df7cce2..278c7832ba 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from itertools import chain from typing import ( TYPE_CHECKING, AbstractSet, @@ -19,14 +20,16 @@ from typing import ( Callable, Collection, Dict, + FrozenSet, Iterable, List, Mapping, Optional, Tuple, + Union, ) -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase from synapse.logging.opentracing import tag_args, trace from synapse.storage.roommember import ProfileInfo @@ -34,14 +37,20 @@ from synapse.storage.util.partial_state_events_tracker import ( PartialCurrentStateTracker, PartialStateEventsTracker, ) -from synapse.types import MutableStateMap, StateMap +from synapse.types import MutableStateMap, StateMap, get_domain_from_id from synapse.types.state import StateFilter +from synapse.util.async_helpers import Linearizer +from synapse.util.caches import intern_string +from synapse.util.caches.descriptors import cached from synapse.util.cancellation import cancellable +from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.state import _StateCacheEntry from synapse.storage.databases import Databases + logger = logging.getLogger(__name__) @@ -52,10 +61,15 @@ class StateStorageController: def __init__(self, hs: "HomeServer", stores: "Databases"): self._is_mine_id = hs.is_mine_id + self._clock = hs.get_clock() self.stores = stores self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main) + # Used by `_get_joined_hosts` to ensure only one thing mutates the cache + # at a time. Keyed by room_id. + self._joined_host_linearizer = Linearizer("_JoinedHostsCache") + def notify_event_un_partial_stated(self, event_id: str) -> None: self._partial_state_events_tracker.notify_un_partial_stated(event_id) @@ -627,3 +641,122 @@ class StateStorageController: await self._partial_state_room_tracker.await_full_state(room_id) return await self.stores.main.get_users_in_room_with_profiles(room_id) + + async def get_joined_hosts( + self, room_id: str, state_entry: "_StateCacheEntry" + ) -> FrozenSet[str]: + state_group: Union[object, int] = state_entry.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + assert state_group is not None + with Measure(self._clock, "get_joined_hosts"): + return await self._get_joined_hosts( + room_id, state_group, state_entry=state_entry + ) + + @cached(num_args=2, max_entries=10000, iterable=True) + async def _get_joined_hosts( + self, + room_id: str, + state_group: Union[object, int], + state_entry: "_StateCacheEntry", + ) -> FrozenSet[str]: + # We don't use `state_group`, it's there so that we can cache based on + # it. However, its important that its never None, since two + # current_state's with a state_group of None are likely to be different. + # + # The `state_group` must match the `state_entry.state_group` (if not None). + assert state_group is not None + assert state_entry.state_group is None or state_entry.state_group == state_group + + # We use a secondary cache of previous work to allow us to build up the + # joined hosts for the given state group based on previous state groups. + # + # We cache one object per room containing the results of the last state + # group we got joined hosts for. The idea is that generally + # `get_joined_hosts` is called with the "current" state group for the + # room, and so consecutive calls will be for consecutive state groups + # which point to the previous state group. + cache = await self.stores.main._get_joined_hosts_cache(room_id) + + # If the state group in the cache matches, we already have the data we need. + if state_entry.state_group == cache.state_group: + return frozenset(cache.hosts_to_joined_users) + + # Since we'll mutate the cache we need to lock. + async with self._joined_host_linearizer.queue(room_id): + if state_entry.state_group == cache.state_group: + # Same state group, so nothing to do. We've already checked for + # this above, but the cache may have changed while waiting on + # the lock. + pass + elif state_entry.prev_group == cache.state_group: + # The cached work is for the previous state group, so we work out + # the delta. + assert state_entry.delta_ids is not None + for (typ, state_key), event_id in state_entry.delta_ids.items(): + if typ != EventTypes.Member: + continue + + host = intern_string(get_domain_from_id(state_key)) + user_id = state_key + known_joins = cache.hosts_to_joined_users.setdefault(host, set()) + + event = await self.stores.main.get_event(event_id) + if event.membership == Membership.JOIN: + known_joins.add(user_id) + else: + known_joins.discard(user_id) + + if not known_joins: + cache.hosts_to_joined_users.pop(host, None) + else: + # The cache doesn't match the state group or prev state group, + # so we calculate the result from first principles. + # + # We need to fetch all hosts joined to the room according to `state` by + # inspecting all join memberships in `state`. However, if the `state` is + # relatively recent then many of its events are likely to be held in + # the current state of the room, which is easily available and likely + # cached. + # + # We therefore compute the set of `state` events not in the + # current state and only fetch those. + current_memberships = ( + await self.stores.main._get_approximate_current_memberships_in_room( + room_id + ) + ) + unknown_state_events = {} + joined_users_in_current_state = [] + + state = await state_entry.get_state( + self, StateFilter.from_types([(EventTypes.Member, None)]) + ) + + for (type, state_key), event_id in state.items(): + if event_id not in current_memberships: + unknown_state_events[type, state_key] = event_id + elif current_memberships[event_id] == Membership.JOIN: + joined_users_in_current_state.append(state_key) + + joined_user_ids = await self.stores.main.get_joined_user_ids_from_state( + room_id, unknown_state_events + ) + + cache.hosts_to_joined_users = {} + for user_id in chain(joined_user_ids, joined_users_in_current_state): + host = intern_string(get_domain_from_id(user_id)) + cache.hosts_to_joined_users.setdefault(host, set()).add(user_id) + + if state_entry.state_group: + cache.state_group = state_entry.state_group + else: + cache.state_group = object() + + return frozenset(cache.hosts_to_joined_users) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index c9d687fb2f..a1c8fb0f46 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -98,8 +98,6 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { "event_push_summary": "event_push_summary_unique_index2", "receipts_linearized": "receipts_linearized_unique_index", "receipts_graph": "receipts_graph_unique_index", - "profiles": "profiles_full_user_id_key_idx", - "user_filters": "full_users_filters_unique_idx", } diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index b6028853c9..be67d1ff22 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, List, Optional, Tuple, cast from synapse.api.constants import Direction from synapse.config.homeserver import HomeServerConfig @@ -196,7 +196,7 @@ class DataStore( txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]: filters = [] - args: List[Union[str, int]] = [] + args: list = [] # Set ordering order_by_column = UserSortOrder(order_by).value diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index 75f7fe8756..fff417f9e3 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -188,13 +188,14 @@ class FilteringWorkerStore(SQLBaseStore): filter_id = max_id + 1 sql = ( - "INSERT INTO user_filters (full_user_id, filter_id, filter_json)" - "VALUES(?, ?, ?)" + "INSERT INTO user_filters (full_user_id, user_id, filter_id, filter_json)" + "VALUES(?, ?, ?, ?)" ) txn.execute( sql, ( user_id.to_string(), + user_id.localpart, filter_id, bytearray(def_json), ), diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 660a5507b7..3ba9cc8853 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -173,9 +173,10 @@ class ProfileWorkerStore(SQLBaseStore): ) async def create_profile(self, user_id: UserID) -> None: + user_localpart = user_id.localpart await self.db_pool.simple_insert( table="profiles", - values={"full_user_id": user_id.to_string()}, + values={"user_id": user_localpart, "full_user_id": user_id.to_string()}, desc="create_profile", ) @@ -190,11 +191,13 @@ class ProfileWorkerStore(SQLBaseStore): new_displayname: The new display name. If this is None, the user's display name is removed. """ + user_localpart = user_id.localpart await self.db_pool.simple_upsert( table="profiles", - keyvalues={"full_user_id": user_id.to_string()}, + keyvalues={"user_id": user_localpart}, values={ "displayname": new_displayname, + "full_user_id": user_id.to_string(), }, desc="set_profile_displayname", ) @@ -210,10 +213,11 @@ class ProfileWorkerStore(SQLBaseStore): new_avatar_url: The new avatar URL. If this is None, the user's avatar is removed. """ + user_localpart = user_id.localpart await self.db_pool.simple_upsert( table="profiles", - keyvalues={"full_user_id": user_id.to_string()}, - values={"avatar_url": new_avatar_url}, + keyvalues={"user_id": user_localpart}, + values={"avatar_url": new_avatar_url, "full_user_id": user_id.to_string()}, desc="set_profile_avatar_url", ) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index ca8be8c80d..830658f328 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -2136,7 +2136,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): raise StoreError(400, "No create event in state") # Before MSC2175, the room creator was a separate field. - if not room_version.msc2175_implicit_room_creator: + if not room_version.implicit_room_creator: room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) if not isinstance(room_creator, str): diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 582875c91a..fff259f74c 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from itertools import chain from typing import ( TYPE_CHECKING, AbstractSet, @@ -57,15 +56,12 @@ from synapse.types import ( StrCollection, get_domain_from_id, ) -from synapse.util.async_helpers import Linearizer -from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer - from synapse.state import _StateCacheEntry logger = logging.getLogger(__name__) @@ -91,10 +87,6 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ): super().__init__(database, db_conn, hs) - # Used by `_get_joined_hosts` to ensure only one thing mutates the cache - # at a time. Keyed by room_id. - self._joined_host_linearizer = Linearizer("_JoinedHostsCache") - self._server_notices_mxid = hs.config.servernotices.server_notices_mxid if ( @@ -1057,120 +1049,6 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): "get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn ) - async def get_joined_hosts( - self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry" - ) -> FrozenSet[str]: - state_group: Union[object, int] = state_entry.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - assert state_group is not None - with Measure(self._clock, "get_joined_hosts"): - return await self._get_joined_hosts( - room_id, state_group, state, state_entry=state_entry - ) - - @cached(num_args=2, max_entries=10000, iterable=True) - async def _get_joined_hosts( - self, - room_id: str, - state_group: Union[object, int], - state: StateMap[str], - state_entry: "_StateCacheEntry", - ) -> FrozenSet[str]: - # We don't use `state_group`, it's there so that we can cache based on - # it. However, its important that its never None, since two - # current_state's with a state_group of None are likely to be different. - # - # The `state_group` must match the `state_entry.state_group` (if not None). - assert state_group is not None - assert state_entry.state_group is None or state_entry.state_group == state_group - - # We use a secondary cache of previous work to allow us to build up the - # joined hosts for the given state group based on previous state groups. - # - # We cache one object per room containing the results of the last state - # group we got joined hosts for. The idea is that generally - # `get_joined_hosts` is called with the "current" state group for the - # room, and so consecutive calls will be for consecutive state groups - # which point to the previous state group. - cache = await self._get_joined_hosts_cache(room_id) - - # If the state group in the cache matches, we already have the data we need. - if state_entry.state_group == cache.state_group: - return frozenset(cache.hosts_to_joined_users) - - # Since we'll mutate the cache we need to lock. - async with self._joined_host_linearizer.queue(room_id): - if state_entry.state_group == cache.state_group: - # Same state group, so nothing to do. We've already checked for - # this above, but the cache may have changed while waiting on - # the lock. - pass - elif state_entry.prev_group == cache.state_group: - # The cached work is for the previous state group, so we work out - # the delta. - assert state_entry.delta_ids is not None - for (typ, state_key), event_id in state_entry.delta_ids.items(): - if typ != EventTypes.Member: - continue - - host = intern_string(get_domain_from_id(state_key)) - user_id = state_key - known_joins = cache.hosts_to_joined_users.setdefault(host, set()) - - event = await self.get_event(event_id) - if event.membership == Membership.JOIN: - known_joins.add(user_id) - else: - known_joins.discard(user_id) - - if not known_joins: - cache.hosts_to_joined_users.pop(host, None) - else: - # The cache doesn't match the state group or prev state group, - # so we calculate the result from first principles. - # - # We need to fetch all hosts joined to the room according to `state` by - # inspecting all join memberships in `state`. However, if the `state` is - # relatively recent then many of its events are likely to be held in - # the current state of the room, which is easily available and likely - # cached. - # - # We therefore compute the set of `state` events not in the - # current state and only fetch those. - current_memberships = ( - await self._get_approximate_current_memberships_in_room(room_id) - ) - unknown_state_events = {} - joined_users_in_current_state = [] - - for (type, state_key), event_id in state.items(): - if event_id not in current_memberships: - unknown_state_events[type, state_key] = event_id - elif current_memberships[event_id] == Membership.JOIN: - joined_users_in_current_state.append(state_key) - - joined_user_ids = await self.get_joined_user_ids_from_state( - room_id, unknown_state_events - ) - - cache.hosts_to_joined_users = {} - for user_id in chain(joined_user_ids, joined_users_in_current_state): - host = intern_string(get_domain_from_id(user_id)) - cache.hosts_to_joined_users.setdefault(host, set()).add(user_id) - - if state_entry.state_group: - cache.state_group = state_entry.state_group - else: - cache.state_group = object() - - return frozenset(cache.hosts_to_joined_users) - async def _get_approximate_current_memberships_in_room( self, room_id: str ) -> Mapping[str, Optional[str]]: diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 97c4dc2603..f34b7ce8f4 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -697,7 +697,7 @@ class StatsStore(StateDeltasStore): txn: LoggingTransaction, ) -> Tuple[List[JsonDict], int]: filters = [] - args = [self.hs.config.server.server_name] + args: list = [] if search_term: filters.append("(lmr.user_id LIKE ? OR displayname LIKE ?)") @@ -733,7 +733,7 @@ class StatsStore(StateDeltasStore): sql_base = """ FROM local_media_repository as lmr - LEFT JOIN profiles AS p ON lmr.user_id = '@' || p.user_id || ':' || ? + LEFT JOIN profiles AS p ON lmr.user_id = p.full_user_id {} GROUP BY lmr.user_id, displayname """.format( diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 924022c95c..2a136f2ff6 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -409,23 +409,22 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): txn, users_to_work_on ) - # Next fetch their profiles. Note that the `user_id` here is the - # *localpart*, and that not all users have profiles. + # Next fetch their profiles. Note that not all users have profiles. profile_rows = self.db_pool.simple_select_many_txn( txn, table="profiles", - column="user_id", - iterable=[get_localpart_from_id(u) for u in users_to_insert], + column="full_user_id", + iterable=list(users_to_insert), retcols=( - "user_id", + "full_user_id", "displayname", "avatar_url", ), keyvalues={}, ) profiles = { - f"@{row['user_id']}:{self.server_name}": _UserDirProfile( - f"@{row['user_id']}:{self.server_name}", + row["full_user_id"]: _UserDirProfile( + row["full_user_id"], row["displayname"], row["avatar_url"], ) diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 6d14963c0a..d3ec648f6d 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -108,7 +108,8 @@ Changes in SCHEMA_VERSION = 78 - Validate check (full_user_id IS NOT NULL) on tables profiles and user_filters Changes in SCHEMA_VERSION = 79 - - We no longer write to column user_id of tables profiles and user_filters + - Add tables to handle in DB read-write locks. + - Add some mitigations for a painful race between foreground and background updates, cf #15677. """ @@ -121,9 +122,7 @@ SCHEMA_COMPAT_VERSION = ( # # insertions to the column `full_user_id` of tables profiles and user_filters can no # longer be null - # - # we no longer write to column `full_user_id` of tables profiles and user_filters - 78 + 76 ) """Limit on how far the synapse codebase can be rolled back without breaking db compat diff --git a/synapse/storage/schema/main/delta/79/01_drop_user_id_constraint_profiles.py b/synapse/storage/schema/main/delta/79/01_drop_user_id_constraint_profiles.py deleted file mode 100644 index 3541266f7d..0000000000 --- a/synapse/storage/schema/main/delta/79/01_drop_user_id_constraint_profiles.py +++ /dev/null @@ -1,50 +0,0 @@ -from synapse.storage.database import LoggingTransaction -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine - - -def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: - """ - Update to drop the NOT NULL constraint on column user_id so that we can cease to - write to it without inserts to other columns triggering the constraint - """ - - if isinstance(database_engine, PostgresEngine): - drop_sql = """ - ALTER TABLE profiles ALTER COLUMN user_id DROP NOT NULL - """ - cur.execute(drop_sql) - else: - # irritatingly in SQLite we need to rewrite the table to drop the constraint. - cur.execute("DROP TABLE IF EXISTS temp_profiles") - - create_sql = """ - CREATE TABLE temp_profiles ( - full_user_id text NOT NULL, - user_id text, - displayname text, - avatar_url text, - UNIQUE (full_user_id), - UNIQUE (user_id) - ) - """ - cur.execute(create_sql) - - copy_sql = """ - INSERT INTO temp_profiles ( - user_id, - displayname, - avatar_url, - full_user_id) - SELECT user_id, displayname, avatar_url, full_user_id FROM profiles - """ - cur.execute(copy_sql) - - drop_sql = """ - DROP TABLE profiles - """ - cur.execute(drop_sql) - - rename_sql = """ - ALTER TABLE temp_profiles RENAME to profiles - """ - cur.execute(rename_sql) diff --git a/synapse/storage/schema/main/delta/79/02_drop_user_id_constraint_user_filters.py b/synapse/storage/schema/main/delta/79/02_drop_user_id_constraint_user_filters.py deleted file mode 100644 index 8e7569c470..0000000000 --- a/synapse/storage/schema/main/delta/79/02_drop_user_id_constraint_user_filters.py +++ /dev/null @@ -1,54 +0,0 @@ -from synapse.storage.database import LoggingTransaction -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine - - -def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: - """ - Update to drop the NOT NULL constraint on column user_id so that we can cease to - write to it without inserts to other columns triggering the constraint - """ - if isinstance(database_engine, PostgresEngine): - drop_sql = """ - ALTER TABLE user_filters ALTER COLUMN user_id DROP NOT NULL - """ - cur.execute(drop_sql) - - else: - # irritatingly in SQLite we need to rewrite the table to drop the constraint. - cur.execute("DROP TABLE IF EXISTS temp_user_filters") - - create_sql = """ - CREATE TABLE temp_user_filters ( - full_user_id text NOT NULL, - user_id text, - filter_id bigint NOT NULL, - filter_json bytea NOT NULL - ) - """ - cur.execute(create_sql) - - index_sql = """ - CREATE UNIQUE INDEX IF NOT EXISTS user_filters_full_user_id_unique ON - temp_user_filters (full_user_id, filter_id) - """ - cur.execute(index_sql) - - copy_sql = """ - INSERT INTO temp_user_filters ( - user_id, - filter_id, - filter_json, - full_user_id) - SELECT user_id, filter_id, filter_json, full_user_id FROM user_filters - """ - cur.execute(copy_sql) - - drop_sql = """ - DROP TABLE user_filters - """ - cur.execute(drop_sql) - - rename_sql = """ - ALTER TABLE temp_user_filters RENAME to user_filters - """ - cur.execute(rename_sql) diff --git a/synapse/storage/schema/main/delta/78/04_read_write_locks_triggers.sql.postgres b/synapse/storage/schema/main/delta/79/03_read_write_locks_triggers.sql.postgres index e1a41be9c9..7df07ab0da 100644 --- a/synapse/storage/schema/main/delta/78/04_read_write_locks_triggers.sql.postgres +++ b/synapse/storage/schema/main/delta/79/03_read_write_locks_triggers.sql.postgres @@ -44,7 +44,7 @@ -- A table to track whether a lock is currently acquired, and if so whether its -- in read or write mode. -CREATE TABLE worker_read_write_locks_mode ( +CREATE TABLE IF NOT EXISTS worker_read_write_locks_mode ( lock_name TEXT NOT NULL, lock_key TEXT NOT NULL, -- Whether this lock is in read (false) or write (true) mode @@ -55,14 +55,14 @@ CREATE TABLE worker_read_write_locks_mode ( ); -- Ensure that we can only have one row per lock -CREATE UNIQUE INDEX worker_read_write_locks_mode_key ON worker_read_write_locks_mode (lock_name, lock_key); +CREATE UNIQUE INDEX IF NOT EXISTS worker_read_write_locks_mode_key ON worker_read_write_locks_mode (lock_name, lock_key); -- We need this (redundant) constraint so that we can have a foreign key -- constraint against this table. -CREATE UNIQUE INDEX worker_read_write_locks_mode_type ON worker_read_write_locks_mode (lock_name, lock_key, write_lock); +CREATE UNIQUE INDEX IF NOT EXISTS worker_read_write_locks_mode_type ON worker_read_write_locks_mode (lock_name, lock_key, write_lock); -- A table to track who has currently acquired a given lock. -CREATE TABLE worker_read_write_locks ( +CREATE TABLE IF NOT EXISTS worker_read_write_locks ( lock_name TEXT NOT NULL, lock_key TEXT NOT NULL, -- We write the instance name to ease manual debugging, we don't ever read @@ -84,9 +84,9 @@ CREATE TABLE worker_read_write_locks ( FOREIGN KEY (lock_name, lock_key, write_lock) REFERENCES worker_read_write_locks_mode (lock_name, lock_key, write_lock) ); -CREATE UNIQUE INDEX worker_read_write_locks_key ON worker_read_write_locks (lock_name, lock_key, token); +CREATE UNIQUE INDEX IF NOT EXISTS worker_read_write_locks_key ON worker_read_write_locks (lock_name, lock_key, token); -- Ensures that only one instance can acquire a lock in write mode at a time. -CREATE UNIQUE INDEX worker_read_write_locks_write ON worker_read_write_locks (lock_name, lock_key) WHERE write_lock; +CREATE UNIQUE INDEX IF NOT EXISTS worker_read_write_locks_write ON worker_read_write_locks (lock_name, lock_key) WHERE write_lock; -- Add a foreign key constraint to ensure that if a lock is in @@ -97,56 +97,6 @@ CREATE UNIQUE INDEX worker_read_write_locks_write ON worker_read_write_locks (lo -- We only add to PostgreSQL as SQLite does not support adding constraints -- after table creation, and so doesn't support "circular" foreign key -- constraints. +ALTER TABLE worker_read_write_locks_mode DROP CONSTRAINT IF EXISTS worker_read_write_locks_mode_foreign; ALTER TABLE worker_read_write_locks_mode ADD CONSTRAINT worker_read_write_locks_mode_foreign FOREIGN KEY (lock_name, lock_key, token) REFERENCES worker_read_write_locks(lock_name, lock_key, token) DEFERRABLE INITIALLY DEFERRED; - - --- Add a trigger to UPSERT into `worker_read_write_locks_mode` whenever we try --- and acquire a lock, i.e. insert into `worker_read_write_locks`, -CREATE OR REPLACE FUNCTION upsert_read_write_lock_parent() RETURNS trigger AS $$ -BEGIN - INSERT INTO worker_read_write_locks_mode (lock_name, lock_key, write_lock, token) - VALUES (NEW.lock_name, NEW.lock_key, NEW.write_lock, NEW.token) - ON CONFLICT (lock_name, lock_key) - DO NOTHING; - RETURN NEW; -END -$$ -LANGUAGE plpgsql; - -CREATE TRIGGER upsert_read_write_lock_parent_trigger BEFORE INSERT ON worker_read_write_locks - FOR EACH ROW - EXECUTE PROCEDURE upsert_read_write_lock_parent(); - - --- Ensure that we keep `worker_read_write_locks_mode` up to date whenever a lock --- is released (i.e. a row deleted from `worker_read_write_locks`). Either we --- update the `worker_read_write_locks_mode.token` to match another instance --- that has currently acquired the lock, or we delete the row if nobody has --- currently acquired a lock. -CREATE OR REPLACE FUNCTION delete_read_write_lock_parent() RETURNS trigger AS $$ -DECLARE - new_token TEXT; -BEGIN - SELECT token INTO new_token FROM worker_read_write_locks - WHERE - lock_name = OLD.lock_name - AND lock_key = OLD.lock_key; - - IF NOT FOUND THEN - DELETE FROM worker_read_write_locks_mode - WHERE lock_name = OLD.lock_name AND lock_key = OLD.lock_key; - ELSE - UPDATE worker_read_write_locks_mode - SET token = new_token - WHERE lock_name = OLD.lock_name AND lock_key = OLD.lock_key; - END IF; - - RETURN NEW; -END -$$ -LANGUAGE plpgsql; - -CREATE TRIGGER delete_read_write_lock_parent_trigger AFTER DELETE ON worker_read_write_locks - FOR EACH ROW - EXECUTE PROCEDURE delete_read_write_lock_parent(); diff --git a/synapse/storage/schema/main/delta/78/04_read_write_locks_triggers.sql.sqlite b/synapse/storage/schema/main/delta/79/03_read_write_locks_triggers.sql.sqlite index be2dfbbb8a..95f9dbf120 100644 --- a/synapse/storage/schema/main/delta/78/04_read_write_locks_triggers.sql.sqlite +++ b/synapse/storage/schema/main/delta/79/03_read_write_locks_triggers.sql.sqlite @@ -22,7 +22,7 @@ -- A table to track whether a lock is currently acquired, and if so whether its -- in read or write mode. -CREATE TABLE worker_read_write_locks_mode ( +CREATE TABLE IF NOT EXISTS worker_read_write_locks_mode ( lock_name TEXT NOT NULL, lock_key TEXT NOT NULL, -- Whether this lock is in read (false) or write (true) mode @@ -38,14 +38,14 @@ CREATE TABLE worker_read_write_locks_mode ( ); -- Ensure that we can only have one row per lock -CREATE UNIQUE INDEX worker_read_write_locks_mode_key ON worker_read_write_locks_mode (lock_name, lock_key); +CREATE UNIQUE INDEX IF NOT EXISTS worker_read_write_locks_mode_key ON worker_read_write_locks_mode (lock_name, lock_key); -- We need this (redundant) constraint so that we can have a foreign key -- constraint against this table. -CREATE UNIQUE INDEX worker_read_write_locks_mode_type ON worker_read_write_locks_mode (lock_name, lock_key, write_lock); +CREATE UNIQUE INDEX IF NOT EXISTS worker_read_write_locks_mode_type ON worker_read_write_locks_mode (lock_name, lock_key, write_lock); -- A table to track who has currently acquired a given lock. -CREATE TABLE worker_read_write_locks ( +CREATE TABLE IF NOT EXISTS worker_read_write_locks ( lock_name TEXT NOT NULL, lock_key TEXT NOT NULL, -- We write the instance name to ease manual debugging, we don't ever read @@ -67,53 +67,6 @@ CREATE TABLE worker_read_write_locks ( FOREIGN KEY (lock_name, lock_key, write_lock) REFERENCES worker_read_write_locks_mode (lock_name, lock_key, write_lock) ); -CREATE UNIQUE INDEX worker_read_write_locks_key ON worker_read_write_locks (lock_name, lock_key, token); +CREATE UNIQUE INDEX IF NOT EXISTS worker_read_write_locks_key ON worker_read_write_locks (lock_name, lock_key, token); -- Ensures that only one instance can acquire a lock in write mode at a time. -CREATE UNIQUE INDEX worker_read_write_locks_write ON worker_read_write_locks (lock_name, lock_key) WHERE write_lock; - - --- Add a trigger to UPSERT into `worker_read_write_locks_mode` whenever we try --- and acquire a lock, i.e. insert into `worker_read_write_locks`, -CREATE TRIGGER IF NOT EXISTS upsert_read_write_lock_parent_trigger -BEFORE INSERT ON worker_read_write_locks -FOR EACH ROW -BEGIN - -- First ensure that `worker_read_write_locks_mode` doesn't have stale - -- entries in it, as on SQLite we don't have the foreign key constraint to - -- enforce this. - DELETE FROM worker_read_write_locks_mode - WHERE lock_name = NEW.lock_name AND lock_key = NEW.lock_key - AND NOT EXISTS ( - SELECT 1 FROM worker_read_write_locks - WHERE lock_name = NEW.lock_name AND lock_key = NEW.lock_key - ); - - INSERT INTO worker_read_write_locks_mode (lock_name, lock_key, write_lock, token) - VALUES (NEW.lock_name, NEW.lock_key, NEW.write_lock, NEW.token) - ON CONFLICT (lock_name, lock_key) - DO NOTHING; -END; - --- Ensure that we keep `worker_read_write_locks_mode` up to date whenever a lock --- is released (i.e. a row deleted from `worker_read_write_locks`). Either we --- update the `worker_read_write_locks_mode.token` to match another instance --- that has currently acquired the lock, or we delete the row if nobody has --- currently acquired a lock. -CREATE TRIGGER IF NOT EXISTS delete_read_write_lock_parent_trigger -AFTER DELETE ON worker_read_write_locks -FOR EACH ROW -BEGIN - DELETE FROM worker_read_write_locks_mode - WHERE lock_name = OLD.lock_name AND lock_key = OLD.lock_key - AND NOT EXISTS ( - SELECT 1 FROM worker_read_write_locks - WHERE lock_name = OLD.lock_name AND lock_key = OLD.lock_key - ); - - UPDATE worker_read_write_locks_mode - SET token = ( - SELECT token FROM worker_read_write_locks - WHERE lock_name = OLD.lock_name AND lock_key = OLD.lock_key - ) - WHERE lock_name = OLD.lock_name AND lock_key = OLD.lock_key; -END; +CREATE UNIQUE INDEX IF NOT EXISTS worker_read_write_locks_write ON worker_read_write_locks (lock_name, lock_key) WHERE write_lock; diff --git a/synapse/storage/schema/main/delta/79/04_mitigate_stream_ordering_update_race.py b/synapse/storage/schema/main/delta/79/04_mitigate_stream_ordering_update_race.py new file mode 100644 index 0000000000..ae63585847 --- /dev/null +++ b/synapse/storage/schema/main/delta/79/04_mitigate_stream_ordering_update_race.py @@ -0,0 +1,70 @@ +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine + + +def run_create( + cur: LoggingTransaction, + database_engine: BaseDatabaseEngine, +) -> None: + """ + An attempt to mitigate a painful race between foreground and background updates + touching the `stream_ordering` column of the events table. More info can be found + at https://github.com/matrix-org/synapse/issues/15677. + """ + + # technically the bg update we're concerned with below should only have been added in + # postgres but it doesn't hurt to be extra careful + if isinstance(database_engine, PostgresEngine): + select_sql = """ + SELECT 1 FROM background_updates + WHERE update_name = 'replace_stream_ordering_column' + """ + cur.execute(select_sql) + res = cur.fetchone() + + # if the background update `replace_stream_ordering_column` is still pending, we need + # to drop the indexes added in 7403, and re-add them to the column `stream_ordering2` + # with the idea that they will be preserved when the column is renamed `stream_ordering` + # after the background update has finished + if res: + drop_cse_sql = """ + ALTER TABLE current_state_events DROP CONSTRAINT IF EXISTS event_stream_ordering_fkey + """ + cur.execute(drop_cse_sql) + + drop_lcm_sql = """ + ALTER TABLE local_current_membership DROP CONSTRAINT IF EXISTS event_stream_ordering_fkey + """ + cur.execute(drop_lcm_sql) + + drop_rm_sql = """ + ALTER TABLE room_memberships DROP CONSTRAINT IF EXISTS event_stream_ordering_fkey + """ + cur.execute(drop_rm_sql) + + add_cse_sql = """ + ALTER TABLE current_state_events ADD CONSTRAINT event_stream_ordering_fkey + FOREIGN KEY (event_stream_ordering) REFERENCES events(stream_ordering2) NOT VALID; + """ + cur.execute(add_cse_sql) + + add_lcm_sql = """ + ALTER TABLE local_current_membership ADD CONSTRAINT event_stream_ordering_fkey + FOREIGN KEY (event_stream_ordering) REFERENCES events(stream_ordering2) NOT VALID; + """ + cur.execute(add_lcm_sql) + + add_rm_sql = """ + ALTER TABLE room_memberships ADD CONSTRAINT event_stream_ordering_fkey + FOREIGN KEY (event_stream_ordering) REFERENCES events(stream_ordering2) NOT VALID; + """ + cur.execute(add_rm_sql) diff --git a/synapse/storage/schema/main/delta/79/05_read_write_locks_triggers.sql.postgres b/synapse/storage/schema/main/delta/79/05_read_write_locks_triggers.sql.postgres new file mode 100644 index 0000000000..ea3496ef2d --- /dev/null +++ b/synapse/storage/schema/main/delta/79/05_read_write_locks_triggers.sql.postgres @@ -0,0 +1,69 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Fix up the triggers that were in `78/04_read_write_locks_triggers.sql` + +-- Add a trigger to UPSERT into `worker_read_write_locks_mode` whenever we try +-- and acquire a lock, i.e. insert into `worker_read_write_locks`, +CREATE OR REPLACE FUNCTION upsert_read_write_lock_parent() RETURNS trigger AS $$ +BEGIN + INSERT INTO worker_read_write_locks_mode (lock_name, lock_key, write_lock, token) + VALUES (NEW.lock_name, NEW.lock_key, NEW.write_lock, NEW.token) + ON CONFLICT (lock_name, lock_key) + DO UPDATE SET write_lock = NEW.write_lock, token = NEW.token; + RETURN NEW; +END +$$ +LANGUAGE plpgsql; + +DROP TRIGGER IF EXISTS upsert_read_write_lock_parent_trigger ON worker_read_write_locks; +CREATE TRIGGER upsert_read_write_lock_parent_trigger BEFORE INSERT ON worker_read_write_locks + FOR EACH ROW + EXECUTE PROCEDURE upsert_read_write_lock_parent(); + + +-- Ensure that we keep `worker_read_write_locks_mode` up to date whenever a lock +-- is released (i.e. a row deleted from `worker_read_write_locks`). Either we +-- update the `worker_read_write_locks_mode.token` to match another instance +-- that has currently acquired the lock, or we delete the row if nobody has +-- currently acquired a lock. +CREATE OR REPLACE FUNCTION delete_read_write_lock_parent() RETURNS trigger AS $$ +DECLARE + new_token TEXT; +BEGIN + SELECT token INTO new_token FROM worker_read_write_locks + WHERE + lock_name = OLD.lock_name + AND lock_key = OLD.lock_key + LIMIT 1 FOR UPDATE; + + IF NOT FOUND THEN + DELETE FROM worker_read_write_locks_mode + WHERE lock_name = OLD.lock_name AND lock_key = OLD.lock_key AND token = OLD.token; + ELSE + UPDATE worker_read_write_locks_mode + SET token = new_token + WHERE lock_name = OLD.lock_name AND lock_key = OLD.lock_key; + END IF; + + RETURN NEW; +END +$$ +LANGUAGE plpgsql; + +DROP TRIGGER IF EXISTS delete_read_write_lock_parent_trigger ON worker_read_write_locks; +CREATE TRIGGER delete_read_write_lock_parent_trigger AFTER DELETE ON worker_read_write_locks + FOR EACH ROW + EXECUTE PROCEDURE delete_read_write_lock_parent(); diff --git a/synapse/storage/schema/main/delta/79/05_read_write_locks_triggers.sql.sqlite b/synapse/storage/schema/main/delta/79/05_read_write_locks_triggers.sql.sqlite new file mode 100644 index 0000000000..acb1a77c80 --- /dev/null +++ b/synapse/storage/schema/main/delta/79/05_read_write_locks_triggers.sql.sqlite @@ -0,0 +1,65 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Fix up the triggers that were in `78/04_read_write_locks_triggers.sql` + +-- Add a trigger to UPSERT into `worker_read_write_locks_mode` whenever we try +-- and acquire a lock, i.e. insert into `worker_read_write_locks`, +DROP TRIGGER IF EXISTS upsert_read_write_lock_parent_trigger; +CREATE TRIGGER IF NOT EXISTS upsert_read_write_lock_parent_trigger +BEFORE INSERT ON worker_read_write_locks +FOR EACH ROW +BEGIN + -- First ensure that `worker_read_write_locks_mode` doesn't have stale + -- entries in it, as on SQLite we don't have the foreign key constraint to + -- enforce this. + DELETE FROM worker_read_write_locks_mode + WHERE lock_name = NEW.lock_name AND lock_key = NEW.lock_key + AND NOT EXISTS ( + SELECT 1 FROM worker_read_write_locks + WHERE lock_name = NEW.lock_name AND lock_key = NEW.lock_key + ); + + INSERT INTO worker_read_write_locks_mode (lock_name, lock_key, write_lock, token) + VALUES (NEW.lock_name, NEW.lock_key, NEW.write_lock, NEW.token) + ON CONFLICT (lock_name, lock_key) + DO UPDATE SET write_lock = NEW.write_lock, token = NEW.token; +END; + +-- Ensure that we keep `worker_read_write_locks_mode` up to date whenever a lock +-- is released (i.e. a row deleted from `worker_read_write_locks`). Either we +-- update the `worker_read_write_locks_mode.token` to match another instance +-- that has currently acquired the lock, or we delete the row if nobody has +-- currently acquired a lock. +DROP TRIGGER IF EXISTS delete_read_write_lock_parent_trigger; +CREATE TRIGGER IF NOT EXISTS delete_read_write_lock_parent_trigger +AFTER DELETE ON worker_read_write_locks +FOR EACH ROW +BEGIN + DELETE FROM worker_read_write_locks_mode + WHERE lock_name = OLD.lock_name AND lock_key = OLD.lock_key + AND token = OLD.token + AND NOT EXISTS ( + SELECT 1 FROM worker_read_write_locks + WHERE lock_name = OLD.lock_name AND lock_key = OLD.lock_key + ); + + UPDATE worker_read_write_locks_mode + SET token = ( + SELECT token FROM worker_read_write_locks + WHERE lock_name = OLD.lock_name AND lock_key = OLD.lock_key + ) + WHERE lock_name = OLD.lock_name AND lock_key = OLD.lock_key; +END; |