diff options
64 files changed, 1158 insertions, 220 deletions
diff --git a/changelog.d/9411.misc b/changelog.d/9411.misc new file mode 100644 index 0000000000..c3e6cfa5f1 --- /dev/null +++ b/changelog.d/9411.misc @@ -0,0 +1 @@ +Preparatory steps for removing redundant `outlier` data from `event_json.internal_metadata` column. diff --git a/changelog.d/9499.misc b/changelog.d/9499.misc new file mode 100644 index 0000000000..1513017a10 --- /dev/null +++ b/changelog.d/9499.misc @@ -0,0 +1 @@ +Introduce bugbear to the test suite and fix some of it's lint violations. \ No newline at end of file diff --git a/changelog.d/9585.bugfix b/changelog.d/9585.bugfix new file mode 100644 index 0000000000..de472ddfd1 --- /dev/null +++ b/changelog.d/9585.bugfix @@ -0,0 +1 @@ +Fix a longstanding bug that could cause issues when editing a reply to a message. \ No newline at end of file diff --git a/changelog.d/9588.bugfix b/changelog.d/9588.bugfix new file mode 100644 index 0000000000..b8d6140565 --- /dev/null +++ b/changelog.d/9588.bugfix @@ -0,0 +1 @@ +Fix the `/capabilities` endpoint to return `m.change_password` as disabled if the local password database is not used for authentication. Contributed by @dklimpel. diff --git a/changelog.d/9609.feature b/changelog.d/9609.feature new file mode 100644 index 0000000000..f3b6342069 --- /dev/null +++ b/changelog.d/9609.feature @@ -0,0 +1 @@ +Logins using OpenID Connect can require attributes on the `userinfo` response in order to login. Contributed by Hubbe King. diff --git a/changelog.d/9631.misc b/changelog.d/9631.misc new file mode 100644 index 0000000000..35338cd332 --- /dev/null +++ b/changelog.d/9631.misc @@ -0,0 +1 @@ +Add additional type hints to the Homeserver object. diff --git a/changelog.d/9634.misc b/changelog.d/9634.misc new file mode 100644 index 0000000000..59ac42cb83 --- /dev/null +++ b/changelog.d/9634.misc @@ -0,0 +1 @@ +Only save remote cross-signing and device keys if they're different from the current ones. diff --git a/changelog.d/9636.bugfix b/changelog.d/9636.bugfix new file mode 100644 index 0000000000..fa772ed6fc --- /dev/null +++ b/changelog.d/9636.bugfix @@ -0,0 +1 @@ +Checks if passwords are allowed before setting it for the user. \ No newline at end of file diff --git a/changelog.d/9637.misc b/changelog.d/9637.misc new file mode 100644 index 0000000000..90a27d9f8f --- /dev/null +++ b/changelog.d/9637.misc @@ -0,0 +1 @@ +Rename storage function to fix spelling and not conflict with another functions name. diff --git a/changelog.d/9638.misc b/changelog.d/9638.misc new file mode 100644 index 0000000000..35338cd332 --- /dev/null +++ b/changelog.d/9638.misc @@ -0,0 +1 @@ +Add additional type hints to the Homeserver object. diff --git a/changelog.d/9640.misc b/changelog.d/9640.misc new file mode 100644 index 0000000000..3d410ed4cd --- /dev/null +++ b/changelog.d/9640.misc @@ -0,0 +1 @@ +Improve performance of federation catch up by sending events the latest events in the room to the remote, rather than just the last event sent by the local server. diff --git a/changelog.d/9643.feature b/changelog.d/9643.feature new file mode 100644 index 0000000000..2f7ccedcfb --- /dev/null +++ b/changelog.d/9643.feature @@ -0,0 +1 @@ +Add initial experimental support for a "space summary" API. diff --git a/changelog.d/9644.feature b/changelog.d/9644.feature new file mode 100644 index 0000000000..556bcf0f9f --- /dev/null +++ b/changelog.d/9644.feature @@ -0,0 +1 @@ +Implement the busy presence state as described in [MSC3026](https://github.com/matrix-org/matrix-doc/pull/3026). diff --git a/changelog.d/9645.misc b/changelog.d/9645.misc new file mode 100644 index 0000000000..9a7ce364c1 --- /dev/null +++ b/changelog.d/9645.misc @@ -0,0 +1 @@ +In the `federation_client` commandline client, stop automatically adding the URL prefix, so that servlets on other prefixes can be tested. diff --git a/changelog.d/9647.misc b/changelog.d/9647.misc new file mode 100644 index 0000000000..303a8c6606 --- /dev/null +++ b/changelog.d/9647.misc @@ -0,0 +1 @@ +In the `federation_client` commandline client, handle inline `signing_key`s in `homeserver.yaml`. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 7de000f4a4..a9f59e39f7 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1873,6 +1873,24 @@ saml2_config: # which is set to the claims returned by the UserInfo Endpoint and/or # in the ID Token. # +# It is possible to configure Synapse to only allow logins if certain attributes +# match particular values in the OIDC userinfo. The requirements can be listed under +# `attribute_requirements` as shown below. All of the listed attributes must +# match for the login to be permitted. Additional attributes can be added to +# userinfo by expanding the `scopes` section of the OIDC config to retrieve +# additional information from the OIDC provider. +# +# If the OIDC claim is a list, then the attribute must match any value in the list. +# Otherwise, it must exactly match the value of the claim. Using the example +# below, the `family_name` claim MUST be "Stephensson", but the `groups` +# claim MUST contain "admin". +# +# attribute_requirements: +# - attribute: family_name +# value: "Stephensson" +# - attribute: groups +# value: "admin" +# # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md # for information on how to configure these options. # @@ -1905,6 +1923,9 @@ oidc_providers: # localpart_template: "{{ user.login }}" # display_name_template: "{{ user.name }}" # email_template: "{{ user.email }}" + # attribute_requirements: + # - attribute: userGroup + # value: "synapseUsers" # For use with Keycloak # @@ -1914,6 +1935,9 @@ oidc_providers: # client_id: "synapse" # client_secret: "copy secret generated in Keycloak UI" # scopes: ["openid", "profile"] + # attribute_requirements: + # - attribute: groups + # value: "admin" # For use with Github # diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py index abcec48c4f..6f76c08fcf 100755 --- a/scripts-dev/federation_client.py +++ b/scripts-dev/federation_client.py @@ -22,8 +22,8 @@ import sys from typing import Any, Optional from urllib import parse as urlparse -import nacl.signing import requests +import signedjson.key import signedjson.types import srvlookup import yaml @@ -44,18 +44,6 @@ def encode_base64(input_bytes): return output_string -def decode_base64(input_string): - """Decode a base64 string to bytes inferring padding from the length of the - string.""" - - input_bytes = input_string.encode("ascii") - input_len = len(input_bytes) - padding = b"=" * (3 - ((input_len + 3) % 4)) - output_len = 3 * ((input_len + 2) // 4) + (input_len + 2) % 4 - 2 - output_bytes = base64.b64decode(input_bytes + padding) - return output_bytes[:output_len] - - def encode_canonical_json(value): return json.dumps( value, @@ -88,42 +76,6 @@ def sign_json( return json_object -NACL_ED25519 = "ed25519" - - -def decode_signing_key_base64(algorithm, version, key_base64): - """Decode a base64 encoded signing key - Args: - algorithm (str): The algorithm the key is for (currently "ed25519"). - version (str): Identifies this key out of the keys for this entity. - key_base64 (str): Base64 encoded bytes of the key. - Returns: - A SigningKey object. - """ - if algorithm == NACL_ED25519: - key_bytes = decode_base64(key_base64) - key = nacl.signing.SigningKey(key_bytes) - key.version = version - key.alg = NACL_ED25519 - return key - else: - raise ValueError("Unsupported algorithm %s" % (algorithm,)) - - -def read_signing_keys(stream): - """Reads a list of keys from a stream - Args: - stream : A stream to iterate for keys. - Returns: - list of SigningKey objects. - """ - keys = [] - for line in stream: - algorithm, version, key_base64 = line.split() - keys.append(decode_signing_key_base64(algorithm, version, key_base64)) - return keys - - def request( method: Optional[str], origin_name: str, @@ -223,23 +175,28 @@ def main(): parser.add_argument("--body", help="Data to send as the body of the HTTP request") parser.add_argument( - "path", help="request path. We will add '/_matrix/federation/v1/' to this." + "path", help="request path, including the '/_matrix/federation/...' prefix." ) args = parser.parse_args() - if not args.server_name or not args.signing_key_path: + args.signing_key = None + if args.signing_key_path: + with open(args.signing_key_path) as f: + args.signing_key = f.readline() + + if not args.server_name or not args.signing_key: read_args_from_config(args) - with open(args.signing_key_path) as f: - key = read_signing_keys(f)[0] + algorithm, version, key_base64 = args.signing_key.split() + key = signedjson.key.decode_signing_key_base64(algorithm, version, key_base64) result = request( args.method, args.server_name, key, args.destination, - "/_matrix/federation/v1/" + args.path, + args.path, content=args.body, ) @@ -255,10 +212,16 @@ def main(): def read_args_from_config(args): with open(args.config, "r") as fh: config = yaml.safe_load(fh) + if not args.server_name: args.server_name = config["server_name"] - if not args.signing_key_path: - args.signing_key_path = config["signing_key_path"] + + if not args.signing_key: + if "signing_key" in config: + args.signing_key = config["signing_key"] + else: + with open(config["signing_key_path"]) as f: + args.signing_key = f.readline() class MatrixConnectionAdapter(HTTPAdapter): diff --git a/setup.cfg b/setup.cfg index 5e301c2cd7..920868df20 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,8 @@ ignore = # E203: whitespace before ':' (which is contrary to pep8?) # E731: do not assign a lambda expression, use a def # E501: Line too long (black enforces this for us) -ignore=W503,W504,E203,E731,E501 +# B00: Subsection of the bugbear suite (TODO: add in remaining fixes) +ignore=W503,W504,E203,E731,E501,B00 [isort] line_length = 88 diff --git a/setup.py b/setup.py index bbd9e7862a..b834e4e55b 100755 --- a/setup.py +++ b/setup.py @@ -99,6 +99,7 @@ CONDITIONAL_REQUIREMENTS["lint"] = [ "isort==5.7.0", "black==20.8b1", "flake8-comprehensions", + "flake8-bugbear", "flake8", ] diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 691f8f9adf..8f37d2cf3b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -51,6 +51,7 @@ class PresenceState: OFFLINE = "offline" UNAVAILABLE = "unavailable" ONLINE = "online" + BUSY = "org.matrix.msc3026.busy" class JoinRules: @@ -100,6 +101,9 @@ class EventTypes: Dummy = "org.matrix.dummy_event" + MSC1772_SPACE_CHILD = "org.matrix.msc1772.space.child" + MSC1772_SPACE_PARENT = "org.matrix.msc1772.space.parent" + class EduTypes: Presence = "m.presence" @@ -160,6 +164,9 @@ class EventContentFields: # cf https://github.com/matrix-org/matrix-doc/pull/2228 SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after" + # cf https://github.com/matrix-org/matrix-doc/pull/1772 + MSC1772_ROOM_TYPE = "org.matrix.msc1772.type" + class RoomEncryptionAlgorithms: MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2" diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py index 4a9b0129c3..d1a2cd5e19 100644 --- a/synapse/app/__init__.py +++ b/synapse/app/__init__.py @@ -22,7 +22,9 @@ logger = logging.getLogger(__name__) try: python_dependencies.check_requirements() except python_dependencies.DependencyException as e: - sys.stderr.writelines(e.message) + sys.stderr.writelines( + e.message # noqa: B306, DependencyException.message is a property + ) sys.exit(1) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 274d582d07..caef394e1d 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -302,6 +302,8 @@ class GenericWorkerPresence(BasePresenceHandler): self.send_stop_syncing, UPDATE_SYNCING_USERS_MS ) + self._busy_presence_enabled = hs.config.experimental.msc3026_enabled + hs.get_reactor().addSystemEventTrigger( "before", "shutdown", @@ -439,8 +441,12 @@ class GenericWorkerPresence(BasePresenceHandler): PresenceState.ONLINE, PresenceState.UNAVAILABLE, PresenceState.OFFLINE, + PresenceState.BUSY, ) - if presence not in valid_presence: + + if presence not in valid_presence or ( + presence == PresenceState.BUSY and not self._busy_presence_enabled + ): raise SynapseError(400, "Invalid presence state") user_id = target_user.to_string() diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index b1c1c51e4d..86f4d9af9d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -27,3 +27,7 @@ class ExperimentalConfig(Config): # MSC2858 (multiple SSO identity providers) self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool + # Spaces (MSC1772, MSC2946, etc) + self.spaces_enabled = experimental.get("spaces_enabled", False) # type: bool + # MSC3026 (busy presence state) + self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool diff --git a/synapse/config/key.py b/synapse/config/key.py index de964dff13..350ff1d665 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -404,7 +404,11 @@ def _parse_key_servers(key_servers, federation_verify_certificates): try: jsonschema.validate(key_servers, TRUSTED_KEY_SERVERS_SCHEMA) except jsonschema.ValidationError as e: - raise ConfigError("Unable to parse 'trusted_key_servers': " + e.message) + raise ConfigError( + "Unable to parse 'trusted_key_servers': {}".format( + e.message # noqa: B306, jsonschema.ValidationError.message is a valid attribute + ) + ) for server in key_servers: server_name = server["server_name"] diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py index dfd27e1523..2b289f4208 100644 --- a/synapse/config/metrics.py +++ b/synapse/config/metrics.py @@ -56,7 +56,9 @@ class MetricsConfig(Config): try: check_requirements("sentry") except DependencyException as e: - raise ConfigError(e.message) + raise ConfigError( + e.message # noqa: B306, DependencyException.message is a property + ) self.sentry_dsn = config["sentry"].get("dsn") if not self.sentry_dsn: diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 2bfb537c15..747ab9a7fe 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -15,11 +15,12 @@ # limitations under the License. from collections import Counter -from typing import Iterable, Mapping, Optional, Tuple, Type +from typing import Iterable, List, Mapping, Optional, Tuple, Type import attr from synapse.config._util import validate_config +from synapse.config.sso import SsoAttributeRequirement from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import Collection, JsonDict from synapse.util.module_loader import load_module @@ -41,7 +42,9 @@ class OIDCConfig(Config): try: check_requirements("oidc") except DependencyException as e: - raise ConfigError(e.message) from e + raise ConfigError( + e.message # noqa: B306, DependencyException.message is a property + ) from e # check we don't have any duplicate idp_ids now. (The SSO handler will also # check for duplicates when the REST listeners get registered, but that happens @@ -191,6 +194,24 @@ class OIDCConfig(Config): # which is set to the claims returned by the UserInfo Endpoint and/or # in the ID Token. # + # It is possible to configure Synapse to only allow logins if certain attributes + # match particular values in the OIDC userinfo. The requirements can be listed under + # `attribute_requirements` as shown below. All of the listed attributes must + # match for the login to be permitted. Additional attributes can be added to + # userinfo by expanding the `scopes` section of the OIDC config to retrieve + # additional information from the OIDC provider. + # + # If the OIDC claim is a list, then the attribute must match any value in the list. + # Otherwise, it must exactly match the value of the claim. Using the example + # below, the `family_name` claim MUST be "Stephensson", but the `groups` + # claim MUST contain "admin". + # + # attribute_requirements: + # - attribute: family_name + # value: "Stephensson" + # - attribute: groups + # value: "admin" + # # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md # for information on how to configure these options. # @@ -223,6 +244,9 @@ class OIDCConfig(Config): # localpart_template: "{{{{ user.login }}}}" # display_name_template: "{{{{ user.name }}}}" # email_template: "{{{{ user.email }}}}" + # attribute_requirements: + # - attribute: userGroup + # value: "synapseUsers" # For use with Keycloak # @@ -232,6 +256,9 @@ class OIDCConfig(Config): # client_id: "synapse" # client_secret: "copy secret generated in Keycloak UI" # scopes: ["openid", "profile"] + # attribute_requirements: + # - attribute: groups + # value: "admin" # For use with Github # @@ -329,6 +356,10 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { }, "allow_existing_users": {"type": "boolean"}, "user_mapping_provider": {"type": ["object", "null"]}, + "attribute_requirements": { + "type": "array", + "items": SsoAttributeRequirement.JSON_SCHEMA, + }, }, } @@ -465,6 +496,11 @@ def _parse_oidc_config_dict( jwt_header=client_secret_jwt_key_config["jwt_header"], jwt_payload=client_secret_jwt_key_config.get("jwt_payload", {}), ) + # parse attribute_requirements from config (list of dicts) into a list of SsoAttributeRequirement + attribute_requirements = [ + SsoAttributeRequirement(**x) + for x in oidc_config.get("attribute_requirements", []) + ] return OidcProviderConfig( idp_id=idp_id, @@ -488,6 +524,7 @@ def _parse_oidc_config_dict( allow_existing_users=oidc_config.get("allow_existing_users", False), user_mapping_provider_class=user_mapping_provider_class, user_mapping_provider_config=user_mapping_provider_config, + attribute_requirements=attribute_requirements, ) @@ -577,3 +614,6 @@ class OidcProviderConfig: # the config of the user mapping provider user_mapping_provider_config = attr.ib() + + # required attributes to require in userinfo to allow login/registration + attribute_requirements = attr.ib(type=List[SsoAttributeRequirement]) diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 69d9de5a43..061c4ec83f 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -176,7 +176,9 @@ class ContentRepositoryConfig(Config): check_requirements("url_preview") except DependencyException as e: - raise ConfigError(e.message) + raise ConfigError( + e.message # noqa: B306, DependencyException.message is a property + ) if "url_preview_ip_range_blacklist" not in config: raise ConfigError( diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 4b494f217f..6db9cb5ced 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -76,7 +76,9 @@ class SAML2Config(Config): try: check_requirements("saml2") except DependencyException as e: - raise ConfigError(e.message) + raise ConfigError( + e.message # noqa: B306, DependencyException.message is a property + ) self.saml2_enabled = True diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py index 0c1a854f09..727a1e7008 100644 --- a/synapse/config/tracer.py +++ b/synapse/config/tracer.py @@ -39,7 +39,9 @@ class TracerConfig(Config): try: check_requirements("opentracing") except DependencyException as e: - raise ConfigError(e.message) + raise ConfigError( + e.message # noqa: B306, DependencyException.message is a property + ) # The tracer is enabled so sanitize the config diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index 14b21796d9..4ca13011e5 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -219,7 +219,7 @@ class SSLClientConnectionCreator: # ... and we also gut-wrench a '_synapse_tls_verifier' attribute into the # tls_protocol so that the SSL context's info callback has something to # call to do the cert verification. - setattr(tls_protocol, "_synapse_tls_verifier", self._verifier) + tls_protocol._synapse_tls_verifier = self._verifier return connection diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 3ec4120f85..8f6b955d17 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -98,7 +98,7 @@ class DefaultDictProperty(DictProperty): class _EventInternalMetadata: - __slots__ = ["_dict", "stream_ordering"] + __slots__ = ["_dict", "stream_ordering", "outlier"] def __init__(self, internal_metadata_dict: JsonDict): # we have to copy the dict, because it turns out that the same dict is @@ -108,7 +108,10 @@ class _EventInternalMetadata: # the stream ordering of this event. None, until it has been persisted. self.stream_ordering = None # type: Optional[int] - outlier = DictProperty("outlier") # type: bool + # whether this event is an outlier (ie, whether we have the state at that point + # in the DAG) + self.outlier = False + out_of_band_membership = DictProperty("out_of_band_membership") # type: bool send_on_behalf_of = DictProperty("send_on_behalf_of") # type: str recheck_redaction = DictProperty("recheck_redaction") # type: bool @@ -129,7 +132,7 @@ class _EventInternalMetadata: return dict(self._dict) def is_outlier(self) -> bool: - return self._dict.get("outlier", False) + return self.outlier def is_out_of_band_membership(self) -> bool: """Whether this is an out of band membership, like an invite or an invite diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 7ca5c9940a..0f8a3b5ad8 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -22,6 +22,7 @@ from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion from synapse.util.async_helpers import yieldable_gather_results +from synapse.util.frozenutils import unfreeze from . import EventBase @@ -54,6 +55,8 @@ def prune_event(event: EventBase) -> EventBase: event.internal_metadata.stream_ordering ) + pruned_event.internal_metadata.outlier = event.internal_metadata.outlier + # Mark the event as redacted pruned_event.internal_metadata.redacted = True @@ -400,10 +403,19 @@ class EventClientSerializer: # If there is an edit replace the content, preserving existing # relations. + # Ensure we take copies of the edit content, otherwise we risk modifying + # the original event. + edit_content = edit.content.copy() + + # Unfreeze the event content if necessary, so that we may modify it below + edit_content = unfreeze(edit_content) + serialized_event["content"] = edit_content.get("m.new_content", {}) + + # Check for existing relations relations = event.content.get("m.relates_to") - serialized_event["content"] = edit.content.get("m.new_content", {}) if relations: - serialized_event["content"]["m.relates_to"] = relations + # Keep the relations, ensuring we use a dict copy of the original + serialized_event["content"]["m.relates_to"] = relations.copy() else: serialized_event["content"].pop("m.relates_to", None) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 9839d3d016..d84e362070 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -35,7 +35,7 @@ from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure -from synapse.api.constants import EduTypes, EventTypes, Membership +from synapse.api.constants import EduTypes, EventTypes from synapse.api.errors import ( AuthError, Codes, @@ -63,7 +63,7 @@ from synapse.replication.http.federation import ( ReplicationFederationSendEduRestServlet, ReplicationGetQueryRestServlet, ) -from synapse.types import JsonDict, get_domain_from_id +from synapse.types import JsonDict from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache @@ -727,27 +727,6 @@ class FederationServer(FederationBase): if the event was unacceptable for any other reason (eg, too large, too many prev_events, couldn't find the prev_events) """ - # check that it's actually being sent from a valid destination to - # workaround bug #1753 in 0.18.5 and 0.18.6 - if origin != get_domain_from_id(pdu.sender): - # We continue to accept join events from any server; this is - # necessary for the federation join dance to work correctly. - # (When we join over federation, the "helper" server is - # responsible for sending out the join event, rather than the - # origin. See bug #1893. This is also true for some third party - # invites). - if not ( - pdu.type == "m.room.member" - and pdu.content - and pdu.content.get("membership", None) - in (Membership.JOIN, Membership.INVITE) - ): - logger.info( - "Discarding PDU %s from invalid origin %s", pdu.event_id, origin - ) - return - else: - logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) # We've already checked that we know the room version by this point room_version = await self.store.get_room_version(pdu.room_id) diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index cc0d765e5f..af85fe0a1e 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -15,7 +15,7 @@ # limitations under the License. import datetime import logging -from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple import attr from prometheus_client import Counter @@ -77,6 +77,7 @@ class PerDestinationQueue: self._transaction_manager = transaction_manager self._instance_name = hs.get_instance_name() self._federation_shard_config = hs.config.worker.federation_shard_config + self._state = hs.get_state_handler() self._should_send_on_this_instance = True if not self._federation_shard_config.should_handle( @@ -415,22 +416,95 @@ class PerDestinationQueue: "This should not happen." % event_ids ) - if logger.isEnabledFor(logging.INFO): - rooms = [p.room_id for p in catchup_pdus] - logger.info("Catching up rooms to %s: %r", self._destination, rooms) + # We send transactions with events from one room only, as its likely + # that the remote will have to do additional processing, which may + # take some time. It's better to give it small amounts of work + # rather than risk the request timing out and repeatedly being + # retried, and not making any progress. + # + # Note: `catchup_pdus` will have exactly one PDU per room. + for pdu in catchup_pdus: + # The PDU from the DB will be the last PDU in the room from + # *this server* that wasn't sent to the remote. However, other + # servers may have sent lots of events since then, and we want + # to try and tell the remote only about the *latest* events in + # the room. This is so that it doesn't get inundated by events + # from various parts of the DAG, which all need to be processed. + # + # Note: this does mean that in large rooms a server coming back + # online will get sent the same events from all the different + # servers, but the remote will correctly deduplicate them and + # handle it only once. + + # Step 1, fetch the current extremities + extrems = await self._store.get_prev_events_for_room(pdu.room_id) + + if pdu.event_id in extrems: + # If the event is in the extremities, then great! We can just + # use that without having to do further checks. + room_catchup_pdus = [pdu] + else: + # If not, fetch the extremities and figure out which we can + # send. + extrem_events = await self._store.get_events_as_list(extrems) + + new_pdus = [] + for p in extrem_events: + # We pulled this from the DB, so it'll be non-null + assert p.internal_metadata.stream_ordering + + # Filter out events that happened before the remote went + # offline + if ( + p.internal_metadata.stream_ordering + < self._last_successful_stream_ordering + ): + continue - await self._transaction_manager.send_new_transaction( - self._destination, catchup_pdus, [] - ) + # Filter out events where the server is not in the room, + # e.g. it may have left/been kicked. *Ideally* we'd pull + # out the kick and send that, but it's a rare edge case + # so we don't bother for now (the server that sent the + # kick should send it out if its online). + hosts = await self._state.get_hosts_in_room_at_events( + p.room_id, [p.event_id] + ) + if self._destination not in hosts: + continue - sent_transactions_counter.inc() - final_pdu = catchup_pdus[-1] - self._last_successful_stream_ordering = cast( - int, final_pdu.internal_metadata.stream_ordering - ) - await self._store.set_destination_last_successful_stream_ordering( - self._destination, self._last_successful_stream_ordering - ) + new_pdus.append(p) + + # If we've filtered out all the extremities, fall back to + # sending the original event. This should ensure that the + # server gets at least some of missed events (especially if + # the other sending servers are up). + if new_pdus: + room_catchup_pdus = new_pdus + + logger.info( + "Catching up rooms to %s: %r", self._destination, pdu.room_id + ) + + await self._transaction_manager.send_new_transaction( + self._destination, room_catchup_pdus, [] + ) + + sent_transactions_counter.inc() + + # We pulled this from the DB, so it'll be non-null + assert pdu.internal_metadata.stream_ordering + + # Note that we mark the last successful stream ordering as that + # from the *original* PDU, rather than the PDU(s) we actually + # send. This is because we use it to mark our position in the + # queue of missed PDUs to process. + self._last_successful_stream_ordering = ( + pdu.internal_metadata.stream_ordering + ) + + await self._store.set_destination_last_successful_stream_ordering( + self._destination, self._last_successful_stream_ordering + ) def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]: if not self._pending_rrs: diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index fb5f8118f0..badac8c26c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -886,6 +886,19 @@ class AuthHandler(BaseHandler): ) return result + def can_change_password(self) -> bool: + """Get whether users on this server are allowed to change or set a password. + + Both `config.password_enabled` and `config.password_localdb_enabled` must be true. + + Note that any account (even SSO accounts) are allowed to add passwords if the above + is true. + + Returns: + Whether users on this server are allowed to change or set a password + """ + return self._password_enabled and self._password_localdb_enabled + def get_supported_login_types(self) -> Iterable[str]: """Get a the login types supported for the /login API diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index df3cdc8fba..2fc4951df4 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -166,7 +166,7 @@ class DeviceWorkerHandler(BaseHandler): # Fetch the current state at the time. try: - event_ids = await self.store.get_forward_extremeties_for_room( + event_ids = await self.store.get_forward_extremities_for_room_at_stream_ordering( room_id, stream_ordering=stream_ordering ) except errors.StoreError: @@ -907,6 +907,7 @@ class DeviceListUpdater: master_key = result.get("master_key") self_signing_key = result.get("self_signing_key") + ignore_devices = False # If the remote server has more than ~1000 devices for this user # we assume that something is going horribly wrong (e.g. a bot # that logs in and creates a new device every time it tries to @@ -925,6 +926,12 @@ class DeviceListUpdater: len(devices), ) devices = [] + ignore_devices = True + else: + cached_devices = await self.store.get_cached_devices_for_user(user_id) + if cached_devices == {d["device_id"]: d for d in devices}: + devices = [] + ignore_devices = True for device in devices: logger.debug( @@ -934,7 +941,10 @@ class DeviceListUpdater: stream_id, ) - await self.store.update_remote_device_list_cache(user_id, devices, stream_id) + if not ignore_devices: + await self.store.update_remote_device_list_cache( + user_id, devices, stream_id + ) device_ids = [device["device_id"] for device in devices] # Handle cross-signing keys. @@ -945,7 +955,8 @@ class DeviceListUpdater: ) device_ids = device_ids + cross_signing_device_ids - await self.device_handler.notify_device_update(user_id, device_ids) + if device_ids: + await self.device_handler.notify_device_update(user_id, device_ids) # We clobber the seen updates since we've re-synced from a given # point. @@ -973,14 +984,17 @@ class DeviceListUpdater: """ device_ids = [] - if master_key: + current_keys_map = await self.store.get_e2e_cross_signing_keys_bulk([user_id]) + current_keys = current_keys_map.get(user_id) or {} + + if master_key and master_key != current_keys.get("master"): await self.store.set_e2e_cross_signing_key(user_id, "master", master_key) _, verify_key = get_verify_key_from_cross_signing_key(master_key) # verify_key is a VerifyKey from signedjson, which uses # .version to denote the portion of the key ID after the # algorithm and colon, which is the device ID device_ids.append(verify_key.version) - if self_signing_key: + if self_signing_key and self_signing_key != current_keys.get("self_signing"): await self.store.set_e2e_cross_signing_key( user_id, "self_signing", self_signing_key ) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 6d8551a6d6..bc3630e9e9 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -280,6 +280,7 @@ class OidcProvider: self._config = provider self._callback_url = hs.config.oidc_callback_url # type: str + self._oidc_attribute_requirements = provider.attribute_requirements self._scopes = provider.scopes self._user_profile_method = provider.user_profile_method @@ -859,6 +860,18 @@ class OidcProvider: ) # otherwise, it's a login + logger.debug("Userinfo for OIDC login: %s", userinfo) + + # Ensure that the attributes of the logged in user meet the required + # attributes by checking the userinfo against attribute_requirements + # In order to deal with the fact that OIDC userinfo can contain many + # types of data, we wrap non-list values in lists. + if not self._sso_handler.check_required_attributes( + request, + {k: v if isinstance(v, list) else [v] for k, v in userinfo.items()}, + self._oidc_attribute_requirements, + ): + return # Call the mapper to register/login the user try: diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 54631b4ee2..da92feacc9 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -104,6 +104,8 @@ class BasePresenceHandler(abc.ABC): self.clock = hs.get_clock() self.store = hs.get_datastore() + self._busy_presence_enabled = hs.config.experimental.msc3026_enabled + active_presence = self.store.take_presence_startup_info() self.user_to_current_state = {state.user_id: state for state in active_presence} @@ -730,8 +732,12 @@ class PresenceHandler(BasePresenceHandler): PresenceState.ONLINE, PresenceState.UNAVAILABLE, PresenceState.OFFLINE, + PresenceState.BUSY, ) - if presence not in valid_presence: + + if presence not in valid_presence or ( + presence == PresenceState.BUSY and not self._busy_presence_enabled + ): raise SynapseError(400, "Invalid presence state") user_id = target_user.to_string() @@ -744,7 +750,9 @@ class PresenceHandler(BasePresenceHandler): msg = status_msg if presence != PresenceState.OFFLINE else None new_fields["status_msg"] = msg - if presence == PresenceState.ONLINE: + if presence == PresenceState.ONLINE or ( + presence == PresenceState.BUSY and self._busy_presence_enabled + ): new_fields["last_active_ts"] = self.clock.time_msec() await self._update_states([prev_state.copy_and_replace(**new_fields)]) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 1abc8875cb..d7f226d589 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -437,10 +437,10 @@ class RegistrationHandler(BaseHandler): if RoomAlias.is_valid(r): ( - room_id, + room, remote_room_hosts, ) = await room_member_handler.lookup_room_alias(room_alias) - room_id = room_id.to_string() + room_id = room.to_string() else: raise SynapseError( 400, "%s was not legal room ID or room alias" % (r,) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 1660921306..4d20ed8357 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -155,6 +155,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): """ raise NotImplementedError() + @abc.abstractmethod + async def forget(self, user: UserID, room_id: str) -> None: + raise NotImplementedError() + def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str): """Ratelimit invites by room and by target user. diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 108730a7a1..d75506c75e 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.api.errors import SynapseError from synapse.handlers.room_member import RoomMemberHandler @@ -25,11 +25,14 @@ from synapse.replication.http.membership import ( ) from synapse.types import Requester, UserID +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class RoomMemberWorkerHandler(RoomMemberHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._remote_join_client = ReplRemoteJoin.make_client(hs) @@ -83,3 +86,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler): await self._notify_change_client( user_id=target.to_string(), room_id=room_id, change="left" ) + + async def forget(self, target: UserID, room_id: str) -> None: + raise RuntimeError("Cannot forget rooms on workers.") diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 84af2dde7e..04e7c64c94 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -41,7 +41,7 @@ class SetPasswordHandler(BaseHandler): logout_devices: bool, requester: Optional[Requester] = None, ) -> None: - if not self.hs.config.password_localdb_enabled: + if not self._auth_handler.can_change_password(): raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) try: diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py new file mode 100644 index 0000000000..513dc0c71a --- /dev/null +++ b/synapse/handlers/space_summary.py @@ -0,0 +1,199 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import logging +from collections import deque +from typing import TYPE_CHECKING, Iterable, List, Optional, Set + +from synapse.api.constants import EventContentFields, EventTypes, HistoryVisibility +from synapse.api.errors import AuthError +from synapse.events import EventBase +from synapse.events.utils import format_event_for_client_v2 +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + +# number of rooms to return. We'll stop once we hit this limit. +# TODO: allow clients to reduce this with a request param. +MAX_ROOMS = 50 + +# max number of events to return per room. +MAX_ROOMS_PER_SPACE = 50 + + +class SpaceSummaryHandler: + def __init__(self, hs: "HomeServer"): + self._clock = hs.get_clock() + self._auth = hs.get_auth() + self._room_list_handler = hs.get_room_list_handler() + self._state_handler = hs.get_state_handler() + self._store = hs.get_datastore() + self._event_serializer = hs.get_event_client_serializer() + + async def get_space_summary( + self, + requester: str, + room_id: str, + suggested_only: bool = False, + max_rooms_per_space: Optional[int] = None, + ) -> JsonDict: + """ + Implementation of the space summary API + + Args: + requester: user id of the user making this request + + room_id: room id to start the summary at + + suggested_only: whether we should only return children with the "suggested" + flag set. + + max_rooms_per_space: an optional limit on the number of child rooms we will + return. This does not apply to the root room (ie, room_id), and + is overridden by ROOMS_PER_SPACE_LIMIT. + + Returns: + summary dict to return + """ + # first of all, check that the user is in the room in question (or it's + # world-readable) + await self._auth.check_user_in_room_or_world_readable(room_id, requester) + + # the queue of rooms to process + room_queue = deque((room_id,)) + + processed_rooms = set() # type: Set[str] + + rooms_result = [] # type: List[JsonDict] + events_result = [] # type: List[JsonDict] + + now = self._clock.time_msec() + + while room_queue and len(rooms_result) < MAX_ROOMS: + room_id = room_queue.popleft() + logger.debug("Processing room %s", room_id) + processed_rooms.add(room_id) + + try: + await self._auth.check_user_in_room_or_world_readable( + room_id, requester + ) + except AuthError: + logger.info( + "user %s cannot view room %s, omitting from summary", + requester, + room_id, + ) + continue + + room_entry = await self._build_room_entry(room_id) + rooms_result.append(room_entry) + + # look for child rooms/spaces. + child_events = await self._get_child_events(room_id) + + if suggested_only: + # we only care about suggested children + child_events = filter(_is_suggested_child_event, child_events) + + # The client-specified max_rooms_per_space limit doesn't apply to the + # room_id specified in the request, so we ignore it if this is the + # first room we are processing. Otherwise, apply any client-specified + # limit, capping to our built-in limit. + if max_rooms_per_space is not None and len(processed_rooms) > 1: + max_rooms = min(MAX_ROOMS_PER_SPACE, max_rooms_per_space) + else: + max_rooms = MAX_ROOMS_PER_SPACE + + for edge_event in itertools.islice(child_events, max_rooms): + edge_room_id = edge_event.state_key + + events_result.append( + await self._event_serializer.serialize_event( + edge_event, + time_now=now, + event_format=format_event_for_client_v2, + ) + ) + + # if we haven't yet visited the target of this link, add it to the queue + if edge_room_id not in processed_rooms: + room_queue.append(edge_room_id) + + return {"rooms": rooms_result, "events": events_result} + + async def _build_room_entry(self, room_id: str) -> JsonDict: + """Generate en entry suitable for the 'rooms' list in the summary response""" + stats = await self._store.get_room_with_stats(room_id) + + # currently this should be impossible because we call + # check_user_in_room_or_world_readable on the room before we get here, so + # there should always be an entry + assert stats is not None, "unable to retrieve stats for %s" % (room_id,) + + current_state_ids = await self._store.get_current_state_ids(room_id) + create_event = await self._store.get_event( + current_state_ids[(EventTypes.Create, "")] + ) + + # TODO: update once MSC1772 lands + room_type = create_event.content.get(EventContentFields.MSC1772_ROOM_TYPE) + + entry = { + "room_id": stats["room_id"], + "name": stats["name"], + "topic": stats["topic"], + "canonical_alias": stats["canonical_alias"], + "num_joined_members": stats["joined_members"], + "avatar_url": stats["avatar"], + "world_readable": ( + stats["history_visibility"] == HistoryVisibility.WORLD_READABLE + ), + "guest_can_join": stats["guest_access"] == "can_join", + "room_type": room_type, + } + + # Filter out Nones – rather omit the field altogether + room_entry = {k: v for k, v in entry.items() if v is not None} + + return room_entry + + async def _get_child_events(self, room_id: str) -> Iterable[EventBase]: + # look for child rooms/spaces. + current_state_ids = await self._store.get_current_state_ids(room_id) + + events = await self._store.get_events_as_list( + [ + event_id + for key, event_id in current_state_ids.items() + # TODO: update once MSC1772 lands + if key[0] == EventTypes.MSC1772_SPACE_CHILD + ] + ) + + # filter out any events without a "via" (which implies it has been redacted) + return (e for e in events if e.content.get("via")) + + +def _is_suggested_child_event(edge_event: EventBase) -> bool: + suggested = edge_event.content.get("suggested") + if isinstance(suggested, bool) and suggested: + return True + logger.debug("Ignorning not-suggested child %s", edge_event.state_key) + return False diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index f50257cd57..7b723ead58 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1979,8 +1979,10 @@ class SyncHandler: logger.info("User joined room after current token: %s", room_id) - extrems = await self.store.get_forward_extremeties_for_room( - room_id, event_pos.stream + extrems = ( + await self.store.get_forward_extremities_for_room_at_stream_ordering( + room_id, event_pos.stream + ) ) users_in_room = await self.state.get_current_users_in_room(room_id, extrems) if user_id in users_in_room: diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 8af53b4f28..82ea3b895f 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -40,6 +40,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): // containing the event "event_format_version": .., // 1,2,3 etc: the event format version "internal_metadata": { .. serialized internal_metadata .. }, + "outlier": true|false, "rejected_reason": .., // The event.rejected_reason field "context": { .. serialized event context .. }, }], @@ -84,6 +85,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): "room_version": event.room_version.identifier, "event_format_version": event.format_version, "internal_metadata": event.internal_metadata.get_dict(), + "outlier": event.internal_metadata.is_outlier(), "rejected_reason": event.rejected_reason, "context": serialized_context, } @@ -116,6 +118,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): event = make_event_from_dict( event_dict, room_ver, internal_metadata, rejected_reason ) + event.internal_metadata.outlier = event_payload["outlier"] context = EventContext.deserialize( self.storage, event_payload["context"] diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 8fa104c8d3..a4c5b44292 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -40,6 +40,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): // containing the event "event_format_version": .., // 1,2,3 etc: the event format version "internal_metadata": { .. serialized internal_metadata .. }, + "outlier": true|false, "rejected_reason": .., // The event.rejected_reason field "context": { .. serialized event context .. }, "requester": { .. serialized requester .. }, @@ -79,7 +80,6 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): ratelimit (bool) extra_users (list(UserID)): Any extra users to notify about event """ - serialized_context = await context.serialize(event, store) payload = { @@ -87,6 +87,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "room_version": event.room_version.identifier, "event_format_version": event.format_version, "internal_metadata": event.internal_metadata.get_dict(), + "outlier": event.internal_metadata.is_outlier(), "rejected_reason": event.rejected_reason, "context": serialized_context, "requester": requester.serialize(), @@ -108,6 +109,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): event = make_event_from_dict( event_dict, room_ver, internal_metadata, rejected_reason ) + event.internal_metadata.outlier = content["outlier"] requester = Requester.deserialize(self.store, content["requester"]) context = EventContext.deserialize(self.storage, content["context"]) diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index f45e7a8c89..7e8e64d61c 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -33,7 +33,7 @@ import attr from synapse.replication.http.streams import ReplicationGetStreamUpdates if TYPE_CHECKING: - import synapse.server + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -299,20 +299,23 @@ class TypingStream(Stream): NAME = "typing" ROW_TYPE = TypingStreamRow - def __init__(self, hs): - typing_handler = hs.get_typing_handler() - + def __init__(self, hs: "HomeServer"): writer_instance = hs.config.worker.writers.typing if writer_instance == hs.get_instance_name(): # On the writer, query the typing handler - update_function = typing_handler.get_all_typing_updates + typing_writer_handler = hs.get_typing_writer_handler() + update_function = ( + typing_writer_handler.get_all_typing_updates + ) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]] + current_token_function = typing_writer_handler.get_current_token else: # Query the typing writer process update_function = make_http_update_function(hs, self.NAME) + current_token_function = hs.get_typing_handler().get_current_token super().__init__( hs.get_instance_name(), - current_token_without_instance(typing_handler.get_current_token), + current_token_without_instance(current_token_function), update_function, ) @@ -509,7 +512,7 @@ class AccountDataStream(Stream): NAME = "account_data" ROW_TYPE = AccountDataStreamRow - def __init__(self, hs: "synapse.server.HomeServer"): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 2c89b62e25..aaa56a7024 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -271,7 +271,7 @@ class UserRestServletV2(RestServlet): elif not deactivate and user["deactivated"]: if ( "password" not in body - and self.hs.config.password_localdb_enabled + and self.auth_handler.can_change_password() ): raise SynapseError( 400, "Must provide a password to re-activate an account." diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 5884daea6d..6c722d634d 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -18,7 +18,7 @@ import logging import re -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple from urllib import parse as urlparse from synapse.api.constants import EventTypes, Membership @@ -35,21 +35,30 @@ from synapse.events.utils import format_event_for_client_v2 from synapse.http.servlet import ( RestServlet, assert_params_in_dict, + parse_boolean, parse_integer, parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.logging.opentracing import set_tag from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.v2_alpha._base import client_patterns from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig -from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID +from synapse.types import ( + JsonDict, + RoomAlias, + RoomID, + StreamToken, + ThirdPartyInstanceID, + UserID, +) from synapse.util import json_decoder from synapse.util.stringutils import parse_and_validate_server_name, random_string if TYPE_CHECKING: - import synapse.server + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -846,10 +855,10 @@ class RoomTypingRestServlet(RestServlet): "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$", v1=True ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() + self.hs = hs self.presence_handler = hs.get_presence_handler() - self.typing_handler = hs.get_typing_handler() self.auth = hs.get_auth() # If we're not on the typing writer instance we should scream if we get @@ -874,16 +883,19 @@ class RoomTypingRestServlet(RestServlet): # Limit timeout to stop people from setting silly typing timeouts. timeout = min(content.get("timeout", 30000), 120000) + # Defer getting the typing handler since it will raise on workers. + typing_handler = self.hs.get_typing_writer_handler() + try: if content["typing"]: - await self.typing_handler.started_typing( + await typing_handler.started_typing( target_user=target_user, requester=requester, room_id=room_id, timeout=timeout, ) else: - await self.typing_handler.stopped_typing( + await typing_handler.stopped_typing( target_user=target_user, requester=requester, room_id=room_id ) except ShadowBanError: @@ -901,7 +913,7 @@ class RoomAliasListServlet(RestServlet): ), ] - def __init__(self, hs: "synapse.server.HomeServer"): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.directory_handler = hs.get_directory_handler() @@ -984,7 +996,58 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): ) -def register_servlets(hs, http_server, is_worker=False): +class RoomSpaceSummaryRestServlet(RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc2946" + "/rooms/(?P<room_id>[^/]*)/spaces$" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._auth = hs.get_auth() + self._space_summary_handler = hs.get_space_summary_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request, allow_guest=True) + + return 200, await self._space_summary_handler.get_space_summary( + requester.user.to_string(), + room_id, + suggested_only=parse_boolean(request, "suggested_only", default=False), + max_rooms_per_space=parse_integer(request, "max_rooms_per_space"), + ) + + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self._auth.get_user_by_req(request, allow_guest=True) + content = parse_json_object_from_request(request) + + suggested_only = content.get("suggested_only", False) + if not isinstance(suggested_only, bool): + raise SynapseError( + 400, "'suggested_only' must be a boolean", Codes.BAD_JSON + ) + + max_rooms_per_space = content.get("max_rooms_per_space") + if max_rooms_per_space is not None and not isinstance(max_rooms_per_space, int): + raise SynapseError( + 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON + ) + + return 200, await self._space_summary_handler.get_space_summary( + requester.user.to_string(), + room_id, + suggested_only=suggested_only, + max_rooms_per_space=max_rooms_per_space, + ) + + +def register_servlets(hs: "HomeServer", http_server, is_worker=False): RoomStateEventRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server) JoinedRoomMemberListRestServlet(hs).register(http_server) @@ -998,6 +1061,9 @@ def register_servlets(hs, http_server, is_worker=False): RoomTypingRestServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server) + if hs.config.experimental.spaces_enabled: + RoomSpaceSummaryRestServlet(hs).register(http_server) + # Some servlets only get registered for the main process. if not is_worker: RoomCreateRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py index 76879ac559..44ccf10ed4 100644 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ b/synapse/rest/client/v2_alpha/capabilities.py @@ -13,12 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -27,21 +33,16 @@ class CapabilitiesRestServlet(RestServlet): PATTERNS = client_patterns("/capabilities$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.config = hs.config self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.auth_handler = hs.get_auth_handler() - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - user = await self.store.get_user_by_id(requester.user.to_string()) - change_password = bool(user["password_hash"]) + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.auth.get_user_by_req(request, allow_guest=True) + change_password = self.auth_handler.can_change_password() response = { "capabilities": { @@ -58,5 +59,5 @@ class CapabilitiesRestServlet(RestServlet): return 200, response -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): CapabilitiesRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index d24a199318..3e3d8839f4 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -81,6 +81,8 @@ class VersionsRestServlet(RestServlet): "io.element.e2ee_forced.public": self.e2ee_forced_public, "io.element.e2ee_forced.private": self.e2ee_forced_private, "io.element.e2ee_forced.trusted_private": self.e2ee_forced_trusted_private, + # Supports the busy presence state described in MSC3026. + "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled, }, }, ) diff --git a/synapse/server.py b/synapse/server.py index 48ac87a124..98822d8e2f 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -96,10 +96,11 @@ from synapse.handlers.room import ( RoomShutdownHandler, ) from synapse.handlers.room_list import RoomListHandler -from synapse.handlers.room_member import RoomMemberMasterHandler +from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler from synapse.handlers.search import SearchHandler from synapse.handlers.set_password import SetPasswordHandler +from synapse.handlers.space_summary import SpaceSummaryHandler from synapse.handlers.sso import SsoHandler from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler @@ -417,10 +418,19 @@ class HomeServer(metaclass=abc.ABCMeta): return PresenceHandler(self) @cache_in_self - def get_typing_handler(self): + def get_typing_writer_handler(self) -> TypingWriterHandler: if self.config.worker.writers.typing == self.get_instance_name(): return TypingWriterHandler(self) else: + raise Exception("Workers cannot write typing") + + @cache_in_self + def get_typing_handler(self) -> FollowerTypingHandler: + if self.config.worker.writers.typing == self.get_instance_name(): + # Use get_typing_writer_handler to ensure that we use the same + # cached version. + return self.get_typing_writer_handler() + else: return FollowerTypingHandler(self) @cache_in_self @@ -630,7 +640,7 @@ class HomeServer(metaclass=abc.ABCMeta): return ThirdPartyEventRules(self) @cache_in_self - def get_room_member_handler(self): + def get_room_member_handler(self) -> RoomMemberHandler: if self.config.worker_app: return RoomMemberWorkerHandler(self) return RoomMemberMasterHandler(self) @@ -724,6 +734,10 @@ class HomeServer(metaclass=abc.ABCMeta): return AccountDataHandler(self) @cache_in_self + def get_space_summary_handler(self) -> SpaceSummaryHandler: + return SpaceSummaryHandler(self) + + @cache_in_self def get_external_cache(self) -> ExternalCache: return ExternalCache(self) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 332193ad1c..a956be491a 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -793,7 +793,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return int(min_depth) if min_depth is not None else None - async def get_forward_extremeties_for_room( + async def get_forward_extremities_for_room_at_stream_ordering( self, room_id: str, stream_ordering: int ) -> List[str]: """For a given room_id and stream_ordering, return the forward diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index cd1ceac50e..98dac19a95 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1270,8 +1270,10 @@ class PersistEventsStore: logger.exception("") raise + # update the stored internal_metadata to update the "outlier" flag. + # TODO: This is unused as of Synapse 1.31. Remove it once we are happy + # to drop backwards-compatibility with 1.30. metadata_json = json_encoder.encode(event.internal_metadata.get_dict()) - sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?" txn.execute(sql, (metadata_json, event.event_id)) @@ -1319,6 +1321,19 @@ class PersistEventsStore: d.pop("redacted_because", None) return d + def get_internal_metadata(event): + im = event.internal_metadata.get_dict() + + # temporary hack for database compatibility with Synapse 1.30 and earlier: + # store the `outlier` flag inside the internal_metadata json as well as in + # the `events` table, so that if anyone rolls back to an older Synapse, + # things keep working. This can be removed once we are happy to drop support + # for that + if event.internal_metadata.is_outlier(): + im["outlier"] = True + + return im + self.db_pool.simple_insert_many_txn( txn, table="event_json", @@ -1327,7 +1342,7 @@ class PersistEventsStore: "event_id": event.event_id, "room_id": event.room_id, "internal_metadata": json_encoder.encode( - event.internal_metadata.get_dict() + get_internal_metadata(event) ), "json": json_encoder.encode(event_dict(event)), "format_version": event.format_version, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index c04e162ccc..952d4969b2 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -799,6 +799,7 @@ class EventsWorkerStore(SQLBaseStore): rejected_reason=rejected_reason, ) original_ev.internal_metadata.stream_ordering = row["stream_ordering"] + original_ev.internal_metadata.outlier = row["outlier"] event_map[event_id] = original_ev @@ -905,7 +906,8 @@ class EventsWorkerStore(SQLBaseStore): ej.json, ej.format_version, r.room_version, - rej.reason + rej.reason, + e.outlier FROM events AS e JOIN event_json AS ej USING (event_id) LEFT JOIN rooms r ON r.room_id = e.room_id @@ -929,6 +931,7 @@ class EventsWorkerStore(SQLBaseStore): "room_version_id": row[5], "rejected_reason": row[6], "redactions": [], + "outlier": row[7], } # check for redactions diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index eba66ff352..90a8f664ef 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1210,6 +1210,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): self._invalidate_cache_and_stream( txn, self.get_user_deactivated_status, (user_id,) ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) @cached() diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 6f96cd7940..95eac6a5a3 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -2,6 +2,7 @@ from typing import List, Tuple from mock import Mock +from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.federation.sender import PerDestinationQueue, TransactionManager from synapse.federation.units import Edu @@ -421,3 +422,51 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): self.assertNotIn("zzzerver", woken) # - all destinations are woken exactly once; they appear once in woken. self.assertCountEqual(woken, server_names[:-1]) + + @override_config({"send_federation": True}) + def test_not_latest_event(self): + """Test that we send the latest event in the room even if its not ours.""" + + per_dest_queue, sent_pdus = self.make_fake_destination_queue() + + # Make a room with a local user, and two servers. One will go offline + # and one will send some events. + self.register_user("u1", "you the one") + u1_token = self.login("u1", "you the one") + room_1 = self.helper.create_room_as("u1", tok=u1_token) + + self.get_success( + event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join") + ) + event_1 = self.get_success( + event_injection.inject_member_event(self.hs, room_1, "@user:host3", "join") + ) + + # First we send something from the local server, so that we notice the + # remote is down and go into catchup mode. + self.helper.send(room_1, "you hear me!!", tok=u1_token) + + # Now simulate us receiving an event from the still online remote. + event_2 = self.get_success( + event_injection.inject_event( + self.hs, + type=EventTypes.Message, + sender="@user:host3", + room_id=room_1, + content={"msgtype": "m.text", "body": "Hello"}, + ) + ) + + self.get_success( + self.hs.get_datastore().set_destination_last_successful_stream_ordering( + "host2", event_1.internal_metadata.stream_ordering + ) + ) + + self.get_success(per_dest_queue._catch_up_transmission_loop()) + + # We expect only the last message from the remote, event_2, to have been + # sent, rather than the last *local* event that was sent. + self.assertEqual(len(sent_pdus), 1) + self.assertEqual(sent_pdus[0].event_id, event_2.event_id) + self.assertFalse(per_dest_queue._catching_up) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 5e9c9c2e88..c7796fb837 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -989,6 +989,138 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) self.assertRenderedError("mapping_error", "localpart is invalid: ") + @override_config( + { + "oidc_config": { + **DEFAULT_CONFIG, + "attribute_requirements": [{"attribute": "test", "value": "foobar"}], + } + } + ) + def test_attribute_requirements(self): + """The required attributes must be met from the OIDC userinfo response.""" + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # userinfo lacking "test": "foobar" attribute should fail. + userinfo = { + "sub": "tester", + "username": "tester", + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + + # userinfo with "test": "foobar" attribute should succeed. + userinfo = { + "sub": "tester", + "username": "tester", + "test": "foobar", + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@tester:test", "oidc", ANY, ANY, None, new_user=True + ) + + @override_config( + { + "oidc_config": { + **DEFAULT_CONFIG, + "attribute_requirements": [{"attribute": "test", "value": "foobar"}], + } + } + ) + def test_attribute_requirements_contains(self): + """Test that auth succeeds if userinfo attribute CONTAINS required value""" + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + # userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed. + userinfo = { + "sub": "tester", + "username": "tester", + "test": ["foobar", "foo", "bar"], + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@tester:test", "oidc", ANY, ANY, None, new_user=True + ) + + @override_config( + { + "oidc_config": { + **DEFAULT_CONFIG, + "attribute_requirements": [{"attribute": "test", "value": "foobar"}], + } + } + ) + def test_attribute_requirements_mismatch(self): + """ + Test that auth fails if attributes exist but don't match, + or are non-string values. + """ + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + # userinfo with "test": "not_foobar" attribute should fail + userinfo = { + "sub": "tester", + "username": "tester", + "test": "not_foobar", + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + + # userinfo with "test": ["foo", "bar"] attribute should fail + userinfo = { + "sub": "tester", + "username": "tester", + "test": ["foo", "bar"], + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + + # userinfo with "test": False attribute should fail + # this is largely just to ensure we don't crash here + userinfo = { + "sub": "tester", + "username": "tester", + "test": False, + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + + # userinfo with "test": None attribute should fail + # a value of None breaks the OIDC spec, but it's important to not crash here + userinfo = { + "sub": "tester", + "username": "tester", + "test": None, + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + + # userinfo with "test": 1 attribute should fail + # this is largely just to ensure we don't crash here + userinfo = { + "sub": "tester", + "username": "tester", + "test": 1, + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + + # userinfo with "test": 3.14 attribute should fail + # this is largely just to ensure we don't crash here + userinfo = { + "sub": "tester", + "username": "tester", + "test": 3.14, + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + auth_handler.complete_sso_login.assert_not_called() + def _generate_oidc_session_token( self, state: str, diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 996c614198..77330f59a9 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -310,6 +310,26 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.UNAVAILABLE) + def test_busy_no_idle(self): + """ + Tests that a user setting their presence to busy but idling doesn't turn their + presence state into unavailable. + """ + user_id = "@foo:bar" + now = 5000000 + + state = UserPresenceState.default(user_id) + state = state.copy_and_replace( + state=PresenceState.BUSY, + last_active_ts=now - IDLE_TIMER - 1, + last_user_sync_ts=now, + ) + + new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) + + self.assertIsNotNone(new_state) + self.assertEquals(new_state.state, PresenceState.BUSY) + def test_sync_timeout(self): user_id = "@foo:bar" now = 5000000 diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index e58d5cf0db..cf61f284cb 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1003,12 +1003,23 @@ class UserRestTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() + self.auth_handler = hs.get_auth_handler() + # create users and get access tokens + # regardless of whether password login or SSO is allowed self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") + self.admin_user_tok = self.get_success( + self.auth_handler.get_access_token_for_user_id( + self.admin_user, device_id=None, valid_until_ms=None + ) + ) self.other_user = self.register_user("user", "pass", displayname="User") - self.other_user_token = self.login("user", "pass") + self.other_user_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + self.other_user, device_id=None, valid_until_ms=None + ) + ) self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( self.other_user ) @@ -1081,7 +1092,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - self.assertEqual(True, channel.json_body["admin"]) + self.assertTrue(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user @@ -1096,9 +1107,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - self.assertEqual(True, channel.json_body["admin"]) - self.assertEqual(False, channel.json_body["is_guest"]) - self.assertEqual(False, channel.json_body["deactivated"]) + self.assertTrue(channel.json_body["admin"]) + self.assertFalse(channel.json_body["is_guest"]) + self.assertFalse(channel.json_body["deactivated"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) def test_create_user(self): @@ -1130,7 +1141,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - self.assertEqual(False, channel.json_body["admin"]) + self.assertFalse(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user @@ -1145,10 +1156,10 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - self.assertEqual(False, channel.json_body["admin"]) - self.assertEqual(False, channel.json_body["is_guest"]) - self.assertEqual(False, channel.json_body["deactivated"]) - self.assertEqual(False, channel.json_body["shadow_banned"]) + self.assertFalse(channel.json_body["admin"]) + self.assertFalse(channel.json_body["is_guest"]) + self.assertFalse(channel.json_body["deactivated"]) + self.assertFalse(channel.json_body["shadow_banned"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) @override_config( @@ -1197,7 +1208,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual(False, channel.json_body["admin"]) + self.assertFalse(channel.json_body["admin"]) @override_config( {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} @@ -1237,7 +1248,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Admin user is not blocked by mau anymore self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual(False, channel.json_body["admin"]) + self.assertFalse(channel.json_body["admin"]) @override_config( { @@ -1429,24 +1440,23 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(False, channel.json_body["deactivated"]) + self.assertFalse(channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User", channel.json_body["displayname"]) # Deactivate user - body = json.dumps({"deactivated": True}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"deactivated": True}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["deactivated"]) + self.assertTrue(channel.json_body["deactivated"]) + self.assertIsNone(channel.json_body["password_hash"]) self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User", channel.json_body["displayname"]) @@ -1461,7 +1471,8 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["deactivated"]) + self.assertTrue(channel.json_body["deactivated"]) + self.assertIsNone(channel.json_body["password_hash"]) self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User", channel.json_body["displayname"]) @@ -1478,41 +1489,37 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertTrue(profile["display_name"] == "User") # Deactivate user - body = json.dumps({"deactivated": True}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"deactivated": True}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["deactivated"]) + self.assertTrue(channel.json_body["deactivated"]) # is not in user directory profile = self.get_success(self.store.get_user_in_directory(self.other_user)) - self.assertTrue(profile is None) + self.assertIsNone(profile) # Set new displayname user - body = json.dumps({"displayname": "Foobar"}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"displayname": "Foobar"}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["deactivated"]) + self.assertTrue(channel.json_body["deactivated"]) self.assertEqual("Foobar", channel.json_body["displayname"]) # is not in user directory profile = self.get_success(self.store.get_user_in_directory(self.other_user)) - self.assertTrue(profile is None) + self.assertIsNone(profile) def test_reactivate_user(self): """ @@ -1520,48 +1527,92 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ # Deactivate the user. + self._deactivate_user("@user:test") + + # Attempt to reactivate the user (without a password). + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"deactivated": False}, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + + # Reactivate the user. channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=json.dumps({"deactivated": True}).encode(encoding="utf_8"), + content={"deactivated": False, "password": "foo"}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertFalse(channel.json_body["deactivated"]) + self.assertIsNotNone(channel.json_body["password_hash"]) self._is_erased("@user:test", False) - d = self.store.mark_user_erased("@user:test") - self.assertIsNone(self.get_success(d)) - self._is_erased("@user:test", True) - # Attempt to reactivate the user (without a password). + @override_config({"password_config": {"localdb_enabled": False}}) + def test_reactivate_user_localdb_disabled(self): + """ + Test reactivating another user when using SSO. + """ + + # Deactivate the user. + self._deactivate_user("@user:test") + + # Reactivate the user with a password channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=json.dumps({"deactivated": False}).encode(encoding="utf_8"), + content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - # Reactivate the user. + # Reactivate the user without a password. channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=json.dumps({"deactivated": False, "password": "foo"}).encode( - encoding="utf_8" - ), + content={"deactivated": False}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertFalse(channel.json_body["deactivated"]) + self.assertIsNone(channel.json_body["password_hash"]) + self._is_erased("@user:test", False) - # Get user + @override_config({"password_config": {"enabled": False}}) + def test_reactivate_user_password_disabled(self): + """ + Test reactivating another user when using SSO. + """ + + # Deactivate the user. + self._deactivate_user("@user:test") + + # Reactivate the user with a password channel = self.make_request( - "GET", + "PUT", self.url_other_user, access_token=self.admin_user_tok, + content={"deactivated": False, "password": "foo"}, ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + # Reactivate the user without a password. + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"deactivated": False}, + ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(False, channel.json_body["deactivated"]) + self.assertFalse(channel.json_body["deactivated"]) + self.assertIsNone(channel.json_body["password_hash"]) self._is_erased("@user:test", False) def test_set_user_as_admin(self): @@ -1570,18 +1621,16 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ # Set a user as an admin - body = json.dumps({"admin": True}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"admin": True}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["admin"]) + self.assertTrue(channel.json_body["admin"]) # Get user channel = self.make_request( @@ -1592,7 +1641,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["admin"]) + self.assertTrue(channel.json_body["admin"]) def test_accidental_deactivation_prevention(self): """ @@ -1602,13 +1651,11 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user - body = json.dumps({"password": "abc123"}) - channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"password": "abc123"}, ) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) @@ -1628,13 +1675,11 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(0, channel.json_body["deactivated"]) # Change password (and use a str for deactivate instead of a bool) - body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops! - channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"password": "abc123", "deactivated": "false"}, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -1653,7 +1698,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Ensure they're still alive self.assertEqual(0, channel.json_body["deactivated"]) - def _is_erased(self, user_id, expect): + def _is_erased(self, user_id: str, expect: bool) -> None: """Assert that the user is erased or not""" d = self.store.is_user_erased(user_id) if expect: @@ -1661,6 +1706,24 @@ class UserRestTestCase(unittest.HomeserverTestCase): else: self.assertFalse(self.get_success(d)) + def _deactivate_user(self, user_id: str) -> None: + """Deactivate user and set as erased""" + + # Deactivate the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % urllib.parse.quote(user_id), + access_token=self.admin_user_tok, + content={"deactivated": True}, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertTrue(channel.json_body["deactivated"]) + self.assertIsNone(channel.json_body["password_hash"]) + self._is_erased(user_id, False) + d = self.store.mark_user_erased(user_id) + self.assertIsNone(self.get_success(d)) + self._is_erased(user_id, True) + class UserMembershipRestTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 227fffab58..bf39014277 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -161,6 +161,68 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): ev = channel.json_body self.assertEqual(ev["content"]["x"], "y") + def test_message_edit(self): + """Ensure that the module doesn't cause issues with edited messages.""" + # first patch the event checker so that it will modify the event + async def check(ev: EventBase, state): + d = ev.get_dict() + d["content"] = { + "msgtype": "m.text", + "body": d["content"]["body"].upper(), + } + return d + + current_rules_module().check_event_allowed = check + + # Send an event, then edit it. + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id, + { + "msgtype": "m.text", + "body": "Original body", + }, + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + orig_event_id = channel.json_body["event_id"] + + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/m.room.message/2" % self.room_id, + { + "m.new_content": {"msgtype": "m.text", "body": "Edited body"}, + "m.relates_to": { + "rel_type": "m.replace", + "event_id": orig_event_id, + }, + "msgtype": "m.text", + "body": "Edited body", + }, + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + edited_event_id = channel.json_body["event_id"] + + # ... and check that they both got modified + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, orig_event_id), + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + ev = channel.json_body + self.assertEqual(ev["content"]["body"], "ORIGINAL BODY") + + channel = self.make_request( + "GET", + "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, edited_event_id), + access_token=self.tok, + ) + self.assertEqual(channel.result["code"], b"200", channel.result) + ev = channel.json_body + self.assertEqual(ev["content"]["body"], "EDITED BODY") + def test_send_event(self): """Tests that the module can send an event into a room via the module api""" content = { diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index e808339fb3..287a1a485c 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -18,6 +18,7 @@ from synapse.rest.client.v1 import login from synapse.rest.client.v2_alpha import capabilities from tests import unittest +from tests.unittest import override_config class CapabilitiesTestCase(unittest.HomeserverTestCase): @@ -33,6 +34,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver() self.store = hs.get_datastore() self.config = hs.config + self.auth_handler = hs.get_auth_handler() return hs def test_check_auth_required(self): @@ -56,7 +58,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): capabilities["m.room_versions"]["default"], ) - def test_get_change_password_capabilities(self): + def test_get_change_password_capabilities_password_login(self): localpart = "user" password = "pass" user = self.register_user(localpart, password) @@ -66,10 +68,36 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) - - # Test case where password is handled outside of Synapse self.assertTrue(capabilities["m.change_password"]["enabled"]) - self.get_success(self.store.user_set_password_hash(user, None)) + + @override_config({"password_config": {"localdb_enabled": False}}) + def test_get_change_password_capabilities_localdb_disabled(self): + localpart = "user" + password = "pass" + user = self.register_user(localpart, password) + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["m.change_password"]["enabled"]) + + @override_config({"password_config": {"enabled": False}}) + def test_get_change_password_capabilities_password_disabled(self): + localpart = "user" + password = "pass" + user = self.register_user(localpart, password) + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + user, device_id=None, valid_until_ms=None + ) + ) + channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"] diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py index 7c457754f1..e7bb5583fc 100644 --- a/tests/rest/client/v2_alpha/test_relations.py +++ b/tests/rest/client/v2_alpha/test_relations.py @@ -39,6 +39,11 @@ class RelationsTestCase(unittest.HomeserverTestCase): # We need to enable msc1849 support for aggregations config = self.default_config() config["experimental_msc1849_support_enabled"] = True + + # We enable frozen dicts as relations/edits change event contents, so we + # want to test that we don't modify the events in the caches. + config["use_frozen_dicts"] = True + return self.setup_test_homeserver(config=config) def prepare(self, reactor, clock, hs): @@ -518,6 +523,63 @@ class RelationsTestCase(unittest.HomeserverTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) + def test_edit_reply(self): + """Test that editing a reply works.""" + + # Create a reply to edit. + channel = self._send_relation( + RelationTypes.REFERENCE, + "m.room.message", + content={"msgtype": "m.text", "body": "A reply!"}, + ) + self.assertEquals(200, channel.code, channel.json_body) + reply = channel.json_body["event_id"] + + new_body = {"msgtype": "m.text", "body": "I've been edited!"} + channel = self._send_relation( + RelationTypes.REPLACE, + "m.room.message", + content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, + parent_id=reply, + ) + self.assertEquals(200, channel.code, channel.json_body) + + edit_event_id = channel.json_body["event_id"] + + channel = self.make_request( + "GET", + "/rooms/%s/event/%s" % (self.room, reply), + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + + # We expect to see the new body in the dict, as well as the reference + # metadata sill intact. + self.assertDictContainsSubset(new_body, channel.json_body["content"]) + self.assertDictContainsSubset( + { + "m.relates_to": { + "event_id": self.parent_id, + "key": None, + "rel_type": "m.reference", + } + }, + channel.json_body["content"], + ) + + # We expect that the edit relation appears in the unsigned relations + # section. + relations_dict = channel.json_body["unsigned"].get("m.relations") + self.assertIn(RelationTypes.REPLACE, relations_dict) + + m_replace_dict = relations_dict[RelationTypes.REPLACE] + for key in ["event_id", "sender", "origin_server_ts"]: + self.assertIn(key, m_replace_dict) + + self.assert_dict( + {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict + ) + def test_relations_redaction_redacts_edits(self): """Test that edits of an event are redacted when the original event is redacted. diff --git a/tests/unittest.py b/tests/unittest.py index ca7031c724..58a4daa1ec 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -32,6 +32,7 @@ from twisted.python.threadpool import ThreadPool from twisted.trial import unittest from twisted.web.resource import Resource +from synapse import events from synapse.api.constants import EventTypes, Membership from synapse.config.homeserver import HomeServerConfig from synapse.config.ratelimiting import FederationRateLimitConfig @@ -140,7 +141,7 @@ class TestCase(unittest.TestCase): try: self.assertEquals(attrs[key], getattr(obj, key)) except AssertionError as e: - raise (type(e))(e.message + " for '.%s'" % key) + raise (type(e))("Assert error for '.{}':".format(key)) from e def assert_dict(self, required, actual): """Does a partial assert of a dict. @@ -229,6 +230,11 @@ class HomeserverTestCase(TestCase): self._hs_args = {"clock": self.clock, "reactor": self.reactor} self.hs = self.make_homeserver(self.reactor, self.clock) + # Honour the `use_frozen_dicts` config option. We have to do this + # manually because this is taken care of in the app `start` code, which + # we don't run. Plus we want to reset it on tearDown. + events.USE_FROZEN_DICTS = self.hs.config.use_frozen_dicts + if self.hs is None: raise Exception("No homeserver returned from make_homeserver.") @@ -292,6 +298,10 @@ class HomeserverTestCase(TestCase): if hasattr(self, "prepare"): self.prepare(self.reactor, self.clock, self.hs) + def tearDown(self): + # Reset to not use frozen dicts. + events.USE_FROZEN_DICTS = False + def wait_on_thread(self, deferred, timeout=10): """ Wait until a Deferred is done, where it's waiting on a real thread. |