diff options
author | H. Shay <hillerys@element.io> | 2022-10-20 14:22:21 -0700 |
---|---|---|
committer | H. Shay <hillerys@element.io> | 2022-10-20 14:22:21 -0700 |
commit | 63f8ee4007e736be8f9c8666f534cc3e867bbf2f (patch) | |
tree | e2aa4481f7a331bbdd26f4bbebee97dfc5feb35f /synapse | |
parent | add version of eventcontext without state group (diff) | |
parent | Use servlets for /key/ endpoints. (#14229) (diff) | |
download | synapse-63f8ee4007e736be8f9c8666f534cc3e867bbf2f.tar.xz |
Merge branch 'develop' into shay/batch_state_groups
Diffstat (limited to 'synapse')
76 files changed, 1610 insertions, 702 deletions
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 5fa599e70e..d850e54e17 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -72,6 +72,7 @@ from synapse.storage.databases.main.registration import ( RegistrationBackgroundUpdateStore, find_max_generated_user_id_localpart, ) +from synapse.storage.databases.main.relations import RelationsWorkerStore from synapse.storage.databases.main.room import RoomBackgroundUpdateStore from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore from synapse.storage.databases.main.search import SearchBackgroundUpdateStore @@ -206,6 +207,7 @@ class Store( PusherWorkerStore, PresenceBackgroundUpdateStore, ReceiptsBackgroundUpdateStore, + RelationsWorkerStore, ): def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) diff --git a/synapse/_scripts/update_synapse_database.py b/synapse/_scripts/update_synapse_database.py index fb1fb83f50..0adf94bba6 100755..100644 --- a/synapse/_scripts/update_synapse_database.py +++ b/synapse/_scripts/update_synapse_database.py @@ -15,7 +15,6 @@ import argparse import logging -import sys from typing import cast import yaml @@ -100,13 +99,6 @@ def main() -> None: # Load, process and sanity-check the config. hs_config = yaml.safe_load(args.database_config) - if "database" not in hs_config and "databases" not in hs_config: - sys.stderr.write( - "The configuration file must have a 'database' or 'databases' section. " - "See https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#database" - ) - sys.exit(4) - config = HomeServerConfig() config.parse_config_dict(hs_config, "", "") diff --git a/synapse/api/errors.py b/synapse/api/errors.py index c606207569..400dd12aba 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -155,7 +155,13 @@ class RedirectException(CodeMessageException): class SynapseError(CodeMessageException): """A base exception type for matrix errors which have an errcode and error - message (as well as an HTTP status code). + message (as well as an HTTP status code). These often bubble all the way up to the + client API response so the error code and status often reach the client directly as + defined here. If the error doesn't make sense to present to a client, then it + probably shouldn't be a `SynapseError`. For example, if we contact another + homeserver over federation, we shouldn't automatically ferry response errors back to + the client on our end (a 500 from a remote server does not make sense to a client + when our server did not experience a 500). Attributes: errcode: Matrix error code e.g 'M_FORBIDDEN' @@ -600,8 +606,20 @@ def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs: Any) -> "JsonDict": class FederationError(RuntimeError): - """This class is used to inform remote homeservers about erroneous - PDUs they sent us. + """ + Raised when we process an erroneous PDU. + + There are two kinds of scenarios where this exception can be raised: + + 1. We may pull an invalid PDU from a remote homeserver (e.g. during backfill). We + raise this exception to signal an error to the rest of the application. + 2. We may be pushed an invalid PDU as part of a `/send` transaction from a remote + homeserver. We raise so that we can respond to the transaction and include the + error string in the "PDU Processing Result". The message which will likely be + ignored by the remote homeserver and is not machine parse-able since it's just a + string. + + TODO: In the future, we should split these usage scenarios into their own error types. FATAL: The remote server could not interpret the source event. (e.g., it was missing a required field) @@ -640,6 +658,27 @@ class FederationError(RuntimeError): } +class FederationPullAttemptBackoffError(RuntimeError): + """ + Raised to indicate that we are are deliberately not attempting to pull the given + event over federation because we've already done so recently and are backing off. + + Attributes: + event_id: The event_id which we are refusing to pull + message: A custom error message that gives more context + """ + + def __init__(self, event_ids: List[str], message: Optional[str]): + self.event_ids = event_ids + + if message: + error_message = message + else: + error_message = f"Not attempting to pull event_ids={self.event_ids} because we already tried to pull them recently (backing off)." + + super().__init__(error_message) + + class HttpResponseException(CodeMessageException): """ Represents an HTTP-level failure of an outbound request diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index cc31cf8cc7..26be377d03 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -36,7 +36,7 @@ from jsonschema import FormatChecker from synapse.api.constants import EduTypes, EventContentFields from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState -from synapse.events import EventBase +from synapse.events import EventBase, relation_from_event from synapse.types import JsonDict, RoomID, UserID if TYPE_CHECKING: @@ -53,6 +53,12 @@ FILTER_SCHEMA = { # check types are valid event types "types": {"type": "array", "items": {"type": "string"}}, "not_types": {"type": "array", "items": {"type": "string"}}, + # MSC3874, filtering /messages. + "org.matrix.msc3874.rel_types": {"type": "array", "items": {"type": "string"}}, + "org.matrix.msc3874.not_rel_types": { + "type": "array", + "items": {"type": "string"}, + }, }, } @@ -334,8 +340,15 @@ class Filter: self.labels = filter_json.get("org.matrix.labels", None) self.not_labels = filter_json.get("org.matrix.not_labels", []) - self.related_by_senders = self.filter_json.get("related_by_senders", None) - self.related_by_rel_types = self.filter_json.get("related_by_rel_types", None) + self.related_by_senders = filter_json.get("related_by_senders", None) + self.related_by_rel_types = filter_json.get("related_by_rel_types", None) + + # For compatibility with _check_fields. + self.rel_types = None + self.not_rel_types = [] + if hs.config.experimental.msc3874_enabled: + self.rel_types = filter_json.get("org.matrix.msc3874.rel_types", None) + self.not_rel_types = filter_json.get("org.matrix.msc3874.not_rel_types", []) def filters_all_types(self) -> bool: return "*" in self.not_types @@ -386,11 +399,19 @@ class Filter: # check if there is a string url field in the content for filtering purposes labels = content.get(EventContentFields.LABELS, []) + # Check if the event has a relation. + rel_type = None + if isinstance(event, EventBase): + relation = relation_from_event(event) + if relation: + rel_type = relation.rel_type + field_matchers = { "rooms": lambda v: room_id == v, "senders": lambda v: sender == v, "types": lambda v: _matches_wildcard(ev_type, v), "labels": lambda v: v in labels, + "rel_types": lambda v: rel_type == v, } result = self._check_fields(field_matchers) diff --git a/synapse/api/urls.py b/synapse/api/urls.py index bd49fa6a5f..a918579f50 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -28,7 +28,7 @@ FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1" FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2" FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable" STATIC_PREFIX = "/_matrix/static" -SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" +SERVER_KEY_PREFIX = "/_matrix/key" MEDIA_R0_PREFIX = "/_matrix/media/r0" MEDIA_V3_PREFIX = "/_matrix/media/v3" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 5e3825fca6..2a9f039367 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -28,7 +28,7 @@ from synapse.api.urls import ( LEGACY_MEDIA_PREFIX, MEDIA_R0_PREFIX, MEDIA_V3_PREFIX, - SERVER_KEY_V2_PREFIX, + SERVER_KEY_PREFIX, ) from synapse.app import _base from synapse.app._base import ( @@ -65,6 +65,7 @@ from synapse.rest.client import ( push_rule, read_marker, receipts, + relations, room, room_batch, room_keys, @@ -88,7 +89,7 @@ from synapse.rest.client.register import ( RegistrationTokenValidityRestServlet, ) from synapse.rest.health import HealthResource -from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.key.v2 import KeyResource from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.well_known import well_known_resource from synapse.server import HomeServer @@ -308,6 +309,7 @@ class GenericWorkerServer(HomeServer): sync.register_servlets(self, resource) events.register_servlets(self, resource) room.register_servlets(self, resource, is_worker=True) + relations.register_servlets(self, resource) room.register_deprecated_servlets(self, resource) initial_sync.register_servlets(self, resource) room_batch.register_servlets(self, resource) @@ -323,13 +325,13 @@ class GenericWorkerServer(HomeServer): presence.register_servlets(self, resource) - resources.update({CLIENT_API_PREFIX: resource}) + resources[CLIENT_API_PREFIX] = resource resources.update(build_synapse_client_resource_tree(self)) - resources.update({"/.well-known": well_known_resource(self)}) + resources["/.well-known"] = well_known_resource(self) elif name == "federation": - resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) + resources[FEDERATION_PREFIX] = TransportLayerServer(self) elif name == "media": if self.config.media.can_load_media_repo: media_repo = self.get_media_repository_resource() @@ -357,16 +359,12 @@ class GenericWorkerServer(HomeServer): # Only load the openid resource separately if federation resource # is not specified since federation resource includes openid # resource. - resources.update( - { - FEDERATION_PREFIX: TransportLayerServer( - self, servlet_groups=["openid"] - ) - } + resources[FEDERATION_PREFIX] = TransportLayerServer( + self, servlet_groups=["openid"] ) if name in ["keys", "federation"]: - resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) + resources[SERVER_KEY_PREFIX] = KeyResource(self) if name == "replication": resources[REPLICATION_PREFIX] = ReplicationRestResource(self) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 883f2fd2ec..de3f08876f 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -31,7 +31,7 @@ from synapse.api.urls import ( LEGACY_MEDIA_PREFIX, MEDIA_R0_PREFIX, MEDIA_V3_PREFIX, - SERVER_KEY_V2_PREFIX, + SERVER_KEY_PREFIX, STATIC_PREFIX, ) from synapse.app import _base @@ -60,7 +60,7 @@ from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource from synapse.rest import ClientRestResource from synapse.rest.admin import AdminRestResource from synapse.rest.health import HealthResource -from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.key.v2 import KeyResource from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.well_known import well_known_resource from synapse.server import HomeServer @@ -215,30 +215,22 @@ class SynapseHomeServer(HomeServer): consent_resource: Resource = ConsentResource(self) if compress: consent_resource = gz_wrap(consent_resource) - resources.update({"/_matrix/consent": consent_resource}) + resources["/_matrix/consent"] = consent_resource if name == "federation": federation_resource: Resource = TransportLayerServer(self) if compress: federation_resource = gz_wrap(federation_resource) - resources.update({FEDERATION_PREFIX: federation_resource}) + resources[FEDERATION_PREFIX] = federation_resource if name == "openid": - resources.update( - { - FEDERATION_PREFIX: TransportLayerServer( - self, servlet_groups=["openid"] - ) - } + resources[FEDERATION_PREFIX] = TransportLayerServer( + self, servlet_groups=["openid"] ) if name in ["static", "client"]: - resources.update( - { - STATIC_PREFIX: StaticResource( - os.path.join(os.path.dirname(synapse.__file__), "static") - ) - } + resources[STATIC_PREFIX] = StaticResource( + os.path.join(os.path.dirname(synapse.__file__), "static") ) if name in ["media", "federation", "client"]: @@ -257,7 +249,7 @@ class SynapseHomeServer(HomeServer): ) if name in ["keys", "federation"]: - resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) + resources[SERVER_KEY_PREFIX] = KeyResource(self) if name == "metrics" and self.config.metrics.enable_metrics: metrics_resource: Resource = MetricsResource(RegistryProxy) diff --git a/synapse/config/cache.py b/synapse/config/cache.py index 2db8cfb005..eb4194a5a9 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -159,7 +159,7 @@ class CacheConfig(Config): self.track_memory_usage = cache_config.get("track_memory_usage", False) if self.track_memory_usage: - check_requirements("cache_memory") + check_requirements("cache-memory") expire_caches = cache_config.get("expire_caches", True) cache_entry_ttl = cache_config.get("cache_entry_ttl", "30m") diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index e00cb7096c..4009add01d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional import attr @@ -95,8 +95,6 @@ class ExperimentalConfig(Config): # MSC2815 (allow room moderators to view redacted event content) self.msc2815_enabled: bool = experimental.get("msc2815_enabled", False) - # MSC3772: A push rule for mutual relations. - self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) # MSC3773: Thread notifications self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False) @@ -119,3 +117,11 @@ class ExperimentalConfig(Config): self.msc3882_token_timeout = self.parse_duration( experimental.get("msc3882_token_timeout", "5m") ) + + # MSC3874: Filtering /messages with rel_types / not_rel_types. + self.msc3874_enabled: bool = experimental.get("msc3874_enabled", False) + + # MSC3886: Simple client rendezvous capability + self.msc3886_endpoint: Optional[str] = experimental.get( + "msc3886_endpoint", None + ) diff --git a/synapse/config/groups.py b/synapse/config/groups.py deleted file mode 100644 index baa051fdd4..0000000000 --- a/synapse/config/groups.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2017 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any - -from synapse.types import JsonDict - -from ._base import Config - - -class GroupsConfig(Config): - section = "groups" - - def read_config(self, config: JsonDict, **kwargs: Any) -> None: - self.enable_group_creation = config.get("enable_group_creation", False) - self.group_creation_prefix = config.get("group_creation_prefix", "") diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 1033496bb4..e4759711ed 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -205,7 +205,7 @@ class ContentRepositoryConfig(Config): ) self.url_preview_enabled = config.get("url_preview_enabled", False) if self.url_preview_enabled: - check_requirements("url_preview") + check_requirements("url-preview") proxy_env = getproxies_environment() if "url_preview_ip_range_blacklist" not in config: diff --git a/synapse/config/server.py b/synapse/config/server.py index f2353ce5fb..ec46ca63ad 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -207,6 +207,9 @@ class HttpListenerConfig: additional_resources: Dict[str, dict] = attr.Factory(dict) tag: Optional[str] = None request_id_header: Optional[str] = None + # If true, the listener will return CORS response headers compatible with MSC3886: + # https://github.com/matrix-org/matrix-spec-proposals/pull/3886 + experimental_cors_msc3886: bool = False @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -935,6 +938,7 @@ def parse_listener_def(num: int, listener: Any) -> ListenerConfig: additional_resources=listener.get("additional_resources", {}), tag=listener.get("tag"), request_id_header=listener.get("request_id_header"), + experimental_cors_msc3886=listener.get("experimental_cors_msc3886", False), ) return ListenerConfig(port, bind_addresses, listener_type, tls, http_config) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c7d5ef92fc..bab31e33c5 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -15,7 +15,18 @@ import logging import typing -from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import ( + Any, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Union, +) from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes @@ -134,6 +145,7 @@ def validate_event_for_room_version(event: "EventBase") -> None: async def check_state_independent_auth_rules( store: _EventSourceStore, event: "EventBase", + batched_auth_events: Optional[Mapping[str, "EventBase"]] = None, ) -> None: """Check that an event complies with auth rules that are independent of room state @@ -143,6 +155,8 @@ async def check_state_independent_auth_rules( Args: store: the datastore; used to fetch the auth events for validation event: the event being checked. + batched_auth_events: if the event being authed is part of a batch, any events + from the same batch that may be necessary to auth the current event Raises: AuthError if the checks fail @@ -162,6 +176,9 @@ async def check_state_independent_auth_rules( redact_behaviour=EventRedactBehaviour.as_is, allow_rejected=True, ) + if batched_auth_events: + auth_events.update(batched_auth_events) + room_id = event.room_id auth_dict: MutableStateMap[str] = {} expected_auth_types = auth_types_for_event(event.room_version, event) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 826a84894e..14ac875ce0 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -65,7 +65,8 @@ class EventContext: None does not necessarily mean that ``state_group`` does not have a prev_group! - If the event is a state event, this is normally the same as ``prev_group``. + If the event is a state event, this is normally the same as + ``state_group_before_event``. If ``state_group`` is None (ie, the event is an outlier), ``prev_group`` will always also be ``None``. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 4dca711cd2..b220ab43fc 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1294,7 +1294,7 @@ class FederationClient(FederationBase): return resp[1] async def send_knock(self, destinations: List[str], pdu: EventBase) -> JsonDict: - """Attempts to send a knock event to given a list of servers. Iterates + """Attempts to send a knock event to a given list of servers. Iterates through the list until one attempt succeeds. Doing so will cause the remote server to add the event to the graph, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 907940e19e..59e351595b 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -481,6 +481,14 @@ class FederationServer(FederationBase): pdu_results[pdu.event_id] = await process_pdu(pdu) async def process_pdu(pdu: EventBase) -> JsonDict: + """ + Processes a pushed PDU sent to us via a `/send` transaction + + Returns: + JsonDict representing a "PDU Processing Result" that will be bundled up + with the other processed PDU's in the `/send` transaction and sent back + to remote homeserver. + """ event_id = pdu.event_id with nested_logging_context(event_id): try: @@ -824,7 +832,14 @@ class FederationServer(FederationBase): context, self._room_prejoin_state_types ) ) - return {"knock_state_events": stripped_room_state} + return { + "knock_room_state": stripped_room_state, + # Since v1.37, Synapse incorrectly used "knock_state_events" for this field. + # Thus, we also populate a 'knock_state_events' with the same content to + # support old instances. + # See https://github.com/matrix-org/synapse/issues/14088. + "knock_state_events": stripped_room_state, + } async def _on_send_membership_event( self, origin: str, content: JsonDict, membership_type: str, room_id: str diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index a6cb3ba58f..774ecd81b6 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -353,21 +353,25 @@ class FederationSender(AbstractFederationSender): last_token = await self.store.get_federation_out_pos("events") ( next_token, - events, event_to_received_ts, - ) = await self.store.get_all_new_events_stream( + ) = await self.store.get_all_new_event_ids_stream( last_token, self._last_poked_id, limit=100 ) + event_ids = event_to_received_ts.keys() + event_entries = await self.store.get_unredacted_events_from_cache_or_db( + event_ids + ) + logger.debug( "Handling %i -> %i: %i events to send (current id %i)", last_token, next_token, - len(events), + len(event_entries), self._last_poked_id, ) - if not events and next_token >= self._last_poked_id: + if not event_entries and next_token >= self._last_poked_id: logger.debug("All events processed") break @@ -508,8 +512,14 @@ class FederationSender(AbstractFederationSender): await handle_event(event) events_by_room: Dict[str, List[EventBase]] = {} - for event in events: - events_by_room.setdefault(event.room_id, []).append(event) + + for event_id in event_ids: + # `event_entries` is unsorted, so we have to iterate over `event_ids` + # to ensure the events are in the right order + event_cache = event_entries.get(event_id) + if event_cache: + event = event_cache.event + events_by_room.setdefault(event.room_id, []).append(event) await make_deferred_yieldable( defer.gatherResults( @@ -524,9 +534,10 @@ class FederationSender(AbstractFederationSender): logger.debug("Successfully handled up to %i", next_token) await self.store.update_federation_out_pos("events", next_token) - if events: + if event_entries: now = self.clock.time_msec() - ts = event_to_received_ts[events[-1].event_id] + last_id = next(reversed(event_ids)) + ts = event_to_received_ts[last_id] assert ts is not None synapse.metrics.event_processing_lag.labels( @@ -536,7 +547,7 @@ class FederationSender(AbstractFederationSender): "federation_sender" ).set(ts) - events_processed_counter.inc(len(events)) + events_processed_counter.inc(len(event_entries)) event_processing_loop_room_count.labels("federation_sender").inc( len(events_by_room) diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 6bb4659c4c..205fd16daa 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -489,7 +489,7 @@ class FederationV2InviteServlet(BaseFederationServerServlet): room_version = content["room_version"] event = content["event"] - invite_room_state = content["invite_room_state"] + invite_room_state = content.get("invite_room_state", []) # Synapse expects invite_room_state to be in unsigned, as it is in v1 # API @@ -499,6 +499,11 @@ class FederationV2InviteServlet(BaseFederationServerServlet): result = await self.handler.on_invite_request( origin, event, room_version_id=room_version ) + + # We only store invite_room_state for internal use, so remove it before + # returning the event to the remote homeserver. + result["event"].get("unsigned", {}).pop("invite_room_state", None) + return 200, result diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 0478448b47..fc21d58001 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -225,7 +225,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]): self, user: UserID, from_key: int, - limit: Optional[int], + limit: int, room_ids: Collection[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 203b62e015..66f5b8d108 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -109,10 +109,13 @@ class ApplicationServicesHandler: last_token = await self.store.get_appservice_last_pos() ( upper_bound, - events, event_to_received_ts, - ) = await self.store.get_all_new_events_stream( - last_token, self.current_max, limit=100, get_prev_content=True + ) = await self.store.get_all_new_event_ids_stream( + last_token, self.current_max, limit=100 + ) + + events = await self.store.get_events_as_list( + event_to_received_ts.keys(), get_prev_content=True ) events_by_room: Dict[str, List[EventBase]] = {} diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index f9cc5bddbc..c597639a7f 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -937,7 +937,10 @@ class DeviceListUpdater: # Check if we are partially joining any rooms. If so we need to store # all device list updates so that we can handle them correctly once we # know who is in the room. - partial_rooms = await self.store.get_partial_state_rooms_and_servers() + # TODO(faster joins): this fetches and processes a bunch of data that we don't + # use. Could be replaced by a tighter query e.g. + # SELECT EXISTS(SELECT 1 FROM partial_state_rooms) + partial_rooms = await self.store.get_partial_state_room_resync_info() if partial_rooms: await self.store.add_remote_device_list_to_pending( user_id, diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 7127d5aefc..d52ebada6b 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -16,6 +16,8 @@ import logging import string from typing import TYPE_CHECKING, Iterable, List, Optional +from typing_extensions import Literal + from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes from synapse.api.errors import ( AuthError, @@ -429,7 +431,10 @@ class DirectoryHandler: return await self.auth.check_can_change_room_list(room_id, requester) async def edit_published_room_list( - self, requester: Requester, room_id: str, visibility: str + self, + requester: Requester, + room_id: str, + visibility: Literal["public", "private"], ) -> None: """Edit the entry of the room in the published room list. @@ -451,9 +456,6 @@ class DirectoryHandler: if requester.is_guest: raise AuthError(403, "Guests cannot edit the published room list") - if visibility not in ["public", "private"]: - raise SynapseError(400, "Invalid visibility setting") - if visibility == "public" and not self.enable_room_list_search: # The room list has been disabled. raise AuthError( @@ -505,7 +507,11 @@ class DirectoryHandler: await self.store.set_room_is_public(room_id, making_public) async def edit_published_appservice_room_list( - self, appservice_id: str, network_id: str, room_id: str, visibility: str + self, + appservice_id: str, + network_id: str, + room_id: str, + visibility: Literal["public", "private"], ) -> None: """Add or remove a room from the appservice/network specific public room list. @@ -516,9 +522,6 @@ class DirectoryHandler: room_id visibility: either "public" or "private" """ - if visibility not in ["public", "private"]: - raise SynapseError(400, "Invalid visibility setting") - await self.store.set_room_is_public_appservice( room_id, appservice_id, network_id, visibility == "public" ) diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 8249ca1ed2..3bbad0271b 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, List, Optional, Union +from typing import TYPE_CHECKING, Collection, List, Mapping, Optional, Union from synapse import event_auth from synapse.api.constants import ( @@ -29,7 +29,6 @@ from synapse.event_auth import ( ) from synapse.events import EventBase from synapse.events.builder import EventBuilder -from synapse.events.snapshot import EventContext from synapse.types import StateMap, get_domain_from_id if TYPE_CHECKING: @@ -51,12 +50,21 @@ class EventAuthHandler: async def check_auth_rules_from_context( self, event: EventBase, - context: EventContext, + batched_auth_events: Optional[Mapping[str, EventBase]] = None, ) -> None: - """Check an event passes the auth rules at its own auth events""" - await check_state_independent_auth_rules(self._store, event) + """Check an event passes the auth rules at its own auth events + Args: + event: event to be authed + batched_auth_events: if the event being authed is part of a batch, any events + from the same batch that may be necessary to auth the current event + """ + await check_state_independent_auth_rules( + self._store, event, batched_auth_events + ) auth_event_ids = event.auth_event_ids() auth_events_by_id = await self._store.get_events(auth_event_ids) + if batched_auth_events: + auth_events_by_id.update(batched_auth_events) check_state_dependent_auth_rules(event, auth_events_by_id.values()) def compute_auth_events( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 986ffed3d5..275a37a575 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -45,6 +45,7 @@ from synapse.api.errors import ( Codes, FederationDeniedError, FederationError, + FederationPullAttemptBackoffError, HttpResponseException, LimitExceededError, NotFoundError, @@ -631,6 +632,7 @@ class FederationHandler: room_id=room_id, servers=ret.servers_in_room, device_lists_stream_id=self.store.get_device_stream_token(), + joined_via=origin, ) try: @@ -781,15 +783,27 @@ class FederationHandler: # Send the signed event back to the room, and potentially receive some # further information about the room in the form of partial state events - stripped_room_state = await self.federation_client.send_knock( - target_hosts, event - ) + knock_response = await self.federation_client.send_knock(target_hosts, event) # Store any stripped room state events in the "unsigned" key of the event. # This is a bit of a hack and is cribbing off of invites. Basically we # store the room state here and retrieve it again when this event appears # in the invitee's sync stream. It is stripped out for all other local users. - event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"] + stripped_room_state = ( + knock_response.get("knock_room_state") + # Since v1.37, Synapse incorrectly used "knock_state_events" for this field. + # Thus, we also check for a 'knock_state_events' to support old instances. + # See https://github.com/matrix-org/synapse/issues/14088. + or knock_response.get("knock_state_events") + ) + + if stripped_room_state is None: + raise KeyError( + "Missing 'knock_room_state' (or legacy 'knock_state_events') field in " + "send_knock response" + ) + + event.unsigned["knock_room_state"] = stripped_room_state context = EventContext.for_outlier(self._storage_controllers) stream_id = await self._federation_event_handler.persist_events_and_notify( @@ -928,7 +942,7 @@ class FederationHandler: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` - await self._event_auth_handler.check_auth_rules_from_context(event, context) + await self._event_auth_handler.check_auth_rules_from_context(event) return event async def on_invite_request( @@ -1109,7 +1123,7 @@ class FederationHandler: try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_leave_request` - await self._event_auth_handler.check_auth_rules_from_context(event, context) + await self._event_auth_handler.check_auth_rules_from_context(event) except AuthError as e: logger.warning("Failed to create new leave %r because %s", event, e) raise e @@ -1168,7 +1182,7 @@ class FederationHandler: try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_knock_request` - await self._event_auth_handler.check_auth_rules_from_context(event, context) + await self._event_auth_handler.check_auth_rules_from_context(event) except AuthError as e: logger.warning("Failed to create new knock %r because %s", event, e) raise e @@ -1334,9 +1348,7 @@ class FederationHandler: try: validate_event_for_room_version(event) - await self._event_auth_handler.check_auth_rules_from_context( - event, context - ) + await self._event_auth_handler.check_auth_rules_from_context(event) except AuthError as e: logger.warning("Denying new third party invite %r because %s", event, e) raise e @@ -1386,7 +1398,7 @@ class FederationHandler: try: validate_event_for_room_version(event) - await self._event_auth_handler.check_auth_rules_from_context(event, context) + await self._event_auth_handler.check_auth_rules_from_context(event) except AuthError as e: logger.warning("Denying third party invite %r because %s", event, e) raise e @@ -1602,13 +1614,13 @@ class FederationHandler: """Resumes resyncing of all partial-state rooms after a restart.""" assert not self.config.worker.worker_app - partial_state_rooms = await self.store.get_partial_state_rooms_and_servers() - for room_id, servers_in_room in partial_state_rooms.items(): + partial_state_rooms = await self.store.get_partial_state_room_resync_info() + for room_id, resync_info in partial_state_rooms.items(): run_as_background_process( desc="sync_partial_state_room", func=self._sync_partial_state_room, - initial_destination=None, - other_destinations=servers_in_room, + initial_destination=resync_info.joined_via, + other_destinations=resync_info.servers_in_room, room_id=room_id, ) @@ -1637,28 +1649,12 @@ class FederationHandler: # really leave, that might mean we have difficulty getting the room state over # federation. # https://github.com/matrix-org/synapse/issues/12802 - # - # TODO(faster_joins): we need some way of prioritising which homeservers in - # `other_destinations` to try first, otherwise we'll spend ages trying dead - # homeservers for large rooms. - # https://github.com/matrix-org/synapse/issues/12999 - - if initial_destination is None and len(other_destinations) == 0: - raise ValueError( - f"Cannot resync state of {room_id}: no destinations provided" - ) # Make an infinite iterator of destinations to try. Once we find a working # destination, we'll stick with it until it flakes. - destinations: Collection[str] - if initial_destination is not None: - # Move `initial_destination` to the front of the list. - destinations = list(other_destinations) - if initial_destination in destinations: - destinations.remove(initial_destination) - destinations = [initial_destination] + destinations - else: - destinations = other_destinations + destinations = _prioritise_destinations_for_partial_state_resync( + initial_destination, other_destinations, room_id + ) destination_iter = itertools.cycle(destinations) # `destination` is the current remote homeserver we're pulling from. @@ -1708,7 +1704,22 @@ class FederationHandler: destination, event ) break + except FederationPullAttemptBackoffError as exc: + # Log a warning about why we failed to process the event (the error message + # for `FederationPullAttemptBackoffError` is pretty good) + logger.warning("_sync_partial_state_room: %s", exc) + # We do not record a failed pull attempt when we backoff fetching a missing + # `prev_event` because not being able to fetch the `prev_events` just means + # we won't be able to de-outlier the pulled event. But we can still use an + # `outlier` in the state/auth chain for another event. So we shouldn't stop + # a downstream event from trying to pull it. + # + # This avoids a cascade of backoff for all events in the DAG downstream from + # one event backoff upstream. except FederationError as e: + # TODO: We should `record_event_failed_pull_attempt` here, + # see https://github.com/matrix-org/synapse/issues/13700 + if attempt == len(destinations) - 1: # We have tried every remote server for this event. Give up. # TODO(faster_joins) giving up isn't the right thing to do @@ -1741,3 +1752,29 @@ class FederationHandler: room_id, destination, ) + + +def _prioritise_destinations_for_partial_state_resync( + initial_destination: Optional[str], + other_destinations: Collection[str], + room_id: str, +) -> Collection[str]: + """Work out the order in which we should ask servers to resync events. + + If an `initial_destination` is given, it takes top priority. Otherwise + all servers are treated equally. + + :raises ValueError: if no destination is provided at all. + """ + if initial_destination is None and len(other_destinations) == 0: + raise ValueError(f"Cannot resync state of {room_id}: no destinations provided") + + if initial_destination is None: + return other_destinations + + # Move `initial_destination` to the front of the list. + destinations = list(other_destinations) + if initial_destination in destinations: + destinations.remove(initial_destination) + destinations = [initial_destination] + destinations + return destinations diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index da319943cc..06e41b5cc0 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -44,6 +44,7 @@ from synapse.api.errors import ( AuthError, Codes, FederationError, + FederationPullAttemptBackoffError, HttpResponseException, RequestSendFailed, SynapseError, @@ -414,7 +415,9 @@ class FederationEventHandler: # First, precalculate the joined hosts so that the federation sender doesn't # need to. - await self._event_creation_handler.cache_joined_hosts_for_event(event, context) + await self._event_creation_handler.cache_joined_hosts_for_events( + [(event, context)] + ) await self._check_for_soft_fail(event, context=context, origin=origin) await self._run_push_actions_and_persist_event(event, context) @@ -565,6 +568,9 @@ class FederationEventHandler: event: partial-state event to be de-partial-stated Raises: + FederationPullAttemptBackoffError if we are are deliberately not attempting + to pull the given event over federation because we've already done so + recently and are backing off. FederationError if we fail to request state from the remote server. """ logger.info("Updating state for %s", event.event_id) @@ -792,9 +798,42 @@ class FederationEventHandler: ], ) + # Check if we already any of these have these events. + # Note: we currently make a lookup in the database directly here rather than + # checking the event cache, due to: + # https://github.com/matrix-org/synapse/issues/13476 + existing_events_map = await self._store._get_events_from_db( + [event.event_id for event in events] + ) + + new_events = [] + for event in events: + event_id = event.event_id + + # If we've already seen this event ID... + if event_id in existing_events_map: + existing_event = existing_events_map[event_id] + + # ...and the event itself was not previously stored as an outlier... + if not existing_event.event.internal_metadata.is_outlier(): + # ...then there's no need to persist it. We have it already. + logger.info( + "_process_pulled_event: Ignoring received event %s which we " + "have already seen", + event.event_id, + ) + continue + + # While we have seen this event before, it was stored as an outlier. + # We'll now persist it as a non-outlier. + logger.info("De-outliering event %s", event_id) + + # Continue on with the events that are new to us. + new_events.append(event) + # We want to sort these by depth so we process them and # tell clients about them in order. - sorted_events = sorted(events, key=lambda x: x.depth) + sorted_events = sorted(new_events, key=lambda x: x.depth) for ev in sorted_events: with nested_logging_context(ev.event_id): await self._process_pulled_event(origin, ev, backfilled=backfilled) @@ -846,18 +885,6 @@ class FederationEventHandler: event_id = event.event_id - existing = await self._store.get_event( - event_id, allow_none=True, allow_rejected=True - ) - if existing: - if not existing.internal_metadata.is_outlier(): - logger.info( - "_process_pulled_event: Ignoring received event %s which we have already seen", - event_id, - ) - return - logger.info("De-outliering event %s", event_id) - try: self._sanity_check_event(event) except SynapseError as err: @@ -899,6 +926,18 @@ class FederationEventHandler: context, backfilled=backfilled, ) + except FederationPullAttemptBackoffError as exc: + # Log a warning about why we failed to process the event (the error message + # for `FederationPullAttemptBackoffError` is pretty good) + logger.warning("_process_pulled_event: %s", exc) + # We do not record a failed pull attempt when we backoff fetching a missing + # `prev_event` because not being able to fetch the `prev_events` just means + # we won't be able to de-outlier the pulled event. But we can still use an + # `outlier` in the state/auth chain for another event. So we shouldn't stop + # a downstream event from trying to pull it. + # + # This avoids a cascade of backoff for all events in the DAG downstream from + # one event backoff upstream. except FederationError as e: await self._store.record_event_failed_pull_attempt( event.room_id, event_id, str(e) @@ -945,6 +984,9 @@ class FederationEventHandler: The event context. Raises: + FederationPullAttemptBackoffError if we are are deliberately not attempting + to pull the given event over federation because we've already done so + recently and are backing off. FederationError if we fail to get the state from the remote server after any missing `prev_event`s. """ @@ -955,6 +997,18 @@ class FederationEventHandler: seen = await self._store.have_events_in_timeline(prevs) missing_prevs = prevs - seen + # If we've already recently attempted to pull this missing event, don't + # try it again so soon. Since we have to fetch all of the prev_events, we can + # bail early here if we find any to ignore. + prevs_to_ignore = await self._store.get_event_ids_to_not_pull_from_backoff( + room_id, missing_prevs + ) + if len(prevs_to_ignore) > 0: + raise FederationPullAttemptBackoffError( + event_ids=prevs_to_ignore, + message=f"While computing context for event={event_id}, not attempting to pull missing prev_event={prevs_to_ignore[0]} because we already tried to pull recently (backing off).", + ) + if not missing_prevs: return await self._state_handler.compute_event_context(event) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 860c82c110..9c335e6863 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -57,13 +57,7 @@ class InitialSyncHandler: self.validator = EventValidator() self.snapshot_cache: ResponseCache[ Tuple[ - str, - Optional[StreamToken], - Optional[StreamToken], - str, - Optional[int], - bool, - bool, + str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool ] ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() @@ -154,11 +148,6 @@ class InitialSyncHandler: public_room_ids = await self.store.get_public_room_ids() - if pagin_config.limit is not None: - limit = pagin_config.limit - else: - limit = 10 - serializer_options = SerializeEventConfig(as_client_event=as_client_event) async def handle_room(event: RoomsForUser) -> None: @@ -210,7 +199,7 @@ class InitialSyncHandler: run_in_background( self.store.get_recent_events_for_room, event.room_id, - limit=limit, + limit=pagin_config.limit, end_token=room_end_token, ), deferred_room_state, @@ -360,15 +349,11 @@ class InitialSyncHandler: member_event_id ) - limit = pagin_config.limit if pagin_config else None - if limit is None: - limit = 10 - leave_position = await self.store.get_position_for_event(member_event_id) stream_token = leave_position.to_room_stream_token() messages, token = await self.store.get_recent_events_for_room( - room_id, limit=limit, end_token=stream_token + room_id, limit=pagin_config.limit, end_token=stream_token ) messages = await filter_events_for_client( @@ -420,10 +405,6 @@ class InitialSyncHandler: now_token = self.hs.get_event_sources().get_current_token() - limit = pagin_config.limit if pagin_config else None - if limit is None: - limit = 10 - room_members = [ m for m in current_state.values() @@ -467,7 +448,7 @@ class InitialSyncHandler: run_in_background( self.store.get_recent_events_for_room, room_id, - limit=limit, + limit=pagin_config.limit, end_token=now_token.room_key, ), ), diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index da1acea275..15b828dd74 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1360,8 +1360,16 @@ class EventCreationHandler: else: try: validate_event_for_room_version(event) + # If we are persisting a batch of events the event(s) needed to auth the + # current event may be part of the batch and will not be in the DB yet + event_id_to_event = {e.event_id: e for e, _ in events_and_context} + batched_auth_events = {} + for event_id in event.auth_event_ids(): + auth_event = event_id_to_event.get(event_id) + if auth_event: + batched_auth_events[event_id] = auth_event await self._event_auth_handler.check_auth_rules_from_context( - event, context + event, batched_auth_events ) except AuthError as err: logger.warning("Denying new event %r because %s", event, err) @@ -1390,7 +1398,7 @@ class EventCreationHandler: extra_users=extra_users, ), run_in_background( - self.cache_joined_hosts_for_event, event, context + self.cache_joined_hosts_for_events, events_and_context ).addErrback( log_failure, "cache_joined_hosts_for_event failed" ), @@ -1491,62 +1499,65 @@ class EventCreationHandler: await self.store.remove_push_actions_from_staging(event.event_id) raise - async def cache_joined_hosts_for_event( - self, event: EventBase, context: EventContext + async def cache_joined_hosts_for_events( + self, events_and_context: List[Tuple[EventBase, EventContext]] ) -> None: - """Precalculate the joined hosts at the event, when using Redis, so that + """Precalculate the joined hosts at each of the given events, when using Redis, so that external federation senders don't have to recalculate it themselves. """ - if not self._external_cache.is_enabled(): - return + for event, _ in events_and_context: + if not self._external_cache.is_enabled(): + return - # If external cache is enabled we should always have this. - assert self._external_cache_joined_hosts_updates is not None + # If external cache is enabled we should always have this. + assert self._external_cache_joined_hosts_updates is not None - # We actually store two mappings, event ID -> prev state group, - # state group -> joined hosts, which is much more space efficient - # than event ID -> joined hosts. - # - # Note: We have to cache event ID -> prev state group, as we don't - # store that in the DB. - # - # Note: We set the state group -> joined hosts cache if it hasn't been - # set for a while, so that the expiry time is reset. - - state_entry = await self.state.resolve_state_groups_for_events( - event.room_id, event_ids=event.prev_event_ids() - ) + # We actually store two mappings, event ID -> prev state group, + # state group -> joined hosts, which is much more space efficient + # than event ID -> joined hosts. + # + # Note: We have to cache event ID -> prev state group, as we don't + # store that in the DB. + # + # Note: We set the state group -> joined hosts cache if it hasn't been + # set for a while, so that the expiry time is reset. - if state_entry.state_group: - await self._external_cache.set( - "event_to_prev_state_group", - event.event_id, - state_entry.state_group, - expiry_ms=60 * 60 * 1000, + state_entry = await self.state.resolve_state_groups_for_events( + event.room_id, event_ids=event.prev_event_ids() ) - if state_entry.state_group in self._external_cache_joined_hosts_updates: - return + if state_entry.state_group: + await self._external_cache.set( + "event_to_prev_state_group", + event.event_id, + state_entry.state_group, + expiry_ms=60 * 60 * 1000, + ) - state = await state_entry.get_state( - self._storage_controllers.state, StateFilter.all() - ) - with opentracing.start_active_span("get_joined_hosts"): - joined_hosts = await self.store.get_joined_hosts( - event.room_id, state, state_entry + if state_entry.state_group in self._external_cache_joined_hosts_updates: + return + + state = await state_entry.get_state( + self._storage_controllers.state, StateFilter.all() ) + with opentracing.start_active_span("get_joined_hosts"): + joined_hosts = await self.store.get_joined_hosts( + event.room_id, state, state_entry + ) - # Note that the expiry times must be larger than the expiry time in - # _external_cache_joined_hosts_updates. - await self._external_cache.set( - "get_joined_hosts", - str(state_entry.state_group), - list(joined_hosts), - expiry_ms=60 * 60 * 1000, - ) + # Note that the expiry times must be larger than the expiry time in + # _external_cache_joined_hosts_updates. + await self._external_cache.set( + "get_joined_hosts", + str(state_entry.state_group), + list(joined_hosts), + expiry_ms=60 * 60 * 1000, + ) - self._external_cache_joined_hosts_updates[state_entry.state_group] = None + self._external_cache_joined_hosts_updates[ + state_entry.state_group + ] = None async def _validate_canonical_alias( self, diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 1f83bab836..a4ca9cb8b4 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -458,11 +458,6 @@ class PaginationHandler: # `/messages` should still works with live tokens when manually provided. assert from_token.room_key.topological is not None - if pagin_config.limit is None: - # This shouldn't happen as we've set a default limit before this - # gets called. - raise Exception("limit not set") - room_token = from_token.room_key async with self.pagination_lock.read(room_id): diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 4e575ffbaa..2670e561d7 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1596,7 +1596,9 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): self, user: UserID, from_key: Optional[int], - limit: Optional[int] = None, + # Having a default limit doesn't match the EventSource API, but some + # callers do not provide it. It is unused in this class. + limit: int = 0, room_ids: Optional[Collection[str]] = None, is_guest: bool = False, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 4a7ec9e426..ac01582442 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -257,7 +257,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): self, user: UserID, from_key: int, - limit: Optional[int], + limit: int, room_ids: Iterable[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index cc5e45c241..0a0c6d938e 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum import logging from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple @@ -20,7 +21,7 @@ from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace -from synapse.storage.databases.main.relations import _RelatedEvent +from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.visibility import filter_events_for_client @@ -32,6 +33,13 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class ThreadsListInclude(str, enum.Enum): + """Valid values for the 'include' flag of /threads.""" + + all = "all" + participated = "participated" + + @attr.s(slots=True, frozen=True, auto_attribs=True) class _ThreadAggregation: # The latest event in the thread. @@ -108,9 +116,6 @@ class RelationsHandler: if event is None: raise SynapseError(404, "Unknown parent event.") - # TODO Update pagination config to not allow None limits. - assert pagin_config.limit is not None - # Note that ignored users are not passed into get_relations_for_event # below. Ignored users are handled in filter_events_for_client (and by # not passing them in here we should get a better cache hit rate). @@ -482,3 +487,79 @@ class RelationsHandler: results.setdefault(event_id, BundledAggregations()).replace = edit return results + + async def get_threads( + self, + requester: Requester, + room_id: str, + include: ThreadsListInclude, + limit: int = 5, + from_token: Optional[ThreadsNextBatch] = None, + ) -> JsonDict: + """Get related events of a event, ordered by topological ordering. + + Args: + requester: The user requesting the relations. + room_id: The room the event belongs to. + include: One of "all" or "participated" to indicate which threads should + be returned. + limit: Only fetch the most recent `limit` events. + from_token: Fetch rows from the given token, or from the start if None. + + Returns: + The pagination chunk. + """ + + user_id = requester.user.to_string() + + # TODO Properly handle a user leaving a room. + (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable( + room_id, requester, allow_departed_users=True + ) + + # Note that ignored users are not passed into get_relations_for_event + # below. Ignored users are handled in filter_events_for_client (and by + # not passing them in here we should get a better cache hit rate). + thread_roots, next_batch = await self._main_store.get_threads( + room_id=room_id, limit=limit, from_token=from_token + ) + + events = await self._main_store.get_events_as_list(thread_roots) + + if include == ThreadsListInclude.participated: + # Pre-seed thread participation with whether the requester sent the event. + participated = {event.event_id: event.sender == user_id for event in events} + # For events the requester did not send, check the database for whether + # the requester sent a threaded reply. + participated.update( + await self._main_store.get_threads_participated( + [eid for eid, p in participated.items() if not p], + user_id, + ) + ) + + # Limit the returned threads to those the user has participated in. + events = [event for event in events if participated[event.event_id]] + + events = await filter_events_for_client( + self._storage_controllers, + user_id, + events, + is_peeking=(member_event_id is None), + ) + + aggregations = await self.get_bundled_aggregations( + events, requester.user.to_string() + ) + + now = self._clock.time_msec() + serialized_events = self._event_serializer.serialize_events( + events, now, bundle_aggregations=aggregations + ) + + return_value: JsonDict = {"chunk": serialized_events} + + if next_batch: + return_value["next_batch"] = str(next_batch) + + return return_value diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 57ab05ad25..638f54051a 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -229,9 +229,7 @@ class RoomCreationHandler: }, ) validate_event_for_room_version(tombstone_event) - await self._event_auth_handler.check_auth_rules_from_context( - tombstone_event, tombstone_context - ) + await self._event_auth_handler.check_auth_rules_from_context(tombstone_event) # Upgrade the room # @@ -1646,7 +1644,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): self, user: UserID, from_key: RoomStreamToken, - limit: Optional[int], + limit: int, room_ids: Collection[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index e035677b8a..5943f08e91 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -874,7 +874,7 @@ class SsoHandler: ) async def handle_terms_accepted( - self, request: Request, session_id: str, terms_version: str + self, request: SynapseRequest, session_id: str, terms_version: str ) -> None: """Handle a request to the new-user 'consent' endpoint diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index f953691669..a0ea719430 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -513,7 +513,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): self, user: UserID, from_key: int, - limit: Optional[int], + limit: int, room_ids: Iterable[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/http/server.py b/synapse/http/server.py index bcbfac2c9f..b26e34bceb 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -19,6 +19,7 @@ import logging import types import urllib from http import HTTPStatus +from http.client import FOUND from inspect import isawaitable from typing import ( TYPE_CHECKING, @@ -339,7 +340,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): return callback_return - _unrecognised_request_handler(request) + return _unrecognised_request_handler(request) @abc.abstractmethod def _send_response( @@ -598,7 +599,7 @@ class RootRedirect(resource.Resource): class OptionsResource(resource.Resource): """Responds to OPTION requests for itself and all children.""" - def render_OPTIONS(self, request: Request) -> bytes: + def render_OPTIONS(self, request: SynapseRequest) -> bytes: request.setResponseCode(204) request.setHeader(b"Content-Length", b"0") @@ -763,7 +764,7 @@ def respond_with_json( def respond_with_json_bytes( - request: Request, + request: SynapseRequest, code: int, json_bytes: bytes, send_cors: bool = False, @@ -859,7 +860,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None: _ByteProducer(request, bytes_generator) -def set_cors_headers(request: Request) -> None: +def set_cors_headers(request: SynapseRequest) -> None: """Set the CORS headers so that javascript running in a web browsers can use this API @@ -870,10 +871,20 @@ def set_cors_headers(request: Request) -> None: request.setHeader( b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS" ) - request.setHeader( - b"Access-Control-Allow-Headers", - b"X-Requested-With, Content-Type, Authorization, Date", - ) + if request.experimental_cors_msc3886: + request.setHeader( + b"Access-Control-Allow-Headers", + b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match", + ) + request.setHeader( + b"Access-Control-Expose-Headers", + b"ETag, Location, X-Max-Bytes", + ) + else: + request.setHeader( + b"Access-Control-Allow-Headers", + b"X-Requested-With, Content-Type, Authorization, Date", + ) def set_corp_headers(request: Request) -> None: @@ -942,10 +953,25 @@ def set_clickjacking_protection_headers(request: Request) -> None: request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';") -def respond_with_redirect(request: Request, url: bytes) -> None: - """Write a 302 response to the request, if it is still alive.""" +def respond_with_redirect( + request: SynapseRequest, url: bytes, statusCode: int = FOUND, cors: bool = False +) -> None: + """ + Write a 302 (or other specified status code) response to the request, if it is still alive. + + Args: + request: The http request to respond to. + url: The URL to redirect to. + statusCode: The HTTP status code to use for the redirect (defaults to 302). + cors: Whether to set CORS headers on the response. + """ logger.debug("Redirect to %s", url.decode("utf-8")) - request.redirect(url) + + if cors: + set_cors_headers(request) + + request.setResponseCode(statusCode) + request.setHeader(b"location", url) finish_request(request) diff --git a/synapse/http/site.py b/synapse/http/site.py index 55a6afce35..3dbd541fed 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -82,6 +82,7 @@ class SynapseRequest(Request): self.reactor = site.reactor self._channel = channel # this is used by the tests self.start_time = 0.0 + self.experimental_cors_msc3886 = site.experimental_cors_msc3886 # The requester, if authenticated. For federation requests this is the # server name, for client requests this is the Requester object. @@ -622,6 +623,8 @@ class SynapseSite(Site): request_id_header = config.http_options.request_id_header + self.experimental_cors_msc3886 = config.http_options.experimental_cors_msc3886 + def request_factory(channel: HTTPChannel, queued: bool) -> Request: return request_class( channel, diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index eced182fd5..a75386f6a0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging from typing import ( TYPE_CHECKING, Any, Collection, Dict, - Iterable, List, Mapping, Optional, - Set, Tuple, Union, ) @@ -38,7 +35,7 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership from synapse.storage.state import StateFilter -from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRuleEvaluator +from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator from synapse.util.caches import register_cache from synapse.util.metrics import measure_func from synapse.visibility import filter_event_for_clients_with_state @@ -117,9 +114,6 @@ class BulkPushRuleEvaluator: resizable=False, ) - # Whether to support MSC3772 is supported. - self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled - async def _get_rules_for_event( self, event: EventBase, @@ -200,51 +194,6 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level - async def _get_mutual_relations( - self, parent_id: str, rules: Iterable[Tuple[PushRule, bool]] - ) -> Dict[str, Set[Tuple[str, str]]]: - """ - Fetch event metadata for events which related to the same event as the given event. - - If the given event has no relation information, returns an empty dictionary. - - Args: - parent_id: The event ID which is targeted by relations. - rules: The push rules which will be processed for this event. - - Returns: - A dictionary of relation type to: - A set of tuples of: - The sender - The event type - """ - - # If the experimental feature is not enabled, skip fetching relations. - if not self._relations_match_enabled: - return {} - - # Pre-filter to figure out which relation types are interesting. - rel_types = set() - for rule, enabled in rules: - if not enabled: - continue - - for condition in rule.conditions: - if condition["kind"] != "org.matrix.msc3772.relation_match": - continue - - # rel_type is required. - rel_type = condition.get("rel_type") - if rel_type: - rel_types.add(rel_type) - - # If no valid rules were found, no mutual relations. - if not rel_types: - return {} - - # If any valid rules were found, fetch the mutual relations. - return await self.store.get_mutual_event_relations(parent_id, rel_types) - @measure_func("action_for_event_by_user") async def action_for_event_by_user( self, event: EventBase, context: EventContext @@ -276,23 +225,18 @@ class BulkPushRuleEvaluator: sender_power_level, ) = await self._get_power_levels_and_sender_level(event, context) + # Find the event's thread ID. relation = relation_from_event(event) - # If the event does not have a relation, then cannot have any mutual - # relations or thread ID. - relations = {} + # If the event does not have a relation, then it cannot have a thread ID. thread_id = MAIN_TIMELINE if relation: - relations = await self._get_mutual_relations( - relation.parent_id, - itertools.chain(*(r.rules() for r in rules_by_user.values())), - ) # Recursively attempt to find the thread this event relates to. if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id else: # Since the event has not yet been persisted we check whether # the parent is part of a thread. - thread_id = await self.store.get_thread_id(relation.parent_id) or "main" + thread_id = await self.store.get_thread_id(relation.parent_id) # It's possible that old room versions have non-integer power levels (floats or # strings). Workaround this by explicitly converting to int. @@ -306,8 +250,6 @@ class BulkPushRuleEvaluator: room_member_count, sender_power_level, notification_levels, - relations, - self._relations_match_enabled, ) users = rules_by_user.keys() diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index 61abb529c8..976c283360 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -39,6 +39,16 @@ class ReplicationRegisterServlet(ReplicationEndpoint): self.store = hs.get_datastores().main self.registration_handler = hs.get_registration_handler() + # Default value if the worker that sent the replication request did not include + # an 'approved' property. + if ( + hs.config.experimental.msc3866.enabled + and hs.config.experimental.msc3866.require_approval_for_new_accounts + ): + self._approval_default = False + else: + self._approval_default = True + @staticmethod async def _serialize_payload( # type: ignore[override] user_id: str, @@ -92,6 +102,12 @@ class ReplicationRegisterServlet(ReplicationEndpoint): await self.registration_handler.check_registration_ratelimit(content["address"]) + # Always default admin users to approved (since it means they were created by + # an admin). + approved_default = self._approval_default + if content["admin"]: + approved_default = True + await self.registration_handler.register_with_store( user_id=user_id, password_hash=content["password_hash"], @@ -103,7 +119,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): user_type=content["user_type"], address=content["address"], shadow_banned=content["shadow_banned"], - approved=content["approved"], + approved=content.get("approved", approved_default), ) return 200, {} diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 9a2ab99ede..28542cd774 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -44,6 +44,7 @@ from synapse.rest.client import ( receipts, register, relations, + rendezvous, report_event, room, room_batch, @@ -132,3 +133,4 @@ class ClientRestResource(JsonResource): # unstable mutual_rooms.register_servlets(hs, client_resource) login_token_request.register_servlets(hs, client_resource) + rendezvous.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py index bc1b18c92d..f17b4c8d22 100644 --- a/synapse/rest/client/directory.py +++ b/synapse/rest/client/directory.py @@ -13,15 +13,22 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple + +from pydantic import StrictStr +from typing_extensions import Literal from twisted.web.server import Request from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.servlet import ( + RestServlet, + parse_and_validate_json_object_from_request, +) from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns +from synapse.rest.models import RequestBodyModel from synapse.types import JsonDict, RoomAlias if TYPE_CHECKING: @@ -54,6 +61,12 @@ class ClientDirectoryServer(RestServlet): return 200, res + class PutBody(RequestBodyModel): + # TODO: get Pydantic to validate that this is a valid room id? + room_id: StrictStr + # `servers` is unspecced + servers: Optional[List[StrictStr]] = None + async def on_PUT( self, request: SynapseRequest, room_alias: str ) -> Tuple[int, JsonDict]: @@ -61,31 +74,22 @@ class ClientDirectoryServer(RestServlet): raise SynapseError(400, "Room alias invalid", errcode=Codes.INVALID_PARAM) room_alias_obj = RoomAlias.from_string(room_alias) - content = parse_json_object_from_request(request) - if "room_id" not in content: - raise SynapseError( - 400, 'Missing params: ["room_id"]', errcode=Codes.BAD_JSON - ) + content = parse_and_validate_json_object_from_request(request, self.PutBody) logger.debug("Got content: %s", content) logger.debug("Got room name: %s", room_alias_obj.to_string()) - room_id = content["room_id"] - servers = content["servers"] if "servers" in content else None - - logger.debug("Got room_id: %s", room_id) - logger.debug("Got servers: %s", servers) + logger.debug("Got room_id: %s", content.room_id) + logger.debug("Got servers: %s", content.servers) - # TODO(erikj): Check types. - - room = await self.store.get_room(room_id) + room = await self.store.get_room(content.room_id) if room is None: raise SynapseError(400, "Room does not exist") requester = await self.auth.get_user_by_req(request) await self.directory_handler.create_association( - requester, room_alias_obj, room_id, servers + requester, room_alias_obj, content.room_id, content.servers ) return 200, {} @@ -137,16 +141,18 @@ class ClientDirectoryListServer(RestServlet): return 200, {"visibility": "public" if room["is_public"] else "private"} + class PutBody(RequestBodyModel): + visibility: Literal["public", "private"] = "public" + async def on_PUT( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - content = parse_json_object_from_request(request) - visibility = content.get("visibility", "public") + content = parse_and_validate_json_object_from_request(request, self.PutBody) await self.directory_handler.edit_published_room_list( - requester, room_id, visibility + requester, room_id, content.visibility ) return 200, {} @@ -163,12 +169,14 @@ class ClientAppserviceDirectoryListServer(RestServlet): self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() + class PutBody(RequestBodyModel): + visibility: Literal["public", "private"] = "public" + async def on_PUT( self, request: SynapseRequest, network_id: str, room_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - visibility = content.get("visibility", "public") - return await self._edit(request, network_id, room_id, visibility) + content = parse_and_validate_json_object_from_request(request, self.PutBody) + return await self._edit(request, network_id, room_id, content.visibility) async def on_DELETE( self, request: SynapseRequest, network_id: str, room_id: str @@ -176,7 +184,11 @@ class ClientAppserviceDirectoryListServer(RestServlet): return await self._edit(request, network_id, room_id, "private") async def _edit( - self, request: SynapseRequest, network_id: str, room_id: str, visibility: str + self, + request: SynapseRequest, + network_id: str, + room_id: str, + visibility: Literal["public", "private"], ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if not requester.app_service: diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index 916f5230f1..782e7d14e8 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -50,7 +50,9 @@ class EventStreamRestServlet(RestServlet): raise SynapseError(400, "Guest users must specify room_id param") room_id = parse_string(request, "room_id") - pagin_config = await PaginationConfig.from_request(self.store, request) + pagin_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS if b"timeout" in args: try: diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py index cfadcb8e50..9b1bb8b521 100644 --- a/synapse/rest/client/initial_sync.py +++ b/synapse/rest/client/initial_sync.py @@ -39,7 +39,9 @@ class InitialSyncRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) args: Dict[bytes, List[bytes]] = request.args # type: ignore as_client_event = b"raw" not in args - pagination_config = await PaginationConfig.from_request(self.store, request) + pagination_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) include_archived = parse_boolean(request, "archived", default=False) content = await self.initial_sync_handler.snapshot_all_rooms( user_id=requester.user.to_string(), diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 14dec7ac4e..18a282b22c 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReceiptTypes +from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -83,7 +83,7 @@ class ReceiptRestServlet(RestServlet): ) # Ensure the event ID roughly correlates to the thread ID. - if thread_id != await self._main_store.get_thread_id(event_id): + if not await self._is_event_in_thread(event_id, thread_id): raise SynapseError( 400, f"event_id {event_id} is not related to thread {thread_id}", @@ -109,6 +109,46 @@ class ReceiptRestServlet(RestServlet): return 200, {} + async def _is_event_in_thread(self, event_id: str, thread_id: str) -> bool: + """ + The event must be related to the thread ID (in a vague sense) to ensure + clients aren't sending bogus receipts. + + A thread ID is considered valid for a given event E if: + + 1. E has a thread relation which matches the thread ID; + 2. E has another event which has a thread relation to E matching the + thread ID; or + 3. E is recursively related (via any rel_type) to an event which + satisfies 1 or 2. + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + It is valid to send a receipt for thread A on A, B, C, D, or E. + + It is valid to send a receipt for the main timeline on A, D, and E. + + Args: + event_id: The event ID to check. + thread_id: The thread ID the event is potentially part of. + + Returns: + True if the event belongs to the given thread, otherwise False. + """ + + # If the receipt is on the main timeline, it is enough to check whether + # the event is directly related to a thread. + if thread_id == MAIN_TIMELINE: + return MAIN_TIMELINE == await self._main_store.get_thread_id(event_id) + + # Otherwise, check if the event is directly part of a thread, or is the + # root message (or related to the root message) of a thread. + return thread_id == await self._main_store.get_thread_id_for_receipts(event_id) + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReceiptRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index b31ce5a0d3..9dd59196d9 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -13,12 +13,15 @@ # limitations under the License. import logging +import re from typing import TYPE_CHECKING, Optional, Tuple +from synapse.handlers.relations import ThreadsListInclude from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet +from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns +from synapse.storage.databases.main.relations import ThreadsNextBatch from synapse.streams.config import PaginationConfig from synapse.types import JsonDict @@ -78,5 +81,45 @@ class RelationPaginationServlet(RestServlet): return 200, result +class ThreadsServlet(RestServlet): + PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/threads"),) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self._relations_handler = hs.get_relations_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + limit = parse_integer(request, "limit", default=5) + from_token_str = parse_string(request, "from") + include = parse_string( + request, + "include", + default=ThreadsListInclude.all.value, + allowed_values=[v.value for v in ThreadsListInclude], + ) + + # Return the relations + from_token = None + if from_token_str: + from_token = ThreadsNextBatch.from_string(from_token_str) + + result = await self._relations_handler.get_threads( + requester=requester, + room_id=room_id, + include=ThreadsListInclude(include), + limit=limit, + from_token=from_token, + ) + + return 200, result + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationPaginationServlet(hs).register(http_server) + ThreadsServlet(hs).register(http_server) diff --git a/synapse/rest/client/rendezvous.py b/synapse/rest/client/rendezvous.py new file mode 100644 index 0000000000..89176b1ffa --- /dev/null +++ b/synapse/rest/client/rendezvous.py @@ -0,0 +1,74 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from http.client import TEMPORARY_REDIRECT +from typing import TYPE_CHECKING, Optional + +from synapse.http.server import HttpServer, respond_with_redirect +from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class RendezvousServlet(RestServlet): + """ + This is a placeholder implementation of [MSC3886](https://github.com/matrix-org/matrix-spec-proposals/pull/3886) + simple client rendezvous capability that is used by the "Sign in with QR" functionality. + + This implementation only serves as a 307 redirect to a configured server rather than being a full implementation. + + A module that implements the full functionality is available at: https://pypi.org/project/matrix-http-rendezvous-synapse/. + + Request: + + POST /rendezvous HTTP/1.1 + Content-Type: ... + + ... + + Response: + + HTTP/1.1 307 + Location: <configured endpoint> + """ + + PATTERNS = client_patterns( + "/org.matrix.msc3886/rendezvous$", releases=[], v1=False, unstable=True + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + redirection_target: Optional[str] = hs.config.experimental.msc3886_endpoint + assert ( + redirection_target is not None + ), "Servlet is only registered if there is a redirection target" + self.endpoint = redirection_target.encode("utf-8") + + async def on_POST(self, request: SynapseRequest) -> None: + respond_with_redirect( + request, self.endpoint, statusCode=TEMPORARY_REDIRECT, cors=True + ) + + # PUT, GET and DELETE are not implemented as they should be fulfilled by the redirect target. + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + if hs.config.experimental.msc3886_endpoint is not None: + RendezvousServlet(hs).register(http_server) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index b6dedbed04..01e5079963 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -729,7 +729,9 @@ class RoomInitialSyncRestServlet(RestServlet): self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = await PaginationConfig.from_request(self.store, request) + pagination_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) content = await self.initial_sync_handler.room_initial_sync( room_id=room_id, requester=requester, pagin_config=pagination_config ) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index d1d2e5f7e3..9b1b72c68a 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -76,6 +76,7 @@ class VersionsRestServlet(RestServlet): "v1.1", "v1.2", "v1.3", + "v1.4", ], # as per MSC1497: "unstable_features": { @@ -113,6 +114,11 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3882": self.config.experimental.msc3882_enabled, # Adds support for remotely enabling/disabling pushers, as per MSC3881 "org.matrix.msc3881": self.config.experimental.msc3881_enabled, + # Adds support for filtering /messages by event relation. + "org.matrix.msc3874": self.config.experimental.msc3874_enabled, + # Adds support for simple HTTP rendezvous as per MSC3886 + "org.matrix.msc3886": self.config.experimental.msc3886_endpoint + is not None, }, }, ) diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py index 7f8c1de1ff..26403facb8 100644 --- a/synapse/rest/key/v2/__init__.py +++ b/synapse/rest/key/v2/__init__.py @@ -14,17 +14,20 @@ from typing import TYPE_CHECKING -from twisted.web.resource import Resource - -from .local_key_resource import LocalKey -from .remote_key_resource import RemoteKey +from synapse.http.server import HttpServer, JsonResource +from synapse.rest.key.v2.local_key_resource import LocalKey +from synapse.rest.key.v2.remote_key_resource import RemoteKey if TYPE_CHECKING: from synapse.server import HomeServer -class KeyApiV2Resource(Resource): +class KeyResource(JsonResource): def __init__(self, hs: "HomeServer"): - Resource.__init__(self) - self.putChild(b"server", LocalKey(hs)) - self.putChild(b"query", RemoteKey(hs)) + super().__init__(hs, canonical_json=True) + self.register_servlets(self, hs) + + @staticmethod + def register_servlets(http_server: HttpServer, hs: "HomeServer") -> None: + LocalKey(hs).register(http_server) + RemoteKey(hs).register(http_server) diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index 0c9f042c84..d03e728d42 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -13,16 +13,15 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional +import re +from typing import TYPE_CHECKING, Optional, Tuple -from canonicaljson import encode_canonical_json from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 -from twisted.web.resource import Resource from twisted.web.server import Request -from synapse.http.server import respond_with_json_bytes +from synapse.http.servlet import RestServlet from synapse.types import JsonDict if TYPE_CHECKING: @@ -31,7 +30,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class LocalKey(Resource): +class LocalKey(RestServlet): """HTTP resource containing encoding the TLS X.509 certificate and NACL signature verification keys for this server:: @@ -61,18 +60,17 @@ class LocalKey(Resource): } """ - isLeaf = True + PATTERNS = (re.compile("^/_matrix/key/v2/server(/(?P<key_id>[^/]*))?$"),) def __init__(self, hs: "HomeServer"): self.config = hs.config self.clock = hs.get_clock() self.update_response_body(self.clock.time_msec()) - Resource.__init__(self) def update_response_body(self, time_now_msec: int) -> None: refresh_interval = self.config.key.key_refresh_interval self.valid_until_ts = int(time_now_msec + refresh_interval) - self.response_body = encode_canonical_json(self.response_json_object()) + self.response_body = self.response_json_object() def response_json_object(self) -> JsonDict: verify_keys = {} @@ -99,9 +97,11 @@ class LocalKey(Resource): json_object = sign_json(json_object, self.config.server.server_name, key) return json_object - def render_GET(self, request: Request) -> Optional[int]: + def on_GET( + self, request: Request, key_id: Optional[str] = None + ) -> Tuple[int, JsonDict]: time_now = self.clock.time_msec() # Update the expiry time if less than half the interval remains. if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts: self.update_response_body(time_now) - return respond_with_json_bytes(request, 200, self.response_body) + return 200, self.response_body diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index 7f8ad29566..19820886f5 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -13,15 +13,20 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Set +import re +from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple from signedjson.sign import sign_json -from synapse.api.errors import Codes, SynapseError +from twisted.web.server import Request + from synapse.crypto.keyring import ServerKeyFetcher -from synapse.http.server import DirectServeJsonResource, respond_with_json -from synapse.http.servlet import parse_integer, parse_json_object_from_request -from synapse.http.site import SynapseRequest +from synapse.http.server import HttpServer +from synapse.http.servlet import ( + RestServlet, + parse_integer, + parse_json_object_from_request, +) from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import yieldable_gather_results @@ -32,7 +37,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class RemoteKey(DirectServeJsonResource): +class RemoteKey(RestServlet): """HTTP resource for retrieving the TLS certificate and NACL signature verification keys for a collection of servers. Checks that the reported X.509 TLS certificate matches the one used in the HTTPS connection. Checks @@ -88,11 +93,7 @@ class RemoteKey(DirectServeJsonResource): } """ - isLeaf = True - def __init__(self, hs: "HomeServer"): - super().__init__() - self.fetcher = ServerKeyFetcher(hs) self.store = hs.get_datastores().main self.clock = hs.get_clock() @@ -101,36 +102,48 @@ class RemoteKey(DirectServeJsonResource): ) self.config = hs.config - async def _async_render_GET(self, request: SynapseRequest) -> None: - assert request.postpath is not None - if len(request.postpath) == 1: - (server,) = request.postpath - query: dict = {server.decode("ascii"): {}} - elif len(request.postpath) == 2: - server, key_id = request.postpath + def register(self, http_server: HttpServer) -> None: + http_server.register_paths( + "GET", + ( + re.compile( + "^/_matrix/key/v2/query/(?P<server>[^/]*)(/(?P<key_id>[^/]*))?$" + ), + ), + self.on_GET, + self.__class__.__name__, + ) + http_server.register_paths( + "POST", + (re.compile("^/_matrix/key/v2/query$"),), + self.on_POST, + self.__class__.__name__, + ) + + async def on_GET( + self, request: Request, server: str, key_id: Optional[str] = None + ) -> Tuple[int, JsonDict]: + if server and key_id: minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") arguments = {} if minimum_valid_until_ts is not None: arguments["minimum_valid_until_ts"] = minimum_valid_until_ts - query = {server.decode("ascii"): {key_id.decode("ascii"): arguments}} + query = {server: {key_id: arguments}} else: - raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND) + query = {server: {}} - await self.query_keys(request, query, query_remote_on_cache_miss=True) + return 200, await self.query_keys(query, query_remote_on_cache_miss=True) - async def _async_render_POST(self, request: SynapseRequest) -> None: + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: content = parse_json_object_from_request(request) query = content["server_keys"] - await self.query_keys(request, query, query_remote_on_cache_miss=True) + return 200, await self.query_keys(query, query_remote_on_cache_miss=True) async def query_keys( - self, - request: SynapseRequest, - query: JsonDict, - query_remote_on_cache_miss: bool = False, - ) -> None: + self, query: JsonDict, query_remote_on_cache_miss: bool = False + ) -> JsonDict: logger.info("Handling query for keys %r", query) store_queries = [] @@ -232,7 +245,7 @@ class RemoteKey(DirectServeJsonResource): for server_name, keys in cache_misses.items() ), ) - await self.query_keys(request, query, query_remote_on_cache_miss=False) + return await self.query_keys(query, query_remote_on_cache_miss=False) else: signed_keys = [] for key_json_raw in json_results: @@ -244,6 +257,4 @@ class RemoteKey(DirectServeJsonResource): signed_keys.append(key_json) - response = {"server_keys": signed_keys} - - respond_with_json(request, 200, response, canonical_json=True) + return {"server_keys": signed_keys} diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py index 1c1c7b3613..22784157e6 100644 --- a/synapse/rest/synapse/client/new_user_consent.py +++ b/synapse/rest/synapse/client/new_user_consent.py @@ -20,6 +20,7 @@ from synapse.api.errors import SynapseError from synapse.handlers.sso import get_username_mapping_session_cookie_from_request from synapse.http.server import DirectServeHtmlResource, respond_with_html from synapse.http.servlet import parse_string +from synapse.http.site import SynapseRequest from synapse.types import UserID from synapse.util.templates import build_jinja_env @@ -88,7 +89,7 @@ class NewUserConsentResource(DirectServeHtmlResource): html = template.render(template_params) respond_with_html(request, 200, html) - async def _async_render_POST(self, request: Request) -> None: + async def _async_render_POST(self, request: SynapseRequest) -> None: try: session_id = get_username_mapping_session_cookie_from_request(request) except SynapseError as e: diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index 6f7ac54c65..e2174fdfea 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -18,6 +18,7 @@ from twisted.web.resource import Resource from twisted.web.server import Request from synapse.http.server import set_cors_headers +from synapse.http.site import SynapseRequest from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.stringutils import parse_server_name @@ -63,7 +64,7 @@ class ClientWellKnownResource(Resource): Resource.__init__(self) self._well_known_builder = WellKnownBuilder(hs) - def render_GET(self, request: Request) -> bytes: + def render_GET(self, request: SynapseRequest) -> bytes: set_cors_headers(request) r = self._well_known_builder.get_well_known() if not r: diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 7bb21f8f81..4717c9728a 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -1658,7 +1658,7 @@ class DatabasePool: table: string giving the table name keyvalues: dict of column names and values to select the row with retcol: string giving the name of the column to return - allow_none: If true, return None instead of failing if the SELECT + allow_none: If true, return None instead of raising StoreError if the SELECT statement returns no rows desc: description of the transaction, for logging and metrics """ diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 3b8ed1f7ee..ddb7397714 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -244,12 +244,18 @@ class CacheInvalidationWorkerStore(SQLBaseStore): # redacted. self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,)) self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,)) + self._attempt_to_invalidate_cache("get_thread_id", (redacts,)) + self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,)) if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) self._attempt_to_invalidate_cache( "get_invited_rooms_for_local_user", (state_key,) ) + self._attempt_to_invalidate_cache( + "get_rooms_for_user_with_stream_ordering", (state_key,) + ) + self._attempt_to_invalidate_cache("get_rooms_for_user", (state_key,)) if relates_to: self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) @@ -259,9 +265,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,)) - self._attempt_to_invalidate_cache( - "get_mutual_event_relations_for_rel_type", (relates_to,) - ) + self._attempt_to_invalidate_cache("get_threads", (room_id,)) async def invalidate_cache_and_stream( self, cache_name: str, keys: Tuple[Any, ...] diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 18358eca46..830b076a32 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -539,9 +539,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): "device_id": device_id, "prev_id": [prev_id] if prev_id else [], "stream_id": stream_id, - "org.matrix.opentracing_context": opentracing_context, } + if opentracing_context != "{}": + result["org.matrix.opentracing_context"] = opentracing_context + prev_id = stream_id if device is not None: @@ -549,7 +551,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): if keys: result["keys"] = keys - device_display_name = device.display_name + device_display_name = None + if ( + self.hs.config.federation.allow_device_name_lookup_over_federation + ): + device_display_name = device.display_name if device_display_name: result["device_display_name"] = device_display_name else: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 6b9a629edd..309a4ba664 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1501,6 +1501,12 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas event_id: The event that failed to be fetched or processed cause: The error message or reason that we failed to pull the event """ + logger.debug( + "record_event_failed_pull_attempt room_id=%s, event_id=%s, cause=%s", + room_id, + event_id, + cause, + ) await self.db_pool.runInteraction( "record_event_failed_pull_attempt", self._record_event_failed_pull_attempt_upsert_txn, @@ -1530,6 +1536,54 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas txn.execute(sql, (room_id, event_id, 1, self._clock.time_msec(), cause)) + @trace + async def get_event_ids_to_not_pull_from_backoff( + self, + room_id: str, + event_ids: Collection[str], + ) -> List[str]: + """ + Filter down the events to ones that we've failed to pull before recently. Uses + exponential backoff. + + Args: + room_id: The room that the events belong to + event_ids: A list of events to filter down + + Returns: + List of event_ids that should not be attempted to be pulled + """ + event_failed_pull_attempts = await self.db_pool.simple_select_many_batch( + table="event_failed_pull_attempts", + column="event_id", + iterable=event_ids, + keyvalues={}, + retcols=( + "event_id", + "last_attempt_ts", + "num_attempts", + ), + desc="get_event_ids_to_not_pull_from_backoff", + ) + + current_time = self._clock.time_msec() + return [ + event_failed_pull_attempt["event_id"] + for event_failed_pull_attempt in event_failed_pull_attempts + # Exponential back-off (up to the upper bound) so we don't try to + # pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc. + if current_time + < event_failed_pull_attempt["last_attempt_ts"] + + ( + 2 + ** min( + event_failed_pull_attempt["num_attempts"], + BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, + ) + ) + * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS + ] + async def get_missing_events( self, room_id: str, diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 332e13d1c9..b283ab0f9c 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -294,6 +294,44 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas self._background_backfill_thread_id, ) + # Indexes which will be used to quickly make the thread_id column non-null. + self.db_pool.updates.register_background_index_update( + "event_push_actions_thread_id_null", + index_name="event_push_actions_thread_id_null", + table="event_push_actions", + columns=["thread_id"], + where_clause="thread_id IS NULL", + ) + self.db_pool.updates.register_background_index_update( + "event_push_summary_thread_id_null", + index_name="event_push_summary_thread_id_null", + table="event_push_summary", + columns=["thread_id"], + where_clause="thread_id IS NULL", + ) + + # Check ASAP (and then later, every 1s) to see if we have finished + # background updates the event_push_actions and event_push_summary tables. + self._clock.call_later(0.0, self._check_event_push_backfill_thread_id) + self._event_push_backfill_thread_id_done = False + + @wrap_as_background_process("check_event_push_backfill_thread_id") + async def _check_event_push_backfill_thread_id(self) -> None: + """ + Has thread_id finished backfilling? + + If not, we need to just-in-time update it so the queries work. + """ + done = await self.db_pool.updates.has_completed_background_update( + "event_push_backfill_thread_id" + ) + + if done: + self._event_push_backfill_thread_id_done = True + else: + # Reschedule to run. + self._clock.call_later(15.0, self._check_event_push_backfill_thread_id) + async def _background_backfill_thread_id( self, progress: JsonDict, batch_size: int ) -> int: @@ -310,11 +348,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas event_push_actions_done = progress.get("event_push_actions_done", False) def add_thread_id_txn( - txn: LoggingTransaction, table_name: str, start_stream_ordering: int + txn: LoggingTransaction, start_stream_ordering: int ) -> int: - sql = f""" + sql = """ SELECT stream_ordering - FROM {table_name} + FROM event_push_actions WHERE thread_id IS NULL AND stream_ordering > ? @@ -326,7 +364,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # No more rows to process. rows = txn.fetchall() if not rows: - progress[f"{table_name}_done"] = True + progress["event_push_actions_done"] = True self.db_pool.updates._background_update_progress_txn( txn, "event_push_backfill_thread_id", progress ) @@ -335,16 +373,65 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # Update the thread ID for any of those rows. max_stream_ordering = rows[-1][0] - sql = f""" - UPDATE {table_name} + sql = """ + UPDATE event_push_actions SET thread_id = 'main' - WHERE stream_ordering <= ? AND thread_id IS NULL + WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL """ - txn.execute(sql, (max_stream_ordering,)) + txn.execute( + sql, + ( + start_stream_ordering, + max_stream_ordering, + ), + ) # Update progress. processed_rows = txn.rowcount - progress[f"max_{table_name}_stream_ordering"] = max_stream_ordering + progress["max_event_push_actions_stream_ordering"] = max_stream_ordering + self.db_pool.updates._background_update_progress_txn( + txn, "event_push_backfill_thread_id", progress + ) + + return processed_rows + + def add_thread_id_summary_txn(txn: LoggingTransaction) -> int: + min_user_id = progress.get("max_summary_user_id", "") + min_room_id = progress.get("max_summary_room_id", "") + + # Slightly overcomplicated query for getting the Nth user ID / room + # ID tuple, or the last if there are less than N remaining. + sql = """ + SELECT user_id, room_id FROM ( + SELECT user_id, room_id FROM event_push_summary + WHERE (user_id, room_id) > (?, ?) + AND thread_id IS NULL + ORDER BY user_id, room_id + LIMIT ? + ) AS e + ORDER BY user_id DESC, room_id DESC + LIMIT 1 + """ + + txn.execute(sql, (min_user_id, min_room_id, batch_size)) + row = txn.fetchone() + if not row: + return 0 + + max_user_id, max_room_id = row + + sql = """ + UPDATE event_push_summary + SET thread_id = 'main' + WHERE + (?, ?) < (user_id, room_id) AND (user_id, room_id) <= (?, ?) + AND thread_id IS NULL + """ + txn.execute(sql, (min_user_id, min_room_id, max_user_id, max_room_id)) + processed_rows = txn.rowcount + + progress["max_summary_user_id"] = max_user_id + progress["max_summary_room_id"] = max_room_id self.db_pool.updates._background_update_progress_txn( txn, "event_push_backfill_thread_id", progress ) @@ -360,15 +447,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas result = await self.db_pool.runInteraction( "event_push_backfill_thread_id", add_thread_id_txn, - "event_push_actions", progress.get("max_event_push_actions_stream_ordering", 0), ) else: result = await self.db_pool.runInteraction( "event_push_backfill_thread_id", - add_thread_id_txn, - "event_push_summary", - progress.get("max_event_push_summary_stream_ordering", 0), + add_thread_id_summary_txn, ) # Only done after the event_push_summary table is done. @@ -480,6 +564,25 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), ) + # First ensure that the existing rows have an updated thread_id field. + if not self._event_push_backfill_thread_id_done: + txn.execute( + """ + UPDATE event_push_summary + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + (MAIN_TIMELINE, room_id, user_id), + ) + txn.execute( + """ + UPDATE event_push_actions + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + (MAIN_TIMELINE, room_id, user_id), + ) + # First we pull the counts from the summary table. # # We check that `last_receipt_stream_ordering` matches the stream ordering of the @@ -1295,6 +1398,25 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas (room_id, user_id, stream_ordering, *thread_args), ) + # First ensure that the existing rows have an updated thread_id field. + if not self._event_push_backfill_thread_id_done: + txn.execute( + """ + UPDATE event_push_summary + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + (MAIN_TIMELINE, room_id, user_id), + ) + txn.execute( + """ + UPDATE event_push_actions + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + (MAIN_TIMELINE, room_id, user_id), + ) + # Fetch the notification counts between the stream ordering of the # latest receipt and what was previously summarised. unread_counts = self._get_notif_unread_count_for_user_room( @@ -1429,6 +1551,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas rotate_to_stream_ordering: The new maximum event stream ordering to summarise. """ + # Ensure that any new actions have an updated thread_id. + if not self._event_push_backfill_thread_id_done: + txn.execute( + """ + UPDATE event_push_actions + SET thread_id = ? + WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL + """, + (MAIN_TIMELINE, old_rotate_stream_ordering, rotate_to_stream_ordering), + ) + + # XXX Do we need to update summaries here too? + # Calculate the new counts that should be upserted into event_push_summary sql = """ SELECT user_id, room_id, thread_id, @@ -1491,6 +1626,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas logger.info("Rotating notifications, handling %d rows", len(summaries)) + # Ensure that any updated threads have the proper thread_id. + if not self._event_push_backfill_thread_id_done: + txn.execute_batch( + """ + UPDATE event_push_summary + SET thread_id = ? + WHERE room_id = ? AND user_id = ? AND thread_id is NULL + """, + [ + (MAIN_TIMELINE, room_id, user_id) + for user_id, room_id, _ in summaries + ], + ) + self.db_pool.simple_upsert_many_txn( txn, table="event_push_summary", diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 3e15827986..6698cbf664 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -35,7 +35,7 @@ import attr from prometheus_client import Counter import synapse.metrics -from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, relation_from_event @@ -1616,7 +1616,7 @@ class PersistEventsStore: ) # Remove from relations table. - self._handle_redact_relations(txn, event.redacts) + self._handle_redact_relations(txn, event.room_id, event.redacts) # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. @@ -1866,6 +1866,34 @@ class PersistEventsStore: }, ) + if relation.rel_type == RelationTypes.THREAD: + # Upsert into the threads table, but only overwrite the value if the + # new event is of a later topological order OR if the topological + # ordering is equal, but the stream ordering is later. + sql = """ + INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (room_id, thread_id) + DO UPDATE SET + latest_event_id = excluded.latest_event_id, + topological_ordering = excluded.topological_ordering, + stream_ordering = excluded.stream_ordering + WHERE + threads.topological_ordering <= excluded.topological_ordering AND + threads.stream_ordering < excluded.stream_ordering + """ + + txn.execute( + sql, + ( + event.room_id, + relation.parent_id, + event.event_id, + event.depth, + event.internal_metadata.stream_ordering, + ), + ) + def _handle_insertion_event( self, txn: LoggingTransaction, event: EventBase ) -> None: @@ -1989,13 +2017,14 @@ class PersistEventsStore: txn.execute(sql, (batch_id,)) def _handle_redact_relations( - self, txn: LoggingTransaction, redacted_event_id: str + self, txn: LoggingTransaction, room_id: str, redacted_event_id: str ) -> None: """Handles receiving a redaction and checking whether the redacted event has any relations which must be removed from the database. Args: txn + room_id: The room ID of the event that was redacted. redacted_event_id: The event that was redacted. """ @@ -2025,9 +2054,7 @@ class PersistEventsStore: txn, self.store.get_thread_participated, (redacted_relates_to,) ) self.store._invalidate_cache_and_stream( - txn, - self.store.get_mutual_event_relations_for_rel_type, - (redacted_relates_to,), + txn, self.store.get_threads, (room_id,) ) self.db_pool.simple_delete_txn( diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 7cdc9fe98f..69fea452ad 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -374,7 +374,7 @@ class EventsWorkerStore(SQLBaseStore): If there is a mismatch, behave as per allow_none. Returns: - The event, or None if the event was not found. + The event, or None if the event was not found and allow_none is `True`. """ if not isinstance(event_id, str): raise TypeError("Invalid event event_id %r" % (event_id,)) @@ -474,7 +474,7 @@ class EventsWorkerStore(SQLBaseStore): return [] # there may be duplicates so we cast the list to a set - event_entry_map = await self._get_events_from_cache_or_db( + event_entry_map = await self.get_unredacted_events_from_cache_or_db( set(event_ids), allow_rejected=allow_rejected ) @@ -509,7 +509,9 @@ class EventsWorkerStore(SQLBaseStore): continue redacted_event_id = entry.event.redacts - event_map = await self._get_events_from_cache_or_db([redacted_event_id]) + event_map = await self.get_unredacted_events_from_cache_or_db( + [redacted_event_id] + ) original_event_entry = event_map.get(redacted_event_id) if not original_event_entry: # we don't have the redacted event (or it was rejected). @@ -588,11 +590,16 @@ class EventsWorkerStore(SQLBaseStore): return events @cancellable - async def _get_events_from_cache_or_db( - self, event_ids: Iterable[str], allow_rejected: bool = False + async def get_unredacted_events_from_cache_or_db( + self, + event_ids: Iterable[str], + allow_rejected: bool = False, ) -> Dict[str, EventCacheEntry]: """Fetch a bunch of events from the cache or the database. + Note that the events pulled by this function will not have any redactions + applied, and no guarantee is made about the ordering of the events returned. + If events are pulled from the database, they will be cached for future lookups. Unknown events are omitted from the response. @@ -1495,21 +1502,15 @@ class EventsWorkerStore(SQLBaseStore): Returns: a dict {event_id -> bool} """ - # if the event cache contains the event, obviously we've seen it. - - cache_results = { - event_id - for event_id in event_ids - if await self._get_event_cache.contains((event_id,)) - } - results = dict.fromkeys(cache_results, True) - remaining = [ - event_id for event_id in event_ids if event_id not in cache_results - ] - if not remaining: - return results + # TODO: We used to query the _get_event_cache here as a fast-path before + # hitting the database. For if an event were in the cache, we've presumably + # seen it before. + # + # But this is currently an invalid assumption due to the _get_event_cache + # not being invalidated when purging events from a room. The optimisation can + # be re-added after https://github.com/matrix-org/synapse/issues/13476 - def have_seen_events_txn(txn: LoggingTransaction) -> None: + def have_seen_events_txn(txn: LoggingTransaction) -> Dict[str, bool]: # we deliberately do *not* query the database for room_id, to make the # query an index-only lookup on `events_event_id_key`. # @@ -1517,16 +1518,17 @@ class EventsWorkerStore(SQLBaseStore): sql = "SELECT event_id FROM events AS e WHERE " clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", remaining + txn.database_engine, "e.event_id", event_ids ) txn.execute(sql + clause, args) found_events = {eid for eid, in txn} # ... and then we can update the results for each key - results.update({eid: (eid in found_events) for eid in remaining}) + return {eid: (eid in found_events) for eid in event_ids} - await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn) - return results + return await self.db_pool.runInteraction( + "have_seen_events", have_seen_events_txn + ) @cached(max_entries=100000, tree=True) async def have_seen_event(self, room_id: str, event_id: str) -> bool: @@ -1969,12 +1971,17 @@ class EventsWorkerStore(SQLBaseStore): Args: room_id: room where the event lives - event_id: event to check + event: event to check (can't be an `outlier`) Returns: Boolean indicating whether it's an extremity """ + assert not event.internal_metadata.is_outlier(), ( + "is_event_next_to_backward_gap(...) can't be used with `outlier` events. " + "This function relies on `event_backward_extremities` which won't be filled in for `outliers`." + ) + def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool: # If the event in question has any of its prev_events listed as a # backward extremity, it's next to a gap. @@ -2024,12 +2031,17 @@ class EventsWorkerStore(SQLBaseStore): Args: room_id: room where the event lives - event_id: event to check + event: event to check (can't be an `outlier`) Returns: Boolean indicating whether it's an extremity """ + assert not event.internal_metadata.is_outlier(), ( + "is_event_next_to_forward_gap(...) can't be used with `outlier` events. " + "This function relies on `event_edges` and `event_forward_extremities` which won't be filled in for `outliers`." + ) + def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool: # If the event in question is a forward extremity, we will just # consider any potential forward gap as not a gap since it's one of @@ -2110,13 +2122,33 @@ class EventsWorkerStore(SQLBaseStore): The closest event_id otherwise None if we can't find any event in the given direction. """ + if direction == "b": + # Find closest event *before* a given timestamp. We use descending + # (which gives values largest to smallest) because we want the + # largest possible timestamp *before* the given timestamp. + comparison_operator = "<=" + order = "DESC" + else: + # Find closest event *after* a given timestamp. We use ascending + # (which gives values smallest to largest) because we want the + # closest possible timestamp *after* the given timestamp. + comparison_operator = ">=" + order = "ASC" - sql_template = """ + sql_template = f""" SELECT event_id FROM events LEFT JOIN rejections USING (event_id) WHERE - origin_server_ts %s ? - AND room_id = ? + room_id = ? + AND origin_server_ts {comparison_operator} ? + /** + * Make sure the event isn't an `outlier` because we have no way + * to later check whether it's next to a gap. `outliers` do not + * have entries in the `event_edges`, `event_forward_extremeties`, + * and `event_backward_extremities` tables to check against + * (used by `is_event_next_to_backward_gap` and `is_event_next_to_forward_gap`). + */ + AND NOT outlier /* Make sure event is not rejected */ AND rejections.event_id IS NULL /** @@ -2126,27 +2158,14 @@ class EventsWorkerStore(SQLBaseStore): * Finally, we can tie-break based on when it was received on the server * (`stream_ordering`). */ - ORDER BY origin_server_ts %s, depth %s, stream_ordering %s + ORDER BY origin_server_ts {order}, depth {order}, stream_ordering {order} LIMIT 1; """ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: - if direction == "b": - # Find closest event *before* a given timestamp. We use descending - # (which gives values largest to smallest) because we want the - # largest possible timestamp *before* the given timestamp. - comparison_operator = "<=" - order = "DESC" - else: - # Find closest event *after* a given timestamp. We use ascending - # (which gives values smallest to largest) because we want the - # closest possible timestamp *after* the given timestamp. - comparison_operator = ">=" - order = "ASC" - txn.execute( - sql_template % (comparison_operator, order, order, order), - (timestamp, room_id), + sql_template, + (room_id, timestamp), ) row = txn.fetchone() if row: diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 8295322b0e..51416b2236 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -29,7 +29,6 @@ from typing import ( ) from synapse.api.errors import StoreError -from synapse.config.homeserver import ExperimentalConfig from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -63,9 +62,7 @@ logger = logging.getLogger(__name__) def _load_rules( - rawrules: List[JsonDict], - enabled_map: Dict[str, bool], - experimental_config: ExperimentalConfig, + rawrules: List[JsonDict], enabled_map: Dict[str, bool] ) -> FilteredPushRules: """Take the DB rows returned from the DB and convert them into a full `FilteredPushRules` object. @@ -83,9 +80,7 @@ def _load_rules( push_rules = PushRules(ruleslist) - filtered_rules = FilteredPushRules( - push_rules, enabled_map, msc3772_enabled=experimental_config.msc3772_enabled - ) + filtered_rules = FilteredPushRules(push_rules, enabled_map) return filtered_rules @@ -165,7 +160,7 @@ class PushRulesWorkerStore( enabled_map = await self.get_push_rules_enabled_for_user(user_id) - return _load_rules(rows, enabled_map, self.hs.config.experimental) + return _load_rules(rows, enabled_map) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( @@ -224,9 +219,7 @@ class PushRulesWorkerStore( results: Dict[str, FilteredPushRules] = {} for user_id, rules in raw_rules.items(): - results[user_id] = _load_rules( - rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental - ) + results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) return results diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 246f78ac1f..dc6989527e 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -418,6 +418,8 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type = event_entry.setdefault(row["receipt_type"], {}) receipt_type[row["user_id"]] = db_to_json(row["data"]) + if row["thread_id"]: + receipt_type[row["user_id"]]["thread_id"] = row["thread_id"] results = { room_id: [results[room_id]] if room_id in results else [] diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 116abef9de..1de62ee9df 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -14,6 +14,7 @@ import logging from typing import ( + TYPE_CHECKING, Collection, Dict, FrozenSet, @@ -28,19 +29,48 @@ from typing import ( import attr -from synapse.api.constants import RelationTypes +from synapse.api.constants import MAIN_TIMELINE, RelationTypes +from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @attr.s(slots=True, frozen=True, auto_attribs=True) +class ThreadsNextBatch: + topological_ordering: int + stream_ordering: int + + def __str__(self) -> str: + return f"{self.topological_ordering}_{self.stream_ordering}" + + @classmethod + def from_string(cls, string: str) -> "ThreadsNextBatch": + """ + Creates a ThreadsNextBatch from its textual representation. + """ + try: + keys = (int(s) for s in string.split("_")) + return cls(*keys) + except Exception: + raise SynapseError(400, "Invalid threads token") + + +@attr.s(slots=True, frozen=True, auto_attribs=True) class _RelatedEvent: """ Contains enough information about a related event in order to properly filter @@ -56,6 +86,76 @@ class _RelatedEvent: class RelationsWorkerStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_update_handler( + "threads_backfill", self._backfill_threads + ) + + async def _backfill_threads(self, progress: JsonDict, batch_size: int) -> int: + """Backfill the threads table.""" + + def threads_backfill_txn(txn: LoggingTransaction) -> int: + last_thread_id = progress.get("last_thread_id", "") + + # Get the latest event in each thread by topo ordering / stream ordering. + # + # Note that the MAX(event_id) is needed to abide by the rules of group by, + # but doesn't actually do anything since there should only be a single event + # ID per topo/stream ordering pair. + sql = f""" + SELECT room_id, relates_to_id, MAX(topological_ordering), MAX(stream_ordering), MAX(event_id) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id > ? AND + relation_type = '{RelationTypes.THREAD}' + GROUP BY room_id, relates_to_id + ORDER BY relates_to_id + LIMIT ? + """ + txn.execute(sql, (last_thread_id, batch_size)) + + # No more rows to process. + rows = txn.fetchall() + if not rows: + return 0 + + # Insert the rows into the threads table. If a matching thread already exists, + # assume it is from a newer event. + sql = """ + INSERT INTO threads (room_id, thread_id, topological_ordering, stream_ordering, latest_event_id) + VALUES %s + ON CONFLICT (room_id, thread_id) + DO NOTHING + """ + if isinstance(txn.database_engine, PostgresEngine): + txn.execute_values(sql % ("?",), rows, fetch=False) + else: + txn.execute_batch(sql % ("(?, ?, ?, ?, ?)",), rows) + + # Mark the progress. + self.db_pool.updates._background_update_progress_txn( + txn, "threads_backfill", {"last_thread_id": rows[-1][1]} + ) + + return txn.rowcount + + result = await self.db_pool.runInteraction( + "threads_backfill", threads_backfill_txn + ) + + if not result: + await self.db_pool.updates._end_background_update("threads_backfill") + + return result + @cached(uncached_args=("event",), tree=True) async def get_relations_for_event( self, @@ -776,95 +876,194 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) - @cached(iterable=True) - async def get_mutual_event_relations_for_rel_type( - self, event_id: str, relation_type: str - ) -> Set[Tuple[str, str]]: - raise NotImplementedError() - - @cachedList( - cached_method_name="get_mutual_event_relations_for_rel_type", - list_name="relation_types", - ) - async def get_mutual_event_relations( - self, event_id: str, relation_types: Collection[str] - ) -> Dict[str, Set[Tuple[str, str]]]: - """ - Fetch event metadata for events which related to the same event as the given event. - - If the given event has no relation information, returns an empty dictionary. + @cached(tree=True) + async def get_threads( + self, + room_id: str, + limit: int = 5, + from_token: Optional[ThreadsNextBatch] = None, + ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: + """Get a list of thread IDs, ordered by topological ordering of their + latest reply. Args: - event_id: The event ID which is targeted by relations. - relation_types: The relation types to check for mutual relations. + room_id: The room the event belongs to. + limit: Only fetch the most recent `limit` threads. + from_token: Fetch rows from a previous next_batch, or from the start if None. Returns: - A dictionary of relation type to: - A set of tuples of: - The sender - The event type + A tuple of: + A list of thread root event IDs. + + The next_batch, if one exists. """ - rel_type_sql, rel_type_args = make_in_list_sql_clause( - self.database_engine, "relation_type", relation_types - ) + # Generate the pagination clause, if necessary. + # + # Find any threads where the latest reply is equal / before the last + # thread's topo ordering and earlier in stream ordering. + pagination_clause = "" + pagination_args: tuple = () + if from_token: + pagination_clause = "AND topological_ordering <= ? AND stream_ordering < ?" + pagination_args = ( + from_token.topological_ordering, + from_token.stream_ordering, + ) sql = f""" - SELECT DISTINCT relation_type, sender, type FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND {rel_type_sql} + SELECT thread_id, topological_ordering, stream_ordering + FROM threads + WHERE + room_id = ? + {pagination_clause} + ORDER BY topological_ordering DESC, stream_ordering DESC + LIMIT ? """ - def _get_event_relations( + def _get_threads_txn( txn: LoggingTransaction, - ) -> Dict[str, Set[Tuple[str, str]]]: - txn.execute(sql, [event_id] + rel_type_args) - result: Dict[str, Set[Tuple[str, str]]] = { - rel_type: set() for rel_type in relation_types - } - for rel_type, sender, type in txn.fetchall(): - result[rel_type].add((sender, type)) - return result + ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: + txn.execute(sql, (room_id, *pagination_args, limit + 1)) - return await self.db_pool.runInteraction( - "get_event_relations", _get_event_relations - ) + rows = cast(List[Tuple[str, int, int]], txn.fetchall()) + thread_ids = [r[0] for r in rows] + + # If there are more events, generate the next pagination key from the + # last thread which will be returned. + next_token = None + if len(thread_ids) > limit: + last_topo_id = rows[-2][1] + last_stream_id = rows[-2][2] + next_token = ThreadsNextBatch(last_topo_id, last_stream_id) + + return thread_ids[:limit], next_token + + return await self.db_pool.runInteraction("get_threads", _get_threads_txn) @cached() - async def get_thread_id(self, event_id: str) -> Optional[str]: + async def get_thread_id(self, event_id: str) -> str: """ Get the thread ID for an event. This considers multi-level relations, e.g. an annotation to an event which is part of a thread. + It only searches up the relations tree, i.e. it only searches for events + which the given event is related to (and which those events are related + to, etc.) + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + get_thread_id(X) considers events B and C as part of thread A. + + See also get_thread_id_for_receipts. + Args: event_id: The event ID to fetch the thread ID for. Returns: The event ID of the root event in the thread, if this event is part - of a thread. None, otherwise. + of a thread. "main", otherwise. """ - # Since event relations form a tree, we should only ever find 0 or 1 - # results from the below query. + + # Recurse event relations up to the *root* event, then search that chain + # of relations for a thread relation. If one is found, the root event is + # returned. + # + # Note that this should only ever find 0 or 1 entries since it is invalid + # for an event to have a thread relation to an event which also has a + # relation. sql = """ WITH RECURSIVE related_events AS ( - SELECT event_id, relates_to_id, relation_type + SELECT event_id, relates_to_id, relation_type, 0 depth FROM event_relations WHERE event_id = ? - UNION SELECT e.event_id, e.relates_to_id, e.relation_type + UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 FROM event_relations e INNER JOIN related_events r ON r.relates_to_id = e.event_id - ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread'; + WHERE depth <= 3 + ) + SELECT relates_to_id FROM related_events + WHERE relation_type = 'm.thread' + ORDER BY depth DESC + LIMIT 1; """ - def _get_thread_id(txn: LoggingTransaction) -> Optional[str]: + def _get_thread_id(txn: LoggingTransaction) -> str: txn.execute(sql, (event_id,)) - # TODO Should we ensure there's only a single result here? row = txn.fetchone() if row: return row[0] - return None + + # If no thread was found, it is part of the main timeline. + return MAIN_TIMELINE return await self.db_pool.runInteraction("get_thread_id", _get_thread_id) + @cached() + async def get_thread_id_for_receipts(self, event_id: str) -> str: + """ + Get the thread ID for an event by traversing to the top-most related event + and confirming any children events form a thread. + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + get_thread_id_for_receipts(X) considers events A, B, C, D, and E as part + of thread A. + + See also get_thread_id. + + Args: + event_id: The event ID to fetch the thread ID for. + + Returns: + The event ID of the root event in the thread, if this event is part + of a thread. "main", otherwise. + """ + + # Recurse event relations up to the *root* event, then search for any events + # related to that root node for a thread relation. If one is found, the + # root event is returned. + # + # Note that there cannot be thread relations in the middle of the chain since + # it is invalid for an event to have a thread relation to an event which also + # has a relation. + sql = """ + SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE(( + WITH RECURSIVE related_events AS ( + SELECT event_id, relates_to_id, relation_type, 0 depth + FROM event_relations + WHERE event_id = ? + UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 + FROM event_relations e + INNER JOIN related_events r ON r.relates_to_id = e.event_id + WHERE depth <= 3 + ) + SELECT relates_to_id FROM related_events + ORDER BY depth DESC + LIMIT 1 + ), ?) AND relation_type = 'm.thread' LIMIT 1; + """ + + def _get_related_thread_id(txn: LoggingTransaction) -> str: + txn.execute(sql, (event_id, event_id)) + row = txn.fetchone() + if row: + return row[0] + + # If no thread was found, it is part of the main timeline. + return MAIN_TIMELINE + + return await self.db_pool.runInteraction( + "get_related_thread_id", _get_related_thread_id + ) + class RelationsStore(RelationsWorkerStore): pass diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index e41c99027a..7d97f8f60e 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -97,6 +97,12 @@ class RoomSortOrder(Enum): STATE_EVENTS = "state_events" +@attr.s(slots=True, frozen=True, auto_attribs=True) +class PartialStateResyncInfo: + joined_via: Optional[str] + servers_in_room: List[str] = attr.ib(factory=list) + + class RoomWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -1160,17 +1166,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): desc="get_partial_state_servers_at_join", ) - async def get_partial_state_rooms_and_servers( + async def get_partial_state_room_resync_info( self, - ) -> Mapping[str, Collection[str]]: - """Get all rooms containing events with partial state, and the servers known - to be in the room. + ) -> Mapping[str, PartialStateResyncInfo]: + """Get all rooms containing events with partial state, and the information + needed to restart a "resync" of those rooms. Returns: A dictionary of rooms with partial state, with room IDs as keys and lists of servers in rooms as values. """ - room_servers: Dict[str, List[str]] = {} + room_servers: Dict[str, PartialStateResyncInfo] = {} + + rows = await self.db_pool.simple_select_list( + table="partial_state_rooms", + keyvalues={}, + retcols=("room_id", "joined_via"), + desc="get_server_which_served_partial_join", + ) + + for row in rows: + room_id = row["room_id"] + joined_via = row["joined_via"] + room_servers[room_id] = PartialStateResyncInfo(joined_via=joined_via) rows = await self.db_pool.simple_select_list( "partial_state_rooms_servers", @@ -1182,7 +1200,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): for row in rows: room_id = row["room_id"] server_name = row["server_name"] - room_servers.setdefault(room_id, []).append(server_name) + entry = room_servers.get(room_id) + if entry is None: + # There is a foreign key constraint which enforces that every room_id in + # partial_state_rooms_servers appears in partial_state_rooms. So we + # expect `entry` to be non-null. (This reasoning fails if we've + # partial-joined between the two SELECTs, but this is unlikely to happen + # in practice.) + continue + entry.servers_in_room.append(server_name) return room_servers @@ -1827,6 +1853,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): room_id: str, servers: Collection[str], device_lists_stream_id: int, + joined_via: str, ) -> None: """Mark the given room as containing events with partial state. @@ -1842,6 +1869,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): servers: other servers known to be in the room device_lists_stream_id: the device_lists stream ID at the time when we first joined the room. + joined_via: the server name we requested a partial join from. """ await self.db_pool.runInteraction( "store_partial_state_room", @@ -1849,6 +1877,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): room_id, servers, device_lists_stream_id, + joined_via, ) def _store_partial_state_room_txn( @@ -1857,6 +1886,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): room_id: str, servers: Collection[str], device_lists_stream_id: int, + joined_via: str, ) -> None: DatabasePool.simple_insert_txn( txn, @@ -1866,6 +1896,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): "device_lists_stream_id": device_lists_stream_id, # To be updated later once the join event is persisted. "join_event_id": None, + "joined_via": joined_via, }, ) DatabasePool.simple_insert_many_txn( diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 2337289d88..32e1e983a5 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -666,7 +666,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): cached_method_name="get_rooms_for_user", list_name="user_ids", ) - async def get_rooms_for_users( + async def _get_rooms_for_users( self, user_ids: Collection[str] ) -> Dict[str, FrozenSet[str]]: """A batched version of `get_rooms_for_user`. @@ -697,6 +697,21 @@ class RoomMemberWorkerStore(EventsWorkerStore): return {key: frozenset(rooms) for key, rooms in user_rooms.items()} + async def get_rooms_for_users( + self, user_ids: Collection[str] + ) -> Dict[str, FrozenSet[str]]: + """A batched wrapper around `_get_rooms_for_users`, to prevent locking + other calls to `get_rooms_for_user` for large user lists. + """ + all_user_rooms: Dict[str, FrozenSet[str]] = {} + + # 250 users is pretty arbitrary but the data can be quite large if users + # are in many rooms. + for batch_user_ids in batch_iter(user_ids, 250): + all_user_rooms.update(await self._get_rooms_for_users(batch_user_ids)) + + return all_user_rooms + @cached(max_entries=10000) async def does_pair_of_users_share_a_room( self, user_id: str, other_user_id: str diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 530f04e149..09ce855aa8 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -357,6 +357,24 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: ) args.extend(event_filter.related_by_rel_types) + if event_filter.rel_types: + clauses.append( + "(%s)" + % " OR ".join( + "event_relation.relation_type = ?" for _ in event_filter.rel_types + ) + ) + args.extend(event_filter.rel_types) + + if event_filter.not_rel_types: + clauses.append( + "((%s) OR event_relation.relation_type IS NULL)" + % " AND ".join( + "event_relation.relation_type != ?" for _ in event_filter.not_rel_types + ) + ) + args.extend(event_filter.not_rel_types) + return " AND ".join(clauses), args @@ -1024,28 +1042,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "after": {"event_ids": events_after, "token": end_token}, } - async def get_all_new_events_stream( - self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False - ) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]: + async def get_all_new_event_ids_stream( + self, + from_id: int, + current_id: int, + limit: int, + ) -> Tuple[int, Dict[str, Optional[int]]]: """Get all new events - Returns all events with from_id < stream_ordering <= current_id. + Returns all event ids with from_id < stream_ordering <= current_id. Args: from_id: the stream_ordering of the last event we processed current_id: the stream_ordering of the most recently processed event limit: the maximum number of events to return - get_prev_content: whether to fetch previous event content Returns: - A tuple of (next_id, events, event_to_received_ts), where `next_id` + A tuple of (next_id, event_to_received_ts), where `next_id` is the next value to pass as `from_id` (it will either be the stream_ordering of the last returned event, or, if fewer than `limit` events were found, the `current_id`). The `event_to_received_ts` is - a dictionary mapping event ID to the event `received_ts`. + a dictionary mapping event ID to the event `received_ts`, sorted by ascending + stream_ordering. """ - def get_all_new_events_stream_txn( + def get_all_new_event_ids_stream_txn( txn: LoggingTransaction, ) -> Tuple[int, Dict[str, Optional[int]]]: sql = ( @@ -1070,15 +1091,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, event_to_received_ts upper_bound, event_to_received_ts = await self.db_pool.runInteraction( - "get_all_new_events_stream", get_all_new_events_stream_txn - ) - - events = await self.get_events_as_list( - event_to_received_ts.keys(), - get_prev_content=get_prev_content, + "get_all_new_event_ids_stream", get_all_new_event_ids_stream_txn ) - return upper_bound, events, event_to_received_ts + return upper_bound, event_to_received_ts async def get_federation_out_pos(self, typ: str) -> int: if self._need_to_reset_federation_stream_positions: @@ -1202,8 +1218,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): `to_token`), or `limit` is zero. """ - assert int(limit) >= 0 - # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. @@ -1282,8 +1296,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # Multiple labels could cause the same event to appear multiple times. needs_distinct = True - # If there is a filter on relation_senders and relation_types join to the - # relations table. + # If there is a relation_senders and relation_types filter join to the + # relations table to get events related to the current event. if event_filter and ( event_filter.related_by_senders or event_filter.related_by_rel_types ): @@ -1298,6 +1312,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id) """ + # If there is a not_rel_types filter join to the relations table to get + # the event's relation information. + if event_filter and (event_filter.rel_types or event_filter.not_rel_types): + join_clause += """ + LEFT JOIN event_relations AS event_relation USING (event_id) + """ + if needs_distinct: select_keywords += " DISTINCT" diff --git a/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql b/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql deleted file mode 100644 index 0ffde9bbeb..0000000000 --- a/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2022 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- Forces the background updates from 06thread_notifications.sql to run in the --- foreground as code will now require those to be "done". - -DELETE FROM background_updates WHERE update_name = 'event_push_backfill_thread_id'; - --- Overwrite any null thread_id columns. -UPDATE event_push_actions_staging SET thread_id = 'main' WHERE thread_id IS NULL; -UPDATE event_push_actions SET thread_id = 'main' WHERE thread_id IS NULL; -UPDATE event_push_summary SET thread_id = 'main' WHERE thread_id IS NULL; - --- Do not run the event_push_summary_unique_index job if it is pending; the --- thread_id field will be made required. -DELETE FROM background_updates WHERE update_name = 'event_push_summary_unique_index'; -DROP INDEX IF EXISTS event_push_summary_unique_index; diff --git a/synapse/storage/schema/main/delta/73/06thread_notifications_thread_id_idx.sql b/synapse/storage/schema/main/delta/73/06thread_notifications_thread_id_idx.sql new file mode 100644 index 0000000000..8b3c636594 --- /dev/null +++ b/synapse/storage/schema/main/delta/73/06thread_notifications_thread_id_idx.sql @@ -0,0 +1,23 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Allow there to be multiple summaries per user/room. +DROP INDEX IF EXISTS event_push_summary_unique_index; + +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (7306, 'event_push_actions_thread_id_null', '{}', 'event_push_backfill_thread_id'); + +INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES + (7306, 'event_push_summary_thread_id_null', '{}', 'event_push_backfill_thread_id'); diff --git a/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite deleted file mode 100644 index 5322ad77a4..0000000000 --- a/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright 2022 The Matrix.org Foundation C.I.C - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - --- SQLite doesn't support modifying columns to an existing table, so it must --- be recreated. - --- Create the new tables. -CREATE TABLE event_push_actions_staging_new ( - event_id TEXT NOT NULL, - user_id TEXT NOT NULL, - actions TEXT NOT NULL, - notif SMALLINT NOT NULL, - highlight SMALLINT NOT NULL, - unread SMALLINT, - thread_id TEXT NOT NULL, - inserted_ts BIGINT -); - -CREATE TABLE event_push_actions_new ( - room_id TEXT NOT NULL, - event_id TEXT NOT NULL, - user_id TEXT NOT NULL, - profile_tag VARCHAR(32), - actions TEXT NOT NULL, - topological_ordering BIGINT, - stream_ordering BIGINT, - notif SMALLINT, - highlight SMALLINT, - unread SMALLINT, - thread_id TEXT NOT NULL, - CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag) -); - -CREATE TABLE event_push_summary_new ( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL, - notif_count BIGINT NOT NULL, - stream_ordering BIGINT NOT NULL, - unread_count BIGINT, - last_receipt_stream_ordering BIGINT, - thread_id TEXT NOT NULL -); - --- Swap the indexes. -DROP INDEX IF EXISTS event_push_actions_staging_id; -CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging_new(event_id); - -DROP INDEX IF EXISTS event_push_actions_room_id_user_id; -DROP INDEX IF EXISTS event_push_actions_rm_tokens; -DROP INDEX IF EXISTS event_push_actions_stream_ordering; -DROP INDEX IF EXISTS event_push_actions_u_highlight; -DROP INDEX IF EXISTS event_push_actions_highlights_index; -CREATE INDEX event_push_actions_room_id_user_id on event_push_actions_new(room_id, user_id); -CREATE INDEX event_push_actions_rm_tokens on event_push_actions_new( user_id, room_id, topological_ordering, stream_ordering ); -CREATE INDEX event_push_actions_stream_ordering on event_push_actions_new( stream_ordering, user_id ); -CREATE INDEX event_push_actions_u_highlight ON event_push_actions_new (user_id, stream_ordering); -CREATE INDEX event_push_actions_highlights_index ON event_push_actions_new (user_id, room_id, topological_ordering, stream_ordering); - --- Copy the data. -INSERT INTO event_push_actions_staging_new (event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts) - SELECT event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts - FROM event_push_actions_staging; - -INSERT INTO event_push_actions_new (room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id) - SELECT room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id - FROM event_push_actions; - -INSERT INTO event_push_summary_new (user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id) - SELECT user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id - FROM event_push_summary; - --- Drop the old tables. -DROP TABLE event_push_actions_staging; -DROP TABLE event_push_actions; -DROP TABLE event_push_summary; - --- Rename the tables. -ALTER TABLE event_push_actions_staging_new RENAME TO event_push_actions_staging; -ALTER TABLE event_push_actions_new RENAME TO event_push_actions; -ALTER TABLE event_push_summary_new RENAME TO event_push_summary; - --- Re-run background updates from 72/02event_push_actions_index.sql and --- 72/06thread_notifications.sql. -INSERT INTO background_updates (ordering, update_name, progress_json) VALUES - (7307, 'event_push_summary_unique_index2', '{}') - ON CONFLICT (update_name) DO NOTHING; -INSERT INTO background_updates (ordering, update_name, progress_json) VALUES - (7307, 'event_push_actions_stream_highlight_index', '{}') - ON CONFLICT (update_name) DO NOTHING; diff --git a/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres b/synapse/storage/schema/main/delta/73/09partial_joined_via_destination.sql index 33674f8c62..066d602b18 100644 --- a/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres +++ b/synapse/storage/schema/main/delta/73/09partial_joined_via_destination.sql @@ -13,7 +13,6 @@ * limitations under the License. */ --- The columns can now be made non-nullable. -ALTER TABLE event_push_actions_staging ALTER COLUMN thread_id SET NOT NULL; -ALTER TABLE event_push_actions ALTER COLUMN thread_id SET NOT NULL; -ALTER TABLE event_push_summary ALTER COLUMN thread_id SET NOT NULL; +-- When we resync partial state, we prioritise doing so using the server we +-- partial-joined from. To do this we need to record which server that was! +ALTER TABLE partial_state_rooms ADD COLUMN joined_via TEXT; diff --git a/synapse/storage/schema/main/delta/73/09threads_table.sql b/synapse/storage/schema/main/delta/73/09threads_table.sql new file mode 100644 index 0000000000..aa7c5e9a2e --- /dev/null +++ b/synapse/storage/schema/main/delta/73/09threads_table.sql @@ -0,0 +1,30 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE threads ( + room_id TEXT NOT NULL, + -- The event ID of the root event in the thread. + thread_id TEXT NOT NULL, + -- The latest event ID and corresponding topo / stream ordering. + latest_event_id TEXT NOT NULL, + topological_ordering BIGINT NOT NULL, + stream_ordering BIGINT NOT NULL, + CONSTRAINT threads_uniqueness UNIQUE (room_id, thread_id) +); + +CREATE INDEX threads_ordering_idx ON threads(room_id, topological_ordering, stream_ordering); + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7309, 'threads_backfill', '{}'); diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py index 806b671305..2dcd43d0a2 100644 --- a/synapse/streams/__init__.py +++ b/synapse/streams/__init__.py @@ -27,7 +27,7 @@ class EventSource(Generic[K, R]): self, user: UserID, from_key: K, - limit: Optional[int], + limit: int, room_ids: Collection[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/streams/config.py b/synapse/streams/config.py index f6f7bf3d8b..6df2de919c 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -35,14 +35,14 @@ class PaginationConfig: from_token: Optional[StreamToken] to_token: Optional[StreamToken] direction: str - limit: Optional[int] + limit: int @classmethod async def from_request( cls, store: "DataStore", request: SynapseRequest, - default_limit: Optional[int] = None, + default_limit: int, default_dir: str = "f", ) -> "PaginationConfig": direction = parse_string( @@ -69,12 +69,10 @@ class PaginationConfig: raise SynapseError(400, "'to' parameter is invalid") limit = parse_integer(request, "limit", default=default_limit) + if limit < 0: + raise SynapseError(400, "Limit must be 0 or above") - if limit: - if limit < 0: - raise SynapseError(400, "Limit must be 0 or above") - - limit = min(int(limit), MAX_LIMIT) + limit = min(limit, MAX_LIMIT) try: return PaginationConfig(from_tok, to_tok, direction, limit) diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 6425f851ea..bcb1cba362 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -395,8 +395,8 @@ class DeferredCache(Generic[KT, VT]): # _pending_deferred_cache.pop should either return a CacheEntry, or, in the # case of a TreeCache, a dict of keys to cache entries. Either way calling # iterate_tree_cache_entry on it will do the right thing. - for entry in iterate_tree_cache_entry(entry): - for cb in entry.get_invalidation_callbacks(key): + for iter_entry in iterate_tree_cache_entry(entry): + for cb in iter_entry.get_invalidation_callbacks(key): cb() def invalidate_all(self) -> None: diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 0391966462..b3c748ef44 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -432,7 +432,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): num_args = cached_method.num_args if num_args != self.num_args: - raise Exception( + raise TypeError( "Number of args (%s) does not match underlying cache_method_name=%s (%s)." % (self.num_args, self.cached_method_name, num_args) ) diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 27a363d7e5..4961fe9313 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -86,7 +86,7 @@ def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]: ValueError if the server name could not be parsed. """ try: - if server_name[-1] == "]": + if server_name and server_name[-1] == "]": # ipv6 literal, hopefully return server_name, None @@ -123,7 +123,7 @@ def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int] # that nobody is sneaking IP literals in that look like hostnames, etc. # look for ipv6 literals - if host[0] == "[": + if host and host[0] == "[": if host[-1] != "]": raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) diff --git a/synapse/visibility.py b/synapse/visibility.py index c4048d2477..40a9c5b53f 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -84,7 +84,15 @@ async def filter_events_for_client( """ # Filter out events that have been soft failed so that we don't relay them # to clients. + events_before_filtering = events events = [e for e in events if not e.internal_metadata.is_soft_failed()] + if len(events_before_filtering) != len(events): + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "filter_events_for_client: Filtered out soft-failed events: Before=%s, After=%s", + [event.event_id for event in events_before_filtering], + [event.event_id for event in events], + ) types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id)) @@ -301,6 +309,10 @@ def _check_client_allowed_to_see_event( _check_filter_send_to_client(event, clock, retention_policy, sender_ignored) == _CheckFilter.DENIED ): + logger.debug( + "_check_client_allowed_to_see_event(event=%s): Filtered out event because `_check_filter_send_to_client` returned `_CheckFilter.DENIED`", + event.event_id, + ) return None if event.event_id in always_include_ids: @@ -312,9 +324,17 @@ def _check_client_allowed_to_see_event( # for out-of-band membership events (eg, incoming invites, or rejections of # said invite) for the user themselves. if event.type == EventTypes.Member and event.state_key == user_id: - logger.debug("Returning out-of-band-membership event %s", event) + logger.debug( + "_check_client_allowed_to_see_event(event=%s): Returning out-of-band-membership event %s", + event.event_id, + event, + ) return event + logger.debug( + "_check_client_allowed_to_see_event(event=%s): Filtered out event because it's an outlier", + event.event_id, + ) return None if state is None: @@ -337,11 +357,21 @@ def _check_client_allowed_to_see_event( membership_result = _check_membership(user_id, event, visibility, state, is_peeking) if not membership_result.allowed: + logger.debug( + "_check_client_allowed_to_see_event(event=%s): Filtered out event because the user can't see the event because of their membership, membership_result.allowed=%s membership_result.joined=%s", + event.event_id, + membership_result.allowed, + membership_result.joined, + ) return None # If the sender has been erased and the user was not joined at the time, we # must only return the redacted form. if sender_erased and not membership_result.joined: + logger.debug( + "_check_client_allowed_to_see_event(event=%s): Returning pruned event because `sender_erased` and the user was not joined at the time", + event.event_id, + ) event = prune_event(event) return event |