summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py2
-rw-r--r--synapse/api/errors.py21
-rw-r--r--synapse/api/filtering.py27
-rw-r--r--synapse/app/generic_worker.py2
-rw-r--r--synapse/config/experimental.py5
-rw-r--r--synapse/config/repository.py2
-rw-r--r--synapse/federation/federation_client.py2
-rw-r--r--synapse/federation/federation_server.py9
-rw-r--r--synapse/federation/sender/__init__.py29
-rw-r--r--synapse/federation/transport/server/federation.py2
-rw-r--r--synapse/handlers/account_data.py2
-rw-r--r--synapse/handlers/appservice.py9
-rw-r--r--synapse/handlers/directory.py19
-rw-r--r--synapse/handlers/federation.py36
-rw-r--r--synapse/handlers/federation_event.py35
-rw-r--r--synapse/handlers/initial_sync.py27
-rw-r--r--synapse/handlers/message.py91
-rw-r--r--synapse/handlers/pagination.py5
-rw-r--r--synapse/handlers/presence.py4
-rw-r--r--synapse/handlers/receipts.py2
-rw-r--r--synapse/handlers/relations.py89
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/handlers/typing.py2
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py66
-rw-r--r--synapse/replication/http/register.py18
-rw-r--r--synapse/rest/client/directory.py58
-rw-r--r--synapse/rest/client/events.py4
-rw-r--r--synapse/rest/client/initial_sync.py4
-rw-r--r--synapse/rest/client/receipts.py44
-rw-r--r--synapse/rest/client/relations.py45
-rw-r--r--synapse/rest/client/room.py4
-rw-r--r--synapse/rest/client/versions.py3
-rw-r--r--synapse/storage/databases/main/cache.py10
-rw-r--r--synapse/storage/databases/main/event_federation.py54
-rw-r--r--synapse/storage/databases/main/event_push_actions.py72
-rw-r--r--synapse/storage/databases/main/events.py39
-rw-r--r--synapse/storage/databases/main/events_worker.py46
-rw-r--r--synapse/storage/databases/main/push_rule.py15
-rw-r--r--synapse/storage/databases/main/receipts.py2
-rw-r--r--synapse/storage/databases/main/relations.py305
-rw-r--r--synapse/storage/databases/main/roommember.py17
-rw-r--r--synapse/storage/databases/main/stream.py59
-rw-r--r--synapse/storage/schema/main/delta/73/09threads_table.sql30
-rw-r--r--synapse/streams/__init__.py2
-rw-r--r--synapse/streams/config.py12
-rw-r--r--synapse/visibility.py32
46 files changed, 1021 insertions, 344 deletions
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py

