diff --git a/changelog.d/9411.misc b/changelog.d/9411.misc
new file mode 100644
index 0000000000..c3e6cfa5f1
--- /dev/null
+++ b/changelog.d/9411.misc
@@ -0,0 +1 @@
+Preparatory steps for removing redundant `outlier` data from `event_json.internal_metadata` column.
diff --git a/changelog.d/9499.misc b/changelog.d/9499.misc
new file mode 100644
index 0000000000..1513017a10
--- /dev/null
+++ b/changelog.d/9499.misc
@@ -0,0 +1 @@
+Introduce bugbear to the test suite and fix some of it's lint violations.
\ No newline at end of file
diff --git a/changelog.d/9585.bugfix b/changelog.d/9585.bugfix
new file mode 100644
index 0000000000..de472ddfd1
--- /dev/null
+++ b/changelog.d/9585.bugfix
@@ -0,0 +1 @@
+Fix a longstanding bug that could cause issues when editing a reply to a message.
\ No newline at end of file
diff --git a/changelog.d/9588.bugfix b/changelog.d/9588.bugfix
new file mode 100644
index 0000000000..b8d6140565
--- /dev/null
+++ b/changelog.d/9588.bugfix
@@ -0,0 +1 @@
+Fix the `/capabilities` endpoint to return `m.change_password` as disabled if the local password database is not used for authentication. Contributed by @dklimpel.
diff --git a/changelog.d/9609.feature b/changelog.d/9609.feature
new file mode 100644
index 0000000000..f3b6342069
--- /dev/null
+++ b/changelog.d/9609.feature
@@ -0,0 +1 @@
+Logins using OpenID Connect can require attributes on the `userinfo` response in order to login. Contributed by Hubbe King.
diff --git a/changelog.d/9631.misc b/changelog.d/9631.misc
new file mode 100644
index 0000000000..35338cd332
--- /dev/null
+++ b/changelog.d/9631.misc
@@ -0,0 +1 @@
+Add additional type hints to the Homeserver object.
diff --git a/changelog.d/9634.misc b/changelog.d/9634.misc
new file mode 100644
index 0000000000..59ac42cb83
--- /dev/null
+++ b/changelog.d/9634.misc
@@ -0,0 +1 @@
+Only save remote cross-signing and device keys if they're different from the current ones.
diff --git a/changelog.d/9636.bugfix b/changelog.d/9636.bugfix
new file mode 100644
index 0000000000..fa772ed6fc
--- /dev/null
+++ b/changelog.d/9636.bugfix
@@ -0,0 +1 @@
+Checks if passwords are allowed before setting it for the user.
\ No newline at end of file
diff --git a/changelog.d/9637.misc b/changelog.d/9637.misc
new file mode 100644
index 0000000000..90a27d9f8f
--- /dev/null
+++ b/changelog.d/9637.misc
@@ -0,0 +1 @@
+Rename storage function to fix spelling and not conflict with another functions name.
diff --git a/changelog.d/9638.misc b/changelog.d/9638.misc
new file mode 100644
index 0000000000..35338cd332
--- /dev/null
+++ b/changelog.d/9638.misc
@@ -0,0 +1 @@
+Add additional type hints to the Homeserver object.
diff --git a/changelog.d/9640.misc b/changelog.d/9640.misc
new file mode 100644
index 0000000000..3d410ed4cd
--- /dev/null
+++ b/changelog.d/9640.misc
@@ -0,0 +1 @@
+Improve performance of federation catch up by sending events the latest events in the room to the remote, rather than just the last event sent by the local server.
diff --git a/changelog.d/9643.feature b/changelog.d/9643.feature
new file mode 100644
index 0000000000..2f7ccedcfb
--- /dev/null
+++ b/changelog.d/9643.feature
@@ -0,0 +1 @@
+Add initial experimental support for a "space summary" API.
diff --git a/changelog.d/9644.feature b/changelog.d/9644.feature
new file mode 100644
index 0000000000..556bcf0f9f
--- /dev/null
+++ b/changelog.d/9644.feature
@@ -0,0 +1 @@
+Implement the busy presence state as described in [MSC3026](https://github.com/matrix-org/matrix-doc/pull/3026).
diff --git a/changelog.d/9645.misc b/changelog.d/9645.misc
new file mode 100644
index 0000000000..9a7ce364c1
--- /dev/null
+++ b/changelog.d/9645.misc
@@ -0,0 +1 @@
+In the `federation_client` commandline client, stop automatically adding the URL prefix, so that servlets on other prefixes can be tested.
diff --git a/changelog.d/9647.misc b/changelog.d/9647.misc
new file mode 100644
index 0000000000..303a8c6606
--- /dev/null
+++ b/changelog.d/9647.misc
@@ -0,0 +1 @@
+In the `federation_client` commandline client, handle inline `signing_key`s in `homeserver.yaml`.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 7de000f4a4..a9f59e39f7 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1873,6 +1873,24 @@ saml2_config:
# which is set to the claims returned by the UserInfo Endpoint and/or
# in the ID Token.
#
+# It is possible to configure Synapse to only allow logins if certain attributes
+# match particular values in the OIDC userinfo. The requirements can be listed under
+# `attribute_requirements` as shown below. All of the listed attributes must
+# match for the login to be permitted. Additional attributes can be added to
+# userinfo by expanding the `scopes` section of the OIDC config to retrieve
+# additional information from the OIDC provider.
+#
+# If the OIDC claim is a list, then the attribute must match any value in the list.
+# Otherwise, it must exactly match the value of the claim. Using the example
+# below, the `family_name` claim MUST be "Stephensson", but the `groups`
+# claim MUST contain "admin".
+#
+# attribute_requirements:
+# - attribute: family_name
+# value: "Stephensson"
+# - attribute: groups
+# value: "admin"
+#
# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
# for information on how to configure these options.
#
@@ -1905,6 +1923,9 @@ oidc_providers:
# localpart_template: "{{ user.login }}"
# display_name_template: "{{ user.name }}"
# email_template: "{{ user.email }}"
+ # attribute_requirements:
+ # - attribute: userGroup
+ # value: "synapseUsers"
# For use with Keycloak
#
@@ -1914,6 +1935,9 @@ oidc_providers:
# client_id: "synapse"
# client_secret: "copy secret generated in Keycloak UI"
# scopes: ["openid", "profile"]
+ # attribute_requirements:
+ # - attribute: groups
+ # value: "admin"
# For use with Github
#
diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index abcec48c4f..6f76c08fcf 100755
--- a/scripts-dev/federation_client.py
+++ b/scripts-dev/federation_client.py
@@ -22,8 +22,8 @@ import sys
from typing import Any, Optional
from urllib import parse as urlparse
-import nacl.signing
import requests
+import signedjson.key
import signedjson.types
import srvlookup
import yaml
@@ -44,18 +44,6 @@ def encode_base64(input_bytes):
return output_string
-def decode_base64(input_string):
- """Decode a base64 string to bytes inferring padding from the length of the
- string."""
-
- input_bytes = input_string.encode("ascii")
- input_len = len(input_bytes)
- padding = b"=" * (3 - ((input_len + 3) % 4))
- output_len = 3 * ((input_len + 2) // 4) + (input_len + 2) % 4 - 2
- output_bytes = base64.b64decode(input_bytes + padding)
- return output_bytes[:output_len]
-
-
def encode_canonical_json(value):
return json.dumps(
value,
@@ -88,42 +76,6 @@ def sign_json(
return json_object
-NACL_ED25519 = "ed25519"
-
-
-def decode_signing_key_base64(algorithm, version, key_base64):
- """Decode a base64 encoded signing key
- Args:
- algorithm (str): The algorithm the key is for (currently "ed25519").
- version (str): Identifies this key out of the keys for this entity.
- key_base64 (str): Base64 encoded bytes of the key.
- Returns:
- A SigningKey object.
- """
- if algorithm == NACL_ED25519:
- key_bytes = decode_base64(key_base64)
- key = nacl.signing.SigningKey(key_bytes)
- key.version = version
- key.alg = NACL_ED25519
- return key
- else:
- raise ValueError("Unsupported algorithm %s" % (algorithm,))
-
-
-def read_signing_keys(stream):
- """Reads a list of keys from a stream
- Args:
- stream : A stream to iterate for keys.
- Returns:
- list of SigningKey objects.
- """
- keys = []
- for line in stream:
- algorithm, version, key_base64 = line.split()
- keys.append(decode_signing_key_base64(algorithm, version, key_base64))
- return keys
-
-
def request(
method: Optional[str],
origin_name: str,
@@ -223,23 +175,28 @@ def main():
parser.add_argument("--body", help="Data to send as the body of the HTTP request")
parser.add_argument(
- "path", help="request path. We will add '/_matrix/federation/v1/' to this."
+ "path", help="request path, including the '/_matrix/federation/...' prefix."
)
args = parser.parse_args()
- if not args.server_name or not args.signing_key_path:
+ args.signing_key = None
+ if args.signing_key_path:
+ with open(args.signing_key_path) as f:
+ args.signing_key = f.readline()
+
+ if not args.server_name or not args.signing_key:
read_args_from_config(args)
- with open(args.signing_key_path) as f:
- key = read_signing_keys(f)[0]
+ algorithm, version, key_base64 = args.signing_key.split()
+ key = signedjson.key.decode_signing_key_base64(algorithm, version, key_base64)
result = request(
args.method,
args.server_name,
key,
args.destination,
- "/_matrix/federation/v1/" + args.path,
+ args.path,
content=args.body,
)
@@ -255,10 +212,16 @@ def main():
def read_args_from_config(args):
with open(args.config, "r") as fh:
config = yaml.safe_load(fh)
+
if not args.server_name:
args.server_name = config["server_name"]
- if not args.signing_key_path:
- args.signing_key_path = config["signing_key_path"]
+
+ if not args.signing_key:
+ if "signing_key" in config:
+ args.signing_key = config["signing_key"]
+ else:
+ with open(config["signing_key_path"]) as f:
+ args.signing_key = f.readline()
class MatrixConnectionAdapter(HTTPAdapter):
diff --git a/setup.cfg b/setup.cfg
index 5e301c2cd7..920868df20 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -18,7 +18,8 @@ ignore =
# E203: whitespace before ':' (which is contrary to pep8?)
# E731: do not assign a lambda expression, use a def
# E501: Line too long (black enforces this for us)
-ignore=W503,W504,E203,E731,E501
+# B00: Subsection of the bugbear suite (TODO: add in remaining fixes)
+ignore=W503,W504,E203,E731,E501,B00
[isort]
line_length = 88
diff --git a/setup.py b/setup.py
index bbd9e7862a..b834e4e55b 100755
--- a/setup.py
+++ b/setup.py
@@ -99,6 +99,7 @@ CONDITIONAL_REQUIREMENTS["lint"] = [
"isort==5.7.0",
"black==20.8b1",
"flake8-comprehensions",
+ "flake8-bugbear",
"flake8",
]
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 691f8f9adf..8f37d2cf3b 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -51,6 +51,7 @@ class PresenceState:
OFFLINE = "offline"
UNAVAILABLE = "unavailable"
ONLINE = "online"
+ BUSY = "org.matrix.msc3026.busy"
class JoinRules:
@@ -100,6 +101,9 @@ class EventTypes:
Dummy = "org.matrix.dummy_event"
+ MSC1772_SPACE_CHILD = "org.matrix.msc1772.space.child"
+ MSC1772_SPACE_PARENT = "org.matrix.msc1772.space.parent"
+
class EduTypes:
Presence = "m.presence"
@@ -160,6 +164,9 @@ class EventContentFields:
# cf https://github.com/matrix-org/matrix-doc/pull/2228
SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after"
+ # cf https://github.com/matrix-org/matrix-doc/pull/1772
+ MSC1772_ROOM_TYPE = "org.matrix.msc1772.type"
+
class RoomEncryptionAlgorithms:
MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2"
diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py
index 4a9b0129c3..d1a2cd5e19 100644
--- a/synapse/app/__init__.py
+++ b/synapse/app/__init__.py
@@ -22,7 +22,9 @@ logger = logging.getLogger(__name__)
try:
python_dependencies.check_requirements()
except python_dependencies.DependencyException as e:
- sys.stderr.writelines(e.message)
+ sys.stderr.writelines(
+ e.message # noqa: B306, DependencyException.message is a property
+ )
sys.exit(1)
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 274d582d07..caef394e1d 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -302,6 +302,8 @@ class GenericWorkerPresence(BasePresenceHandler):
self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
)
+ self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
+
hs.get_reactor().addSystemEventTrigger(
"before",
"shutdown",
@@ -439,8 +441,12 @@ class GenericWorkerPresence(BasePresenceHandler):
PresenceState.ONLINE,
PresenceState.UNAVAILABLE,
PresenceState.OFFLINE,
+ PresenceState.BUSY,
)
- if presence not in valid_presence:
+
+ if presence not in valid_presence or (
+ presence == PresenceState.BUSY and not self._busy_presence_enabled
+ ):
raise SynapseError(400, "Invalid presence state")
user_id = target_user.to_string()
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index b1c1c51e4d..86f4d9af9d 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -27,3 +27,7 @@ class ExperimentalConfig(Config):
# MSC2858 (multiple SSO identity providers)
self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool
+ # Spaces (MSC1772, MSC2946, etc)
+ self.spaces_enabled = experimental.get("spaces_enabled", False) # type: bool
+ # MSC3026 (busy presence state)
+ self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool
diff --git a/synapse/config/key.py b/synapse/config/key.py
index de964dff13..350ff1d665 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -404,7 +404,11 @@ def _parse_key_servers(key_servers, federation_verify_certificates):
try:
jsonschema.validate(key_servers, TRUSTED_KEY_SERVERS_SCHEMA)
except jsonschema.ValidationError as e:
- raise ConfigError("Unable to parse 'trusted_key_servers': " + e.message)
+ raise ConfigError(
+ "Unable to parse 'trusted_key_servers': {}".format(
+ e.message # noqa: B306, jsonschema.ValidationError.message is a valid attribute
+ )
+ )
for server in key_servers:
server_name = server["server_name"]
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index dfd27e1523..2b289f4208 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -56,7 +56,9 @@ class MetricsConfig(Config):
try:
check_requirements("sentry")
except DependencyException as e:
- raise ConfigError(e.message)
+ raise ConfigError(
+ e.message # noqa: B306, DependencyException.message is a property
+ )
self.sentry_dsn = config["sentry"].get("dsn")
if not self.sentry_dsn:
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index 2bfb537c15..747ab9a7fe 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -15,11 +15,12 @@
# limitations under the License.
from collections import Counter
-from typing import Iterable, Mapping, Optional, Tuple, Type
+from typing import Iterable, List, Mapping, Optional, Tuple, Type
import attr
from synapse.config._util import validate_config
+from synapse.config.sso import SsoAttributeRequirement
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.types import Collection, JsonDict
from synapse.util.module_loader import load_module
@@ -41,7 +42,9 @@ class OIDCConfig(Config):
try:
check_requirements("oidc")
except DependencyException as e:
- raise ConfigError(e.message) from e
+ raise ConfigError(
+ e.message # noqa: B306, DependencyException.message is a property
+ ) from e
# check we don't have any duplicate idp_ids now. (The SSO handler will also
# check for duplicates when the REST listeners get registered, but that happens
@@ -191,6 +194,24 @@ class OIDCConfig(Config):
# which is set to the claims returned by the UserInfo Endpoint and/or
# in the ID Token.
#
+ # It is possible to configure Synapse to only allow logins if certain attributes
+ # match particular values in the OIDC userinfo. The requirements can be listed under
+ # `attribute_requirements` as shown below. All of the listed attributes must
+ # match for the login to be permitted. Additional attributes can be added to
+ # userinfo by expanding the `scopes` section of the OIDC config to retrieve
+ # additional information from the OIDC provider.
+ #
+ # If the OIDC claim is a list, then the attribute must match any value in the list.
+ # Otherwise, it must exactly match the value of the claim. Using the example
+ # below, the `family_name` claim MUST be "Stephensson", but the `groups`
+ # claim MUST contain "admin".
+ #
+ # attribute_requirements:
+ # - attribute: family_name
+ # value: "Stephensson"
+ # - attribute: groups
+ # value: "admin"
+ #
# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md
# for information on how to configure these options.
#
@@ -223,6 +244,9 @@ class OIDCConfig(Config):
# localpart_template: "{{{{ user.login }}}}"
# display_name_template: "{{{{ user.name }}}}"
# email_template: "{{{{ user.email }}}}"
+ # attribute_requirements:
+ # - attribute: userGroup
+ # value: "synapseUsers"
# For use with Keycloak
#
@@ -232,6 +256,9 @@ class OIDCConfig(Config):
# client_id: "synapse"
# client_secret: "copy secret generated in Keycloak UI"
# scopes: ["openid", "profile"]
+ # attribute_requirements:
+ # - attribute: groups
+ # value: "admin"
# For use with Github
#
@@ -329,6 +356,10 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
},
"allow_existing_users": {"type": "boolean"},
"user_mapping_provider": {"type": ["object", "null"]},
+ "attribute_requirements": {
+ "type": "array",
+ "items": SsoAttributeRequirement.JSON_SCHEMA,
+ },
},
}
@@ -465,6 +496,11 @@ def _parse_oidc_config_dict(
jwt_header=client_secret_jwt_key_config["jwt_header"],
jwt_payload=client_secret_jwt_key_config.get("jwt_payload", {}),
)
+ # parse attribute_requirements from config (list of dicts) into a list of SsoAttributeRequirement
+ attribute_requirements = [
+ SsoAttributeRequirement(**x)
+ for x in oidc_config.get("attribute_requirements", [])
+ ]
return OidcProviderConfig(
idp_id=idp_id,
@@ -488,6 +524,7 @@ def _parse_oidc_config_dict(
allow_existing_users=oidc_config.get("allow_existing_users", False),
user_mapping_provider_class=user_mapping_provider_class,
user_mapping_provider_config=user_mapping_provider_config,
+ attribute_requirements=attribute_requirements,
)
@@ -577,3 +614,6 @@ class OidcProviderConfig:
# the config of the user mapping provider
user_mapping_provider_config = attr.ib()
+
+ # required attributes to require in userinfo to allow login/registration
+ attribute_requirements = attr.ib(type=List[SsoAttributeRequirement])
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 69d9de5a43..061c4ec83f 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -176,7 +176,9 @@ class ContentRepositoryConfig(Config):
check_requirements("url_preview")
except DependencyException as e:
- raise ConfigError(e.message)
+ raise ConfigError(
+ e.message # noqa: B306, DependencyException.message is a property
+ )
if "url_preview_ip_range_blacklist" not in config:
raise ConfigError(
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index 4b494f217f..6db9cb5ced 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -76,7 +76,9 @@ class SAML2Config(Config):
try:
check_requirements("saml2")
except DependencyException as e:
- raise ConfigError(e.message)
+ raise ConfigError(
+ e.message # noqa: B306, DependencyException.message is a property
+ )
self.saml2_enabled = True
diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
index 0c1a854f09..727a1e7008 100644
--- a/synapse/config/tracer.py
+++ b/synapse/config/tracer.py
@@ -39,7 +39,9 @@ class TracerConfig(Config):
try:
check_requirements("opentracing")
except DependencyException as e:
- raise ConfigError(e.message)
+ raise ConfigError(
+ e.message # noqa: B306, DependencyException.message is a property
+ )
# The tracer is enabled so sanitize the config
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 14b21796d9..4ca13011e5 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -219,7 +219,7 @@ class SSLClientConnectionCreator:
# ... and we also gut-wrench a '_synapse_tls_verifier' attribute into the
# tls_protocol so that the SSL context's info callback has something to
# call to do the cert verification.
- setattr(tls_protocol, "_synapse_tls_verifier", self._verifier)
+ tls_protocol._synapse_tls_verifier = self._verifier
return connection
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 3ec4120f85..8f6b955d17 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -98,7 +98,7 @@ class DefaultDictProperty(DictProperty):
class _EventInternalMetadata:
- __slots__ = ["_dict", "stream_ordering"]
+ __slots__ = ["_dict", "stream_ordering", "outlier"]
def __init__(self, internal_metadata_dict: JsonDict):
# we have to copy the dict, because it turns out that the same dict is
@@ -108,7 +108,10 @@ class _EventInternalMetadata:
# the stream ordering of this event. None, until it has been persisted.
self.stream_ordering = None # type: Optional[int]
- outlier = DictProperty("outlier") # type: bool
+ # whether this event is an outlier (ie, whether we have the state at that point
+ # in the DAG)
+ self.outlier = False
+
out_of_band_membership = DictProperty("out_of_band_membership") # type: bool
send_on_behalf_of = DictProperty("send_on_behalf_of") # type: str
recheck_redaction = DictProperty("recheck_redaction") # type: bool
@@ -129,7 +132,7 @@ class _EventInternalMetadata:
return dict(self._dict)
def is_outlier(self) -> bool:
- return self._dict.get("outlier", False)
+ return self.outlier
def is_out_of_band_membership(self) -> bool:
"""Whether this is an out of band membership, like an invite or an invite
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 7ca5c9940a..0f8a3b5ad8 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -22,6 +22,7 @@ from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.util.async_helpers import yieldable_gather_results
+from synapse.util.frozenutils import unfreeze
from . import EventBase
@@ -54,6 +55,8 @@ def prune_event(event: EventBase) -> EventBase:
event.internal_metadata.stream_ordering
)
+ pruned_event.internal_metadata.outlier = event.internal_metadata.outlier
+
# Mark the event as redacted
pruned_event.internal_metadata.redacted = True
@@ -400,10 +403,19 @@ class EventClientSerializer:
# If there is an edit replace the content, preserving existing
# relations.
+ # Ensure we take copies of the edit content, otherwise we risk modifying
+ # the original event.
+ edit_content = edit.content.copy()
+
+ # Unfreeze the event content if necessary, so that we may modify it below
+ edit_content = unfreeze(edit_content)
+ serialized_event["content"] = edit_content.get("m.new_content", {})
+
+ # Check for existing relations
relations = event.content.get("m.relates_to")
- serialized_event["content"] = edit.content.get("m.new_content", {})
if relations:
- serialized_event["content"]["m.relates_to"] = relations
+ # Keep the relations, ensuring we use a dict copy of the original
+ serialized_event["content"]["m.relates_to"] = relations.copy()
else:
serialized_event["content"].pop("m.relates_to", None)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 9839d3d016..d84e362070 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -35,7 +35,7 @@ from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
-from synapse.api.constants import EduTypes, EventTypes, Membership
+from synapse.api.constants import EduTypes, EventTypes
from synapse.api.errors import (
AuthError,
Codes,
@@ -63,7 +63,7 @@ from synapse.replication.http.federation import (
ReplicationFederationSendEduRestServlet,
ReplicationGetQueryRestServlet,
)
-from synapse.types import JsonDict, get_domain_from_id
+from synapse.types import JsonDict
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
@@ -727,27 +727,6 @@ class FederationServer(FederationBase):
if the event was unacceptable for any other reason (eg, too large,
too many prev_events, couldn't find the prev_events)
"""
- # check that it's actually being sent from a valid destination to
- # workaround bug #1753 in 0.18.5 and 0.18.6
- if origin != get_domain_from_id(pdu.sender):
- # We continue to accept join events from any server; this is
- # necessary for the federation join dance to work correctly.
- # (When we join over federation, the "helper" server is
- # responsible for sending out the join event, rather than the
- # origin. See bug #1893. This is also true for some third party
- # invites).
- if not (
- pdu.type == "m.room.member"
- and pdu.content
- and pdu.content.get("membership", None)
- in (Membership.JOIN, Membership.INVITE)
- ):
- logger.info(
- "Discarding PDU %s from invalid origin %s", pdu.event_id, origin
- )
- return
- else:
- logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
# We've already checked that we know the room version by this point
room_version = await self.store.get_room_version(pdu.room_id)
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index cc0d765e5f..af85fe0a1e 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -15,7 +15,7 @@
# limitations under the License.
import datetime
import logging
-from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple
import attr
from prometheus_client import Counter
@@ -77,6 +77,7 @@ class PerDestinationQueue:
self._transaction_manager = transaction_manager
self._instance_name = hs.get_instance_name()
self._federation_shard_config = hs.config.worker.federation_shard_config
+ self._state = hs.get_state_handler()
self._should_send_on_this_instance = True
if not self._federation_shard_config.should_handle(
@@ -415,22 +416,95 @@ class PerDestinationQueue:
"This should not happen." % event_ids
)
- if logger.isEnabledFor(logging.INFO):
- rooms = [p.room_id for p in catchup_pdus]
- logger.info("Catching up rooms to %s: %r", self._destination, rooms)
+ # We send transactions with events from one room only, as its likely
+ # that the remote will have to do additional processing, which may
+ # take some time. It's better to give it small amounts of work
+ # rather than risk the request timing out and repeatedly being
+ # retried, and not making any progress.
+ #
+ # Note: `catchup_pdus` will have exactly one PDU per room.
+ for pdu in catchup_pdus:
+ # The PDU from the DB will be the last PDU in the room from
+ # *this server* that wasn't sent to the remote. However, other
+ # servers may have sent lots of events since then, and we want
+ # to try and tell the remote only about the *latest* events in
+ # the room. This is so that it doesn't get inundated by events
+ # from various parts of the DAG, which all need to be processed.
+ #
+ # Note: this does mean that in large rooms a server coming back
+ # online will get sent the same events from all the different
+ # servers, but the remote will correctly deduplicate them and
+ # handle it only once.
+
+ # Step 1, fetch the current extremities
+ extrems = await self._store.get_prev_events_for_room(pdu.room_id)
+
+ if pdu.event_id in extrems:
+ # If the event is in the extremities, then great! We can just
+ # use that without having to do further checks.
+ room_catchup_pdus = [pdu]
+ else:
+ # If not, fetch the extremities and figure out which we can
+ # send.
+ extrem_events = await self._store.get_events_as_list(extrems)
+
+ new_pdus = []
+ for p in extrem_events:
+ # We pulled this from the DB, so it'll be non-null
+ assert p.internal_metadata.stream_ordering
+
+ # Filter out events that happened before the remote went
+ # offline
+ if (
+ p.internal_metadata.stream_ordering
+ < self._last_successful_stream_ordering
+ ):
+ continue
- await self._transaction_manager.send_new_transaction(
- self._destination, catchup_pdus, []
- )
+ # Filter out events where the server is not in the room,
+ # e.g. it may have left/been kicked. *Ideally* we'd pull
+ # out the kick and send that, but it's a rare edge case
+ # so we don't bother for now (the server that sent the
+ # kick should send it out if its online).
+ hosts = await self._state.get_hosts_in_room_at_events(
+ p.room_id, [p.event_id]
+ )
+ if self._destination not in hosts:
+ continue
- sent_transactions_counter.inc()
- final_pdu = catchup_pdus[-1]
- self._last_successful_stream_ordering = cast(
- int, final_pdu.internal_metadata.stream_ordering
- )
- await self._store.set_destination_last_successful_stream_ordering(
- self._destination, self._last_successful_stream_ordering
- )
+ new_pdus.append(p)
+
+ # If we've filtered out all the extremities, fall back to
+ # sending the original event. This should ensure that the
+ # server gets at least some of missed events (especially if
+ # the other sending servers are up).
+ if new_pdus:
+ room_catchup_pdus = new_pdus
+
+ logger.info(
+ "Catching up rooms to %s: %r", self._destination, pdu.room_id
+ )
+
+ await self._transaction_manager.send_new_transaction(
+ self._destination, room_catchup_pdus, []
+ )
+
+ sent_transactions_counter.inc()
+
+ # We pulled this from the DB, so it'll be non-null
+ assert pdu.internal_metadata.stream_ordering
+
+ # Note that we mark the last successful stream ordering as that
+ # from the *original* PDU, rather than the PDU(s) we actually
+ # send. This is because we use it to mark our position in the
+ # queue of missed PDUs to process.
+ self._last_successful_stream_ordering = (
+ pdu.internal_metadata.stream_ordering
+ )
+
+ await self._store.set_destination_last_successful_stream_ordering(
+ self._destination, self._last_successful_stream_ordering
+ )
def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]:
if not self._pending_rrs:
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index fb5f8118f0..badac8c26c 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -886,6 +886,19 @@ class AuthHandler(BaseHandler):
)
return result
+ def can_change_password(self) -> bool:
+ """Get whether users on this server are allowed to change or set a password.
+
+ Both `config.password_enabled` and `config.password_localdb_enabled` must be true.
+
+ Note that any account (even SSO accounts) are allowed to add passwords if the above
+ is true.
+
+ Returns:
+ Whether users on this server are allowed to change or set a password
+ """
+ return self._password_enabled and self._password_localdb_enabled
+
def get_supported_login_types(self) -> Iterable[str]:
"""Get a the login types supported for the /login API
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index df3cdc8fba..2fc4951df4 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -166,7 +166,7 @@ class DeviceWorkerHandler(BaseHandler):
# Fetch the current state at the time.
try:
- event_ids = await self.store.get_forward_extremeties_for_room(
+ event_ids = await self.store.get_forward_extremities_for_room_at_stream_ordering(
room_id, stream_ordering=stream_ordering
)
except errors.StoreError:
@@ -907,6 +907,7 @@ class DeviceListUpdater:
master_key = result.get("master_key")
self_signing_key = result.get("self_signing_key")
+ ignore_devices = False
# If the remote server has more than ~1000 devices for this user
# we assume that something is going horribly wrong (e.g. a bot
# that logs in and creates a new device every time it tries to
@@ -925,6 +926,12 @@ class DeviceListUpdater:
len(devices),
)
devices = []
+ ignore_devices = True
+ else:
+ cached_devices = await self.store.get_cached_devices_for_user(user_id)
+ if cached_devices == {d["device_id"]: d for d in devices}:
+ devices = []
+ ignore_devices = True
for device in devices:
logger.debug(
@@ -934,7 +941,10 @@ class DeviceListUpdater:
stream_id,
)
- await self.store.update_remote_device_list_cache(user_id, devices, stream_id)
+ if not ignore_devices:
+ await self.store.update_remote_device_list_cache(
+ user_id, devices, stream_id
+ )
device_ids = [device["device_id"] for device in devices]
# Handle cross-signing keys.
@@ -945,7 +955,8 @@ class DeviceListUpdater:
)
device_ids = device_ids + cross_signing_device_ids
- await self.device_handler.notify_device_update(user_id, device_ids)
+ if device_ids:
+ await self.device_handler.notify_device_update(user_id, device_ids)
# We clobber the seen updates since we've re-synced from a given
# point.
@@ -973,14 +984,17 @@ class DeviceListUpdater:
"""
device_ids = []
- if master_key:
+ current_keys_map = await self.store.get_e2e_cross_signing_keys_bulk([user_id])
+ current_keys = current_keys_map.get(user_id) or {}
+
+ if master_key and master_key != current_keys.get("master"):
await self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
_, verify_key = get_verify_key_from_cross_signing_key(master_key)
# verify_key is a VerifyKey from signedjson, which uses
# .version to denote the portion of the key ID after the
# algorithm and colon, which is the device ID
device_ids.append(verify_key.version)
- if self_signing_key:
+ if self_signing_key and self_signing_key != current_keys.get("self_signing"):
await self.store.set_e2e_cross_signing_key(
user_id, "self_signing", self_signing_key
)
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 6d8551a6d6..bc3630e9e9 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -280,6 +280,7 @@ class OidcProvider:
self._config = provider
self._callback_url = hs.config.oidc_callback_url # type: str
+ self._oidc_attribute_requirements = provider.attribute_requirements
self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method
@@ -859,6 +860,18 @@ class OidcProvider:
)
# otherwise, it's a login
+ logger.debug("Userinfo for OIDC login: %s", userinfo)
+
+ # Ensure that the attributes of the logged in user meet the required
+ # attributes by checking the userinfo against attribute_requirements
+ # In order to deal with the fact that OIDC userinfo can contain many
+ # types of data, we wrap non-list values in lists.
+ if not self._sso_handler.check_required_attributes(
+ request,
+ {k: v if isinstance(v, list) else [v] for k, v in userinfo.items()},
+ self._oidc_attribute_requirements,
+ ):
+ return
# Call the mapper to register/login the user
try:
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 54631b4ee2..da92feacc9 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -104,6 +104,8 @@ class BasePresenceHandler(abc.ABC):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
+ self._busy_presence_enabled = hs.config.experimental.msc3026_enabled
+
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {state.user_id: state for state in active_presence}
@@ -730,8 +732,12 @@ class PresenceHandler(BasePresenceHandler):
PresenceState.ONLINE,
PresenceState.UNAVAILABLE,
PresenceState.OFFLINE,
+ PresenceState.BUSY,
)
- if presence not in valid_presence:
+
+ if presence not in valid_presence or (
+ presence == PresenceState.BUSY and not self._busy_presence_enabled
+ ):
raise SynapseError(400, "Invalid presence state")
user_id = target_user.to_string()
@@ -744,7 +750,9 @@ class PresenceHandler(BasePresenceHandler):
msg = status_msg if presence != PresenceState.OFFLINE else None
new_fields["status_msg"] = msg
- if presence == PresenceState.ONLINE:
+ if presence == PresenceState.ONLINE or (
+ presence == PresenceState.BUSY and self._busy_presence_enabled
+ ):
new_fields["last_active_ts"] = self.clock.time_msec()
await self._update_states([prev_state.copy_and_replace(**new_fields)])
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 1abc8875cb..d7f226d589 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -437,10 +437,10 @@ class RegistrationHandler(BaseHandler):
if RoomAlias.is_valid(r):
(
- room_id,
+ room,
remote_room_hosts,
) = await room_member_handler.lookup_room_alias(room_alias)
- room_id = room_id.to_string()
+ room_id = room.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (r,)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 1660921306..4d20ed8357 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -155,6 +155,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
"""
raise NotImplementedError()
+ @abc.abstractmethod
+ async def forget(self, user: UserID, room_id: str) -> None:
+ raise NotImplementedError()
+
def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str):
"""Ratelimit invites by room and by target user.
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 108730a7a1..d75506c75e 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.api.errors import SynapseError
from synapse.handlers.room_member import RoomMemberHandler
@@ -25,11 +25,14 @@ from synapse.replication.http.membership import (
)
from synapse.types import Requester, UserID
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class RoomMemberWorkerHandler(RoomMemberHandler):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._remote_join_client = ReplRemoteJoin.make_client(hs)
@@ -83,3 +86,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
await self._notify_change_client(
user_id=target.to_string(), room_id=room_id, change="left"
)
+
+ async def forget(self, target: UserID, room_id: str) -> None:
+ raise RuntimeError("Cannot forget rooms on workers.")
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 84af2dde7e..04e7c64c94 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -41,7 +41,7 @@ class SetPasswordHandler(BaseHandler):
logout_devices: bool,
requester: Optional[Requester] = None,
) -> None:
- if not self.hs.config.password_localdb_enabled:
+ if not self._auth_handler.can_change_password():
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
try:
diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py
new file mode 100644
index 0000000000..513dc0c71a
--- /dev/null
+++ b/synapse/handlers/space_summary.py
@@ -0,0 +1,199 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import logging
+from collections import deque
+from typing import TYPE_CHECKING, Iterable, List, Optional, Set
+
+from synapse.api.constants import EventContentFields, EventTypes, HistoryVisibility
+from synapse.api.errors import AuthError
+from synapse.events import EventBase
+from synapse.events.utils import format_event_for_client_v2
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+# number of rooms to return. We'll stop once we hit this limit.
+# TODO: allow clients to reduce this with a request param.
+MAX_ROOMS = 50
+
+# max number of events to return per room.
+MAX_ROOMS_PER_SPACE = 50
+
+
+class SpaceSummaryHandler:
+ def __init__(self, hs: "HomeServer"):
+ self._clock = hs.get_clock()
+ self._auth = hs.get_auth()
+ self._room_list_handler = hs.get_room_list_handler()
+ self._state_handler = hs.get_state_handler()
+ self._store = hs.get_datastore()
+ self._event_serializer = hs.get_event_client_serializer()
+
+ async def get_space_summary(
+ self,
+ requester: str,
+ room_id: str,
+ suggested_only: bool = False,
+ max_rooms_per_space: Optional[int] = None,
+ ) -> JsonDict:
+ """
+ Implementation of the space summary API
+
+ Args:
+ requester: user id of the user making this request
+
+ room_id: room id to start the summary at
+
+ suggested_only: whether we should only return children with the "suggested"
+ flag set.
+
+ max_rooms_per_space: an optional limit on the number of child rooms we will
+ return. This does not apply to the root room (ie, room_id), and
+ is overridden by ROOMS_PER_SPACE_LIMIT.
+
+ Returns:
+ summary dict to return
+ """
+ # first of all, check that the user is in the room in question (or it's
+ # world-readable)
+ await self._auth.check_user_in_room_or_world_readable(room_id, requester)
+
+ # the queue of rooms to process
+ room_queue = deque((room_id,))
+
+ processed_rooms = set() # type: Set[str]
+
+ rooms_result = [] # type: List[JsonDict]
+ events_result = [] # type: List[JsonDict]
+
+ now = self._clock.time_msec()
+
+ while room_queue and len(rooms_result) < MAX_ROOMS:
+ room_id = room_queue.popleft()
+ logger.debug("Processing room %s", room_id)
+ processed_rooms.add(room_id)
+
+ try:
+ await self._auth.check_user_in_room_or_world_readable(
+ room_id, requester
+ )
+ except AuthError:
+ logger.info(
+ "user %s cannot view room %s, omitting from summary",
+ requester,
+ room_id,
+ )
+ continue
+
+ room_entry = await self._build_room_entry(room_id)
+ rooms_result.append(room_entry)
+
+ # look for child rooms/spaces.
+ child_events = await self._get_child_events(room_id)
+
+ if suggested_only:
+ # we only care about suggested children
+ child_events = filter(_is_suggested_child_event, child_events)
+
+ # The client-specified max_rooms_per_space limit doesn't apply to the
+ # room_id specified in the request, so we ignore it if this is the
+ # first room we are processing. Otherwise, apply any client-specified
+ # limit, capping to our built-in limit.
+ if max_rooms_per_space is not None and len(processed_rooms) > 1:
+ max_rooms = min(MAX_ROOMS_PER_SPACE, max_rooms_per_space)
+ else:
+ max_rooms = MAX_ROOMS_PER_SPACE
+
+ for edge_event in itertools.islice(child_events, max_rooms):
+ edge_room_id = edge_event.state_key
+
+ events_result.append(
+ await self._event_serializer.serialize_event(
+ edge_event,
+ time_now=now,
+ event_format=format_event_for_client_v2,
+ )
+ )
+
+ # if we haven't yet visited the target of this link, add it to the queue
+ if edge_room_id not in processed_rooms:
+ room_queue.append(edge_room_id)
+
+ return {"rooms": rooms_result, "events": events_result}
+
+ async def _build_room_entry(self, room_id: str) -> JsonDict:
+ """Generate en entry suitable for the 'rooms' list in the summary response"""
+ stats = await self._store.get_room_with_stats(room_id)
+
+ # currently this should be impossible because we call
+ # check_user_in_room_or_world_readable on the room before we get here, so
+ # there should always be an entry
+ assert stats is not None, "unable to retrieve stats for %s" % (room_id,)
+
+ current_state_ids = await self._store.get_current_state_ids(room_id)
+ create_event = await self._store.get_event(
+ current_state_ids[(EventTypes.Create, "")]
+ )
+
+ # TODO: update once MSC1772 lands
+ room_type = create_event.content.get(EventContentFields.MSC1772_ROOM_TYPE)
+
+ entry = {
+ "room_id": stats["room_id"],
+ "name": stats["name"],
+ "topic": stats["topic"],
+ "canonical_alias": stats["canonical_alias"],
+ "num_joined_members": stats["joined_members"],
+ "avatar_url": stats["avatar"],
+ "world_readable": (
+ stats["history_visibility"] == HistoryVisibility.WORLD_READABLE
+ ),
+ "guest_can_join": stats["guest_access"] == "can_join",
+ "room_type": room_type,
+ }
+
+ # Filter out Nones – rather omit the field altogether
+ room_entry = {k: v for k, v in entry.items() if v is not None}
+
+ return room_entry
+
+ async def _get_child_events(self, room_id: str) -> Iterable[EventBase]:
+ # look for child rooms/spaces.
+ current_state_ids = await self._store.get_current_state_ids(room_id)
+
+ events = await self._store.get_events_as_list(
+ [
+ event_id
+ for key, event_id in current_state_ids.items()
+ # TODO: update once MSC1772 lands
+ if key[0] == EventTypes.MSC1772_SPACE_CHILD
+ ]
+ )
+
+ # filter out any events without a "via" (which implies it has been redacted)
+ return (e for e in events if e.content.get("via"))
+
+
+def _is_suggested_child_event(edge_event: EventBase) -> bool:
+ suggested = edge_event.content.get("suggested")
+ if isinstance(suggested, bool) and suggested:
+ return True
+ logger.debug("Ignorning not-suggested child %s", edge_event.state_key)
+ return False
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index f50257cd57..7b723ead58 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1979,8 +1979,10 @@ class SyncHandler:
logger.info("User joined room after current token: %s", room_id)
- extrems = await self.store.get_forward_extremeties_for_room(
- room_id, event_pos.stream
+ extrems = (
+ await self.store.get_forward_extremities_for_room_at_stream_ordering(
+ room_id, event_pos.stream
+ )
)
users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
if user_id in users_in_room:
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 8af53b4f28..82ea3b895f 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -40,6 +40,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
// containing the event
"event_format_version": .., // 1,2,3 etc: the event format version
"internal_metadata": { .. serialized internal_metadata .. },
+ "outlier": true|false,
"rejected_reason": .., // The event.rejected_reason field
"context": { .. serialized event context .. },
}],
@@ -84,6 +85,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
"room_version": event.room_version.identifier,
"event_format_version": event.format_version,
"internal_metadata": event.internal_metadata.get_dict(),
+ "outlier": event.internal_metadata.is_outlier(),
"rejected_reason": event.rejected_reason,
"context": serialized_context,
}
@@ -116,6 +118,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event = make_event_from_dict(
event_dict, room_ver, internal_metadata, rejected_reason
)
+ event.internal_metadata.outlier = event_payload["outlier"]
context = EventContext.deserialize(
self.storage, event_payload["context"]
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index 8fa104c8d3..a4c5b44292 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -40,6 +40,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
// containing the event
"event_format_version": .., // 1,2,3 etc: the event format version
"internal_metadata": { .. serialized internal_metadata .. },
+ "outlier": true|false,
"rejected_reason": .., // The event.rejected_reason field
"context": { .. serialized event context .. },
"requester": { .. serialized requester .. },
@@ -79,7 +80,6 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
ratelimit (bool)
extra_users (list(UserID)): Any extra users to notify about event
"""
-
serialized_context = await context.serialize(event, store)
payload = {
@@ -87,6 +87,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
"room_version": event.room_version.identifier,
"event_format_version": event.format_version,
"internal_metadata": event.internal_metadata.get_dict(),
+ "outlier": event.internal_metadata.is_outlier(),
"rejected_reason": event.rejected_reason,
"context": serialized_context,
"requester": requester.serialize(),
@@ -108,6 +109,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
event = make_event_from_dict(
event_dict, room_ver, internal_metadata, rejected_reason
)
+ event.internal_metadata.outlier = content["outlier"]
requester = Requester.deserialize(self.store, content["requester"])
context = EventContext.deserialize(self.storage, content["context"])
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index f45e7a8c89..7e8e64d61c 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -33,7 +33,7 @@ import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
if TYPE_CHECKING:
- import synapse.server
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@@ -299,20 +299,23 @@ class TypingStream(Stream):
NAME = "typing"
ROW_TYPE = TypingStreamRow
- def __init__(self, hs):
- typing_handler = hs.get_typing_handler()
-
+ def __init__(self, hs: "HomeServer"):
writer_instance = hs.config.worker.writers.typing
if writer_instance == hs.get_instance_name():
# On the writer, query the typing handler
- update_function = typing_handler.get_all_typing_updates
+ typing_writer_handler = hs.get_typing_writer_handler()
+ update_function = (
+ typing_writer_handler.get_all_typing_updates
+ ) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]]
+ current_token_function = typing_writer_handler.get_current_token
else:
# Query the typing writer process
update_function = make_http_update_function(hs, self.NAME)
+ current_token_function = hs.get_typing_handler().get_current_token
super().__init__(
hs.get_instance_name(),
- current_token_without_instance(typing_handler.get_current_token),
+ current_token_without_instance(current_token_function),
update_function,
)
@@ -509,7 +512,7 @@ class AccountDataStream(Stream):
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
- def __init__(self, hs: "synapse.server.HomeServer"):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 2c89b62e25..aaa56a7024 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -271,7 +271,7 @@ class UserRestServletV2(RestServlet):
elif not deactivate and user["deactivated"]:
if (
"password" not in body
- and self.hs.config.password_localdb_enabled
+ and self.auth_handler.can_change_password()
):
raise SynapseError(
400, "Must provide a password to re-activate an account."
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 5884daea6d..6c722d634d 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -18,7 +18,7 @@
import logging
import re
-from typing import TYPE_CHECKING, List, Optional
+from typing import TYPE_CHECKING, List, Optional, Tuple
from urllib import parse as urlparse
from synapse.api.constants import EventTypes, Membership
@@ -35,21 +35,30 @@ from synapse.events.utils import format_event_for_client_v2
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
+ parse_boolean,
parse_integer,
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
-from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
+from synapse.types import (
+ JsonDict,
+ RoomAlias,
+ RoomID,
+ StreamToken,
+ ThirdPartyInstanceID,
+ UserID,
+)
from synapse.util import json_decoder
from synapse.util.stringutils import parse_and_validate_server_name, random_string
if TYPE_CHECKING:
- import synapse.server
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@@ -846,10 +855,10 @@ class RoomTypingRestServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$", v1=True
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
+ self.hs = hs
self.presence_handler = hs.get_presence_handler()
- self.typing_handler = hs.get_typing_handler()
self.auth = hs.get_auth()
# If we're not on the typing writer instance we should scream if we get
@@ -874,16 +883,19 @@ class RoomTypingRestServlet(RestServlet):
# Limit timeout to stop people from setting silly typing timeouts.
timeout = min(content.get("timeout", 30000), 120000)
+ # Defer getting the typing handler since it will raise on workers.
+ typing_handler = self.hs.get_typing_writer_handler()
+
try:
if content["typing"]:
- await self.typing_handler.started_typing(
+ await typing_handler.started_typing(
target_user=target_user,
requester=requester,
room_id=room_id,
timeout=timeout,
)
else:
- await self.typing_handler.stopped_typing(
+ await typing_handler.stopped_typing(
target_user=target_user, requester=requester, room_id=room_id
)
except ShadowBanError:
@@ -901,7 +913,7 @@ class RoomAliasListServlet(RestServlet):
),
]
- def __init__(self, hs: "synapse.server.HomeServer"):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.directory_handler = hs.get_directory_handler()
@@ -984,7 +996,58 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
)
-def register_servlets(hs, http_server, is_worker=False):
+class RoomSpaceSummaryRestServlet(RestServlet):
+ PATTERNS = (
+ re.compile(
+ "^/_matrix/client/unstable/org.matrix.msc2946"
+ "/rooms/(?P<room_id>[^/]*)/spaces$"
+ ),
+ )
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._auth = hs.get_auth()
+ self._space_summary_handler = hs.get_space_summary_handler()
+
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self._auth.get_user_by_req(request, allow_guest=True)
+
+ return 200, await self._space_summary_handler.get_space_summary(
+ requester.user.to_string(),
+ room_id,
+ suggested_only=parse_boolean(request, "suggested_only", default=False),
+ max_rooms_per_space=parse_integer(request, "max_rooms_per_space"),
+ )
+
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self._auth.get_user_by_req(request, allow_guest=True)
+ content = parse_json_object_from_request(request)
+
+ suggested_only = content.get("suggested_only", False)
+ if not isinstance(suggested_only, bool):
+ raise SynapseError(
+ 400, "'suggested_only' must be a boolean", Codes.BAD_JSON
+ )
+
+ max_rooms_per_space = content.get("max_rooms_per_space")
+ if max_rooms_per_space is not None and not isinstance(max_rooms_per_space, int):
+ raise SynapseError(
+ 400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON
+ )
+
+ return 200, await self._space_summary_handler.get_space_summary(
+ requester.user.to_string(),
+ room_id,
+ suggested_only=suggested_only,
+ max_rooms_per_space=max_rooms_per_space,
+ )
+
+
+def register_servlets(hs: "HomeServer", http_server, is_worker=False):
RoomStateEventRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server)
JoinedRoomMemberListRestServlet(hs).register(http_server)
@@ -998,6 +1061,9 @@ def register_servlets(hs, http_server, is_worker=False):
RoomTypingRestServlet(hs).register(http_server)
RoomEventContextServlet(hs).register(http_server)
+ if hs.config.experimental.spaces_enabled:
+ RoomSpaceSummaryRestServlet(hs).register(http_server)
+
# Some servlets only get registered for the main process.
if not is_worker:
RoomCreateRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
index 76879ac559..44ccf10ed4 100644
--- a/synapse/rest/client/v2_alpha/capabilities.py
+++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -13,12 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -27,21 +33,16 @@ class CapabilitiesRestServlet(RestServlet):
PATTERNS = client_patterns("/capabilities$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.config = hs.config
self.auth = hs.get_auth()
- self.store = hs.get_datastore()
+ self.auth_handler = hs.get_auth_handler()
- async def on_GET(self, request):
- requester = await self.auth.get_user_by_req(request, allow_guest=True)
- user = await self.store.get_user_by_id(requester.user.to_string())
- change_password = bool(user["password_hash"])
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await self.auth.get_user_by_req(request, allow_guest=True)
+ change_password = self.auth_handler.can_change_password()
response = {
"capabilities": {
@@ -58,5 +59,5 @@ class CapabilitiesRestServlet(RestServlet):
return 200, response
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
CapabilitiesRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index d24a199318..3e3d8839f4 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -81,6 +81,8 @@ class VersionsRestServlet(RestServlet):
"io.element.e2ee_forced.public": self.e2ee_forced_public,
"io.element.e2ee_forced.private": self.e2ee_forced_private,
"io.element.e2ee_forced.trusted_private": self.e2ee_forced_trusted_private,
+ # Supports the busy presence state described in MSC3026.
+ "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled,
},
},
)
diff --git a/synapse/server.py b/synapse/server.py
index 48ac87a124..98822d8e2f 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -96,10 +96,11 @@ from synapse.handlers.room import (
RoomShutdownHandler,
)
from synapse.handlers.room_list import RoomListHandler
-from synapse.handlers.room_member import RoomMemberMasterHandler
+from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
from synapse.handlers.search import SearchHandler
from synapse.handlers.set_password import SetPasswordHandler
+from synapse.handlers.space_summary import SpaceSummaryHandler
from synapse.handlers.sso import SsoHandler
from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler
@@ -417,10 +418,19 @@ class HomeServer(metaclass=abc.ABCMeta):
return PresenceHandler(self)
@cache_in_self
- def get_typing_handler(self):
+ def get_typing_writer_handler(self) -> TypingWriterHandler:
if self.config.worker.writers.typing == self.get_instance_name():
return TypingWriterHandler(self)
else:
+ raise Exception("Workers cannot write typing")
+
+ @cache_in_self
+ def get_typing_handler(self) -> FollowerTypingHandler:
+ if self.config.worker.writers.typing == self.get_instance_name():
+ # Use get_typing_writer_handler to ensure that we use the same
+ # cached version.
+ return self.get_typing_writer_handler()
+ else:
return FollowerTypingHandler(self)
@cache_in_self
@@ -630,7 +640,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return ThirdPartyEventRules(self)
@cache_in_self
- def get_room_member_handler(self):
+ def get_room_member_handler(self) -> RoomMemberHandler:
if self.config.worker_app:
return RoomMemberWorkerHandler(self)
return RoomMemberMasterHandler(self)
@@ -724,6 +734,10 @@ class HomeServer(metaclass=abc.ABCMeta):
return AccountDataHandler(self)
@cache_in_self
+ def get_space_summary_handler(self) -> SpaceSummaryHandler:
+ return SpaceSummaryHandler(self)
+
+ @cache_in_self
def get_external_cache(self) -> ExternalCache:
return ExternalCache(self)
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 332193ad1c..a956be491a 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -793,7 +793,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return int(min_depth) if min_depth is not None else None
- async def get_forward_extremeties_for_room(
+ async def get_forward_extremities_for_room_at_stream_ordering(
self, room_id: str, stream_ordering: int
) -> List[str]:
"""For a given room_id and stream_ordering, return the forward
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index cd1ceac50e..98dac19a95 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1270,8 +1270,10 @@ class PersistEventsStore:
logger.exception("")
raise
+ # update the stored internal_metadata to update the "outlier" flag.
+ # TODO: This is unused as of Synapse 1.31. Remove it once we are happy
+ # to drop backwards-compatibility with 1.30.
metadata_json = json_encoder.encode(event.internal_metadata.get_dict())
-
sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
txn.execute(sql, (metadata_json, event.event_id))
@@ -1319,6 +1321,19 @@ class PersistEventsStore:
d.pop("redacted_because", None)
return d
+ def get_internal_metadata(event):
+ im = event.internal_metadata.get_dict()
+
+ # temporary hack for database compatibility with Synapse 1.30 and earlier:
+ # store the `outlier` flag inside the internal_metadata json as well as in
+ # the `events` table, so that if anyone rolls back to an older Synapse,
+ # things keep working. This can be removed once we are happy to drop support
+ # for that
+ if event.internal_metadata.is_outlier():
+ im["outlier"] = True
+
+ return im
+
self.db_pool.simple_insert_many_txn(
txn,
table="event_json",
@@ -1327,7 +1342,7 @@ class PersistEventsStore:
"event_id": event.event_id,
"room_id": event.room_id,
"internal_metadata": json_encoder.encode(
- event.internal_metadata.get_dict()
+ get_internal_metadata(event)
),
"json": json_encoder.encode(event_dict(event)),
"format_version": event.format_version,
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c04e162ccc..952d4969b2 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -799,6 +799,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected_reason=rejected_reason,
)
original_ev.internal_metadata.stream_ordering = row["stream_ordering"]
+ original_ev.internal_metadata.outlier = row["outlier"]
event_map[event_id] = original_ev
@@ -905,7 +906,8 @@ class EventsWorkerStore(SQLBaseStore):
ej.json,
ej.format_version,
r.room_version,
- rej.reason
+ rej.reason,
+ e.outlier
FROM events AS e
JOIN event_json AS ej USING (event_id)
LEFT JOIN rooms r ON r.room_id = e.room_id
@@ -929,6 +931,7 @@ class EventsWorkerStore(SQLBaseStore):
"room_version_id": row[5],
"rejected_reason": row[6],
"redactions": [],
+ "outlier": row[7],
}
# check for redactions
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index eba66ff352..90a8f664ef 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1210,6 +1210,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
+ self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
@cached()
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index 6f96cd7940..95eac6a5a3 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -2,6 +2,7 @@ from typing import List, Tuple
from mock import Mock
+from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.federation.sender import PerDestinationQueue, TransactionManager
from synapse.federation.units import Edu
@@ -421,3 +422,51 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
self.assertNotIn("zzzerver", woken)
# - all destinations are woken exactly once; they appear once in woken.
self.assertCountEqual(woken, server_names[:-1])
+
+ @override_config({"send_federation": True})
+ def test_not_latest_event(self):
+ """Test that we send the latest event in the room even if its not ours."""
+
+ per_dest_queue, sent_pdus = self.make_fake_destination_queue()
+
+ # Make a room with a local user, and two servers. One will go offline
+ # and one will send some events.
+ self.register_user("u1", "you the one")
+ u1_token = self.login("u1", "you the one")
+ room_1 = self.helper.create_room_as("u1", tok=u1_token)
+
+ self.get_success(
+ event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join")
+ )
+ event_1 = self.get_success(
+ event_injection.inject_member_event(self.hs, room_1, "@user:host3", "join")
+ )
+
+ # First we send something from the local server, so that we notice the
+ # remote is down and go into catchup mode.
+ self.helper.send(room_1, "you hear me!!", tok=u1_token)
+
+ # Now simulate us receiving an event from the still online remote.
+ event_2 = self.get_success(
+ event_injection.inject_event(
+ self.hs,
+ type=EventTypes.Message,
+ sender="@user:host3",
+ room_id=room_1,
+ content={"msgtype": "m.text", "body": "Hello"},
+ )
+ )
+
+ self.get_success(
+ self.hs.get_datastore().set_destination_last_successful_stream_ordering(
+ "host2", event_1.internal_metadata.stream_ordering
+ )
+ )
+
+ self.get_success(per_dest_queue._catch_up_transmission_loop())
+
+ # We expect only the last message from the remote, event_2, to have been
+ # sent, rather than the last *local* event that was sent.
+ self.assertEqual(len(sent_pdus), 1)
+ self.assertEqual(sent_pdus[0].event_id, event_2.event_id)
+ self.assertFalse(per_dest_queue._catching_up)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 5e9c9c2e88..c7796fb837 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -989,6 +989,138 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
self.assertRenderedError("mapping_error", "localpart is invalid: ")
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "attribute_requirements": [{"attribute": "test", "value": "foobar"}],
+ }
+ }
+ )
+ def test_attribute_requirements(self):
+ """The required attributes must be met from the OIDC userinfo response."""
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # userinfo lacking "test": "foobar" attribute should fail.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": "foobar" attribute should succeed.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": "foobar",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@tester:test", "oidc", ANY, ANY, None, new_user=True
+ )
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "attribute_requirements": [{"attribute": "test", "value": "foobar"}],
+ }
+ }
+ )
+ def test_attribute_requirements_contains(self):
+ """Test that auth succeeds if userinfo attribute CONTAINS required value"""
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+ # userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": ["foobar", "foo", "bar"],
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@tester:test", "oidc", ANY, ANY, None, new_user=True
+ )
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
+ "attribute_requirements": [{"attribute": "test", "value": "foobar"}],
+ }
+ }
+ )
+ def test_attribute_requirements_mismatch(self):
+ """
+ Test that auth fails if attributes exist but don't match,
+ or are non-string values.
+ """
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+ # userinfo with "test": "not_foobar" attribute should fail
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": "not_foobar",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": ["foo", "bar"] attribute should fail
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": ["foo", "bar"],
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": False attribute should fail
+ # this is largely just to ensure we don't crash here
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": False,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": None attribute should fail
+ # a value of None breaks the OIDC spec, but it's important to not crash here
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": None,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": 1 attribute should fail
+ # this is largely just to ensure we don't crash here
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": 1,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
+ # userinfo with "test": 3.14 attribute should fail
+ # this is largely just to ensure we don't crash here
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ "test": 3.14,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+
def _generate_oidc_session_token(
self,
state: str,
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 996c614198..77330f59a9 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -310,6 +310,26 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertIsNotNone(new_state)
self.assertEquals(new_state.state, PresenceState.UNAVAILABLE)
+ def test_busy_no_idle(self):
+ """
+ Tests that a user setting their presence to busy but idling doesn't turn their
+ presence state into unavailable.
+ """
+ user_id = "@foo:bar"
+ now = 5000000
+
+ state = UserPresenceState.default(user_id)
+ state = state.copy_and_replace(
+ state=PresenceState.BUSY,
+ last_active_ts=now - IDLE_TIMER - 1,
+ last_user_sync_ts=now,
+ )
+
+ new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
+
+ self.assertIsNotNone(new_state)
+ self.assertEquals(new_state.state, PresenceState.BUSY)
+
def test_sync_timeout(self):
user_id = "@foo:bar"
now = 5000000
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index e58d5cf0db..cf61f284cb 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1003,12 +1003,23 @@ class UserRestTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
+ self.auth_handler = hs.get_auth_handler()
+ # create users and get access tokens
+ # regardless of whether password login or SSO is allowed
self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
+ self.admin_user_tok = self.get_success(
+ self.auth_handler.get_access_token_for_user_id(
+ self.admin_user, device_id=None, valid_until_ms=None
+ )
+ )
self.other_user = self.register_user("user", "pass", displayname="User")
- self.other_user_token = self.login("user", "pass")
+ self.other_user_token = self.get_success(
+ self.auth_handler.get_access_token_for_user_id(
+ self.other_user, device_id=None, valid_until_ms=None
+ )
+ )
self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
self.other_user
)
@@ -1081,7 +1092,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual(True, channel.json_body["admin"])
+ self.assertTrue(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user
@@ -1096,9 +1107,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual(True, channel.json_body["admin"])
- self.assertEqual(False, channel.json_body["is_guest"])
- self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertTrue(channel.json_body["admin"])
+ self.assertFalse(channel.json_body["is_guest"])
+ self.assertFalse(channel.json_body["deactivated"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
def test_create_user(self):
@@ -1130,7 +1141,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual(False, channel.json_body["admin"])
+ self.assertFalse(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user
@@ -1145,10 +1156,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
- self.assertEqual(False, channel.json_body["admin"])
- self.assertEqual(False, channel.json_body["is_guest"])
- self.assertEqual(False, channel.json_body["deactivated"])
- self.assertEqual(False, channel.json_body["shadow_banned"])
+ self.assertFalse(channel.json_body["admin"])
+ self.assertFalse(channel.json_body["is_guest"])
+ self.assertFalse(channel.json_body["deactivated"])
+ self.assertFalse(channel.json_body["shadow_banned"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
@override_config(
@@ -1197,7 +1208,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
- self.assertEqual(False, channel.json_body["admin"])
+ self.assertFalse(channel.json_body["admin"])
@override_config(
{"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
@@ -1237,7 +1248,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Admin user is not blocked by mau anymore
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"])
- self.assertEqual(False, channel.json_body["admin"])
+ self.assertFalse(channel.json_body["admin"])
@override_config(
{
@@ -1429,24 +1440,23 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertFalse(channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
# Deactivate user
- body = json.dumps({"deactivated": True})
-
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"deactivated": True},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertTrue(channel.json_body["deactivated"])
+ self.assertIsNone(channel.json_body["password_hash"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
@@ -1461,7 +1471,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertTrue(channel.json_body["deactivated"])
+ self.assertIsNone(channel.json_body["password_hash"])
self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"])
@@ -1478,41 +1489,37 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertTrue(profile["display_name"] == "User")
# Deactivate user
- body = json.dumps({"deactivated": True})
-
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"deactivated": True},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertTrue(channel.json_body["deactivated"])
# is not in user directory
profile = self.get_success(self.store.get_user_in_directory(self.other_user))
- self.assertTrue(profile is None)
+ self.assertIsNone(profile)
# Set new displayname user
- body = json.dumps({"displayname": "Foobar"})
-
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"displayname": "Foobar"},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["deactivated"])
+ self.assertTrue(channel.json_body["deactivated"])
self.assertEqual("Foobar", channel.json_body["displayname"])
# is not in user directory
profile = self.get_success(self.store.get_user_in_directory(self.other_user))
- self.assertTrue(profile is None)
+ self.assertIsNone(profile)
def test_reactivate_user(self):
"""
@@ -1520,48 +1527,92 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
# Deactivate the user.
+ self._deactivate_user("@user:test")
+
+ # Attempt to reactivate the user (without a password).
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"deactivated": False},
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+ # Reactivate the user.
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=json.dumps({"deactivated": True}).encode(encoding="utf_8"),
+ content={"deactivated": False, "password": "foo"},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertFalse(channel.json_body["deactivated"])
+ self.assertIsNotNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
- d = self.store.mark_user_erased("@user:test")
- self.assertIsNone(self.get_success(d))
- self._is_erased("@user:test", True)
- # Attempt to reactivate the user (without a password).
+ @override_config({"password_config": {"localdb_enabled": False}})
+ def test_reactivate_user_localdb_disabled(self):
+ """
+ Test reactivating another user when using SSO.
+ """
+
+ # Deactivate the user.
+ self._deactivate_user("@user:test")
+
+ # Reactivate the user with a password
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=json.dumps({"deactivated": False}).encode(encoding="utf_8"),
+ content={"deactivated": False, "password": "foo"},
)
- self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- # Reactivate the user.
+ # Reactivate the user without a password.
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=json.dumps({"deactivated": False, "password": "foo"}).encode(
- encoding="utf_8"
- ),
+ content={"deactivated": False},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertFalse(channel.json_body["deactivated"])
+ self.assertIsNone(channel.json_body["password_hash"])
+ self._is_erased("@user:test", False)
- # Get user
+ @override_config({"password_config": {"enabled": False}})
+ def test_reactivate_user_password_disabled(self):
+ """
+ Test reactivating another user when using SSO.
+ """
+
+ # Deactivate the user.
+ self._deactivate_user("@user:test")
+
+ # Reactivate the user with a password
channel = self.make_request(
- "GET",
+ "PUT",
self.url_other_user,
access_token=self.admin_user_tok,
+ content={"deactivated": False, "password": "foo"},
)
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+ # Reactivate the user without a password.
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"deactivated": False},
+ )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(False, channel.json_body["deactivated"])
+ self.assertFalse(channel.json_body["deactivated"])
+ self.assertIsNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
def test_set_user_as_admin(self):
@@ -1570,18 +1621,16 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"""
# Set a user as an admin
- body = json.dumps({"admin": True})
-
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"admin": True},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["admin"])
+ self.assertTrue(channel.json_body["admin"])
# Get user
channel = self.make_request(
@@ -1592,7 +1641,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
- self.assertEqual(True, channel.json_body["admin"])
+ self.assertTrue(channel.json_body["admin"])
def test_accidental_deactivation_prevention(self):
"""
@@ -1602,13 +1651,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v2/users/@bob:test"
# Create user
- body = json.dumps({"password": "abc123"})
-
channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"password": "abc123"},
)
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
@@ -1628,13 +1675,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, channel.json_body["deactivated"])
# Change password (and use a str for deactivate instead of a bool)
- body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops!
-
channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"password": "abc123", "deactivated": "false"},
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -1653,7 +1698,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Ensure they're still alive
self.assertEqual(0, channel.json_body["deactivated"])
- def _is_erased(self, user_id, expect):
+ def _is_erased(self, user_id: str, expect: bool) -> None:
"""Assert that the user is erased or not"""
d = self.store.is_user_erased(user_id)
if expect:
@@ -1661,6 +1706,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
else:
self.assertFalse(self.get_success(d))
+ def _deactivate_user(self, user_id: str) -> None:
+ """Deactivate user and set as erased"""
+
+ # Deactivate the user.
+ channel = self.make_request(
+ "PUT",
+ "/_synapse/admin/v2/users/%s" % urllib.parse.quote(user_id),
+ access_token=self.admin_user_tok,
+ content={"deactivated": True},
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertTrue(channel.json_body["deactivated"])
+ self.assertIsNone(channel.json_body["password_hash"])
+ self._is_erased(user_id, False)
+ d = self.store.mark_user_erased(user_id)
+ self.assertIsNone(self.get_success(d))
+ self._is_erased(user_id, True)
+
class UserMembershipRestTestCase(unittest.HomeserverTestCase):
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 227fffab58..bf39014277 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -161,6 +161,68 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y")
+ def test_message_edit(self):
+ """Ensure that the module doesn't cause issues with edited messages."""
+ # first patch the event checker so that it will modify the event
+ async def check(ev: EventBase, state):
+ d = ev.get_dict()
+ d["content"] = {
+ "msgtype": "m.text",
+ "body": d["content"]["body"].upper(),
+ }
+ return d
+
+ current_rules_module().check_event_allowed = check
+
+ # Send an event, then edit it.
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
+ {
+ "msgtype": "m.text",
+ "body": "Original body",
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ orig_event_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/m.room.message/2" % self.room_id,
+ {
+ "m.new_content": {"msgtype": "m.text", "body": "Edited body"},
+ "m.relates_to": {
+ "rel_type": "m.replace",
+ "event_id": orig_event_id,
+ },
+ "msgtype": "m.text",
+ "body": "Edited body",
+ },
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ edited_event_id = channel.json_body["event_id"]
+
+ # ... and check that they both got modified
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, orig_event_id),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ ev = channel.json_body
+ self.assertEqual(ev["content"]["body"], "ORIGINAL BODY")
+
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, edited_event_id),
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.result["code"], b"200", channel.result)
+ ev = channel.json_body
+ self.assertEqual(ev["content"]["body"], "EDITED BODY")
+
def test_send_event(self):
"""Tests that the module can send an event into a room via the module api"""
content = {
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index e808339fb3..287a1a485c 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -18,6 +18,7 @@ from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import capabilities
from tests import unittest
+from tests.unittest import override_config
class CapabilitiesTestCase(unittest.HomeserverTestCase):
@@ -33,6 +34,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver()
self.store = hs.get_datastore()
self.config = hs.config
+ self.auth_handler = hs.get_auth_handler()
return hs
def test_check_auth_required(self):
@@ -56,7 +58,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
capabilities["m.room_versions"]["default"],
)
- def test_get_change_password_capabilities(self):
+ def test_get_change_password_capabilities_password_login(self):
localpart = "user"
password = "pass"
user = self.register_user(localpart, password)
@@ -66,10 +68,36 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
capabilities = channel.json_body["capabilities"]
self.assertEqual(channel.code, 200)
-
- # Test case where password is handled outside of Synapse
self.assertTrue(capabilities["m.change_password"]["enabled"])
- self.get_success(self.store.user_set_password_hash(user, None))
+
+ @override_config({"password_config": {"localdb_enabled": False}})
+ def test_get_change_password_capabilities_localdb_disabled(self):
+ localpart = "user"
+ password = "pass"
+ user = self.register_user(localpart, password)
+ access_token = self.get_success(
+ self.auth_handler.get_access_token_for_user_id(
+ user, device_id=None, valid_until_ms=None
+ )
+ )
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertFalse(capabilities["m.change_password"]["enabled"])
+
+ @override_config({"password_config": {"enabled": False}})
+ def test_get_change_password_capabilities_password_disabled(self):
+ localpart = "user"
+ password = "pass"
+ user = self.register_user(localpart, password)
+ access_token = self.get_success(
+ self.auth_handler.get_access_token_for_user_id(
+ user, device_id=None, valid_until_ms=None
+ )
+ )
+
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/v2_alpha/test_relations.py
index 7c457754f1..e7bb5583fc 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/v2_alpha/test_relations.py
@@ -39,6 +39,11 @@ class RelationsTestCase(unittest.HomeserverTestCase):
# We need to enable msc1849 support for aggregations
config = self.default_config()
config["experimental_msc1849_support_enabled"] = True
+
+ # We enable frozen dicts as relations/edits change event contents, so we
+ # want to test that we don't modify the events in the caches.
+ config["use_frozen_dicts"] = True
+
return self.setup_test_homeserver(config=config)
def prepare(self, reactor, clock, hs):
@@ -518,6 +523,63 @@ class RelationsTestCase(unittest.HomeserverTestCase):
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
+ def test_edit_reply(self):
+ """Test that editing a reply works."""
+
+ # Create a reply to edit.
+ channel = self._send_relation(
+ RelationTypes.REFERENCE,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "A reply!"},
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ reply = channel.json_body["event_id"]
+
+ new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+ parent_id=reply,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ edit_event_id = channel.json_body["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ "/rooms/%s/event/%s" % (self.room, reply),
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # We expect to see the new body in the dict, as well as the reference
+ # metadata sill intact.
+ self.assertDictContainsSubset(new_body, channel.json_body["content"])
+ self.assertDictContainsSubset(
+ {
+ "m.relates_to": {
+ "event_id": self.parent_id,
+ "key": None,
+ "rel_type": "m.reference",
+ }
+ },
+ channel.json_body["content"],
+ )
+
+ # We expect that the edit relation appears in the unsigned relations
+ # section.
+ relations_dict = channel.json_body["unsigned"].get("m.relations")
+ self.assertIn(RelationTypes.REPLACE, relations_dict)
+
+ m_replace_dict = relations_dict[RelationTypes.REPLACE]
+ for key in ["event_id", "sender", "origin_server_ts"]:
+ self.assertIn(key, m_replace_dict)
+
+ self.assert_dict(
+ {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
+ )
+
def test_relations_redaction_redacts_edits(self):
"""Test that edits of an event are redacted when the original event
is redacted.
diff --git a/tests/unittest.py b/tests/unittest.py
index ca7031c724..58a4daa1ec 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -32,6 +32,7 @@ from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
from twisted.web.resource import Resource
+from synapse import events
from synapse.api.constants import EventTypes, Membership
from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig
@@ -140,7 +141,7 @@ class TestCase(unittest.TestCase):
try:
self.assertEquals(attrs[key], getattr(obj, key))
except AssertionError as e:
- raise (type(e))(e.message + " for '.%s'" % key)
+ raise (type(e))("Assert error for '.{}':".format(key)) from e
def assert_dict(self, required, actual):
"""Does a partial assert of a dict.
@@ -229,6 +230,11 @@ class HomeserverTestCase(TestCase):
self._hs_args = {"clock": self.clock, "reactor": self.reactor}
self.hs = self.make_homeserver(self.reactor, self.clock)
+ # Honour the `use_frozen_dicts` config option. We have to do this
+ # manually because this is taken care of in the app `start` code, which
+ # we don't run. Plus we want to reset it on tearDown.
+ events.USE_FROZEN_DICTS = self.hs.config.use_frozen_dicts
+
if self.hs is None:
raise Exception("No homeserver returned from make_homeserver.")
@@ -292,6 +298,10 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "prepare"):
self.prepare(self.reactor, self.clock, self.hs)
+ def tearDown(self):
+ # Reset to not use frozen dicts.
+ events.USE_FROZEN_DICTS = False
+
def wait_on_thread(self, deferred, timeout=10):
"""
Wait until a Deferred is done, where it's waiting on a real thread.
|