diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index 77a7129ee2..14cb21c7fb 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -52,6 +52,7 @@ def request_registration(
user_type: Optional[str] = None,
_print: Callable[[str], None] = print,
exit: Callable[[int], None] = sys.exit,
+ exists_ok: bool = False,
) -> None:
url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),)
@@ -97,6 +98,10 @@ def request_registration(
r = requests.post(url, json=data)
if r.status_code != 200:
+ response = r.json()
+ if exists_ok and response["errcode"] == "M_USER_IN_USE":
+ _print("User already exists. Skipping.")
+ return
_print("ERROR! Received %d %s" % (r.status_code, r.reason))
if 400 <= r.status_code < 500:
try:
@@ -115,6 +120,7 @@ def register_new_user(
shared_secret: str,
admin: Optional[bool],
user_type: Optional[str],
+ exists_ok: bool = False,
) -> None:
if not user:
try:
@@ -154,7 +160,13 @@ def register_new_user(
admin = False
request_registration(
- user, password, server_location, shared_secret, bool(admin), user_type
+ user,
+ password,
+ server_location,
+ shared_secret,
+ bool(admin),
+ user_type,
+ exists_ok=exists_ok,
)
@@ -174,10 +186,22 @@ def main() -> None:
help="Local part of the new user. Will prompt if omitted.",
)
parser.add_argument(
+ "--exists-ok",
+ action="store_true",
+ help="Do not fail if user already exists.",
+ )
+ password_group = parser.add_mutually_exclusive_group()
+ password_group.add_argument(
"-p",
"--password",
default=None,
- help="New password for user. Will prompt if omitted.",
+ help="New password for user. Will prompt for a password if "
+ "this flag and `--password-file` are both omitted.",
+ )
+ password_group.add_argument(
+ "--password-file",
+ default=None,
+ help="File containing the new password for user. If set, will override `--password`.",
)
parser.add_argument(
"-t",
@@ -185,6 +209,7 @@ def main() -> None:
default=None,
help="User type as specified in synapse.api.constants.UserTypes",
)
+
admin_group = parser.add_mutually_exclusive_group()
admin_group.add_argument(
"-a",
@@ -247,6 +272,11 @@ def main() -> None:
print(_NO_SHARED_SECRET_OPTS_ERROR, file=sys.stderr)
sys.exit(1)
+ if args.password_file:
+ password = _read_file(args.password_file, "password-file").strip()
+ else:
+ password = args.password
+
if args.server_url:
server_url = args.server_url
elif config is not None:
@@ -270,7 +300,13 @@ def main() -> None:
admin = args.admin
register_new_user(
- args.user, args.password, server_url, secret, admin, args.user_type
+ args.user,
+ password,
+ server_url,
+ secret,
+ admin,
+ args.user_type,
+ exists_ok=args.exists_ok,
)
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 1e56f46911..3bb4a34938 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -777,22 +777,74 @@ class Porter:
await self._setup_events_stream_seqs()
await self._setup_sequence(
"un_partial_stated_event_stream_sequence",
- ("un_partial_stated_event_stream",),
+ [("un_partial_stated_event_stream", "stream_id")],
)
await self._setup_sequence(
- "device_inbox_sequence", ("device_inbox", "device_federation_outbox")
+ "device_inbox_sequence",
+ [
+ ("device_inbox", "stream_id"),
+ ("device_federation_outbox", "stream_id"),
+ ],
)
await self._setup_sequence(
"account_data_sequence",
- ("room_account_data", "room_tags_revisions", "account_data"),
+ [
+ ("room_account_data", "stream_id"),
+ ("room_tags_revisions", "stream_id"),
+ ("account_data", "stream_id"),
+ ],
+ )
+ await self._setup_sequence(
+ "receipts_sequence",
+ [
+ ("receipts_linearized", "stream_id"),
+ ],
+ )
+ await self._setup_sequence(
+ "presence_stream_sequence",
+ [
+ ("presence_stream", "stream_id"),
+ ],
)
- await self._setup_sequence("receipts_sequence", ("receipts_linearized",))
- await self._setup_sequence("presence_stream_sequence", ("presence_stream",))
await self._setup_auth_chain_sequence()
await self._setup_sequence(
"application_services_txn_id_seq",
- ("application_services_txns",),
- "txn_id",
+ [
+ (
+ "application_services_txns",
+ "txn_id",
+ )
+ ],
+ )
+ await self._setup_sequence(
+ "device_lists_sequence",
+ [
+ ("device_lists_stream", "stream_id"),
+ ("user_signature_stream", "stream_id"),
+ ("device_lists_outbound_pokes", "stream_id"),
+ ("device_lists_changes_in_room", "stream_id"),
+ ("device_lists_remote_pending", "stream_id"),
+ ("device_lists_changes_converted_stream_position", "stream_id"),
+ ],
+ )
+ await self._setup_sequence(
+ "e2e_cross_signing_keys_sequence",
+ [
+ ("e2e_cross_signing_keys", "stream_id"),
+ ],
+ )
+ await self._setup_sequence(
+ "push_rules_stream_sequence",
+ [
+ ("push_rules_stream", "stream_id"),
+ ],
+ )
+ await self._setup_sequence(
+ "pushers_sequence",
+ [
+ ("pushers", "id"),
+ ("deleted_pushers", "stream_id"),
+ ],
)
# Step 3. Get tables.
@@ -1101,12 +1153,11 @@ class Porter:
async def _setup_sequence(
self,
sequence_name: str,
- stream_id_tables: Iterable[str],
- column_name: str = "stream_id",
+ stream_id_tables: Iterable[Tuple[str, str]],
) -> None:
"""Set a sequence to the correct value."""
current_stream_ids = []
- for stream_id_table in stream_id_tables:
+ for stream_id_table, column_name in stream_id_tables:
max_stream_id = cast(
int,
await self.sqlite_store.db_pool.simple_select_one_onecol(
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 0a9123c56b..9265a271d2 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -50,7 +50,7 @@ class Membership:
KNOCK: Final = "knock"
LEAVE: Final = "leave"
BAN: Final = "ban"
- LIST: Final = (INVITE, JOIN, KNOCK, LEAVE, BAN)
+ LIST: Final = {INVITE, JOIN, KNOCK, LEAVE, BAN}
class PresenceState:
@@ -238,7 +238,7 @@ class EventUnsignedContentFields:
"""Fields found inside the 'unsigned' data on events"""
# Requesting user's membership, per MSC4115
- MSC4115_MEMBERSHIP: Final = "io.element.msc4115.membership"
+ MEMBERSHIP: Final = "membership"
class RoomTypes:
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index a73626bc86..a99a9e09fc 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -316,6 +316,10 @@ class Ratelimiter:
)
if not allowed:
+ # We pause for a bit here to stop clients from "tight-looping" on
+ # retrying their request.
+ await self.clock.sleep(0.5)
+
raise LimitExceededError(
limiter_name=self._limiter_name,
retry_after_ms=int(1000 * (time_allowed - time_now_s)),
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 3182608f73..4cc260d551 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -68,6 +68,7 @@ from synapse.config._base import format_config_error
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import ListenerConfig, ManholeConfig, TCPListenerConfig
from synapse.crypto import context_factory
+from synapse.events.auto_accept_invites import InviteAutoAccepter
from synapse.events.presence_router import load_legacy_presence_router
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.http.site import SynapseSite
@@ -582,6 +583,11 @@ async def start(hs: "HomeServer") -> None:
m = module(config, module_api)
logger.info("Loaded module %s", m)
+ if hs.config.auto_accept_invites.enabled:
+ # Start the local auto_accept_invites module.
+ m = InviteAutoAccepter(hs.config.auto_accept_invites, module_api)
+ logger.info("Loaded local module %s", m)
+
load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs)
load_legacy_presence_router(hs)
@@ -675,17 +681,17 @@ def setup_sentry(hs: "HomeServer") -> None:
)
# We set some default tags that give some context to this instance
- with sentry_sdk.configure_scope() as scope:
- scope.set_tag("matrix_server_name", hs.config.server.server_name)
+ global_scope = sentry_sdk.Scope.get_global_scope()
+ global_scope.set_tag("matrix_server_name", hs.config.server.server_name)
- app = (
- hs.config.worker.worker_app
- if hs.config.worker.worker_app
- else "synapse.app.homeserver"
- )
- name = hs.get_instance_name()
- scope.set_tag("worker_app", app)
- scope.set_tag("worker_name", name)
+ app = (
+ hs.config.worker.worker_app
+ if hs.config.worker.worker_app
+ else "synapse.app.homeserver"
+ )
+ name = hs.get_instance_name()
+ global_scope.set_tag("worker_app", app)
+ global_scope.set_tag("worker_name", name)
def setup_sdnotify(hs: "HomeServer") -> None:
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index fc51aed234..d9cb0da38b 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -23,6 +23,7 @@ from synapse.config import ( # noqa: F401
api,
appservice,
auth,
+ auto_accept_invites,
background_updates,
cache,
captcha,
@@ -120,6 +121,7 @@ class RootConfig:
federation: federation.FederationConfig
retention: retention.RetentionConfig
background_updates: background_updates.BackgroundUpdateConfig
+ auto_accept_invites: auto_accept_invites.AutoAcceptInvitesConfig
config_classes: List[Type["Config"]] = ...
config_files: List[str]
diff --git a/synapse/config/auto_accept_invites.py b/synapse/config/auto_accept_invites.py
new file mode 100644
index 0000000000..d90e13a510
--- /dev/null
+++ b/synapse/config/auto_accept_invites.py
@@ -0,0 +1,43 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+from typing import Any
+
+from synapse.types import JsonDict
+
+from ._base import Config
+
+
+class AutoAcceptInvitesConfig(Config):
+ section = "auto_accept_invites"
+
+ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
+ auto_accept_invites_config = config.get("auto_accept_invites") or {}
+
+ self.enabled = auto_accept_invites_config.get("enabled", False)
+
+ self.accept_invites_only_for_direct_messages = auto_accept_invites_config.get(
+ "only_for_direct_messages", False
+ )
+
+ self.accept_invites_only_from_local_users = auto_accept_invites_config.get(
+ "only_from_local_users", False
+ )
+
+ self.worker_to_run_on = auto_accept_invites_config.get("worker_to_run_on")
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index d23dcf96b2..fa59c350c1 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -66,6 +66,17 @@ class CasConfig(Config):
self.cas_enable_registration = cas_config.get("enable_registration", True)
+ self.cas_allow_numeric_ids = cas_config.get("allow_numeric_ids")
+ self.cas_numeric_ids_prefix = cas_config.get("numeric_ids_prefix")
+ if (
+ self.cas_numeric_ids_prefix is not None
+ and self.cas_numeric_ids_prefix.isalnum() is False
+ ):
+ raise ConfigError(
+ "Only alphanumeric characters are allowed for numeric IDs prefix",
+ ("cas_config", "numeric_ids_prefix"),
+ )
+
self.idp_name = cas_config.get("idp_name", "CAS")
self.idp_icon = cas_config.get("idp_icon")
self.idp_brand = cas_config.get("idp_brand")
@@ -77,6 +88,8 @@ class CasConfig(Config):
self.cas_displayname_attribute = None
self.cas_required_attributes = []
self.cas_enable_registration = False
+ self.cas_allow_numeric_ids = False
+ self.cas_numeric_ids_prefix = "u"
# CAS uses a legacy required attributes mapping, not the one provided by
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 749452ce93..1b72727b75 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -332,6 +332,9 @@ class ExperimentalConfig(Config):
# MSC3391: Removing account data.
self.msc3391_enabled = experimental.get("msc3391_enabled", False)
+ # MSC3575 (Sliding Sync API endpoints)
+ self.msc3575_enabled: bool = experimental.get("msc3575_enabled", False)
+
# MSC3773: Thread notifications
self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False)
@@ -390,9 +393,6 @@ class ExperimentalConfig(Config):
# MSC3391: Removing account data.
self.msc3391_enabled = experimental.get("msc3391_enabled", False)
- # MSC3967: Do not require UIA when first uploading cross signing keys
- self.msc3967_enabled = experimental.get("msc3967_enabled", False)
-
# MSC3861: Matrix architecture change to delegate authentication via OIDC
try:
self.msc3861 = MSC3861(**experimental.get("msc3861", {}))
@@ -433,6 +433,16 @@ class ExperimentalConfig(Config):
("experimental", "msc4108_delegation_endpoint"),
)
- self.msc4115_membership_on_events = experimental.get(
- "msc4115_membership_on_events", False
+ self.msc3823_account_suspension = experimental.get(
+ "msc3823_account_suspension", False
+ )
+
+ self.msc3916_authenticated_media_enabled = experimental.get(
+ "msc3916_authenticated_media_enabled", False
)
+
+ # MSC4151: Report room API (Client-Server API)
+ self.msc4151_enabled: bool = experimental.get("msc4151_enabled", False)
+
+ # MSC4156: Migrate server_name to via
+ self.msc4156_enabled: bool = experimental.get("msc4156_enabled", False)
diff --git a/synapse/config/federation.py b/synapse/config/federation.py
index 9032effac3..cf29fa2562 100644
--- a/synapse/config/federation.py
+++ b/synapse/config/federation.py
@@ -42,6 +42,10 @@ class FederationConfig(Config):
for domain in federation_domain_whitelist:
self.federation_domain_whitelist[domain] = True
+ self.federation_whitelist_endpoint_enabled = config.get(
+ "federation_whitelist_endpoint_enabled", False
+ )
+
federation_metrics_domains = config.get("federation_metrics_domains") or []
validate_config(
_METRICS_FOR_DOMAINS_SCHEMA,
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 72e93ed04f..e36c0bd6ae 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -23,6 +23,7 @@ from .account_validity import AccountValidityConfig
from .api import ApiConfig
from .appservice import AppServiceConfig
from .auth import AuthConfig
+from .auto_accept_invites import AutoAcceptInvitesConfig
from .background_updates import BackgroundUpdateConfig
from .cache import CacheConfig
from .captcha import CaptchaConfig
@@ -105,4 +106,5 @@ class HomeServerConfig(RootConfig):
RedisConfig,
ExperimentalConfig,
BackgroundUpdateConfig,
+ AutoAcceptInvitesConfig,
]
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index d2cb4576df..3fa33f5373 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -218,3 +218,13 @@ class RatelimitConfig(Config):
"rc_media_create",
defaults={"per_second": 10, "burst_count": 50},
)
+
+ self.remote_media_downloads = RatelimitSettings(
+ key="rc_remote_media_downloads",
+ per_second=self.parse_size(
+ config.get("remote_media_download_per_second", "87K")
+ ),
+ burst_count=self.parse_size(
+ config.get("remote_media_download_burst_count", "500M")
+ ),
+ )
diff --git a/synapse/events/auto_accept_invites.py b/synapse/events/auto_accept_invites.py
new file mode 100644
index 0000000000..d88ec51d9d
--- /dev/null
+++ b/synapse/events/auto_accept_invites.py
@@ -0,0 +1,196 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright 2021 The Matrix.org Foundation C.I.C
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+import logging
+from http import HTTPStatus
+from typing import Any, Dict, Tuple
+
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership
+from synapse.api.errors import SynapseError
+from synapse.config.auto_accept_invites import AutoAcceptInvitesConfig
+from synapse.module_api import EventBase, ModuleApi, run_as_background_process
+
+logger = logging.getLogger(__name__)
+
+
+class InviteAutoAccepter:
+ def __init__(self, config: AutoAcceptInvitesConfig, api: ModuleApi):
+ # Keep a reference to the Module API.
+ self._api = api
+ self._config = config
+
+ if not self._config.enabled:
+ return
+
+ should_run_on_this_worker = config.worker_to_run_on == self._api.worker_name
+
+ if not should_run_on_this_worker:
+ logger.info(
+ "Not accepting invites on this worker (configured: %r, here: %r)",
+ config.worker_to_run_on,
+ self._api.worker_name,
+ )
+ return
+
+ logger.info(
+ "Accepting invites on this worker (here: %r)", self._api.worker_name
+ )
+
+ # Register the callback.
+ self._api.register_third_party_rules_callbacks(
+ on_new_event=self.on_new_event,
+ )
+
+ async def on_new_event(self, event: EventBase, *args: Any) -> None:
+ """Listens for new events, and if the event is an invite for a local user then
+ automatically accepts it.
+
+ Args:
+ event: The incoming event.
+ """
+ # Check if the event is an invite for a local user.
+ is_invite_for_local_user = (
+ event.type == EventTypes.Member
+ and event.is_state()
+ and event.membership == Membership.INVITE
+ and self._api.is_mine(event.state_key)
+ )
+
+ # Only accept invites for direct messages if the configuration mandates it.
+ is_direct_message = event.content.get("is_direct", False)
+ is_allowed_by_direct_message_rules = (
+ not self._config.accept_invites_only_for_direct_messages
+ or is_direct_message is True
+ )
+
+ # Only accept invites from remote users if the configuration mandates it.
+ is_from_local_user = self._api.is_mine(event.sender)
+ is_allowed_by_local_user_rules = (
+ not self._config.accept_invites_only_from_local_users
+ or is_from_local_user is True
+ )
+
+ if (
+ is_invite_for_local_user
+ and is_allowed_by_direct_message_rules
+ and is_allowed_by_local_user_rules
+ ):
+ # Make the user join the room. We run this as a background process to circumvent a race condition
+ # that occurs when responding to invites over federation (see https://github.com/matrix-org/synapse-auto-accept-invite/issues/12)
+ run_as_background_process(
+ "retry_make_join",
+ self._retry_make_join,
+ event.state_key,
+ event.state_key,
+ event.room_id,
+ "join",
+ bg_start_span=False,
+ )
+
+ if is_direct_message:
+ # Mark this room as a direct message!
+ await self._mark_room_as_direct_message(
+ event.state_key, event.sender, event.room_id
+ )
+
+ async def _mark_room_as_direct_message(
+ self, user_id: str, dm_user_id: str, room_id: str
+ ) -> None:
+ """
+ Marks a room (`room_id`) as a direct message with the counterparty `dm_user_id`
+ from the perspective of the user `user_id`.
+
+ Args:
+ user_id: the user for whom the membership is changing
+ dm_user_id: the user performing the membership change
+ room_id: room id of the room the user is invited to
+ """
+
+ # This is a dict of User IDs to tuples of Room IDs
+ # (get_global will return a frozendict of tuples as it freezes the data,
+ # but we should accept either frozen or unfrozen variants.)
+ # Be careful: we convert the outer frozendict into a dict here,
+ # but the contents of the dict are still frozen (tuples in lieu of lists,
+ # etc.)
+ dm_map: Dict[str, Tuple[str, ...]] = dict(
+ await self._api.account_data_manager.get_global(
+ user_id, AccountDataTypes.DIRECT
+ )
+ or {}
+ )
+
+ if dm_user_id not in dm_map:
+ dm_map[dm_user_id] = (room_id,)
+ else:
+ dm_rooms_for_user = dm_map[dm_user_id]
+ assert isinstance(dm_rooms_for_user, (tuple, list))
+
+ dm_map[dm_user_id] = tuple(dm_rooms_for_user) + (room_id,)
+
+ await self._api.account_data_manager.put_global(
+ user_id, AccountDataTypes.DIRECT, dm_map
+ )
+
+ async def _retry_make_join(
+ self, sender: str, target: str, room_id: str, new_membership: str
+ ) -> None:
+ """
+ A function to retry sending the `make_join` request with an increasing backoff. This is
+ implemented to work around a race condition when receiving invites over federation.
+
+ Args:
+ sender: the user performing the membership change
+ target: the user for whom the membership is changing
+ room_id: room id of the room to join to
+ new_membership: the type of membership event (in this case will be "join")
+ """
+
+ sleep = 0
+ retries = 0
+ join_event = None
+
+ while retries < 5:
+ try:
+ await self._api.sleep(sleep)
+ join_event = await self._api.update_room_membership(
+ sender=sender,
+ target=target,
+ room_id=room_id,
+ new_membership=new_membership,
+ )
+ except SynapseError as e:
+ if e.code == HTTPStatus.FORBIDDEN:
+ logger.debug(
+ f"Update_room_membership was forbidden. This can sometimes be expected for remote invites. Exception: {e}"
+ )
+ else:
+ logger.warn(
+ f"Update_room_membership raised the following unexpected (SynapseError) exception: {e}"
+ )
+ except Exception as e:
+ logger.warn(
+ f"Update_room_membership raised the following unexpected exception: {e}"
+ )
+
+ sleep = 2**retries
+ retries += 1
+
+ if join_event is not None:
+ break
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 0772472312..b997d82d71 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -90,6 +90,7 @@ def prune_event(event: EventBase) -> EventBase:
pruned_event.internal_metadata.stream_ordering = (
event.internal_metadata.stream_ordering
)
+ pruned_event.internal_metadata.instance_name = event.internal_metadata.instance_name
pruned_event.internal_metadata.outlier = event.internal_metadata.outlier
# Mark the event as redacted
@@ -116,6 +117,7 @@ def clone_event(event: EventBase) -> EventBase:
new_event.internal_metadata.stream_ordering = (
event.internal_metadata.stream_ordering
)
+ new_event.internal_metadata.instance_name = event.internal_metadata.instance_name
new_event.internal_metadata.outlier = event.internal_metadata.outlier
return new_event
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 62f0b67dbd..73b63b77f2 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -47,9 +47,9 @@ from synapse.events.utils import (
validate_canonicaljson,
)
from synapse.http.servlet import validate_json_object
-from synapse.rest.models import RequestBodyModel
from synapse.storage.controllers.state import server_acl_evaluator_from_event
from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID
+from synapse.types.rest import RequestBodyModel
class EventValidator:
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index e613eb87a6..f0f5a37a57 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -56,6 +56,7 @@ from synapse.api.errors import (
SynapseError,
UnsupportedRoomVersionError,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
@@ -1877,6 +1878,8 @@ class FederationClient(FederationBase):
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
+ download_ratelimiter: Ratelimiter,
+ ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
try:
return await self.transport_layer.download_media_v3(
@@ -1885,6 +1888,8 @@ class FederationClient(FederationBase):
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
+ download_ratelimiter=download_ratelimiter,
+ ip_address=ip_address,
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
@@ -1905,6 +1910,8 @@ class FederationClient(FederationBase):
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
+ download_ratelimiter=download_ratelimiter,
+ ip_address=ip_address,
)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 7ffc650aa1..1932fa82a4 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -674,7 +674,7 @@ class FederationServer(FederationBase):
# This is in addition to the HS-level rate limiting applied by
# BaseFederationServlet.
# type-ignore: mypy doesn't seem able to deduce the type of the limiter(!?)
- await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type]
+ await self._room_member_handler._join_rate_per_room_limiter.ratelimit(
requester=None,
key=room_id,
update=False,
@@ -717,7 +717,7 @@ class FederationServer(FederationBase):
SynapseTags.SEND_JOIN_RESPONSE_IS_PARTIAL_STATE,
caller_supports_partial_state,
)
- await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type]
+ await self._room_member_handler._join_rate_per_room_limiter.ratelimit(
requester=None,
key=room_id,
update=False,
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index de408f7f8d..af1336fe5f 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -43,6 +43,7 @@ import ijson
from synapse.api.constants import Direction, Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError
+from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import RoomVersion
from synapse.api.urls import (
FEDERATION_UNSTABLE_PREFIX,
@@ -819,6 +820,8 @@ class TransportLayerClient:
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
+ download_ratelimiter: Ratelimiter,
+ ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/r0/download/{destination}/{media_id}"
@@ -834,6 +837,8 @@ class TransportLayerClient:
"allow_remote": "false",
"timeout_ms": str(max_timeout_ms),
},
+ download_ratelimiter=download_ratelimiter,
+ ip_address=ip_address,
)
async def download_media_v3(
@@ -843,6 +848,8 @@ class TransportLayerClient:
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
+ download_ratelimiter: Ratelimiter,
+ ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/v3/download/{destination}/{media_id}"
@@ -862,6 +869,8 @@ class TransportLayerClient:
"allow_redirect": "true",
},
follow_redirects=True,
+ download_ratelimiter=download_ratelimiter,
+ ip_address=ip_address,
)
diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py
index bac569e977..edaf0196d6 100644
--- a/synapse/federation/transport/server/__init__.py
+++ b/synapse/federation/transport/server/__init__.py
@@ -33,6 +33,7 @@ from synapse.federation.transport.server.federation import (
FEDERATION_SERVLET_CLASSES,
FederationAccountStatusServlet,
FederationUnstableClientKeysClaimServlet,
+ FederationUnstableMediaDownloadServlet,
)
from synapse.http.server import HttpServer, JsonResource
from synapse.http.servlet import (
@@ -315,6 +316,13 @@ def register_servlets(
):
continue
+ if servletclass == FederationUnstableMediaDownloadServlet:
+ if (
+ not hs.config.server.enable_media_repo
+ or not hs.config.experimental.msc3916_authenticated_media_enabled
+ ):
+ continue
+
servletclass(
hs=hs,
authenticator=authenticator,
diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py
index db0f5076a9..4e2717b565 100644
--- a/synapse/federation/transport/server/_base.py
+++ b/synapse/federation/transport/server/_base.py
@@ -360,13 +360,29 @@ class BaseFederationServlet:
"request"
)
return None
+ if (
+ func.__self__.__class__.__name__ # type: ignore
+ == "FederationUnstableMediaDownloadServlet"
+ ):
+ response = await func(
+ origin, content, request, *args, **kwargs
+ )
+ else:
+ response = await func(
+ origin, content, request.args, *args, **kwargs
+ )
+ else:
+ if (
+ func.__self__.__class__.__name__ # type: ignore
+ == "FederationUnstableMediaDownloadServlet"
+ ):
+ response = await func(
+ origin, content, request, *args, **kwargs
+ )
+ else:
response = await func(
origin, content, request.args, *args, **kwargs
)
- else:
- response = await func(
- origin, content, request.args, *args, **kwargs
- )
finally:
# if we used the origin's context as the parent, add a new span using
# the servlet span as a parent, so that we have a link
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index a59734785f..67bb907050 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -44,10 +44,13 @@ from synapse.federation.transport.server._base import (
)
from synapse.http.servlet import (
parse_boolean_from_args,
+ parse_integer,
parse_integer_from_args,
parse_string_from_args,
parse_strings_from_args,
)
+from synapse.http.site import SynapseRequest
+from synapse.media._base import DEFAULT_MAX_TIMEOUT_MS, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS
from synapse.types import JsonDict
from synapse.util import SYNAPSE_VERSION
from synapse.util.ratelimitutils import FederationRateLimiter
@@ -787,6 +790,43 @@ class FederationAccountStatusServlet(BaseFederationServerServlet):
return 200, {"account_statuses": statuses, "failures": failures}
+class FederationUnstableMediaDownloadServlet(BaseFederationServerServlet):
+ """
+ Implementation of new federation media `/download` endpoint outlined in MSC3916. Returns
+ a multipart/mixed response consisting of a JSON object and the requested media
+ item. This endpoint only returns local media.
+ """
+
+ PATH = "/media/download/(?P<media_id>[^/]*)"
+ PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3916"
+ RATELIMIT = True
+
+ def __init__(
+ self,
+ hs: "HomeServer",
+ ratelimiter: FederationRateLimiter,
+ authenticator: Authenticator,
+ server_name: str,
+ ):
+ super().__init__(hs, authenticator, ratelimiter, server_name)
+ self.media_repo = self.hs.get_media_repository()
+
+ async def on_GET(
+ self,
+ origin: Optional[str],
+ content: Literal[None],
+ request: SynapseRequest,
+ media_id: str,
+ ) -> None:
+ max_timeout_ms = parse_integer(
+ request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
+ )
+ max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
+ await self.media_repo.get_local_media(
+ request, media_id, None, max_timeout_ms, federation=True
+ )
+
+
FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationSendServlet,
FederationEventServlet,
@@ -818,4 +858,5 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationV1SendKnockServlet,
FederationMakeKnockServlet,
FederationAccountStatusServlet,
+ FederationUnstableMediaDownloadServlet,
)
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 702d40332c..ec35784c5f 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -42,7 +42,6 @@ class AdminHandler:
self._device_handler = hs.get_device_handler()
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
- self._hs_config = hs.config
self._msc3866_enabled = hs.config.experimental.msc3866.enabled
async def get_whois(self, user: UserID) -> JsonMapping:
@@ -126,13 +125,7 @@ class AdminHandler:
# Get all rooms the user is in or has been in
rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id,
- membership_list=(
- Membership.JOIN,
- Membership.LEAVE,
- Membership.BAN,
- Membership.INVITE,
- Membership.KNOCK,
- ),
+ membership_list=Membership.LIST,
)
# We only try and fetch events for rooms the user has been in. If
@@ -179,7 +172,7 @@ class AdminHandler:
if room.membership == Membership.JOIN:
stream_ordering = self._store.get_room_max_stream_ordering()
else:
- stream_ordering = room.stream_ordering
+ stream_ordering = room.event_pos.stream
from_key = RoomStreamToken(topological=0, stream=0)
to_key = RoomStreamToken(stream=stream_ordering)
@@ -221,7 +214,6 @@ class AdminHandler:
self._storage_controllers,
user_id,
events,
- msc4115_membership_on_events=self._hs_config.experimental.msc4115_membership_on_events,
)
writer.write_events(room_id, events)
diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py
index 153123ee83..cc3d641b7d 100644
--- a/synapse/handlers/cas.py
+++ b/synapse/handlers/cas.py
@@ -78,6 +78,8 @@ class CasHandler:
self._cas_displayname_attribute = hs.config.cas.cas_displayname_attribute
self._cas_required_attributes = hs.config.cas.cas_required_attributes
self._cas_enable_registration = hs.config.cas.cas_enable_registration
+ self._cas_allow_numeric_ids = hs.config.cas.cas_allow_numeric_ids
+ self._cas_numeric_ids_prefix = hs.config.cas.cas_numeric_ids_prefix
self._http_client = hs.get_proxied_http_client()
@@ -188,6 +190,9 @@ class CasHandler:
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
+ # if numeric user IDs are allowed and username is numeric then we add the prefix so Synapse can handle it
+ if self._cas_allow_numeric_ids and user is not None and user.isdigit():
+ user = f"{self._cas_numeric_ids_prefix}{user}"
if child.tag.endswith("attributes"):
for attribute in child:
# ElementTree library expands the namespace in
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 67953a3ed9..0432d97109 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -159,20 +159,32 @@ class DeviceWorkerHandler:
@cancellable
async def get_device_changes_in_shared_rooms(
- self, user_id: str, room_ids: StrCollection, from_token: StreamToken
+ self,
+ user_id: str,
+ room_ids: StrCollection,
+ from_token: StreamToken,
+ now_token: Optional[StreamToken] = None,
) -> Set[str]:
"""Get the set of users whose devices have changed who share a room with
the given user.
"""
+ now_device_lists_key = self.store.get_device_stream_token()
+ if now_token:
+ now_device_lists_key = now_token.device_list_key
+
changed_users = await self.store.get_device_list_changes_in_rooms(
- room_ids, from_token.device_list_key
+ room_ids,
+ from_token.device_list_key,
+ now_device_lists_key,
)
if changed_users is not None:
# We also check if the given user has changed their device. If
# they're in no rooms then the above query won't include them.
changed = await self.store.get_users_whose_devices_changed(
- from_token.device_list_key, [user_id]
+ from_token.device_list_key,
+ [user_id],
+ to_key=now_device_lists_key,
)
changed_users.update(changed)
return changed_users
@@ -190,7 +202,9 @@ class DeviceWorkerHandler:
tracked_users.add(user_id)
changed = await self.store.get_users_whose_devices_changed(
- from_token.device_list_key, tracked_users
+ from_token.device_list_key,
+ tracked_users,
+ to_key=now_device_lists_key,
)
return changed
@@ -892,6 +906,13 @@ class DeviceHandler(DeviceWorkerHandler):
context=opentracing_context,
)
+ await self.store.mark_redundant_device_lists_pokes(
+ user_id=user_id,
+ device_id=device_id,
+ room_id=room_id,
+ converted_upto_stream_id=stream_id,
+ )
+
# Notify replication that we've updated the device list stream.
self.notifier.notify_replication()
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 79be7c97c8..e56bdb4072 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -236,6 +236,13 @@ class DeviceMessageHandler:
local_messages = {}
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, by_device in messages.items():
+ if not UserID.is_valid(user_id):
+ logger.warning(
+ "Ignoring attempt to send device message to invalid user: %r",
+ user_id,
+ )
+ continue
+
# add an opentracing log entry for each message
for device_id, message_content in by_device.items():
log_kv(
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 1ece54ccfc..668cec513b 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -35,6 +35,7 @@ from synapse.api.errors import CodeMessageException, Codes, NotFoundError, Synap
from synapse.handlers.device import DeviceHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
+from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
from synapse.types import (
JsonDict,
JsonMapping,
@@ -45,7 +46,10 @@ from synapse.types import (
from synapse.util import json_decoder
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.cancellation import cancellable
-from synapse.util.retryutils import NotRetryingDestination
+from synapse.util.retryutils import (
+ NotRetryingDestination,
+ filter_destinations_by_retry_limiter,
+)
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -53,6 +57,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+ONE_TIME_KEY_UPLOAD = "one_time_key_upload_lock"
+
+
class E2eKeysHandler:
def __init__(self, hs: "HomeServer"):
self.config = hs.config
@@ -62,6 +69,7 @@ class E2eKeysHandler:
self._appservice_handler = hs.get_application_service_handler()
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
+ self._worker_lock_handler = hs.get_worker_locks_handler()
federation_registry = hs.get_federation_registry()
@@ -82,6 +90,12 @@ class E2eKeysHandler:
edu_updater.incoming_signing_key_update,
)
+ self.device_key_uploader = self.upload_device_keys_for_user
+ else:
+ self.device_key_uploader = (
+ ReplicationUploadKeysForUserRestServlet.make_client(hs)
+ )
+
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
@@ -145,6 +159,11 @@ class E2eKeysHandler:
remote_queries = {}
for user_id, device_ids in device_keys_query.items():
+ if not UserID.is_valid(user_id):
+ # Ignore invalid user IDs, which is the same behaviour as if
+ # the user existed but had no keys.
+ continue
+
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
local_query[user_id] = device_ids
@@ -259,10 +278,8 @@ class E2eKeysHandler:
"%d destinations to query devices for", len(remote_queries_not_in_cache)
)
- async def _query(
- destination_queries: Tuple[str, Dict[str, Iterable[str]]]
- ) -> None:
- destination, queries = destination_queries
+ async def _query(destination: str) -> None:
+ queries = remote_queries_not_in_cache[destination]
return await self._query_devices_for_destination(
results,
cross_signing_keys,
@@ -272,9 +289,20 @@ class E2eKeysHandler:
timeout,
)
+ # Only try and fetch keys for destinations that are not marked as
+ # down.
+ filtered_destinations = await filter_destinations_by_retry_limiter(
+ remote_queries_not_in_cache.keys(),
+ self.clock,
+ self.store,
+ # Let's give an arbitrary grace period for those hosts that are
+ # only recently down
+ retry_due_within_ms=60 * 1000,
+ )
+
await concurrently_execute(
_query,
- remote_queries_not_in_cache.items(),
+ filtered_destinations,
10,
delay_cancellation=True,
)
@@ -775,36 +803,17 @@ class E2eKeysHandler:
"one_time_keys": A mapping from algorithm to number of keys for that
algorithm, including those previously persisted.
"""
- # This can only be called from the main process.
- assert isinstance(self.device_handler, DeviceHandler)
-
time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys.
device_keys = keys.get("device_keys", None)
if device_keys:
- logger.info(
- "Updating device_keys for device %r for user %s at %d",
- device_id,
- user_id,
- time_now,
- )
- log_kv(
- {
- "message": "Updating device_keys for user.",
- "user_id": user_id,
- "device_id": device_id,
- }
- )
- # TODO: Sign the JSON with the server key
- changed = await self.store.set_e2e_device_keys(
- user_id, device_id, time_now, device_keys
+ await self.device_key_uploader(
+ user_id=user_id,
+ device_id=device_id,
+ keys={"device_keys": device_keys},
)
- if changed:
- # Only notify about device updates *if* the keys actually changed
- await self.device_handler.notify_device_update(user_id, [device_id])
- else:
- log_kv({"message": "Not updating device_keys for user", "user_id": user_id})
+
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
log_kv(
@@ -840,60 +849,106 @@ class E2eKeysHandler:
{"message": "Did not update fallback_keys", "reason": "no keys given"}
)
- # the device should have been registered already, but it may have been
- # deleted due to a race with a DELETE request. Or we may be using an
- # old access_token without an associated device_id. Either way, we
- # need to double-check the device is registered to avoid ending up with
- # keys without a corresponding device.
- await self.device_handler.check_device_registered(user_id, device_id)
-
result = await self.store.count_e2e_one_time_keys(user_id, device_id)
set_tag("one_time_key_counts", str(result))
return {"one_time_key_counts": result}
- async def _upload_one_time_keys_for_user(
- self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
+ @tag_args
+ async def upload_device_keys_for_user(
+ self, user_id: str, device_id: str, keys: JsonDict
) -> None:
+ """
+ Args:
+ user_id: user whose keys are being uploaded.
+ device_id: device whose keys are being uploaded.
+ device_keys: the `device_keys` of an /keys/upload request.
+
+ """
+ # This can only be called from the main process.
+ assert isinstance(self.device_handler, DeviceHandler)
+
+ time_now = self.clock.time_msec()
+
+ device_keys = keys["device_keys"]
logger.info(
- "Adding one_time_keys %r for device %r for user %r at %d",
- one_time_keys.keys(),
+ "Updating device_keys for device %r for user %s at %d",
device_id,
user_id,
time_now,
)
+ log_kv(
+ {
+ "message": "Updating device_keys for user.",
+ "user_id": user_id,
+ "device_id": device_id,
+ }
+ )
+ # TODO: Sign the JSON with the server key
+ changed = await self.store.set_e2e_device_keys(
+ user_id, device_id, time_now, device_keys
+ )
+ if changed:
+ # Only notify about device updates *if* the keys actually changed
+ await self.device_handler.notify_device_update(user_id, [device_id])
- # make a list of (alg, id, key) tuples
- key_list = []
- for key_id, key_obj in one_time_keys.items():
- algorithm, key_id = key_id.split(":")
- key_list.append((algorithm, key_id, key_obj))
+ # the device should have been registered already, but it may have been
+ # deleted due to a race with a DELETE request. Or we may be using an
+ # old access_token without an associated device_id. Either way, we
+ # need to double-check the device is registered to avoid ending up with
+ # keys without a corresponding device.
+ await self.device_handler.check_device_registered(user_id, device_id)
- # First we check if we have already persisted any of the keys.
- existing_key_map = await self.store.get_e2e_one_time_keys(
- user_id, device_id, [k_id for _, k_id, _ in key_list]
- )
+ async def _upload_one_time_keys_for_user(
+ self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
+ ) -> None:
+ # We take out a lock so that we don't have to worry about a client
+ # sending duplicate requests.
+ lock_key = f"{user_id}_{device_id}"
+ async with self._worker_lock_handler.acquire_lock(
+ ONE_TIME_KEY_UPLOAD, lock_key
+ ):
+ logger.info(
+ "Adding one_time_keys %r for device %r for user %r at %d",
+ one_time_keys.keys(),
+ device_id,
+ user_id,
+ time_now,
+ )
- new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
- for algorithm, key_id, key in key_list:
- ex_json = existing_key_map.get((algorithm, key_id), None)
- if ex_json:
- if not _one_time_keys_match(ex_json, key):
- raise SynapseError(
- 400,
- (
- "One time key %s:%s already exists. "
- "Old key: %s; new key: %r"
+ # make a list of (alg, id, key) tuples
+ key_list = []
+ for key_id, key_obj in one_time_keys.items():
+ algorithm, key_id = key_id.split(":")
+ key_list.append((algorithm, key_id, key_obj))
+
+ # First we check if we have already persisted any of the keys.
+ existing_key_map = await self.store.get_e2e_one_time_keys(
+ user_id, device_id, [k_id for _, k_id, _ in key_list]
+ )
+
+ new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
+ for algorithm, key_id, key in key_list:
+ ex_json = existing_key_map.get((algorithm, key_id), None)
+ if ex_json:
+ if not _one_time_keys_match(ex_json, key):
+ raise SynapseError(
+ 400,
+ (
+ "One time key %s:%s already exists. "
+ "Old key: %s; new key: %r"
+ )
+ % (algorithm, key_id, ex_json, key),
)
- % (algorithm, key_id, ex_json, key),
+ else:
+ new_keys.append(
+ (algorithm, key_id, encode_canonical_json(key).decode("ascii"))
)
- else:
- new_keys.append(
- (algorithm, key_id, encode_canonical_json(key).decode("ascii"))
- )
- log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
- await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
+ log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
+ await self.store.add_e2e_one_time_keys(
+ user_id, device_id, time_now, new_keys
+ )
async def upload_signing_keys_for_user(
self, user_id: str, keys: JsonDict
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index e76a51ba30..99f9f6e64a 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -247,6 +247,12 @@ class E2eRoomKeysHandler:
if current_room_key:
if self._should_replace_room_key(current_room_key, room_key):
log_kv({"message": "Replacing room key."})
+ logger.debug(
+ "Replacing room key. room=%s session=%s user=%s",
+ room_id,
+ session_id,
+ user_id,
+ )
# updates are done one at a time in the DB, so send
# updates right away rather than batching them up,
# like we do with the inserts
@@ -256,6 +262,12 @@ class E2eRoomKeysHandler:
changed = True
else:
log_kv({"message": "Not replacing room_key."})
+ logger.debug(
+ "Not replacing room key. room=%s session=%s user=%s",
+ room_id,
+ session_id,
+ user_id,
+ )
else:
log_kv(
{
@@ -265,6 +277,12 @@ class E2eRoomKeysHandler:
}
)
log_kv({"message": "Replacing room key."})
+ logger.debug(
+ "Inserting new room key. room=%s session=%s user=%s",
+ room_id,
+ session_id,
+ user_id,
+ )
to_insert.append((room_id, session_id, room_key))
changed = True
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 09d553cff1..3f46032a43 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -148,7 +148,6 @@ class EventHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
- self._config = hs.config
async def get_event(
self,
@@ -194,7 +193,6 @@ class EventHandler:
user.to_string(),
[event],
is_peeking=is_peeking,
- msc4115_membership_on_events=self._config.experimental.msc4115_membership_on_events,
)
if not filtered:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index d99fc4bec0..bd3c87f5f4 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -199,7 +199,7 @@ class InitialSyncHandler:
)
elif event.membership == Membership.LEAVE:
room_end_token = RoomStreamToken(
- stream=event.stream_ordering,
+ stream=event.event_pos.stream,
)
deferred_room_state = run_in_background(
self._state_storage_controller.get_state_for_events,
@@ -224,7 +224,6 @@ class InitialSyncHandler:
self._storage_controllers,
user_id,
messages,
- msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
)
start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
@@ -383,7 +382,6 @@ class InitialSyncHandler:
requester.user.to_string(),
messages,
is_peeking=is_peeking,
- msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
)
start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token)
@@ -498,7 +496,6 @@ class InitialSyncHandler:
requester.user.to_string(),
messages,
is_peeking=is_peeking,
- msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
)
start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index ccaa5508ff..5aa48230ec 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -201,7 +201,7 @@ class MessageHandler:
if at_token:
last_event_id = (
- await self.store.get_last_event_in_room_before_stream_ordering(
+ await self.store.get_last_event_id_in_room_before_stream_ordering(
room_id,
end_token=at_token.room_key,
)
@@ -496,13 +496,6 @@ class EventCreationHandler:
self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
- self.membership_types_to_include_profile_data_in = {
- Membership.JOIN,
- Membership.KNOCK,
- }
- if self.hs.config.server.include_profile_data_on_invite:
- self.membership_types_to_include_profile_data_in.add(Membership.INVITE)
-
self.send_event = ReplicationSendEventRestServlet.make_client(hs)
self.send_events = ReplicationSendEventsRestServlet.make_client(hs)
@@ -594,8 +587,6 @@ class EventCreationHandler:
Creates an FrozenEvent object, filling out auth_events, prev_events,
etc.
- Adds display names to Join membership events.
-
Args:
requester
event_dict: An entire event
@@ -651,6 +642,17 @@ class EventCreationHandler:
"""
await self.auth_blocking.check_auth_blocking(requester=requester)
+ if event_dict["type"] == EventTypes.Message:
+ requester_suspended = await self.store.get_user_suspended_status(
+ requester.user.to_string()
+ )
+ if requester_suspended:
+ raise SynapseError(
+ 403,
+ "Sending messages while account is suspended is not allowed.",
+ Codes.USER_ACCOUNT_SUSPENDED,
+ )
+
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
room_version_id = event_dict["content"]["room_version"]
maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
@@ -672,29 +674,6 @@ class EventCreationHandler:
self.validator.validate_builder(builder)
- if builder.type == EventTypes.Member:
- membership = builder.content.get("membership", None)
- target = UserID.from_string(builder.state_key)
-
- if membership in self.membership_types_to_include_profile_data_in:
- # If event doesn't include a display name, add one.
- profile = self.profile_handler
- content = builder.content
-
- try:
- if "displayname" not in content:
- displayname = await profile.get_displayname(target)
- if displayname is not None:
- content["displayname"] = displayname
- if "avatar_url" not in content:
- avatar_url = await profile.get_avatar_url(target)
- if avatar_url is not None:
- content["avatar_url"] = avatar_url
- except Exception as e:
- logger.info(
- "Failed to get profile information for %r: %s", target, e
- )
-
is_exempt = await self._is_exempt_from_privacy_policy(builder, requester)
if require_consent and not is_exempt:
await self.assert_accepted_privacy_policy(requester)
@@ -1583,6 +1562,7 @@ class EventCreationHandler:
# stream_ordering entry manually (as it was persisted on
# another worker).
event.internal_metadata.stream_ordering = stream_id
+ event.internal_metadata.instance_name = writer_instance
return event
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 6617105cdb..872c85fbad 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -27,7 +27,6 @@ from synapse.api.constants import Direction, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import SerializeEventConfig
-from synapse.handlers.room import ShutdownRoomParams, ShutdownRoomResponse
from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging.opentracing import trace
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -41,6 +40,7 @@ from synapse.types import (
StreamKeyType,
TaskStatus,
)
+from synapse.types.handlers import ShutdownRoomParams, ShutdownRoomResponse
from synapse.types.state import StateFilter
from synapse.util.async_helpers import ReadWriteLock
from synapse.visibility import filter_events_for_client
@@ -623,7 +623,6 @@ class PaginationHandler:
user_id,
events,
is_peeking=(member_event_id is None),
- msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
)
# if after the filter applied there are no more events
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index e51e282a9f..6663d4b271 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -20,7 +20,7 @@
#
import logging
import random
-from typing import TYPE_CHECKING, Optional, Union
+from typing import TYPE_CHECKING, List, Optional, Union
from synapse.api.errors import (
AuthError,
@@ -64,8 +64,10 @@ class ProfileHandler:
self.user_directory_handler = hs.get_user_directory_handler()
self.request_ratelimiter = hs.get_request_ratelimiter()
- self.max_avatar_size = hs.config.server.max_avatar_size
- self.allowed_avatar_mimetypes = hs.config.server.allowed_avatar_mimetypes
+ self.max_avatar_size: Optional[int] = hs.config.server.max_avatar_size
+ self.allowed_avatar_mimetypes: Optional[List[str]] = (
+ hs.config.server.allowed_avatar_mimetypes
+ )
self._is_mine_server_name = hs.is_mine_server_name
@@ -337,6 +339,12 @@ class ProfileHandler:
return False
if self.max_avatar_size:
+ if media_info.media_length is None:
+ logger.warning(
+ "Forbidding avatar change to %s: unknown media size",
+ mxc,
+ )
+ return False
# Ensure avatar does not exceed max allowed avatar size
if media_info.media_length > self.max_avatar_size:
logger.warning(
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index e48e70db04..c200e29569 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -590,7 +590,7 @@ class RegistrationHandler:
# moving away from bare excepts is a good thing to do.
logger.error("Failed to join new user to %r: %r", r, e)
except Exception as e:
- logger.error("Failed to join new user to %r: %r", r, e)
+ logger.error("Failed to join new user to %r: %r", r, e, exc_info=True)
async def _auto_join_rooms(self, user_id: str) -> None:
"""Automatically joins users to auto join rooms - creating the room in the first place
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index c5cee8860b..efe31e81f9 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -95,7 +95,6 @@ class RelationsHandler:
self._event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
self._event_creation_handler = hs.get_event_creation_handler()
- self._config = hs.config
async def get_relations(
self,
@@ -164,7 +163,6 @@ class RelationsHandler:
user_id,
events,
is_peeking=(member_event_id is None),
- msc4115_membership_on_events=self._config.experimental.msc4115_membership_on_events,
)
# The relations returned for the requested event do include their
@@ -393,9 +391,9 @@ class RelationsHandler:
# Attempt to find another event to use as the latest event.
potential_events, _ = await self._main_store.get_relations_for_event(
+ room_id,
event_id,
event,
- room_id,
RelationTypes.THREAD,
direction=Direction.FORWARDS,
)
@@ -610,7 +608,6 @@ class RelationsHandler:
user_id,
events,
is_peeking=(member_event_id is None),
- msc4115_membership_on_events=self._config.experimental.msc4115_membership_on_events,
)
aggregations = await self.get_bundled_aggregations(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 51739a2653..2302d283a7 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -40,7 +40,6 @@ from typing import (
)
import attr
-from typing_extensions import TypedDict
import synapse.events.snapshot
from synapse.api.constants import (
@@ -88,6 +87,7 @@ from synapse.types import (
UserID,
create_requester,
)
+from synapse.types.handlers import ShutdownRoomParams, ShutdownRoomResponse
from synapse.types.state import StateFilter
from synapse.util import stringutils
from synapse.util.caches.response_cache import ResponseCache
@@ -1476,7 +1476,6 @@ class RoomContextHandler:
user.to_string(),
events,
is_peeking=is_peeking,
- msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
)
event = await self.store.get_event(
@@ -1780,63 +1779,6 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
return self.store.get_current_room_stream_token_for_room_id(room_id)
-class ShutdownRoomParams(TypedDict):
- """
- Attributes:
- requester_user_id:
- User who requested the action. Will be recorded as putting the room on the
- blocking list.
- new_room_user_id:
- If set, a new room will be created with this user ID
- as the creator and admin, and all users in the old room will be
- moved into that room. If not set, no new room will be created
- and the users will just be removed from the old room.
- new_room_name:
- A string representing the name of the room that new users will
- be invited to. Defaults to `Content Violation Notification`
- message:
- A string containing the first message that will be sent as
- `new_room_user_id` in the new room. Ideally this will clearly
- convey why the original room was shut down.
- Defaults to `Sharing illegal content on this server is not
- permitted and rooms in violation will be blocked.`
- block:
- If set to `true`, this room will be added to a blocking list,
- preventing future attempts to join the room. Defaults to `false`.
- purge:
- If set to `true`, purge the given room from the database.
- force_purge:
- If set to `true`, the room will be purged from database
- even if there are still users joined to the room.
- """
-
- requester_user_id: Optional[str]
- new_room_user_id: Optional[str]
- new_room_name: Optional[str]
- message: Optional[str]
- block: bool
- purge: bool
- force_purge: bool
-
-
-class ShutdownRoomResponse(TypedDict):
- """
- Attributes:
- kicked_users: An array of users (`user_id`) that were kicked.
- failed_to_kick_users:
- An array of users (`user_id`) that that were not kicked.
- local_aliases:
- An array of strings representing the local aliases that were
- migrated from the old room to the new.
- new_room_id: A string representing the room ID of the new room.
- """
-
- kicked_users: List[str]
- failed_to_kick_users: List[str]
- local_aliases: List[str]
- new_room_id: Optional[str]
-
-
class RoomShutdownHandler:
DEFAULT_MESSAGE = (
"Sharing illegal content on this server is not permitted and rooms in"
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 655c78e150..51b9772329 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -106,6 +106,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.event_auth_handler = hs.get_event_auth_handler()
self._worker_lock_handler = hs.get_worker_locks_handler()
+ self._membership_types_to_include_profile_data_in = {
+ Membership.JOIN,
+ Membership.KNOCK,
+ }
+ if self.hs.config.server.include_profile_data_on_invite:
+ self._membership_types_to_include_profile_data_in.add(Membership.INVITE)
+
self.member_linearizer: Linearizer = Linearizer(name="member")
self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter")
@@ -785,9 +792,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if (
not self.allow_per_room_profiles and not is_requester_server_notices_user
) or requester.shadow_banned:
- # Strip profile data, knowing that new profile data will be added to the
- # event's content in event_creation_handler.create_event() using the target's
- # global profile.
+ # Strip profile data, knowing that new profile data will be added to
+ # the event's content below using the target's global profile.
content.pop("displayname", None)
content.pop("avatar_url", None)
@@ -823,6 +829,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if action in ["kick", "unban"]:
effective_membership_state = "leave"
+ if effective_membership_state not in Membership.LIST:
+ raise SynapseError(400, "Invalid membership key")
+
+ # Add profile data for joins etc, if no per-room profile.
+ if (
+ effective_membership_state
+ in self._membership_types_to_include_profile_data_in
+ ):
+ # If event doesn't include a display name, add one.
+ profile = self.profile_handler
+
+ try:
+ if "displayname" not in content:
+ displayname = await profile.get_displayname(target)
+ if displayname is not None:
+ content["displayname"] = displayname
+ if "avatar_url" not in content:
+ avatar_url = await profile.get_avatar_url(target)
+ if avatar_url is not None:
+ content["avatar_url"] = avatar_url
+ except Exception as e:
+ logger.info("Failed to get profile information for %r: %s", target, e)
+
# if this is a join with a 3pid signature, we may need to turn a 3pid
# invite into a normal invite before we can handle the join.
if third_party_signed is not None:
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index fdbe98de3b..a7d52fa648 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -483,7 +483,6 @@ class SearchHandler:
self._storage_controllers,
user.to_string(),
filtered_events,
- msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
)
events.sort(key=lambda e: -rank_map[e.event_id])
@@ -585,7 +584,6 @@ class SearchHandler:
self._storage_controllers,
user.to_string(),
filtered_events,
- msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
)
room_events.extend(events)
@@ -673,14 +671,12 @@ class SearchHandler:
self._storage_controllers,
user.to_string(),
res.events_before,
- msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
)
events_after = await filter_events_for_client(
self._storage_controllers,
user.to_string(),
res.events_after,
- msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
)
context: JsonDict = {
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
new file mode 100644
index 0000000000..847a638bba
--- /dev/null
+++ b/synapse/handlers/sliding_sync.py
@@ -0,0 +1,680 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+import logging
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+
+from immutabledict import immutabledict
+
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership
+from synapse.events import EventBase
+from synapse.storage.roommember import RoomsForUser
+from synapse.types import (
+ PersistedEventPosition,
+ Requester,
+ RoomStreamToken,
+ StreamToken,
+ UserID,
+)
+from synapse.types.handlers import OperationType, SlidingSyncConfig, SlidingSyncResult
+from synapse.types.state import StateFilter
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+def convert_event_to_rooms_for_user(event: EventBase) -> RoomsForUser:
+ """
+ Quick helper to convert an event to a `RoomsForUser` object.
+ """
+ # These fields should be present for all persisted events
+ assert event.internal_metadata.stream_ordering is not None
+ assert event.internal_metadata.instance_name is not None
+
+ return RoomsForUser(
+ room_id=event.room_id,
+ sender=event.sender,
+ membership=event.membership,
+ event_id=event.event_id,
+ event_pos=PersistedEventPosition(
+ event.internal_metadata.instance_name,
+ event.internal_metadata.stream_ordering,
+ ),
+ room_version_id=event.room_version.identifier,
+ )
+
+
+def filter_membership_for_sync(*, membership: str, user_id: str, sender: str) -> bool:
+ """
+ Returns True if the membership event should be included in the sync response,
+ otherwise False.
+
+ Attributes:
+ membership: The membership state of the user in the room.
+ user_id: The user ID that the membership applies to
+ sender: The person who sent the membership event
+ """
+
+ # Everything except `Membership.LEAVE` because we want everything that's *still*
+ # relevant to the user. There are few more things to include in the sync response
+ # (newly_left) but those are handled separately.
+ #
+ # This logic includes kicks (leave events where the sender is not the same user) and
+ # can be read as "anything that isn't a leave or a leave with a different sender".
+ return membership != Membership.LEAVE or sender != user_id
+
+
+class SlidingSyncHandler:
+ def __init__(self, hs: "HomeServer"):
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastores().main
+ self.storage_controllers = hs.get_storage_controllers()
+ self.auth_blocking = hs.get_auth_blocking()
+ self.notifier = hs.get_notifier()
+ self.event_sources = hs.get_event_sources()
+ self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
+
+ async def wait_for_sync_for_user(
+ self,
+ requester: Requester,
+ sync_config: SlidingSyncConfig,
+ from_token: Optional[StreamToken] = None,
+ timeout_ms: int = 0,
+ ) -> SlidingSyncResult:
+ """
+ Get the sync for a client if we have new data for it now. Otherwise
+ wait for new data to arrive on the server. If the timeout expires, then
+ return an empty sync result.
+
+ Args:
+ requester: The user making the request
+ sync_config: Sync configuration
+ from_token: The point in the stream to sync from. Token of the end of the
+ previous batch. May be `None` if this is the initial sync request.
+ timeout_ms: The time in milliseconds to wait for new data to arrive. If 0,
+ we will immediately but there might not be any new data so we just return an
+ empty response.
+ """
+ # If the user is not part of the mau group, then check that limits have
+ # not been exceeded (if not part of the group by this point, almost certain
+ # auth_blocking will occur)
+ await self.auth_blocking.check_auth_blocking(requester=requester)
+
+ # TODO: If the To-Device extension is enabled and we have a `from_token`, delete
+ # any to-device messages before that token (since we now know that the device
+ # has received them). (see sync v2 for how to do this)
+
+ # If we're working with a user-provided token, we need to make sure to wait for
+ # this worker to catch up with the token so we don't skip past any incoming
+ # events or future events if the user is nefariously, manually modifying the
+ # token.
+ if from_token is not None:
+ # We need to make sure this worker has caught up with the token. If
+ # this returns false, it means we timed out waiting, and we should
+ # just return an empty response.
+ before_wait_ts = self.clock.time_msec()
+ if not await self.notifier.wait_for_stream_token(from_token):
+ logger.warning(
+ "Timed out waiting for worker to catch up. Returning empty response"
+ )
+ return SlidingSyncResult.empty(from_token)
+
+ # If we've spent significant time waiting to catch up, take it off
+ # the timeout.
+ after_wait_ts = self.clock.time_msec()
+ if after_wait_ts - before_wait_ts > 1_000:
+ timeout_ms -= after_wait_ts - before_wait_ts
+ timeout_ms = max(timeout_ms, 0)
+
+ # We're going to respond immediately if the timeout is 0 or if this is an
+ # initial sync (without a `from_token`) so we can avoid calling
+ # `notifier.wait_for_events()`.
+ if timeout_ms == 0 or from_token is None:
+ now_token = self.event_sources.get_current_token()
+ result = await self.current_sync_for_user(
+ sync_config,
+ from_token=from_token,
+ to_token=now_token,
+ )
+ else:
+ # Otherwise, we wait for something to happen and report it to the user.
+ async def current_sync_callback(
+ before_token: StreamToken, after_token: StreamToken
+ ) -> SlidingSyncResult:
+ return await self.current_sync_for_user(
+ sync_config,
+ from_token=from_token,
+ to_token=after_token,
+ )
+
+ result = await self.notifier.wait_for_events(
+ sync_config.user.to_string(),
+ timeout_ms,
+ current_sync_callback,
+ from_token=from_token,
+ )
+
+ return result
+
+ async def current_sync_for_user(
+ self,
+ sync_config: SlidingSyncConfig,
+ to_token: StreamToken,
+ from_token: Optional[StreamToken] = None,
+ ) -> SlidingSyncResult:
+ """
+ Generates the response body of a Sliding Sync result, represented as a
+ `SlidingSyncResult`.
+
+ We fetch data according to the token range (> `from_token` and <= `to_token`).
+
+ Args:
+ sync_config: Sync configuration
+ to_token: The point in the stream to sync up to.
+ from_token: The point in the stream to sync from. Token of the end of the
+ previous batch. May be `None` if this is the initial sync request.
+ """
+ user_id = sync_config.user.to_string()
+ app_service = self.store.get_app_service_by_user_id(user_id)
+ if app_service:
+ # We no longer support AS users using /sync directly.
+ # See https://github.com/matrix-org/matrix-doc/issues/1144
+ raise NotImplementedError()
+
+ # Assemble sliding window lists
+ lists: Dict[str, SlidingSyncResult.SlidingWindowList] = {}
+ if sync_config.lists:
+ # Get all of the room IDs that the user should be able to see in the sync
+ # response
+ sync_room_map = await self.get_sync_room_ids_for_user(
+ sync_config.user,
+ from_token=from_token,
+ to_token=to_token,
+ )
+
+ for list_key, list_config in sync_config.lists.items():
+ # Apply filters
+ filtered_sync_room_map = sync_room_map
+ if list_config.filters is not None:
+ filtered_sync_room_map = await self.filter_rooms(
+ sync_config.user, sync_room_map, list_config.filters, to_token
+ )
+
+ sorted_room_info = await self.sort_rooms(
+ filtered_sync_room_map, to_token
+ )
+
+ ops: List[SlidingSyncResult.SlidingWindowList.Operation] = []
+ if list_config.ranges:
+ for range in list_config.ranges:
+ ops.append(
+ SlidingSyncResult.SlidingWindowList.Operation(
+ op=OperationType.SYNC,
+ range=range,
+ room_ids=[
+ room_id
+ for room_id, _ in sorted_room_info[
+ range[0] : range[1]
+ ]
+ ],
+ )
+ )
+
+ lists[list_key] = SlidingSyncResult.SlidingWindowList(
+ count=len(sorted_room_info),
+ ops=ops,
+ )
+
+ return SlidingSyncResult(
+ next_pos=to_token,
+ lists=lists,
+ # TODO: Gather room data for rooms in lists and `sync_config.room_subscriptions`
+ rooms={},
+ extensions={},
+ )
+
+ async def get_sync_room_ids_for_user(
+ self,
+ user: UserID,
+ to_token: StreamToken,
+ from_token: Optional[StreamToken] = None,
+ ) -> Dict[str, RoomsForUser]:
+ """
+ Fetch room IDs that should be listed for this user in the sync response (the
+ full room list that will be filtered, sorted, and sliced).
+
+ We're looking for rooms where the user has the following state in the token
+ range (> `from_token` and <= `to_token`):
+
+ - `invite`, `join`, `knock`, `ban` membership events
+ - Kicks (`leave` membership events where `sender` is different from the
+ `user_id`/`state_key`)
+ - `newly_left` (rooms that were left during the given token range)
+ - In order for bans/kicks to not show up in sync, you need to `/forget` those
+ rooms. This doesn't modify the event itself though and only adds the
+ `forgotten` flag to the `room_memberships` table in Synapse. There isn't a way
+ to tell when a room was forgotten at the moment so we can't factor it into the
+ from/to range.
+
+ Args:
+ user: User to fetch rooms for
+ to_token: The token to fetch rooms up to.
+ from_token: The point in the stream to sync from.
+
+ Returns:
+ A dictionary of room IDs that should be listed in the sync response along
+ with membership information in that room at the time of `to_token`.
+ """
+ user_id = user.to_string()
+
+ # First grab a current snapshot rooms for the user
+ # (also handles forgotten rooms)
+ room_for_user_list = await self.store.get_rooms_for_local_user_where_membership_is(
+ user_id=user_id,
+ # We want to fetch any kind of membership (joined and left rooms) in order
+ # to get the `event_pos` of the latest room membership event for the
+ # user.
+ #
+ # We will filter out the rooms that don't belong below (see
+ # `filter_membership_for_sync`)
+ membership_list=Membership.LIST,
+ excluded_rooms=self.rooms_to_exclude_globally,
+ )
+
+ # If the user has never joined any rooms before, we can just return an empty list
+ if not room_for_user_list:
+ return {}
+
+ # Our working list of rooms that can show up in the sync response
+ sync_room_id_set = {
+ room_for_user.room_id: room_for_user
+ for room_for_user in room_for_user_list
+ if filter_membership_for_sync(
+ membership=room_for_user.membership,
+ user_id=user_id,
+ sender=room_for_user.sender,
+ )
+ }
+
+ # Get the `RoomStreamToken` that represents the spot we queried up to when we got
+ # our membership snapshot from `get_rooms_for_local_user_where_membership_is()`.
+ #
+ # First, we need to get the max stream_ordering of each event persister instance
+ # that we queried events from.
+ instance_to_max_stream_ordering_map: Dict[str, int] = {}
+ for room_for_user in room_for_user_list:
+ instance_name = room_for_user.event_pos.instance_name
+ stream_ordering = room_for_user.event_pos.stream
+
+ current_instance_max_stream_ordering = (
+ instance_to_max_stream_ordering_map.get(instance_name)
+ )
+ if (
+ current_instance_max_stream_ordering is None
+ or stream_ordering > current_instance_max_stream_ordering
+ ):
+ instance_to_max_stream_ordering_map[instance_name] = stream_ordering
+
+ # Then assemble the `RoomStreamToken`
+ membership_snapshot_token = RoomStreamToken(
+ # Minimum position in the `instance_map`
+ stream=min(instance_to_max_stream_ordering_map.values()),
+ instance_map=immutabledict(instance_to_max_stream_ordering_map),
+ )
+
+ # Since we fetched the users room list at some point in time after the from/to
+ # tokens, we need to revert/rewind some membership changes to match the point in
+ # time of the `to_token`. In particular, we need to make these fixups:
+ #
+ # - 1a) Remove rooms that the user joined after the `to_token`
+ # - 1b) Add back rooms that the user left after the `to_token`
+ # - 2) Add back newly_left rooms (> `from_token` and <= `to_token`)
+ #
+ # Below, we're doing two separate lookups for membership changes. We could
+ # request everything for both fixups in one range, [`from_token.room_key`,
+ # `membership_snapshot_token`), but we want to avoid raw `stream_ordering`
+ # comparison without `instance_name` (which is flawed). We could refactor
+ # `event.internal_metadata` to include `instance_name` but it might turn out a
+ # little difficult and a bigger, broader Synapse change than we want to make.
+
+ # 1) -----------------------------------------------------
+
+ # 1) Fetch membership changes that fall in the range from `to_token` up to
+ # `membership_snapshot_token`
+ #
+ # If our `to_token` is already the same or ahead of the latest room membership
+ # for the user, we don't need to do any "2)" fix-ups and can just straight-up
+ # use the room list from the snapshot as a base (nothing has changed)
+ membership_change_events_after_to_token = []
+ if not membership_snapshot_token.is_before_or_eq(to_token.room_key):
+ membership_change_events_after_to_token = (
+ await self.store.get_membership_changes_for_user(
+ user_id,
+ from_key=to_token.room_key,
+ to_key=membership_snapshot_token,
+ excluded_rooms=self.rooms_to_exclude_globally,
+ )
+ )
+
+ # 1) Assemble a list of the last membership events in some given ranges. Someone
+ # could have left and joined multiple times during the given range but we only
+ # care about end-result so we grab the last one.
+ last_membership_change_by_room_id_after_to_token: Dict[str, EventBase] = {}
+ # We also need the first membership event after the `to_token` so we can step
+ # backward to the previous membership that would apply to the from/to range.
+ first_membership_change_by_room_id_after_to_token: Dict[str, EventBase] = {}
+ for event in membership_change_events_after_to_token:
+ last_membership_change_by_room_id_after_to_token[event.room_id] = event
+ # Only set if we haven't already set it
+ first_membership_change_by_room_id_after_to_token.setdefault(
+ event.room_id, event
+ )
+
+ # 1) Fixup
+ for (
+ last_membership_change_after_to_token
+ ) in last_membership_change_by_room_id_after_to_token.values():
+ room_id = last_membership_change_after_to_token.room_id
+
+ # We want to find the first membership change after the `to_token` then step
+ # backward to know the membership in the from/to range.
+ first_membership_change_after_to_token = (
+ first_membership_change_by_room_id_after_to_token.get(room_id)
+ )
+ assert first_membership_change_after_to_token is not None, (
+ "If there was a `last_membership_change_after_to_token` that we're iterating over, "
+ + "then there should be corresponding a first change. For example, even if there "
+ + "is only one event after the `to_token`, the first and last event will be same event. "
+ + "This is probably a mistake in assembling the `last_membership_change_by_room_id_after_to_token`"
+ + "/`first_membership_change_by_room_id_after_to_token` dicts above."
+ )
+ # TODO: Instead of reading from `unsigned`, refactor this to use the
+ # `current_state_delta_stream` table in the future. Probably a new
+ # `get_membership_changes_for_user()` function that uses
+ # `current_state_delta_stream` with a join to `room_memberships`. This would
+ # help in state reset scenarios since `prev_content` is looking at the
+ # current branch vs the current room state. This is all just data given to
+ # the client so no real harm to data integrity, but we'd like to be nice to
+ # the client. Since the `current_state_delta_stream` table is new, it
+ # doesn't have all events in it. Since this is Sliding Sync, if we ever need
+ # to, we can signal the client to throw all of their state away by sending
+ # "operation: RESET".
+ prev_content = first_membership_change_after_to_token.unsigned.get(
+ "prev_content", {}
+ )
+ prev_membership = prev_content.get("membership", None)
+ prev_sender = first_membership_change_after_to_token.unsigned.get(
+ "prev_sender", None
+ )
+
+ # Check if the previous membership (membership that applies to the from/to
+ # range) should be included in our `sync_room_id_set`
+ should_prev_membership_be_included = (
+ prev_membership is not None
+ and prev_sender is not None
+ and filter_membership_for_sync(
+ membership=prev_membership,
+ user_id=user_id,
+ sender=prev_sender,
+ )
+ )
+
+ # Check if the last membership (membership that applies to our snapshot) was
+ # already included in our `sync_room_id_set`
+ was_last_membership_already_included = filter_membership_for_sync(
+ membership=last_membership_change_after_to_token.membership,
+ user_id=user_id,
+ sender=last_membership_change_after_to_token.sender,
+ )
+
+ # 1a) Add back rooms that the user left after the `to_token`
+ #
+ # For example, if the last membership event after the `to_token` is a leave
+ # event, then the room was excluded from `sync_room_id_set` when we first
+ # crafted it above. We should add these rooms back as long as the user also
+ # was part of the room before the `to_token`.
+ if (
+ not was_last_membership_already_included
+ and should_prev_membership_be_included
+ ):
+ sync_room_id_set[room_id] = convert_event_to_rooms_for_user(
+ last_membership_change_after_to_token
+ )
+ # 1b) Remove rooms that the user joined (hasn't left) after the `to_token`
+ #
+ # For example, if the last membership event after the `to_token` is a "join"
+ # event, then the room was included `sync_room_id_set` when we first crafted
+ # it above. We should remove these rooms as long as the user also wasn't
+ # part of the room before the `to_token`.
+ elif (
+ was_last_membership_already_included
+ and not should_prev_membership_be_included
+ ):
+ del sync_room_id_set[room_id]
+
+ # 2) -----------------------------------------------------
+ # We fix-up newly_left rooms after the first fixup because it may have removed
+ # some left rooms that we can figure out our newly_left in the following code
+
+ # 2) Fetch membership changes that fall in the range from `from_token` up to `to_token`
+ membership_change_events_in_from_to_range = []
+ if from_token:
+ membership_change_events_in_from_to_range = (
+ await self.store.get_membership_changes_for_user(
+ user_id,
+ from_key=from_token.room_key,
+ to_key=to_token.room_key,
+ excluded_rooms=self.rooms_to_exclude_globally,
+ )
+ )
+
+ # 2) Assemble a list of the last membership events in some given ranges. Someone
+ # could have left and joined multiple times during the given range but we only
+ # care about end-result so we grab the last one.
+ last_membership_change_by_room_id_in_from_to_range: Dict[str, EventBase] = {}
+ for event in membership_change_events_in_from_to_range:
+ last_membership_change_by_room_id_in_from_to_range[event.room_id] = event
+
+ # 2) Fixup
+ for (
+ last_membership_change_in_from_to_range
+ ) in last_membership_change_by_room_id_in_from_to_range.values():
+ room_id = last_membership_change_in_from_to_range.room_id
+
+ # 2) Add back newly_left rooms (> `from_token` and <= `to_token`). We
+ # include newly_left rooms because the last event that the user should see
+ # is their own leave event
+ if last_membership_change_in_from_to_range.membership == Membership.LEAVE:
+ sync_room_id_set[room_id] = convert_event_to_rooms_for_user(
+ last_membership_change_in_from_to_range
+ )
+
+ return sync_room_id_set
+
+ async def filter_rooms(
+ self,
+ user: UserID,
+ sync_room_map: Dict[str, RoomsForUser],
+ filters: SlidingSyncConfig.SlidingSyncList.Filters,
+ to_token: StreamToken,
+ ) -> Dict[str, RoomsForUser]:
+ """
+ Filter rooms based on the sync request.
+
+ Args:
+ user: User to filter rooms for
+ sync_room_map: Dictionary of room IDs to sort along with membership
+ information in the room at the time of `to_token`.
+ filters: Filters to apply
+ to_token: We filter based on the state of the room at this token
+
+ Returns:
+ A filtered dictionary of room IDs along with membership information in the
+ room at the time of `to_token`.
+ """
+ user_id = user.to_string()
+
+ # TODO: Apply filters
+ #
+ # TODO: Exclude partially stated rooms unless the `required_state` has
+ # `["m.room.member", "$LAZY"]`
+
+ filtered_room_id_set = set(sync_room_map.keys())
+
+ # Filter for Direct-Message (DM) rooms
+ if filters.is_dm is not None:
+ # We're using global account data (`m.direct`) instead of checking for
+ # `is_direct` on membership events because that property only appears for
+ # the invitee membership event (doesn't show up for the inviter). Account
+ # data is set by the client so it needs to be scrutinized.
+ #
+ # We're unable to take `to_token` into account for global account data since
+ # we only keep track of the latest account data for the user.
+ dm_map = await self.store.get_global_account_data_by_type_for_user(
+ user_id, AccountDataTypes.DIRECT
+ )
+
+ # Flatten out the map
+ dm_room_id_set = set()
+ if isinstance(dm_map, dict):
+ for room_ids in dm_map.values():
+ # Account data should be a list of room IDs. Ignore anything else
+ if isinstance(room_ids, list):
+ for room_id in room_ids:
+ if isinstance(room_id, str):
+ dm_room_id_set.add(room_id)
+
+ if filters.is_dm:
+ # Only DM rooms please
+ filtered_room_id_set = filtered_room_id_set.intersection(dm_room_id_set)
+ else:
+ # Only non-DM rooms please
+ filtered_room_id_set = filtered_room_id_set.difference(dm_room_id_set)
+
+ if filters.spaces:
+ raise NotImplementedError()
+
+ # Filter for encrypted rooms
+ if filters.is_encrypted is not None:
+ # Make a copy so we don't run into an error: `Set changed size during
+ # iteration`, when we filter out and remove items
+ for room_id in list(filtered_room_id_set):
+ state_at_to_token = await self.storage_controllers.state.get_state_at(
+ room_id,
+ to_token,
+ state_filter=StateFilter.from_types(
+ [(EventTypes.RoomEncryption, "")]
+ ),
+ )
+ is_encrypted = state_at_to_token.get((EventTypes.RoomEncryption, ""))
+
+ # If we're looking for encrypted rooms, filter out rooms that are not
+ # encrypted and vice versa
+ if (filters.is_encrypted and not is_encrypted) or (
+ not filters.is_encrypted and is_encrypted
+ ):
+ filtered_room_id_set.remove(room_id)
+
+ # Filter for rooms that the user has been invited to
+ if filters.is_invite is not None:
+ # Make a copy so we don't run into an error: `Set changed size during
+ # iteration`, when we filter out and remove items
+ for room_id in list(filtered_room_id_set):
+ room_for_user = sync_room_map[room_id]
+ # If we're looking for invite rooms, filter out rooms that the user is
+ # not invited to and vice versa
+ if (
+ filters.is_invite and room_for_user.membership != Membership.INVITE
+ ) or (
+ not filters.is_invite
+ and room_for_user.membership == Membership.INVITE
+ ):
+ filtered_room_id_set.remove(room_id)
+
+ if filters.room_types:
+ raise NotImplementedError()
+
+ if filters.not_room_types:
+ raise NotImplementedError()
+
+ if filters.room_name_like:
+ raise NotImplementedError()
+
+ if filters.tags:
+ raise NotImplementedError()
+
+ if filters.not_tags:
+ raise NotImplementedError()
+
+ # Assemble a new sync room map but only with the `filtered_room_id_set`
+ return {room_id: sync_room_map[room_id] for room_id in filtered_room_id_set}
+
+ async def sort_rooms(
+ self,
+ sync_room_map: Dict[str, RoomsForUser],
+ to_token: StreamToken,
+ ) -> List[Tuple[str, RoomsForUser]]:
+ """
+ Sort by `stream_ordering` of the last event that the user should see in the
+ room. `stream_ordering` is unique so we get a stable sort.
+
+ Args:
+ sync_room_map: Dictionary of room IDs to sort along with membership
+ information in the room at the time of `to_token`.
+ to_token: We sort based on the events in the room at this token (<= `to_token`)
+
+ Returns:
+ A sorted list of room IDs by `stream_ordering` along with membership information.
+ """
+
+ # Assemble a map of room ID to the `stream_ordering` of the last activity that the
+ # user should see in the room (<= `to_token`)
+ last_activity_in_room_map: Dict[str, int] = {}
+ for room_id, room_for_user in sync_room_map.items():
+ # If they are fully-joined to the room, let's find the latest activity
+ # at/before the `to_token`.
+ if room_for_user.membership == Membership.JOIN:
+ last_event_result = (
+ await self.store.get_last_event_pos_in_room_before_stream_ordering(
+ room_id, to_token.room_key
+ )
+ )
+
+ # If the room has no events at/before the `to_token`, this is probably a
+ # mistake in the code that generates the `sync_room_map` since that should
+ # only give us rooms that the user had membership in during the token range.
+ assert last_event_result is not None
+
+ _, event_pos = last_event_result
+
+ last_activity_in_room_map[room_id] = event_pos.stream
+ else:
+ # Otherwise, if the user has left/been invited/knocked/been banned from
+ # a room, they shouldn't see anything past that point.
+ last_activity_in_room_map[room_id] = room_for_user.event_pos.stream
+
+ return sorted(
+ sync_room_map.items(),
+ # Sort by the last activity (stream_ordering) in the room
+ key=lambda room_info: last_activity_in_room_map[room_info[0]],
+ # We want descending order
+ reverse=True,
+ )
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index f275d4f35a..ee74289b6c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -817,7 +817,7 @@ class SsoHandler:
server_name = profile["avatar_url"].split("/")[-2]
media_id = profile["avatar_url"].split("/")[-1]
if self._is_mine_server_name(server_name):
- media = await self._media_repo.store.get_local_media(media_id)
+ media = await self._media_repo.store.get_local_media(media_id) # type: ignore[has-type]
if media is not None and upload_name == media.upload_name:
logger.info("skipping saving the user avatar")
return True
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 8ff45a3353..e2563428d2 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -20,6 +20,7 @@
#
import itertools
import logging
+from enum import Enum
from typing import (
TYPE_CHECKING,
AbstractSet,
@@ -27,11 +28,14 @@ from typing import (
Dict,
FrozenSet,
List,
+ Literal,
Mapping,
Optional,
Sequence,
Set,
Tuple,
+ Union,
+ overload,
)
import attr
@@ -112,12 +116,30 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100
SyncRequestKey = Tuple[Any, ...]
+class SyncVersion(Enum):
+ """
+ Enum for specifying the version of sync request. This is used to key which type of
+ sync response that we are generating.
+
+ This is different than the `sync_type` you might see used in other code below; which
+ specifies the sub-type sync request (e.g. initial_sync, full_state_sync,
+ incremental_sync) and is really only relevant for the `/sync` v2 endpoint.
+ """
+
+ # These string values are semantically significant because they are used in the the
+ # metrics
+
+ # Traditional `/sync` endpoint
+ SYNC_V2 = "sync_v2"
+ # Part of MSC3575 Sliding Sync
+ E2EE_SYNC = "e2ee_sync"
+
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class SyncConfig:
user: UserID
filter_collection: FilterCollection
is_guest: bool
- request_key: SyncRequestKey
device_id: Optional[str]
@@ -262,6 +284,47 @@ class SyncResult:
or self.device_lists
)
+ @staticmethod
+ def empty(
+ next_batch: StreamToken,
+ device_one_time_keys_count: JsonMapping,
+ device_unused_fallback_key_types: List[str],
+ ) -> "SyncResult":
+ "Return a new empty result"
+ return SyncResult(
+ next_batch=next_batch,
+ presence=[],
+ account_data=[],
+ joined=[],
+ invited=[],
+ knocked=[],
+ archived=[],
+ to_device=[],
+ device_lists=DeviceListUpdates(),
+ device_one_time_keys_count=device_one_time_keys_count,
+ device_unused_fallback_key_types=device_unused_fallback_key_types,
+ )
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class E2eeSyncResult:
+ """
+ Attributes:
+ next_batch: Token for the next sync
+ to_device: List of direct messages for the device.
+ device_lists: List of user_ids whose devices have changed
+ device_one_time_keys_count: Dict of algorithm to count for one time keys
+ for this device
+ device_unused_fallback_key_types: List of key types that have an unused fallback
+ key
+ """
+
+ next_batch: StreamToken
+ to_device: List[JsonDict]
+ device_lists: DeviceListUpdates
+ device_one_time_keys_count: JsonMapping
+ device_unused_fallback_key_types: List[str]
+
class SyncHandler:
def __init__(self, hs: "HomeServer"):
@@ -305,17 +368,68 @@ class SyncHandler:
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
+ @overload
async def wait_for_sync_for_user(
self,
requester: Requester,
sync_config: SyncConfig,
+ sync_version: Literal[SyncVersion.SYNC_V2],
+ request_key: SyncRequestKey,
since_token: Optional[StreamToken] = None,
timeout: int = 0,
full_state: bool = False,
- ) -> SyncResult:
+ ) -> SyncResult: ...
+
+ @overload
+ async def wait_for_sync_for_user(
+ self,
+ requester: Requester,
+ sync_config: SyncConfig,
+ sync_version: Literal[SyncVersion.E2EE_SYNC],
+ request_key: SyncRequestKey,
+ since_token: Optional[StreamToken] = None,
+ timeout: int = 0,
+ full_state: bool = False,
+ ) -> E2eeSyncResult: ...
+
+ @overload
+ async def wait_for_sync_for_user(
+ self,
+ requester: Requester,
+ sync_config: SyncConfig,
+ sync_version: SyncVersion,
+ request_key: SyncRequestKey,
+ since_token: Optional[StreamToken] = None,
+ timeout: int = 0,
+ full_state: bool = False,
+ ) -> Union[SyncResult, E2eeSyncResult]: ...
+
+ async def wait_for_sync_for_user(
+ self,
+ requester: Requester,
+ sync_config: SyncConfig,
+ sync_version: SyncVersion,
+ request_key: SyncRequestKey,
+ since_token: Optional[StreamToken] = None,
+ timeout: int = 0,
+ full_state: bool = False,
+ ) -> Union[SyncResult, E2eeSyncResult]:
"""Get the sync for a client if we have new data for it now. Otherwise
wait for new data to arrive on the server. If the timeout expires, then
return an empty sync result.
+
+ Args:
+ requester: The user requesting the sync response.
+ sync_config: Config/info necessary to process the sync request.
+ sync_version: Determines what kind of sync response to generate.
+ request_key: The key to use for caching the response.
+ since_token: The point in the stream to sync from.
+ timeout: How long to wait for new data to arrive before giving up.
+ full_state: Whether to return the full state for each room.
+
+ Returns:
+ When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
+ When `SyncVersion.E2EE_SYNC`, returns a `E2eeSyncResult`.
"""
# If the user is not part of the mau group, then check that limits have
# not been exceeded (if not part of the group by this point, almost certain
@@ -324,9 +438,10 @@ class SyncHandler:
await self.auth_blocking.check_auth_blocking(requester=requester)
res = await self.response_cache.wrap(
- sync_config.request_key,
+ request_key,
self._wait_for_sync_for_user,
sync_config,
+ sync_version,
since_token,
timeout,
full_state,
@@ -335,14 +450,48 @@ class SyncHandler:
logger.debug("Returning sync response for %s", user_id)
return res
+ @overload
async def _wait_for_sync_for_user(
self,
sync_config: SyncConfig,
+ sync_version: Literal[SyncVersion.SYNC_V2],
since_token: Optional[StreamToken],
timeout: int,
full_state: bool,
cache_context: ResponseCacheContext[SyncRequestKey],
- ) -> SyncResult:
+ ) -> SyncResult: ...
+
+ @overload
+ async def _wait_for_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ sync_version: Literal[SyncVersion.E2EE_SYNC],
+ since_token: Optional[StreamToken],
+ timeout: int,
+ full_state: bool,
+ cache_context: ResponseCacheContext[SyncRequestKey],
+ ) -> E2eeSyncResult: ...
+
+ @overload
+ async def _wait_for_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ sync_version: SyncVersion,
+ since_token: Optional[StreamToken],
+ timeout: int,
+ full_state: bool,
+ cache_context: ResponseCacheContext[SyncRequestKey],
+ ) -> Union[SyncResult, E2eeSyncResult]: ...
+
+ async def _wait_for_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ sync_version: SyncVersion,
+ since_token: Optional[StreamToken],
+ timeout: int,
+ full_state: bool,
+ cache_context: ResponseCacheContext[SyncRequestKey],
+ ) -> Union[SyncResult, E2eeSyncResult]:
"""The start of the machinery that produces a /sync response.
See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details.
@@ -363,9 +512,50 @@ class SyncHandler:
else:
sync_type = "incremental_sync"
+ sync_label = f"{sync_version}:{sync_type}"
+
context = current_context()
if context:
- context.tag = sync_type
+ context.tag = sync_label
+
+ if since_token is not None:
+ # We need to make sure this worker has caught up with the token. If
+ # this returns false it means we timed out waiting, and we should
+ # just return an empty response.
+ start = self.clock.time_msec()
+ if not await self.notifier.wait_for_stream_token(since_token):
+ logger.warning(
+ "Timed out waiting for worker to catch up. Returning empty response"
+ )
+ device_id = sync_config.device_id
+ one_time_keys_count: JsonMapping = {}
+ unused_fallback_key_types: List[str] = []
+ if device_id:
+ user_id = sync_config.user.to_string()
+ # TODO: We should have a way to let clients differentiate between the states of:
+ # * no change in OTK count since the provided since token
+ # * the server has zero OTKs left for this device
+ # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
+ one_time_keys_count = await self.store.count_e2e_one_time_keys(
+ user_id, device_id
+ )
+ unused_fallback_key_types = list(
+ await self.store.get_e2e_unused_fallback_key_types(
+ user_id, device_id
+ )
+ )
+
+ cache_context.should_cache = False # Don't cache empty responses
+ return SyncResult.empty(
+ since_token, one_time_keys_count, unused_fallback_key_types
+ )
+
+ # If we've spent significant time waiting to catch up, take it off
+ # the timeout.
+ now = self.clock.time_msec()
+ if now - start > 1_000:
+ timeout -= now - start
+ timeout = max(timeout, 0)
# if we have a since token, delete any to-device messages before that token
# (since we now know that the device has received them)
@@ -383,15 +573,19 @@ class SyncHandler:
if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling
# notifier.wait_for_events.
- result: SyncResult = await self.current_sync_for_user(
- sync_config, since_token, full_state=full_state
+ result: Union[SyncResult, E2eeSyncResult] = (
+ await self.current_sync_for_user(
+ sync_config, sync_version, since_token, full_state=full_state
+ )
)
else:
# Otherwise, we wait for something to happen and report it to the user.
async def current_sync_callback(
before_token: StreamToken, after_token: StreamToken
- ) -> SyncResult:
- return await self.current_sync_for_user(sync_config, since_token)
+ ) -> Union[SyncResult, E2eeSyncResult]:
+ return await self.current_sync_for_user(
+ sync_config, sync_version, since_token
+ )
result = await self.notifier.wait_for_events(
sync_config.user.to_string(),
@@ -416,27 +610,81 @@ class SyncHandler:
lazy_loaded = "true"
else:
lazy_loaded = "false"
- non_empty_sync_counter.labels(sync_type, lazy_loaded).inc()
+ non_empty_sync_counter.labels(sync_label, lazy_loaded).inc()
return result
+ @overload
async def current_sync_for_user(
self,
sync_config: SyncConfig,
+ sync_version: Literal[SyncVersion.SYNC_V2],
since_token: Optional[StreamToken] = None,
full_state: bool = False,
- ) -> SyncResult:
- """Generates the response body of a sync result, represented as a SyncResult.
+ ) -> SyncResult: ...
+
+ @overload
+ async def current_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ sync_version: Literal[SyncVersion.E2EE_SYNC],
+ since_token: Optional[StreamToken] = None,
+ full_state: bool = False,
+ ) -> E2eeSyncResult: ...
+
+ @overload
+ async def current_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ sync_version: SyncVersion,
+ since_token: Optional[StreamToken] = None,
+ full_state: bool = False,
+ ) -> Union[SyncResult, E2eeSyncResult]: ...
+
+ async def current_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ sync_version: SyncVersion,
+ since_token: Optional[StreamToken] = None,
+ full_state: bool = False,
+ ) -> Union[SyncResult, E2eeSyncResult]:
+ """
+ Generates the response body of a sync result, represented as a
+ `SyncResult`/`E2eeSyncResult`.
This is a wrapper around `generate_sync_result` which starts an open tracing
span to track the sync. See `generate_sync_result` for the next part of your
indoctrination.
+
+ Args:
+ sync_config: Config/info necessary to process the sync request.
+ sync_version: Determines what kind of sync response to generate.
+ since_token: The point in the stream to sync from.p.
+ full_state: Whether to return the full state for each room.
+
+ Returns:
+ When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
+ When `SyncVersion.E2EE_SYNC`, returns a `E2eeSyncResult`.
"""
with start_active_span("sync.current_sync_for_user"):
log_kv({"since_token": since_token})
- sync_result = await self.generate_sync_result(
- sync_config, since_token, full_state
- )
+
+ # Go through the `/sync` v2 path
+ if sync_version == SyncVersion.SYNC_V2:
+ sync_result: Union[SyncResult, E2eeSyncResult] = (
+ await self.generate_sync_result(
+ sync_config, since_token, full_state
+ )
+ )
+ # Go through the MSC3575 Sliding Sync `/sync/e2ee` path
+ elif sync_version == SyncVersion.E2EE_SYNC:
+ sync_result = await self.generate_e2ee_sync_result(
+ sync_config, since_token
+ )
+ else:
+ raise Exception(
+ f"Unknown sync_version (this is a Synapse problem): {sync_version}"
+ )
set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
return sync_result
@@ -596,7 +844,6 @@ class SyncHandler:
sync_config.user.to_string(),
recents,
always_include_ids=current_state_ids,
- msc4115_membership_on_events=self.hs_config.experimental.msc4115_membership_on_events,
)
log_kv({"recents_after_visibility_filtering": len(recents)})
else:
@@ -682,7 +929,6 @@ class SyncHandler:
sync_config.user.to_string(),
loaded_recents,
always_include_ids=current_state_ids,
- msc4115_membership_on_events=self.hs_config.experimental.msc4115_membership_on_events,
)
loaded_recents = []
@@ -733,89 +979,6 @@ class SyncHandler:
bundled_aggregations=bundled_aggregations,
)
- async def get_state_after_event(
- self,
- event_id: str,
- state_filter: Optional[StateFilter] = None,
- await_full_state: bool = True,
- ) -> StateMap[str]:
- """
- Get the room state after the given event
-
- Args:
- event_id: event of interest
- state_filter: The state filter used to fetch state from the database.
- await_full_state: if `True`, will block if we do not yet have complete state
- at the event and `state_filter` is not satisfied by partial state.
- Defaults to `True`.
- """
- state_ids = await self._state_storage_controller.get_state_ids_for_event(
- event_id,
- state_filter=state_filter or StateFilter.all(),
- await_full_state=await_full_state,
- )
-
- # using get_metadata_for_events here (instead of get_event) sidesteps an issue
- # with redactions: if `event_id` is a redaction event, and we don't have the
- # original (possibly because it got purged), get_event will refuse to return
- # the redaction event, which isn't terribly helpful here.
- #
- # (To be fair, in that case we could assume it's *not* a state event, and
- # therefore we don't need to worry about it. But still, it seems cleaner just
- # to pull the metadata.)
- m = (await self.store.get_metadata_for_events([event_id]))[event_id]
- if m.state_key is not None and m.rejection_reason is None:
- state_ids = dict(state_ids)
- state_ids[(m.event_type, m.state_key)] = event_id
-
- return state_ids
-
- async def get_state_at(
- self,
- room_id: str,
- stream_position: StreamToken,
- state_filter: Optional[StateFilter] = None,
- await_full_state: bool = True,
- ) -> StateMap[str]:
- """Get the room state at a particular stream position
-
- Args:
- room_id: room for which to get state
- stream_position: point at which to get state
- state_filter: The state filter used to fetch state from the database.
- await_full_state: if `True`, will block if we do not yet have complete state
- at the last event in the room before `stream_position` and
- `state_filter` is not satisfied by partial state. Defaults to `True`.
- """
- # FIXME: This gets the state at the latest event before the stream ordering,
- # which might not be the same as the "current state" of the room at the time
- # of the stream token if there were multiple forward extremities at the time.
- last_event_id = await self.store.get_last_event_in_room_before_stream_ordering(
- room_id,
- end_token=stream_position.room_key,
- )
-
- if last_event_id:
- state = await self.get_state_after_event(
- last_event_id,
- state_filter=state_filter or StateFilter.all(),
- await_full_state=await_full_state,
- )
-
- else:
- # no events in this room - so presumably no state
- state = {}
-
- # (erikj) This should be rarely hit, but we've had some reports that
- # we get more state down gappy syncs than we should, so let's add
- # some logging.
- logger.info(
- "Failed to find any events in room %s at %s",
- room_id,
- stream_position.room_key,
- )
- return state
-
async def compute_summary(
self,
room_id: str,
@@ -1189,7 +1352,7 @@ class SyncHandler:
await_full_state = True
lazy_load_members = False
- state_at_timeline_end = await self.get_state_at(
+ state_at_timeline_end = await self._state_storage_controller.get_state_at(
room_id,
stream_position=end_token,
state_filter=state_filter,
@@ -1273,7 +1436,7 @@ class SyncHandler:
# We need to make sure the first event in our batch points to the
# last event in the previous batch.
last_event_id_prev_batch = (
- await self.store.get_last_event_in_room_before_stream_ordering(
+ await self.store.get_last_event_id_in_room_before_stream_ordering(
room_id,
end_token=since_token.room_key,
)
@@ -1317,7 +1480,7 @@ class SyncHandler:
else:
# We can get here if the user has ignored the senders of all
# the recent events.
- state_at_timeline_start = await self.get_state_at(
+ state_at_timeline_start = await self._state_storage_controller.get_state_at(
room_id,
stream_position=end_token,
state_filter=state_filter,
@@ -1339,14 +1502,14 @@ class SyncHandler:
# about them).
state_filter = StateFilter.all()
- state_at_previous_sync = await self.get_state_at(
+ state_at_previous_sync = await self._state_storage_controller.get_state_at(
room_id,
stream_position=since_token,
state_filter=state_filter,
await_full_state=await_full_state,
)
- state_at_timeline_end = await self.get_state_at(
+ state_at_timeline_end = await self._state_storage_controller.get_state_at(
room_id,
stream_position=end_token,
state_filter=state_filter,
@@ -1518,128 +1681,17 @@ class SyncHandler:
# See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError()
- # Note: we get the users room list *before* we get the current token, this
- # avoids checking back in history if rooms are joined after the token is fetched.
- token_before_rooms = self.event_sources.get_current_token()
- mutable_joined_room_ids = set(await self.store.get_rooms_for_user(user_id))
-
- # NB: The now_token gets changed by some of the generate_sync_* methods,
- # this is due to some of the underlying streams not supporting the ability
- # to query up to a given point.
- # Always use the `now_token` in `SyncResultBuilder`
- now_token = self.event_sources.get_current_token()
- log_kv({"now_token": now_token})
-
- # Since we fetched the users room list before the token, there's a small window
- # during which membership events may have been persisted, so we fetch these now
- # and modify the joined room list for any changes between the get_rooms_for_user
- # call and the get_current_token call.
- membership_change_events = []
- if since_token:
- membership_change_events = await self.store.get_membership_changes_for_user(
- user_id,
- since_token.room_key,
- now_token.room_key,
- self.rooms_to_exclude_globally,
- )
-
- mem_last_change_by_room_id: Dict[str, EventBase] = {}
- for event in membership_change_events:
- mem_last_change_by_room_id[event.room_id] = event
-
- # For the latest membership event in each room found, add/remove the room ID
- # from the joined room list accordingly. In this case we only care if the
- # latest change is JOIN.
-
- for room_id, event in mem_last_change_by_room_id.items():
- assert event.internal_metadata.stream_ordering
- if (
- event.internal_metadata.stream_ordering
- < token_before_rooms.room_key.stream
- ):
- continue
-
- logger.info(
- "User membership change between getting rooms and current token: %s %s %s",
- user_id,
- event.membership,
- room_id,
- )
- # User joined a room - we have to then check the room state to ensure we
- # respect any bans if there's a race between the join and ban events.
- if event.membership == Membership.JOIN:
- user_ids_in_room = await self.store.get_users_in_room(room_id)
- if user_id in user_ids_in_room:
- mutable_joined_room_ids.add(room_id)
- # The user left the room, or left and was re-invited but not joined yet
- else:
- mutable_joined_room_ids.discard(room_id)
-
- # Tweak the set of rooms to return to the client for eager (non-lazy) syncs.
- mutable_rooms_to_exclude = set(self.rooms_to_exclude_globally)
- if not sync_config.filter_collection.lazy_load_members():
- # Non-lazy syncs should never include partially stated rooms.
- # Exclude all partially stated rooms from this sync.
- results = await self.store.is_partial_state_room_batched(
- mutable_joined_room_ids
- )
- mutable_rooms_to_exclude.update(
- room_id
- for room_id, is_partial_state in results.items()
- if is_partial_state
- )
- membership_change_events = [
- event
- for event in membership_change_events
- if not results.get(event.room_id, False)
- ]
-
- # Incremental eager syncs should additionally include rooms that
- # - we are joined to
- # - are full-stated
- # - became fully-stated at some point during the sync period
- # (These rooms will have been omitted during a previous eager sync.)
- forced_newly_joined_room_ids: Set[str] = set()
- if since_token and not sync_config.filter_collection.lazy_load_members():
- un_partial_stated_rooms = (
- await self.store.get_un_partial_stated_rooms_between(
- since_token.un_partial_stated_rooms_key,
- now_token.un_partial_stated_rooms_key,
- mutable_joined_room_ids,
- )
- )
- results = await self.store.is_partial_state_room_batched(
- un_partial_stated_rooms
- )
- forced_newly_joined_room_ids.update(
- room_id
- for room_id, is_partial_state in results.items()
- if not is_partial_state
- )
-
- # Now we have our list of joined room IDs, exclude as configured and freeze
- joined_room_ids = frozenset(
- room_id
- for room_id in mutable_joined_room_ids
- if room_id not in mutable_rooms_to_exclude
+ sync_result_builder = await self.get_sync_result_builder(
+ sync_config,
+ since_token,
+ full_state,
)
logger.debug(
"Calculating sync response for %r between %s and %s",
sync_config.user,
- since_token,
- now_token,
- )
-
- sync_result_builder = SyncResultBuilder(
- sync_config,
- full_state,
- since_token=since_token,
- now_token=now_token,
- joined_room_ids=joined_room_ids,
- excluded_room_ids=frozenset(mutable_rooms_to_exclude),
- forced_newly_joined_room_ids=frozenset(forced_newly_joined_room_ids),
- membership_change_events=membership_change_events,
+ sync_result_builder.since_token,
+ sync_result_builder.now_token,
)
logger.debug("Fetching account data")
@@ -1751,6 +1803,242 @@ class SyncHandler:
next_batch=sync_result_builder.now_token,
)
+ async def generate_e2ee_sync_result(
+ self,
+ sync_config: SyncConfig,
+ since_token: Optional[StreamToken] = None,
+ ) -> E2eeSyncResult:
+ """
+ Generates the response body of a MSC3575 Sliding Sync `/sync/e2ee` result.
+
+ This is represented by a `E2eeSyncResult` struct, which is built from small
+ pieces using a `SyncResultBuilder`. The `sync_result_builder` is passed as a
+ mutable ("inout") parameter to various helper functions. These retrieve and
+ process the data which forms the sync body, often writing to the
+ `sync_result_builder` to store their output.
+
+ At the end, we transfer data from the `sync_result_builder` to a new `E2eeSyncResult`
+ instance to signify that the sync calculation is complete.
+ """
+ user_id = sync_config.user.to_string()
+ app_service = self.store.get_app_service_by_user_id(user_id)
+ if app_service:
+ # We no longer support AS users using /sync directly.
+ # See https://github.com/matrix-org/matrix-doc/issues/1144
+ raise NotImplementedError()
+
+ sync_result_builder = await self.get_sync_result_builder(
+ sync_config,
+ since_token,
+ full_state=False,
+ )
+
+ # 1. Calculate `to_device` events
+ await self._generate_sync_entry_for_to_device(sync_result_builder)
+
+ # 2. Calculate `device_lists`
+ # Device list updates are sent if a since token is provided.
+ device_lists = DeviceListUpdates()
+ include_device_list_updates = bool(since_token and since_token.device_list_key)
+ if include_device_list_updates:
+ # Note that _generate_sync_entry_for_rooms sets sync_result_builder.joined, which
+ # is used in calculate_user_changes below.
+ #
+ # TODO: Running `_generate_sync_entry_for_rooms()` is a lot of work just to
+ # figure out the membership changes/derived info needed for
+ # `_generate_sync_entry_for_device_list()`. In the future, we should try to
+ # refactor this away.
+ (
+ newly_joined_rooms,
+ newly_left_rooms,
+ ) = await self._generate_sync_entry_for_rooms(sync_result_builder)
+
+ # This uses the sync_result_builder.joined which is set in
+ # `_generate_sync_entry_for_rooms`, if that didn't find any joined
+ # rooms for some reason it is a no-op.
+ (
+ newly_joined_or_invited_or_knocked_users,
+ newly_left_users,
+ ) = sync_result_builder.calculate_user_changes()
+
+ device_lists = await self._generate_sync_entry_for_device_list(
+ sync_result_builder,
+ newly_joined_rooms=newly_joined_rooms,
+ newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
+ newly_left_rooms=newly_left_rooms,
+ newly_left_users=newly_left_users,
+ )
+
+ # 3. Calculate `device_one_time_keys_count` and `device_unused_fallback_key_types`
+ device_id = sync_config.device_id
+ one_time_keys_count: JsonMapping = {}
+ unused_fallback_key_types: List[str] = []
+ if device_id:
+ # TODO: We should have a way to let clients differentiate between the states of:
+ # * no change in OTK count since the provided since token
+ # * the server has zero OTKs left for this device
+ # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
+ one_time_keys_count = await self.store.count_e2e_one_time_keys(
+ user_id, device_id
+ )
+ unused_fallback_key_types = list(
+ await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
+ )
+
+ return E2eeSyncResult(
+ to_device=sync_result_builder.to_device,
+ device_lists=device_lists,
+ device_one_time_keys_count=one_time_keys_count,
+ device_unused_fallback_key_types=unused_fallback_key_types,
+ next_batch=sync_result_builder.now_token,
+ )
+
+ async def get_sync_result_builder(
+ self,
+ sync_config: SyncConfig,
+ since_token: Optional[StreamToken] = None,
+ full_state: bool = False,
+ ) -> "SyncResultBuilder":
+ """
+ Assemble a `SyncResultBuilder` with all of the initial context to
+ start building up the sync response:
+
+ - Membership changes between the last sync and the current sync.
+ - Joined room IDs (minus any rooms to exclude).
+ - Rooms that became fully-stated/un-partial stated since the last sync.
+
+ Args:
+ sync_config: Config/info necessary to process the sync request.
+ since_token: The point in the stream to sync from.
+ full_state: Whether to return the full state for each room.
+
+ Returns:
+ `SyncResultBuilder` ready to start generating parts of the sync response.
+ """
+ user_id = sync_config.user.to_string()
+
+ # Note: we get the users room list *before* we get the `now_token`, this
+ # avoids checking back in history if rooms are joined after the token is fetched.
+ token_before_rooms = self.event_sources.get_current_token()
+ mutable_joined_room_ids = set(await self.store.get_rooms_for_user(user_id))
+
+ # NB: The `now_token` gets changed by some of the `generate_sync_*` methods,
+ # this is due to some of the underlying streams not supporting the ability
+ # to query up to a given point.
+ # Always use the `now_token` in `SyncResultBuilder`
+ now_token = self.event_sources.get_current_token()
+ log_kv({"now_token": now_token})
+
+ # Since we fetched the users room list before calculating the `now_token` (see
+ # above), there's a small window during which membership events may have been
+ # persisted, so we fetch these now and modify the joined room list for any
+ # changes between the get_rooms_for_user call and the get_current_token call.
+ membership_change_events = []
+ if since_token:
+ membership_change_events = await self.store.get_membership_changes_for_user(
+ user_id,
+ since_token.room_key,
+ now_token.room_key,
+ self.rooms_to_exclude_globally,
+ )
+
+ last_membership_change_by_room_id: Dict[str, EventBase] = {}
+ for event in membership_change_events:
+ last_membership_change_by_room_id[event.room_id] = event
+
+ # For the latest membership event in each room found, add/remove the room ID
+ # from the joined room list accordingly. In this case we only care if the
+ # latest change is JOIN.
+
+ for room_id, event in last_membership_change_by_room_id.items():
+ assert event.internal_metadata.stream_ordering
+ # As a shortcut, skip any events that happened before we got our
+ # `get_rooms_for_user()` snapshot (any changes are already represented
+ # in that list).
+ if (
+ event.internal_metadata.stream_ordering
+ < token_before_rooms.room_key.stream
+ ):
+ continue
+
+ logger.info(
+ "User membership change between getting rooms and current token: %s %s %s",
+ user_id,
+ event.membership,
+ room_id,
+ )
+ # User joined a room - we have to then check the room state to ensure we
+ # respect any bans if there's a race between the join and ban events.
+ if event.membership == Membership.JOIN:
+ user_ids_in_room = await self.store.get_users_in_room(room_id)
+ if user_id in user_ids_in_room:
+ mutable_joined_room_ids.add(room_id)
+ # The user left the room, or left and was re-invited but not joined yet
+ else:
+ mutable_joined_room_ids.discard(room_id)
+
+ # Tweak the set of rooms to return to the client for eager (non-lazy) syncs.
+ mutable_rooms_to_exclude = set(self.rooms_to_exclude_globally)
+ if not sync_config.filter_collection.lazy_load_members():
+ # Non-lazy syncs should never include partially stated rooms.
+ # Exclude all partially stated rooms from this sync.
+ results = await self.store.is_partial_state_room_batched(
+ mutable_joined_room_ids
+ )
+ mutable_rooms_to_exclude.update(
+ room_id
+ for room_id, is_partial_state in results.items()
+ if is_partial_state
+ )
+ membership_change_events = [
+ event
+ for event in membership_change_events
+ if not results.get(event.room_id, False)
+ ]
+
+ # Incremental eager syncs should additionally include rooms that
+ # - we are joined to
+ # - are full-stated
+ # - became fully-stated at some point during the sync period
+ # (These rooms will have been omitted during a previous eager sync.)
+ forced_newly_joined_room_ids: Set[str] = set()
+ if since_token and not sync_config.filter_collection.lazy_load_members():
+ un_partial_stated_rooms = (
+ await self.store.get_un_partial_stated_rooms_between(
+ since_token.un_partial_stated_rooms_key,
+ now_token.un_partial_stated_rooms_key,
+ mutable_joined_room_ids,
+ )
+ )
+ results = await self.store.is_partial_state_room_batched(
+ un_partial_stated_rooms
+ )
+ forced_newly_joined_room_ids.update(
+ room_id
+ for room_id, is_partial_state in results.items()
+ if not is_partial_state
+ )
+
+ # Now we have our list of joined room IDs, exclude as configured and freeze
+ joined_room_ids = frozenset(
+ room_id
+ for room_id in mutable_joined_room_ids
+ if room_id not in mutable_rooms_to_exclude
+ )
+
+ sync_result_builder = SyncResultBuilder(
+ sync_config,
+ full_state,
+ since_token=since_token,
+ now_token=now_token,
+ joined_room_ids=joined_room_ids,
+ excluded_room_ids=frozenset(mutable_rooms_to_exclude),
+ forced_newly_joined_room_ids=frozenset(forced_newly_joined_room_ids),
+ membership_change_events=membership_change_events,
+ )
+
+ return sync_result_builder
+
@measure_func("_generate_sync_entry_for_device_list")
async def _generate_sync_entry_for_device_list(
self,
@@ -1799,42 +2087,18 @@ class SyncHandler:
users_that_have_changed = set()
- joined_rooms = sync_result_builder.joined_room_ids
+ joined_room_ids = sync_result_builder.joined_room_ids
# Step 1a, check for changes in devices of users we share a room
# with
- #
- # We do this in two different ways depending on what we have cached.
- # If we already have a list of all the user that have changed since
- # the last sync then it's likely more efficient to compare the rooms
- # they're in with the rooms the syncing user is in.
- #
- # If we don't have that info cached then we get all the users that
- # share a room with our user and check if those users have changed.
- cache_result = self.store.get_cached_device_list_changes(
- since_token.device_list_key
- )
- if cache_result.hit:
- changed_users = cache_result.entities
-
- result = await self.store.get_rooms_for_users(changed_users)
-
- for changed_user_id, entries in result.items():
- # Check if the changed user shares any rooms with the user,
- # or if the changed user is the syncing user (as we always
- # want to include device list updates of their own devices).
- if user_id == changed_user_id or any(
- rid in joined_rooms for rid in entries
- ):
- users_that_have_changed.add(changed_user_id)
- else:
- users_that_have_changed = (
- await self._device_handler.get_device_changes_in_shared_rooms(
- user_id,
- sync_result_builder.joined_room_ids,
- from_token=since_token,
- )
+ users_that_have_changed = (
+ await self._device_handler.get_device_changes_in_shared_rooms(
+ user_id,
+ joined_room_ids,
+ from_token=since_token,
+ now_token=sync_result_builder.now_token,
)
+ )
# Step 1b, check for newly joined rooms
for room_id in newly_joined_rooms:
@@ -1858,7 +2122,7 @@ class SyncHandler:
# Remove any users that we still share a room with.
left_users_rooms = await self.store.get_rooms_for_users(newly_left_users)
for user_id, entries in left_users_rooms.items():
- if any(rid in joined_rooms for rid in entries):
+ if any(rid in joined_room_ids for rid in entries):
newly_left_users.discard(user_id)
return DeviceListUpdates(changed=users_that_have_changed, left=newly_left_users)
@@ -1945,23 +2209,19 @@ class SyncHandler:
)
if push_rules_changed:
- global_account_data = {
- AccountDataTypes.PUSH_RULES: await self._push_rules_handler.push_rules_for_user(
- sync_config.user
- ),
- **global_account_data,
- }
+ global_account_data = dict(global_account_data)
+ global_account_data[AccountDataTypes.PUSH_RULES] = (
+ await self._push_rules_handler.push_rules_for_user(sync_config.user)
+ )
else:
all_global_account_data = await self.store.get_global_account_data_for_user(
user_id
)
- global_account_data = {
- AccountDataTypes.PUSH_RULES: await self._push_rules_handler.push_rules_for_user(
- sync_config.user
- ),
- **all_global_account_data,
- }
+ global_account_data = dict(all_global_account_data)
+ global_account_data[AccountDataTypes.PUSH_RULES] = (
+ await self._push_rules_handler.push_rules_for_user(sync_config.user)
+ )
account_data_for_user = (
await sync_config.filter_collection.filter_global_account_data(
@@ -2248,7 +2508,7 @@ class SyncHandler:
continue
if room_id in sync_result_builder.joined_room_ids or has_join:
- old_state_ids = await self.get_state_at(
+ old_state_ids = await self._state_storage_controller.get_state_at(
room_id,
since_token,
state_filter=StateFilter.from_types([(EventTypes.Member, user_id)]),
@@ -2278,12 +2538,14 @@ class SyncHandler:
newly_left_rooms.append(room_id)
else:
if not old_state_ids:
- old_state_ids = await self.get_state_at(
- room_id,
- since_token,
- state_filter=StateFilter.from_types(
- [(EventTypes.Member, user_id)]
- ),
+ old_state_ids = (
+ await self._state_storage_controller.get_state_at(
+ room_id,
+ since_token,
+ state_filter=StateFilter.from_types(
+ [(EventTypes.Member, user_id)]
+ ),
+ )
)
old_mem_ev_id = old_state_ids.get(
(EventTypes.Member, user_id), None
@@ -2488,7 +2750,7 @@ class SyncHandler:
continue
leave_token = now_token.copy_and_replace(
- StreamKeyType.ROOM, RoomStreamToken(stream=event.stream_ordering)
+ StreamKeyType.ROOM, RoomStreamToken(stream=event.event_pos.stream)
)
room_entries.append(
RoomSyncResultBuilder(
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 7619d91c98..4c87718337 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -477,9 +477,9 @@ class TypingWriterHandler(FollowerTypingHandler):
rows = []
for room_id in changed_rooms:
- serial = self._room_serials[room_id]
- if last_id < serial <= current_id:
- typing = self._room_typing[room_id]
+ serial = self._room_serials.get(room_id)
+ if serial and last_id < serial <= current_id:
+ typing = self._room_typing.get(room_id, set())
rows.append((serial, [room_id, list(typing)]))
rows.sort()
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c73a589e6c..104b803b0f 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -57,7 +57,7 @@ from twisted.internet.interfaces import IReactorTime
from twisted.internet.task import Cooperator
from twisted.web.client import ResponseFailed
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IBodyProducer, IResponse
+from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
import synapse.metrics
import synapse.util.retryutils
@@ -68,6 +68,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import (
@@ -1411,9 +1412,11 @@ class MatrixFederationHttpClient:
destination: str,
path: str,
output_stream: BinaryIO,
+ download_ratelimiter: Ratelimiter,
+ ip_address: str,
+ max_size: int,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
- max_size: Optional[int] = None,
ignore_backoff: bool = False,
follow_redirects: bool = False,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
@@ -1422,6 +1425,10 @@ class MatrixFederationHttpClient:
destination: The remote server to send the HTTP request to.
path: The HTTP path to GET.
output_stream: File to write the response body to.
+ download_ratelimiter: a ratelimiter to limit remote media downloads, keyed to
+ requester IP
+ ip_address: IP address of the requester
+ max_size: maximum allowable size in bytes of the file
args: Optional dictionary used to create the query string.
ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
@@ -1441,11 +1448,27 @@ class MatrixFederationHttpClient:
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
+ SynapseError: If the requested file exceeds ratelimits
"""
request = MatrixFederationRequest(
method="GET", destination=destination, path=path, query=args
)
+ # check for a minimum balance of 1MiB in ratelimiter before initiating request
+ send_req, _ = await download_ratelimiter.can_do_action(
+ requester=None, key=ip_address, n_actions=1048576, update=False
+ )
+
+ if not send_req:
+ msg = "Requested file size exceeds ratelimits"
+ logger.warning(
+ "{%s} [%s] %s",
+ request.txn_id,
+ request.destination,
+ msg,
+ )
+ raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
+
response = await self._send_request(
request,
retry_on_dns_fail=retry_on_dns_fail,
@@ -1455,12 +1478,36 @@ class MatrixFederationHttpClient:
headers = dict(response.headers.getAllRawHeaders())
+ expected_size = response.length
+ # if we don't get an expected length then use the max length
+ if expected_size == UNKNOWN_LENGTH:
+ expected_size = max_size
+ logger.debug(
+ f"File size unknown, assuming file is max allowable size: {max_size}"
+ )
+
+ read_body, _ = await download_ratelimiter.can_do_action(
+ requester=None,
+ key=ip_address,
+ n_actions=expected_size,
+ )
+ if not read_body:
+ msg = "Requested file size exceeds ratelimits"
+ logger.warning(
+ "{%s} [%s] %s",
+ request.txn_id,
+ request.destination,
+ msg,
+ )
+ raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
+
try:
- d = read_body_with_max_size(response, output_stream, max_size)
+ # add a byte of headroom to max size as function errs at >=
+ d = read_body_with_max_size(response, output_stream, expected_size + 1)
d.addTimeout(self.default_timeout_seconds, self.reactor)
length = await make_deferred_yieldable(d)
except BodyExceededMaxSize:
- msg = "Requested file is too large > %r bytes" % (max_size,)
+ msg = "Requested file is too large > %r bytes" % (expected_size,)
logger.warning(
"{%s} [%s] %s",
request.txn_id,
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index ab12951da8..08b8ff7afd 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -119,14 +119,15 @@ def parse_integer(
default: value to use if the parameter is absent, defaults to None.
required: whether to raise a 400 SynapseError if the parameter is absent,
defaults to False.
- negative: whether to allow negative integers, defaults to True.
+ negative: whether to allow negative integers, defaults to False (disallowing
+ negatives).
Returns:
An int value or the default.
Raises:
SynapseError: if the parameter is absent and required, if the
parameter is present and not an integer, or if the
- parameter is illegitimate negative.
+ parameter is illegitimately negative.
"""
args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore
return parse_integer_from_args(args, name, default, required, negative)
@@ -164,7 +165,7 @@ def parse_integer_from_args(
name: str,
default: Optional[int] = None,
required: bool = False,
- negative: bool = True,
+ negative: bool = False,
) -> Optional[int]:
"""Parse an integer parameter from the request string
@@ -174,7 +175,8 @@ def parse_integer_from_args(
default: value to use if the parameter is absent, defaults to None.
required: whether to raise a 400 SynapseError if the parameter is absent,
defaults to False.
- negative: whether to allow negative integers, defaults to True.
+ negative: whether to allow negative integers, defaults to False (disallowing
+ negatives).
Returns:
An int value or the default.
@@ -182,7 +184,7 @@ def parse_integer_from_args(
Raises:
SynapseError: if the parameter is absent and required, if the
parameter is present and not an integer, or if the
- parameter is illegitimate negative.
+ parameter is illegitimately negative.
"""
name_bytes = name.encode("ascii")
diff --git a/synapse/media/_base.py b/synapse/media/_base.py
index 3fbed6062f..7ad0b7c3cf 100644
--- a/synapse/media/_base.py
+++ b/synapse/media/_base.py
@@ -25,7 +25,16 @@ import os
import urllib
from abc import ABC, abstractmethod
from types import TracebackType
-from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type
+from typing import (
+ TYPE_CHECKING,
+ Awaitable,
+ Dict,
+ Generator,
+ List,
+ Optional,
+ Tuple,
+ Type,
+)
import attr
@@ -37,8 +46,13 @@ from synapse.api.errors import Codes, cs_error
from synapse.http.server import finish_request, respond_with_json
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
+from synapse.util import Clock
from synapse.util.stringutils import is_ascii
+if TYPE_CHECKING:
+ from synapse.storage.databases.main.media_repository import LocalMedia
+
+
logger = logging.getLogger(__name__)
# list all text content types that will have the charset default to UTF-8 when
@@ -260,6 +274,68 @@ def _can_encode_filename_as_token(x: str) -> bool:
return True
+async def respond_with_multipart_responder(
+ clock: Clock,
+ request: SynapseRequest,
+ responder: "Optional[Responder]",
+ media_info: "LocalMedia",
+) -> None:
+ """
+ Responds to requests originating from the federation media `/download` endpoint by
+ streaming a multipart/mixed response
+
+ Args:
+ clock:
+ request: the federation request to respond to
+ responder: the responder which will send the response
+ media_info: metadata about the media item
+ """
+ if not responder:
+ respond_404(request)
+ return
+
+ # If we have a responder we *must* use it as a context manager.
+ with responder:
+ if request._disconnected:
+ logger.warning(
+ "Not sending response to request %s, already disconnected.", request
+ )
+ return
+
+ from synapse.media.media_storage import MultipartFileConsumer
+
+ # note that currently the json_object is just {}, this will change when linked media
+ # is implemented
+ multipart_consumer = MultipartFileConsumer(
+ clock, request, media_info.media_type, {}, media_info.media_length
+ )
+
+ logger.debug("Responding to media request with responder %s", responder)
+ if media_info.media_length is not None:
+ content_length = multipart_consumer.content_length()
+ assert content_length is not None
+ request.setHeader(b"Content-Length", b"%d" % (content_length,))
+
+ request.setHeader(
+ b"Content-Type",
+ b"multipart/mixed; boundary=%s" % multipart_consumer.boundary,
+ )
+
+ try:
+ await responder.write_to_consumer(multipart_consumer)
+ except Exception as e:
+ # The majority of the time this will be due to the client having gone
+ # away. Unfortunately, Twisted simply throws a generic exception at us
+ # in that case.
+ logger.warning("Failed to write to consumer: %s %s", type(e), e)
+
+ # Unregister the producer, if it has one, so Twisted doesn't complain
+ if request.producer:
+ request.unregisterProducer()
+
+ finish_request(request)
+
+
async def respond_with_responder(
request: SynapseRequest,
responder: "Optional[Responder]",
diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py
index 0e875132f6..1436329fad 100644
--- a/synapse/media/media_repository.py
+++ b/synapse/media/media_repository.py
@@ -42,6 +42,7 @@ from synapse.api.errors import (
SynapseError,
cs_error,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.config.repository import ThumbnailRequirement
from synapse.http.server import respond_with_json
from synapse.http.site import SynapseRequest
@@ -53,6 +54,7 @@ from synapse.media._base import (
ThumbnailInfo,
get_filename_from_headers,
respond_404,
+ respond_with_multipart_responder,
respond_with_responder,
)
from synapse.media.filepath import MediaFilePaths
@@ -111,6 +113,12 @@ class MediaRepository:
)
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
+ self.download_ratelimiter = Ratelimiter(
+ store=hs.get_storage_controllers().main,
+ clock=hs.get_clock(),
+ cfg=hs.config.ratelimiting.remote_media_downloads,
+ )
+
# List of StorageProviders where we should search for media and
# potentially upload to.
storage_providers = []
@@ -422,6 +430,7 @@ class MediaRepository:
media_id: str,
name: Optional[str],
max_timeout_ms: int,
+ federation: bool = False,
) -> None:
"""Responds to requests for local media, if exists, or returns 404.
@@ -433,6 +442,7 @@ class MediaRepository:
the filename in the Content-Disposition header of the response.
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
+ federation: whether the local media being fetched is for a federation request
Returns:
Resolves once a response has successfully been written to request
@@ -453,9 +463,14 @@ class MediaRepository:
file_info = FileInfo(None, media_id, url_cache=bool(url_cache))
responder = await self.media_storage.fetch_media(file_info)
- await respond_with_responder(
- request, responder, media_type, media_length, upload_name
- )
+ if federation:
+ await respond_with_multipart_responder(
+ self.clock, request, responder, media_info
+ )
+ else:
+ await respond_with_responder(
+ request, responder, media_type, media_length, upload_name
+ )
async def get_remote_media(
self,
@@ -464,6 +479,7 @@ class MediaRepository:
media_id: str,
name: Optional[str],
max_timeout_ms: int,
+ ip_address: str,
) -> None:
"""Respond to requests for remote media.
@@ -475,6 +491,7 @@ class MediaRepository:
the filename in the Content-Disposition header of the response.
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
+ ip_address: the IP address of the requester
Returns:
Resolves once a response has successfully been written to request
@@ -500,7 +517,11 @@ class MediaRepository:
key = (server_name, media_id)
async with self.remote_media_linearizer.queue(key):
responder, media_info = await self._get_remote_media_impl(
- server_name, media_id, max_timeout_ms
+ server_name,
+ media_id,
+ max_timeout_ms,
+ self.download_ratelimiter,
+ ip_address,
)
# We deliberately stream the file outside the lock
@@ -517,7 +538,7 @@ class MediaRepository:
respond_404(request)
async def get_remote_media_info(
- self, server_name: str, media_id: str, max_timeout_ms: int
+ self, server_name: str, media_id: str, max_timeout_ms: int, ip_address: str
) -> RemoteMedia:
"""Gets the media info associated with the remote file, downloading
if necessary.
@@ -527,6 +548,7 @@ class MediaRepository:
media_id: The media ID of the content (as defined by the remote server).
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
+ ip_address: IP address of the requester
Returns:
The media info of the file
@@ -542,7 +564,11 @@ class MediaRepository:
key = (server_name, media_id)
async with self.remote_media_linearizer.queue(key):
responder, media_info = await self._get_remote_media_impl(
- server_name, media_id, max_timeout_ms
+ server_name,
+ media_id,
+ max_timeout_ms,
+ self.download_ratelimiter,
+ ip_address,
)
# Ensure we actually use the responder so that it releases resources
@@ -553,7 +579,12 @@ class MediaRepository:
return media_info
async def _get_remote_media_impl(
- self, server_name: str, media_id: str, max_timeout_ms: int
+ self,
+ server_name: str,
+ media_id: str,
+ max_timeout_ms: int,
+ download_ratelimiter: Ratelimiter,
+ ip_address: str,
) -> Tuple[Optional[Responder], RemoteMedia]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -564,6 +595,9 @@ class MediaRepository:
remote server).
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
+ download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
+ requester IP.
+ ip_address: the IP address of the requester
Returns:
A tuple of responder and the media info of the file.
@@ -596,7 +630,7 @@ class MediaRepository:
try:
media_info = await self._download_remote_file(
- server_name, media_id, max_timeout_ms
+ server_name, media_id, max_timeout_ms, download_ratelimiter, ip_address
)
except SynapseError:
raise
@@ -630,6 +664,8 @@ class MediaRepository:
server_name: str,
media_id: str,
max_timeout_ms: int,
+ download_ratelimiter: Ratelimiter,
+ ip_address: str,
) -> RemoteMedia:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@@ -641,6 +677,9 @@ class MediaRepository:
locally generated.
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
+ download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
+ requester IP
+ ip_address: the IP address of the requester
Returns:
The media info of the file.
@@ -650,7 +689,7 @@ class MediaRepository:
file_info = FileInfo(server_name=server_name, file_id=file_id)
- with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ async with self.media_storage.store_into_file(file_info) as (f, fname):
try:
length, headers = await self.client.download_media(
server_name,
@@ -658,6 +697,8 @@ class MediaRepository:
output_stream=f,
max_size=self.max_upload_size,
max_timeout_ms=max_timeout_ms,
+ download_ratelimiter=download_ratelimiter,
+ ip_address=ip_address,
)
except RequestSendFailed as e:
logger.warning(
@@ -693,8 +734,6 @@ class MediaRepository:
)
raise SynapseError(502, "Failed to fetch remote media")
- await finish()
-
if b"Content-Type" in headers:
media_type = headers[b"Content-Type"][0].decode("ascii")
else:
@@ -1045,17 +1084,17 @@ class MediaRepository:
),
)
- with self.media_storage.store_into_file(file_info) as (
- f,
- fname,
- finish,
- ):
+ async with self.media_storage.store_into_file(file_info) as (f, fname):
try:
await self.media_storage.write_to_file(t_byte_source, f)
- await finish()
finally:
t_byte_source.close()
+ # We flush and close the file to ensure that the bytes have
+ # been written before getting the size.
+ f.flush()
+ f.close()
+
t_len = os.path.getsize(fname)
# Write to database
diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py
index b45b319f5c..1be2c9b5f5 100644
--- a/synapse/media/media_storage.py
+++ b/synapse/media/media_storage.py
@@ -19,36 +19,49 @@
#
#
import contextlib
+import json
import logging
import os
import shutil
+from contextlib import closing
+from io import BytesIO
from types import TracebackType
from typing import (
IO,
TYPE_CHECKING,
Any,
- Awaitable,
+ AsyncIterator,
BinaryIO,
Callable,
- Generator,
+ List,
Optional,
Sequence,
Tuple,
Type,
+ Union,
+ cast,
)
+from uuid import uuid4
import attr
+from zope.interface import implementer
+from twisted.internet import interfaces
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from synapse.api.errors import NotFoundError
-from synapse.logging.context import defer_to_thread, make_deferred_yieldable
+from synapse.logging.context import (
+ defer_to_thread,
+ make_deferred_yieldable,
+ run_in_background,
+)
from synapse.logging.opentracing import start_active_span, trace, trace_with_opname
from synapse.util import Clock
from synapse.util.file_consumer import BackgroundFileConsumer
+from ..types import JsonDict
from ._base import FileInfo, Responder
from .filepath import MediaFilePaths
@@ -58,6 +71,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+CRLF = b"\r\n"
+
class MediaStorage:
"""Responsible for storing/fetching files from local sources.
@@ -97,11 +112,9 @@ class MediaStorage:
the file path written to in the primary media store
"""
- with self.store_into_file(file_info) as (f, fname, finish_cb):
+ async with self.store_into_file(file_info) as (f, fname):
# Write to the main media repository
await self.write_to_file(source, f)
- # Write to the other storage providers
- await finish_cb()
return fname
@@ -111,32 +124,27 @@ class MediaStorage:
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
@trace_with_opname("MediaStorage.store_into_file")
- @contextlib.contextmanager
- def store_into_file(
+ @contextlib.asynccontextmanager
+ async def store_into_file(
self, file_info: FileInfo
- ) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
- """Context manager used to get a file like object to write into, as
+ ) -> AsyncIterator[Tuple[BinaryIO, str]]:
+ """Async Context manager used to get a file like object to write into, as
described by file_info.
- Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
- like object that can be written to, fname is the absolute path of file
- on disk, and finish_cb is a function that returns an awaitable.
+ Actually yields a 2-tuple (file, fname,), where file is a file
+ like object that can be written to and fname is the absolute path of file
+ on disk.
fname can be used to read the contents from after upload, e.g. to
generate thumbnails.
- finish_cb must be called and waited on after the file has been successfully been
- written to. Should not be called if there was an error. Checks for spam and
- stores the file into the configured storage providers.
-
Args:
file_info: Info about the file to store
Example:
- with media_storage.store_into_file(info) as (f, fname, finish_cb):
+ async with media_storage.store_into_file(info) as (f, fname,):
# .. write into f ...
- await finish_cb()
"""
path = self._file_info_to_path(file_info)
@@ -145,69 +153,44 @@ class MediaStorage:
dirname = os.path.dirname(fname)
os.makedirs(dirname, exist_ok=True)
- finished_called = [False]
-
- main_media_repo_write_trace_scope = start_active_span(
- "writing to main media repo"
- )
- main_media_repo_write_trace_scope.__enter__()
-
try:
- with open(fname, "wb") as f:
-
- async def finish() -> None:
- # When someone calls finish, we assume they are done writing to the main media repo
- main_media_repo_write_trace_scope.__exit__(None, None, None)
-
- with start_active_span("writing to other storage providers"):
- # Ensure that all writes have been flushed and close the
- # file.
- f.flush()
- f.close()
-
- spam_check = await self._spam_checker_module_callbacks.check_media_file_for_spam(
- ReadableFileWrapper(self.clock, fname), file_info
- )
- if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
- logger.info("Blocking media due to spam checker")
- # Note that we'll delete the stored media, due to the
- # try/except below. The media also won't be stored in
- # the DB.
- # We currently ignore any additional field returned by
- # the spam-check API.
- raise SpamMediaException(errcode=spam_check[0])
-
- for provider in self.storage_providers:
- with start_active_span(str(provider)):
- await provider.store_file(path, file_info)
-
- finished_called[0] = True
-
- yield f, fname, finish
+ with start_active_span("writing to main media repo"):
+ with open(fname, "wb") as f:
+ yield f, fname
+
+ with start_active_span("writing to other storage providers"):
+ spam_check = (
+ await self._spam_checker_module_callbacks.check_media_file_for_spam(
+ ReadableFileWrapper(self.clock, fname), file_info
+ )
+ )
+ if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
+ logger.info("Blocking media due to spam checker")
+ # Note that we'll delete the stored media, due to the
+ # try/except below. The media also won't be stored in
+ # the DB.
+ # We currently ignore any additional field returned by
+ # the spam-check API.
+ raise SpamMediaException(errcode=spam_check[0])
+
+ for provider in self.storage_providers:
+ with start_active_span(str(provider)):
+ await provider.store_file(path, file_info)
+
except Exception as e:
try:
- main_media_repo_write_trace_scope.__exit__(
- type(e), None, e.__traceback__
- )
os.remove(fname)
except Exception:
pass
raise e from None
- if not finished_called:
- exc = Exception("Finished callback not called")
- main_media_repo_write_trace_scope.__exit__(
- type(exc), None, exc.__traceback__
- )
- raise exc
-
async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
"""Attempts to fetch media described by file_info from the local cache
and configured storage providers.
Args:
- file_info
+ file_info: Metadata about the media file
Returns:
Returns a Responder if the file was found, otherwise None.
@@ -349,7 +332,7 @@ class FileResponder(Responder):
"""Wraps an open file that can be sent to a request.
Args:
- open_file: A file like object to be streamed ot the client,
+ open_file: A file like object to be streamed to the client,
is closed when finished streaming.
"""
@@ -403,3 +386,240 @@ class ReadableFileWrapper:
# We yield to the reactor by sleeping for 0 seconds.
await self.clock.sleep(0)
+
+
+@implementer(interfaces.IConsumer)
+@implementer(interfaces.IPushProducer)
+class MultipartFileConsumer:
+ """Wraps a given consumer so that any data that gets written to it gets
+ converted to a multipart format.
+ """
+
+ def __init__(
+ self,
+ clock: Clock,
+ wrapped_consumer: interfaces.IConsumer,
+ file_content_type: str,
+ json_object: JsonDict,
+ content_length: Optional[int] = None,
+ ) -> None:
+ self.clock = clock
+ self.wrapped_consumer = wrapped_consumer
+ self.json_field = json_object
+ self.json_field_written = False
+ self.content_type_written = False
+ self.file_content_type = file_content_type
+ self.boundary = uuid4().hex.encode("ascii")
+
+ # The producer that registered with us, and if it's a push or pull
+ # producer.
+ self.producer: Optional["interfaces.IProducer"] = None
+ self.streaming: Optional[bool] = None
+
+ # Whether the wrapped consumer has asked us to pause.
+ self.paused = False
+
+ self.length = content_length
+
+ ### IConsumer APIs ###
+
+ def registerProducer(
+ self, producer: "interfaces.IProducer", streaming: bool
+ ) -> None:
+ """
+ Register to receive data from a producer.
+
+ This sets self to be a consumer for a producer. When this object runs
+ out of data (as when a send(2) call on a socket succeeds in moving the
+ last data from a userspace buffer into a kernelspace buffer), it will
+ ask the producer to resumeProducing().
+
+ For L{IPullProducer} providers, C{resumeProducing} will be called once
+ each time data is required.
+
+ For L{IPushProducer} providers, C{pauseProducing} will be called
+ whenever the write buffer fills up and C{resumeProducing} will only be
+ called when it empties. The consumer will only call C{resumeProducing}
+ to balance a previous C{pauseProducing} call; the producer is assumed
+ to start in an un-paused state.
+
+ @param streaming: C{True} if C{producer} provides L{IPushProducer},
+ C{False} if C{producer} provides L{IPullProducer}.
+
+ @raise RuntimeError: If a producer is already registered.
+ """
+ self.producer = producer
+ self.streaming = streaming
+
+ self.wrapped_consumer.registerProducer(self, True)
+
+ # kick off producing if `self.producer` is not a streaming producer
+ if not streaming:
+ self.resumeProducing()
+
+ def unregisterProducer(self) -> None:
+ """
+ Stop consuming data from a producer, without disconnecting.
+ """
+ self.wrapped_consumer.write(CRLF + b"--" + self.boundary + b"--" + CRLF)
+ self.wrapped_consumer.unregisterProducer()
+ self.paused = True
+
+ def write(self, data: bytes) -> None:
+ """
+ The producer will write data by calling this method.
+
+ The implementation must be non-blocking and perform whatever
+ buffering is necessary. If the producer has provided enough data
+ for now and it is a L{IPushProducer}, the consumer may call its
+ C{pauseProducing} method.
+ """
+ if not self.json_field_written:
+ self.wrapped_consumer.write(CRLF + b"--" + self.boundary + CRLF)
+
+ content_type = Header(b"Content-Type", b"application/json")
+ self.wrapped_consumer.write(bytes(content_type) + CRLF)
+
+ json_field = json.dumps(self.json_field)
+ json_bytes = json_field.encode("utf-8")
+ self.wrapped_consumer.write(CRLF + json_bytes)
+ self.wrapped_consumer.write(CRLF + b"--" + self.boundary + CRLF)
+
+ self.json_field_written = True
+
+ # if we haven't written the content type yet, do so
+ if not self.content_type_written:
+ type = self.file_content_type.encode("utf-8")
+ content_type = Header(b"Content-Type", type)
+ self.wrapped_consumer.write(bytes(content_type) + CRLF + CRLF)
+ self.content_type_written = True
+
+ self.wrapped_consumer.write(data)
+
+ ### IPushProducer APIs ###
+
+ def stopProducing(self) -> None:
+ """
+ Stop producing data.
+
+ This tells a producer that its consumer has died, so it must stop
+ producing data for good.
+ """
+ assert self.producer is not None
+
+ self.paused = True
+ self.producer.stopProducing()
+
+ def pauseProducing(self) -> None:
+ """
+ Pause producing data.
+
+ Tells a producer that it has produced too much data to process for
+ the time being, and to stop until C{resumeProducing()} is called.
+ """
+ assert self.producer is not None
+
+ self.paused = True
+
+ if self.streaming:
+ cast("interfaces.IPushProducer", self.producer).pauseProducing()
+ else:
+ self.paused = True
+
+ def resumeProducing(self) -> None:
+ """
+ Resume producing data.
+
+ This tells a producer to re-add itself to the main loop and produce
+ more data for its consumer.
+ """
+ assert self.producer is not None
+
+ if self.streaming:
+ cast("interfaces.IPushProducer", self.producer).resumeProducing()
+ else:
+ # If the producer is not a streaming producer we need to start
+ # repeatedly calling `resumeProducing` in a loop.
+ run_in_background(self._resumeProducingRepeatedly)
+
+ def content_length(self) -> Optional[int]:
+ """
+ Calculate the content length of the multipart response
+ in bytes.
+ """
+ if not self.length:
+ return None
+ # calculate length of json field and content-type header
+ json_field = json.dumps(self.json_field)
+ json_bytes = json_field.encode("utf-8")
+ json_length = len(json_bytes)
+
+ type = self.file_content_type.encode("utf-8")
+ content_type = Header(b"Content-Type", type)
+ type_length = len(bytes(content_type))
+
+ # 154 is the length of the elements that aren't variable, ie
+ # CRLFs and boundary strings, etc
+ self.length += json_length + type_length + 154
+
+ return self.length
+
+ ### Internal APIs. ###
+
+ async def _resumeProducingRepeatedly(self) -> None:
+ assert self.producer is not None
+ assert not self.streaming
+
+ producer = cast("interfaces.IPullProducer", self.producer)
+
+ self.paused = False
+ while not self.paused:
+ producer.resumeProducing()
+ await self.clock.sleep(0)
+
+
+class Header:
+ """
+ `Header` This class is a tiny wrapper that produces
+ request headers. We can't use standard python header
+ class because it encodes unicode fields using =? bla bla ?=
+ encoding, which is correct, but no one in HTTP world expects
+ that, everyone wants utf-8 raw bytes. (stolen from treq.multipart)
+
+ """
+
+ def __init__(
+ self,
+ name: bytes,
+ value: Any,
+ params: Optional[List[Tuple[Any, Any]]] = None,
+ ):
+ self.name = name
+ self.value = value
+ self.params = params or []
+
+ def add_param(self, name: Any, value: Any) -> None:
+ self.params.append((name, value))
+
+ def __bytes__(self) -> bytes:
+ with closing(BytesIO()) as h:
+ h.write(self.name + b": " + escape(self.value).encode("us-ascii"))
+ if self.params:
+ for name, val in self.params:
+ h.write(b"; ")
+ h.write(escape(name).encode("us-ascii"))
+ h.write(b"=")
+ h.write(b'"' + escape(val).encode("utf-8") + b'"')
+ h.seek(0)
+ return h.read()
+
+
+def escape(value: Union[str, bytes]) -> str:
+ """
+ This function prevents header values from corrupting the request,
+ a newline in the file name parameter makes form-data request unreadable
+ for a majority of parsers. (stolen from treq.multipart)
+ """
+ if isinstance(value, bytes):
+ value = value.decode("utf-8")
+ return value.replace("\r", "").replace("\n", "").replace('"', '\\"')
diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py
index 5538020bec..f8a9560784 100644
--- a/synapse/media/thumbnailer.py
+++ b/synapse/media/thumbnailer.py
@@ -22,11 +22,27 @@
import logging
from io import BytesIO
from types import TracebackType
-from typing import Optional, Tuple, Type
+from typing import TYPE_CHECKING, List, Optional, Tuple, Type
from PIL import Image
+from synapse.api.errors import Codes, SynapseError, cs_error
+from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
+from synapse.http.server import respond_with_json
+from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import trace
+from synapse.media._base import (
+ FileInfo,
+ ThumbnailInfo,
+ respond_404,
+ respond_with_file,
+ respond_with_responder,
+)
+from synapse.media.media_storage import MediaStorage
+
+if TYPE_CHECKING:
+ from synapse.media.media_repository import MediaRepository
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -231,3 +247,473 @@ class Thumbnailer:
def __del__(self) -> None:
# Make sure we actually do close the image, rather than leak data.
self.close()
+
+
+class ThumbnailProvider:
+ def __init__(
+ self,
+ hs: "HomeServer",
+ media_repo: "MediaRepository",
+ media_storage: MediaStorage,
+ ):
+ self.hs = hs
+ self.media_repo = media_repo
+ self.media_storage = media_storage
+ self.store = hs.get_datastores().main
+ self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
+
+ async def respond_local_thumbnail(
+ self,
+ request: SynapseRequest,
+ media_id: str,
+ width: int,
+ height: int,
+ method: str,
+ m_type: str,
+ max_timeout_ms: int,
+ ) -> None:
+ media_info = await self.media_repo.get_local_media_info(
+ request, media_id, max_timeout_ms
+ )
+ if not media_info:
+ return
+
+ thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
+ await self._select_and_respond_with_thumbnail(
+ request,
+ width,
+ height,
+ method,
+ m_type,
+ thumbnail_infos,
+ media_id,
+ media_id,
+ url_cache=bool(media_info.url_cache),
+ server_name=None,
+ )
+
+ async def select_or_generate_local_thumbnail(
+ self,
+ request: SynapseRequest,
+ media_id: str,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ max_timeout_ms: int,
+ ) -> None:
+ media_info = await self.media_repo.get_local_media_info(
+ request, media_id, max_timeout_ms
+ )
+
+ if not media_info:
+ return
+
+ thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
+ for info in thumbnail_infos:
+ t_w = info.width == desired_width
+ t_h = info.height == desired_height
+ t_method = info.method == desired_method
+ t_type = info.type == desired_type
+
+ if t_w and t_h and t_method and t_type:
+ file_info = FileInfo(
+ server_name=None,
+ file_id=media_id,
+ url_cache=bool(media_info.url_cache),
+ thumbnail=info,
+ )
+
+ responder = await self.media_storage.fetch_media(file_info)
+ if responder:
+ await respond_with_responder(
+ request, responder, info.type, info.length
+ )
+ return
+
+ logger.debug("We don't have a thumbnail of that size. Generating")
+
+ # Okay, so we generate one.
+ file_path = await self.media_repo.generate_local_exact_thumbnail(
+ media_id,
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
+ url_cache=bool(media_info.url_cache),
+ )
+
+ if file_path:
+ await respond_with_file(request, desired_type, file_path)
+ else:
+ logger.warning("Failed to generate thumbnail")
+ raise SynapseError(400, "Failed to generate thumbnail.")
+
+ async def select_or_generate_remote_thumbnail(
+ self,
+ request: SynapseRequest,
+ server_name: str,
+ media_id: str,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ max_timeout_ms: int,
+ ip_address: str,
+ ) -> None:
+ media_info = await self.media_repo.get_remote_media_info(
+ server_name, media_id, max_timeout_ms, ip_address
+ )
+ if not media_info:
+ respond_404(request)
+ return
+
+ thumbnail_infos = await self.store.get_remote_media_thumbnails(
+ server_name, media_id
+ )
+
+ file_id = media_info.filesystem_id
+
+ for info in thumbnail_infos:
+ t_w = info.width == desired_width
+ t_h = info.height == desired_height
+ t_method = info.method == desired_method
+ t_type = info.type == desired_type
+
+ if t_w and t_h and t_method and t_type:
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=file_id,
+ thumbnail=info,
+ )
+
+ responder = await self.media_storage.fetch_media(file_info)
+ if responder:
+ await respond_with_responder(
+ request, responder, info.type, info.length
+ )
+ return
+
+ logger.debug("We don't have a thumbnail of that size. Generating")
+
+ # Okay, so we generate one.
+ file_path = await self.media_repo.generate_remote_exact_thumbnail(
+ server_name,
+ file_id,
+ media_id,
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
+ )
+
+ if file_path:
+ await respond_with_file(request, desired_type, file_path)
+ else:
+ logger.warning("Failed to generate thumbnail")
+ raise SynapseError(400, "Failed to generate thumbnail.")
+
+ async def respond_remote_thumbnail(
+ self,
+ request: SynapseRequest,
+ server_name: str,
+ media_id: str,
+ width: int,
+ height: int,
+ method: str,
+ m_type: str,
+ max_timeout_ms: int,
+ ip_address: str,
+ ) -> None:
+ # TODO: Don't download the whole remote file
+ # We should proxy the thumbnail from the remote server instead of
+ # downloading the remote file and generating our own thumbnails.
+ media_info = await self.media_repo.get_remote_media_info(
+ server_name, media_id, max_timeout_ms, ip_address
+ )
+ if not media_info:
+ return
+
+ thumbnail_infos = await self.store.get_remote_media_thumbnails(
+ server_name, media_id
+ )
+ await self._select_and_respond_with_thumbnail(
+ request,
+ width,
+ height,
+ method,
+ m_type,
+ thumbnail_infos,
+ media_id,
+ media_info.filesystem_id,
+ url_cache=False,
+ server_name=server_name,
+ )
+
+ async def _select_and_respond_with_thumbnail(
+ self,
+ request: SynapseRequest,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ thumbnail_infos: List[ThumbnailInfo],
+ media_id: str,
+ file_id: str,
+ url_cache: bool,
+ server_name: Optional[str] = None,
+ ) -> None:
+ """
+ Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
+
+ Args:
+ request: The incoming request.
+ desired_width: The desired width, the returned thumbnail may be larger than this.
+ desired_height: The desired height, the returned thumbnail may be larger than this.
+ desired_method: The desired method used to generate the thumbnail.
+ desired_type: The desired content-type of the thumbnail.
+ thumbnail_infos: A list of thumbnail info of candidate thumbnails.
+ file_id: The ID of the media that a thumbnail is being requested for.
+ url_cache: True if this is from a URL cache.
+ server_name: The server name, if this is a remote thumbnail.
+ """
+ logger.debug(
+ "_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
+ media_id,
+ desired_width,
+ desired_height,
+ desired_method,
+ thumbnail_infos,
+ )
+
+ # If `dynamic_thumbnails` is enabled, we expect Synapse to go down a
+ # different code path to handle it.
+ assert not self.dynamic_thumbnails
+
+ if thumbnail_infos:
+ file_info = self._select_thumbnail(
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
+ thumbnail_infos,
+ file_id,
+ url_cache,
+ server_name,
+ )
+ if not file_info:
+ logger.info("Couldn't find a thumbnail matching the desired inputs")
+ respond_404(request)
+ return
+
+ # The thumbnail property must exist.
+ assert file_info.thumbnail is not None
+
+ responder = await self.media_storage.fetch_media(file_info)
+ if responder:
+ await respond_with_responder(
+ request,
+ responder,
+ file_info.thumbnail.type,
+ file_info.thumbnail.length,
+ )
+ return
+
+ # If we can't find the thumbnail we regenerate it. This can happen
+ # if e.g. we've deleted the thumbnails but still have the original
+ # image somewhere.
+ #
+ # Since we have an entry for the thumbnail in the DB we a) know we
+ # have have successfully generated the thumbnail in the past (so we
+ # don't need to worry about repeatedly failing to generate
+ # thumbnails), and b) have already calculated that appropriate
+ # width/height/method so we can just call the "generate exact"
+ # methods.
+
+ # First let's check that we do actually have the original image
+ # still. This will throw a 404 if we don't.
+ # TODO: We should refetch the thumbnails for remote media.
+ await self.media_storage.ensure_media_is_in_local_cache(
+ FileInfo(server_name, file_id, url_cache=url_cache)
+ )
+
+ if server_name:
+ await self.media_repo.generate_remote_exact_thumbnail(
+ server_name,
+ file_id=file_id,
+ media_id=media_id,
+ t_width=file_info.thumbnail.width,
+ t_height=file_info.thumbnail.height,
+ t_method=file_info.thumbnail.method,
+ t_type=file_info.thumbnail.type,
+ )
+ else:
+ await self.media_repo.generate_local_exact_thumbnail(
+ media_id=media_id,
+ t_width=file_info.thumbnail.width,
+ t_height=file_info.thumbnail.height,
+ t_method=file_info.thumbnail.method,
+ t_type=file_info.thumbnail.type,
+ url_cache=url_cache,
+ )
+
+ responder = await self.media_storage.fetch_media(file_info)
+ await respond_with_responder(
+ request,
+ responder,
+ file_info.thumbnail.type,
+ file_info.thumbnail.length,
+ )
+ else:
+ # This might be because:
+ # 1. We can't create thumbnails for the given media (corrupted or
+ # unsupported file type), or
+ # 2. The thumbnailing process never ran or errored out initially
+ # when the media was first uploaded (these bugs should be
+ # reported and fixed).
+ # Note that we don't attempt to generate a thumbnail now because
+ # `dynamic_thumbnails` is disabled.
+ logger.info("Failed to find any generated thumbnails")
+
+ assert request.path is not None
+ respond_with_json(
+ request,
+ 400,
+ cs_error(
+ "Cannot find any thumbnails for the requested media ('%s'). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
+ % (
+ request.path.decode(),
+ ", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
+ ),
+ code=Codes.UNKNOWN,
+ ),
+ send_cors=True,
+ )
+
+ def _select_thumbnail(
+ self,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ thumbnail_infos: List[ThumbnailInfo],
+ file_id: str,
+ url_cache: bool,
+ server_name: Optional[str],
+ ) -> Optional[FileInfo]:
+ """
+ Choose an appropriate thumbnail from the previously generated thumbnails.
+
+ Args:
+ desired_width: The desired width, the returned thumbnail may be larger than this.
+ desired_height: The desired height, the returned thumbnail may be larger than this.
+ desired_method: The desired method used to generate the thumbnail.
+ desired_type: The desired content-type of the thumbnail.
+ thumbnail_infos: A list of thumbnail infos of candidate thumbnails.
+ file_id: The ID of the media that a thumbnail is being requested for.
+ url_cache: True if this is from a URL cache.
+ server_name: The server name, if this is a remote thumbnail.
+
+ Returns:
+ The thumbnail which best matches the desired parameters.
+ """
+ desired_method = desired_method.lower()
+
+ # The chosen thumbnail.
+ thumbnail_info = None
+
+ d_w = desired_width
+ d_h = desired_height
+
+ if desired_method == "crop":
+ # Thumbnails that match equal or larger sizes of desired width/height.
+ crop_info_list: List[
+ Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
+ ] = []
+ # Other thumbnails.
+ crop_info_list2: List[
+ Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
+ ] = []
+ for info in thumbnail_infos:
+ # Skip thumbnails generated with different methods.
+ if info.method != "crop":
+ continue
+
+ t_w = info.width
+ t_h = info.height
+ aspect_quality = abs(d_w * t_h - d_h * t_w)
+ min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
+ size_quality = abs((d_w - t_w) * (d_h - t_h))
+ type_quality = desired_type != info.type
+ length_quality = info.length
+ if t_w >= d_w or t_h >= d_h:
+ crop_info_list.append(
+ (
+ aspect_quality,
+ min_quality,
+ size_quality,
+ type_quality,
+ length_quality,
+ info,
+ )
+ )
+ else:
+ crop_info_list2.append(
+ (
+ aspect_quality,
+ min_quality,
+ size_quality,
+ type_quality,
+ length_quality,
+ info,
+ )
+ )
+ # Pick the most appropriate thumbnail. Some values of `desired_width` and
+ # `desired_height` may result in a tie, in which case we avoid comparing on
+ # the thumbnail info and pick the thumbnail that appears earlier
+ # in the list of candidates.
+ if crop_info_list:
+ thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1]
+ elif crop_info_list2:
+ thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1]
+ elif desired_method == "scale":
+ # Thumbnails that match equal or larger sizes of desired width/height.
+ info_list: List[Tuple[int, bool, int, ThumbnailInfo]] = []
+ # Other thumbnails.
+ info_list2: List[Tuple[int, bool, int, ThumbnailInfo]] = []
+
+ for info in thumbnail_infos:
+ # Skip thumbnails generated with different methods.
+ if info.method != "scale":
+ continue
+
+ t_w = info.width
+ t_h = info.height
+ size_quality = abs((d_w - t_w) * (d_h - t_h))
+ type_quality = desired_type != info.type
+ length_quality = info.length
+ if t_w >= d_w or t_h >= d_h:
+ info_list.append((size_quality, type_quality, length_quality, info))
+ else:
+ info_list2.append(
+ (size_quality, type_quality, length_quality, info)
+ )
+ # Pick the most appropriate thumbnail. Some values of `desired_width` and
+ # `desired_height` may result in a tie, in which case we avoid comparing on
+ # the thumbnail info and pick the thumbnail that appears earlier
+ # in the list of candidates.
+ if info_list:
+ thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1]
+ elif info_list2:
+ thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1]
+
+ if thumbnail_info:
+ return FileInfo(
+ file_id=file_id,
+ url_cache=url_cache,
+ server_name=server_name,
+ thumbnail=thumbnail_info,
+ )
+
+ # No matching thumbnail was found.
+ return None
diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py
index 3897823b35..2e65a04789 100644
--- a/synapse/media/url_previewer.py
+++ b/synapse/media/url_previewer.py
@@ -592,7 +592,7 @@ class UrlPreviewer:
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
- with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ async with self.media_storage.store_into_file(file_info) as (f, fname):
if url.startswith("data:"):
if not allow_data_urls:
raise SynapseError(
@@ -603,8 +603,6 @@ class UrlPreviewer:
else:
download_result = await self._download_url(url, f)
- await finish()
-
try:
time_now_ms = self.clock.time_msec()
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 7c1cd3b5f2..c87eb748c0 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -721,7 +721,6 @@ class Notifier:
user.to_string(),
new_events,
is_peeking=is_peeking,
- msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
)
elif keyname == StreamKeyType.PRESENCE:
now = self.clock.time_msec()
@@ -763,6 +762,29 @@ class Notifier:
return result
+ async def wait_for_stream_token(self, stream_token: StreamToken) -> bool:
+ """Wait for this worker to catch up with the given stream token."""
+
+ start = self.clock.time_msec()
+ while True:
+ current_token = self.event_sources.get_current_token()
+ if stream_token.is_before_or_eq(current_token):
+ return True
+
+ now = self.clock.time_msec()
+
+ if now - start > 10_000:
+ return False
+
+ logger.info(
+ "Waiting for current token to reach %s; currently at %s",
+ stream_token,
+ current_token,
+ )
+
+ # TODO: be better
+ await self.clock.sleep(0.5)
+
async def _get_room_ids(
self, user: UserID, explicit_room_id: Optional[str]
) -> Tuple[StrCollection, bool]:
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 49ce9d6dda..cf611bd90b 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -28,7 +28,7 @@ import jinja2
from markupsafe import Markup
from prometheus_client import Counter
-from synapse.api.constants import EventTypes, Membership, RoomTypes
+from synapse.api.constants import EventContentFields, EventTypes, Membership, RoomTypes
from synapse.api.errors import StoreError
from synapse.config.emailconfig import EmailSubjectConfig
from synapse.events import EventBase
@@ -532,7 +532,6 @@ class Mailer:
self._storage_controllers,
user_id,
results.events_before,
- msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
)
the_events.append(notif_event)
@@ -717,7 +716,8 @@ class Mailer:
)
if (
create_event
- and create_event.content.get("room_type") == RoomTypes.SPACE
+ and create_event.content.get(EventContentFields.ROOM_TYPE)
+ == RoomTypes.SPACE
):
return self.email_subjects.invite_from_person_to_space % {
"person": inviter_name,
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index ba257d34e6..3dddbb70b4 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -55,6 +55,7 @@ from synapse.replication.tcp.streams.partial_state import (
)
from synapse.types import PersistedEventPosition, ReadReceipt, StreamKeyType, UserID
from synapse.util.async_helpers import Linearizer, timeout_deferred
+from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
if TYPE_CHECKING:
@@ -111,6 +112,21 @@ class ReplicationDataHandler:
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
+ all_room_ids: Set[str] = set()
+ if stream_name == DeviceListsStream.NAME:
+ if any(not row.is_signature and not row.hosts_calculated for row in rows):
+ prev_token = self.store.get_device_stream_token()
+ all_room_ids = await self.store.get_all_device_list_changes(
+ prev_token, token
+ )
+ self.store.device_lists_in_rooms_have_changed(all_room_ids, token)
+
+ # If we're sending federation we need to update the device lists
+ # outbound pokes stream change cache with updated hosts.
+ if self.send_handler and any(row.hosts_calculated for row in rows):
+ hosts = await self.store.get_destinations_for_device(token)
+ self.store.device_lists_outbound_pokes_have_changed(hosts, token)
+
self.store.process_replication_rows(stream_name, instance_name, token, rows)
# NOTE: this must be called after process_replication_rows to ensure any
# cache invalidations are first handled before any stream ID advances.
@@ -145,14 +161,14 @@ class ReplicationDataHandler:
StreamKeyType.TO_DEVICE, token, users=entities
)
elif stream_name == DeviceListsStream.NAME:
- all_room_ids: Set[str] = set()
- for row in rows:
- if row.entity.startswith("@") and not row.is_signature:
- room_ids = await self.store.get_rooms_for_user(row.entity)
- all_room_ids.update(room_ids)
- self.notifier.on_new_event(
- StreamKeyType.DEVICE_LIST, token, rooms=all_room_ids
- )
+ # `all_room_ids` can be large, so let's wake up those streams in batches
+ for batched_room_ids in batch_iter(all_room_ids, 100):
+ self.notifier.on_new_event(
+ StreamKeyType.DEVICE_LIST, token, rooms=batched_room_ids
+ )
+
+ # Yield to reactor so that we don't block.
+ await self._clock.sleep(0)
elif stream_name == PushersStream.NAME:
for row in rows:
if row.deleted:
@@ -423,12 +439,11 @@ class FederationSenderHandler:
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
- hosts = {
- row.entity
- for row in rows
- if not row.entity.startswith("@") and not row.is_signature
- }
- await self.federation_sender.send_device_messages(hosts, immediate=False)
+ if any(row.hosts_calculated for row in rows):
+ hosts = await self.store.get_destinations_for_device(token)
+ await self.federation_sender.send_device_messages(
+ hosts, immediate=False
+ )
elif stream_name == ToDeviceStream.NAME:
# The to_device stream includes stuff to be pushed to both local
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 661206c841..d021904de7 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -549,10 +549,14 @@ class DeviceListsStream(_StreamFromIdGen):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceListsStreamRow:
- entity: str
+ user_id: str
# Indicates that a user has signed their own device with their user-signing key
is_signature: bool
+ # Indicates if this is a notification that we've calculated the hosts we
+ # need to send the update to.
+ hosts_calculated: bool
+
NAME = "device_lists"
ROW_TYPE = DeviceListsStreamRow
@@ -594,13 +598,13 @@ class DeviceListsStream(_StreamFromIdGen):
upper_limit_token = min(upper_limit_token, signatures_to_token)
device_updates = [
- (stream_id, (entity, False))
- for stream_id, (entity,) in device_updates
+ (stream_id, (entity, False, hosts))
+ for stream_id, (entity, hosts) in device_updates
if stream_id <= upper_limit_token
]
signatures_updates = [
- (stream_id, (entity, True))
+ (stream_id, (entity, True, False))
for stream_id, (entity,) in signatures_updates
if stream_id <= upper_limit_token
]
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 534dc0e276..0024ccf708 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -53,7 +53,7 @@ from synapse.rest.client import (
register,
relations,
rendezvous,
- report_event,
+ reporting,
room,
room_keys,
room_upgrade_rest_servlet,
@@ -128,7 +128,7 @@ class ClientRestResource(JsonResource):
tags.register_servlets(hs, client_resource)
account_data.register_servlets(hs, client_resource)
if is_main_process:
- report_event.register_servlets(hs, client_resource)
+ reporting.register_servlets(hs, client_resource)
openid.register_servlets(hs, client_resource)
notifications.register_servlets(hs, client_resource)
devices.register_servlets(hs, client_resource)
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 6da1d79168..cdaee17451 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -101,6 +101,7 @@ from synapse.rest.admin.users import (
ResetPasswordRestServlet,
SearchUsersRestServlet,
ShadowBanRestServlet,
+ SuspendAccountRestServlet,
UserAdminServlet,
UserByExternalId,
UserByThreePid,
@@ -327,6 +328,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
BackgroundUpdateRestServlet(hs).register(http_server)
BackgroundUpdateStartJobRestServlet(hs).register(http_server)
ExperimentalFeaturesRestServlet(hs).register(http_server)
+ if hs.config.experimental.msc3823_account_suspension:
+ SuspendAccountRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(
diff --git a/synapse/rest/admin/experimental_features.py b/synapse/rest/admin/experimental_features.py
index 52eb9e62db..c5a00c490c 100644
--- a/synapse/rest/admin/experimental_features.py
+++ b/synapse/rest/admin/experimental_features.py
@@ -41,7 +41,6 @@ class ExperimentalFeature(str, Enum):
MSC3026 = "msc3026"
MSC3881 = "msc3881"
- MSC3967 = "msc3967"
class ExperimentalFeaturesRestServlet(RestServlet):
diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py
index 14ab4644cb..d85a04b825 100644
--- a/synapse/rest/admin/federation.py
+++ b/synapse/rest/admin/federation.py
@@ -61,8 +61,8 @@ class ListDestinationsRestServlet(RestServlet):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self._auth, request)
- start = parse_integer(request, "from", default=0, negative=False)
- limit = parse_integer(request, "limit", default=100, negative=False)
+ start = parse_integer(request, "from", default=0)
+ limit = parse_integer(request, "limit", default=100)
destination = parse_string(request, "destination")
@@ -181,8 +181,8 @@ class DestinationMembershipRestServlet(RestServlet):
if not await self._store.is_destination_known(destination):
raise NotFoundError("Unknown destination")
- start = parse_integer(request, "from", default=0, negative=False)
- limit = parse_integer(request, "limit", default=100, negative=False)
+ start = parse_integer(request, "from", default=0)
+ limit = parse_integer(request, "limit", default=100)
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index a05b7252ec..ee6a681285 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -311,8 +311,8 @@ class DeleteMediaByDateSize(RestServlet):
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
- before_ts = parse_integer(request, "before_ts", required=True, negative=False)
- size_gt = parse_integer(request, "size_gt", default=0, negative=False)
+ before_ts = parse_integer(request, "before_ts", required=True)
+ size_gt = parse_integer(request, "size_gt", default=0)
keep_profiles = parse_boolean(request, "keep_profiles", default=True)
if before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds
@@ -377,8 +377,8 @@ class UserMediaRestServlet(RestServlet):
if user is None:
raise NotFoundError("Unknown user")
- start = parse_integer(request, "from", default=0, negative=False)
- limit = parse_integer(request, "limit", default=100, negative=False)
+ start = parse_integer(request, "from", default=0)
+ limit = parse_integer(request, "limit", default=100)
# If neither `order_by` nor `dir` is set, set the default order
# to newest media is on top for backward compatibility.
@@ -421,8 +421,8 @@ class UserMediaRestServlet(RestServlet):
if user is None:
raise NotFoundError("Unknown user")
- start = parse_integer(request, "from", default=0, negative=False)
- limit = parse_integer(request, "limit", default=100, negative=False)
+ start = parse_integer(request, "from", default=0)
+ limit = parse_integer(request, "limit", default=100)
# If neither `order_by` nor `dir` is set, set the default order
# to newest media is on top for backward compatibility.
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 0d86a4e15f..01f9de9ffa 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -35,6 +35,7 @@ from synapse.http.servlet import (
ResolveRoomIdMixin,
RestServlet,
assert_params_in_dict,
+ parse_boolean,
parse_enum,
parse_integer,
parse_json,
@@ -242,13 +243,23 @@ class ListRoomRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
+ public_rooms = parse_boolean(request, "public_rooms")
+ empty_rooms = parse_boolean(request, "empty_rooms")
+
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
reverse_order = True if direction == Direction.BACKWARDS else False
# Return list of rooms according to parameters
rooms, total_rooms = await self.store.get_rooms_paginate(
- start, limit, order_by, reverse_order, search_term
+ start,
+ limit,
+ order_by,
+ reverse_order,
+ search_term,
+ public_rooms,
+ empty_rooms,
)
+
response = {
# next_token should be opaque, so return a value the client can parse
"offset": start,
diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py
index dc27a41dd9..0adc5b7005 100644
--- a/synapse/rest/admin/statistics.py
+++ b/synapse/rest/admin/statistics.py
@@ -63,10 +63,10 @@ class UserMediaStatisticsRestServlet(RestServlet):
),
)
- start = parse_integer(request, "from", default=0, negative=False)
- limit = parse_integer(request, "limit", default=100, negative=False)
- from_ts = parse_integer(request, "from_ts", default=0, negative=False)
- until_ts = parse_integer(request, "until_ts", negative=False)
+ start = parse_integer(request, "from", default=0)
+ limit = parse_integer(request, "limit", default=100)
+ from_ts = parse_integer(request, "from_ts", default=0)
+ until_ts = parse_integer(request, "until_ts")
if until_ts is not None:
if until_ts <= from_ts:
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 5bf12c4979..ad515bd5a3 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -27,11 +27,13 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import attr
+from synapse._pydantic_compat import HAS_PYDANTIC_V2
from synapse.api.constants import Direction, UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
+ parse_and_validate_json_object_from_request,
parse_boolean,
parse_enum,
parse_integer,
@@ -49,10 +51,17 @@ from synapse.rest.client._base import client_patterns
from synapse.storage.databases.main.registration import ExternalIDReuseException
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.types import JsonDict, JsonMapping, UserID
+from synapse.types.rest import RequestBodyModel
if TYPE_CHECKING:
from synapse.server import HomeServer
+if TYPE_CHECKING or HAS_PYDANTIC_V2:
+ from pydantic.v1 import StrictBool
+else:
+ from pydantic import StrictBool
+
+
logger = logging.getLogger(__name__)
@@ -90,8 +99,8 @@ class UsersRestServletV2(RestServlet):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
- start = parse_integer(request, "from", default=0, negative=False)
- limit = parse_integer(request, "limit", default=100, negative=False)
+ start = parse_integer(request, "from", default=0)
+ limit = parse_integer(request, "limit", default=100)
user_id = parse_string(request, "user_id")
name = parse_string(request, "name", encoding="utf-8")
@@ -732,6 +741,36 @@ class DeactivateAccountRestServlet(RestServlet):
return HTTPStatus.OK, {"id_server_unbind_result": id_server_unbind_result}
+class SuspendAccountRestServlet(RestServlet):
+ PATTERNS = admin_patterns("/suspend/(?P<target_user_id>[^/]*)$")
+
+ def __init__(self, hs: "HomeServer"):
+ self.auth = hs.get_auth()
+ self.is_mine = hs.is_mine
+ self.store = hs.get_datastores().main
+
+ class PutBody(RequestBodyModel):
+ suspend: StrictBool
+
+ async def on_PUT(
+ self, request: SynapseRequest, target_user_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester)
+
+ if not self.is_mine(UserID.from_string(target_user_id)):
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only suspend local users")
+
+ if not await self.store.get_user_by_id(target_user_id):
+ raise NotFoundError("User not found")
+
+ body = parse_and_validate_json_object_from_request(request, self.PutBody)
+ suspend = body.suspend
+ await self.store.set_user_suspended_status(target_user_id, suspend)
+
+ return HTTPStatus.OK, {f"user_{target_user_id}_suspended": suspend}
+
+
class AccountValidityRenewServlet(RestServlet):
PATTERNS = admin_patterns("/account_validity/validity$")
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 6ac07d354c..8daa449f9e 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -56,14 +56,14 @@ from synapse.http.servlet import (
from synapse.http.site import SynapseRequest
from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
-from synapse.rest.client.models import (
+from synapse.types import JsonDict
+from synapse.types.rest import RequestBodyModel
+from synapse.types.rest.client import (
AuthenticationData,
ClientSecretStr,
EmailRequestTokenBody,
MsisdnRequestTokenBody,
)
-from synapse.rest.models import RequestBodyModel
-from synapse.types import JsonDict
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import check_3pid_allowed, validate_email
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index b1b803549e..8313d687b7 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -42,9 +42,9 @@ from synapse.http.servlet import (
)
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns, interactive_auth_handler
-from synapse.rest.client.models import AuthenticationData
-from synapse.rest.models import RequestBodyModel
from synapse.types import JsonDict
+from synapse.types.rest import RequestBodyModel
+from synapse.types.rest.client import AuthenticationData
if TYPE_CHECKING:
from synapse.server import HomeServer
diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py
index 8099fdf3e4..11fdd0f7c6 100644
--- a/synapse/rest/client/directory.py
+++ b/synapse/rest/client/directory.py
@@ -41,8 +41,8 @@ from synapse.http.servlet import (
)
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
-from synapse.rest.models import RequestBodyModel
from synapse.types import JsonDict, RoomAlias
+from synapse.types.rest import RequestBodyModel
if TYPE_CHECKING:
from synapse.server import HomeServer
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index a0017257ce..67de634eab 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -36,7 +36,6 @@ from synapse.http.servlet import (
)
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag
-from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.types import JsonDict, StreamToken
from synapse.util.cancellation import cancellable
@@ -105,13 +104,8 @@ class KeyUploadServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler()
-
- if hs.config.worker.worker_app is None:
- # if main process
- self.key_uploader = self.e2e_keys_handler.upload_keys_for_user
- else:
- # then a worker
- self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs)
+ self._clock = hs.get_clock()
+ self._store = hs.get_datastores().main
async def on_POST(
self, request: SynapseRequest, device_id: Optional[str]
@@ -151,9 +145,10 @@ class KeyUploadServlet(RestServlet):
400, "To upload keys, you must pass device_id when authenticating"
)
- result = await self.key_uploader(
+ result = await self.e2e_keys_handler.upload_keys_for_user(
user_id=user_id, device_id=device_id, keys=body
)
+
return 200, result
@@ -387,44 +382,35 @@ class SigningKeyUploadServlet(RestServlet):
master_key_updatable_without_uia,
) = await self.e2e_keys_handler.check_cross_signing_setup(user_id)
- # Before MSC3967 we required UIA both when setting up cross signing for the
- # first time and when resetting the device signing key. With MSC3967 we only
- # require UIA when resetting cross-signing, and not when setting up the first
- # time. Because there is no UIA in MSC3861, for now we throw an error if the
- # user tries to reset the device signing key when MSC3861 is enabled, but allow
- # first-time setup.
- if self.hs.config.experimental.msc3861.enabled:
- # The auth service has to explicitly mark the master key as replaceable
- # without UIA to reset the device signing key with MSC3861.
- if is_cross_signing_setup and not master_key_updatable_without_uia:
- config = self.hs.config.experimental.msc3861
- if config.account_management_url is not None:
- url = f"{config.account_management_url}?action=org.matrix.cross_signing_reset"
- else:
- url = config.issuer
-
- raise SynapseError(
- HTTPStatus.NOT_IMPLEMENTED,
- "To reset your end-to-end encryption cross-signing identity, "
- f"you first need to approve it at {url} and then try again.",
- Codes.UNRECOGNIZED,
- )
- # But first-time setup is fine
-
- elif self.hs.config.experimental.msc3967_enabled:
- # MSC3967 allows this endpoint to 200 OK for idempotency. Resending exactly the same
- # keys should just 200 OK without doing a UIA prompt.
- keys_are_different = await self.e2e_keys_handler.has_different_keys(
- user_id, body
- )
- if not keys_are_different:
- # FIXME: we do not fallthrough to upload_signing_keys_for_user because confusingly
- # if we do, we 500 as it looks like it tries to INSERT the same key twice, causing a
- # unique key constraint violation. This sounds like a bug?
- return 200, {}
- # the keys are different, is x-signing set up? If no, then the keys don't exist which is
- # why they are different. If yes, then we need to UIA to change them.
- if is_cross_signing_setup:
+ # Resending exactly the same keys should just 200 OK without doing a UIA prompt.
+ keys_are_different = await self.e2e_keys_handler.has_different_keys(
+ user_id, body
+ )
+ if not keys_are_different:
+ return 200, {}
+
+ # The keys are different; is x-signing set up? If no, then this is first-time
+ # setup, and that is allowed without UIA, per MSC3967.
+ # If yes, then we need to authenticate the change.
+ if is_cross_signing_setup:
+ # With MSC3861, UIA is not possible. Instead, the auth service has to
+ # explicitly mark the master key as replaceable.
+ if self.hs.config.experimental.msc3861.enabled:
+ if not master_key_updatable_without_uia:
+ config = self.hs.config.experimental.msc3861
+ if config.account_management_url is not None:
+ url = f"{config.account_management_url}?action=org.matrix.cross_signing_reset"
+ else:
+ url = config.issuer
+
+ raise SynapseError(
+ HTTPStatus.NOT_IMPLEMENTED,
+ "To reset your end-to-end encryption cross-signing identity, "
+ f"you first need to approve it at {url} and then try again.",
+ Codes.UNRECOGNIZED,
+ )
+ else:
+ # Without MSC3861, we require UIA.
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
@@ -433,18 +419,6 @@ class SigningKeyUploadServlet(RestServlet):
# Do not allow skipping of UIA auth.
can_skip_ui_auth=False,
)
- # Otherwise we don't require UIA since we are setting up cross signing for first time
- else:
- # Previous behaviour is to always require UIA but allow it to be skipped
- await self.auth_handler.validate_user_via_ui_auth(
- requester,
- request,
- body,
- "add a device signing key to your account",
- # Allow skipping of UI auth since this is frequently called directly
- # after login and it is silly to ask users to re-auth immediately.
- can_skip_ui_auth=True,
- )
result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
return 200, result
diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py
index ff52a9bf8c..e31687fc13 100644
--- a/synapse/rest/client/knock.py
+++ b/synapse/rest/client/knock.py
@@ -53,6 +53,7 @@ class KnockRoomAliasServlet(RestServlet):
super().__init__()
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
+ self._support_via = hs.config.experimental.msc4156_enabled
async def on_POST(
self,
@@ -74,6 +75,13 @@ class KnockRoomAliasServlet(RestServlet):
remote_room_hosts = parse_strings_from_args(
args, "server_name", required=False
)
+ if self._support_via:
+ remote_room_hosts = parse_strings_from_args(
+ args,
+ "org.matrix.msc4156.via",
+ default=remote_room_hosts,
+ required=False,
+ )
elif RoomAlias.is_valid(room_identifier):
handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
diff --git a/synapse/rest/client/media.py b/synapse/rest/client/media.py
new file mode 100644
index 0000000000..0c089163c1
--- /dev/null
+++ b/synapse/rest/client/media.py
@@ -0,0 +1,207 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+
+import logging
+import re
+
+from synapse.http.server import (
+ HttpServer,
+ respond_with_json,
+ respond_with_json_bytes,
+ set_corp_headers,
+ set_cors_headers,
+)
+from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.site import SynapseRequest
+from synapse.media._base import (
+ DEFAULT_MAX_TIMEOUT_MS,
+ MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
+ respond_404,
+)
+from synapse.media.media_repository import MediaRepository
+from synapse.media.media_storage import MediaStorage
+from synapse.media.thumbnailer import ThumbnailProvider
+from synapse.server import HomeServer
+from synapse.util.stringutils import parse_and_validate_server_name
+
+logger = logging.getLogger(__name__)
+
+
+class UnstablePreviewURLServlet(RestServlet):
+ """
+ Same as `GET /_matrix/media/r0/preview_url`, this endpoint provides a generic preview API
+ for URLs which outputs Open Graph (https://ogp.me/) responses (with some Matrix
+ specific additions).
+
+ This does have trade-offs compared to other designs:
+
+ * Pros:
+ * Simple and flexible; can be used by any clients at any point
+ * Cons:
+ * If each homeserver provides one of these independently, all the homeservers in a
+ room may needlessly DoS the target URI
+ * The URL metadata must be stored somewhere, rather than just using Matrix
+ itself to store the media.
+ * Matrix cannot be used to distribute the metadata between homeservers.
+ """
+
+ PATTERNS = [
+ re.compile(r"^/_matrix/client/unstable/org.matrix.msc3916/media/preview_url$")
+ ]
+
+ def __init__(
+ self,
+ hs: "HomeServer",
+ media_repo: "MediaRepository",
+ media_storage: MediaStorage,
+ ):
+ super().__init__()
+
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.media_repo = media_repo
+ self.media_storage = media_storage
+ assert self.media_repo.url_previewer is not None
+ self.url_previewer = self.media_repo.url_previewer
+
+ async def on_GET(self, request: SynapseRequest) -> None:
+ requester = await self.auth.get_user_by_req(request)
+ url = parse_string(request, "url", required=True)
+ ts = parse_integer(request, "ts")
+ if ts is None:
+ ts = self.clock.time_msec()
+
+ og = await self.url_previewer.preview(url, requester.user, ts)
+ respond_with_json_bytes(request, 200, og, send_cors=True)
+
+
+class UnstableMediaConfigResource(RestServlet):
+ PATTERNS = [
+ re.compile(r"^/_matrix/client/unstable/org.matrix.msc3916/media/config$")
+ ]
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ config = hs.config
+ self.clock = hs.get_clock()
+ self.auth = hs.get_auth()
+ self.limits_dict = {"m.upload.size": config.media.max_upload_size}
+
+ async def on_GET(self, request: SynapseRequest) -> None:
+ await self.auth.get_user_by_req(request)
+ respond_with_json(request, 200, self.limits_dict, send_cors=True)
+
+
+class UnstableThumbnailResource(RestServlet):
+ PATTERNS = [
+ re.compile(
+ "/_matrix/client/unstable/org.matrix.msc3916/media/thumbnail/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
+ )
+ ]
+
+ def __init__(
+ self,
+ hs: "HomeServer",
+ media_repo: "MediaRepository",
+ media_storage: MediaStorage,
+ ):
+ super().__init__()
+
+ self.store = hs.get_datastores().main
+ self.media_repo = media_repo
+ self.media_storage = media_storage
+ self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
+ self._is_mine_server_name = hs.is_mine_server_name
+ self._server_name = hs.hostname
+ self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
+ self.thumbnailer = ThumbnailProvider(hs, media_repo, media_storage)
+ self.auth = hs.get_auth()
+
+ async def on_GET(
+ self, request: SynapseRequest, server_name: str, media_id: str
+ ) -> None:
+ # Validate the server name, raising if invalid
+ parse_and_validate_server_name(server_name)
+ await self.auth.get_user_by_req(request)
+
+ set_cors_headers(request)
+ set_corp_headers(request)
+ width = parse_integer(request, "width", required=True)
+ height = parse_integer(request, "height", required=True)
+ method = parse_string(request, "method", "scale")
+ # TODO Parse the Accept header to get an prioritised list of thumbnail types.
+ m_type = "image/png"
+ max_timeout_ms = parse_integer(
+ request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
+ )
+ max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
+
+ if self._is_mine_server_name(server_name):
+ if self.dynamic_thumbnails:
+ await self.thumbnailer.select_or_generate_local_thumbnail(
+ request, media_id, width, height, method, m_type, max_timeout_ms
+ )
+ else:
+ await self.thumbnailer.respond_local_thumbnail(
+ request, media_id, width, height, method, m_type, max_timeout_ms
+ )
+ self.media_repo.mark_recently_accessed(None, media_id)
+ else:
+ # Don't let users download media from configured domains, even if it
+ # is already downloaded. This is Trust & Safety tooling to make some
+ # media inaccessible to local users.
+ # See `prevent_media_downloads_from` config docs for more info.
+ if server_name in self.prevent_media_downloads_from:
+ respond_404(request)
+ return
+
+ ip_address = request.getClientAddress().host
+ remote_resp_function = (
+ self.thumbnailer.select_or_generate_remote_thumbnail
+ if self.dynamic_thumbnails
+ else self.thumbnailer.respond_remote_thumbnail
+ )
+ await remote_resp_function(
+ request,
+ server_name,
+ media_id,
+ width,
+ height,
+ method,
+ m_type,
+ max_timeout_ms,
+ ip_address,
+ )
+ self.media_repo.mark_recently_accessed(server_name, media_id)
+
+
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
+ if hs.config.experimental.msc3916_authenticated_media_enabled:
+ media_repo = hs.get_media_repository()
+ if hs.config.media.url_preview_enabled:
+ UnstablePreviewURLServlet(
+ hs, media_repo, media_repo.media_storage
+ ).register(http_server)
+ UnstableMediaConfigResource(hs).register(http_server)
+ UnstableThumbnailResource(hs, media_repo, media_repo.media_storage).register(
+ http_server
+ )
diff --git a/synapse/rest/client/models.py b/synapse/rest/client/models.py
deleted file mode 100644
index fc1aed2889..0000000000
--- a/synapse/rest/client/models.py
+++ /dev/null
@@ -1,99 +0,0 @@
-#
-# This file is licensed under the Affero General Public License (AGPL) version 3.
-#
-# Copyright 2022 The Matrix.org Foundation C.I.C.
-# Copyright (C) 2023 New Vector, Ltd
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as
-# published by the Free Software Foundation, either version 3 of the
-# License, or (at your option) any later version.
-#
-# See the GNU Affero General Public License for more details:
-# <https://www.gnu.org/licenses/agpl-3.0.html>.
-#
-# Originally licensed under the Apache License, Version 2.0:
-# <http://www.apache.org/licenses/LICENSE-2.0>.
-#
-# [This file includes modifications made by New Vector Limited]
-#
-#
-from typing import TYPE_CHECKING, Dict, Optional
-
-from synapse._pydantic_compat import HAS_PYDANTIC_V2
-
-if TYPE_CHECKING or HAS_PYDANTIC_V2:
- from pydantic.v1 import Extra, StrictInt, StrictStr, constr, validator
-else:
- from pydantic import Extra, StrictInt, StrictStr, constr, validator
-
-from synapse.rest.models import RequestBodyModel
-from synapse.util.threepids import validate_email
-
-
-class AuthenticationData(RequestBodyModel):
- """
- Data used during user-interactive authentication.
-
- (The name "Authentication Data" is taken directly from the spec.)
-
- Additional keys will be present, depending on the `type` field. Use
- `.dict(exclude_unset=True)` to access them.
- """
-
- class Config:
- extra = Extra.allow
-
- session: Optional[StrictStr] = None
- type: Optional[StrictStr] = None
-
-
-if TYPE_CHECKING:
- ClientSecretStr = StrictStr
-else:
- # See also assert_valid_client_secret()
- ClientSecretStr = constr(
- regex="[0-9a-zA-Z.=_-]", # noqa: F722
- min_length=1,
- max_length=255,
- strict=True,
- )
-
-
-class ThreepidRequestTokenBody(RequestBodyModel):
- client_secret: ClientSecretStr
- id_server: Optional[StrictStr]
- id_access_token: Optional[StrictStr]
- next_link: Optional[StrictStr]
- send_attempt: StrictInt
-
- @validator("id_access_token", always=True)
- def token_required_for_identity_server(
- cls, token: Optional[str], values: Dict[str, object]
- ) -> Optional[str]:
- if values.get("id_server") is not None and token is None:
- raise ValueError("id_access_token is required if an id_server is supplied.")
- return token
-
-
-class EmailRequestTokenBody(ThreepidRequestTokenBody):
- email: StrictStr
-
- # Canonicalise the email address. The addresses are all stored canonicalised
- # in the database. This allows the user to reset his password without having to
- # know the exact spelling (eg. upper and lower case) of address in the database.
- # Without this, an email stored in the database as "foo@bar.com" would cause
- # user requests for "FOO@bar.com" to raise a Not Found error.
- _email_validator = validator("email", allow_reuse=True)(validate_email)
-
-
-if TYPE_CHECKING:
- ISO3116_1_Alpha_2 = StrictStr
-else:
- # Per spec: two-letter uppercase ISO-3166-1-alpha-2
- ISO3116_1_Alpha_2 = constr(regex="[A-Z]{2}", strict=True)
-
-
-class MsisdnRequestTokenBody(ThreepidRequestTokenBody):
- country: ISO3116_1_Alpha_2
- phone_number: StrictStr
diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index be9b584748..168ce50d3f 100644
--- a/synapse/rest/client/notifications.py
+++ b/synapse/rest/client/notifications.py
@@ -32,6 +32,7 @@ from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
+from ...api.errors import SynapseError
from ._base import client_patterns
if TYPE_CHECKING:
@@ -56,7 +57,22 @@ class NotificationsServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
- from_token = parse_string(request, "from", required=False)
+ # While this is intended to be "string" to clients, the 'from' token
+ # is actually based on a numeric ID. So it must parse to an int.
+ from_token_str = parse_string(request, "from", required=False)
+ if from_token_str is not None:
+ # Parse to an integer.
+ try:
+ from_token = int(from_token_str)
+ except ValueError:
+ # If it doesn't parse to an integer, then this cannot possibly be a valid
+ # pagination token, as we only hand out integers.
+ raise SynapseError(
+ 400, 'Query parameter "from" contains unrecognised token'
+ )
+ else:
+ from_token = None
+
limit = parse_integer(request, "limit", default=50)
only = parse_string(request, "only", required=False)
diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py
index 0323f6afa1..c1a80c5c3d 100644
--- a/synapse/rest/client/profile.py
+++ b/synapse/rest/client/profile.py
@@ -108,6 +108,19 @@ class ProfileDisplaynameRestServlet(RestServlet):
propagate = _read_propagate(self.hs, request)
+ requester_suspended = (
+ await self.hs.get_datastores().main.get_user_suspended_status(
+ requester.user.to_string()
+ )
+ )
+
+ if requester_suspended:
+ raise SynapseError(
+ 403,
+ "Updating displayname while account is suspended is not allowed.",
+ Codes.USER_ACCOUNT_SUSPENDED,
+ )
+
await self.profile_handler.set_displayname(
user, requester, new_name, is_admin, propagate=propagate
)
@@ -167,6 +180,19 @@ class ProfileAvatarURLRestServlet(RestServlet):
propagate = _read_propagate(self.hs, request)
+ requester_suspended = (
+ await self.hs.get_datastores().main.get_user_suspended_status(
+ requester.user.to_string()
+ )
+ )
+
+ if requester_suspended:
+ raise SynapseError(
+ 403,
+ "Updating avatar URL while account is suspended is not allowed.",
+ Codes.USER_ACCOUNT_SUSPENDED,
+ )
+
await self.profile_handler.set_avatar_url(
user, requester, new_avatar_url, is_admin, propagate=propagate
)
diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/reporting.py
index 447281931e..4eee53e5a8 100644
--- a/synapse/rest/client/report_event.py
+++ b/synapse/rest/client/reporting.py
@@ -23,17 +23,28 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
+from synapse._pydantic_compat import HAS_PYDANTIC_V2
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.servlet import (
+ RestServlet,
+ parse_and_validate_json_object_from_request,
+ parse_json_object_from_request,
+)
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
+from synapse.types.rest import RequestBodyModel
from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
+if TYPE_CHECKING or HAS_PYDANTIC_V2:
+ from pydantic.v1 import StrictStr
+else:
+ from pydantic import StrictStr
+
logger = logging.getLogger(__name__)
@@ -95,5 +106,57 @@ class ReportEventRestServlet(RestServlet):
return 200, {}
+class ReportRoomRestServlet(RestServlet):
+ """This endpoint lets clients report a room for abuse.
+
+ Whilst MSC4151 is not yet merged, this unstable endpoint is enabled on matrix.org
+ for content moderation purposes, and therefore backwards compatibility should be
+ carefully considered when changing anything on this endpoint.
+
+ More details on the MSC: https://github.com/matrix-org/matrix-spec-proposals/pull/4151
+ """
+
+ PATTERNS = client_patterns(
+ "/org.matrix.msc4151/rooms/(?P<room_id>[^/]*)/report$",
+ releases=[],
+ v1=False,
+ unstable=True,
+ )
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastores().main
+
+ class PostBody(RequestBodyModel):
+ reason: StrictStr
+
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ body = parse_and_validate_json_object_from_request(request, self.PostBody)
+
+ room = await self.store.get_room(room_id)
+ if room is None:
+ raise NotFoundError("Room does not exist")
+
+ await self.store.add_room_report(
+ room_id=room_id,
+ user_id=user_id,
+ reason=body.reason,
+ received_ts=self.clock.time_msec(),
+ )
+
+ return 200, {}
+
+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReportEventRestServlet(hs).register(http_server)
+
+ if hs.config.experimental.msc4151_enabled:
+ ReportRoomRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index fb4d44211e..903c74f6d8 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -292,6 +292,9 @@ class RoomStateEventRestServlet(RestServlet):
try:
if event_type == EventTypes.Member:
membership = content.get("membership", None)
+ if not isinstance(membership, str):
+ raise SynapseError(400, "Invalid membership (must be a string)")
+
event_id, _ = await self.room_member_handler.update_membership(
requester,
target=UserID.from_string(state_key),
@@ -414,6 +417,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
super().__init__(hs)
super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up
self.auth = hs.get_auth()
+ self._support_via = hs.config.experimental.msc4156_enabled
def register(self, http_server: HttpServer) -> None:
# /join/$room_identifier[/$txn_id]
@@ -432,6 +436,13 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
# twisted.web.server.Request.args is incorrectly defined as Optional[Any]
args: Dict[bytes, List[bytes]] = request.args # type: ignore
remote_room_hosts = parse_strings_from_args(args, "server_name", required=False)
+ if self._support_via:
+ remote_room_hosts = parse_strings_from_args(
+ args,
+ "org.matrix.msc4156.via",
+ default=remote_room_hosts,
+ required=False,
+ )
room_id, remote_room_hosts = await self.resolve_room_id(
room_identifier,
remote_room_hosts,
@@ -499,7 +510,7 @@ class PublicRoomListRestServlet(RestServlet):
if server:
raise e
- limit: Optional[int] = parse_integer(request, "limit", 0, negative=False)
+ limit: Optional[int] = parse_integer(request, "limit", 0)
since_token = parse_string(request, "since")
if limit == 0:
@@ -1109,6 +1120,20 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
+ requester_suspended = await self._store.get_user_suspended_status(
+ requester.user.to_string()
+ )
+
+ if requester_suspended:
+ event = await self._store.get_event(event_id, allow_none=True)
+ if event:
+ if event.sender != requester.user.to_string():
+ raise SynapseError(
+ 403,
+ "You can only redact your own events while account is suspended.",
+ Codes.USER_ACCOUNT_SUSPENDED,
+ )
+
# Ensure the redacts property in the content matches the one provided in
# the URL.
room_version = await self._store.get_room_version(room_id)
@@ -1419,16 +1444,7 @@ class RoomHierarchyRestServlet(RestServlet):
requester = await self._auth.get_user_by_req(request, allow_guest=True)
max_depth = parse_integer(request, "max_depth")
- if max_depth is not None and max_depth < 0:
- raise SynapseError(
- 400, "'max_depth' must be a non-negative integer", Codes.BAD_JSON
- )
-
limit = parse_integer(request, "limit")
- if limit is not None and limit <= 0:
- raise SynapseError(
- 400, "'limit' must be a positive integer", Codes.BAD_JSON
- )
return 200, await self._room_summary_handler.get_room_hierarchy(
requester,
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 2b103ca6a8..b5ab0d8534 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -33,6 +33,7 @@ from synapse.events.utils import (
format_event_raw,
)
from synapse.handlers.presence import format_user_presence_state
+from synapse.handlers.sliding_sync import SlidingSyncConfig, SlidingSyncResult
from synapse.handlers.sync import (
ArchivedSyncResult,
InvitedSyncResult,
@@ -40,13 +41,22 @@ from synapse.handlers.sync import (
KnockedSyncResult,
SyncConfig,
SyncResult,
+ SyncVersion,
)
from synapse.http.server import HttpServer
-from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
+from synapse.http.servlet import (
+ RestServlet,
+ parse_and_validate_json_object_from_request,
+ parse_boolean,
+ parse_integer,
+ parse_string,
+)
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import trace_with_opname
from synapse.types import JsonDict, Requester, StreamToken
+from synapse.types.rest.client import SlidingSyncBody
from synapse.util import json_decoder
+from synapse.util.caches.lrucache import LruCache
from ._base import client_patterns, set_timeline_upper_limit
@@ -110,6 +120,11 @@ class SyncRestServlet(RestServlet):
self._msc2654_enabled = hs.config.experimental.msc2654_enabled
self._msc3773_enabled = hs.config.experimental.msc3773_enabled
+ self._json_filter_cache: LruCache[str, bool] = LruCache(
+ max_size=1000,
+ cache_name="sync_valid_filter",
+ )
+
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
# This will always be set by the time Twisted calls us.
assert request.args is not None
@@ -177,7 +192,13 @@ class SyncRestServlet(RestServlet):
filter_object = json_decoder.decode(filter_id)
except Exception:
raise SynapseError(400, "Invalid filter JSON", errcode=Codes.NOT_JSON)
- self.filtering.check_valid_filter(filter_object)
+
+ # We cache the validation, as this can get quite expensive if people use
+ # a literal json blob as a query param.
+ if not self._json_filter_cache.get(filter_id):
+ self.filtering.check_valid_filter(filter_object)
+ self._json_filter_cache[filter_id] = True
+
set_timeline_upper_limit(
filter_object, self.hs.config.server.filter_timeline_limit
)
@@ -197,7 +218,6 @@ class SyncRestServlet(RestServlet):
user=user,
filter_collection=filter_collection,
is_guest=requester.is_guest,
- request_key=request_key,
device_id=device_id,
)
@@ -220,6 +240,8 @@ class SyncRestServlet(RestServlet):
sync_result = await self.sync_handler.wait_for_sync_for_user(
requester,
sync_config,
+ SyncVersion.SYNC_V2,
+ request_key,
since_token=since_token,
timeout=timeout,
full_state=full_state,
@@ -553,5 +575,396 @@ class SyncRestServlet(RestServlet):
return result
+class SlidingSyncE2eeRestServlet(RestServlet):
+ """
+ API endpoint for MSC3575 Sliding Sync `/sync/e2ee`. This is being introduced as part
+ of Sliding Sync but doesn't have any sliding window component. It's just a way to
+ get E2EE events without having to sit through a big initial sync (`/sync` v2). And
+ we can avoid encryption events being backed up by the main sync response.
+
+ Having To-Device messages split out to this sync endpoint also helps when clients
+ need to have 2 or more sync streams open at a time, e.g a push notification process
+ and a main process. This can cause the two processes to race to fetch the To-Device
+ events, resulting in the need for complex synchronisation rules to ensure the token
+ is correctly and atomically exchanged between processes.
+
+ GET parameters::
+ timeout(int): How long to wait for new events in milliseconds.
+ since(batch_token): Batch token when asking for incremental deltas.
+
+ Response JSON::
+ {
+ "next_batch": // batch token for the next /sync
+ "to_device": {
+ // list of to-device events
+ "events": [
+ {
+ "content: { "algorithm": "m.olm.v1.curve25519-aes-sha2", "ciphertext": { ... }, "org.matrix.msgid": "abcd", "session_id": "abcd" },
+ "type": "m.room.encrypted",
+ "sender": "@alice:example.com",
+ }
+ // ...
+ ]
+ },
+ "device_lists": {
+ "changed": ["@alice:example.com"],
+ "left": ["@bob:example.com"]
+ },
+ "device_one_time_keys_count": {
+ "signed_curve25519": 50
+ },
+ "device_unused_fallback_key_types": [
+ "signed_curve25519"
+ ]
+ }
+ """
+
+ PATTERNS = client_patterns(
+ "/org.matrix.msc3575/sync/e2ee$", releases=[], v1=False, unstable=True
+ )
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastores().main
+ self.sync_handler = hs.get_sync_handler()
+
+ # Filtering only matters for the `device_lists` because it requires a bunch of
+ # derived information from rooms (see how `_generate_sync_entry_for_rooms()`
+ # prepares a bunch of data for `_generate_sync_entry_for_device_list()`).
+ self.only_member_events_filter_collection = FilterCollection(
+ self.hs,
+ {
+ "room": {
+ # We only care about membership events for the `device_lists`.
+ # Membership will tell us whether a user has joined/left a room and
+ # if there are new devices to encrypt for.
+ "timeline": {
+ "types": ["m.room.member"],
+ },
+ "state": {
+ "types": ["m.room.member"],
+ },
+ # We don't want any extra account_data generated because it's not
+ # returned by this endpoint. This helps us avoid work in
+ # `_generate_sync_entry_for_rooms()`
+ "account_data": {
+ "not_types": ["*"],
+ },
+ # We don't want any extra ephemeral data generated because it's not
+ # returned by this endpoint. This helps us avoid work in
+ # `_generate_sync_entry_for_rooms()`
+ "ephemeral": {
+ "not_types": ["*"],
+ },
+ },
+ # We don't want any extra account_data generated because it's not
+ # returned by this endpoint. (This is just here for good measure)
+ "account_data": {
+ "not_types": ["*"],
+ },
+ # We don't want any extra presence data generated because it's not
+ # returned by this endpoint. (This is just here for good measure)
+ "presence": {
+ "not_types": ["*"],
+ },
+ },
+ )
+
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ user = requester.user
+ device_id = requester.device_id
+
+ timeout = parse_integer(request, "timeout", default=0)
+ since = parse_string(request, "since")
+
+ sync_config = SyncConfig(
+ user=user,
+ filter_collection=self.only_member_events_filter_collection,
+ is_guest=requester.is_guest,
+ device_id=device_id,
+ )
+
+ since_token = None
+ if since is not None:
+ since_token = await StreamToken.from_string(self.store, since)
+
+ # Request cache key
+ request_key = (
+ SyncVersion.E2EE_SYNC,
+ user,
+ timeout,
+ since,
+ )
+
+ # Gather data for the response
+ sync_result = await self.sync_handler.wait_for_sync_for_user(
+ requester,
+ sync_config,
+ SyncVersion.E2EE_SYNC,
+ request_key,
+ since_token=since_token,
+ timeout=timeout,
+ full_state=False,
+ )
+
+ # The client may have disconnected by now; don't bother to serialize the
+ # response if so.
+ if request._disconnected:
+ logger.info("Client has disconnected; not serializing response.")
+ return 200, {}
+
+ response: JsonDict = defaultdict(dict)
+ response["next_batch"] = await sync_result.next_batch.to_string(self.store)
+
+ if sync_result.to_device:
+ response["to_device"] = {"events": sync_result.to_device}
+
+ if sync_result.device_lists.changed:
+ response["device_lists"]["changed"] = list(sync_result.device_lists.changed)
+ if sync_result.device_lists.left:
+ response["device_lists"]["left"] = list(sync_result.device_lists.left)
+
+ # We always include this because https://github.com/vector-im/element-android/issues/3725
+ # The spec isn't terribly clear on when this can be omitted and how a client would tell
+ # the difference between "no keys present" and "nothing changed" in terms of whole field
+ # absent / individual key type entry absent
+ # Corresponding synapse issue: https://github.com/matrix-org/synapse/issues/10456
+ response["device_one_time_keys_count"] = sync_result.device_one_time_keys_count
+
+ # https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
+ # states that this field should always be included, as long as the server supports the feature.
+ response["device_unused_fallback_key_types"] = (
+ sync_result.device_unused_fallback_key_types
+ )
+
+ return 200, response
+
+
+class SlidingSyncRestServlet(RestServlet):
+ """
+ API endpoint for MSC3575 Sliding Sync `/sync`. Allows for clients to request a
+ subset (sliding window) of rooms, state, and timeline events (just what they need)
+ in order to bootstrap quickly and subscribe to only what the client cares about.
+ Because the client can specify what it cares about, we can respond quickly and skip
+ all of the work we would normally have to do with a sync v2 response.
+
+ Request query parameters:
+ timeout: How long to wait for new events in milliseconds.
+ pos: Stream position token when asking for incremental deltas.
+
+ Request body::
+ {
+ // Sliding Window API
+ "lists": {
+ "foo-list": {
+ "ranges": [ [0, 99] ],
+ "sort": [ "by_notification_level", "by_recency", "by_name" ],
+ "required_state": [
+ ["m.room.join_rules", ""],
+ ["m.room.history_visibility", ""],
+ ["m.space.child", "*"]
+ ],
+ "timeline_limit": 10,
+ "filters": {
+ "is_dm": true
+ },
+ "bump_event_types": [ "m.room.message", "m.room.encrypted" ],
+ }
+ },
+ // Room Subscriptions API
+ "room_subscriptions": {
+ "!sub1:bar": {
+ "required_state": [ ["*","*"] ],
+ "timeline_limit": 10,
+ "include_old_rooms": {
+ "timeline_limit": 1,
+ "required_state": [ ["m.room.tombstone", ""], ["m.room.create", ""] ],
+ }
+ }
+ },
+ // Extensions API
+ "extensions": {}
+ }
+
+ Response JSON::
+ {
+ "next_pos": "s58_224_0_13_10_1_1_16_0_1",
+ "lists": {
+ "foo-list": {
+ "count": 1337,
+ "ops": [{
+ "op": "SYNC",
+ "range": [0, 99],
+ "room_ids": [
+ "!foo:bar",
+ // ... 99 more room IDs
+ ]
+ }]
+ }
+ },
+ // Aggregated rooms from lists and room subscriptions
+ "rooms": {
+ // Room from room subscription
+ "!sub1:bar": {
+ "name": "Alice and Bob",
+ "avatar": "mxc://...",
+ "initial": true,
+ "required_state": [
+ {"sender":"@alice:example.com","type":"m.room.create", "state_key":"", "content":{"creator":"@alice:example.com"}},
+ {"sender":"@alice:example.com","type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"invite"}},
+ {"sender":"@alice:example.com","type":"m.room.history_visibility", "state_key":"", "content":{"history_visibility":"joined"}},
+ {"sender":"@alice:example.com","type":"m.room.member", "state_key":"@alice:example.com", "content":{"membership":"join"}}
+ ],
+ "timeline": [
+ {"sender":"@alice:example.com","type":"m.room.create", "state_key":"", "content":{"creator":"@alice:example.com"}},
+ {"sender":"@alice:example.com","type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"invite"}},
+ {"sender":"@alice:example.com","type":"m.room.history_visibility", "state_key":"", "content":{"history_visibility":"joined"}},
+ {"sender":"@alice:example.com","type":"m.room.member", "state_key":"@alice:example.com", "content":{"membership":"join"}},
+ {"sender":"@alice:example.com","type":"m.room.message", "content":{"body":"A"}},
+ {"sender":"@alice:example.com","type":"m.room.message", "content":{"body":"B"}},
+ ],
+ "prev_batch": "t111_222_333",
+ "joined_count": 41,
+ "invited_count": 1,
+ "notification_count": 1,
+ "highlight_count": 0
+ },
+ // rooms from list
+ "!foo:bar": {
+ "name": "The calculated room name",
+ "avatar": "mxc://...",
+ "initial": true,
+ "required_state": [
+ {"sender":"@alice:example.com","type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"invite"}},
+ {"sender":"@alice:example.com","type":"m.room.history_visibility", "state_key":"", "content":{"history_visibility":"joined"}},
+ {"sender":"@alice:example.com","type":"m.space.child", "state_key":"!foo:example.com", "content":{"via":["example.com"]}},
+ {"sender":"@alice:example.com","type":"m.space.child", "state_key":"!bar:example.com", "content":{"via":["example.com"]}},
+ {"sender":"@alice:example.com","type":"m.space.child", "state_key":"!baz:example.com", "content":{"via":["example.com"]}}
+ ],
+ "timeline": [
+ {"sender":"@alice:example.com","type":"m.room.join_rules", "state_key":"", "content":{"join_rule":"invite"}},
+ {"sender":"@alice:example.com","type":"m.room.message", "content":{"body":"A"}},
+ {"sender":"@alice:example.com","type":"m.room.message", "content":{"body":"B"}},
+ {"sender":"@alice:example.com","type":"m.room.message", "content":{"body":"C"}},
+ {"sender":"@alice:example.com","type":"m.room.message", "content":{"body":"D"}},
+ ],
+ "prev_batch": "t111_222_333",
+ "joined_count": 4,
+ "invited_count": 0,
+ "notification_count": 54,
+ "highlight_count": 3
+ },
+ // ... 99 more items
+ },
+ "extensions": {}
+ }
+ """
+
+ PATTERNS = client_patterns(
+ "/org.matrix.simplified_msc3575/sync$", releases=[], v1=False, unstable=True
+ )
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastores().main
+ self.filtering = hs.get_filtering()
+ self.sliding_sync_handler = hs.get_sliding_sync_handler()
+
+ # TODO: Update this to `on_GET` once we figure out how we want to handle params
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ user = requester.user
+ device_id = requester.device_id
+
+ timeout = parse_integer(request, "timeout", default=0)
+ # Position in the stream
+ from_token_string = parse_string(request, "pos")
+
+ from_token = None
+ if from_token_string is not None:
+ from_token = await StreamToken.from_string(self.store, from_token_string)
+
+ # TODO: We currently don't know whether we're going to use sticky params or
+ # maybe some filters like sync v2 where they are built up once and referenced
+ # by filter ID. For now, we will just prototype with always passing everything
+ # in.
+ body = parse_and_validate_json_object_from_request(request, SlidingSyncBody)
+ logger.info("Sliding sync request: %r", body)
+
+ sync_config = SlidingSyncConfig(
+ user=user,
+ device_id=device_id,
+ # FIXME: Currently, we're just manually copying the fields from the
+ # `SlidingSyncBody` into the config. How can we gurantee into the future
+ # that we don't forget any? I would like something more structured like
+ # `copy_attributes(from=body, to=config)`
+ lists=body.lists,
+ room_subscriptions=body.room_subscriptions,
+ extensions=body.extensions,
+ )
+
+ sliding_sync_results = await self.sliding_sync_handler.wait_for_sync_for_user(
+ requester,
+ sync_config,
+ from_token,
+ timeout,
+ )
+
+ # The client may have disconnected by now; don't bother to serialize the
+ # response if so.
+ if request._disconnected:
+ logger.info("Client has disconnected; not serializing response.")
+ return 200, {}
+
+ response_content = await self.encode_response(sliding_sync_results)
+
+ return 200, response_content
+
+ # TODO: Is there a better way to encode things?
+ async def encode_response(
+ self,
+ sliding_sync_result: SlidingSyncResult,
+ ) -> JsonDict:
+ response: JsonDict = defaultdict(dict)
+
+ response["next_pos"] = await sliding_sync_result.next_pos.to_string(self.store)
+ serialized_lists = self.encode_lists(sliding_sync_result.lists)
+ if serialized_lists:
+ response["lists"] = serialized_lists
+ response["rooms"] = {} # TODO: sliding_sync_result.rooms
+ response["extensions"] = {} # TODO: sliding_sync_result.extensions
+
+ return response
+
+ def encode_lists(
+ self, lists: Dict[str, SlidingSyncResult.SlidingWindowList]
+ ) -> JsonDict:
+ def encode_operation(
+ operation: SlidingSyncResult.SlidingWindowList.Operation,
+ ) -> JsonDict:
+ return {
+ "op": operation.op.value,
+ "range": operation.range,
+ "room_ids": operation.room_ids,
+ }
+
+ serialized_lists = {}
+ for list_key, list_result in lists.items():
+ serialized_lists[list_key] = {
+ "count": list_result.count,
+ "ops": [encode_operation(op) for op in list_result.ops],
+ }
+
+ return serialized_lists
+
+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)
+
+ if hs.config.experimental.msc3575_enabled:
+ SlidingSyncRestServlet(hs).register(http_server)
+ SlidingSyncE2eeRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 56de6906d0..f428158139 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -149,6 +149,8 @@ class VersionsRestServlet(RestServlet):
is not None
)
),
+ # MSC4151: Report room API (Client-Server API)
+ "org.matrix.msc4151": self.config.experimental.msc4151_enabled,
},
},
)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index dc7325fc57..a411ed614e 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -41,9 +41,9 @@ from synapse.http.servlet import (
parse_and_validate_json_object_from_request,
parse_integer,
)
-from synapse.rest.models import RequestBodyModel
from synapse.storage.keys import FetchKeyResultForRemote
from synapse.types import JsonDict
+from synapse.types.rest import RequestBodyModel
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
diff --git a/synapse/rest/media/download_resource.py b/synapse/rest/media/download_resource.py
index 8ba723c8d4..1628d58926 100644
--- a/synapse/rest/media/download_resource.py
+++ b/synapse/rest/media/download_resource.py
@@ -97,6 +97,12 @@ class DownloadResource(RestServlet):
respond_404(request)
return
+ ip_address = request.getClientAddress().host
await self.media_repo.get_remote_media(
- request, server_name, media_id, file_name, max_timeout_ms
+ request,
+ server_name,
+ media_id,
+ file_name,
+ max_timeout_ms,
+ ip_address,
)
diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py
index 7cb335c7c3..ce511c6dce 100644
--- a/synapse/rest/media/thumbnail_resource.py
+++ b/synapse/rest/media/thumbnail_resource.py
@@ -22,23 +22,18 @@
import logging
import re
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING
-from synapse.api.errors import Codes, SynapseError, cs_error
-from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
-from synapse.http.server import respond_with_json, set_corp_headers, set_cors_headers
+from synapse.http.server import set_corp_headers, set_cors_headers
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.media._base import (
DEFAULT_MAX_TIMEOUT_MS,
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
- FileInfo,
- ThumbnailInfo,
respond_404,
- respond_with_file,
- respond_with_responder,
)
from synapse.media.media_storage import MediaStorage
+from synapse.media.thumbnailer import ThumbnailProvider
from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING:
@@ -66,10 +61,11 @@ class ThumbnailResource(RestServlet):
self.store = hs.get_datastores().main
self.media_repo = media_repo
self.media_storage = media_storage
- self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
self._is_mine_server_name = hs.is_mine_server_name
self._server_name = hs.hostname
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
+ self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
+ self.thumbnail_provider = ThumbnailProvider(hs, media_repo, media_storage)
async def on_GET(
self, request: SynapseRequest, server_name: str, media_id: str
@@ -91,11 +87,11 @@ class ThumbnailResource(RestServlet):
if self._is_mine_server_name(server_name):
if self.dynamic_thumbnails:
- await self._select_or_generate_local_thumbnail(
+ await self.thumbnail_provider.select_or_generate_local_thumbnail(
request, media_id, width, height, method, m_type, max_timeout_ms
)
else:
- await self._respond_local_thumbnail(
+ await self.thumbnail_provider.respond_local_thumbnail(
request, media_id, width, height, method, m_type, max_timeout_ms
)
self.media_repo.mark_recently_accessed(None, media_id)
@@ -108,10 +104,11 @@ class ThumbnailResource(RestServlet):
respond_404(request)
return
+ ip_address = request.getClientAddress().host
remote_resp_function = (
- self._select_or_generate_remote_thumbnail
+ self.thumbnail_provider.select_or_generate_remote_thumbnail
if self.dynamic_thumbnails
- else self._respond_remote_thumbnail
+ else self.thumbnail_provider.respond_remote_thumbnail
)
await remote_resp_function(
request,
@@ -122,459 +119,6 @@ class ThumbnailResource(RestServlet):
method,
m_type,
max_timeout_ms,
+ ip_address,
)
self.media_repo.mark_recently_accessed(server_name, media_id)
-
- async def _respond_local_thumbnail(
- self,
- request: SynapseRequest,
- media_id: str,
- width: int,
- height: int,
- method: str,
- m_type: str,
- max_timeout_ms: int,
- ) -> None:
- media_info = await self.media_repo.get_local_media_info(
- request, media_id, max_timeout_ms
- )
- if not media_info:
- return
-
- thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
- await self._select_and_respond_with_thumbnail(
- request,
- width,
- height,
- method,
- m_type,
- thumbnail_infos,
- media_id,
- media_id,
- url_cache=bool(media_info.url_cache),
- server_name=None,
- )
-
- async def _select_or_generate_local_thumbnail(
- self,
- request: SynapseRequest,
- media_id: str,
- desired_width: int,
- desired_height: int,
- desired_method: str,
- desired_type: str,
- max_timeout_ms: int,
- ) -> None:
- media_info = await self.media_repo.get_local_media_info(
- request, media_id, max_timeout_ms
- )
-
- if not media_info:
- return
-
- thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
- for info in thumbnail_infos:
- t_w = info.width == desired_width
- t_h = info.height == desired_height
- t_method = info.method == desired_method
- t_type = info.type == desired_type
-
- if t_w and t_h and t_method and t_type:
- file_info = FileInfo(
- server_name=None,
- file_id=media_id,
- url_cache=bool(media_info.url_cache),
- thumbnail=info,
- )
-
- responder = await self.media_storage.fetch_media(file_info)
- if responder:
- await respond_with_responder(
- request, responder, info.type, info.length
- )
- return
-
- logger.debug("We don't have a thumbnail of that size. Generating")
-
- # Okay, so we generate one.
- file_path = await self.media_repo.generate_local_exact_thumbnail(
- media_id,
- desired_width,
- desired_height,
- desired_method,
- desired_type,
- url_cache=bool(media_info.url_cache),
- )
-
- if file_path:
- await respond_with_file(request, desired_type, file_path)
- else:
- logger.warning("Failed to generate thumbnail")
- raise SynapseError(400, "Failed to generate thumbnail.")
-
- async def _select_or_generate_remote_thumbnail(
- self,
- request: SynapseRequest,
- server_name: str,
- media_id: str,
- desired_width: int,
- desired_height: int,
- desired_method: str,
- desired_type: str,
- max_timeout_ms: int,
- ) -> None:
- media_info = await self.media_repo.get_remote_media_info(
- server_name, media_id, max_timeout_ms
- )
- if not media_info:
- respond_404(request)
- return
-
- thumbnail_infos = await self.store.get_remote_media_thumbnails(
- server_name, media_id
- )
-
- file_id = media_info.filesystem_id
-
- for info in thumbnail_infos:
- t_w = info.width == desired_width
- t_h = info.height == desired_height
- t_method = info.method == desired_method
- t_type = info.type == desired_type
-
- if t_w and t_h and t_method and t_type:
- file_info = FileInfo(
- server_name=server_name,
- file_id=file_id,
- thumbnail=info,
- )
-
- responder = await self.media_storage.fetch_media(file_info)
- if responder:
- await respond_with_responder(
- request, responder, info.type, info.length
- )
- return
-
- logger.debug("We don't have a thumbnail of that size. Generating")
-
- # Okay, so we generate one.
- file_path = await self.media_repo.generate_remote_exact_thumbnail(
- server_name,
- file_id,
- media_id,
- desired_width,
- desired_height,
- desired_method,
- desired_type,
- )
-
- if file_path:
- await respond_with_file(request, desired_type, file_path)
- else:
- logger.warning("Failed to generate thumbnail")
- raise SynapseError(400, "Failed to generate thumbnail.")
-
- async def _respond_remote_thumbnail(
- self,
- request: SynapseRequest,
- server_name: str,
- media_id: str,
- width: int,
- height: int,
- method: str,
- m_type: str,
- max_timeout_ms: int,
- ) -> None:
- # TODO: Don't download the whole remote file
- # We should proxy the thumbnail from the remote server instead of
- # downloading the remote file and generating our own thumbnails.
- media_info = await self.media_repo.get_remote_media_info(
- server_name, media_id, max_timeout_ms
- )
- if not media_info:
- return
-
- thumbnail_infos = await self.store.get_remote_media_thumbnails(
- server_name, media_id
- )
- await self._select_and_respond_with_thumbnail(
- request,
- width,
- height,
- method,
- m_type,
- thumbnail_infos,
- media_id,
- media_info.filesystem_id,
- url_cache=False,
- server_name=server_name,
- )
-
- async def _select_and_respond_with_thumbnail(
- self,
- request: SynapseRequest,
- desired_width: int,
- desired_height: int,
- desired_method: str,
- desired_type: str,
- thumbnail_infos: List[ThumbnailInfo],
- media_id: str,
- file_id: str,
- url_cache: bool,
- server_name: Optional[str] = None,
- ) -> None:
- """
- Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
-
- Args:
- request: The incoming request.
- desired_width: The desired width, the returned thumbnail may be larger than this.
- desired_height: The desired height, the returned thumbnail may be larger than this.
- desired_method: The desired method used to generate the thumbnail.
- desired_type: The desired content-type of the thumbnail.
- thumbnail_infos: A list of thumbnail info of candidate thumbnails.
- file_id: The ID of the media that a thumbnail is being requested for.
- url_cache: True if this is from a URL cache.
- server_name: The server name, if this is a remote thumbnail.
- """
- logger.debug(
- "_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
- media_id,
- desired_width,
- desired_height,
- desired_method,
- thumbnail_infos,
- )
-
- # If `dynamic_thumbnails` is enabled, we expect Synapse to go down a
- # different code path to handle it.
- assert not self.dynamic_thumbnails
-
- if thumbnail_infos:
- file_info = self._select_thumbnail(
- desired_width,
- desired_height,
- desired_method,
- desired_type,
- thumbnail_infos,
- file_id,
- url_cache,
- server_name,
- )
- if not file_info:
- logger.info("Couldn't find a thumbnail matching the desired inputs")
- respond_404(request)
- return
-
- # The thumbnail property must exist.
- assert file_info.thumbnail is not None
-
- responder = await self.media_storage.fetch_media(file_info)
- if responder:
- await respond_with_responder(
- request,
- responder,
- file_info.thumbnail.type,
- file_info.thumbnail.length,
- )
- return
-
- # If we can't find the thumbnail we regenerate it. This can happen
- # if e.g. we've deleted the thumbnails but still have the original
- # image somewhere.
- #
- # Since we have an entry for the thumbnail in the DB we a) know we
- # have have successfully generated the thumbnail in the past (so we
- # don't need to worry about repeatedly failing to generate
- # thumbnails), and b) have already calculated that appropriate
- # width/height/method so we can just call the "generate exact"
- # methods.
-
- # First let's check that we do actually have the original image
- # still. This will throw a 404 if we don't.
- # TODO: We should refetch the thumbnails for remote media.
- await self.media_storage.ensure_media_is_in_local_cache(
- FileInfo(server_name, file_id, url_cache=url_cache)
- )
-
- if server_name:
- await self.media_repo.generate_remote_exact_thumbnail(
- server_name,
- file_id=file_id,
- media_id=media_id,
- t_width=file_info.thumbnail.width,
- t_height=file_info.thumbnail.height,
- t_method=file_info.thumbnail.method,
- t_type=file_info.thumbnail.type,
- )
- else:
- await self.media_repo.generate_local_exact_thumbnail(
- media_id=media_id,
- t_width=file_info.thumbnail.width,
- t_height=file_info.thumbnail.height,
- t_method=file_info.thumbnail.method,
- t_type=file_info.thumbnail.type,
- url_cache=url_cache,
- )
-
- responder = await self.media_storage.fetch_media(file_info)
- await respond_with_responder(
- request,
- responder,
- file_info.thumbnail.type,
- file_info.thumbnail.length,
- )
- else:
- # This might be because:
- # 1. We can't create thumbnails for the given media (corrupted or
- # unsupported file type), or
- # 2. The thumbnailing process never ran or errored out initially
- # when the media was first uploaded (these bugs should be
- # reported and fixed).
- # Note that we don't attempt to generate a thumbnail now because
- # `dynamic_thumbnails` is disabled.
- logger.info("Failed to find any generated thumbnails")
-
- assert request.path is not None
- respond_with_json(
- request,
- 400,
- cs_error(
- "Cannot find any thumbnails for the requested media ('%s'). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
- % (
- request.path.decode(),
- ", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
- ),
- code=Codes.UNKNOWN,
- ),
- send_cors=True,
- )
-
- def _select_thumbnail(
- self,
- desired_width: int,
- desired_height: int,
- desired_method: str,
- desired_type: str,
- thumbnail_infos: List[ThumbnailInfo],
- file_id: str,
- url_cache: bool,
- server_name: Optional[str],
- ) -> Optional[FileInfo]:
- """
- Choose an appropriate thumbnail from the previously generated thumbnails.
-
- Args:
- desired_width: The desired width, the returned thumbnail may be larger than this.
- desired_height: The desired height, the returned thumbnail may be larger than this.
- desired_method: The desired method used to generate the thumbnail.
- desired_type: The desired content-type of the thumbnail.
- thumbnail_infos: A list of thumbnail infos of candidate thumbnails.
- file_id: The ID of the media that a thumbnail is being requested for.
- url_cache: True if this is from a URL cache.
- server_name: The server name, if this is a remote thumbnail.
-
- Returns:
- The thumbnail which best matches the desired parameters.
- """
- desired_method = desired_method.lower()
-
- # The chosen thumbnail.
- thumbnail_info = None
-
- d_w = desired_width
- d_h = desired_height
-
- if desired_method == "crop":
- # Thumbnails that match equal or larger sizes of desired width/height.
- crop_info_list: List[
- Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
- ] = []
- # Other thumbnails.
- crop_info_list2: List[
- Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
- ] = []
- for info in thumbnail_infos:
- # Skip thumbnails generated with different methods.
- if info.method != "crop":
- continue
-
- t_w = info.width
- t_h = info.height
- aspect_quality = abs(d_w * t_h - d_h * t_w)
- min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
- size_quality = abs((d_w - t_w) * (d_h - t_h))
- type_quality = desired_type != info.type
- length_quality = info.length
- if t_w >= d_w or t_h >= d_h:
- crop_info_list.append(
- (
- aspect_quality,
- min_quality,
- size_quality,
- type_quality,
- length_quality,
- info,
- )
- )
- else:
- crop_info_list2.append(
- (
- aspect_quality,
- min_quality,
- size_quality,
- type_quality,
- length_quality,
- info,
- )
- )
- # Pick the most appropriate thumbnail. Some values of `desired_width` and
- # `desired_height` may result in a tie, in which case we avoid comparing on
- # the thumbnail info and pick the thumbnail that appears earlier
- # in the list of candidates.
- if crop_info_list:
- thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1]
- elif crop_info_list2:
- thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1]
- elif desired_method == "scale":
- # Thumbnails that match equal or larger sizes of desired width/height.
- info_list: List[Tuple[int, bool, int, ThumbnailInfo]] = []
- # Other thumbnails.
- info_list2: List[Tuple[int, bool, int, ThumbnailInfo]] = []
-
- for info in thumbnail_infos:
- # Skip thumbnails generated with different methods.
- if info.method != "scale":
- continue
-
- t_w = info.width
- t_h = info.height
- size_quality = abs((d_w - t_w) * (d_h - t_h))
- type_quality = desired_type != info.type
- length_quality = info.length
- if t_w >= d_w or t_h >= d_h:
- info_list.append((size_quality, type_quality, length_quality, info))
- else:
- info_list2.append(
- (size_quality, type_quality, length_quality, info)
- )
- # Pick the most appropriate thumbnail. Some values of `desired_width` and
- # `desired_height` may result in a tie, in which case we avoid comparing on
- # the thumbnail info and pick the thumbnail that appears earlier
- # in the list of candidates.
- if info_list:
- thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1]
- elif info_list2:
- thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1]
-
- if thumbnail_info:
- return FileInfo(
- file_id=file_id,
- url_cache=url_cache,
- server_name=server_name,
- thumbnail=thumbnail_info,
- )
-
- # No matching thumbnail was found.
- return None
diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py
index ba6576d4db..7b5bfc0421 100644
--- a/synapse/rest/synapse/client/__init__.py
+++ b/synapse/rest/synapse/client/__init__.py
@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Mapping
from twisted.web.resource import Resource
+from synapse.rest.synapse.client.federation_whitelist import FederationWhitelistResource
from synapse.rest.synapse.client.new_user_consent import NewUserConsentResource
from synapse.rest.synapse.client.pick_idp import PickIdpResource
from synapse.rest.synapse.client.pick_username import pick_username_resource
@@ -77,6 +78,9 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc
# To be removed in Synapse v1.32.0.
resources["/_matrix/saml2"] = res
+ if hs.config.federation.federation_whitelist_endpoint_enabled:
+ resources[FederationWhitelistResource.PATH] = FederationWhitelistResource(hs)
+
if hs.config.experimental.msc4108_enabled:
resources["/_synapse/client/rendezvous"] = MSC4108RendezvousSessionResource(hs)
diff --git a/synapse/rest/synapse/client/federation_whitelist.py b/synapse/rest/synapse/client/federation_whitelist.py
new file mode 100644
index 0000000000..2b8f0320e0
--- /dev/null
+++ b/synapse/rest/synapse/client/federation_whitelist.py
@@ -0,0 +1,66 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+
+import logging
+from typing import TYPE_CHECKING, Tuple
+
+from synapse.http.server import DirectServeJsonResource
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class FederationWhitelistResource(DirectServeJsonResource):
+ """Custom endpoint (disabled by default) to fetch the federation whitelist
+ config.
+
+ Only enabled if `federation_whitelist_endpoint_enabled` feature is enabled.
+
+ Response format:
+
+ {
+ "whitelist_enabled": true, // Whether the federation whitelist is being enforced
+ "whitelist": [ // Which server names are allowed by the whitelist
+ "example.com"
+ ]
+ }
+ """
+
+ PATH = "/_synapse/client/v1/config/federation_whitelist"
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+
+ self._federation_whitelist = hs.config.federation.federation_domain_whitelist
+
+ self._auth = hs.get_auth()
+
+ async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await self._auth.get_user_by_req(request)
+
+ whitelist = []
+ if self._federation_whitelist:
+ # federation_whitelist is actually a dict, not a list
+ whitelist = list(self._federation_whitelist)
+
+ return_dict: JsonDict = {
+ "whitelist_enabled": self._federation_whitelist is not None,
+ "whitelist": whitelist,
+ }
+
+ return 200, return_dict
diff --git a/synapse/server.py b/synapse/server.py
index 95e319d2e6..ae927c3904 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -109,6 +109,7 @@ from synapse.handlers.room_summary import RoomSummaryHandler
from synapse.handlers.search import SearchHandler
from synapse.handlers.send_email import SendEmailHandler
from synapse.handlers.set_password import SetPasswordHandler
+from synapse.handlers.sliding_sync import SlidingSyncHandler
from synapse.handlers.sso import SsoHandler
from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler
@@ -554,6 +555,9 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_sync_handler(self) -> SyncHandler:
return SyncHandler(self)
+ def get_sliding_sync_handler(self) -> SlidingSyncHandler:
+ return SlidingSyncHandler(self)
+
@cache_in_self
def get_room_list_handler(self) -> RoomListHandler:
return RoomListHandler(self)
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index 84699a2ee1..d0e015bf19 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -617,6 +617,17 @@ class EventsPersistenceStorageController:
room_id, chunk
)
+ with Measure(self._clock, "calculate_chain_cover_index_for_events"):
+ # We now calculate chain ID/sequence numbers for any state events we're
+ # persisting. We ignore out of band memberships as we're not in the room
+ # and won't have their auth chain (we'll fix it up later if we join the
+ # room).
+ #
+ # See: docs/auth_chain_difference_algorithm.md
+ new_event_links = await self.persist_events_store.calculate_chain_cover_index_for_events(
+ room_id, [e for e, _ in chunk]
+ )
+
await self.persist_events_store._persist_events_and_state_updates(
room_id,
chunk,
@@ -624,6 +635,7 @@ class EventsPersistenceStorageController:
new_forward_extremities=new_forward_extremities,
use_negative_stream_ordering=backfilled,
inhibit_local_membership_updates=backfilled,
+ new_event_links=new_event_links,
)
return replaced_events
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index f9eced23bf..cc9b162ae4 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -45,7 +45,7 @@ from synapse.storage.util.partial_state_events_tracker import (
PartialStateEventsTracker,
)
from synapse.synapse_rust.acl import ServerAclEvaluator
-from synapse.types import MutableStateMap, StateMap, get_domain_from_id
+from synapse.types import MutableStateMap, StateMap, StreamToken, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
@@ -372,6 +372,91 @@ class StateStorageController:
)
return state_map[event_id]
+ async def get_state_after_event(
+ self,
+ event_id: str,
+ state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
+ ) -> StateMap[str]:
+ """
+ Get the room state after the given event
+
+ Args:
+ event_id: event of interest
+ state_filter: The state filter used to fetch state from the database.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at the event and `state_filter` is not satisfied by partial state.
+ Defaults to `True`.
+ """
+ state_ids = await self.get_state_ids_for_event(
+ event_id,
+ state_filter=state_filter or StateFilter.all(),
+ await_full_state=await_full_state,
+ )
+
+ # using get_metadata_for_events here (instead of get_event) sidesteps an issue
+ # with redactions: if `event_id` is a redaction event, and we don't have the
+ # original (possibly because it got purged), get_event will refuse to return
+ # the redaction event, which isn't terribly helpful here.
+ #
+ # (To be fair, in that case we could assume it's *not* a state event, and
+ # therefore we don't need to worry about it. But still, it seems cleaner just
+ # to pull the metadata.)
+ m = (await self.stores.main.get_metadata_for_events([event_id]))[event_id]
+ if m.state_key is not None and m.rejection_reason is None:
+ state_ids = dict(state_ids)
+ state_ids[(m.event_type, m.state_key)] = event_id
+
+ return state_ids
+
+ async def get_state_at(
+ self,
+ room_id: str,
+ stream_position: StreamToken,
+ state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
+ ) -> StateMap[str]:
+ """Get the room state at a particular stream position
+
+ Args:
+ room_id: room for which to get state
+ stream_position: point at which to get state
+ state_filter: The state filter used to fetch state from the database.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at the last event in the room before `stream_position` and
+ `state_filter` is not satisfied by partial state. Defaults to `True`.
+ """
+ # FIXME: This gets the state at the latest event before the stream ordering,
+ # which might not be the same as the "current state" of the room at the time
+ # of the stream token if there were multiple forward extremities at the time.
+ last_event_id = (
+ await self.stores.main.get_last_event_id_in_room_before_stream_ordering(
+ room_id,
+ end_token=stream_position.room_key,
+ )
+ )
+
+ if last_event_id:
+ state = await self.get_state_after_event(
+ last_event_id,
+ state_filter=state_filter or StateFilter.all(),
+ await_full_state=await_full_state,
+ )
+
+ else:
+ # no events in this room - so presumably no state
+ state = {}
+
+ # (erikj) This should be rarely hit, but we've had some reports that
+ # we get more state down gappy syncs than we should, so let's add
+ # some logging.
+ logger.info(
+ "Failed to find any events in room %s at %s",
+ room_id,
+ stream_position.room_key,
+ )
+ return state
+
@trace
@tag_args
async def get_state_for_groups(
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d9c85e411e..569f618193 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -2461,7 +2461,11 @@ class DatabasePool:
def make_in_list_sql_clause(
- database_engine: BaseDatabaseEngine, column: str, iterable: Collection[Any]
+ database_engine: BaseDatabaseEngine,
+ column: str,
+ iterable: Collection[Any],
+ *,
+ negative: bool = False,
) -> Tuple[str, list]:
"""Returns an SQL clause that checks the given column is in the iterable.
@@ -2474,6 +2478,7 @@ def make_in_list_sql_clause(
database_engine
column: Name of the column
iterable: The values to check the column against.
+ negative: Whether we should check for inequality, i.e. `NOT IN`
Returns:
A tuple of SQL query and the args
@@ -2482,9 +2487,19 @@ def make_in_list_sql_clause(
if database_engine.supports_using_any_list:
# This should hopefully be faster, but also makes postgres query
# stats easier to understand.
- return "%s = ANY(?)" % (column,), [list(iterable)]
+ if not negative:
+ clause = f"{column} = ANY(?)"
+ else:
+ clause = f"{column} != ALL(?)"
+
+ return clause, [list(iterable)]
else:
- return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
+ params = ",".join("?" for _ in iterable)
+ if not negative:
+ clause = f"{column} IN ({params})"
+ else:
+ clause = f"{column} NOT IN ({params})"
+ return clause, list(iterable)
# These overloads ensure that `columns` and `iterable` values have the same length.
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 563450a97e..9611a84932 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -43,11 +43,9 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
-from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
MultiWriterIdGenerator,
- StreamIdGenerator,
)
from synapse.types import JsonDict, JsonMapping
from synapse.util import json_encoder
@@ -75,37 +73,20 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
self._account_data_id_gen: AbstractStreamIdGenerator
- if isinstance(database.engine, PostgresEngine):
- self._account_data_id_gen = MultiWriterIdGenerator(
- db_conn=db_conn,
- db=database,
- notifier=hs.get_replication_notifier(),
- stream_name="account_data",
- instance_name=self._instance_name,
- tables=[
- ("room_account_data", "instance_name", "stream_id"),
- ("room_tags_revisions", "instance_name", "stream_id"),
- ("account_data", "instance_name", "stream_id"),
- ],
- sequence_name="account_data_sequence",
- writers=hs.config.worker.writers.account_data,
- )
- else:
- # Multiple writers are not supported for SQLite.
- #
- # We shouldn't be running in worker mode with SQLite, but its useful
- # to support it for unit tests.
- self._account_data_id_gen = StreamIdGenerator(
- db_conn,
- hs.get_replication_notifier(),
- "room_account_data",
- "stream_id",
- extra_tables=[
- ("account_data", "stream_id"),
- ("room_tags_revisions", "stream_id"),
- ],
- is_writer=self._instance_name in hs.config.worker.writers.account_data,
- )
+ self._account_data_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ notifier=hs.get_replication_notifier(),
+ stream_name="account_data",
+ instance_name=self._instance_name,
+ tables=[
+ ("room_account_data", "instance_name", "stream_id"),
+ ("room_tags_revisions", "instance_name", "stream_id"),
+ ("account_data", "instance_name", "stream_id"),
+ ],
+ sequence_name="account_data_sequence",
+ writers=hs.config.worker.writers.account_data,
+ )
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index bfd492d95d..c6787faea0 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -318,7 +318,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._invalidate_local_get_event_cache(redacts) # type: ignore[attr-defined]
# Caches which might leak edits must be invalidated for the event being
# redacted.
- self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,))
+ self._attempt_to_invalidate_cache(
+ "get_relations_for_event",
+ (
+ room_id,
+ redacts,
+ ),
+ )
self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,))
self._attempt_to_invalidate_cache("get_thread_id", (redacts,))
self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,))
@@ -345,7 +351,13 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
)
if relates_to:
- self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,))
+ self._attempt_to_invalidate_cache(
+ "get_relations_for_event",
+ (
+ room_id,
+ relates_to,
+ ),
+ )
self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,))
self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
@@ -380,9 +392,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache(
"get_unread_event_push_actions_by_room_for_user", (room_id,)
)
+ self._attempt_to_invalidate_cache("get_relations_for_event", (room_id,))
self._attempt_to_invalidate_cache("_get_membership_from_event_id", None)
- self._attempt_to_invalidate_cache("get_relations_for_event", None)
self._attempt_to_invalidate_cache("get_applicable_edit", None)
self._attempt_to_invalidate_cache("get_thread_id", None)
self._attempt_to_invalidate_cache("get_thread_id_for_receipts", None)
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index e17821ff6e..07333efff8 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -50,16 +50,15 @@ from synapse.storage.database import (
LoggingTransaction,
make_in_list_sql_clause,
)
-from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
MultiWriterIdGenerator,
- StreamIdGenerator,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -89,35 +88,23 @@ class DeviceInboxWorkerStore(SQLBaseStore):
expiry_ms=30 * 60 * 1000,
)
- if isinstance(database.engine, PostgresEngine):
- self._can_write_to_device = (
- self._instance_name in hs.config.worker.writers.to_device
- )
+ self._can_write_to_device = (
+ self._instance_name in hs.config.worker.writers.to_device
+ )
- self._to_device_msg_id_gen: AbstractStreamIdGenerator = (
- MultiWriterIdGenerator(
- db_conn=db_conn,
- db=database,
- notifier=hs.get_replication_notifier(),
- stream_name="to_device",
- instance_name=self._instance_name,
- tables=[
- ("device_inbox", "instance_name", "stream_id"),
- ("device_federation_outbox", "instance_name", "stream_id"),
- ],
- sequence_name="device_inbox_sequence",
- writers=hs.config.worker.writers.to_device,
- )
- )
- else:
- self._can_write_to_device = True
- self._to_device_msg_id_gen = StreamIdGenerator(
- db_conn,
- hs.get_replication_notifier(),
- "device_inbox",
- "stream_id",
- extra_tables=[("device_federation_outbox", "stream_id")],
- )
+ self._to_device_msg_id_gen: AbstractStreamIdGenerator = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ notifier=hs.get_replication_notifier(),
+ stream_name="to_device",
+ instance_name=self._instance_name,
+ tables=[
+ ("device_inbox", "instance_name", "stream_id"),
+ ("device_federation_outbox", "instance_name", "stream_id"),
+ ],
+ sequence_name="device_inbox_sequence",
+ writers=hs.config.worker.writers.to_device,
+ )
max_device_inbox_id = self._to_device_msg_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
@@ -978,6 +965,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
+ CLEANUP_DEVICE_FEDERATION_OUTBOX = "cleanup_device_federation_outbox"
def __init__(
self,
@@ -1003,6 +991,11 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
self._remove_dead_devices_from_device_inbox,
)
+ self.db_pool.updates.register_background_update_handler(
+ self.CLEANUP_DEVICE_FEDERATION_OUTBOX,
+ self._cleanup_device_federation_outbox,
+ )
+
async def _background_drop_index_device_inbox(
self, progress: JsonDict, batch_size: int
) -> int:
@@ -1094,6 +1087,75 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
return batch_size
+ async def _cleanup_device_federation_outbox(
+ self,
+ progress: JsonDict,
+ batch_size: int,
+ ) -> int:
+ def _cleanup_device_federation_outbox_txn(
+ txn: LoggingTransaction,
+ ) -> bool:
+ if "max_stream_id" in progress:
+ max_stream_id = progress["max_stream_id"]
+ else:
+ txn.execute("SELECT max(stream_id) FROM device_federation_outbox")
+ res = cast(Tuple[Optional[int]], txn.fetchone())
+ if res[0] is None:
+ # this can only happen if the `device_inbox` table is empty, in which
+ # case we have no work to do.
+ return True
+ else:
+ max_stream_id = res[0]
+
+ start = progress.get("stream_id", 0)
+ stop = start + batch_size
+
+ sql = """
+ SELECT destination FROM device_federation_outbox
+ WHERE ? < stream_id AND stream_id <= ?
+ """
+
+ txn.execute(sql, (start, stop))
+
+ destinations = {d for d, in txn}
+ to_remove = set()
+ for d in destinations:
+ try:
+ parse_and_validate_server_name(d)
+ except ValueError:
+ to_remove.add(d)
+
+ self.db_pool.simple_delete_many_txn(
+ txn,
+ table="device_federation_outbox",
+ column="destination",
+ values=to_remove,
+ keyvalues={},
+ )
+
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ self.CLEANUP_DEVICE_FEDERATION_OUTBOX,
+ {
+ "stream_id": stop,
+ "max_stream_id": max_stream_id,
+ },
+ )
+
+ return stop >= max_stream_id
+
+ finished = await self.db_pool.runInteraction(
+ "_cleanup_device_federation_outbox",
+ _cleanup_device_federation_outbox_txn,
+ )
+
+ if finished:
+ await self.db_pool.updates._end_background_update(
+ self.CLEANUP_DEVICE_FEDERATION_OUTBOX,
+ )
+
+ return batch_size
+
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
pass
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 8dbcb3f5a0..59a035dd62 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -57,10 +57,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import (
- AbstractStreamIdGenerator,
- StreamIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import (
JsonDict,
JsonMapping,
@@ -70,10 +67,7 @@ from synapse.types import (
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
-from synapse.util.caches.stream_change_cache import (
- AllEntitiesChangedResult,
- StreamChangeCache,
-)
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -102,19 +96,26 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
- self._device_list_id_gen = StreamIdGenerator(
- db_conn,
- hs.get_replication_notifier(),
- "device_lists_stream",
- "stream_id",
- extra_tables=[
- ("user_signature_stream", "stream_id"),
- ("device_lists_outbound_pokes", "stream_id"),
- ("device_lists_changes_in_room", "stream_id"),
- ("device_lists_remote_pending", "stream_id"),
- ("device_lists_changes_converted_stream_position", "stream_id"),
+ self._device_list_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ notifier=hs.get_replication_notifier(),
+ stream_name="device_lists_stream",
+ instance_name=self._instance_name,
+ tables=[
+ ("device_lists_stream", "instance_name", "stream_id"),
+ ("user_signature_stream", "instance_name", "stream_id"),
+ ("device_lists_outbound_pokes", "instance_name", "stream_id"),
+ ("device_lists_changes_in_room", "instance_name", "stream_id"),
+ ("device_lists_remote_pending", "instance_name", "stream_id"),
+ (
+ "device_lists_changes_converted_stream_position",
+ "instance_name",
+ "stream_id",
+ ),
],
- is_writer=hs.config.worker.worker_app is None,
+ sequence_name="device_lists_sequence",
+ writers=["master"],
)
device_list_max = self._device_list_id_gen.get_current_token()
@@ -132,6 +133,20 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
prefilled_cache=device_list_prefill,
)
+ device_list_room_prefill, min_device_list_room_id = self.db_pool.get_cache_dict(
+ db_conn,
+ "device_lists_changes_in_room",
+ entity_column="room_id",
+ stream_column="stream_id",
+ max_value=device_list_max,
+ limit=10000,
+ )
+ self._device_list_room_stream_cache = StreamChangeCache(
+ "DeviceListRoomStreamChangeCache",
+ min_device_list_room_id,
+ prefilled_cache=device_list_room_prefill,
+ )
+
(
user_signature_stream_prefill,
user_signature_stream_list_id,
@@ -149,22 +164,24 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
prefilled_cache=user_signature_stream_prefill,
)
- (
- device_list_federation_prefill,
- device_list_federation_list_id,
- ) = self.db_pool.get_cache_dict(
- db_conn,
- "device_lists_outbound_pokes",
- entity_column="destination",
- stream_column="stream_id",
- max_value=device_list_max,
- limit=10000,
- )
- self._device_list_federation_stream_cache = StreamChangeCache(
- "DeviceListFederationStreamChangeCache",
- device_list_federation_list_id,
- prefilled_cache=device_list_federation_prefill,
- )
+ self._device_list_federation_stream_cache = None
+ if hs.should_send_federation():
+ (
+ device_list_federation_prefill,
+ device_list_federation_list_id,
+ ) = self.db_pool.get_cache_dict(
+ db_conn,
+ "device_lists_outbound_pokes",
+ entity_column="destination",
+ stream_column="stream_id",
+ max_value=device_list_max,
+ limit=10000,
+ )
+ self._device_list_federation_stream_cache = StreamChangeCache(
+ "DeviceListFederationStreamChangeCache",
+ device_list_federation_list_id,
+ prefilled_cache=device_list_federation_prefill,
+ )
if hs.config.worker.run_background_tasks:
self._clock.looping_call(
@@ -192,23 +209,37 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
) -> None:
for row in rows:
if row.is_signature:
- self._user_signature_stream_cache.entity_has_changed(row.entity, token)
+ self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
continue
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
- if row.entity.startswith("@"):
- self._device_list_stream_cache.entity_has_changed(row.entity, token)
- self.get_cached_devices_for_user.invalidate((row.entity,))
- self._get_cached_user_device.invalidate((row.entity,))
- self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
-
- else:
- self._device_list_federation_stream_cache.entity_has_changed(
- row.entity, token
+ if not row.hosts_calculated:
+ self._device_list_stream_cache.entity_has_changed(row.user_id, token)
+ self.get_cached_devices_for_user.invalidate((row.user_id,))
+ self._get_cached_user_device.invalidate((row.user_id,))
+ self.get_device_list_last_stream_id_for_remote.invalidate(
+ (row.user_id,)
)
+ def device_lists_outbound_pokes_have_changed(
+ self, destinations: StrCollection, token: int
+ ) -> None:
+ assert self._device_list_federation_stream_cache is not None
+
+ for destination in destinations:
+ self._device_list_federation_stream_cache.entity_has_changed(
+ destination, token
+ )
+
+ def device_lists_in_rooms_have_changed(
+ self, room_ids: StrCollection, token: int
+ ) -> None:
+ "Record that device lists have changed in rooms"
+ for room_id in room_ids:
+ self._device_list_room_stream_cache.entity_has_changed(room_id, token)
+
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
@@ -341,6 +372,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
EDU contents.
"""
now_stream_id = self.get_device_stream_token()
+ if from_stream_id == now_stream_id:
+ return now_stream_id, []
+
+ if self._device_list_federation_stream_cache is None:
+ raise Exception("Func can only be used on federation senders")
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
@@ -744,6 +780,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"stream_id": stream_id,
"from_user_id": from_user_id,
"user_ids": json_encoder.encode(user_ids),
+ "instance_name": self._instance_name,
},
)
@@ -832,16 +869,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
return {device[0]: db_to_json(device[1]) for device in devices}
- def get_cached_device_list_changes(
- self,
- from_key: int,
- ) -> AllEntitiesChangedResult:
- """Get set of users whose devices have changed since `from_key`, or None
- if that information is not in our cache.
- """
-
- return self._device_list_stream_cache.get_all_entities_changed(from_key)
-
@cancellable
async def get_all_devices_changed(
self,
@@ -1005,10 +1032,10 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
# This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries.
sql = """
- SELECT stream_id, entity FROM (
- SELECT stream_id, user_id AS entity FROM device_lists_stream
+ SELECT stream_id, user_id, hosts FROM (
+ SELECT stream_id, user_id, false AS hosts FROM device_lists_stream
UNION ALL
- SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+ SELECT DISTINCT stream_id, user_id, true AS hosts FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
@@ -1457,7 +1484,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
@cancellable
async def get_device_list_changes_in_rooms(
- self, room_ids: Collection[str], from_id: int
+ self, room_ids: Collection[str], from_id: int, to_id: int
) -> Optional[Set[str]]:
"""Return the set of users whose devices have changed in the given rooms
since the given stream ID.
@@ -1473,9 +1500,15 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
if min_stream_id > from_id:
return None
+ changed_room_ids = self._device_list_room_stream_cache.get_entities_changed(
+ room_ids, from_id
+ )
+ if not changed_room_ids:
+ return set()
+
sql = """
SELECT DISTINCT user_id FROM device_lists_changes_in_room
- WHERE {clause} AND stream_id >= ?
+ WHERE {clause} AND stream_id > ? AND stream_id <= ?
"""
def _get_device_list_changes_in_rooms_txn(
@@ -1487,11 +1520,12 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return {user_id for user_id, in txn}
changes = set()
- for chunk in batch_iter(room_ids, 1000):
+ for chunk in batch_iter(changed_room_ids, 1000):
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", chunk
)
args.append(from_id)
+ args.append(to_id)
changes |= await self.db_pool.runInteraction(
"get_device_list_changes_in_rooms",
@@ -1502,6 +1536,34 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return changes
+ async def get_all_device_list_changes(self, from_id: int, to_id: int) -> Set[str]:
+ """Return the set of rooms where devices have changed since the given
+ stream ID.
+
+ Will raise an exception if the given stream ID is too old.
+ """
+
+ min_stream_id = await self._get_min_device_lists_changes_in_room()
+
+ if min_stream_id > from_id:
+ raise Exception("stream ID is too old")
+
+ sql = """
+ SELECT DISTINCT room_id FROM device_lists_changes_in_room
+ WHERE stream_id > ? AND stream_id <= ?
+ """
+
+ def _get_all_device_list_changes_txn(
+ txn: LoggingTransaction,
+ ) -> Set[str]:
+ txn.execute(sql, (from_id, to_id))
+ return {room_id for room_id, in txn}
+
+ return await self.db_pool.runInteraction(
+ "get_all_device_list_changes",
+ _get_all_device_list_changes_txn,
+ )
+
async def get_device_list_changes_in_room(
self, room_id: str, min_stream_id: int
) -> Collection[Tuple[str, str]]:
@@ -1529,6 +1591,14 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
get_device_list_changes_in_room_txn,
)
+ async def get_destinations_for_device(self, stream_id: int) -> StrCollection:
+ return await self.db_pool.simple_select_onecol(
+ table="device_lists_outbound_pokes",
+ keyvalues={"stream_id": stream_id},
+ retcol="destination",
+ desc="get_destinations_for_device",
+ )
+
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(
@@ -1539,6 +1609,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
+ self._instance_name = hs.get_instance_name()
+
self.db_pool.updates.register_background_index_update(
"device_lists_stream_idx",
index_name="device_lists_stream_user_id",
@@ -1651,6 +1723,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
"device_lists_outbound_pokes",
{
"stream_id": stream_id,
+ "instance_name": self._instance_name,
"destination": destination,
"user_id": user_id,
"device_id": device_id,
@@ -1687,10 +1760,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
- # Because we have write access, this will be a StreamIdGenerator
- # (see DeviceWorkerStore.__init__)
- _device_list_id_gen: AbstractStreamIdGenerator
-
def __init__(
self,
database: DatabasePool,
@@ -1962,8 +2031,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
async def add_device_change_to_streams(
self,
user_id: str,
- device_ids: Collection[str],
- room_ids: Collection[str],
+ device_ids: StrCollection,
+ room_ids: StrCollection,
) -> Optional[int]:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
@@ -2049,9 +2118,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_stream",
- keys=("stream_id", "user_id", "device_id"),
+ keys=("instance_name", "stream_id", "user_id", "device_id"),
values=[
- (stream_id, user_id, device_id)
+ (self._instance_name, stream_id, user_id, device_id)
for stream_id, device_id in zip(stream_ids, device_ids)
],
)
@@ -2062,18 +2131,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_id: str,
hosts: Collection[str],
- stream_ids: List[int],
+ stream_id: int,
context: Optional[Dict[str, str]],
) -> None:
- for host in hosts:
- txn.call_after(
- self._device_list_federation_stream_cache.entity_has_changed,
- host,
- stream_ids[-1],
- )
+ if self._device_list_federation_stream_cache:
+ for host in hosts:
+ txn.call_after(
+ self._device_list_federation_stream_cache.entity_has_changed,
+ host,
+ stream_id,
+ )
now = self._clock.time_msec()
- stream_id_iterator = iter(stream_ids)
encoded_context = json_encoder.encode(context)
mark_sent = not self.hs.is_mine_id(user_id)
@@ -2081,7 +2150,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values = [
(
destination,
- next(stream_id_iterator),
+ self._instance_name,
+ stream_id,
user_id,
device_id,
mark_sent,
@@ -2096,6 +2166,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_outbound_pokes",
keys=(
"destination",
+ "instance_name",
"stream_id",
"user_id",
"device_id",
@@ -2114,16 +2185,40 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id,
{
stream_id: destination
- for (destination, stream_id, _, _, _, _, _) in values
+ for (destination, _, stream_id, _, _, _, _, _) in values
},
)
+ async def mark_redundant_device_lists_pokes(
+ self,
+ user_id: str,
+ device_id: str,
+ room_id: str,
+ converted_upto_stream_id: int,
+ ) -> None:
+ """If we've calculated the outbound pokes for a given room/device list
+ update, mark any subsequent changes as already converted"""
+
+ sql = """
+ UPDATE device_lists_changes_in_room
+ SET converted_to_destinations = true
+ WHERE stream_id > ? AND user_id = ? AND device_id = ?
+ AND room_id = ? AND NOT converted_to_destinations
+ """
+
+ def mark_redundant_device_lists_pokes_txn(txn: LoggingTransaction) -> None:
+ txn.execute(sql, (converted_upto_stream_id, user_id, device_id, room_id))
+
+ return await self.db_pool.runInteraction(
+ "mark_redundant_device_lists_pokes", mark_redundant_device_lists_pokes_txn
+ )
+
def _add_device_outbound_room_poke_txn(
self,
txn: LoggingTransaction,
user_id: str,
- device_ids: Iterable[str],
- room_ids: Collection[str],
+ device_ids: StrCollection,
+ room_ids: StrCollection,
stream_ids: List[int],
context: Dict[str, str],
) -> None:
@@ -2143,6 +2238,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"device_id",
"room_id",
"stream_id",
+ "instance_name",
"converted_to_destinations",
"opentracing_context",
),
@@ -2152,6 +2248,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id,
room_id,
stream_id,
+ self._instance_name,
# We only need to calculate outbound pokes for local users
not self.hs.is_mine_id(user_id),
encoded_context,
@@ -2161,6 +2258,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
],
)
+ txn.call_after(
+ self.device_lists_in_rooms_have_changed, room_ids, max(stream_ids)
+ )
+
async def get_uncoverted_outbound_room_pokes(
self, start_stream_id: int, start_room_id: str, limit: int = 10
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
@@ -2235,22 +2336,22 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return
def add_device_list_outbound_pokes_txn(
- txn: LoggingTransaction, stream_ids: List[int]
+ txn: LoggingTransaction, stream_id: int
) -> None:
self._add_device_outbound_poke_to_stream_txn(
txn,
user_id=user_id,
device_id=device_id,
hosts=hosts,
- stream_ids=stream_ids,
+ stream_id=stream_id,
context=context,
)
- async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
+ async with self._device_list_id_gen.get_next() as stream_id:
return await self.db_pool.runInteraction(
"add_device_list_outbound_pokes",
add_device_list_outbound_pokes_txn,
- stream_ids,
+ stream_id,
)
async def add_remote_device_list_to_pending(
@@ -2267,7 +2368,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"user_id": user_id,
"device_id": device_id,
},
- values={"stream_id": stream_id},
+ values={
+ "stream_id": stream_id,
+ "instance_name": self._instance_name,
+ },
desc="add_remote_device_list_to_pending",
)
@@ -2317,15 +2421,16 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
`FALSE` have not been converted.
"""
- return cast(
- Tuple[int, str],
- await self.db_pool.simple_select_one(
- table="device_lists_changes_converted_stream_position",
- keyvalues={},
- retcols=["stream_id", "room_id"],
- desc="get_device_change_last_converted_pos",
- ),
+ # There should be only one row in this table, though we want to
+ # future-proof ourselves for when we have multiple rows (one for each
+ # instance). So to handle that case we take the minimum of all rows.
+ rows = await self.db_pool.simple_select_list(
+ table="device_lists_changes_converted_stream_position",
+ keyvalues={},
+ retcols=["stream_id", "room_id"],
+ desc="get_device_change_last_converted_pos",
)
+ return cast(Tuple[int, str], min(rows))
async def set_device_change_last_converted_pos(
self,
@@ -2340,6 +2445,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
await self.db_pool.simple_update_one(
table="device_lists_changes_converted_stream_position",
keyvalues={},
- updatevalues={"stream_id": stream_id, "room_id": room_id},
+ updatevalues={
+ "stream_id": stream_id,
+ "instance_name": self._instance_name,
+ "room_id": room_id,
+ },
desc="set_device_change_last_converted_pos",
)
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index b219ea70ee..9e6c9561ae 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -58,7 +58,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import JsonDict, JsonMapping
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@@ -123,9 +123,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
if stream_name == DeviceListsStream.NAME:
for row in rows:
assert isinstance(row, DeviceListsStream.DeviceListsStreamRow)
- if row.entity.startswith("@"):
+ if not row.hosts_calculated:
self._get_e2e_device_keys_for_federation_query_inner.invalidate(
- (row.entity,)
+ (row.user_id,)
)
super().process_replication_rows(stream_name, instance_name, token, rows)
@@ -1448,11 +1448,17 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
):
super().__init__(database, db_conn, hs)
- self._cross_signing_id_gen = StreamIdGenerator(
- db_conn,
- hs.get_replication_notifier(),
- "e2e_cross_signing_keys",
- "stream_id",
+ self._cross_signing_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ notifier=hs.get_replication_notifier(),
+ stream_name="e2e_cross_signing_keys",
+ instance_name=self._instance_name,
+ tables=[
+ ("e2e_cross_signing_keys", "instance_name", "stream_id"),
+ ],
+ sequence_name="e2e_cross_signing_keys_sequence",
+ writers=["master"],
)
async def set_e2e_device_keys(
@@ -1627,6 +1633,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"keytype": key_type,
"keydata": json_encoder.encode(key),
"stream_id": stream_id,
+ "instance_name": self._instance_name,
},
)
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 9b3ced9edb..9da9723674 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -159,6 +159,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
500000, "_event_auth_cache", size_callback=len
)
+ # Flag used by unit tests to disable fallback when there is no chain cover
+ # index.
+ self.tests_allow_no_chain_cover_index = True
+
self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)
if isinstance(self.database_engine, PostgresEngine):
@@ -231,8 +235,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
except _NoChainCoverIndex:
# For whatever reason we don't actually have a chain cover index
- # for the events in question, so we fall back to the old method.
- pass
+ # for the events in question, so we fall back to the old method
+ # (except in tests)
+ if not self.tests_allow_no_chain_cover_index:
+ raise
return await self.db_pool.runInteraction(
"get_auth_chain_ids",
@@ -282,7 +288,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
if events_missing_chain_info:
# This can happen due to e.g. downgrade/upgrade of the server. We
# raise an exception and fall back to the previous algorithm.
- logger.info(
+ logger.error(
"Unexpectedly found that events don't have chain IDs in room %s: %s",
room_id,
events_missing_chain_info,
@@ -579,8 +585,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
except _NoChainCoverIndex:
# For whatever reason we don't actually have a chain cover index
- # for the events in question, so we fall back to the old method.
- pass
+ # for the events in question, so we fall back to the old method
+ # (except in tests)
+ if not self.tests_allow_no_chain_cover_index:
+ raise
return await self.db_pool.runInteraction(
"get_auth_chain_difference",
@@ -807,7 +815,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
if events_missing_chain_info - event_to_auth_ids.keys():
# Uh oh, we somehow haven't correctly done the chain cover index,
# bail and fall back to the old method.
- logger.info(
+ logger.error(
"Unexpectedly found that events don't have chain IDs in room %s: %s",
room_id,
events_missing_chain_info - event_to_auth_ids.keys(),
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index bdd0781c48..0ebf5b53d5 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -1829,7 +1829,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
async def get_push_actions_for_user(
self,
user_id: str,
- before: Optional[str] = None,
+ before: Optional[int] = None,
limit: int = 50,
only_highlight: bool = False,
) -> List[UserPushAction]:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 990698aa5c..1f7acdb859 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -34,7 +34,6 @@ from typing import (
Optional,
Set,
Tuple,
- Union,
cast,
)
@@ -95,6 +94,27 @@ class DeltaState:
to_insert: StateMap[str]
no_longer_in_room: bool = False
+ def is_noop(self) -> bool:
+ """Whether this state delta is actually empty"""
+ return not self.to_delete and not self.to_insert and not self.no_longer_in_room
+
+
+@attr.s(slots=True, auto_attribs=True)
+class NewEventChainLinks:
+ """Information about new auth chain links that need to be added to the DB.
+
+ Attributes:
+ chain_id, sequence_number: the IDs corresponding to the event being
+ inserted, and the starting point of the links
+ links: Lists the links that need to be added, 2-tuple of the chain
+ ID/sequence number of the end point of the link.
+ """
+
+ chain_id: int
+ sequence_number: int
+
+ links: List[Tuple[int, int]] = attr.Factory(list)
+
class PersistEventsStore:
"""Contains all the functions for writing events to the database.
@@ -144,6 +164,7 @@ class PersistEventsStore:
*,
state_delta_for_room: Optional[DeltaState],
new_forward_extremities: Optional[Set[str]],
+ new_event_links: Dict[str, NewEventChainLinks],
use_negative_stream_ordering: bool = False,
inhibit_local_membership_updates: bool = False,
) -> None:
@@ -203,6 +224,7 @@ class PersistEventsStore:
async with stream_ordering_manager as stream_orderings:
for (event, _), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
+ event.internal_metadata.instance_name = self._instance_name
await self.db_pool.runInteraction(
"persist_events",
@@ -212,6 +234,7 @@ class PersistEventsStore:
inhibit_local_membership_updates=inhibit_local_membership_updates,
state_delta_for_room=state_delta_for_room,
new_forward_extremities=new_forward_extremities,
+ new_event_links=new_event_links,
)
persist_event_counter.inc(len(events_and_contexts))
@@ -238,6 +261,87 @@ class PersistEventsStore:
(room_id,), frozenset(new_forward_extremities)
)
+ async def calculate_chain_cover_index_for_events(
+ self, room_id: str, events: Collection[EventBase]
+ ) -> Dict[str, NewEventChainLinks]:
+ # Filter to state events, and ensure there are no duplicates.
+ state_events = []
+ seen_events = set()
+ for event in events:
+ if not event.is_state() or event.event_id in seen_events:
+ continue
+
+ state_events.append(event)
+ seen_events.add(event.event_id)
+
+ if not state_events:
+ return {}
+
+ return await self.db_pool.runInteraction(
+ "_calculate_chain_cover_index_for_events",
+ self.calculate_chain_cover_index_for_events_txn,
+ room_id,
+ state_events,
+ )
+
+ def calculate_chain_cover_index_for_events_txn(
+ self, txn: LoggingTransaction, room_id: str, state_events: Collection[EventBase]
+ ) -> Dict[str, NewEventChainLinks]:
+ # We now calculate chain ID/sequence numbers for any state events we're
+ # persisting. We ignore out of band memberships as we're not in the room
+ # and won't have their auth chain (we'll fix it up later if we join the
+ # room).
+ #
+ # See: docs/auth_chain_difference_algorithm.md
+
+ # We ignore legacy rooms that we aren't filling the chain cover index
+ # for.
+ row = self.db_pool.simple_select_one_txn(
+ txn,
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ retcols=("room_id", "has_auth_chain_index"),
+ allow_none=True,
+ )
+ if row is None or row[1] is False:
+ return {}
+
+ # Filter out events that we've already calculated.
+ rows = self.db_pool.simple_select_many_txn(
+ txn,
+ table="event_auth_chains",
+ column="event_id",
+ iterable=[e.event_id for e in state_events],
+ keyvalues={},
+ retcols=("event_id",),
+ )
+ already_persisted_events = {event_id for event_id, in rows}
+ state_events = [
+ event
+ for event in state_events
+ if event.event_id not in already_persisted_events
+ ]
+
+ if not state_events:
+ return {}
+
+ # We need to know the type/state_key and auth events of the events we're
+ # calculating chain IDs for. We don't rely on having the full Event
+ # instances as we'll potentially be pulling more events from the DB and
+ # we don't need the overhead of fetching/parsing the full event JSON.
+ event_to_types = {e.event_id: (e.type, e.state_key) for e in state_events}
+ event_to_auth_chain = {e.event_id: e.auth_event_ids() for e in state_events}
+ event_to_room_id = {e.event_id: e.room_id for e in state_events}
+
+ return self._calculate_chain_cover_index(
+ txn,
+ self.db_pool,
+ self.store.event_chain_id_gen,
+ event_to_room_id,
+ event_to_types,
+ event_to_auth_chain,
+ )
+
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
"""Filter the supplied list of event_ids to get those which are prev_events of
existing (non-outlier/rejected) events.
@@ -353,6 +457,7 @@ class PersistEventsStore:
inhibit_local_membership_updates: bool,
state_delta_for_room: Optional[DeltaState],
new_forward_extremities: Optional[Set[str]],
+ new_event_links: Dict[str, NewEventChainLinks],
) -> None:
"""Insert some number of room events into the necessary database tables.
@@ -461,7 +566,9 @@ class PersistEventsStore:
# Insert into event_to_state_groups.
self._store_event_state_mappings_txn(txn, events_and_contexts)
- self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
+ self._persist_event_auth_chain_txn(
+ txn, [e for e, _ in events_and_contexts], new_event_links
+ )
# _store_rejected_events_txn filters out any events which were
# rejected, and returns the filtered list.
@@ -491,7 +598,11 @@ class PersistEventsStore:
self,
txn: LoggingTransaction,
events: List[EventBase],
+ new_event_links: Dict[str, NewEventChainLinks],
) -> None:
+ if new_event_links:
+ self._persist_chain_cover_index(txn, self.db_pool, new_event_links)
+
# We only care about state events, so this if there are no state events.
if not any(e.is_state() for e in events):
return
@@ -514,62 +625,37 @@ class PersistEventsStore:
],
)
- # We now calculate chain ID/sequence numbers for any state events we're
- # persisting. We ignore out of band memberships as we're not in the room
- # and won't have their auth chain (we'll fix it up later if we join the
- # room).
- #
- # See: docs/auth_chain_difference_algorithm.md
-
- # We ignore legacy rooms that we aren't filling the chain cover index
- # for.
- rows = cast(
- List[Tuple[str, Optional[Union[int, bool]]]],
- self.db_pool.simple_select_many_txn(
- txn,
- table="rooms",
- column="room_id",
- iterable={event.room_id for event in events if event.is_state()},
- keyvalues={},
- retcols=("room_id", "has_auth_chain_index"),
- ),
- )
- rooms_using_chain_index = {
- room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
- }
-
- state_events = {
- event.event_id: event
- for event in events
- if event.is_state() and event.room_id in rooms_using_chain_index
- }
-
- if not state_events:
- return
+ @classmethod
+ def _add_chain_cover_index(
+ cls,
+ txn: LoggingTransaction,
+ db_pool: DatabasePool,
+ event_chain_id_gen: SequenceGenerator,
+ event_to_room_id: Dict[str, str],
+ event_to_types: Dict[str, Tuple[str, str]],
+ event_to_auth_chain: Dict[str, StrCollection],
+ ) -> None:
+ """Calculate and persist the chain cover index for the given events.
- # We need to know the type/state_key and auth events of the events we're
- # calculating chain IDs for. We don't rely on having the full Event
- # instances as we'll potentially be pulling more events from the DB and
- # we don't need the overhead of fetching/parsing the full event JSON.
- event_to_types = {
- e.event_id: (e.type, e.state_key) for e in state_events.values()
- }
- event_to_auth_chain = {
- e.event_id: e.auth_event_ids() for e in state_events.values()
- }
- event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
+ Args:
+ event_to_room_id: Event ID to the room ID of the event
+ event_to_types: Event ID to type and state_key of the event
+ event_to_auth_chain: Event ID to list of auth event IDs of the
+ event (events with no auth events can be excluded).
+ """
- self._add_chain_cover_index(
+ new_event_links = cls._calculate_chain_cover_index(
txn,
- self.db_pool,
- self.store.event_chain_id_gen,
+ db_pool,
+ event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
)
+ cls._persist_chain_cover_index(txn, db_pool, new_event_links)
@classmethod
- def _add_chain_cover_index(
+ def _calculate_chain_cover_index(
cls,
txn: LoggingTransaction,
db_pool: DatabasePool,
@@ -577,7 +663,7 @@ class PersistEventsStore:
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, StrCollection],
- ) -> None:
+ ) -> Dict[str, NewEventChainLinks]:
"""Calculate the chain cover index for the given events.
Args:
@@ -585,6 +671,10 @@ class PersistEventsStore:
event_to_types: Event ID to type and state_key of the event
event_to_auth_chain: Event ID to list of auth event IDs of the
event (events with no auth events can be excluded).
+
+ Returns:
+ A mapping with any new auth chain links we need to add, keyed by
+ event ID.
"""
# Map from event ID to chain ID/sequence number.
@@ -703,11 +793,11 @@ class PersistEventsStore:
room_id = event_to_room_id.get(event_id)
if room_id:
e_type, state_key = event_to_types[event_id]
- db_pool.simple_insert_txn(
+ db_pool.simple_upsert_txn(
txn,
table="event_auth_chain_to_calculate",
+ keyvalues={"event_id": event_id},
values={
- "event_id": event_id,
"room_id": room_id,
"type": e_type,
"state_key": state_key,
@@ -719,7 +809,7 @@ class PersistEventsStore:
break
if not events_to_calc_chain_id_for:
- return
+ return {}
# Allocate chain ID/sequence numbers to each new event.
new_chain_tuples = cls._allocate_chain_ids(
@@ -734,23 +824,10 @@ class PersistEventsStore:
)
chain_map.update(new_chain_tuples)
- db_pool.simple_insert_many_txn(
- txn,
- table="event_auth_chains",
- keys=("event_id", "chain_id", "sequence_number"),
- values=[
- (event_id, c_id, seq)
- for event_id, (c_id, seq) in new_chain_tuples.items()
- ],
- )
-
- db_pool.simple_delete_many_txn(
- txn,
- table="event_auth_chain_to_calculate",
- keyvalues={},
- column="event_id",
- values=new_chain_tuples,
- )
+ to_return = {
+ event_id: NewEventChainLinks(chain_id, sequence_number)
+ for event_id, (chain_id, sequence_number) in new_chain_tuples.items()
+ }
# Now we need to calculate any new links between chains caused by
# the new events.
@@ -820,10 +897,38 @@ class PersistEventsStore:
auth_chain_id, auth_sequence_number = chain_map[auth_id]
# Step 2a, add link between the event and auth event
+ to_return[event_id].links.append((auth_chain_id, auth_sequence_number))
chain_links.add_link(
(chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
)
+ return to_return
+
+ @classmethod
+ def _persist_chain_cover_index(
+ cls,
+ txn: LoggingTransaction,
+ db_pool: DatabasePool,
+ new_event_links: Dict[str, NewEventChainLinks],
+ ) -> None:
+ db_pool.simple_insert_many_txn(
+ txn,
+ table="event_auth_chains",
+ keys=("event_id", "chain_id", "sequence_number"),
+ values=[
+ (event_id, new_links.chain_id, new_links.sequence_number)
+ for event_id, new_links in new_event_links.items()
+ ],
+ )
+
+ db_pool.simple_delete_many_txn(
+ txn,
+ table="event_auth_chain_to_calculate",
+ keyvalues={},
+ column="event_id",
+ values=new_event_links,
+ )
+
db_pool.simple_insert_many_txn(
txn,
table="event_auth_chain_links",
@@ -833,7 +938,16 @@ class PersistEventsStore:
"target_chain_id",
"target_sequence_number",
),
- values=list(chain_links.get_additions()),
+ values=[
+ (
+ new_links.chain_id,
+ new_links.sequence_number,
+ target_chain_id,
+ target_sequence_number,
+ )
+ for new_links in new_event_links.values()
+ for (target_chain_id, target_sequence_number) in new_links.links
+ ],
)
@staticmethod
@@ -1017,6 +1131,9 @@ class PersistEventsStore:
) -> None:
"""Update the current state stored in the datatabase for the given room"""
+ if state_delta.is_noop():
+ return
+
async with self._stream_id_gen.get_next() as stream_ordering:
await self.db_pool.runInteraction(
"update_current_state",
@@ -1923,7 +2040,12 @@ class PersistEventsStore:
# Any relation information for the related event must be cleared.
self.store._invalidate_cache_and_stream(
- txn, self.store.get_relations_for_event, (redacted_relates_to,)
+ txn,
+ self.store.get_relations_for_event,
+ (
+ room_id,
+ redacted_relates_to,
+ ),
)
if rel_type == RelationTypes.REFERENCE:
self.store._invalidate_cache_and_stream(
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 6c979f9f2c..64d303e330 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -1181,7 +1181,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
results = list(txn)
# (event_id, parent_id, rel_type) for each relation
- relations_to_insert: List[Tuple[str, str, str]] = []
+ relations_to_insert: List[Tuple[str, str, str, str]] = []
for event_id, event_json_raw in results:
try:
event_json = db_to_json(event_json_raw)
@@ -1214,7 +1214,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
if not isinstance(parent_id, str):
continue
- relations_to_insert.append((event_id, parent_id, rel_type))
+ room_id = event_json["room_id"]
+ relations_to_insert.append((room_id, event_id, parent_id, rel_type))
# Insert the missing data, note that we upsert here in case the event
# has already been processed.
@@ -1223,18 +1224,27 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn=txn,
table="event_relations",
key_names=("event_id",),
- key_values=[(r[0],) for r in relations_to_insert],
+ key_values=[(r[1],) for r in relations_to_insert],
value_names=("relates_to_id", "relation_type"),
- value_values=[r[1:] for r in relations_to_insert],
+ value_values=[r[2:] for r in relations_to_insert],
)
# Iterate the parent IDs and invalidate caches.
- cache_tuples = {(r[1],) for r in relations_to_insert}
self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
- txn, self.get_relations_for_event, cache_tuples # type: ignore[attr-defined]
+ txn,
+ self.get_relations_for_event, # type: ignore[attr-defined]
+ {
+ (
+ r[0], # room_id
+ r[2], # parent_id
+ )
+ for r in relations_to_insert
+ },
)
self._invalidate_cache_and_stream_bulk( # type: ignore[attr-defined]
- txn, self.get_thread_summary, cache_tuples # type: ignore[attr-defined]
+ txn,
+ self.get_thread_summary, # type: ignore[attr-defined]
+ {(r[1],) for r in relations_to_insert},
)
if results:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index e39d4b9624..e264d36f02 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -75,12 +75,10 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
-from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
MultiWriterIdGenerator,
- StreamIdGenerator,
)
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
@@ -158,6 +156,7 @@ class _EventRow:
event_id: str
stream_ordering: int
+ instance_name: str
json: str
internal_metadata: str
format_version: Optional[int]
@@ -195,51 +194,35 @@ class EventsWorkerStore(SQLBaseStore):
self._stream_id_gen: AbstractStreamIdGenerator
self._backfill_id_gen: AbstractStreamIdGenerator
- if isinstance(database.engine, PostgresEngine):
- # If we're using Postgres than we can use `MultiWriterIdGenerator`
- # regardless of whether this process writes to the streams or not.
- self._stream_id_gen = MultiWriterIdGenerator(
- db_conn=db_conn,
- db=database,
- notifier=hs.get_replication_notifier(),
- stream_name="events",
- instance_name=hs.get_instance_name(),
- tables=[("events", "instance_name", "stream_ordering")],
- sequence_name="events_stream_seq",
- writers=hs.config.worker.writers.events,
- )
- self._backfill_id_gen = MultiWriterIdGenerator(
- db_conn=db_conn,
- db=database,
- notifier=hs.get_replication_notifier(),
- stream_name="backfill",
- instance_name=hs.get_instance_name(),
- tables=[("events", "instance_name", "stream_ordering")],
- sequence_name="events_backfill_stream_seq",
- positive=False,
- writers=hs.config.worker.writers.events,
- )
- else:
- # Multiple writers are not supported for SQLite.
- #
- # We shouldn't be running in worker mode with SQLite, but its useful
- # to support it for unit tests.
- self._stream_id_gen = StreamIdGenerator(
- db_conn,
- hs.get_replication_notifier(),
- "events",
- "stream_ordering",
- is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
- )
- self._backfill_id_gen = StreamIdGenerator(
- db_conn,
- hs.get_replication_notifier(),
- "events",
- "stream_ordering",
- step=-1,
- extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
- is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
- )
+
+ self._stream_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ notifier=hs.get_replication_notifier(),
+ stream_name="events",
+ instance_name=hs.get_instance_name(),
+ tables=[
+ ("events", "instance_name", "stream_ordering"),
+ ("current_state_delta_stream", "instance_name", "stream_id"),
+ ("ex_outlier_stream", "instance_name", "event_stream_ordering"),
+ ],
+ sequence_name="events_stream_seq",
+ writers=hs.config.worker.writers.events,
+ )
+ self._backfill_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ notifier=hs.get_replication_notifier(),
+ stream_name="backfill",
+ instance_name=hs.get_instance_name(),
+ tables=[
+ ("events", "instance_name", "stream_ordering"),
+ ("ex_outlier_stream", "instance_name", "event_stream_ordering"),
+ ],
+ sequence_name="events_backfill_stream_seq",
+ positive=False,
+ writers=hs.config.worker.writers.events,
+ )
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
@@ -309,27 +292,17 @@ class EventsWorkerStore(SQLBaseStore):
self._un_partial_stated_events_stream_id_gen: AbstractStreamIdGenerator
- if isinstance(database.engine, PostgresEngine):
- self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
- db_conn=db_conn,
- db=database,
- notifier=hs.get_replication_notifier(),
- stream_name="un_partial_stated_event_stream",
- instance_name=hs.get_instance_name(),
- tables=[
- ("un_partial_stated_event_stream", "instance_name", "stream_id")
- ],
- sequence_name="un_partial_stated_event_stream_sequence",
- # TODO(faster_joins, multiple writers) Support multiple writers.
- writers=["master"],
- )
- else:
- self._un_partial_stated_events_stream_id_gen = StreamIdGenerator(
- db_conn,
- hs.get_replication_notifier(),
- "un_partial_stated_event_stream",
- "stream_id",
- )
+ self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ notifier=hs.get_replication_notifier(),
+ stream_name="un_partial_stated_event_stream",
+ instance_name=hs.get_instance_name(),
+ tables=[("un_partial_stated_event_stream", "instance_name", "stream_id")],
+ sequence_name="un_partial_stated_event_stream_sequence",
+ # TODO(faster_joins, multiple writers) Support multiple writers.
+ writers=["master"],
+ )
def get_un_partial_stated_events_token(self, instance_name: str) -> int:
return (
@@ -1382,6 +1355,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected_reason=rejected_reason,
)
original_ev.internal_metadata.stream_ordering = row.stream_ordering
+ original_ev.internal_metadata.instance_name = row.instance_name
original_ev.internal_metadata.outlier = row.outlier
# Consistency check: if the content of the event has been modified in the
@@ -1467,6 +1441,7 @@ class EventsWorkerStore(SQLBaseStore):
SELECT
e.event_id,
e.stream_ordering,
+ e.instance_name,
ej.internal_metadata,
ej.json,
ej.format_version,
@@ -1490,13 +1465,14 @@ class EventsWorkerStore(SQLBaseStore):
event_dict[event_id] = _EventRow(
event_id=event_id,
stream_ordering=row[1],
- internal_metadata=row[2],
- json=row[3],
- format_version=row[4],
- room_version_id=row[5],
- rejected_reason=row[6],
+ instance_name=row[2],
+ internal_metadata=row[3],
+ json=row[4],
+ format_version=row[5],
+ room_version_id=row[6],
+ rejected_reason=row[7],
redactions=[],
- outlier=bool(row[7]), # This is an int in SQLite3
+ outlier=bool(row[8]), # This is an int in SQLite3
)
# check for redactions
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 567c2d30bd..923e764491 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -40,13 +40,11 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.engines import PostgresEngine
from synapse.storage.engines._base import IsolationLevel
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
MultiWriterIdGenerator,
- StreamIdGenerator,
)
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -91,21 +89,16 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
self._instance_name in hs.config.worker.writers.presence
)
- if isinstance(database.engine, PostgresEngine):
- self._presence_id_gen = MultiWriterIdGenerator(
- db_conn=db_conn,
- db=database,
- notifier=hs.get_replication_notifier(),
- stream_name="presence_stream",
- instance_name=self._instance_name,
- tables=[("presence_stream", "instance_name", "stream_id")],
- sequence_name="presence_stream_sequence",
- writers=hs.config.worker.writers.presence,
- )
- else:
- self._presence_id_gen = StreamIdGenerator(
- db_conn, hs.get_replication_notifier(), "presence_stream", "stream_id"
- )
+ self._presence_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ notifier=hs.get_replication_notifier(),
+ stream_name="presence_stream",
+ instance_name=self._instance_name,
+ tables=[("presence_stream", "instance_name", "stream_id")],
+ sequence_name="presence_stream_sequence",
+ writers=hs.config.worker.writers.presence,
+ )
self.hs = hs
self._presence_on_startup = self._get_active_presence(db_conn)
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 660c834518..2a39dc9f90 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -53,7 +53,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import IdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
from synapse.types import JsonDict
from synapse.util import json_encoder, unwrapFirstError
@@ -126,7 +126,7 @@ class PushRulesWorkerStore(
`get_max_push_rules_stream_id` which can be called in the initializer.
"""
- _push_rules_stream_id_gen: StreamIdGenerator
+ _push_rules_stream_id_gen: MultiWriterIdGenerator
def __init__(
self,
@@ -140,14 +140,17 @@ class PushRulesWorkerStore(
hs.get_instance_name() in hs.config.worker.writers.push_rules
)
- # In the worker store this is an ID tracker which we overwrite in the non-worker
- # class below that is used on the main process.
- self._push_rules_stream_id_gen = StreamIdGenerator(
- db_conn,
- hs.get_replication_notifier(),
- "push_rules_stream",
- "stream_id",
- is_writer=self._is_push_writer,
+ self._push_rules_stream_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ notifier=hs.get_replication_notifier(),
+ stream_name="push_rules_stream",
+ instance_name=self._instance_name,
+ tables=[
+ ("push_rules_stream", "instance_name", "stream_id"),
+ ],
+ sequence_name="push_rules_stream_sequence",
+ writers=hs.config.worker.writers.push_rules,
)
push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
@@ -880,6 +883,7 @@ class PushRulesWorkerStore(
raise Exception("Not a push writer")
values = {
+ "instance_name": self._instance_name,
"stream_id": stream_id,
"event_stream_ordering": event_stream_ordering,
"user_id": user_id,
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 39e22d3b43..a8a37b6c85 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -40,10 +40,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
-from synapse.storage.util.id_generators import (
- AbstractStreamIdGenerator,
- StreamIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -84,15 +81,20 @@ class PusherWorkerStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
- # In the worker store this is an ID tracker which we overwrite in the non-worker
- # class below that is used on the main process.
- self._pushers_id_gen = StreamIdGenerator(
- db_conn,
- hs.get_replication_notifier(),
- "pushers",
- "id",
- extra_tables=[("deleted_pushers", "stream_id")],
- is_writer=hs.config.worker.worker_app is None,
+ self._instance_name = hs.get_instance_name()
+
+ self._pushers_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ notifier=hs.get_replication_notifier(),
+ stream_name="pushers",
+ instance_name=self._instance_name,
+ tables=[
+ ("pushers", "instance_name", "id"),
+ ("deleted_pushers", "instance_name", "stream_id"),
+ ],
+ sequence_name="pushers_sequence",
+ writers=["master"],
)
self.db_pool.updates.register_background_update_handler(
@@ -655,7 +657,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
# Because we have write access, this will be a StreamIdGenerator
# (see PusherWorkerStore.__init__)
- _pushers_id_gen: AbstractStreamIdGenerator
+ _pushers_id_gen: MultiWriterIdGenerator
async def add_pusher(
self,
@@ -688,6 +690,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
"last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag,
"id": stream_id,
+ "instance_name": self._instance_name,
"enabled": enabled,
"device_id": device_id,
# XXX(quenting): We're only really persisting the access token ID
@@ -735,6 +738,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
table="deleted_pushers",
values={
"stream_id": stream_id,
+ "instance_name": self._instance_name,
"app_id": app_id,
"pushkey": pushkey,
"user_id": user_id,
@@ -773,9 +777,15 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
self.db_pool.simple_insert_many_txn(
txn,
table="deleted_pushers",
- keys=("stream_id", "app_id", "pushkey", "user_id"),
+ keys=("stream_id", "instance_name", "app_id", "pushkey", "user_id"),
values=[
- (stream_id, pusher.app_id, pusher.pushkey, user_id)
+ (
+ stream_id,
+ self._instance_name,
+ pusher.app_id,
+ pusher.pushkey,
+ user_id,
+ )
for stream_id, pusher in zip(stream_ids, pushers)
],
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 13387a3839..8432560a89 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -44,12 +44,10 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
-from synapse.storage.engines import PostgresEngine
from synapse.storage.engines._base import IsolationLevel
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
MultiWriterIdGenerator,
- StreamIdGenerator,
)
from synapse.types import (
JsonDict,
@@ -80,35 +78,20 @@ class ReceiptsWorkerStore(SQLBaseStore):
# class below that is used on the main process.
self._receipts_id_gen: AbstractStreamIdGenerator
- if isinstance(database.engine, PostgresEngine):
- self._can_write_to_receipts = (
- self._instance_name in hs.config.worker.writers.receipts
- )
+ self._can_write_to_receipts = (
+ self._instance_name in hs.config.worker.writers.receipts
+ )
- self._receipts_id_gen = MultiWriterIdGenerator(
- db_conn=db_conn,
- db=database,
- notifier=hs.get_replication_notifier(),
- stream_name="receipts",
- instance_name=self._instance_name,
- tables=[("receipts_linearized", "instance_name", "stream_id")],
- sequence_name="receipts_sequence",
- writers=hs.config.worker.writers.receipts,
- )
- else:
- self._can_write_to_receipts = True
-
- # Multiple writers are not supported for SQLite.
- #
- # We shouldn't be running in worker mode with SQLite, but its useful
- # to support it for unit tests.
- self._receipts_id_gen = StreamIdGenerator(
- db_conn,
- hs.get_replication_notifier(),
- "receipts_linearized",
- "stream_id",
- is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts,
- )
+ self._receipts_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ notifier=hs.get_replication_notifier(),
+ stream_name="receipts",
+ instance_name=self._instance_name,
+ tables=[("receipts_linearized", "instance_name", "stream_id")],
+ sequence_name="receipts_sequence",
+ writers=hs.config.worker.writers.receipts,
+ )
super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 77f3641525..29a001ff92 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -169,9 +169,9 @@ class RelationsWorkerStore(SQLBaseStore):
@cached(uncached_args=("event",), tree=True)
async def get_relations_for_event(
self,
+ room_id: str,
event_id: str,
event: EventBase,
- room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
limit: int = 5,
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 81c7bf3712..d5627b1d6e 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -21,13 +21,11 @@
#
import logging
-from abc import abstractmethod
from enum import Enum
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
- Awaitable,
Collection,
Dict,
List,
@@ -53,20 +51,18 @@ from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
from synapse.replication.tcp.streams.partial_state import UnPartialStatedRoomStream
-from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
IdGenerator,
MultiWriterIdGenerator,
- StreamIdGenerator,
)
from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
from synapse.util import json_encoder
@@ -157,27 +153,17 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator
- if isinstance(database.engine, PostgresEngine):
- self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
- db_conn=db_conn,
- db=database,
- notifier=hs.get_replication_notifier(),
- stream_name="un_partial_stated_room_stream",
- instance_name=self._instance_name,
- tables=[
- ("un_partial_stated_room_stream", "instance_name", "stream_id")
- ],
- sequence_name="un_partial_stated_room_stream_sequence",
- # TODO(faster_joins, multiple writers) Support multiple writers.
- writers=["master"],
- )
- else:
- self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator(
- db_conn,
- hs.get_replication_notifier(),
- "un_partial_stated_room_stream",
- "stream_id",
- )
+ self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ notifier=hs.get_replication_notifier(),
+ stream_name="un_partial_stated_room_stream",
+ instance_name=self._instance_name,
+ tables=[("un_partial_stated_room_stream", "instance_name", "stream_id")],
+ sequence_name="un_partial_stated_room_stream_sequence",
+ # TODO(faster_joins, multiple writers) Support multiple writers.
+ writers=["master"],
+ )
def process_replication_position(
self, stream_name: str, instance_name: str, token: int
@@ -620,6 +606,8 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
order_by: str,
reverse_order: bool,
search_term: Optional[str],
+ public_rooms: Optional[bool],
+ empty_rooms: Optional[bool],
) -> Tuple[List[Dict[str, Any]], int]:
"""Function to retrieve a paginated list of rooms as json.
@@ -631,30 +619,49 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
search_term: a string to filter room names,
canonical alias and room ids by.
Room ID must match exactly. Canonical alias must match a substring of the local part.
+ public_rooms: Optional flag to filter public and non-public rooms. If true, public rooms are queried.
+ if false, public rooms are excluded from the query. When it is
+ none (the default), both public rooms and none-public-rooms are queried.
+ empty_rooms: Optional flag to filter empty and non-empty rooms.
+ A room is empty if joined_members is zero.
+ If true, empty rooms are queried.
+ if false, empty rooms are excluded from the query. When it is
+ none (the default), both empty rooms and none-empty rooms are queried.
Returns:
A list of room dicts and an integer representing the total number of
rooms that exist given this query
"""
# Filter room names by a string
- where_statement = ""
- search_pattern: List[object] = []
+ filter_ = []
+ where_args = []
if search_term:
- where_statement = """
- WHERE LOWER(state.name) LIKE ?
- OR LOWER(state.canonical_alias) LIKE ?
- OR state.room_id = ?
- """
+ filter_ = [
+ "LOWER(state.name) LIKE ? OR "
+ "LOWER(state.canonical_alias) LIKE ? OR "
+ "state.room_id = ?"
+ ]
# Our postgres db driver converts ? -> %s in SQL strings as that's the
# placeholder for postgres.
# HOWEVER, if you put a % into your SQL then everything goes wibbly.
# To get around this, we're going to surround search_term with %'s
# before giving it to the database in python instead
- search_pattern = [
- "%" + search_term.lower() + "%",
- "#%" + search_term.lower() + "%:%",
+ where_args = [
+ f"%{search_term.lower()}%",
+ f"#%{search_term.lower()}%:%",
search_term,
]
+ if public_rooms is not None:
+ filter_arg = "1" if public_rooms else "0"
+ filter_.append(f"rooms.is_public = '{filter_arg}'")
+
+ if empty_rooms is not None:
+ if empty_rooms:
+ filter_.append("curr.joined_members = 0")
+ else:
+ filter_.append("curr.joined_members <> 0")
+
+ where_clause = "WHERE " + " AND ".join(filter_) if len(filter_) > 0 else ""
# Set ordering
if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
@@ -731,7 +738,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
LIMIT ?
OFFSET ?
""".format(
- where=where_statement,
+ where=where_clause,
order_by=order_by_column,
direction="ASC" if order_by_asc else "DESC",
)
@@ -740,10 +747,12 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
count_sql = """
SELECT count(*) FROM (
SELECT room_id FROM room_stats_state state
+ INNER JOIN room_stats_current curr USING (room_id)
+ INNER JOIN rooms USING (room_id)
{where}
) AS get_room_ids
""".format(
- where=where_statement,
+ where=where_clause,
)
def _get_rooms_paginate_txn(
@@ -751,7 +760,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
) -> Tuple[List[Dict[str, Any]], int]:
# Add the search term into the WHERE clause
# and execute the data query
- txn.execute(info_sql, search_pattern + [limit, start])
+ txn.execute(info_sql, where_args + [limit, start])
# Refactor room query data into a structured dictionary
rooms = []
@@ -781,7 +790,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
# Execute the count query
# Add the search term into the WHERE clause if present
- txn.execute(count_sql, search_pattern)
+ txn.execute(count_sql, where_args)
room_count = cast(Tuple[int], txn.fetchone())
return rooms, room_count[0]
@@ -1684,6 +1693,58 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
return True
+ async def set_room_is_public(self, room_id: str, is_public: bool) -> None:
+ await self.db_pool.simple_update_one(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"is_public": is_public},
+ desc="set_room_is_public",
+ )
+
+ async def set_room_is_public_appservice(
+ self, room_id: str, appservice_id: str, network_id: str, is_public: bool
+ ) -> None:
+ """Edit the appservice/network specific public room list.
+
+ Each appservice can have a number of published room lists associated
+ with them, keyed off of an appservice defined `network_id`, which
+ basically represents a single instance of a bridge to a third party
+ network.
+
+ Args:
+ room_id
+ appservice_id
+ network_id
+ is_public: Whether to publish or unpublish the room from the list.
+ """
+
+ if is_public:
+ await self.db_pool.simple_upsert(
+ table="appservice_room_list",
+ keyvalues={
+ "appservice_id": appservice_id,
+ "network_id": network_id,
+ "room_id": room_id,
+ },
+ values={},
+ insertion_values={
+ "appservice_id": appservice_id,
+ "network_id": network_id,
+ "room_id": room_id,
+ },
+ desc="set_room_is_public_appservice_true",
+ )
+ else:
+ await self.db_pool.simple_delete(
+ table="appservice_room_list",
+ keyvalues={
+ "appservice_id": appservice_id,
+ "network_id": network_id,
+ "room_id": room_id,
+ },
+ desc="set_room_is_public_appservice_false",
+ )
+
class _BackgroundUpdates:
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -1702,7 +1763,7 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
)
-class RoomBackgroundUpdateStore(SQLBaseStore):
+class RoomBackgroundUpdateStore(RoomWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -1935,14 +1996,6 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return len(rooms)
- @abstractmethod
- def set_room_is_public(self, room_id: str, is_public: bool) -> Awaitable[None]:
- # this will need to be implemented if a background update is performed with
- # existing (tombstoned, public) rooms in the database.
- #
- # It's overridden by RoomStore for the synapse master.
- raise NotImplementedError()
-
async def has_auth_chain_index(self, room_id: str) -> bool:
"""Check if the room has (or can have) a chain cover index.
@@ -2177,6 +2230,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
super().__init__(database, db_conn, hs)
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
+ self._room_reports_id_gen = IdGenerator(db_conn, "room_reports", "id")
self._instance_name = hs.get_instance_name()
@@ -2349,62 +2403,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
},
)
- async def set_room_is_public(self, room_id: str, is_public: bool) -> None:
- await self.db_pool.simple_update_one(
- table="rooms",
- keyvalues={"room_id": room_id},
- updatevalues={"is_public": is_public},
- desc="set_room_is_public",
- )
-
- self.hs.get_notifier().on_new_replication_data()
-
- async def set_room_is_public_appservice(
- self, room_id: str, appservice_id: str, network_id: str, is_public: bool
- ) -> None:
- """Edit the appservice/network specific public room list.
-
- Each appservice can have a number of published room lists associated
- with them, keyed off of an appservice defined `network_id`, which
- basically represents a single instance of a bridge to a third party
- network.
-
- Args:
- room_id
- appservice_id
- network_id
- is_public: Whether to publish or unpublish the room from the list.
- """
-
- if is_public:
- await self.db_pool.simple_upsert(
- table="appservice_room_list",
- keyvalues={
- "appservice_id": appservice_id,
- "network_id": network_id,
- "room_id": room_id,
- },
- values={},
- insertion_values={
- "appservice_id": appservice_id,
- "network_id": network_id,
- "room_id": room_id,
- },
- desc="set_room_is_public_appservice_true",
- )
- else:
- await self.db_pool.simple_delete(
- table="appservice_room_list",
- keyvalues={
- "appservice_id": appservice_id,
- "network_id": network_id,
- "room_id": room_id,
- },
- desc="set_room_is_public_appservice_false",
- )
-
- self.hs.get_notifier().on_new_replication_data()
-
async def add_event_report(
self,
room_id: str,
@@ -2442,6 +2440,37 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
)
return next_id
+ async def add_room_report(
+ self,
+ room_id: str,
+ user_id: str,
+ reason: str,
+ received_ts: int,
+ ) -> int:
+ """Add a room report
+
+ Args:
+ room_id: The room ID being reported.
+ user_id: User who reports the room.
+ reason: Description that the user specifies.
+ received_ts: Time when the user submitted the report (milliseconds).
+ Returns:
+ Id of the room report.
+ """
+ next_id = self._room_reports_id_gen.get_next()
+ await self.db_pool.simple_insert(
+ table="room_reports",
+ values={
+ "id": next_id,
+ "received_ts": received_ts,
+ "room_id": room_id,
+ "user_id": user_id,
+ "reason": reason,
+ },
+ desc="add_room_report",
+ )
+ return next_id
+
async def block_room(self, room_id: str, user_id: str) -> None:
"""Marks the room as blocked.
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 9fddbb2caf..d8b54dc4e3 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -476,7 +476,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
)
sql = """
- SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering, r.room_version
+ SELECT room_id, e.sender, c.membership, event_id, e.instance_name, e.stream_ordering, r.room_version
FROM local_current_membership AS c
INNER JOIN events AS e USING (room_id, event_id)
INNER JOIN rooms AS r USING (room_id)
@@ -488,7 +488,17 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
)
txn.execute(sql, (user_id, *args))
- results = [RoomsForUser(*r) for r in txn]
+ results = [
+ RoomsForUser(
+ room_id=room_id,
+ sender=sender,
+ membership=membership,
+ event_id=event_id,
+ event_pos=PersistedEventPosition(instance_name, stream_ordering),
+ room_version_id=room_version,
+ )
+ for room_id, sender, membership, event_id, instance_name, stream_ordering, room_version in txn
+ ]
return results
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 7ab6003f61..ff0d723684 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -895,7 +895,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_room_event_before_stream_ordering", _f
)
- async def get_last_event_in_room_before_stream_ordering(
+ async def get_last_event_id_in_room_before_stream_ordering(
self,
room_id: str,
end_token: RoomStreamToken,
@@ -910,16 +910,55 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
The ID of the most recent event, or None if there are no events in the room
before this stream ordering.
"""
+ last_event_result = (
+ await self.get_last_event_pos_in_room_before_stream_ordering(
+ room_id, end_token
+ )
+ )
- def get_last_event_in_room_before_stream_ordering_txn(
- txn: LoggingTransaction,
- ) -> Optional[str]:
- # We need to handle the fact that the stream tokens can be vector
- # clocks. We do this by getting all rows between the minimum and
- # maximum stream ordering in the token, plus one row less than the
- # minimum stream ordering. We then filter the results against the
- # token and return the first row that matches.
+ if last_event_result:
+ return last_event_result[0]
+
+ return None
+
+ async def get_last_event_pos_in_room_before_stream_ordering(
+ self,
+ room_id: str,
+ end_token: RoomStreamToken,
+ ) -> Optional[Tuple[str, PersistedEventPosition]]:
+ """
+ Returns the ID and event position of the last event in a room at or before a
+ stream ordering.
+ Args:
+ room_id
+ end_token: The token used to stream from
+
+ Returns:
+ The ID of the most recent event and it's position, or None if there are no
+ events in the room before this stream ordering.
+ """
+
+ def get_last_event_pos_in_room_before_stream_ordering_txn(
+ txn: LoggingTransaction,
+ ) -> Optional[Tuple[str, PersistedEventPosition]]:
+ # We're looking for the closest event at or before the token. We need to
+ # handle the fact that the stream token can be a vector clock (with an
+ # `instance_map`) and events can be persisted on different instances
+ # (sharded event persisters). The first subquery handles the events that
+ # would be within the vector clock and gets all rows between the minimum and
+ # maximum stream ordering in the token which need to be filtered against the
+ # `instance_map`. The second subquery handles the "before" case and finds
+ # the first row before the token. We then filter out any results past the
+ # token's vector clock and return the first row that matches.
+ min_stream = end_token.stream
+ max_stream = end_token.get_max_stream_pos()
+
+ # We use `union all` because we don't need any of the deduplication logic
+ # (`union` is really a union + distinct). `UNION ALL` does preserve the
+ # ordering of the operand queries but there is no actual gurantee that it
+ # has this behavior in all scenarios so we need the extra `ORDER BY` at the
+ # bottom.
sql = """
SELECT * FROM (
SELECT instance_name, stream_ordering, topological_ordering, event_id
@@ -931,7 +970,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
AND rejections.event_id IS NULL
ORDER BY stream_ordering DESC
) AS a
- UNION
+ UNION ALL
SELECT * FROM (
SELECT instance_name, stream_ordering, topological_ordering, event_id
FROM events
@@ -943,15 +982,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ORDER BY stream_ordering DESC
LIMIT 1
) AS b
+ ORDER BY stream_ordering DESC
"""
txn.execute(
sql,
(
room_id,
- end_token.stream,
- end_token.get_max_stream_pos(),
+ min_stream,
+ max_stream,
room_id,
- end_token.stream,
+ min_stream,
),
)
@@ -963,13 +1003,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
topological_ordering=topological_ordering,
stream_ordering=stream_ordering,
):
- return event_id
+ return event_id, PersistedEventPosition(
+ instance_name, stream_ordering
+ )
return None
return await self.db_pool.runInteraction(
- "get_last_event_in_room_before_stream_ordering",
- get_last_event_in_room_before_stream_ordering_txn,
+ "get_last_event_pos_in_room_before_stream_ordering",
+ get_last_event_pos_in_room_before_stream_ordering_txn,
)
async def get_current_room_stream_token_for_room_id(
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 0513e7dc06..6e18f714d7 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -1281,7 +1281,7 @@ def _parse_words_with_regex(search_term: str) -> List[str]:
Break down search term into words, when we don't have ICU available.
See: `_parse_words`
"""
- return re.findall(r"([\w\-]+)", search_term, re.UNICODE)
+ return re.findall(r"([\w-]+)", search_term, re.UNICODE)
def _parse_words_with_icu(search_term: str) -> List[str]:
@@ -1303,15 +1303,69 @@ def _parse_words_with_icu(search_term: str) -> List[str]:
if j < 0:
break
- result = search_term[i:j]
+ # We want to make sure that we split on `@` and `:` specifically, as
+ # they occur in user IDs.
+ for result in re.split(r"[@:]+", search_term[i:j]):
+ results.append(result.strip())
+
+ i = j
+
+ # libicu will break up words that have punctuation in them, but to handle
+ # cases where user IDs have '-', '.' and '_' in them we want to *not* break
+ # those into words and instead allow the DB to tokenise them how it wants.
+ #
+ # In particular, user-71 in postgres gets tokenised to "user, -71", and this
+ # will not match a query for "user, 71".
+ new_results: List[str] = []
+ i = 0
+ while i < len(results):
+ curr = results[i]
+
+ prev = None
+ next = None
+ if i > 0:
+ prev = results[i - 1]
+ if i + 1 < len(results):
+ next = results[i + 1]
+
+ i += 1
# libicu considers spaces and punctuation between words as words, but we don't
# want to include those in results as they would result in syntax errors in SQL
# queries (e.g. "foo bar" would result in the search query including "foo & &
# bar").
- if len(re.findall(r"([\w\-]+)", result, re.UNICODE)):
- results.append(result)
+ if not curr:
+ continue
+
+ if curr in ["-", ".", "_"]:
+ prefix = ""
+ suffix = ""
+
+ # Check if the next item is a word, and if so use it as the suffix.
+ # We check for if its a word as we don't want to concatenate
+ # multiple punctuation marks.
+ if next is not None and re.match(r"\w", next):
+ suffix = next
+ i += 1 # We're using next, so we skip it in the outer loop.
+ else:
+ # We want to avoid creating terms like "user-", as we should
+ # strip trailing punctuation.
+ continue
- i = j
+ if prev and re.match(r"\w", prev) and new_results:
+ prefix = new_results[-1]
+ new_results.pop()
+
+ # We might not have a prefix here, but that's fine as we want to
+ # ensure that we don't strip preceding punctuation e.g. '-71'
+ # shouldn't be converted to '71'.
+
+ new_results.append(f"{prefix}{curr}{suffix}")
+ continue
+ elif not re.match(r"\w", curr):
+ # Ignore other punctuation
+ continue
+
+ new_results.append(curr)
- return results
+ return new_results
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index b9168ee074..90641d5a18 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -142,6 +142,10 @@ class PostgresEngine(
apply stricter checks on new databases versus existing database.
"""
+ allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
+ if allow_unsafe_locale:
+ return
+
collation, ctype = self.get_db_locale(txn)
errors = []
@@ -155,7 +159,9 @@ class PostgresEngine(
if errors:
raise IncorrectDatabaseSetup(
"Database is incorrectly configured:\n\n%s\n\n"
- "See docs/postgres.md for more information." % ("\n".join(errors))
+ "See docs/postgres.md for more information. You can override this check by"
+ "setting 'allow_unsafe_locale' to true in the database config.",
+ "\n".join(errors),
)
def convert_param_style(self, sql: str) -> str:
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 7471f81a19..80c9630867 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -35,7 +35,7 @@ class RoomsForUser:
sender: str
membership: str
event_id: str
- stream_ordering: int
+ event_pos: PersistedEventPosition
room_version_id: str
diff --git a/synapse/storage/schema/main/delta/85/02_add_instance_names.sql b/synapse/storage/schema/main/delta/85/02_add_instance_names.sql
new file mode 100644
index 0000000000..d604595f73
--- /dev/null
+++ b/synapse/storage/schema/main/delta/85/02_add_instance_names.sql
@@ -0,0 +1,27 @@
+--
+-- This file is licensed under the Affero General Public License (AGPL) version 3.
+--
+-- Copyright (C) 2024 New Vector, Ltd
+--
+-- This program is free software: you can redistribute it and/or modify
+-- it under the terms of the GNU Affero General Public License as
+-- published by the Free Software Foundation, either version 3 of the
+-- License, or (at your option) any later version.
+--
+-- See the GNU Affero General Public License for more details:
+-- <https://www.gnu.org/licenses/agpl-3.0.html>.
+
+-- Add `instance_name` columns to stream tables to allow them to be used with
+-- `MultiWriterIdGenerator`
+ALTER TABLE device_lists_stream ADD COLUMN instance_name TEXT;
+ALTER TABLE user_signature_stream ADD COLUMN instance_name TEXT;
+ALTER TABLE device_lists_outbound_pokes ADD COLUMN instance_name TEXT;
+ALTER TABLE device_lists_changes_in_room ADD COLUMN instance_name TEXT;
+ALTER TABLE device_lists_remote_pending ADD COLUMN instance_name TEXT;
+
+ALTER TABLE e2e_cross_signing_keys ADD COLUMN instance_name TEXT;
+
+ALTER TABLE push_rules_stream ADD COLUMN instance_name TEXT;
+
+ALTER TABLE pushers ADD COLUMN instance_name TEXT;
+ALTER TABLE deleted_pushers ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/schema/main/delta/85/03_new_sequences.sql.postgres b/synapse/storage/schema/main/delta/85/03_new_sequences.sql.postgres
new file mode 100644
index 0000000000..9d34066bf5
--- /dev/null
+++ b/synapse/storage/schema/main/delta/85/03_new_sequences.sql.postgres
@@ -0,0 +1,54 @@
+--
+-- This file is licensed under the Affero General Public License (AGPL) version 3.
+--
+-- Copyright (C) 2024 New Vector, Ltd
+--
+-- This program is free software: you can redistribute it and/or modify
+-- it under the terms of the GNU Affero General Public License as
+-- published by the Free Software Foundation, either version 3 of the
+-- License, or (at your option) any later version.
+--
+-- See the GNU Affero General Public License for more details:
+-- <https://www.gnu.org/licenses/agpl-3.0.html>.
+
+-- Add squences for stream tables to allow them to be used with
+-- `MultiWriterIdGenerator`
+CREATE SEQUENCE IF NOT EXISTS device_lists_sequence;
+
+-- We need to take the max across all the device lists tables as they share the
+-- ID generator
+SELECT setval('device_lists_sequence', (
+ SELECT GREATEST(
+ (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_stream),
+ (SELECT COALESCE(MAX(stream_id), 1) FROM user_signature_stream),
+ (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_outbound_pokes),
+ (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_in_room),
+ (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_remote_pending),
+ (SELECT COALESCE(MAX(stream_id), 1) FROM device_lists_changes_converted_stream_position)
+ )
+));
+
+CREATE SEQUENCE IF NOT EXISTS e2e_cross_signing_keys_sequence;
+
+SELECT setval('e2e_cross_signing_keys_sequence', (
+ SELECT COALESCE(MAX(stream_id), 1) FROM e2e_cross_signing_keys
+));
+
+
+CREATE SEQUENCE IF NOT EXISTS push_rules_stream_sequence;
+
+SELECT setval('push_rules_stream_sequence', (
+ SELECT COALESCE(MAX(stream_id), 1) FROM push_rules_stream
+));
+
+
+CREATE SEQUENCE IF NOT EXISTS pushers_sequence;
+
+-- We need to take the max across all the pusher tables as they share the
+-- ID generator
+SELECT setval('pushers_sequence', (
+ SELECT GREATEST(
+ (SELECT COALESCE(MAX(id), 1) FROM pushers),
+ (SELECT COALESCE(MAX(stream_id), 1) FROM deleted_pushers)
+ )
+));
diff --git a/synapse/storage/schema/main/delta/85/04_cleanup_device_federation_outbox.sql b/synapse/storage/schema/main/delta/85/04_cleanup_device_federation_outbox.sql
new file mode 100644
index 0000000000..041b17b0ee
--- /dev/null
+++ b/synapse/storage/schema/main/delta/85/04_cleanup_device_federation_outbox.sql
@@ -0,0 +1,15 @@
+--
+-- This file is licensed under the Affero General Public License (AGPL) version 3.
+--
+-- Copyright (C) 2024 New Vector, Ltd
+--
+-- This program is free software: you can redistribute it and/or modify
+-- it under the terms of the GNU Affero General Public License as
+-- published by the Free Software Foundation, either version 3 of the
+-- License, or (at your option) any later version.
+--
+-- See the GNU Affero General Public License for more details:
+-- <https://www.gnu.org/licenses/agpl-3.0.html>.
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (8504, 'cleanup_device_federation_outbox', '{}');
diff --git a/synapse/storage/schema/main/delta/85/05_add_instance_names_converted_pos.sql b/synapse/storage/schema/main/delta/85/05_add_instance_names_converted_pos.sql
new file mode 100644
index 0000000000..c3f2b6a1dd
--- /dev/null
+++ b/synapse/storage/schema/main/delta/85/05_add_instance_names_converted_pos.sql
@@ -0,0 +1,16 @@
+--
+-- This file is licensed under the Affero General Public License (AGPL) version 3.
+--
+-- Copyright (C) 2024 New Vector, Ltd
+--
+-- This program is free software: you can redistribute it and/or modify
+-- it under the terms of the GNU Affero General Public License as
+-- published by the Free Software Foundation, either version 3 of the
+-- License, or (at your option) any later version.
+--
+-- See the GNU Affero General Public License for more details:
+-- <https://www.gnu.org/licenses/agpl-3.0.html>.
+
+-- Add `instance_name` columns to stream tables to allow them to be used with
+-- `MultiWriterIdGenerator`
+ALTER TABLE device_lists_changes_converted_stream_position ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/schema/main/delta/85/06_add_room_reports.sql b/synapse/storage/schema/main/delta/85/06_add_room_reports.sql
new file mode 100644
index 0000000000..f7b45276cf
--- /dev/null
+++ b/synapse/storage/schema/main/delta/85/06_add_room_reports.sql
@@ -0,0 +1,20 @@
+--
+-- This file is licensed under the Affero General Public License (AGPL) version 3.
+--
+-- Copyright (C) 2024 New Vector, Ltd
+--
+-- This program is free software: you can redistribute it and/or modify
+-- it under the terms of the GNU Affero General Public License as
+-- published by the Free Software Foundation, either version 3 of the
+-- License, or (at your option) any later version.
+--
+-- See the GNU Affero General Public License for more details:
+-- <https://www.gnu.org/licenses/agpl-3.0.html>.
+
+CREATE TABLE room_reports (
+ id BIGINT NOT NULL PRIMARY KEY,
+ received_ts BIGINT NOT NULL,
+ room_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ reason TEXT NOT NULL
+);
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index fadc75cc80..48f88a6f8a 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -23,15 +23,12 @@ import abc
import heapq
import logging
import threading
-from collections import OrderedDict
-from contextlib import contextmanager
from types import TracebackType
from typing import (
TYPE_CHECKING,
AsyncContextManager,
ContextManager,
Dict,
- Generator,
Generic,
Iterable,
List,
@@ -53,9 +50,11 @@ from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
+ make_in_list_sql_clause,
)
+from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
-from synapse.storage.util.sequence import PostgresSequenceGenerator
+from synapse.storage.util.sequence import build_sequence_generator
if TYPE_CHECKING:
from synapse.notifier import ReplicationNotifier
@@ -177,161 +176,6 @@ class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
raise NotImplementedError()
-class StreamIdGenerator(AbstractStreamIdGenerator):
- """Generates and tracks stream IDs for a stream with a single writer.
-
- This class must only be used when the current Synapse process is the sole
- writer for a stream.
-
- Args:
- db_conn(connection): A database connection to use to fetch the
- initial value of the generator from.
- table(str): A database table to read the initial value of the id
- generator from.
- column(str): The column of the database table to read the initial
- value from the id generator from.
- extra_tables(list): List of pairs of database tables and columns to
- use to source the initial value of the generator from. The value
- with the largest magnitude is used.
- step(int): which direction the stream ids grow in. +1 to grow
- upwards, -1 to grow downwards.
-
- Usage:
- async with stream_id_gen.get_next() as stream_id:
- # ... persist event ...
- """
-
- def __init__(
- self,
- db_conn: LoggingDatabaseConnection,
- notifier: "ReplicationNotifier",
- table: str,
- column: str,
- extra_tables: Iterable[Tuple[str, str]] = (),
- step: int = 1,
- is_writer: bool = True,
- ) -> None:
- assert step != 0
- self._lock = threading.Lock()
- self._step: int = step
- self._current: int = _load_current_id(db_conn, table, column, step)
- self._is_writer = is_writer
- for table, column in extra_tables:
- self._current = (max if step > 0 else min)(
- self._current, _load_current_id(db_conn, table, column, step)
- )
-
- # We use this as an ordered set, as we want to efficiently append items,
- # remove items and get the first item. Since we insert IDs in order, the
- # insertion ordering will ensure its in the correct ordering.
- #
- # The key and values are the same, but we never look at the values.
- self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
-
- self._notifier = notifier
-
- def advance(self, instance_name: str, new_id: int) -> None:
- # Advance should never be called on a writer instance, only over replication
- if self._is_writer:
- raise Exception("Replication is not supported by writer StreamIdGenerator")
-
- self._current = (max if self._step > 0 else min)(self._current, new_id)
-
- def get_next(self) -> AsyncContextManager[int]:
- with self._lock:
- self._current += self._step
- next_id = self._current
-
- self._unfinished_ids[next_id] = next_id
-
- @contextmanager
- def manager() -> Generator[int, None, None]:
- try:
- yield next_id
- finally:
- with self._lock:
- self._unfinished_ids.pop(next_id)
-
- self._notifier.notify_replication()
-
- return _AsyncCtxManagerWrapper(manager())
-
- def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
- with self._lock:
- next_ids = range(
- self._current + self._step,
- self._current + self._step * (n + 1),
- self._step,
- )
- self._current += n * self._step
-
- for next_id in next_ids:
- self._unfinished_ids[next_id] = next_id
-
- @contextmanager
- def manager() -> Generator[Sequence[int], None, None]:
- try:
- yield next_ids
- finally:
- with self._lock:
- for next_id in next_ids:
- self._unfinished_ids.pop(next_id)
-
- self._notifier.notify_replication()
-
- return _AsyncCtxManagerWrapper(manager())
-
- def get_next_txn(self, txn: LoggingTransaction) -> int:
- """
- Retrieve the next stream ID from within a database transaction.
-
- Clean-up functions will be called when the transaction finishes.
-
- Args:
- txn: The database transaction object.
-
- Returns:
- The next stream ID.
- """
- if not self._is_writer:
- raise Exception("Tried to allocate stream ID on non-writer")
-
- # Get the next stream ID.
- with self._lock:
- self._current += self._step
- next_id = self._current
-
- self._unfinished_ids[next_id] = next_id
-
- def clear_unfinished_id(id_to_clear: int) -> None:
- """A function to mark processing this ID as finished"""
- with self._lock:
- self._unfinished_ids.pop(id_to_clear)
-
- # Mark this ID as finished once the database transaction itself finishes.
- txn.call_after(clear_unfinished_id, next_id)
- txn.call_on_exception(clear_unfinished_id, next_id)
-
- # Return the new ID.
- return next_id
-
- def get_current_token(self) -> int:
- if not self._is_writer:
- return self._current
-
- with self._lock:
- if self._unfinished_ids:
- return next(iter(self._unfinished_ids)) - self._step
-
- return self._current
-
- def get_current_token_for_writer(self, instance_name: str) -> int:
- return self.get_current_token()
-
- def get_minimal_local_current_token(self) -> int:
- return self.get_current_token()
-
-
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
"""Generates and tracks stream IDs for a stream with multiple writers.
@@ -432,7 +276,19 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# no active writes in progress.
self._max_position_of_local_instance = self._max_seen_allocated_stream_id
- self._sequence_gen = PostgresSequenceGenerator(sequence_name)
+ self._sequence_gen = build_sequence_generator(
+ db_conn=db_conn,
+ database_engine=db.engine,
+ get_first_callback=lambda _: self._persisted_upto_position,
+ sequence_name=sequence_name,
+ # We only need to set the below if we want it to call
+ # `check_consistency`, but we do that ourselves below so we can
+ # leave them blank.
+ table=None,
+ id_column=None,
+ stream_name=None,
+ positive=positive,
+ )
# We check that the table and sequence haven't diverged.
for table, _, id_column in tables:
@@ -445,7 +301,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
)
# This goes and fills out the above state from the database.
- self._load_current_ids(db_conn, tables)
+ # This may read on the PostgreSQL sequence, and
+ # SequenceGenerator.check_consistency might have fixed up the sequence, which
+ # means the SequenceGenerator needs to be setup before we read the value from
+ # the sequence.
+ self._load_current_ids(db_conn, tables, sequence_name)
self._max_seen_allocated_stream_id = max(
self._current_positions.values(), default=1
@@ -471,6 +331,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self,
db_conn: LoggingDatabaseConnection,
tables: List[Tuple[str, str, str]],
+ sequence_name: str,
) -> None:
cur = db_conn.cursor(txn_name="_load_current_ids")
@@ -480,13 +341,17 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# important if we add back a writer after a long time; we want to
# consider that a "new" writer, rather than using the old stale
# entry here.
- sql = """
+ clause, args = make_in_list_sql_clause(
+ self._db.engine, "instance_name", self._writers, negative=True
+ )
+
+ sql = f"""
DELETE FROM stream_positions
WHERE
stream_name = ?
- AND instance_name != ALL(?)
+ AND {clause}
"""
- cur.execute(sql, (self._stream_name, self._writers))
+ cur.execute(sql, [self._stream_name] + args)
sql = """
SELECT instance_name, stream_id FROM stream_positions
@@ -500,6 +365,18 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
if instance in self._writers
}
+ # If we're a writer, we can assume we're at the end of the stream
+ # Usually, we would get that from the stream_positions, but in some cases,
+ # like if we rolled back Synapse, the stream_positions table might not be up to
+ # date. If we're using Postgres for the sequences, we can just use the current
+ # sequence value as our own position.
+ if self._instance_name in self._writers:
+ if isinstance(self._db.engine, PostgresEngine):
+ cur.execute(f"SELECT last_value FROM {sequence_name}")
+ row = cur.fetchone()
+ assert row is not None
+ self._current_positions[self._instance_name] = row[0]
+
# We set the `_persisted_upto_position` to be the minimum of all current
# positions. If empty we use the max stream ID from the DB table.
min_stream_id = min(self._current_positions.values(), default=None)
@@ -508,12 +385,16 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# We add a GREATEST here to ensure that the result is always
# positive. (This can be a problem for e.g. backfill streams where
# the server has never backfilled).
+ greatest_func = (
+ "GREATEST" if isinstance(self._db.engine, PostgresEngine) else "MAX"
+ )
max_stream_id = 1
for table, _, id_column in tables:
sql = """
- SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+ SELECT %(greatest_func)s(COALESCE(%(agg)s(%(id)s), 1), 1)
FROM %(table)s
""" % {
+ "greatest_func": greatest_func,
"id": id_column,
"table": table,
"agg": "MAX" if self._positive else "-MIN",
@@ -913,6 +794,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# We upsert the value, ensuring on conflict that we always increase the
# value (or decrease if stream goes backwards).
+ if isinstance(self._db.engine, PostgresEngine):
+ agg = "GREATEST" if self._positive else "LEAST"
+ else:
+ agg = "MAX" if self._positive else "MIN"
+
sql = """
INSERT INTO stream_positions (stream_name, instance_name, stream_id)
VALUES (?, ?, ?)
@@ -920,10 +806,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
DO UPDATE SET
stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
""" % {
- "agg": "GREATEST" if self._positive else "LEAST",
+ "agg": agg,
}
- pos = (self.get_current_token_for_writer(self._instance_name),)
+ pos = self.get_current_token_for_writer(self._instance_name)
txn.execute(sql, (self._stream_name, self._instance_name, pos))
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index f57e7ec41c..c4c0602b28 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -36,21 +36,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-_INCONSISTENT_SEQUENCE_ERROR = """
-Postgres sequence '%(seq)s' is inconsistent with associated
-table '%(table)s'. This can happen if Synapse has been downgraded and
-then upgraded again, or due to a bad migration.
-
-To fix this error, shut down Synapse (including any and all workers)
-and run the following SQL:
-
- SELECT setval('%(seq)s', (
- %(max_id_sql)s
- ));
-
-See docs/postgres.md for more information.
-"""
-
_INCONSISTENT_STREAM_ERROR = """
Postgres sequence '%(seq)s' is inconsistent with associated stream position
of '%(stream_name)s' in the 'stream_positions' table.
@@ -169,25 +154,33 @@ class PostgresSequenceGenerator(SequenceGenerator):
if row:
max_in_stream_positions = row[0]
- txn.close()
-
# If `is_called` is False then `last_value` is actually the value that
# will be generated next, so we decrement to get the true "last value".
if not is_called:
last_value -= 1
if max_stream_id > last_value:
+ # The sequence is lagging behind the tables. This is probably due to
+ # rolling back to a version before the sequence was used and then
+ # forwards again. We resolve this by setting the sequence to the
+ # right value.
logger.warning(
- "Postgres sequence %s is behind table %s: %d < %d",
+ "Postgres sequence %s is behind table %s: %d < %d. Updating sequence.",
self._sequence_name,
table,
last_value,
max_stream_id,
)
- raise IncorrectDatabaseSetup(
- _INCONSISTENT_SEQUENCE_ERROR
- % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql}
- )
+
+ sql = f"""
+ SELECT setval('{self._sequence_name}', GREATEST(
+ (SELECT last_value FROM {self._sequence_name}),
+ ({table_sql})
+ ));
+ """
+ txn.execute(sql)
+
+ txn.close()
# If we have values in the stream positions table then they have to be
# less than or equal to `last_value`
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index eeafe889de..9fee5bfb92 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -75,9 +75,6 @@ class PaginationConfig:
raise SynapseError(400, "'to' parameter is invalid")
limit = parse_integer(request, "limit", default=default_limit)
- if limit < 0:
- raise SynapseError(400, "Limit must be 0 or above")
-
limit = min(limit, MAX_LIMIT)
try:
diff --git a/synapse/synapse_rust/events.pyi b/synapse/synapse_rust/events.pyi
index 69837617f5..1682d0d151 100644
--- a/synapse/synapse_rust/events.pyi
+++ b/synapse/synapse_rust/events.pyi
@@ -19,6 +19,8 @@ class EventInternalMetadata:
stream_ordering: Optional[int]
"""the stream ordering of this event. None, until it has been persisted."""
+ instance_name: Optional[str]
+ """the instance name of the server that persisted this event. None, until it has been persisted."""
outlier: bool
"""whether this event is an outlier (ie, whether we have the state at that
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 509a2d3a0f..151658df53 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -48,7 +48,7 @@ import attr
from immutabledict import immutabledict
from signedjson.key import decode_verify_key_bytes
from signedjson.types import VerifyKey
-from typing_extensions import TypedDict
+from typing_extensions import Self, TypedDict
from unpaddedbase64 import decode_base64
from zope.interface import Interface
@@ -515,6 +515,27 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
# at `self.stream`.
return self.instance_map.get(instance_name, self.stream)
+ def is_before_or_eq(self, other_token: Self) -> bool:
+ """Wether this token is before the other token, i.e. every constituent
+ part is before the other.
+
+ Essentially it is `self <= other`.
+
+ Note: if `self.is_before_or_eq(other_token) is False` then that does not
+ imply that the reverse is True.
+ """
+ if self.stream > other_token.stream:
+ return False
+
+ instances = self.instance_map.keys() | other_token.instance_map.keys()
+ for instance in instances:
+ if self.instance_map.get(
+ instance, self.stream
+ ) > other_token.instance_map.get(instance, other_token.stream):
+ return False
+
+ return True
+
@attr.s(frozen=True, slots=True, order=False)
class RoomStreamToken(AbstractMultiWriterStreamToken):
@@ -1008,6 +1029,41 @@ class StreamToken:
"""Returns the stream ID for the given key."""
return getattr(self, key.value)
+ def is_before_or_eq(self, other_token: "StreamToken") -> bool:
+ """Wether this token is before the other token, i.e. every constituent
+ part is before the other.
+
+ Essentially it is `self <= other`.
+
+ Note: if `self.is_before_or_eq(other_token) is False` then that does not
+ imply that the reverse is True.
+ """
+
+ for _, key in StreamKeyType.__members__.items():
+ if key == StreamKeyType.TYPING:
+ # Typing stream is allowed to "reset", and so comparisons don't
+ # really make sense as is.
+ # TODO: Figure out a better way of tracking resets.
+ continue
+
+ self_value = self.get_field(key)
+ other_value = other_token.get_field(key)
+
+ if isinstance(self_value, RoomStreamToken):
+ assert isinstance(other_value, RoomStreamToken)
+ if not self_value.is_before_or_eq(other_value):
+ return False
+ elif isinstance(self_value, MultiWriterStreamToken):
+ assert isinstance(other_value, MultiWriterStreamToken)
+ if not self_value.is_before_or_eq(other_value):
+ return False
+ else:
+ assert isinstance(other_value, int)
+ if self_value > other_value:
+ return False
+
+ return True
+
StreamToken.START = StreamToken(
RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0
diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py
new file mode 100644
index 0000000000..1d65551d5b
--- /dev/null
+++ b/synapse/types/handlers/__init__.py
@@ -0,0 +1,252 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright (C) 2024 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+from enum import Enum
+from typing import TYPE_CHECKING, Dict, Final, List, Optional, Tuple
+
+import attr
+from typing_extensions import TypedDict
+
+from synapse._pydantic_compat import HAS_PYDANTIC_V2
+
+if TYPE_CHECKING or HAS_PYDANTIC_V2:
+ from pydantic.v1 import Extra
+else:
+ from pydantic import Extra
+
+from synapse.events import EventBase
+from synapse.types import JsonMapping, StreamToken, UserID
+from synapse.types.rest.client import SlidingSyncBody
+
+
+class ShutdownRoomParams(TypedDict):
+ """
+ Attributes:
+ requester_user_id:
+ User who requested the action. Will be recorded as putting the room on the
+ blocking list.
+ new_room_user_id:
+ If set, a new room will be created with this user ID
+ as the creator and admin, and all users in the old room will be
+ moved into that room. If not set, no new room will be created
+ and the users will just be removed from the old room.
+ new_room_name:
+ A string representing the name of the room that new users will
+ be invited to. Defaults to `Content Violation Notification`
+ message:
+ A string containing the first message that will be sent as
+ `new_room_user_id` in the new room. Ideally this will clearly
+ convey why the original room was shut down.
+ Defaults to `Sharing illegal content on this server is not
+ permitted and rooms in violation will be blocked.`
+ block:
+ If set to `true`, this room will be added to a blocking list,
+ preventing future attempts to join the room. Defaults to `false`.
+ purge:
+ If set to `true`, purge the given room from the database.
+ force_purge:
+ If set to `true`, the room will be purged from database
+ even if there are still users joined to the room.
+ """
+
+ requester_user_id: Optional[str]
+ new_room_user_id: Optional[str]
+ new_room_name: Optional[str]
+ message: Optional[str]
+ block: bool
+ purge: bool
+ force_purge: bool
+
+
+class ShutdownRoomResponse(TypedDict):
+ """
+ Attributes:
+ kicked_users: An array of users (`user_id`) that were kicked.
+ failed_to_kick_users:
+ An array of users (`user_id`) that that were not kicked.
+ local_aliases:
+ An array of strings representing the local aliases that were
+ migrated from the old room to the new.
+ new_room_id: A string representing the room ID of the new room.
+ """
+
+ kicked_users: List[str]
+ failed_to_kick_users: List[str]
+ local_aliases: List[str]
+ new_room_id: Optional[str]
+
+
+class SlidingSyncConfig(SlidingSyncBody):
+ """
+ Inherit from `SlidingSyncBody` since we need all of the same fields and add a few
+ extra fields that we need in the handler
+ """
+
+ user: UserID
+ device_id: Optional[str]
+
+ # Pydantic config
+ class Config:
+ # By default, ignore fields that we don't recognise.
+ extra = Extra.ignore
+ # By default, don't allow fields to be reassigned after parsing.
+ allow_mutation = False
+ # Allow custom types like `UserID` to be used in the model
+ arbitrary_types_allowed = True
+
+
+class OperationType(Enum):
+ """
+ Represents the operation types in a Sliding Sync window.
+
+ Attributes:
+ SYNC: Sets a range of entries. Clients SHOULD discard what they previous knew about
+ entries in this range.
+ INSERT: Sets a single entry. If the position is not empty then clients MUST move
+ entries to the left or the right depending on where the closest empty space is.
+ DELETE: Remove a single entry. Often comes before an INSERT to allow entries to move
+ places.
+ INVALIDATE: Remove a range of entries. Clients MAY persist the invalidated range for
+ offline support, but they should be treated as empty when additional operations
+ which concern indexes in the range arrive from the server.
+ """
+
+ SYNC: Final = "SYNC"
+ INSERT: Final = "INSERT"
+ DELETE: Final = "DELETE"
+ INVALIDATE: Final = "INVALIDATE"
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class SlidingSyncResult:
+ """
+ The Sliding Sync result to be serialized to JSON for a response.
+
+ Attributes:
+ next_pos: The next position token in the sliding window to request (next_batch).
+ lists: Sliding window API. A map of list key to list results.
+ rooms: Room subscription API. A map of room ID to room subscription to room results.
+ extensions: Extensions API. A map of extension key to extension results.
+ """
+
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class RoomResult:
+ """
+ Attributes:
+ name: Room name or calculated room name.
+ avatar: Room avatar
+ heroes: List of stripped membership events (containing `user_id` and optionally
+ `avatar_url` and `displayname`) for the users used to calculate the room name.
+ initial: Flag which is set when this is the first time the server is sending this
+ data on this connection. Clients can use this flag to replace or update
+ their local state. When there is an update, servers MUST omit this flag
+ entirely and NOT send "initial":false as this is wasteful on bandwidth. The
+ absence of this flag means 'false'.
+ required_state: The current state of the room
+ timeline: Latest events in the room. The last event is the most recent
+ is_dm: Flag to specify whether the room is a direct-message room (most likely
+ between two people).
+ invite_state: Stripped state events. Same as `rooms.invite.$room_id.invite_state`
+ in sync v2, absent on joined/left rooms
+ prev_batch: A token that can be passed as a start parameter to the
+ `/rooms/<room_id>/messages` API to retrieve earlier messages.
+ limited: True if their are more events than fit between the given position and now.
+ Sync again to get more.
+ joined_count: The number of users with membership of join, including the client's
+ own user ID. (same as sync `v2 m.joined_member_count`)
+ invited_count: The number of users with membership of invite. (same as sync v2
+ `m.invited_member_count`)
+ notification_count: The total number of unread notifications for this room. (same
+ as sync v2)
+ highlight_count: The number of unread notifications for this room with the highlight
+ flag set. (same as sync v2)
+ num_live: The number of timeline events which have just occurred and are not historical.
+ The last N events are 'live' and should be treated as such. This is mostly
+ useful to determine whether a given @mention event should make a noise or not.
+ Clients cannot rely solely on the absence of `initial: true` to determine live
+ events because if a room not in the sliding window bumps into the window because
+ of an @mention it will have `initial: true` yet contain a single live event
+ (with potentially other old events in the timeline).
+ """
+
+ name: str
+ avatar: Optional[str]
+ heroes: Optional[List[EventBase]]
+ initial: bool
+ required_state: List[EventBase]
+ timeline: List[EventBase]
+ is_dm: bool
+ invite_state: List[EventBase]
+ prev_batch: StreamToken
+ limited: bool
+ joined_count: int
+ invited_count: int
+ notification_count: int
+ highlight_count: int
+ num_live: int
+
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class SlidingWindowList:
+ """
+ Attributes:
+ count: The total number of entries in the list. Always present if this list
+ is.
+ ops: The sliding list operations to perform.
+ """
+
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class Operation:
+ """
+ Attributes:
+ op: The operation type to perform.
+ range: Which index positions are affected by this operation. These are
+ both inclusive.
+ room_ids: Which room IDs are affected by this operation. These IDs match
+ up to the positions in the `range`, so the last room ID in this list
+ matches the 9th index. The room data is held in a separate object.
+ """
+
+ op: OperationType
+ range: Tuple[int, int]
+ room_ids: List[str]
+
+ count: int
+ ops: List[Operation]
+
+ next_pos: StreamToken
+ lists: Dict[str, SlidingWindowList]
+ rooms: Dict[str, RoomResult]
+ extensions: JsonMapping
+
+ def __bool__(self) -> bool:
+ """Make the result appear empty if there are no updates. This is used
+ to tell if the notifier needs to wait for more events when polling for
+ events.
+ """
+ return bool(self.lists or self.rooms or self.extensions)
+
+ @staticmethod
+ def empty(next_pos: StreamToken) -> "SlidingSyncResult":
+ "Return a new empty result"
+ return SlidingSyncResult(
+ next_pos=next_pos,
+ lists={},
+ rooms={},
+ extensions={},
+ )
diff --git a/synapse/rest/models.py b/synapse/types/rest/__init__.py
index 2b6f5ed35a..2b6f5ed35a 100644
--- a/synapse/rest/models.py
+++ b/synapse/types/rest/__init__.py
diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py
new file mode 100644
index 0000000000..e2c79c4106
--- /dev/null
+++ b/synapse/types/rest/client/__init__.py
@@ -0,0 +1,309 @@
+#
+# This file is licensed under the Affero General Public License (AGPL) version 3.
+#
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+# Copyright (C) 2023 New Vector, Ltd
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as
+# published by the Free Software Foundation, either version 3 of the
+# License, or (at your option) any later version.
+#
+# See the GNU Affero General Public License for more details:
+# <https://www.gnu.org/licenses/agpl-3.0.html>.
+#
+# Originally licensed under the Apache License, Version 2.0:
+# <http://www.apache.org/licenses/LICENSE-2.0>.
+#
+# [This file includes modifications made by New Vector Limited]
+#
+#
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+from synapse._pydantic_compat import HAS_PYDANTIC_V2
+
+if TYPE_CHECKING or HAS_PYDANTIC_V2:
+ from pydantic.v1 import (
+ Extra,
+ StrictBool,
+ StrictInt,
+ StrictStr,
+ conint,
+ constr,
+ validator,
+ )
+else:
+ from pydantic import (
+ Extra,
+ StrictBool,
+ StrictInt,
+ StrictStr,
+ conint,
+ constr,
+ validator,
+ )
+
+from synapse.types.rest import RequestBodyModel
+from synapse.util.threepids import validate_email
+
+
+class AuthenticationData(RequestBodyModel):
+ """
+ Data used during user-interactive authentication.
+
+ (The name "Authentication Data" is taken directly from the spec.)
+
+ Additional keys will be present, depending on the `type` field. Use
+ `.dict(exclude_unset=True)` to access them.
+ """
+
+ class Config:
+ extra = Extra.allow
+
+ session: Optional[StrictStr] = None
+ type: Optional[StrictStr] = None
+
+
+if TYPE_CHECKING:
+ ClientSecretStr = StrictStr
+else:
+ # See also assert_valid_client_secret()
+ ClientSecretStr = constr(
+ regex="[0-9a-zA-Z.=_-]", # noqa: F722
+ min_length=1,
+ max_length=255,
+ strict=True,
+ )
+
+
+class ThreepidRequestTokenBody(RequestBodyModel):
+ client_secret: ClientSecretStr
+ id_server: Optional[StrictStr]
+ id_access_token: Optional[StrictStr]
+ next_link: Optional[StrictStr]
+ send_attempt: StrictInt
+
+ @validator("id_access_token", always=True)
+ def token_required_for_identity_server(
+ cls, token: Optional[str], values: Dict[str, object]
+ ) -> Optional[str]:
+ if values.get("id_server") is not None and token is None:
+ raise ValueError("id_access_token is required if an id_server is supplied.")
+ return token
+
+
+class EmailRequestTokenBody(ThreepidRequestTokenBody):
+ email: StrictStr
+
+ # Canonicalise the email address. The addresses are all stored canonicalised
+ # in the database. This allows the user to reset his password without having to
+ # know the exact spelling (eg. upper and lower case) of address in the database.
+ # Without this, an email stored in the database as "foo@bar.com" would cause
+ # user requests for "FOO@bar.com" to raise a Not Found error.
+ _email_validator = validator("email", allow_reuse=True)(validate_email)
+
+
+if TYPE_CHECKING:
+ ISO3116_1_Alpha_2 = StrictStr
+else:
+ # Per spec: two-letter uppercase ISO-3166-1-alpha-2
+ ISO3116_1_Alpha_2 = constr(regex="[A-Z]{2}", strict=True)
+
+
+class MsisdnRequestTokenBody(ThreepidRequestTokenBody):
+ country: ISO3116_1_Alpha_2
+ phone_number: StrictStr
+
+
+class SlidingSyncBody(RequestBodyModel):
+ """
+ Sliding Sync API request body.
+
+ Attributes:
+ lists: Sliding window API. A map of list key to list information
+ (:class:`SlidingSyncList`). Max lists: 100. The list keys should be
+ arbitrary strings which the client is using to refer to the list. Keep this
+ small as it needs to be sent a lot. Max length: 64 bytes.
+ room_subscriptions: Room subscription API. A map of room ID to room subscription
+ information. Used to subscribe to a specific room. Sometimes clients know
+ exactly which room they want to get information about e.g by following a
+ permalink or by refreshing a webapp currently viewing a specific room. The
+ sliding window API alone is insufficient for this use case because there's
+ no way to say "please track this room explicitly".
+ extensions: Extensions API. A map of extension key to extension config.
+ """
+
+ class CommonRoomParameters(RequestBodyModel):
+ """
+ Common parameters shared between the sliding window and room subscription APIs.
+
+ Attributes:
+ required_state: Required state for each room returned. An array of event
+ type and state key tuples. Elements in this array are ORd together to
+ produce the final set of state events to return. One unique exception is
+ when you request all state events via `["*", "*"]`. When used, all state
+ events are returned by default, and additional entries FILTER OUT the
+ returned set of state events. These additional entries cannot use `*`
+ themselves. For example, `["*", "*"], ["m.room.member",
+ "@alice:example.com"]` will *exclude* every `m.room.member` event
+ *except* for `@alice:example.com`, and include every other state event.
+ In addition, `["*", "*"], ["m.space.child", "*"]` is an error, the
+ `m.space.child` filter is not required as it would have been returned
+ anyway.
+ timeline_limit: The maximum number of timeline events to return per response.
+ (Max 1000 messages)
+ include_old_rooms: Determines if `predecessor` rooms are included in the
+ `rooms` response. The user MUST be joined to old rooms for them to show up
+ in the response.
+ """
+
+ class IncludeOldRooms(RequestBodyModel):
+ timeline_limit: StrictInt
+ required_state: List[Tuple[StrictStr, StrictStr]]
+
+ required_state: List[Tuple[StrictStr, StrictStr]]
+ # mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
+ if TYPE_CHECKING:
+ timeline_limit: int
+ else:
+ timeline_limit: conint(le=1000, strict=True) # type: ignore[valid-type]
+ include_old_rooms: Optional[IncludeOldRooms] = None
+
+ class SlidingSyncList(CommonRoomParameters):
+ """
+ Attributes:
+ ranges: Sliding window ranges. If this field is missing, no sliding window
+ is used and all rooms are returned in this list. Integers are
+ *inclusive*.
+ slow_get_all_rooms: Just get all rooms (for clients that don't want to deal with
+ sliding windows). When true, the `ranges` field is ignored.
+ required_state: Required state for each room returned. An array of event
+ type and state key tuples. Elements in this array are ORd together to
+ produce the final set of state events to return.
+
+ One unique exception is when you request all state events via `["*",
+ "*"]`. When used, all state events are returned by default, and
+ additional entries FILTER OUT the returned set of state events. These
+ additional entries cannot use `*` themselves. For example, `["*", "*"],
+ ["m.room.member", "@alice:example.com"]` will *exclude* every
+ `m.room.member` event *except* for `@alice:example.com`, and include
+ every other state event. In addition, `["*", "*"], ["m.space.child",
+ "*"]` is an error, the `m.space.child` filter is not required as it
+ would have been returned anyway.
+
+ Room members can be lazily-loaded by using the special `$LAZY` state key
+ (`["m.room.member", "$LAZY"]`). Typically, when you view a room, you
+ want to retrieve all state events except for m.room.member events which
+ you want to lazily load. To get this behaviour, clients can send the
+ following::
+
+ {
+ "required_state": [
+ // activate lazy loading
+ ["m.room.member", "$LAZY"],
+ // request all state events _except_ for m.room.member
+ events which are lazily loaded
+ ["*", "*"]
+ ]
+ }
+
+ timeline_limit: The maximum number of timeline events to return per response.
+ include_old_rooms: Determines if `predecessor` rooms are included in the
+ `rooms` response. The user MUST be joined to old rooms for them to show up
+ in the response.
+ include_heroes: Return a stripped variant of membership events (containing
+ `user_id` and optionally `avatar_url` and `displayname`) for the users used
+ to calculate the room name.
+ filters: Filters to apply to the list before sorting.
+ """
+
+ class Filters(RequestBodyModel):
+ """
+ All fields are applied with AND operators, hence if `is_dm: True` and
+ `is_encrypted: True` then only Encrypted DM rooms will be returned. The
+ absence of fields implies no filter on that criteria: it does NOT imply
+ `False`. These fields may be expanded through use of extensions.
+
+ Attributes:
+ is_dm: Flag which only returns rooms present (or not) in the DM section
+ of account data. If unset, both DM rooms and non-DM rooms are returned.
+ If False, only non-DM rooms are returned. If True, only DM rooms are
+ returned.
+ spaces: Filter the room based on the space they belong to according to
+ `m.space.child` state events. If multiple spaces are present, a room can
+ be part of any one of the listed spaces (OR'd). The server will inspect
+ the `m.space.child` state events for the JOINED space room IDs given.
+ Servers MUST NOT navigate subspaces. It is up to the client to give a
+ complete list of spaces to navigate. Only rooms directly mentioned as
+ `m.space.child` events in these spaces will be returned. Unknown spaces
+ or spaces the user is not joined to will be ignored.
+ is_encrypted: Flag which only returns rooms which have an
+ `m.room.encryption` state event. If unset, both encrypted and
+ unencrypted rooms are returned. If `False`, only unencrypted rooms are
+ returned. If `True`, only encrypted rooms are returned.
+ is_invite: Flag which only returns rooms the user is currently invited
+ to. If unset, both invited and joined rooms are returned. If `False`, no
+ invited rooms are returned. If `True`, only invited rooms are returned.
+ room_types: If specified, only rooms where the `m.room.create` event has
+ a `type` matching one of the strings in this array will be returned. If
+ this field is unset, all rooms are returned regardless of type. This can
+ be used to get the initial set of spaces for an account. For rooms which
+ do not have a room type, use `null`/`None` to include them.
+ not_room_types: Same as `room_types` but inverted. This can be used to
+ filter out spaces from the room list. If a type is in both `room_types`
+ and `not_room_types`, then `not_room_types` wins and they are not included
+ in the result.
+ room_name_like: Filter the room name. Case-insensitive partial matching
+ e.g 'foo' matches 'abFooab'. The term 'like' is inspired by SQL 'LIKE',
+ and the text here is similar to '%foo%'.
+ tags: Filter the room based on its room tags. If multiple tags are
+ present, a room can have any one of the listed tags (OR'd).
+ not_tags: Filter the room based on its room tags. Takes priority over
+ `tags`. For example, a room with tags A and B with filters `tags: [A]`
+ `not_tags: [B]` would NOT be included because `not_tags` takes priority over
+ `tags`. This filter is useful if your rooms list does NOT include the
+ list of favourite rooms again.
+ """
+
+ is_dm: Optional[StrictBool] = None
+ spaces: Optional[List[StrictStr]] = None
+ is_encrypted: Optional[StrictBool] = None
+ is_invite: Optional[StrictBool] = None
+ room_types: Optional[List[Union[StrictStr, None]]] = None
+ not_room_types: Optional[List[StrictStr]] = None
+ room_name_like: Optional[StrictStr] = None
+ tags: Optional[List[StrictStr]] = None
+ not_tags: Optional[List[StrictStr]] = None
+
+ # mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
+ if TYPE_CHECKING:
+ ranges: Optional[List[Tuple[int, int]]] = None
+ else:
+ ranges: Optional[List[Tuple[conint(ge=0, strict=True), conint(ge=0, strict=True)]]] = None # type: ignore[valid-type]
+ slow_get_all_rooms: Optional[StrictBool] = False
+ include_heroes: Optional[StrictBool] = False
+ filters: Optional[Filters] = None
+
+ class RoomSubscription(CommonRoomParameters):
+ pass
+
+ class Extension(RequestBodyModel):
+ enabled: Optional[StrictBool] = False
+ lists: Optional[List[StrictStr]] = None
+ rooms: Optional[List[StrictStr]] = None
+
+ # mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
+ if TYPE_CHECKING:
+ lists: Optional[Dict[str, SlidingSyncList]] = None
+ else:
+ lists: Optional[Dict[constr(max_length=64, strict=True), SlidingSyncList]] = None # type: ignore[valid-type]
+ room_subscriptions: Optional[Dict[StrictStr, RoomSubscription]] = None
+ extensions: Optional[Dict[StrictStr, Extension]] = None
+
+ @validator("lists")
+ def lists_length_check(
+ cls, value: Optional[Dict[str, SlidingSyncList]]
+ ) -> Optional[Dict[str, SlidingSyncList]]:
+ if value is not None:
+ assert len(value) <= 100, f"Max lists: 100 but saw {len(value)}"
+ return value
diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py
index 01d05c9ed6..448960b297 100644
--- a/synapse/util/task_scheduler.py
+++ b/synapse/util/task_scheduler.py
@@ -24,7 +24,12 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set
from twisted.python.failure import Failure
-from synapse.logging.context import nested_logging_context
+from synapse.logging.context import (
+ ContextResourceUsage,
+ LoggingContext,
+ nested_logging_context,
+ set_current_context,
+)
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import (
run_as_background_process,
@@ -81,6 +86,8 @@ class TaskScheduler:
MAX_CONCURRENT_RUNNING_TASKS = 5
# Time from the last task update after which we will log a warning
LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000 # 24hrs
+ # Report a running task's status and usage every so often.
+ OCCASIONAL_REPORT_INTERVAL_MS = 5 * 60 * 1000 # 5 minutes
def __init__(self, hs: "HomeServer"):
self._hs = hs
@@ -346,6 +353,33 @@ class TaskScheduler:
assert task.id not in self._running_tasks
await self._store.delete_scheduled_task(task.id)
+ @staticmethod
+ def _log_task_usage(
+ state: str, task: ScheduledTask, usage: ContextResourceUsage, active_time: float
+ ) -> None:
+ """
+ Log a line describing the state and usage of a task.
+ The log line is inspired by / a copy of the request log line format,
+ but with irrelevant fields removed.
+
+ active_time: Time that the task has been running for, in seconds.
+ """
+
+ logger.info(
+ "Task %s: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
+ " [%d dbevts] %r, %r",
+ state,
+ active_time,
+ usage.ru_utime,
+ usage.ru_stime,
+ usage.db_sched_duration_sec,
+ usage.db_txn_duration_sec,
+ int(usage.db_txn_count),
+ usage.evt_db_fetch_count,
+ task.resource_id,
+ task.params,
+ )
+
async def _launch_task(self, task: ScheduledTask) -> None:
"""Launch a scheduled task now.
@@ -360,8 +394,32 @@ class TaskScheduler:
)
function = self._actions[task.action]
+ def _occasional_report(
+ task_log_context: LoggingContext, start_time: float
+ ) -> None:
+ """
+ Helper to log a 'Task continuing' line every so often.
+ """
+
+ current_time = self._clock.time()
+ calling_context = set_current_context(task_log_context)
+ try:
+ usage = task_log_context.get_resource_usage()
+ TaskScheduler._log_task_usage(
+ "continuing", task, usage, current_time - start_time
+ )
+ finally:
+ set_current_context(calling_context)
+
async def wrapper() -> None:
- with nested_logging_context(task.id):
+ with nested_logging_context(task.id) as log_context:
+ start_time = self._clock.time()
+ occasional_status_call = self._clock.looping_call(
+ _occasional_report,
+ TaskScheduler.OCCASIONAL_REPORT_INTERVAL_MS,
+ log_context,
+ start_time,
+ )
try:
(status, result, error) = await function(task)
except Exception:
@@ -383,6 +441,13 @@ class TaskScheduler:
)
self._running_tasks.remove(task.id)
+ current_time = self._clock.time()
+ usage = log_context.get_resource_usage()
+ TaskScheduler._log_task_usage(
+ status.value, task, usage, current_time - start_time
+ )
+ occasional_status_call.stop()
+
# Try launch a new task since we've finished with this one.
self._clock.call_later(0.1, self._launch_scheduled_tasks)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 09a947ef15..128413c8aa 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -82,7 +82,6 @@ async def filter_events_for_client(
is_peeking: bool = False,
always_include_ids: FrozenSet[str] = frozenset(),
filter_send_to_client: bool = True,
- msc4115_membership_on_events: bool = False,
) -> List[EventBase]:
"""
Check which events a user is allowed to see. If the user can see the event but its
@@ -101,12 +100,10 @@ async def filter_events_for_client(
filter_send_to_client: Whether we're checking an event that's going to be
sent to a client. This might not always be the case since this function can
also be called to check whether a user can see the state at a given point.
- msc4115_membership_on_events: Whether to include the requesting user's
- membership in the "unsigned" data, per MSC4115.
Returns:
- The filtered events. If `msc4115_membership_on_events` is true, the `unsigned`
- data is annotated with the membership state of `user_id` at each event.
+ The filtered events. The `unsigned` data is annotated with the membership state
+ of `user_id` at each event.
"""
# Filter out events that have been soft failed so that we don't relay them
# to clients.
@@ -151,7 +148,7 @@ async def filter_events_for_client(
filter_send_to_client=filter_send_to_client,
sender_ignored=event.sender in ignore_list,
always_include_ids=always_include_ids,
- retention_policy=retention_policies[room_id],
+ retention_policy=retention_policies[event.room_id],
state=state_after_event,
is_peeking=is_peeking,
sender_erased=erased_senders.get(event.sender, False),
@@ -159,9 +156,6 @@ async def filter_events_for_client(
if filtered is None:
return None
- if not msc4115_membership_on_events:
- return filtered
-
# Annotate the event with the user's membership after the event.
#
# Normally we just look in `state_after_event`, but if the event is an outlier
@@ -186,7 +180,7 @@ async def filter_events_for_client(
# Copy the event before updating the unsigned data: this shouldn't be persisted
# to the cache!
cloned = clone_event(filtered)
- cloned.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP] = user_membership
+ cloned.unsigned[EventUnsignedContentFields.MEMBERSHIP] = user_membership
return cloned
|