index 5fa599e70e..d850e54e17 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py
@@ -72,6 +72,7 @@ from synapse.storage.databases.main.registration import ( RegistrationBackgroundUpdateStore, find_max_generated_user_id_localpart, ) +from synapse.storage.databases.main.relations import RelationsWorkerStore from synapse.storage.databases.main.room import RoomBackgroundUpdateStore from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore from synapse.storage.databases.main.search import SearchBackgroundUpdateStore @@ -206,6 +207,7 @@ class Store( PusherWorkerStore, PresenceBackgroundUpdateStore, ReceiptsBackgroundUpdateStore, + RelationsWorkerStore, ): def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index c606207569..e0873b1913 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py
@@ -640,6 +640,27 @@ class FederationError(RuntimeError): } +class FederationPullAttemptBackoffError(RuntimeError): + """ + Raised to indicate that we are are deliberately not attempting to pull the given + event over federation because we've already done so recently and are backing off. + + Attributes: + event_id: The event_id which we are refusing to pull + message: A custom error message that gives more context + """ + + def __init__(self, event_ids: List[str], message: Optional[str]): + self.event_ids = event_ids + + if message: + error_message = message + else: + error_message = f"Not attempting to pull event_ids={self.event_ids} because we already tried to pull them recently (backing off)." + + super().__init__(error_message) + + class HttpResponseException(CodeMessageException): """ Represents an HTTP-level failure of an outbound request diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index cc31cf8cc7..26be377d03 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py
@@ -36,7 +36,7 @@ from jsonschema import FormatChecker from synapse.api.constants import EduTypes, EventContentFields from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState -from synapse.events import EventBase +from synapse.events import EventBase, relation_from_event from synapse.types import JsonDict, RoomID, UserID if TYPE_CHECKING: @@ -53,6 +53,12 @@ FILTER_SCHEMA = { # check types are valid event types "types": {"type": "array", "items": {"type": "string"}}, "not_types": {"type": "array", "items": {"type": "string"}}, + # MSC3874, filtering /messages. + "org.matrix.msc3874.rel_types": {"type": "array", "items": {"type": "string"}}, + "org.matrix.msc3874.not_rel_types": { + "type": "array", + "items": {"type": "string"}, + }, }, } @@ -334,8 +340,15 @@ class Filter: self.labels = filter_json.get("org.matrix.labels", None) self.not_labels = filter_json.get("org.matrix.not_labels", []) - self.related_by_senders = self.filter_json.get("related_by_senders", None) - self.related_by_rel_types = self.filter_json.get("related_by_rel_types", None) + self.related_by_senders = filter_json.get("related_by_senders", None) + self.related_by_rel_types = filter_json.get("related_by_rel_types", None) + + # For compatibility with _check_fields. + self.rel_types = None + self.not_rel_types = [] + if hs.config.experimental.msc3874_enabled: + self.rel_types = filter_json.get("org.matrix.msc3874.rel_types", None) + self.not_rel_types = filter_json.get("org.matrix.msc3874.not_rel_types", []) def filters_all_types(self) -> bool: return "*" in self.not_types @@ -386,11 +399,19 @@ class Filter: # check if there is a string url field in the content for filtering purposes labels = content.get(EventContentFields.LABELS, []) + # Check if the event has a relation. + rel_type = None + if isinstance(event, EventBase): + relation = relation_from_event(event) + if relation: + rel_type = relation.rel_type + field_matchers = { "rooms": lambda v: room_id == v, "senders": lambda v: sender == v, "types": lambda v: _matches_wildcard(ev_type, v), "labels": lambda v: v in labels, + "rel_types": lambda v: rel_type == v, } result = self._check_fields(field_matchers) diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 5e3825fca6..dc49840f73 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py
@@ -65,6 +65,7 @@ from synapse.rest.client import ( push_rule, read_marker, receipts, + relations, room, room_batch, room_keys, @@ -308,6 +309,7 @@ class GenericWorkerServer(HomeServer): sync.register_servlets(self, resource) events.register_servlets(self, resource) room.register_servlets(self, resource, is_worker=True) + relations.register_servlets(self, resource) room.register_deprecated_servlets(self, resource) initial_sync.register_servlets(self, resource) room_batch.register_servlets(self, resource) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index e00cb7096c..f9a49451d8 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py
@@ -95,8 +95,6 @@ class ExperimentalConfig(Config): # MSC2815 (allow room moderators to view redacted event content) self.msc2815_enabled: bool = experimental.get("msc2815_enabled", False) - # MSC3772: A push rule for mutual relations. - self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) # MSC3773: Thread notifications self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False) @@ -119,3 +117,6 @@ class ExperimentalConfig(Config): self.msc3882_token_timeout = self.parse_duration( experimental.get("msc3882_token_timeout", "5m") ) + + # MSC3874: Filtering /messages with rel_types / not_rel_types. + self.msc3874_enabled: bool = experimental.get("msc3874_enabled", False) diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 1033496bb4..e4759711ed 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py
@@ -205,7 +205,7 @@ class ContentRepositoryConfig(Config): ) self.url_preview_enabled = config.get("url_preview_enabled", False) if self.url_preview_enabled: - check_requirements("url_preview") + check_requirements("url-preview") proxy_env = getproxies_environment() if "url_preview_ip_range_blacklist" not in config: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 4dca711cd2..b220ab43fc 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py
@@ -1294,7 +1294,7 @@ class FederationClient(FederationBase): return resp[1] async def send_knock(self, destinations: List[str], pdu: EventBase) -> JsonDict: - """Attempts to send a knock event to given a list of servers. Iterates + """Attempts to send a knock event to a given list of servers. Iterates through the list until one attempt succeeds. Doing so will cause the remote server to add the event to the graph, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 907940e19e..28097664b4 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py
@@ -824,7 +824,14 @@ class FederationServer(FederationBase): context, self._room_prejoin_state_types ) ) - return {"knock_state_events": stripped_room_state} + return { + "knock_room_state": stripped_room_state, + # Since v1.37, Synapse incorrectly used "knock_state_events" for this field. + # Thus, we also populate a 'knock_state_events' with the same content to + # support old instances. + # See https://github.com/matrix-org/synapse/issues/14088. + "knock_state_events": stripped_room_state, + } async def _on_send_membership_event( self, origin: str, content: JsonDict, membership_type: str, room_id: str diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index a6cb3ba58f..774ecd81b6 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py
@@ -353,21 +353,25 @@ class FederationSender(AbstractFederationSender): last_token = await self.store.get_federation_out_pos("events") ( next_token, - events, event_to_received_ts, - ) = await self.store.get_all_new_events_stream( + ) = await self.store.get_all_new_event_ids_stream( last_token, self._last_poked_id, limit=100 ) + event_ids = event_to_received_ts.keys() + event_entries = await self.store.get_unredacted_events_from_cache_or_db( + event_ids + ) + logger.debug( "Handling %i -> %i: %i events to send (current id %i)", last_token, next_token, - len(events), + len(event_entries), self._last_poked_id, ) - if not events and next_token >= self._last_poked_id: + if not event_entries and next_token >= self._last_poked_id: logger.debug("All events processed") break @@ -508,8 +512,14 @@ class FederationSender(AbstractFederationSender): await handle_event(event) events_by_room: Dict[str, List[EventBase]] = {} - for event in events: - events_by_room.setdefault(event.room_id, []).append(event) + + for event_id in event_ids: + # `event_entries` is unsorted, so we have to iterate over `event_ids` + # to ensure the events are in the right order + event_cache = event_entries.get(event_id) + if event_cache: + event = event_cache.event + events_by_room.setdefault(event.room_id, []).append(event) await make_deferred_yieldable( defer.gatherResults( @@ -524,9 +534,10 @@ class FederationSender(AbstractFederationSender): logger.debug("Successfully handled up to %i", next_token) await self.store.update_federation_out_pos("events", next_token) - if events: + if event_entries: now = self.clock.time_msec() - ts = event_to_received_ts[events[-1].event_id] + last_id = next(reversed(event_ids)) + ts = event_to_received_ts[last_id] assert ts is not None synapse.metrics.event_processing_lag.labels( @@ -536,7 +547,7 @@ class FederationSender(AbstractFederationSender): "federation_sender" ).set(ts) - events_processed_counter.inc(len(events)) + events_processed_counter.inc(len(event_entries)) event_processing_loop_room_count.labels("federation_sender").inc( len(events_by_room) diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index 6bb4659c4c..6f11138b57 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py
@@ -489,7 +489,7 @@ class FederationV2InviteServlet(BaseFederationServerServlet): room_version = content["room_version"] event = content["event"] - invite_room_state = content["invite_room_state"] + invite_room_state = content.get("invite_room_state", []) # Synapse expects invite_room_state to be in unsigned, as it is in v1 # API diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 0478448b47..fc21d58001 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py
@@ -225,7 +225,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]): self, user: UserID, from_key: int, - limit: Optional[int], + limit: int, room_ids: Collection[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 203b62e015..66f5b8d108 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py
@@ -109,10 +109,13 @@ class ApplicationServicesHandler: last_token = await self.store.get_appservice_last_pos() ( upper_bound, - events, event_to_received_ts, - ) = await self.store.get_all_new_events_stream( - last_token, self.current_max, limit=100, get_prev_content=True + ) = await self.store.get_all_new_event_ids_stream( + last_token, self.current_max, limit=100 + ) + + events = await self.store.get_events_as_list( + event_to_received_ts.keys(), get_prev_content=True ) events_by_room: Dict[str, List[EventBase]] = {} diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 7127d5aefc..d52ebada6b 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py
@@ -16,6 +16,8 @@ import logging import string from typing import TYPE_CHECKING, Iterable, List, Optional +from typing_extensions import Literal + from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes from synapse.api.errors import ( AuthError, @@ -429,7 +431,10 @@ class DirectoryHandler: return await self.auth.check_can_change_room_list(room_id, requester) async def edit_published_room_list( - self, requester: Requester, room_id: str, visibility: str + self, + requester: Requester, + room_id: str, + visibility: Literal["public", "private"], ) -> None: """Edit the entry of the room in the published room list. @@ -451,9 +456,6 @@ class DirectoryHandler: if requester.is_guest: raise AuthError(403, "Guests cannot edit the published room list") - if visibility not in ["public", "private"]: - raise SynapseError(400, "Invalid visibility setting") - if visibility == "public" and not self.enable_room_list_search: # The room list has been disabled. raise AuthError( @@ -505,7 +507,11 @@ class DirectoryHandler: await self.store.set_room_is_public(room_id, making_public) async def edit_published_appservice_room_list( - self, appservice_id: str, network_id: str, room_id: str, visibility: str + self, + appservice_id: str, + network_id: str, + room_id: str, + visibility: Literal["public", "private"], ) -> None: """Add or remove a room from the appservice/network specific public room list. @@ -516,9 +522,6 @@ class DirectoryHandler: room_id visibility: either "public" or "private" """ - if visibility not in ["public", "private"]: - raise SynapseError(400, "Invalid visibility setting") - await self.store.set_room_is_public_appservice( room_id, appservice_id, network_id, visibility == "public" ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 986ffed3d5..5f7e0a1f79 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py
@@ -45,6 +45,7 @@ from synapse.api.errors import ( Codes, FederationDeniedError, FederationError, + FederationPullAttemptBackoffError, HttpResponseException, LimitExceededError, NotFoundError, @@ -781,15 +782,27 @@ class FederationHandler: # Send the signed event back to the room, and potentially receive some # further information about the room in the form of partial state events - stripped_room_state = await self.federation_client.send_knock( - target_hosts, event - ) + knock_response = await self.federation_client.send_knock(target_hosts, event) # Store any stripped room state events in the "unsigned" key of the event. # This is a bit of a hack and is cribbing off of invites. Basically we # store the room state here and retrieve it again when this event appears # in the invitee's sync stream. It is stripped out for all other local users. - event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"] + stripped_room_state = ( + knock_response.get("knock_room_state") + # Since v1.37, Synapse incorrectly used "knock_state_events" for this field. + # Thus, we also check for a 'knock_state_events' to support old instances. + # See https://github.com/matrix-org/synapse/issues/14088. + or knock_response.get("knock_state_events") + ) + + if stripped_room_state is None: + raise KeyError( + "Missing 'knock_room_state' (or legacy 'knock_state_events') field in " + "send_knock response" + ) + + event.unsigned["knock_room_state"] = stripped_room_state context = EventContext.for_outlier(self._storage_controllers) stream_id = await self._federation_event_handler.persist_events_and_notify( @@ -1708,7 +1721,22 @@ class FederationHandler: destination, event ) break + except FederationPullAttemptBackoffError as exc: + # Log a warning about why we failed to process the event (the error message + # for `FederationPullAttemptBackoffError` is pretty good) + logger.warning("_sync_partial_state_room: %s", exc) + # We do not record a failed pull attempt when we backoff fetching a missing + # `prev_event` because not being able to fetch the `prev_events` just means + # we won't be able to de-outlier the pulled event. But we can still use an + # `outlier` in the state/auth chain for another event. So we shouldn't stop + # a downstream event from trying to pull it. + # + # This avoids a cascade of backoff for all events in the DAG downstream from + # one event backoff upstream. except FederationError as e: + # TODO: We should `record_event_failed_pull_attempt` here, + # see https://github.com/matrix-org/synapse/issues/13700 + if attempt == len(destinations) - 1: # We have tried every remote server for this event. Give up. # TODO(faster_joins) giving up isn't the right thing to do diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index b94641a87d..06e41b5cc0 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py
@@ -44,6 +44,7 @@ from synapse.api.errors import ( AuthError, Codes, FederationError, + FederationPullAttemptBackoffError, HttpResponseException, RequestSendFailed, SynapseError, @@ -414,7 +415,9 @@ class FederationEventHandler: # First, precalculate the joined hosts so that the federation sender doesn't # need to. - await self._event_creation_handler.cache_joined_hosts_for_event(event, context) + await self._event_creation_handler.cache_joined_hosts_for_events( + [(event, context)] + ) await self._check_for_soft_fail(event, context=context, origin=origin) await self._run_push_actions_and_persist_event(event, context) @@ -565,6 +568,9 @@ class FederationEventHandler: event: partial-state event to be de-partial-stated Raises: + FederationPullAttemptBackoffError if we are are deliberately not attempting + to pull the given event over federation because we've already done so + recently and are backing off. FederationError if we fail to request state from the remote server. """ logger.info("Updating state for %s", event.event_id) @@ -920,6 +926,18 @@ class FederationEventHandler: context, backfilled=backfilled, ) + except FederationPullAttemptBackoffError as exc: + # Log a warning about why we failed to process the event (the error message + # for `FederationPullAttemptBackoffError` is pretty good) + logger.warning("_process_pulled_event: %s", exc) + # We do not record a failed pull attempt when we backoff fetching a missing + # `prev_event` because not being able to fetch the `prev_events` just means + # we won't be able to de-outlier the pulled event. But we can still use an + # `outlier` in the state/auth chain for another event. So we shouldn't stop + # a downstream event from trying to pull it. + # + # This avoids a cascade of backoff for all events in the DAG downstream from + # one event backoff upstream. except FederationError as e: await self._store.record_event_failed_pull_attempt( event.room_id, event_id, str(e) @@ -966,6 +984,9 @@ class FederationEventHandler: The event context. Raises: + FederationPullAttemptBackoffError if we are are deliberately not attempting + to pull the given event over federation because we've already done so + recently and are backing off. FederationError if we fail to get the state from the remote server after any missing `prev_event`s. """ @@ -976,6 +997,18 @@ class FederationEventHandler: seen = await self._store.have_events_in_timeline(prevs) missing_prevs = prevs - seen + # If we've already recently attempted to pull this missing event, don't + # try it again so soon. Since we have to fetch all of the prev_events, we can + # bail early here if we find any to ignore. + prevs_to_ignore = await self._store.get_event_ids_to_not_pull_from_backoff( + room_id, missing_prevs + ) + if len(prevs_to_ignore) > 0: + raise FederationPullAttemptBackoffError( + event_ids=prevs_to_ignore, + message=f"While computing context for event={event_id}, not attempting to pull missing prev_event={prevs_to_ignore[0]} because we already tried to pull recently (backing off).", + ) + if not missing_prevs: return await self._state_handler.compute_event_context(event) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 860c82c110..9c335e6863 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py
@@ -57,13 +57,7 @@ class InitialSyncHandler: self.validator = EventValidator() self.snapshot_cache: ResponseCache[ Tuple[ - str, - Optional[StreamToken], - Optional[StreamToken], - str, - Optional[int], - bool, - bool, + str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool ] ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() @@ -154,11 +148,6 @@ class InitialSyncHandler: public_room_ids = await self.store.get_public_room_ids() - if pagin_config.limit is not None: - limit = pagin_config.limit - else: - limit = 10 - serializer_options = SerializeEventConfig(as_client_event=as_client_event) async def handle_room(event: RoomsForUser) -> None: @@ -210,7 +199,7 @@ class InitialSyncHandler: run_in_background( self.store.get_recent_events_for_room, event.room_id, - limit=limit, + limit=pagin_config.limit, end_token=room_end_token, ), deferred_room_state, @@ -360,15 +349,11 @@ class InitialSyncHandler: member_event_id ) - limit = pagin_config.limit if pagin_config else None - if limit is None: - limit = 10 - leave_position = await self.store.get_position_for_event(member_event_id) stream_token = leave_position.to_room_stream_token() messages, token = await self.store.get_recent_events_for_room( - room_id, limit=limit, end_token=stream_token + room_id, limit=pagin_config.limit, end_token=stream_token ) messages = await filter_events_for_client( @@ -420,10 +405,6 @@ class InitialSyncHandler: now_token = self.hs.get_event_sources().get_current_token() - limit = pagin_config.limit if pagin_config else None - if limit is None: - limit = 10 - room_members = [ m for m in current_state.values() @@ -467,7 +448,7 @@ class InitialSyncHandler: run_in_background( self.store.get_recent_events_for_room, room_id, - limit=limit, + limit=pagin_config.limit, end_token=now_token.room_key, ), ), diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index da1acea275..4e55ebba0b 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py
@@ -1390,7 +1390,7 @@ class EventCreationHandler: extra_users=extra_users, ), run_in_background( - self.cache_joined_hosts_for_event, event, context + self.cache_joined_hosts_for_events, events_and_context ).addErrback( log_failure, "cache_joined_hosts_for_event failed" ), @@ -1491,62 +1491,65 @@ class EventCreationHandler: await self.store.remove_push_actions_from_staging(event.event_id) raise - async def cache_joined_hosts_for_event( - self, event: EventBase, context: EventContext + async def cache_joined_hosts_for_events( + self, events_and_context: List[Tuple[EventBase, EventContext]] ) -> None: - """Precalculate the joined hosts at the event, when using Redis, so that + """Precalculate the joined hosts at each of the given events, when using Redis, so that external federation senders don't have to recalculate it themselves. """ - if not self._external_cache.is_enabled(): - return - - # If external cache is enabled we should always have this. - assert self._external_cache_joined_hosts_updates is not None + for event, _ in events_and_context: + if not self._external_cache.is_enabled(): + return - # We actually store two mappings, event ID -> prev state group, - # state group -> joined hosts, which is much more space efficient - # than event ID -> joined hosts. - # - # Note: We have to cache event ID -> prev state group, as we don't - # store that in the DB. - # - # Note: We set the state group -> joined hosts cache if it hasn't been - # set for a while, so that the expiry time is reset. + # If external cache is enabled we should always have this. + assert self._external_cache_joined_hosts_updates is not None - state_entry = await self.state.resolve_state_groups_for_events( - event.room_id, event_ids=event.prev_event_ids() - ) + # We actually store two mappings, event ID -> prev state group, + # state group -> joined hosts, which is much more space efficient + # than event ID -> joined hosts. + # + # Note: We have to cache event ID -> prev state group, as we don't + # store that in the DB. + # + # Note: We set the state group -> joined hosts cache if it hasn't been + # set for a while, so that the expiry time is reset. - if state_entry.state_group: - await self._external_cache.set( - "event_to_prev_state_group", - event.event_id, - state_entry.state_group, - expiry_ms=60 * 60 * 1000, + state_entry = await self.state.resolve_state_groups_for_events( + event.room_id, event_ids=event.prev_event_ids() ) - if state_entry.state_group in self._external_cache_joined_hosts_updates: - return + if state_entry.state_group: + await self._external_cache.set( + "event_to_prev_state_group", + event.event_id, + state_entry.state_group, + expiry_ms=60 * 60 * 1000, + ) - state = await state_entry.get_state( - self._storage_controllers.state, StateFilter.all() - ) - with opentracing.start_active_span("get_joined_hosts"): - joined_hosts = await self.store.get_joined_hosts( - event.room_id, state, state_entry + if state_entry.state_group in self._external_cache_joined_hosts_updates: + return + + state = await state_entry.get_state( + self._storage_controllers.state, StateFilter.all() ) + with opentracing.start_active_span("get_joined_hosts"): + joined_hosts = await self.store.get_joined_hosts( + event.room_id, state, state_entry + ) - # Note that the expiry times must be larger than the expiry time in - # _external_cache_joined_hosts_updates. - await self._external_cache.set( - "get_joined_hosts", - str(state_entry.state_group), - list(joined_hosts), - expiry_ms=60 * 60 * 1000, - ) + # Note that the expiry times must be larger than the expiry time in + # _external_cache_joined_hosts_updates. + await self._external_cache.set( + "get_joined_hosts", + str(state_entry.state_group), + list(joined_hosts), + expiry_ms=60 * 60 * 1000, + ) - self._external_cache_joined_hosts_updates[state_entry.state_group] = None + self._external_cache_joined_hosts_updates[ + state_entry.state_group + ] = None async def _validate_canonical_alias( self, diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 1f83bab836..a4ca9cb8b4 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py
@@ -458,11 +458,6 @@ class PaginationHandler: # `/messages` should still works with live tokens when manually provided. assert from_token.room_key.topological is not None - if pagin_config.limit is None: - # This shouldn't happen as we've set a default limit before this - # gets called. - raise Exception("limit not set") - room_token = from_token.room_key async with self.pagination_lock.read(room_id): diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 4e575ffbaa..2670e561d7 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py
@@ -1596,7 +1596,9 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): self, user: UserID, from_key: Optional[int], - limit: Optional[int] = None, + # Having a default limit doesn't match the EventSource API, but some + # callers do not provide it. It is unused in this class. + limit: int = 0, room_ids: Optional[Collection[str]] = None, is_guest: bool = False, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 4a7ec9e426..ac01582442 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py
@@ -257,7 +257,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): self, user: UserID, from_key: int, - limit: Optional[int], + limit: int, room_ids: Iterable[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index cc5e45c241..0a0c6d938e 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py
@@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum import logging from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple @@ -20,7 +21,7 @@ from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace -from synapse.storage.databases.main.relations import _RelatedEvent +from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.visibility import filter_events_for_client @@ -32,6 +33,13 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class ThreadsListInclude(str, enum.Enum): + """Valid values for the 'include' flag of /threads.""" + + all = "all" + participated = "participated" + + @attr.s(slots=True, frozen=True, auto_attribs=True) class _ThreadAggregation: # The latest event in the thread. @@ -108,9 +116,6 @@ class RelationsHandler: if event is None: raise SynapseError(404, "Unknown parent event.") - # TODO Update pagination config to not allow None limits. - assert pagin_config.limit is not None - # Note that ignored users are not passed into get_relations_for_event # below. Ignored users are handled in filter_events_for_client (and by # not passing them in here we should get a better cache hit rate). @@ -482,3 +487,79 @@ class RelationsHandler: results.setdefault(event_id, BundledAggregations()).replace = edit return results + + async def get_threads( + self, + requester: Requester, + room_id: str, + include: ThreadsListInclude, + limit: int = 5, + from_token: Optional[ThreadsNextBatch] = None, + ) -> JsonDict: + """Get related events of a event, ordered by topological ordering. + + Args: + requester: The user requesting the relations. + room_id: The room the event belongs to. + include: One of "all" or "participated" to indicate which threads should + be returned. + limit: Only fetch the most recent `limit` events. + from_token: Fetch rows from the given token, or from the start if None. + + Returns: + The pagination chunk. + """ + + user_id = requester.user.to_string() + + # TODO Properly handle a user leaving a room. + (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable( + room_id, requester, allow_departed_users=True + ) + + # Note that ignored users are not passed into get_relations_for_event + # below. Ignored users are handled in filter_events_for_client (and by + # not passing them in here we should get a better cache hit rate). + thread_roots, next_batch = await self._main_store.get_threads( + room_id=room_id, limit=limit, from_token=from_token + ) + + events = await self._main_store.get_events_as_list(thread_roots) + + if include == ThreadsListInclude.participated: + # Pre-seed thread participation with whether the requester sent the event. + participated = {event.event_id: event.sender == user_id for event in events} + # For events the requester did not send, check the database for whether + # the requester sent a threaded reply. + participated.update( + await self._main_store.get_threads_participated( + [eid for eid, p in participated.items() if not p], + user_id, + ) + ) + + # Limit the returned threads to those the user has participated in. + events = [event for event in events if participated[event.event_id]] + + events = await filter_events_for_client( + self._storage_controllers, + user_id, + events, + is_peeking=(member_event_id is None), + ) + + aggregations = await self.get_bundled_aggregations( + events, requester.user.to_string() + ) + + now = self._clock.time_msec() + serialized_events = self._event_serializer.serialize_events( + events, now, bundle_aggregations=aggregations + ) + + return_value: JsonDict = {"chunk": serialized_events} + + if next_batch: + return_value["next_batch"] = str(next_batch) + + return return_value diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 57ab05ad25..4e1aacb408 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py
@@ -1646,7 +1646,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): self, user: UserID, from_key: RoomStreamToken, - limit: Optional[int], + limit: int, room_ids: Collection[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index f953691669..a0ea719430 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py
@@ -513,7 +513,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): self, user: UserID, from_key: int, - limit: Optional[int], + limit: int, room_ids: Iterable[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index eced182fd5..a75386f6a0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging from typing import ( TYPE_CHECKING, Any, Collection, Dict, - Iterable, List, Mapping, Optional, - Set, Tuple, Union, ) @@ -38,7 +35,7 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership from synapse.storage.state import StateFilter -from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRuleEvaluator +from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator from synapse.util.caches import register_cache from synapse.util.metrics import measure_func from synapse.visibility import filter_event_for_clients_with_state @@ -117,9 +114,6 @@ class BulkPushRuleEvaluator: resizable=False, ) - # Whether to support MSC3772 is supported. - self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled - async def _get_rules_for_event( self, event: EventBase, @@ -200,51 +194,6 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level - async def _get_mutual_relations( - self, parent_id: str, rules: Iterable[Tuple[PushRule, bool]] - ) -> Dict[str, Set[Tuple[str, str]]]: - """ - Fetch event metadata for events which related to the same event as the given event. - - If the given event has no relation information, returns an empty dictionary. - - Args: - parent_id: The event ID which is targeted by relations. - rules: The push rules which will be processed for this event. - - Returns: - A dictionary of relation type to: - A set of tuples of: - The sender - The event type - """ - - # If the experimental feature is not enabled, skip fetching relations. - if not self._relations_match_enabled: - return {} - - # Pre-filter to figure out which relation types are interesting. - rel_types = set() - for rule, enabled in rules: - if not enabled: - continue - - for condition in rule.conditions: - if condition["kind"] != "org.matrix.msc3772.relation_match": - continue - - # rel_type is required. - rel_type = condition.get("rel_type") - if rel_type: - rel_types.add(rel_type) - - # If no valid rules were found, no mutual relations. - if not rel_types: - return {} - - # If any valid rules were found, fetch the mutual relations. - return await self.store.get_mutual_event_relations(parent_id, rel_types) - @measure_func("action_for_event_by_user") async def action_for_event_by_user( self, event: EventBase, context: EventContext @@ -276,23 +225,18 @@ class BulkPushRuleEvaluator: sender_power_level, ) = await self._get_power_levels_and_sender_level(event, context) + # Find the event's thread ID. relation = relation_from_event(event) - # If the event does not have a relation, then cannot have any mutual - # relations or thread ID. - relations = {} + # If the event does not have a relation, then it cannot have a thread ID. thread_id = MAIN_TIMELINE if relation: - relations = await self._get_mutual_relations( - relation.parent_id, - itertools.chain(*(r.rules() for r in rules_by_user.values())), - ) # Recursively attempt to find the thread this event relates to. if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id else: # Since the event has not yet been persisted we check whether # the parent is part of a thread. - thread_id = await self.store.get_thread_id(relation.parent_id) or "main" + thread_id = await self.store.get_thread_id(relation.parent_id) # It's possible that old room versions have non-integer power levels (floats or # strings). Workaround this by explicitly converting to int. @@ -306,8 +250,6 @@ class BulkPushRuleEvaluator: room_member_count, sender_power_level, notification_levels, - relations, - self._relations_match_enabled, ) users = rules_by_user.keys() diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 61abb529c8..976c283360 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py
@@ -39,6 +39,16 @@ class ReplicationRegisterServlet(ReplicationEndpoint): self.store = hs.get_datastores().main self.registration_handler = hs.get_registration_handler() + # Default value if the worker that sent the replication request did not include + # an 'approved' property. + if ( + hs.config.experimental.msc3866.enabled + and hs.config.experimental.msc3866.require_approval_for_new_accounts + ): + self._approval_default = False + else: + self._approval_default = True + @staticmethod async def _serialize_payload( # type: ignore[override] user_id: str, @@ -92,6 +102,12 @@ class ReplicationRegisterServlet(ReplicationEndpoint): await self.registration_handler.check_registration_ratelimit(content["address"]) + # Always default admin users to approved (since it means they were created by + # an admin). + approved_default = self._approval_default + if content["admin"]: + approved_default = True + await self.registration_handler.register_with_store( user_id=user_id, password_hash=content["password_hash"], @@ -103,7 +119,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): user_type=content["user_type"], address=content["address"], shadow_banned=content["shadow_banned"], - approved=content["approved"], + approved=content.get("approved", approved_default), ) return 200, {} diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py
index bc1b18c92d..f17b4c8d22 100644 --- a/synapse/rest/client/directory.py +++ b/synapse/rest/client/directory.py
@@ -13,15 +13,22 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple + +from pydantic import StrictStr +from typing_extensions import Literal from twisted.web.server import Request from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.servlet import ( + RestServlet, + parse_and_validate_json_object_from_request, +) from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns +from synapse.rest.models import RequestBodyModel from synapse.types import JsonDict, RoomAlias if TYPE_CHECKING: @@ -54,6 +61,12 @@ class ClientDirectoryServer(RestServlet): return 200, res + class PutBody(RequestBodyModel): + # TODO: get Pydantic to validate that this is a valid room id? + room_id: StrictStr + # `servers` is unspecced + servers: Optional[List[StrictStr]] = None + async def on_PUT( self, request: SynapseRequest, room_alias: str ) -> Tuple[int, JsonDict]: @@ -61,31 +74,22 @@ class ClientDirectoryServer(RestServlet): raise SynapseError(400, "Room alias invalid", errcode=Codes.INVALID_PARAM) room_alias_obj = RoomAlias.from_string(room_alias) - content = parse_json_object_from_request(request) - if "room_id" not in content: - raise SynapseError( - 400, 'Missing params: ["room_id"]', errcode=Codes.BAD_JSON - ) + content = parse_and_validate_json_object_from_request(request, self.PutBody) logger.debug("Got content: %s", content) logger.debug("Got room name: %s", room_alias_obj.to_string()) - room_id = content["room_id"] - servers = content["servers"] if "servers" in content else None - - logger.debug("Got room_id: %s", room_id) - logger.debug("Got servers: %s", servers) + logger.debug("Got room_id: %s", content.room_id) + logger.debug("Got servers: %s", content.servers) - # TODO(erikj): Check types. - - room = await self.store.get_room(room_id) + room = await self.store.get_room(content.room_id) if room is None: raise SynapseError(400, "Room does not exist") requester = await self.auth.get_user_by_req(request) await self.directory_handler.create_association( - requester, room_alias_obj, room_id, servers + requester, room_alias_obj, content.room_id, content.servers ) return 200, {} @@ -137,16 +141,18 @@ class ClientDirectoryListServer(RestServlet): return 200, {"visibility": "public" if room["is_public"] else "private"} + class PutBody(RequestBodyModel): + visibility: Literal["public", "private"] = "public" + async def on_PUT( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - content = parse_json_object_from_request(request) - visibility = content.get("visibility", "public") + content = parse_and_validate_json_object_from_request(request, self.PutBody) await self.directory_handler.edit_published_room_list( - requester, room_id, visibility + requester, room_id, content.visibility ) return 200, {} @@ -163,12 +169,14 @@ class ClientAppserviceDirectoryListServer(RestServlet): self.directory_handler = hs.get_directory_handler() self.auth = hs.get_auth() + class PutBody(RequestBodyModel): + visibility: Literal["public", "private"] = "public" + async def on_PUT( self, request: SynapseRequest, network_id: str, room_id: str ) -> Tuple[int, JsonDict]: - content = parse_json_object_from_request(request) - visibility = content.get("visibility", "public") - return await self._edit(request, network_id, room_id, visibility) + content = parse_and_validate_json_object_from_request(request, self.PutBody) + return await self._edit(request, network_id, room_id, content.visibility) async def on_DELETE( self, request: SynapseRequest, network_id: str, room_id: str @@ -176,7 +184,11 @@ class ClientAppserviceDirectoryListServer(RestServlet): return await self._edit(request, network_id, room_id, "private") async def _edit( - self, request: SynapseRequest, network_id: str, room_id: str, visibility: str + self, + request: SynapseRequest, + network_id: str, + room_id: str, + visibility: Literal["public", "private"], ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if not requester.app_service: diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py
index 916f5230f1..782e7d14e8 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py
@@ -50,7 +50,9 @@ class EventStreamRestServlet(RestServlet): raise SynapseError(400, "Guest users must specify room_id param") room_id = parse_string(request, "room_id") - pagin_config = await PaginationConfig.from_request(self.store, request) + pagin_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS if b"timeout" in args: try: diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py
index cfadcb8e50..9b1bb8b521 100644 --- a/synapse/rest/client/initial_sync.py +++ b/synapse/rest/client/initial_sync.py
@@ -39,7 +39,9 @@ class InitialSyncRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) args: Dict[bytes, List[bytes]] = request.args # type: ignore as_client_event = b"raw" not in args - pagination_config = await PaginationConfig.from_request(self.store, request) + pagination_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) include_archived = parse_boolean(request, "archived", default=False) content = await self.initial_sync_handler.snapshot_all_rooms( user_id=requester.user.to_string(), diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 14dec7ac4e..18a282b22c 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py
@@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReceiptTypes +from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -83,7 +83,7 @@ class ReceiptRestServlet(RestServlet): ) # Ensure the event ID roughly correlates to the thread ID. - if thread_id != await self._main_store.get_thread_id(event_id): + if not await self._is_event_in_thread(event_id, thread_id): raise SynapseError( 400, f"event_id {event_id} is not related to thread {thread_id}", @@ -109,6 +109,46 @@ class ReceiptRestServlet(RestServlet): return 200, {} + async def _is_event_in_thread(self, event_id: str, thread_id: str) -> bool: + """ + The event must be related to the thread ID (in a vague sense) to ensure + clients aren't sending bogus receipts. + + A thread ID is considered valid for a given event E if: + + 1. E has a thread relation which matches the thread ID; + 2. E has another event which has a thread relation to E matching the + thread ID; or + 3. E is recursively related (via any rel_type) to an event which + satisfies 1 or 2. + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + It is valid to send a receipt for thread A on A, B, C, D, or E. + + It is valid to send a receipt for the main timeline on A, D, and E. + + Args: + event_id: The event ID to check. + thread_id: The thread ID the event is potentially part of. + + Returns: + True if the event belongs to the given thread, otherwise False. + """ + + # If the receipt is on the main timeline, it is enough to check whether + # the event is directly related to a thread. + if thread_id == MAIN_TIMELINE: + return MAIN_TIMELINE == await self._main_store.get_thread_id(event_id) + + # Otherwise, check if the event is directly part of a thread, or is the + # root message (or related to the root message) of a thread. + return thread_id == await self._main_store.get_thread_id_for_receipts(event_id) + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReceiptRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index b31ce5a0d3..9dd59196d9 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py
@@ -13,12 +13,15 @@ # limitations under the License. import logging +import re from typing import TYPE_CHECKING, Optional, Tuple +from synapse.handlers.relations import ThreadsListInclude from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet +from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns +from synapse.storage.databases.main.relations import ThreadsNextBatch from synapse.streams.config import PaginationConfig from synapse.types import JsonDict @@ -78,5 +81,45 @@ class RelationPaginationServlet(RestServlet): return 200, result +class ThreadsServlet(RestServlet): + PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/threads"),) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self._relations_handler = hs.get_relations_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + limit = parse_integer(request, "limit", default=5) + from_token_str = parse_string(request, "from") + include = parse_string( + request, + "include", + default=ThreadsListInclude.all.value, + allowed_values=[v.value for v in ThreadsListInclude], + ) + + # Return the relations + from_token = None + if from_token_str: + from_token = ThreadsNextBatch.from_string(from_token_str) + + result = await self._relations_handler.get_threads( + requester=requester, + room_id=room_id, + include=ThreadsListInclude(include), + limit=limit, + from_token=from_token, + ) + + return 200, result + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationPaginationServlet(hs).register(http_server) + ThreadsServlet(hs).register(http_server) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index b6dedbed04..01e5079963 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py
@@ -729,7 +729,9 @@ class RoomInitialSyncRestServlet(RestServlet): self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = await PaginationConfig.from_request(self.store, request) + pagination_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) content = await self.initial_sync_handler.room_initial_sync( room_id=room_id, requester=requester, pagin_config=pagination_config ) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index d1d2e5f7e3..4b87ee978a 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py
@@ -76,6 +76,7 @@ class VersionsRestServlet(RestServlet): "v1.1", "v1.2", "v1.3", + "v1.4", ], # as per MSC1497: "unstable_features": { @@ -113,6 +114,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3882": self.config.experimental.msc3882_enabled, # Adds support for remotely enabling/disabling pushers, as per MSC3881 "org.matrix.msc3881": self.config.experimental.msc3881_enabled, + # Adds support for filtering /messages by event relation. + "org.matrix.msc3874": self.config.experimental.msc3874_enabled, }, }, ) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 3b8ed1f7ee..ddb7397714 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py
@@ -244,12 +244,18 @@ class CacheInvalidationWorkerStore(SQLBaseStore): # redacted. self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,)) self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,)) + self._attempt_to_invalidate_cache("get_thread_id", (redacts,)) + self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,)) if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) self._attempt_to_invalidate_cache( "get_invited_rooms_for_local_user", (state_key,) ) + self._attempt_to_invalidate_cache( + "get_rooms_for_user_with_stream_ordering", (state_key,) + ) + self._attempt_to_invalidate_cache("get_rooms_for_user", (state_key,)) if relates_to: self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) @@ -259,9 +265,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,)) - self._attempt_to_invalidate_cache( - "get_mutual_event_relations_for_rel_type", (relates_to,) - ) + self._attempt_to_invalidate_cache("get_threads", (room_id,)) async def invalidate_cache_and_stream( self, cache_name: str, keys: Tuple[Any, ...] diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 6b9a629edd..309a4ba664 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -1501,6 +1501,12 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas event_id: The event that failed to be fetched or processed cause: The error message or reason that we failed to pull the event """ + logger.debug( + "record_event_failed_pull_attempt room_id=%s, event_id=%s, cause=%s", + room_id, + event_id, + cause, + ) await self.db_pool.runInteraction( "record_event_failed_pull_attempt", self._record_event_failed_pull_attempt_upsert_txn, @@ -1530,6 +1536,54 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas txn.execute(sql, (room_id, event_id, 1, self._clock.time_msec(), cause)) + @trace + async def get_event_ids_to_not_pull_from_backoff( + self, + room_id: str, + event_ids: Collection[str], + ) -> List[str]: + """ + Filter down the events to ones that we've failed to pull before recently. Uses + exponential backoff. + + Args: + room_id: The room that the events belong to + event_ids: A list of events to filter down + + Returns: + List of event_ids that should not be attempted to be pulled + """ + event_failed_pull_attempts = await self.db_pool.simple_select_many_batch( + table="event_failed_pull_attempts", + column="event_id", + iterable=event_ids, + keyvalues={}, + retcols=( + "event_id", + "last_attempt_ts", + "num_attempts", + ), + desc="get_event_ids_to_not_pull_from_backoff", + ) + + current_time = self._clock.time_msec() + return [ + event_failed_pull_attempt["event_id"] + for event_failed_pull_attempt in event_failed_pull_attempts + # Exponential back-off (up to the upper bound) so we don't try to + # pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc. + if current_time + < event_failed_pull_attempt["last_attempt_ts"] + + ( + 2 + ** min( + event_failed_pull_attempt["num_attempts"], + BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, + ) + ) + * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS + ] + async def get_missing_events( self, room_id: str, diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 332e13d1c9..f070e6e88a 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -310,11 +310,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas event_push_actions_done = progress.get("event_push_actions_done", False) def add_thread_id_txn( - txn: LoggingTransaction, table_name: str, start_stream_ordering: int + txn: LoggingTransaction, start_stream_ordering: int ) -> int: - sql = f""" + sql = """ SELECT stream_ordering - FROM {table_name} + FROM event_push_actions WHERE thread_id IS NULL AND stream_ordering > ? @@ -326,7 +326,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # No more rows to process. rows = txn.fetchall() if not rows: - progress[f"{table_name}_done"] = True + progress["event_push_actions_done"] = True self.db_pool.updates._background_update_progress_txn( txn, "event_push_backfill_thread_id", progress ) @@ -335,16 +335,65 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # Update the thread ID for any of those rows. max_stream_ordering = rows[-1][0] - sql = f""" - UPDATE {table_name} + sql = """ + UPDATE event_push_actions SET thread_id = 'main' - WHERE stream_ordering <= ? AND thread_id IS NULL + WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL """ - txn.execute(sql, (max_stream_ordering,)) + txn.execute( + sql, + ( + start_stream_ordering, + max_stream_ordering, + ), + ) # Update progress. processed_rows = txn.rowcount - progress[f"max_{table_name}_stream_ordering"] = max_stream_ordering + progress["max_event_push_actions_stream_ordering"] = max_stream_ordering + self.db_pool.updates._background_update_progress_txn( + txn, "event_push_backfill_thread_id", progress + ) + + return processed_rows + + def add_thread_id_summary_txn(txn: LoggingTransaction) -> int: + min_user_id = progress.get("max_summary_user_id", "") + min_room_id = progress.get("max_summary_room_id", "") + + # Slightly overcomplicated query for getting the Nth user ID / room + # ID tuple, or the last if there are less than N remaining. + sql = """ + SELECT user_id, room_id FROM ( + SELECT user_id, room_id FROM event_push_summary + WHERE (user_id, room_id) > (?, ?) + AND thread_id IS NULL + ORDER BY user_id, room_id + LIMIT ? + ) AS e + ORDER BY user_id DESC, room_id DESC + LIMIT 1 + """ + + txn.execute(sql, (min_user_id, min_room_id, batch_size)) + row = txn.fetchone() + if not row: + return 0 + + max_user_id, max_room_id = row + + sql = """ + UPDATE event_push_summary + SET thread_id = 'main' + WHERE + (?, ?) < (user_id, room_id) AND (user_id, room_id) <= (?, ?) + AND thread_id IS NULL + """ + txn.execute(sql, (min_user_id, min_room_id, max_user_id, max_room_id)) + processed_rows = txn.rowcount + + progress["max_summary_user_id"] = max_user_id + progress["max_summary_room_id"] = max_room_id self.db_pool.updates._background_update_progress_txn( txn, "event_push_backfill_thread_id", progress ) @@ -360,15 +409,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas result = await self.db_pool.runInteraction( "event_push_backfill_thread_id", add_thread_id_txn, - "event_push_actions", progress.get("max_event_push_actions_stream_ordering", 0), ) else: result = await self.db_pool.runInteraction( "event_push_backfill_thread_id", - add_thread_id_txn, - "event_push_summary", - progress.get("max_event_push_summary_stream_ordering", 0), + add_thread_id_summary_txn, ) # Only done after the event_push_summary table is done. diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 3e15827986..6698cbf664 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -35,7 +35,7 @@ import attr from prometheus_client import Counter import synapse.metrics -from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, relation_from_event @@ -1616,7 +1616,7 @@ class PersistEventsStore: ) # Remove from relations table. - self._handle_redact_relations(txn, event.redacts) + self._handle_redact_relations(txn, event.room_id, event.redacts) # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. @@ -1866,6 +1866,34 @@ class PersistEventsStore: }, ) + if relation.rel_type == RelationTypes.THREAD: + # Upsert into the threads table, but only overwrite the value if the + # new event is of a later topological order OR if the topological + # ordering is equal, but the stream ordering is later. + sql = """ + INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (room_id, thread_id) + DO UPDATE SET + latest_event_id = excluded.latest_event_id, + topological_ordering = excluded.topological_ordering, + stream_ordering = excluded.stream_ordering + WHERE + threads.topological_ordering <= excluded.topological_ordering AND + threads.stream_ordering < excluded.stream_ordering + """ + + txn.execute( + sql, + ( + event.room_id, + relation.parent_id, + event.event_id, + event.depth, + event.internal_metadata.stream_ordering, + ), + ) + def _handle_insertion_event( self, txn: LoggingTransaction, event: EventBase ) -> None: @@ -1989,13 +2017,14 @@ class PersistEventsStore: txn.execute(sql, (batch_id,)) def _handle_redact_relations( - self, txn: LoggingTransaction, redacted_event_id: str + self, txn: LoggingTransaction, room_id: str, redacted_event_id: str ) -> None: """Handles receiving a redaction and checking whether the redacted event has any relations which must be removed from the database. Args: txn + room_id: The room ID of the event that was redacted. redacted_event_id: The event that was redacted. """ @@ -2025,9 +2054,7 @@ class PersistEventsStore: txn, self.store.get_thread_participated, (redacted_relates_to,) ) self.store._invalidate_cache_and_stream( - txn, - self.store.get_mutual_event_relations_for_rel_type, - (redacted_relates_to,), + txn, self.store.get_threads, (room_id,) ) self.db_pool.simple_delete_txn( diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 221864d9cc..7bc7f2f33e 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -474,7 +474,7 @@ class EventsWorkerStore(SQLBaseStore): return [] # there may be duplicates so we cast the list to a set - event_entry_map = await self._get_events_from_cache_or_db( + event_entry_map = await self.get_unredacted_events_from_cache_or_db( set(event_ids), allow_rejected=allow_rejected ) @@ -509,7 +509,9 @@ class EventsWorkerStore(SQLBaseStore): continue redacted_event_id = entry.event.redacts - event_map = await self._get_events_from_cache_or_db([redacted_event_id]) + event_map = await self.get_unredacted_events_from_cache_or_db( + [redacted_event_id] + ) original_event_entry = event_map.get(redacted_event_id) if not original_event_entry: # we don't have the redacted event (or it was rejected). @@ -588,11 +590,16 @@ class EventsWorkerStore(SQLBaseStore): return events @cancellable - async def _get_events_from_cache_or_db( - self, event_ids: Iterable[str], allow_rejected: bool = False + async def get_unredacted_events_from_cache_or_db( + self, + event_ids: Iterable[str], + allow_rejected: bool = False, ) -> Dict[str, EventCacheEntry]: """Fetch a bunch of events from the cache or the database. + Note that the events pulled by this function will not have any redactions + applied, and no guarantee is made about the ordering of the events returned. + If events are pulled from the database, they will be cached for future lookups. Unknown events are omitted from the response. @@ -1495,21 +1502,15 @@ class EventsWorkerStore(SQLBaseStore): Returns: a dict {event_id -> bool} """ - # if the event cache contains the event, obviously we've seen it. - - cache_results = { - event_id - for event_id in event_ids - if await self._get_event_cache.contains((event_id,)) - } - results = dict.fromkeys(cache_results, True) - remaining = [ - event_id for event_id in event_ids if event_id not in cache_results - ] - if not remaining: - return results + # TODO: We used to query the _get_event_cache here as a fast-path before + # hitting the database. For if an event were in the cache, we've presumably + # seen it before. + # + # But this is currently an invalid assumption due to the _get_event_cache + # not being invalidated when purging events from a room. The optimisation can + # be re-added after https://github.com/matrix-org/synapse/issues/13476 - def have_seen_events_txn(txn: LoggingTransaction) -> None: + def have_seen_events_txn(txn: LoggingTransaction) -> Dict[str, bool]: # we deliberately do *not* query the database for room_id, to make the # query an index-only lookup on `events_event_id_key`. # @@ -1517,16 +1518,17 @@ class EventsWorkerStore(SQLBaseStore): sql = "SELECT event_id FROM events AS e WHERE " clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", remaining + txn.database_engine, "e.event_id", event_ids ) txn.execute(sql + clause, args) found_events = {eid for eid, in txn} # ... and then we can update the results for each key - results.update({eid: (eid in found_events) for eid in remaining}) + return {eid: (eid in found_events) for eid in event_ids} - await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn) - return results + return await self.db_pool.runInteraction( + "have_seen_events", have_seen_events_txn + ) @cached(max_entries=100000, tree=True) async def have_seen_event(self, room_id: str, event_id: str) -> bool: diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 8295322b0e..51416b2236 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -29,7 +29,6 @@ from typing import ( ) from synapse.api.errors import StoreError -from synapse.config.homeserver import ExperimentalConfig from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -63,9 +62,7 @@ logger = logging.getLogger(__name__) def _load_rules( - rawrules: List[JsonDict], - enabled_map: Dict[str, bool], - experimental_config: ExperimentalConfig, + rawrules: List[JsonDict], enabled_map: Dict[str, bool] ) -> FilteredPushRules: """Take the DB rows returned from the DB and convert them into a full `FilteredPushRules` object. @@ -83,9 +80,7 @@ def _load_rules( push_rules = PushRules(ruleslist) - filtered_rules = FilteredPushRules( - push_rules, enabled_map, msc3772_enabled=experimental_config.msc3772_enabled - ) + filtered_rules = FilteredPushRules(push_rules, enabled_map) return filtered_rules @@ -165,7 +160,7 @@ class PushRulesWorkerStore( enabled_map = await self.get_push_rules_enabled_for_user(user_id) - return _load_rules(rows, enabled_map, self.hs.config.experimental) + return _load_rules(rows, enabled_map) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( @@ -224,9 +219,7 @@ class PushRulesWorkerStore( results: Dict[str, FilteredPushRules] = {} for user_id, rules in raw_rules.items(): - results[user_id] = _load_rules( - rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental - ) + results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) return results diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 246f78ac1f..dc6989527e 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py
@@ -418,6 +418,8 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type = event_entry.setdefault(row["receipt_type"], {}) receipt_type[row["user_id"]] = db_to_json(row["data"]) + if row["thread_id"]: + receipt_type[row["user_id"]]["thread_id"] = row["thread_id"] results = { room_id: [results[room_id]] if room_id in results else [] diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 116abef9de..1de62ee9df 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py
@@ -14,6 +14,7 @@ import logging from typing import ( + TYPE_CHECKING, Collection, Dict, FrozenSet, @@ -28,19 +29,48 @@ from typing import ( import attr -from synapse.api.constants import RelationTypes +from synapse.api.constants import MAIN_TIMELINE, RelationTypes +from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @attr.s(slots=True, frozen=True, auto_attribs=True) +class ThreadsNextBatch: + topological_ordering: int + stream_ordering: int + + def __str__(self) -> str: + return f"{self.topological_ordering}_{self.stream_ordering}" + + @classmethod + def from_string(cls, string: str) -> "ThreadsNextBatch": + """ + Creates a ThreadsNextBatch from its textual representation. + """ + try: + keys = (int(s) for s in string.split("_")) + return cls(*keys) + except Exception: + raise SynapseError(400, "Invalid threads token") + + +@attr.s(slots=True, frozen=True, auto_attribs=True) class _RelatedEvent: """ Contains enough information about a related event in order to properly filter @@ -56,6 +86,76 @@ class _RelatedEvent: class RelationsWorkerStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_update_handler( + "threads_backfill", self._backfill_threads + ) + + async def _backfill_threads(self, progress: JsonDict, batch_size: int) -> int: + """Backfill the threads table.""" + + def threads_backfill_txn(txn: LoggingTransaction) -> int: + last_thread_id = progress.get("last_thread_id", "") + + # Get the latest event in each thread by topo ordering / stream ordering. + # + # Note that the MAX(event_id) is needed to abide by the rules of group by, + # but doesn't actually do anything since there should only be a single event + # ID per topo/stream ordering pair. + sql = f""" + SELECT room_id, relates_to_id, MAX(topological_ordering), MAX(stream_ordering), MAX(event_id) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id > ? AND + relation_type = '{RelationTypes.THREAD}' + GROUP BY room_id, relates_to_id + ORDER BY relates_to_id + LIMIT ? + """ + txn.execute(sql, (last_thread_id, batch_size)) + + # No more rows to process. + rows = txn.fetchall() + if not rows: + return 0 + + # Insert the rows into the threads table. If a matching thread already exists, + # assume it is from a newer event. + sql = """ + INSERT INTO threads (room_id, thread_id, topological_ordering, stream_ordering, latest_event_id) + VALUES %s + ON CONFLICT (room_id, thread_id) + DO NOTHING + """ + if isinstance(txn.database_engine, PostgresEngine): + txn.execute_values(sql % ("?",), rows, fetch=False) + else: + txn.execute_batch(sql % ("(?, ?, ?, ?, ?)",), rows) + + # Mark the progress. + self.db_pool.updates._background_update_progress_txn( + txn, "threads_backfill", {"last_thread_id": rows[-1][1]} + ) + + return txn.rowcount + + result = await self.db_pool.runInteraction( + "threads_backfill", threads_backfill_txn + ) + + if not result: + await self.db_pool.updates._end_background_update("threads_backfill") + + return result + @cached(uncached_args=("event",), tree=True) async def get_relations_for_event( self, @@ -776,95 +876,194 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) - @cached(iterable=True) - async def get_mutual_event_relations_for_rel_type( - self, event_id: str, relation_type: str - ) -> Set[Tuple[str, str]]: - raise NotImplementedError() - - @cachedList( - cached_method_name="get_mutual_event_relations_for_rel_type", - list_name="relation_types", - ) - async def get_mutual_event_relations( - self, event_id: str, relation_types: Collection[str] - ) -> Dict[str, Set[Tuple[str, str]]]: - """ - Fetch event metadata for events which related to the same event as the given event. - - If the given event has no relation information, returns an empty dictionary. + @cached(tree=True) + async def get_threads( + self, + room_id: str, + limit: int = 5, + from_token: Optional[ThreadsNextBatch] = None, + ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: + """Get a list of thread IDs, ordered by topological ordering of their + latest reply. Args: - event_id: The event ID which is targeted by relations. - relation_types: The relation types to check for mutual relations. + room_id: The room the event belongs to. + limit: Only fetch the most recent `limit` threads. + from_token: Fetch rows from a previous next_batch, or from the start if None. Returns: - A dictionary of relation type to: - A set of tuples of: - The sender - The event type + A tuple of: + A list of thread root event IDs. + + The next_batch, if one exists. """ - rel_type_sql, rel_type_args = make_in_list_sql_clause( - self.database_engine, "relation_type", relation_types - ) + # Generate the pagination clause, if necessary. + # + # Find any threads where the latest reply is equal / before the last + # thread's topo ordering and earlier in stream ordering. + pagination_clause = "" + pagination_args: tuple = () + if from_token: + pagination_clause = "AND topological_ordering <= ? AND stream_ordering < ?" + pagination_args = ( + from_token.topological_ordering, + from_token.stream_ordering, + ) sql = f""" - SELECT DISTINCT relation_type, sender, type FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND {rel_type_sql} + SELECT thread_id, topological_ordering, stream_ordering + FROM threads + WHERE + room_id = ? + {pagination_clause} + ORDER BY topological_ordering DESC, stream_ordering DESC + LIMIT ? """ - def _get_event_relations( + def _get_threads_txn( txn: LoggingTransaction, - ) -> Dict[str, Set[Tuple[str, str]]]: - txn.execute(sql, [event_id] + rel_type_args) - result: Dict[str, Set[Tuple[str, str]]] = { - rel_type: set() for rel_type in relation_types - } - for rel_type, sender, type in txn.fetchall(): - result[rel_type].add((sender, type)) - return result + ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: + txn.execute(sql, (room_id, *pagination_args, limit + 1)) - return await self.db_pool.runInteraction( - "get_event_relations", _get_event_relations - ) + rows = cast(List[Tuple[str, int, int]], txn.fetchall()) + thread_ids = [r[0] for r in rows] + + # If there are more events, generate the next pagination key from the + # last thread which will be returned. + next_token = None + if len(thread_ids) > limit: + last_topo_id = rows[-2][1] + last_stream_id = rows[-2][2] + next_token = ThreadsNextBatch(last_topo_id, last_stream_id) + + return thread_ids[:limit], next_token + + return await self.db_pool.runInteraction("get_threads", _get_threads_txn) @cached() - async def get_thread_id(self, event_id: str) -> Optional[str]: + async def get_thread_id(self, event_id: str) -> str: """ Get the thread ID for an event. This considers multi-level relations, e.g. an annotation to an event which is part of a thread. + It only searches up the relations tree, i.e. it only searches for events + which the given event is related to (and which those events are related + to, etc.) + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + get_thread_id(X) considers events B and C as part of thread A. + + See also get_thread_id_for_receipts. + Args: event_id: The event ID to fetch the thread ID for. Returns: The event ID of the root event in the thread, if this event is part - of a thread. None, otherwise. + of a thread. "main", otherwise. """ - # Since event relations form a tree, we should only ever find 0 or 1 - # results from the below query. + + # Recurse event relations up to the *root* event, then search that chain + # of relations for a thread relation. If one is found, the root event is + # returned. + # + # Note that this should only ever find 0 or 1 entries since it is invalid + # for an event to have a thread relation to an event which also has a + # relation. sql = """ WITH RECURSIVE related_events AS ( - SELECT event_id, relates_to_id, relation_type + SELECT event_id, relates_to_id, relation_type, 0 depth FROM event_relations WHERE event_id = ? - UNION SELECT e.event_id, e.relates_to_id, e.relation_type + UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 FROM event_relations e INNER JOIN related_events r ON r.relates_to_id = e.event_id - ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread'; + WHERE depth <= 3 + ) + SELECT relates_to_id FROM related_events + WHERE relation_type = 'm.thread' + ORDER BY depth DESC + LIMIT 1; """ - def _get_thread_id(txn: LoggingTransaction) -> Optional[str]: + def _get_thread_id(txn: LoggingTransaction) -> str: txn.execute(sql, (event_id,)) - # TODO Should we ensure there's only a single result here? row = txn.fetchone() if row: return row[0] - return None + + # If no thread was found, it is part of the main timeline. + return MAIN_TIMELINE return await self.db_pool.runInteraction("get_thread_id", _get_thread_id) + @cached() + async def get_thread_id_for_receipts(self, event_id: str) -> str: + """ + Get the thread ID for an event by traversing to the top-most related event + and confirming any children events form a thread. + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + get_thread_id_for_receipts(X) considers events A, B, C, D, and E as part + of thread A. + + See also get_thread_id. + + Args: + event_id: The event ID to fetch the thread ID for. + + Returns: + The event ID of the root event in the thread, if this event is part + of a thread. "main", otherwise. + """ + + # Recurse event relations up to the *root* event, then search for any events + # related to that root node for a thread relation. If one is found, the + # root event is returned. + # + # Note that there cannot be thread relations in the middle of the chain since + # it is invalid for an event to have a thread relation to an event which also + # has a relation. + sql = """ + SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE(( + WITH RECURSIVE related_events AS ( + SELECT event_id, relates_to_id, relation_type, 0 depth + FROM event_relations + WHERE event_id = ? + UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 + FROM event_relations e + INNER JOIN related_events r ON r.relates_to_id = e.event_id + WHERE depth <= 3 + ) + SELECT relates_to_id FROM related_events + ORDER BY depth DESC + LIMIT 1 + ), ?) AND relation_type = 'm.thread' LIMIT 1; + """ + + def _get_related_thread_id(txn: LoggingTransaction) -> str: + txn.execute(sql, (event_id, event_id)) + row = txn.fetchone() + if row: + return row[0] + + # If no thread was found, it is part of the main timeline. + return MAIN_TIMELINE + + return await self.db_pool.runInteraction( + "get_related_thread_id", _get_related_thread_id + ) + class RelationsStore(RelationsWorkerStore): pass diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 2337289d88..2ed6ad754f 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -666,7 +666,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): cached_method_name="get_rooms_for_user", list_name="user_ids", ) - async def get_rooms_for_users( + async def _get_rooms_for_users( self, user_ids: Collection[str] ) -> Dict[str, FrozenSet[str]]: """A batched version of `get_rooms_for_user`. @@ -697,6 +697,21 @@ class RoomMemberWorkerStore(EventsWorkerStore): return {key: frozenset(rooms) for key, rooms in user_rooms.items()} + async def get_rooms_for_users( + self, user_ids: Collection[str] + ) -> Dict[str, FrozenSet[str]]: + """A batched wrapper around `_get_rooms_for_users`, to prevent locking + other calls to `get_rooms_for_user` for large user lists. + """ + all_user_rooms: Dict[str, FrozenSet[str]] = {} + + # 250 users is pretty arbitrary but the data can be quite large if users + # are in many rooms. + for user_ids in batch_iter(user_ids, 250): + all_user_rooms.update(await self._get_rooms_for_users(user_ids)) + + return all_user_rooms + @cached(max_entries=10000) async def does_pair_of_users_share_a_room( self, user_id: str, other_user_id: str diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 530f04e149..09ce855aa8 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py
@@ -357,6 +357,24 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: ) args.extend(event_filter.related_by_rel_types) + if event_filter.rel_types: + clauses.append( + "(%s)" + % " OR ".join( + "event_relation.relation_type = ?" for _ in event_filter.rel_types + ) + ) + args.extend(event_filter.rel_types) + + if event_filter.not_rel_types: + clauses.append( + "((%s) OR event_relation.relation_type IS NULL)" + % " AND ".join( + "event_relation.relation_type != ?" for _ in event_filter.not_rel_types + ) + ) + args.extend(event_filter.not_rel_types) + return " AND ".join(clauses), args @@ -1024,28 +1042,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "after": {"event_ids": events_after, "token": end_token}, } - async def get_all_new_events_stream( - self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False - ) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]: + async def get_all_new_event_ids_stream( + self, + from_id: int, + current_id: int, + limit: int, + ) -> Tuple[int, Dict[str, Optional[int]]]: """Get all new events - Returns all events with from_id < stream_ordering <= current_id. + Returns all event ids with from_id < stream_ordering <= current_id. Args: from_id: the stream_ordering of the last event we processed current_id: the stream_ordering of the most recently processed event limit: the maximum number of events to return - get_prev_content: whether to fetch previous event content Returns: - A tuple of (next_id, events, event_to_received_ts), where `next_id` + A tuple of (next_id, event_to_received_ts), where `next_id` is the next value to pass as `from_id` (it will either be the stream_ordering of the last returned event, or, if fewer than `limit` events were found, the `current_id`). The `event_to_received_ts` is - a dictionary mapping event ID to the event `received_ts`. + a dictionary mapping event ID to the event `received_ts`, sorted by ascending + stream_ordering. """ - def get_all_new_events_stream_txn( + def get_all_new_event_ids_stream_txn( txn: LoggingTransaction, ) -> Tuple[int, Dict[str, Optional[int]]]: sql = ( @@ -1070,15 +1091,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return upper_bound, event_to_received_ts upper_bound, event_to_received_ts = await self.db_pool.runInteraction( - "get_all_new_events_stream", get_all_new_events_stream_txn - ) - - events = await self.get_events_as_list( - event_to_received_ts.keys(), - get_prev_content=get_prev_content, + "get_all_new_event_ids_stream", get_all_new_event_ids_stream_txn ) - return upper_bound, events, event_to_received_ts + return upper_bound, event_to_received_ts async def get_federation_out_pos(self, typ: str) -> int: if self._need_to_reset_federation_stream_positions: @@ -1202,8 +1218,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): `to_token`), or `limit` is zero. """ - assert int(limit) >= 0 - # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. @@ -1282,8 +1296,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # Multiple labels could cause the same event to appear multiple times. needs_distinct = True - # If there is a filter on relation_senders and relation_types join to the - # relations table. + # If there is a relation_senders and relation_types filter join to the + # relations table to get events related to the current event. if event_filter and ( event_filter.related_by_senders or event_filter.related_by_rel_types ): @@ -1298,6 +1312,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id) """ + # If there is a not_rel_types filter join to the relations table to get + # the event's relation information. + if event_filter and (event_filter.rel_types or event_filter.not_rel_types): + join_clause += """ + LEFT JOIN event_relations AS event_relation USING (event_id) + """ + if needs_distinct: select_keywords += " DISTINCT" diff --git a/synapse/storage/schema/main/delta/73/09threads_table.sql b/synapse/storage/schema/main/delta/73/09threads_table.sql new file mode 100644
index 0000000000..aa7c5e9a2e --- /dev/null +++ b/synapse/storage/schema/main/delta/73/09threads_table.sql
@@ -0,0 +1,30 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE threads ( + room_id TEXT NOT NULL, + -- The event ID of the root event in the thread. + thread_id TEXT NOT NULL, + -- The latest event ID and corresponding topo / stream ordering. + latest_event_id TEXT NOT NULL, + topological_ordering BIGINT NOT NULL, + stream_ordering BIGINT NOT NULL, + CONSTRAINT threads_uniqueness UNIQUE (room_id, thread_id) +); + +CREATE INDEX threads_ordering_idx ON threads(room_id, topological_ordering, stream_ordering); + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7309, 'threads_backfill', '{}'); diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py
index 806b671305..2dcd43d0a2 100644 --- a/synapse/streams/__init__.py +++ b/synapse/streams/__init__.py
@@ -27,7 +27,7 @@ class EventSource(Generic[K, R]): self, user: UserID, from_key: K, - limit: Optional[int], + limit: int, room_ids: Collection[str], is_guest: bool, explicit_room_id: Optional[str] = None, diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index f6f7bf3d8b..6df2de919c 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py
@@ -35,14 +35,14 @@ class PaginationConfig: from_token: Optional[StreamToken] to_token: Optional[StreamToken] direction: str - limit: Optional[int] + limit: int @classmethod async def from_request( cls, store: "DataStore", request: SynapseRequest, - default_limit: Optional[int] = None, + default_limit: int, default_dir: str = "f", ) -> "PaginationConfig": direction = parse_string( @@ -69,12 +69,10 @@ class PaginationConfig: raise SynapseError(400, "'to' parameter is invalid") limit = parse_integer(request, "limit", default=default_limit) + if limit < 0: + raise SynapseError(400, "Limit must be 0 or above") - if limit: - if limit < 0: - raise SynapseError(400, "Limit must be 0 or above") - - limit = min(int(limit), MAX_LIMIT) + limit = min(limit, MAX_LIMIT) try: return PaginationConfig(from_tok, to_tok, direction, limit) diff --git a/synapse/visibility.py b/synapse/visibility.py
index c4048d2477..40a9c5b53f 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py
@@ -84,7 +84,15 @@ async def filter_events_for_client( """ # Filter out events that have been soft failed so that we don't relay them # to clients. + events_before_filtering = events events = [e for e in events if not e.internal_metadata.is_soft_failed()] + if len(events_before_filtering) != len(events): + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "filter_events_for_client: Filtered out soft-failed events: Before=%s, After=%s", + [event.event_id for event in events_before_filtering], + [event.event_id for event in events], + ) types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id)) @@ -301,6 +309,10 @@ def _check_client_allowed_to_see_event( _check_filter_send_to_client(event, clock, retention_policy, sender_ignored) == _CheckFilter.DENIED ): + logger.debug( + "_check_client_allowed_to_see_event(event=%s): Filtered out event because `_check_filter_send_to_client` returned `_CheckFilter.DENIED`", + event.event_id, + ) return None if event.event_id in always_include_ids: @@ -312,9 +324,17 @@ def _check_client_allowed_to_see_event( # for out-of-band membership events (eg, incoming invites, or rejections of # said invite) for the user themselves. if event.type == EventTypes.Member and event.state_key == user_id: - logger.debug("Returning out-of-band-membership event %s", event) + logger.debug( + "_check_client_allowed_to_see_event(event=%s): Returning out-of-band-membership event %s", + event.event_id, + event, + ) return event + logger.debug( + "_check_client_allowed_to_see_event(event=%s): Filtered out event because it's an outlier", + event.event_id, + ) return None if state is None: @@ -337,11 +357,21 @@ def _check_client_allowed_to_see_event( membership_result = _check_membership(user_id, event, visibility, state, is_peeking) if not membership_result.allowed: + logger.debug( + "_check_client_allowed_to_see_event(event=%s): Filtered out event because the user can't see the event because of their membership, membership_result.allowed=%s membership_result.joined=%s", + event.event_id, + membership_result.allowed, + membership_result.joined, + ) return None # If the sender has been erased and the user was not joined at the time, we # must only return the redacted form. if sender_erased and not membership_result.joined: + logger.debug( + "_check_client_allowed_to_see_event(event=%s): Returning pruned event because `sender_erased` and the user was not joined at the time", + event.event_id, + ) event = prune_event(event) return event