diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 5fa599e70e..d850e54e17 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -72,6 +72,7 @@ from synapse.storage.databases.main.registration import (
RegistrationBackgroundUpdateStore,
find_max_generated_user_id_localpart,
)
+from synapse.storage.databases.main.relations import RelationsWorkerStore
from synapse.storage.databases.main.room import RoomBackgroundUpdateStore
from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore
from synapse.storage.databases.main.search import SearchBackgroundUpdateStore
@@ -206,6 +207,7 @@ class Store(
PusherWorkerStore,
PresenceBackgroundUpdateStore,
ReceiptsBackgroundUpdateStore,
+ RelationsWorkerStore,
):
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index c606207569..e0873b1913 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -640,6 +640,27 @@ class FederationError(RuntimeError):
}
+class FederationPullAttemptBackoffError(RuntimeError):
+ """
+ Raised to indicate that we are are deliberately not attempting to pull the given
+ event over federation because we've already done so recently and are backing off.
+
+ Attributes:
+ event_id: The event_id which we are refusing to pull
+ message: A custom error message that gives more context
+ """
+
+ def __init__(self, event_ids: List[str], message: Optional[str]):
+ self.event_ids = event_ids
+
+ if message:
+ error_message = message
+ else:
+ error_message = f"Not attempting to pull event_ids={self.event_ids} because we already tried to pull them recently (backing off)."
+
+ super().__init__(error_message)
+
+
class HttpResponseException(CodeMessageException):
"""
Represents an HTTP-level failure of an outbound request
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index cc31cf8cc7..26be377d03 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -36,7 +36,7 @@ from jsonschema import FormatChecker
from synapse.api.constants import EduTypes, EventContentFields
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
-from synapse.events import EventBase
+from synapse.events import EventBase, relation_from_event
from synapse.types import JsonDict, RoomID, UserID
if TYPE_CHECKING:
@@ -53,6 +53,12 @@ FILTER_SCHEMA = {
# check types are valid event types
"types": {"type": "array", "items": {"type": "string"}},
"not_types": {"type": "array", "items": {"type": "string"}},
+ # MSC3874, filtering /messages.
+ "org.matrix.msc3874.rel_types": {"type": "array", "items": {"type": "string"}},
+ "org.matrix.msc3874.not_rel_types": {
+ "type": "array",
+ "items": {"type": "string"},
+ },
},
}
@@ -334,8 +340,15 @@ class Filter:
self.labels = filter_json.get("org.matrix.labels", None)
self.not_labels = filter_json.get("org.matrix.not_labels", [])
- self.related_by_senders = self.filter_json.get("related_by_senders", None)
- self.related_by_rel_types = self.filter_json.get("related_by_rel_types", None)
+ self.related_by_senders = filter_json.get("related_by_senders", None)
+ self.related_by_rel_types = filter_json.get("related_by_rel_types", None)
+
+ # For compatibility with _check_fields.
+ self.rel_types = None
+ self.not_rel_types = []
+ if hs.config.experimental.msc3874_enabled:
+ self.rel_types = filter_json.get("org.matrix.msc3874.rel_types", None)
+ self.not_rel_types = filter_json.get("org.matrix.msc3874.not_rel_types", [])
def filters_all_types(self) -> bool:
return "*" in self.not_types
@@ -386,11 +399,19 @@ class Filter:
# check if there is a string url field in the content for filtering purposes
labels = content.get(EventContentFields.LABELS, [])
+ # Check if the event has a relation.
+ rel_type = None
+ if isinstance(event, EventBase):
+ relation = relation_from_event(event)
+ if relation:
+ rel_type = relation.rel_type
+
field_matchers = {
"rooms": lambda v: room_id == v,
"senders": lambda v: sender == v,
"types": lambda v: _matches_wildcard(ev_type, v),
"labels": lambda v: v in labels,
+ "rel_types": lambda v: rel_type == v,
}
result = self._check_fields(field_matchers)
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 5e3825fca6..dc49840f73 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -65,6 +65,7 @@ from synapse.rest.client import (
push_rule,
read_marker,
receipts,
+ relations,
room,
room_batch,
room_keys,
@@ -308,6 +309,7 @@ class GenericWorkerServer(HomeServer):
sync.register_servlets(self, resource)
events.register_servlets(self, resource)
room.register_servlets(self, resource, is_worker=True)
+ relations.register_servlets(self, resource)
room.register_deprecated_servlets(self, resource)
initial_sync.register_servlets(self, resource)
room_batch.register_servlets(self, resource)
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index e00cb7096c..f9a49451d8 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -95,8 +95,6 @@ class ExperimentalConfig(Config):
# MSC2815 (allow room moderators to view redacted event content)
self.msc2815_enabled: bool = experimental.get("msc2815_enabled", False)
- # MSC3772: A push rule for mutual relations.
- self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False)
# MSC3773: Thread notifications
self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False)
@@ -119,3 +117,6 @@ class ExperimentalConfig(Config):
self.msc3882_token_timeout = self.parse_duration(
experimental.get("msc3882_token_timeout", "5m")
)
+
+ # MSC3874: Filtering /messages with rel_types / not_rel_types.
+ self.msc3874_enabled: bool = experimental.get("msc3874_enabled", False)
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 1033496bb4..e4759711ed 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -205,7 +205,7 @@ class ContentRepositoryConfig(Config):
)
self.url_preview_enabled = config.get("url_preview_enabled", False)
if self.url_preview_enabled:
- check_requirements("url_preview")
+ check_requirements("url-preview")
proxy_env = getproxies_environment()
if "url_preview_ip_range_blacklist" not in config:
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 4dca711cd2..b220ab43fc 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -1294,7 +1294,7 @@ class FederationClient(FederationBase):
return resp[1]
async def send_knock(self, destinations: List[str], pdu: EventBase) -> JsonDict:
- """Attempts to send a knock event to given a list of servers. Iterates
+ """Attempts to send a knock event to a given list of servers. Iterates
through the list until one attempt succeeds.
Doing so will cause the remote server to add the event to the graph,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 907940e19e..28097664b4 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -824,7 +824,14 @@ class FederationServer(FederationBase):
context, self._room_prejoin_state_types
)
)
- return {"knock_state_events": stripped_room_state}
+ return {
+ "knock_room_state": stripped_room_state,
+ # Since v1.37, Synapse incorrectly used "knock_state_events" for this field.
+ # Thus, we also populate a 'knock_state_events' with the same content to
+ # support old instances.
+ # See https://github.com/matrix-org/synapse/issues/14088.
+ "knock_state_events": stripped_room_state,
+ }
async def _on_send_membership_event(
self, origin: str, content: JsonDict, membership_type: str, room_id: str
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index a6cb3ba58f..774ecd81b6 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -353,21 +353,25 @@ class FederationSender(AbstractFederationSender):
last_token = await self.store.get_federation_out_pos("events")
(
next_token,
- events,
event_to_received_ts,
- ) = await self.store.get_all_new_events_stream(
+ ) = await self.store.get_all_new_event_ids_stream(
last_token, self._last_poked_id, limit=100
)
+ event_ids = event_to_received_ts.keys()
+ event_entries = await self.store.get_unredacted_events_from_cache_or_db(
+ event_ids
+ )
+
logger.debug(
"Handling %i -> %i: %i events to send (current id %i)",
last_token,
next_token,
- len(events),
+ len(event_entries),
self._last_poked_id,
)
- if not events and next_token >= self._last_poked_id:
+ if not event_entries and next_token >= self._last_poked_id:
logger.debug("All events processed")
break
@@ -508,8 +512,14 @@ class FederationSender(AbstractFederationSender):
await handle_event(event)
events_by_room: Dict[str, List[EventBase]] = {}
- for event in events:
- events_by_room.setdefault(event.room_id, []).append(event)
+
+ for event_id in event_ids:
+ # `event_entries` is unsorted, so we have to iterate over `event_ids`
+ # to ensure the events are in the right order
+ event_cache = event_entries.get(event_id)
+ if event_cache:
+ event = event_cache.event
+ events_by_room.setdefault(event.room_id, []).append(event)
await make_deferred_yieldable(
defer.gatherResults(
@@ -524,9 +534,10 @@ class FederationSender(AbstractFederationSender):
logger.debug("Successfully handled up to %i", next_token)
await self.store.update_federation_out_pos("events", next_token)
- if events:
+ if event_entries:
now = self.clock.time_msec()
- ts = event_to_received_ts[events[-1].event_id]
+ last_id = next(reversed(event_ids))
+ ts = event_to_received_ts[last_id]
assert ts is not None
synapse.metrics.event_processing_lag.labels(
@@ -536,7 +547,7 @@ class FederationSender(AbstractFederationSender):
"federation_sender"
).set(ts)
- events_processed_counter.inc(len(events))
+ events_processed_counter.inc(len(event_entries))
event_processing_loop_room_count.labels("federation_sender").inc(
len(events_by_room)
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index 6bb4659c4c..6f11138b57 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -489,7 +489,7 @@ class FederationV2InviteServlet(BaseFederationServerServlet):
room_version = content["room_version"]
event = content["event"]
- invite_room_state = content["invite_room_state"]
+ invite_room_state = content.get("invite_room_state", [])
# Synapse expects invite_room_state to be in unsigned, as it is in v1
# API
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 0478448b47..fc21d58001 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -225,7 +225,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
self,
user: UserID,
from_key: int,
- limit: Optional[int],
+ limit: int,
room_ids: Collection[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 203b62e015..66f5b8d108 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -109,10 +109,13 @@ class ApplicationServicesHandler:
last_token = await self.store.get_appservice_last_pos()
(
upper_bound,
- events,
event_to_received_ts,
- ) = await self.store.get_all_new_events_stream(
- last_token, self.current_max, limit=100, get_prev_content=True
+ ) = await self.store.get_all_new_event_ids_stream(
+ last_token, self.current_max, limit=100
+ )
+
+ events = await self.store.get_events_as_list(
+ event_to_received_ts.keys(), get_prev_content=True
)
events_by_room: Dict[str, List[EventBase]] = {}
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 7127d5aefc..d52ebada6b 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -16,6 +16,8 @@ import logging
import string
from typing import TYPE_CHECKING, Iterable, List, Optional
+from typing_extensions import Literal
+
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes
from synapse.api.errors import (
AuthError,
@@ -429,7 +431,10 @@ class DirectoryHandler:
return await self.auth.check_can_change_room_list(room_id, requester)
async def edit_published_room_list(
- self, requester: Requester, room_id: str, visibility: str
+ self,
+ requester: Requester,
+ room_id: str,
+ visibility: Literal["public", "private"],
) -> None:
"""Edit the entry of the room in the published room list.
@@ -451,9 +456,6 @@ class DirectoryHandler:
if requester.is_guest:
raise AuthError(403, "Guests cannot edit the published room list")
- if visibility not in ["public", "private"]:
- raise SynapseError(400, "Invalid visibility setting")
-
if visibility == "public" and not self.enable_room_list_search:
# The room list has been disabled.
raise AuthError(
@@ -505,7 +507,11 @@ class DirectoryHandler:
await self.store.set_room_is_public(room_id, making_public)
async def edit_published_appservice_room_list(
- self, appservice_id: str, network_id: str, room_id: str, visibility: str
+ self,
+ appservice_id: str,
+ network_id: str,
+ room_id: str,
+ visibility: Literal["public", "private"],
) -> None:
"""Add or remove a room from the appservice/network specific public
room list.
@@ -516,9 +522,6 @@ class DirectoryHandler:
room_id
visibility: either "public" or "private"
"""
- if visibility not in ["public", "private"]:
- raise SynapseError(400, "Invalid visibility setting")
-
await self.store.set_room_is_public_appservice(
room_id, appservice_id, network_id, visibility == "public"
)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 986ffed3d5..5f7e0a1f79 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -45,6 +45,7 @@ from synapse.api.errors import (
Codes,
FederationDeniedError,
FederationError,
+ FederationPullAttemptBackoffError,
HttpResponseException,
LimitExceededError,
NotFoundError,
@@ -781,15 +782,27 @@ class FederationHandler:
# Send the signed event back to the room, and potentially receive some
# further information about the room in the form of partial state events
- stripped_room_state = await self.federation_client.send_knock(
- target_hosts, event
- )
+ knock_response = await self.federation_client.send_knock(target_hosts, event)
# Store any stripped room state events in the "unsigned" key of the event.
# This is a bit of a hack and is cribbing off of invites. Basically we
# store the room state here and retrieve it again when this event appears
# in the invitee's sync stream. It is stripped out for all other local users.
- event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
+ stripped_room_state = (
+ knock_response.get("knock_room_state")
+ # Since v1.37, Synapse incorrectly used "knock_state_events" for this field.
+ # Thus, we also check for a 'knock_state_events' to support old instances.
+ # See https://github.com/matrix-org/synapse/issues/14088.
+ or knock_response.get("knock_state_events")
+ )
+
+ if stripped_room_state is None:
+ raise KeyError(
+ "Missing 'knock_room_state' (or legacy 'knock_state_events') field in "
+ "send_knock response"
+ )
+
+ event.unsigned["knock_room_state"] = stripped_room_state
context = EventContext.for_outlier(self._storage_controllers)
stream_id = await self._federation_event_handler.persist_events_and_notify(
@@ -1708,7 +1721,22 @@ class FederationHandler:
destination, event
)
break
+ except FederationPullAttemptBackoffError as exc:
+ # Log a warning about why we failed to process the event (the error message
+ # for `FederationPullAttemptBackoffError` is pretty good)
+ logger.warning("_sync_partial_state_room: %s", exc)
+ # We do not record a failed pull attempt when we backoff fetching a missing
+ # `prev_event` because not being able to fetch the `prev_events` just means
+ # we won't be able to de-outlier the pulled event. But we can still use an
+ # `outlier` in the state/auth chain for another event. So we shouldn't stop
+ # a downstream event from trying to pull it.
+ #
+ # This avoids a cascade of backoff for all events in the DAG downstream from
+ # one event backoff upstream.
except FederationError as e:
+ # TODO: We should `record_event_failed_pull_attempt` here,
+ # see https://github.com/matrix-org/synapse/issues/13700
+
if attempt == len(destinations) - 1:
# We have tried every remote server for this event. Give up.
# TODO(faster_joins) giving up isn't the right thing to do
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index b94641a87d..06e41b5cc0 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -44,6 +44,7 @@ from synapse.api.errors import (
AuthError,
Codes,
FederationError,
+ FederationPullAttemptBackoffError,
HttpResponseException,
RequestSendFailed,
SynapseError,
@@ -414,7 +415,9 @@ class FederationEventHandler:
# First, precalculate the joined hosts so that the federation sender doesn't
# need to.
- await self._event_creation_handler.cache_joined_hosts_for_event(event, context)
+ await self._event_creation_handler.cache_joined_hosts_for_events(
+ [(event, context)]
+ )
await self._check_for_soft_fail(event, context=context, origin=origin)
await self._run_push_actions_and_persist_event(event, context)
@@ -565,6 +568,9 @@ class FederationEventHandler:
event: partial-state event to be de-partial-stated
Raises:
+ FederationPullAttemptBackoffError if we are are deliberately not attempting
+ to pull the given event over federation because we've already done so
+ recently and are backing off.
FederationError if we fail to request state from the remote server.
"""
logger.info("Updating state for %s", event.event_id)
@@ -920,6 +926,18 @@ class FederationEventHandler:
context,
backfilled=backfilled,
)
+ except FederationPullAttemptBackoffError as exc:
+ # Log a warning about why we failed to process the event (the error message
+ # for `FederationPullAttemptBackoffError` is pretty good)
+ logger.warning("_process_pulled_event: %s", exc)
+ # We do not record a failed pull attempt when we backoff fetching a missing
+ # `prev_event` because not being able to fetch the `prev_events` just means
+ # we won't be able to de-outlier the pulled event. But we can still use an
+ # `outlier` in the state/auth chain for another event. So we shouldn't stop
+ # a downstream event from trying to pull it.
+ #
+ # This avoids a cascade of backoff for all events in the DAG downstream from
+ # one event backoff upstream.
except FederationError as e:
await self._store.record_event_failed_pull_attempt(
event.room_id, event_id, str(e)
@@ -966,6 +984,9 @@ class FederationEventHandler:
The event context.
Raises:
+ FederationPullAttemptBackoffError if we are are deliberately not attempting
+ to pull the given event over federation because we've already done so
+ recently and are backing off.
FederationError if we fail to get the state from the remote server after any
missing `prev_event`s.
"""
@@ -976,6 +997,18 @@ class FederationEventHandler:
seen = await self._store.have_events_in_timeline(prevs)
missing_prevs = prevs - seen
+ # If we've already recently attempted to pull this missing event, don't
+ # try it again so soon. Since we have to fetch all of the prev_events, we can
+ # bail early here if we find any to ignore.
+ prevs_to_ignore = await self._store.get_event_ids_to_not_pull_from_backoff(
+ room_id, missing_prevs
+ )
+ if len(prevs_to_ignore) > 0:
+ raise FederationPullAttemptBackoffError(
+ event_ids=prevs_to_ignore,
+ message=f"While computing context for event={event_id}, not attempting to pull missing prev_event={prevs_to_ignore[0]} because we already tried to pull recently (backing off).",
+ )
+
if not missing_prevs:
return await self._state_handler.compute_event_context(event)
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 860c82c110..9c335e6863 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -57,13 +57,7 @@ class InitialSyncHandler:
self.validator = EventValidator()
self.snapshot_cache: ResponseCache[
Tuple[
- str,
- Optional[StreamToken],
- Optional[StreamToken],
- str,
- Optional[int],
- bool,
- bool,
+ str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool
]
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
@@ -154,11 +148,6 @@ class InitialSyncHandler:
public_room_ids = await self.store.get_public_room_ids()
- if pagin_config.limit is not None:
- limit = pagin_config.limit
- else:
- limit = 10
-
serializer_options = SerializeEventConfig(as_client_event=as_client_event)
async def handle_room(event: RoomsForUser) -> None:
@@ -210,7 +199,7 @@ class InitialSyncHandler:
run_in_background(
self.store.get_recent_events_for_room,
event.room_id,
- limit=limit,
+ limit=pagin_config.limit,
end_token=room_end_token,
),
deferred_room_state,
@@ -360,15 +349,11 @@ class InitialSyncHandler:
member_event_id
)
- limit = pagin_config.limit if pagin_config else None
- if limit is None:
- limit = 10
-
leave_position = await self.store.get_position_for_event(member_event_id)
stream_token = leave_position.to_room_stream_token()
messages, token = await self.store.get_recent_events_for_room(
- room_id, limit=limit, end_token=stream_token
+ room_id, limit=pagin_config.limit, end_token=stream_token
)
messages = await filter_events_for_client(
@@ -420,10 +405,6 @@ class InitialSyncHandler:
now_token = self.hs.get_event_sources().get_current_token()
- limit = pagin_config.limit if pagin_config else None
- if limit is None:
- limit = 10
-
room_members = [
m
for m in current_state.values()
@@ -467,7 +448,7 @@ class InitialSyncHandler:
run_in_background(
self.store.get_recent_events_for_room,
room_id,
- limit=limit,
+ limit=pagin_config.limit,
end_token=now_token.room_key,
),
),
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index da1acea275..4e55ebba0b 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1390,7 +1390,7 @@ class EventCreationHandler:
extra_users=extra_users,
),
run_in_background(
- self.cache_joined_hosts_for_event, event, context
+ self.cache_joined_hosts_for_events, events_and_context
).addErrback(
log_failure, "cache_joined_hosts_for_event failed"
),
@@ -1491,62 +1491,65 @@ class EventCreationHandler:
await self.store.remove_push_actions_from_staging(event.event_id)
raise
- async def cache_joined_hosts_for_event(
- self, event: EventBase, context: EventContext
+ async def cache_joined_hosts_for_events(
+ self, events_and_context: List[Tuple[EventBase, EventContext]]
) -> None:
- """Precalculate the joined hosts at the event, when using Redis, so that
+ """Precalculate the joined hosts at each of the given events, when using Redis, so that
external federation senders don't have to recalculate it themselves.
"""
- if not self._external_cache.is_enabled():
- return
-
- # If external cache is enabled we should always have this.
- assert self._external_cache_joined_hosts_updates is not None
+ for event, _ in events_and_context:
+ if not self._external_cache.is_enabled():
+ return
- # We actually store two mappings, event ID -> prev state group,
- # state group -> joined hosts, which is much more space efficient
- # than event ID -> joined hosts.
- #
- # Note: We have to cache event ID -> prev state group, as we don't
- # store that in the DB.
- #
- # Note: We set the state group -> joined hosts cache if it hasn't been
- # set for a while, so that the expiry time is reset.
+ # If external cache is enabled we should always have this.
+ assert self._external_cache_joined_hosts_updates is not None
- state_entry = await self.state.resolve_state_groups_for_events(
- event.room_id, event_ids=event.prev_event_ids()
- )
+ # We actually store two mappings, event ID -> prev state group,
+ # state group -> joined hosts, which is much more space efficient
+ # than event ID -> joined hosts.
+ #
+ # Note: We have to cache event ID -> prev state group, as we don't
+ # store that in the DB.
+ #
+ # Note: We set the state group -> joined hosts cache if it hasn't been
+ # set for a while, so that the expiry time is reset.
- if state_entry.state_group:
- await self._external_cache.set(
- "event_to_prev_state_group",
- event.event_id,
- state_entry.state_group,
- expiry_ms=60 * 60 * 1000,
+ state_entry = await self.state.resolve_state_groups_for_events(
+ event.room_id, event_ids=event.prev_event_ids()
)
- if state_entry.state_group in self._external_cache_joined_hosts_updates:
- return
+ if state_entry.state_group:
+ await self._external_cache.set(
+ "event_to_prev_state_group",
+ event.event_id,
+ state_entry.state_group,
+ expiry_ms=60 * 60 * 1000,
+ )
- state = await state_entry.get_state(
- self._storage_controllers.state, StateFilter.all()
- )
- with opentracing.start_active_span("get_joined_hosts"):
- joined_hosts = await self.store.get_joined_hosts(
- event.room_id, state, state_entry
+ if state_entry.state_group in self._external_cache_joined_hosts_updates:
+ return
+
+ state = await state_entry.get_state(
+ self._storage_controllers.state, StateFilter.all()
)
+ with opentracing.start_active_span("get_joined_hosts"):
+ joined_hosts = await self.store.get_joined_hosts(
+ event.room_id, state, state_entry
+ )
- # Note that the expiry times must be larger than the expiry time in
- # _external_cache_joined_hosts_updates.
- await self._external_cache.set(
- "get_joined_hosts",
- str(state_entry.state_group),
- list(joined_hosts),
- expiry_ms=60 * 60 * 1000,
- )
+ # Note that the expiry times must be larger than the expiry time in
+ # _external_cache_joined_hosts_updates.
+ await self._external_cache.set(
+ "get_joined_hosts",
+ str(state_entry.state_group),
+ list(joined_hosts),
+ expiry_ms=60 * 60 * 1000,
+ )
- self._external_cache_joined_hosts_updates[state_entry.state_group] = None
+ self._external_cache_joined_hosts_updates[
+ state_entry.state_group
+ ] = None
async def _validate_canonical_alias(
self,
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 1f83bab836..a4ca9cb8b4 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -458,11 +458,6 @@ class PaginationHandler:
# `/messages` should still works with live tokens when manually provided.
assert from_token.room_key.topological is not None
- if pagin_config.limit is None:
- # This shouldn't happen as we've set a default limit before this
- # gets called.
- raise Exception("limit not set")
-
room_token = from_token.room_key
async with self.pagination_lock.read(room_id):
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 4e575ffbaa..2670e561d7 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1596,7 +1596,9 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
self,
user: UserID,
from_key: Optional[int],
- limit: Optional[int] = None,
+ # Having a default limit doesn't match the EventSource API, but some
+ # callers do not provide it. It is unused in this class.
+ limit: int = 0,
room_ids: Optional[Collection[str]] = None,
is_guest: bool = False,
explicit_room_id: Optional[str] = None,
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 4a7ec9e426..ac01582442 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -257,7 +257,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
self,
user: UserID,
from_key: int,
- limit: Optional[int],
+ limit: int,
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index cc5e45c241..0a0c6d938e 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import enum
import logging
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
@@ -20,7 +21,7 @@ from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event
from synapse.logging.opentracing import trace
-from synapse.storage.databases.main.relations import _RelatedEvent
+from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, StreamToken, UserID
from synapse.visibility import filter_events_for_client
@@ -32,6 +33,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class ThreadsListInclude(str, enum.Enum):
+ """Valid values for the 'include' flag of /threads."""
+
+ all = "all"
+ participated = "participated"
+
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
# The latest event in the thread.
@@ -108,9 +116,6 @@ class RelationsHandler:
if event is None:
raise SynapseError(404, "Unknown parent event.")
- # TODO Update pagination config to not allow None limits.
- assert pagin_config.limit is not None
-
# Note that ignored users are not passed into get_relations_for_event
# below. Ignored users are handled in filter_events_for_client (and by
# not passing them in here we should get a better cache hit rate).
@@ -482,3 +487,79 @@ class RelationsHandler:
results.setdefault(event_id, BundledAggregations()).replace = edit
return results
+
+ async def get_threads(
+ self,
+ requester: Requester,
+ room_id: str,
+ include: ThreadsListInclude,
+ limit: int = 5,
+ from_token: Optional[ThreadsNextBatch] = None,
+ ) -> JsonDict:
+ """Get related events of a event, ordered by topological ordering.
+
+ Args:
+ requester: The user requesting the relations.
+ room_id: The room the event belongs to.
+ include: One of "all" or "participated" to indicate which threads should
+ be returned.
+ limit: Only fetch the most recent `limit` events.
+ from_token: Fetch rows from the given token, or from the start if None.
+
+ Returns:
+ The pagination chunk.
+ """
+
+ user_id = requester.user.to_string()
+
+ # TODO Properly handle a user leaving a room.
+ (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
+ room_id, requester, allow_departed_users=True
+ )
+
+ # Note that ignored users are not passed into get_relations_for_event
+ # below. Ignored users are handled in filter_events_for_client (and by
+ # not passing them in here we should get a better cache hit rate).
+ thread_roots, next_batch = await self._main_store.get_threads(
+ room_id=room_id, limit=limit, from_token=from_token
+ )
+
+ events = await self._main_store.get_events_as_list(thread_roots)
+
+ if include == ThreadsListInclude.participated:
+ # Pre-seed thread participation with whether the requester sent the event.
+ participated = {event.event_id: event.sender == user_id for event in events}
+ # For events the requester did not send, check the database for whether
+ # the requester sent a threaded reply.
+ participated.update(
+ await self._main_store.get_threads_participated(
+ [eid for eid, p in participated.items() if not p],
+ user_id,
+ )
+ )
+
+ # Limit the returned threads to those the user has participated in.
+ events = [event for event in events if participated[event.event_id]]
+
+ events = await filter_events_for_client(
+ self._storage_controllers,
+ user_id,
+ events,
+ is_peeking=(member_event_id is None),
+ )
+
+ aggregations = await self.get_bundled_aggregations(
+ events, requester.user.to_string()
+ )
+
+ now = self._clock.time_msec()
+ serialized_events = self._event_serializer.serialize_events(
+ events, now, bundle_aggregations=aggregations
+ )
+
+ return_value: JsonDict = {"chunk": serialized_events}
+
+ if next_batch:
+ return_value["next_batch"] = str(next_batch)
+
+ return return_value
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 57ab05ad25..4e1aacb408 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1646,7 +1646,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
self,
user: UserID,
from_key: RoomStreamToken,
- limit: Optional[int],
+ limit: int,
room_ids: Collection[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index f953691669..a0ea719430 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -513,7 +513,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
self,
user: UserID,
from_key: int,
- limit: Optional[int],
+ limit: int,
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index eced182fd5..a75386f6a0 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -13,18 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
import logging
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
- Iterable,
List,
Mapping,
Optional,
- Set,
Tuple,
Union,
)
@@ -38,7 +35,7 @@ from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.storage.state import StateFilter
-from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRuleEvaluator
+from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator
from synapse.util.caches import register_cache
from synapse.util.metrics import measure_func
from synapse.visibility import filter_event_for_clients_with_state
@@ -117,9 +114,6 @@ class BulkPushRuleEvaluator:
resizable=False,
)
- # Whether to support MSC3772 is supported.
- self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled
-
async def _get_rules_for_event(
self,
event: EventBase,
@@ -200,51 +194,6 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level
- async def _get_mutual_relations(
- self, parent_id: str, rules: Iterable[Tuple[PushRule, bool]]
- ) -> Dict[str, Set[Tuple[str, str]]]:
- """
- Fetch event metadata for events which related to the same event as the given event.
-
- If the given event has no relation information, returns an empty dictionary.
-
- Args:
- parent_id: The event ID which is targeted by relations.
- rules: The push rules which will be processed for this event.
-
- Returns:
- A dictionary of relation type to:
- A set of tuples of:
- The sender
- The event type
- """
-
- # If the experimental feature is not enabled, skip fetching relations.
- if not self._relations_match_enabled:
- return {}
-
- # Pre-filter to figure out which relation types are interesting.
- rel_types = set()
- for rule, enabled in rules:
- if not enabled:
- continue
-
- for condition in rule.conditions:
- if condition["kind"] != "org.matrix.msc3772.relation_match":
- continue
-
- # rel_type is required.
- rel_type = condition.get("rel_type")
- if rel_type:
- rel_types.add(rel_type)
-
- # If no valid rules were found, no mutual relations.
- if not rel_types:
- return {}
-
- # If any valid rules were found, fetch the mutual relations.
- return await self.store.get_mutual_event_relations(parent_id, rel_types)
-
@measure_func("action_for_event_by_user")
async def action_for_event_by_user(
self, event: EventBase, context: EventContext
@@ -276,23 +225,18 @@ class BulkPushRuleEvaluator:
sender_power_level,
) = await self._get_power_levels_and_sender_level(event, context)
+ # Find the event's thread ID.
relation = relation_from_event(event)
- # If the event does not have a relation, then cannot have any mutual
- # relations or thread ID.
- relations = {}
+ # If the event does not have a relation, then it cannot have a thread ID.
thread_id = MAIN_TIMELINE
if relation:
- relations = await self._get_mutual_relations(
- relation.parent_id,
- itertools.chain(*(r.rules() for r in rules_by_user.values())),
- )
# Recursively attempt to find the thread this event relates to.
if relation.rel_type == RelationTypes.THREAD:
thread_id = relation.parent_id
else:
# Since the event has not yet been persisted we check whether
# the parent is part of a thread.
- thread_id = await self.store.get_thread_id(relation.parent_id) or "main"
+ thread_id = await self.store.get_thread_id(relation.parent_id)
# It's possible that old room versions have non-integer power levels (floats or
# strings). Workaround this by explicitly converting to int.
@@ -306,8 +250,6 @@ class BulkPushRuleEvaluator:
room_member_count,
sender_power_level,
notification_levels,
- relations,
- self._relations_match_enabled,
)
users = rules_by_user.keys()
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 61abb529c8..976c283360 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -39,6 +39,16 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
self.store = hs.get_datastores().main
self.registration_handler = hs.get_registration_handler()
+ # Default value if the worker that sent the replication request did not include
+ # an 'approved' property.
+ if (
+ hs.config.experimental.msc3866.enabled
+ and hs.config.experimental.msc3866.require_approval_for_new_accounts
+ ):
+ self._approval_default = False
+ else:
+ self._approval_default = True
+
@staticmethod
async def _serialize_payload( # type: ignore[override]
user_id: str,
@@ -92,6 +102,12 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
await self.registration_handler.check_registration_ratelimit(content["address"])
+ # Always default admin users to approved (since it means they were created by
+ # an admin).
+ approved_default = self._approval_default
+ if content["admin"]:
+ approved_default = True
+
await self.registration_handler.register_with_store(
user_id=user_id,
password_hash=content["password_hash"],
@@ -103,7 +119,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
user_type=content["user_type"],
address=content["address"],
shadow_banned=content["shadow_banned"],
- approved=content["approved"],
+ approved=content.get("approved", approved_default),
)
return 200, {}
diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py
index bc1b18c92d..f17b4c8d22 100644
--- a/synapse/rest/client/directory.py
+++ b/synapse/rest/client/directory.py
@@ -13,15 +13,22 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+from pydantic import StrictStr
+from typing_extensions import Literal
from twisted.web.server import Request
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.servlet import (
+ RestServlet,
+ parse_and_validate_json_object_from_request,
+)
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
+from synapse.rest.models import RequestBodyModel
from synapse.types import JsonDict, RoomAlias
if TYPE_CHECKING:
@@ -54,6 +61,12 @@ class ClientDirectoryServer(RestServlet):
return 200, res
+ class PutBody(RequestBodyModel):
+ # TODO: get Pydantic to validate that this is a valid room id?
+ room_id: StrictStr
+ # `servers` is unspecced
+ servers: Optional[List[StrictStr]] = None
+
async def on_PUT(
self, request: SynapseRequest, room_alias: str
) -> Tuple[int, JsonDict]:
@@ -61,31 +74,22 @@ class ClientDirectoryServer(RestServlet):
raise SynapseError(400, "Room alias invalid", errcode=Codes.INVALID_PARAM)
room_alias_obj = RoomAlias.from_string(room_alias)
- content = parse_json_object_from_request(request)
- if "room_id" not in content:
- raise SynapseError(
- 400, 'Missing params: ["room_id"]', errcode=Codes.BAD_JSON
- )
+ content = parse_and_validate_json_object_from_request(request, self.PutBody)
logger.debug("Got content: %s", content)
logger.debug("Got room name: %s", room_alias_obj.to_string())
- room_id = content["room_id"]
- servers = content["servers"] if "servers" in content else None
-
- logger.debug("Got room_id: %s", room_id)
- logger.debug("Got servers: %s", servers)
+ logger.debug("Got room_id: %s", content.room_id)
+ logger.debug("Got servers: %s", content.servers)
- # TODO(erikj): Check types.
-
- room = await self.store.get_room(room_id)
+ room = await self.store.get_room(content.room_id)
if room is None:
raise SynapseError(400, "Room does not exist")
requester = await self.auth.get_user_by_req(request)
await self.directory_handler.create_association(
- requester, room_alias_obj, room_id, servers
+ requester, room_alias_obj, content.room_id, content.servers
)
return 200, {}
@@ -137,16 +141,18 @@ class ClientDirectoryListServer(RestServlet):
return 200, {"visibility": "public" if room["is_public"] else "private"}
+ class PutBody(RequestBodyModel):
+ visibility: Literal["public", "private"] = "public"
+
async def on_PUT(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- content = parse_json_object_from_request(request)
- visibility = content.get("visibility", "public")
+ content = parse_and_validate_json_object_from_request(request, self.PutBody)
await self.directory_handler.edit_published_room_list(
- requester, room_id, visibility
+ requester, room_id, content.visibility
)
return 200, {}
@@ -163,12 +169,14 @@ class ClientAppserviceDirectoryListServer(RestServlet):
self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
+ class PutBody(RequestBodyModel):
+ visibility: Literal["public", "private"] = "public"
+
async def on_PUT(
self, request: SynapseRequest, network_id: str, room_id: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
- visibility = content.get("visibility", "public")
- return await self._edit(request, network_id, room_id, visibility)
+ content = parse_and_validate_json_object_from_request(request, self.PutBody)
+ return await self._edit(request, network_id, room_id, content.visibility)
async def on_DELETE(
self, request: SynapseRequest, network_id: str, room_id: str
@@ -176,7 +184,11 @@ class ClientAppserviceDirectoryListServer(RestServlet):
return await self._edit(request, network_id, room_id, "private")
async def _edit(
- self, request: SynapseRequest, network_id: str, room_id: str, visibility: str
+ self,
+ request: SynapseRequest,
+ network_id: str,
+ room_id: str,
+ visibility: Literal["public", "private"],
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if not requester.app_service:
diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py
index 916f5230f1..782e7d14e8 100644
--- a/synapse/rest/client/events.py
+++ b/synapse/rest/client/events.py
@@ -50,7 +50,9 @@ class EventStreamRestServlet(RestServlet):
raise SynapseError(400, "Guest users must specify room_id param")
room_id = parse_string(request, "room_id")
- pagin_config = await PaginationConfig.from_request(self.store, request)
+ pagin_config = await PaginationConfig.from_request(
+ self.store, request, default_limit=10
+ )
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if b"timeout" in args:
try:
diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py
index cfadcb8e50..9b1bb8b521 100644
--- a/synapse/rest/client/initial_sync.py
+++ b/synapse/rest/client/initial_sync.py
@@ -39,7 +39,9 @@ class InitialSyncRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
args: Dict[bytes, List[bytes]] = request.args # type: ignore
as_client_event = b"raw" not in args
- pagination_config = await PaginationConfig.from_request(self.store, request)
+ pagination_config = await PaginationConfig.from_request(
+ self.store, request, default_limit=10
+ )
include_archived = parse_boolean(request, "archived", default=False)
content = await self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(),
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 14dec7ac4e..18a282b22c 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple
-from synapse.api.constants import ReceiptTypes
+from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -83,7 +83,7 @@ class ReceiptRestServlet(RestServlet):
)
# Ensure the event ID roughly correlates to the thread ID.
- if thread_id != await self._main_store.get_thread_id(event_id):
+ if not await self._is_event_in_thread(event_id, thread_id):
raise SynapseError(
400,
f"event_id {event_id} is not related to thread {thread_id}",
@@ -109,6 +109,46 @@ class ReceiptRestServlet(RestServlet):
return 200, {}
+ async def _is_event_in_thread(self, event_id: str, thread_id: str) -> bool:
+ """
+ The event must be related to the thread ID (in a vague sense) to ensure
+ clients aren't sending bogus receipts.
+
+ A thread ID is considered valid for a given event E if:
+
+ 1. E has a thread relation which matches the thread ID;
+ 2. E has another event which has a thread relation to E matching the
+ thread ID; or
+ 3. E is recursively related (via any rel_type) to an event which
+ satisfies 1 or 2.
+
+ Given the following DAG:
+
+ A <---[m.thread]-- B <--[m.annotation]-- C
+ ^
+ |--[m.reference]-- D <--[m.annotation]-- E
+
+ It is valid to send a receipt for thread A on A, B, C, D, or E.
+
+ It is valid to send a receipt for the main timeline on A, D, and E.
+
+ Args:
+ event_id: The event ID to check.
+ thread_id: The thread ID the event is potentially part of.
+
+ Returns:
+ True if the event belongs to the given thread, otherwise False.
+ """
+
+ # If the receipt is on the main timeline, it is enough to check whether
+ # the event is directly related to a thread.
+ if thread_id == MAIN_TIMELINE:
+ return MAIN_TIMELINE == await self._main_store.get_thread_id(event_id)
+
+ # Otherwise, check if the event is directly part of a thread, or is the
+ # root message (or related to the root message) of a thread.
+ return thread_id == await self._main_store.get_thread_id_for_receipts(event_id)
+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReceiptRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index b31ce5a0d3..9dd59196d9 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -13,12 +13,15 @@
# limitations under the License.
import logging
+import re
from typing import TYPE_CHECKING, Optional, Tuple
+from synapse.handlers.relations import ThreadsListInclude
from synapse.http.server import HttpServer
-from synapse.http.servlet import RestServlet
+from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
+from synapse.storage.databases.main.relations import ThreadsNextBatch
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict
@@ -78,5 +81,45 @@ class RelationPaginationServlet(RestServlet):
return 200, result
+class ThreadsServlet(RestServlet):
+ PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/threads"),)
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastores().main
+ self._relations_handler = hs.get_relations_handler()
+
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+
+ limit = parse_integer(request, "limit", default=5)
+ from_token_str = parse_string(request, "from")
+ include = parse_string(
+ request,
+ "include",
+ default=ThreadsListInclude.all.value,
+ allowed_values=[v.value for v in ThreadsListInclude],
+ )
+
+ # Return the relations
+ from_token = None
+ if from_token_str:
+ from_token = ThreadsNextBatch.from_string(from_token_str)
+
+ result = await self._relations_handler.get_threads(
+ requester=requester,
+ room_id=room_id,
+ include=ThreadsListInclude(include),
+ limit=limit,
+ from_token=from_token,
+ )
+
+ return 200, result
+
+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RelationPaginationServlet(hs).register(http_server)
+ ThreadsServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index b6dedbed04..01e5079963 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -729,7 +729,9 @@ class RoomInitialSyncRestServlet(RestServlet):
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
- pagination_config = await PaginationConfig.from_request(self.store, request)
+ pagination_config = await PaginationConfig.from_request(
+ self.store, request, default_limit=10
+ )
content = await self.initial_sync_handler.room_initial_sync(
room_id=room_id, requester=requester, pagin_config=pagination_config
)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index d1d2e5f7e3..4b87ee978a 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -76,6 +76,7 @@ class VersionsRestServlet(RestServlet):
"v1.1",
"v1.2",
"v1.3",
+ "v1.4",
],
# as per MSC1497:
"unstable_features": {
@@ -113,6 +114,8 @@ class VersionsRestServlet(RestServlet):
"org.matrix.msc3882": self.config.experimental.msc3882_enabled,
# Adds support for remotely enabling/disabling pushers, as per MSC3881
"org.matrix.msc3881": self.config.experimental.msc3881_enabled,
+ # Adds support for filtering /messages by event relation.
+ "org.matrix.msc3874": self.config.experimental.msc3874_enabled,
},
},
)
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 3b8ed1f7ee..ddb7397714 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -244,12 +244,18 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
# redacted.
self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,))
self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,))
+ self._attempt_to_invalidate_cache("get_thread_id", (redacts,))
+ self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,))
if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
self._attempt_to_invalidate_cache(
"get_invited_rooms_for_local_user", (state_key,)
)
+ self._attempt_to_invalidate_cache(
+ "get_rooms_for_user_with_stream_ordering", (state_key,)
+ )
+ self._attempt_to_invalidate_cache("get_rooms_for_user", (state_key,))
if relates_to:
self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,))
@@ -259,9 +265,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,))
- self._attempt_to_invalidate_cache(
- "get_mutual_event_relations_for_rel_type", (relates_to,)
- )
+ self._attempt_to_invalidate_cache("get_threads", (room_id,))
async def invalidate_cache_and_stream(
self, cache_name: str, keys: Tuple[Any, ...]
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 6b9a629edd..309a4ba664 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1501,6 +1501,12 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
event_id: The event that failed to be fetched or processed
cause: The error message or reason that we failed to pull the event
"""
+ logger.debug(
+ "record_event_failed_pull_attempt room_id=%s, event_id=%s, cause=%s",
+ room_id,
+ event_id,
+ cause,
+ )
await self.db_pool.runInteraction(
"record_event_failed_pull_attempt",
self._record_event_failed_pull_attempt_upsert_txn,
@@ -1530,6 +1536,54 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(sql, (room_id, event_id, 1, self._clock.time_msec(), cause))
+ @trace
+ async def get_event_ids_to_not_pull_from_backoff(
+ self,
+ room_id: str,
+ event_ids: Collection[str],
+ ) -> List[str]:
+ """
+ Filter down the events to ones that we've failed to pull before recently. Uses
+ exponential backoff.
+
+ Args:
+ room_id: The room that the events belong to
+ event_ids: A list of events to filter down
+
+ Returns:
+ List of event_ids that should not be attempted to be pulled
+ """
+ event_failed_pull_attempts = await self.db_pool.simple_select_many_batch(
+ table="event_failed_pull_attempts",
+ column="event_id",
+ iterable=event_ids,
+ keyvalues={},
+ retcols=(
+ "event_id",
+ "last_attempt_ts",
+ "num_attempts",
+ ),
+ desc="get_event_ids_to_not_pull_from_backoff",
+ )
+
+ current_time = self._clock.time_msec()
+ return [
+ event_failed_pull_attempt["event_id"]
+ for event_failed_pull_attempt in event_failed_pull_attempts
+ # Exponential back-off (up to the upper bound) so we don't try to
+ # pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc.
+ if current_time
+ < event_failed_pull_attempt["last_attempt_ts"]
+ + (
+ 2
+ ** min(
+ event_failed_pull_attempt["num_attempts"],
+ BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
+ )
+ )
+ * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS
+ ]
+
async def get_missing_events(
self,
room_id: str,
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 332e13d1c9..f070e6e88a 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -310,11 +310,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
event_push_actions_done = progress.get("event_push_actions_done", False)
def add_thread_id_txn(
- txn: LoggingTransaction, table_name: str, start_stream_ordering: int
+ txn: LoggingTransaction, start_stream_ordering: int
) -> int:
- sql = f"""
+ sql = """
SELECT stream_ordering
- FROM {table_name}
+ FROM event_push_actions
WHERE
thread_id IS NULL
AND stream_ordering > ?
@@ -326,7 +326,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# No more rows to process.
rows = txn.fetchall()
if not rows:
- progress[f"{table_name}_done"] = True
+ progress["event_push_actions_done"] = True
self.db_pool.updates._background_update_progress_txn(
txn, "event_push_backfill_thread_id", progress
)
@@ -335,16 +335,65 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# Update the thread ID for any of those rows.
max_stream_ordering = rows[-1][0]
- sql = f"""
- UPDATE {table_name}
+ sql = """
+ UPDATE event_push_actions
SET thread_id = 'main'
- WHERE stream_ordering <= ? AND thread_id IS NULL
+ WHERE ? < stream_ordering AND stream_ordering <= ? AND thread_id IS NULL
"""
- txn.execute(sql, (max_stream_ordering,))
+ txn.execute(
+ sql,
+ (
+ start_stream_ordering,
+ max_stream_ordering,
+ ),
+ )
# Update progress.
processed_rows = txn.rowcount
- progress[f"max_{table_name}_stream_ordering"] = max_stream_ordering
+ progress["max_event_push_actions_stream_ordering"] = max_stream_ordering
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "event_push_backfill_thread_id", progress
+ )
+
+ return processed_rows
+
+ def add_thread_id_summary_txn(txn: LoggingTransaction) -> int:
+ min_user_id = progress.get("max_summary_user_id", "")
+ min_room_id = progress.get("max_summary_room_id", "")
+
+ # Slightly overcomplicated query for getting the Nth user ID / room
+ # ID tuple, or the last if there are less than N remaining.
+ sql = """
+ SELECT user_id, room_id FROM (
+ SELECT user_id, room_id FROM event_push_summary
+ WHERE (user_id, room_id) > (?, ?)
+ AND thread_id IS NULL
+ ORDER BY user_id, room_id
+ LIMIT ?
+ ) AS e
+ ORDER BY user_id DESC, room_id DESC
+ LIMIT 1
+ """
+
+ txn.execute(sql, (min_user_id, min_room_id, batch_size))
+ row = txn.fetchone()
+ if not row:
+ return 0
+
+ max_user_id, max_room_id = row
+
+ sql = """
+ UPDATE event_push_summary
+ SET thread_id = 'main'
+ WHERE
+ (?, ?) < (user_id, room_id) AND (user_id, room_id) <= (?, ?)
+ AND thread_id IS NULL
+ """
+ txn.execute(sql, (min_user_id, min_room_id, max_user_id, max_room_id))
+ processed_rows = txn.rowcount
+
+ progress["max_summary_user_id"] = max_user_id
+ progress["max_summary_room_id"] = max_room_id
self.db_pool.updates._background_update_progress_txn(
txn, "event_push_backfill_thread_id", progress
)
@@ -360,15 +409,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
result = await self.db_pool.runInteraction(
"event_push_backfill_thread_id",
add_thread_id_txn,
- "event_push_actions",
progress.get("max_event_push_actions_stream_ordering", 0),
)
else:
result = await self.db_pool.runInteraction(
"event_push_backfill_thread_id",
- add_thread_id_txn,
- "event_push_summary",
- progress.get("max_event_push_summary_stream_ordering", 0),
+ add_thread_id_summary_txn,
)
# Only done after the event_push_summary table is done.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 3e15827986..6698cbf664 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -35,7 +35,7 @@ import attr
from prometheus_client import Counter
import synapse.metrics
-from synapse.api.constants import EventContentFields, EventTypes
+from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, relation_from_event
@@ -1616,7 +1616,7 @@ class PersistEventsStore:
)
# Remove from relations table.
- self._handle_redact_relations(txn, event.redacts)
+ self._handle_redact_relations(txn, event.room_id, event.redacts)
# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
@@ -1866,6 +1866,34 @@ class PersistEventsStore:
},
)
+ if relation.rel_type == RelationTypes.THREAD:
+ # Upsert into the threads table, but only overwrite the value if the
+ # new event is of a later topological order OR if the topological
+ # ordering is equal, but the stream ordering is later.
+ sql = """
+ INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering)
+ VALUES (?, ?, ?, ?, ?)
+ ON CONFLICT (room_id, thread_id)
+ DO UPDATE SET
+ latest_event_id = excluded.latest_event_id,
+ topological_ordering = excluded.topological_ordering,
+ stream_ordering = excluded.stream_ordering
+ WHERE
+ threads.topological_ordering <= excluded.topological_ordering AND
+ threads.stream_ordering < excluded.stream_ordering
+ """
+
+ txn.execute(
+ sql,
+ (
+ event.room_id,
+ relation.parent_id,
+ event.event_id,
+ event.depth,
+ event.internal_metadata.stream_ordering,
+ ),
+ )
+
def _handle_insertion_event(
self, txn: LoggingTransaction, event: EventBase
) -> None:
@@ -1989,13 +2017,14 @@ class PersistEventsStore:
txn.execute(sql, (batch_id,))
def _handle_redact_relations(
- self, txn: LoggingTransaction, redacted_event_id: str
+ self, txn: LoggingTransaction, room_id: str, redacted_event_id: str
) -> None:
"""Handles receiving a redaction and checking whether the redacted event
has any relations which must be removed from the database.
Args:
txn
+ room_id: The room ID of the event that was redacted.
redacted_event_id: The event that was redacted.
"""
@@ -2025,9 +2054,7 @@ class PersistEventsStore:
txn, self.store.get_thread_participated, (redacted_relates_to,)
)
self.store._invalidate_cache_and_stream(
- txn,
- self.store.get_mutual_event_relations_for_rel_type,
- (redacted_relates_to,),
+ txn, self.store.get_threads, (room_id,)
)
self.db_pool.simple_delete_txn(
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 221864d9cc..7bc7f2f33e 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -474,7 +474,7 @@ class EventsWorkerStore(SQLBaseStore):
return []
# there may be duplicates so we cast the list to a set
- event_entry_map = await self._get_events_from_cache_or_db(
+ event_entry_map = await self.get_unredacted_events_from_cache_or_db(
set(event_ids), allow_rejected=allow_rejected
)
@@ -509,7 +509,9 @@ class EventsWorkerStore(SQLBaseStore):
continue
redacted_event_id = entry.event.redacts
- event_map = await self._get_events_from_cache_or_db([redacted_event_id])
+ event_map = await self.get_unredacted_events_from_cache_or_db(
+ [redacted_event_id]
+ )
original_event_entry = event_map.get(redacted_event_id)
if not original_event_entry:
# we don't have the redacted event (or it was rejected).
@@ -588,11 +590,16 @@ class EventsWorkerStore(SQLBaseStore):
return events
@cancellable
- async def _get_events_from_cache_or_db(
- self, event_ids: Iterable[str], allow_rejected: bool = False
+ async def get_unredacted_events_from_cache_or_db(
+ self,
+ event_ids: Iterable[str],
+ allow_rejected: bool = False,
) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.
+ Note that the events pulled by this function will not have any redactions
+ applied, and no guarantee is made about the ordering of the events returned.
+
If events are pulled from the database, they will be cached for future lookups.
Unknown events are omitted from the response.
@@ -1495,21 +1502,15 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
a dict {event_id -> bool}
"""
- # if the event cache contains the event, obviously we've seen it.
-
- cache_results = {
- event_id
- for event_id in event_ids
- if await self._get_event_cache.contains((event_id,))
- }
- results = dict.fromkeys(cache_results, True)
- remaining = [
- event_id for event_id in event_ids if event_id not in cache_results
- ]
- if not remaining:
- return results
+ # TODO: We used to query the _get_event_cache here as a fast-path before
+ # hitting the database. For if an event were in the cache, we've presumably
+ # seen it before.
+ #
+ # But this is currently an invalid assumption due to the _get_event_cache
+ # not being invalidated when purging events from a room. The optimisation can
+ # be re-added after https://github.com/matrix-org/synapse/issues/13476
- def have_seen_events_txn(txn: LoggingTransaction) -> None:
+ def have_seen_events_txn(txn: LoggingTransaction) -> Dict[str, bool]:
# we deliberately do *not* query the database for room_id, to make the
# query an index-only lookup on `events_event_id_key`.
#
@@ -1517,16 +1518,17 @@ class EventsWorkerStore(SQLBaseStore):
sql = "SELECT event_id FROM events AS e WHERE "
clause, args = make_in_list_sql_clause(
- txn.database_engine, "e.event_id", remaining
+ txn.database_engine, "e.event_id", event_ids
)
txn.execute(sql + clause, args)
found_events = {eid for eid, in txn}
# ... and then we can update the results for each key
- results.update({eid: (eid in found_events) for eid in remaining})
+ return {eid: (eid in found_events) for eid in event_ids}
- await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn)
- return results
+ return await self.db_pool.runInteraction(
+ "have_seen_events", have_seen_events_txn
+ )
@cached(max_entries=100000, tree=True)
async def have_seen_event(self, room_id: str, event_id: str) -> bool:
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 8295322b0e..51416b2236 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -29,7 +29,6 @@ from typing import (
)
from synapse.api.errors import StoreError
-from synapse.config.homeserver import ExperimentalConfig
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@@ -63,9 +62,7 @@ logger = logging.getLogger(__name__)
def _load_rules(
- rawrules: List[JsonDict],
- enabled_map: Dict[str, bool],
- experimental_config: ExperimentalConfig,
+ rawrules: List[JsonDict], enabled_map: Dict[str, bool]
) -> FilteredPushRules:
"""Take the DB rows returned from the DB and convert them into a full
`FilteredPushRules` object.
@@ -83,9 +80,7 @@ def _load_rules(
push_rules = PushRules(ruleslist)
- filtered_rules = FilteredPushRules(
- push_rules, enabled_map, msc3772_enabled=experimental_config.msc3772_enabled
- )
+ filtered_rules = FilteredPushRules(push_rules, enabled_map)
return filtered_rules
@@ -165,7 +160,7 @@ class PushRulesWorkerStore(
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
- return _load_rules(rows, enabled_map, self.hs.config.experimental)
+ return _load_rules(rows, enabled_map)
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list(
@@ -224,9 +219,7 @@ class PushRulesWorkerStore(
results: Dict[str, FilteredPushRules] = {}
for user_id, rules in raw_rules.items():
- results[user_id] = _load_rules(
- rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental
- )
+ results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
return results
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 246f78ac1f..dc6989527e 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -418,6 +418,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type = event_entry.setdefault(row["receipt_type"], {})
receipt_type[row["user_id"]] = db_to_json(row["data"])
+ if row["thread_id"]:
+ receipt_type[row["user_id"]]["thread_id"] = row["thread_id"]
results = {
room_id: [results[room_id]] if room_id in results else []
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 116abef9de..1de62ee9df 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -14,6 +14,7 @@
import logging
from typing import (
+ TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
@@ -28,19 +29,48 @@ from typing import (
import attr
-from synapse.api.constants import RelationTypes
+from synapse.api.constants import MAIN_TIMELINE, RelationTypes
+from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+ make_in_list_sql_clause,
+)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
+class ThreadsNextBatch:
+ topological_ordering: int
+ stream_ordering: int
+
+ def __str__(self) -> str:
+ return f"{self.topological_ordering}_{self.stream_ordering}"
+
+ @classmethod
+ def from_string(cls, string: str) -> "ThreadsNextBatch":
+ """
+ Creates a ThreadsNextBatch from its textual representation.
+ """
+ try:
+ keys = (int(s) for s in string.split("_"))
+ return cls(*keys)
+ except Exception:
+ raise SynapseError(400, "Invalid threads token")
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class _RelatedEvent:
"""
Contains enough information about a related event in order to properly filter
@@ -56,6 +86,76 @@ class _RelatedEvent:
class RelationsWorkerStore(SQLBaseStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ self.db_pool.updates.register_background_update_handler(
+ "threads_backfill", self._backfill_threads
+ )
+
+ async def _backfill_threads(self, progress: JsonDict, batch_size: int) -> int:
+ """Backfill the threads table."""
+
+ def threads_backfill_txn(txn: LoggingTransaction) -> int:
+ last_thread_id = progress.get("last_thread_id", "")
+
+ # Get the latest event in each thread by topo ordering / stream ordering.
+ #
+ # Note that the MAX(event_id) is needed to abide by the rules of group by,
+ # but doesn't actually do anything since there should only be a single event
+ # ID per topo/stream ordering pair.
+ sql = f"""
+ SELECT room_id, relates_to_id, MAX(topological_ordering), MAX(stream_ordering), MAX(event_id)
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ relates_to_id > ? AND
+ relation_type = '{RelationTypes.THREAD}'
+ GROUP BY room_id, relates_to_id
+ ORDER BY relates_to_id
+ LIMIT ?
+ """
+ txn.execute(sql, (last_thread_id, batch_size))
+
+ # No more rows to process.
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ # Insert the rows into the threads table. If a matching thread already exists,
+ # assume it is from a newer event.
+ sql = """
+ INSERT INTO threads (room_id, thread_id, topological_ordering, stream_ordering, latest_event_id)
+ VALUES %s
+ ON CONFLICT (room_id, thread_id)
+ DO NOTHING
+ """
+ if isinstance(txn.database_engine, PostgresEngine):
+ txn.execute_values(sql % ("?",), rows, fetch=False)
+ else:
+ txn.execute_batch(sql % ("(?, ?, ?, ?, ?)",), rows)
+
+ # Mark the progress.
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "threads_backfill", {"last_thread_id": rows[-1][1]}
+ )
+
+ return txn.rowcount
+
+ result = await self.db_pool.runInteraction(
+ "threads_backfill", threads_backfill_txn
+ )
+
+ if not result:
+ await self.db_pool.updates._end_background_update("threads_backfill")
+
+ return result
+
@cached(uncached_args=("event",), tree=True)
async def get_relations_for_event(
self,
@@ -776,95 +876,194 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
- @cached(iterable=True)
- async def get_mutual_event_relations_for_rel_type(
- self, event_id: str, relation_type: str
- ) -> Set[Tuple[str, str]]:
- raise NotImplementedError()
-
- @cachedList(
- cached_method_name="get_mutual_event_relations_for_rel_type",
- list_name="relation_types",
- )
- async def get_mutual_event_relations(
- self, event_id: str, relation_types: Collection[str]
- ) -> Dict[str, Set[Tuple[str, str]]]:
- """
- Fetch event metadata for events which related to the same event as the given event.
-
- If the given event has no relation information, returns an empty dictionary.
+ @cached(tree=True)
+ async def get_threads(
+ self,
+ room_id: str,
+ limit: int = 5,
+ from_token: Optional[ThreadsNextBatch] = None,
+ ) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
+ """Get a list of thread IDs, ordered by topological ordering of their
+ latest reply.
Args:
- event_id: The event ID which is targeted by relations.
- relation_types: The relation types to check for mutual relations.
+ room_id: The room the event belongs to.
+ limit: Only fetch the most recent `limit` threads.
+ from_token: Fetch rows from a previous next_batch, or from the start if None.
Returns:
- A dictionary of relation type to:
- A set of tuples of:
- The sender
- The event type
+ A tuple of:
+ A list of thread root event IDs.
+
+ The next_batch, if one exists.
"""
- rel_type_sql, rel_type_args = make_in_list_sql_clause(
- self.database_engine, "relation_type", relation_types
- )
+ # Generate the pagination clause, if necessary.
+ #
+ # Find any threads where the latest reply is equal / before the last
+ # thread's topo ordering and earlier in stream ordering.
+ pagination_clause = ""
+ pagination_args: tuple = ()
+ if from_token:
+ pagination_clause = "AND topological_ordering <= ? AND stream_ordering < ?"
+ pagination_args = (
+ from_token.topological_ordering,
+ from_token.stream_ordering,
+ )
sql = f"""
- SELECT DISTINCT relation_type, sender, type FROM event_relations
- INNER JOIN events USING (event_id)
- WHERE relates_to_id = ? AND {rel_type_sql}
+ SELECT thread_id, topological_ordering, stream_ordering
+ FROM threads
+ WHERE
+ room_id = ?
+ {pagination_clause}
+ ORDER BY topological_ordering DESC, stream_ordering DESC
+ LIMIT ?
"""
- def _get_event_relations(
+ def _get_threads_txn(
txn: LoggingTransaction,
- ) -> Dict[str, Set[Tuple[str, str]]]:
- txn.execute(sql, [event_id] + rel_type_args)
- result: Dict[str, Set[Tuple[str, str]]] = {
- rel_type: set() for rel_type in relation_types
- }
- for rel_type, sender, type in txn.fetchall():
- result[rel_type].add((sender, type))
- return result
+ ) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
+ txn.execute(sql, (room_id, *pagination_args, limit + 1))
- return await self.db_pool.runInteraction(
- "get_event_relations", _get_event_relations
- )
+ rows = cast(List[Tuple[str, int, int]], txn.fetchall())
+ thread_ids = [r[0] for r in rows]
+
+ # If there are more events, generate the next pagination key from the
+ # last thread which will be returned.
+ next_token = None
+ if len(thread_ids) > limit:
+ last_topo_id = rows[-2][1]
+ last_stream_id = rows[-2][2]
+ next_token = ThreadsNextBatch(last_topo_id, last_stream_id)
+
+ return thread_ids[:limit], next_token
+
+ return await self.db_pool.runInteraction("get_threads", _get_threads_txn)
@cached()
- async def get_thread_id(self, event_id: str) -> Optional[str]:
+ async def get_thread_id(self, event_id: str) -> str:
"""
Get the thread ID for an event. This considers multi-level relations,
e.g. an annotation to an event which is part of a thread.
+ It only searches up the relations tree, i.e. it only searches for events
+ which the given event is related to (and which those events are related
+ to, etc.)
+
+ Given the following DAG:
+
+ A <---[m.thread]-- B <--[m.annotation]-- C
+ ^
+ |--[m.reference]-- D <--[m.annotation]-- E
+
+ get_thread_id(X) considers events B and C as part of thread A.
+
+ See also get_thread_id_for_receipts.
+
Args:
event_id: The event ID to fetch the thread ID for.
Returns:
The event ID of the root event in the thread, if this event is part
- of a thread. None, otherwise.
+ of a thread. "main", otherwise.
"""
- # Since event relations form a tree, we should only ever find 0 or 1
- # results from the below query.
+
+ # Recurse event relations up to the *root* event, then search that chain
+ # of relations for a thread relation. If one is found, the root event is
+ # returned.
+ #
+ # Note that this should only ever find 0 or 1 entries since it is invalid
+ # for an event to have a thread relation to an event which also has a
+ # relation.
sql = """
WITH RECURSIVE related_events AS (
- SELECT event_id, relates_to_id, relation_type
+ SELECT event_id, relates_to_id, relation_type, 0 depth
FROM event_relations
WHERE event_id = ?
- UNION SELECT e.event_id, e.relates_to_id, e.relation_type
+ UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
FROM event_relations e
INNER JOIN related_events r ON r.relates_to_id = e.event_id
- ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread';
+ WHERE depth <= 3
+ )
+ SELECT relates_to_id FROM related_events
+ WHERE relation_type = 'm.thread'
+ ORDER BY depth DESC
+ LIMIT 1;
"""
- def _get_thread_id(txn: LoggingTransaction) -> Optional[str]:
+ def _get_thread_id(txn: LoggingTransaction) -> str:
txn.execute(sql, (event_id,))
- # TODO Should we ensure there's only a single result here?
row = txn.fetchone()
if row:
return row[0]
- return None
+
+ # If no thread was found, it is part of the main timeline.
+ return MAIN_TIMELINE
return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)
+ @cached()
+ async def get_thread_id_for_receipts(self, event_id: str) -> str:
+ """
+ Get the thread ID for an event by traversing to the top-most related event
+ and confirming any children events form a thread.
+
+ Given the following DAG:
+
+ A <---[m.thread]-- B <--[m.annotation]-- C
+ ^
+ |--[m.reference]-- D <--[m.annotation]-- E
+
+ get_thread_id_for_receipts(X) considers events A, B, C, D, and E as part
+ of thread A.
+
+ See also get_thread_id.
+
+ Args:
+ event_id: The event ID to fetch the thread ID for.
+
+ Returns:
+ The event ID of the root event in the thread, if this event is part
+ of a thread. "main", otherwise.
+ """
+
+ # Recurse event relations up to the *root* event, then search for any events
+ # related to that root node for a thread relation. If one is found, the
+ # root event is returned.
+ #
+ # Note that there cannot be thread relations in the middle of the chain since
+ # it is invalid for an event to have a thread relation to an event which also
+ # has a relation.
+ sql = """
+ SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE((
+ WITH RECURSIVE related_events AS (
+ SELECT event_id, relates_to_id, relation_type, 0 depth
+ FROM event_relations
+ WHERE event_id = ?
+ UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
+ FROM event_relations e
+ INNER JOIN related_events r ON r.relates_to_id = e.event_id
+ WHERE depth <= 3
+ )
+ SELECT relates_to_id FROM related_events
+ ORDER BY depth DESC
+ LIMIT 1
+ ), ?) AND relation_type = 'm.thread' LIMIT 1;
+ """
+
+ def _get_related_thread_id(txn: LoggingTransaction) -> str:
+ txn.execute(sql, (event_id, event_id))
+ row = txn.fetchone()
+ if row:
+ return row[0]
+
+ # If no thread was found, it is part of the main timeline.
+ return MAIN_TIMELINE
+
+ return await self.db_pool.runInteraction(
+ "get_related_thread_id", _get_related_thread_id
+ )
+
class RelationsStore(RelationsWorkerStore):
pass
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 2337289d88..2ed6ad754f 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -666,7 +666,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
cached_method_name="get_rooms_for_user",
list_name="user_ids",
)
- async def get_rooms_for_users(
+ async def _get_rooms_for_users(
self, user_ids: Collection[str]
) -> Dict[str, FrozenSet[str]]:
"""A batched version of `get_rooms_for_user`.
@@ -697,6 +697,21 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
+ async def get_rooms_for_users(
+ self, user_ids: Collection[str]
+ ) -> Dict[str, FrozenSet[str]]:
+ """A batched wrapper around `_get_rooms_for_users`, to prevent locking
+ other calls to `get_rooms_for_user` for large user lists.
+ """
+ all_user_rooms: Dict[str, FrozenSet[str]] = {}
+
+ # 250 users is pretty arbitrary but the data can be quite large if users
+ # are in many rooms.
+ for user_ids in batch_iter(user_ids, 250):
+ all_user_rooms.update(await self._get_rooms_for_users(user_ids))
+
+ return all_user_rooms
+
@cached(max_entries=10000)
async def does_pair_of_users_share_a_room(
self, user_id: str, other_user_id: str
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 530f04e149..09ce855aa8 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -357,6 +357,24 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
)
args.extend(event_filter.related_by_rel_types)
+ if event_filter.rel_types:
+ clauses.append(
+ "(%s)"
+ % " OR ".join(
+ "event_relation.relation_type = ?" for _ in event_filter.rel_types
+ )
+ )
+ args.extend(event_filter.rel_types)
+
+ if event_filter.not_rel_types:
+ clauses.append(
+ "((%s) OR event_relation.relation_type IS NULL)"
+ % " AND ".join(
+ "event_relation.relation_type != ?" for _ in event_filter.not_rel_types
+ )
+ )
+ args.extend(event_filter.not_rel_types)
+
return " AND ".join(clauses), args
@@ -1024,28 +1042,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"after": {"event_ids": events_after, "token": end_token},
}
- async def get_all_new_events_stream(
- self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False
- ) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]:
+ async def get_all_new_event_ids_stream(
+ self,
+ from_id: int,
+ current_id: int,
+ limit: int,
+ ) -> Tuple[int, Dict[str, Optional[int]]]:
"""Get all new events
- Returns all events with from_id < stream_ordering <= current_id.
+ Returns all event ids with from_id < stream_ordering <= current_id.
Args:
from_id: the stream_ordering of the last event we processed
current_id: the stream_ordering of the most recently processed event
limit: the maximum number of events to return
- get_prev_content: whether to fetch previous event content
Returns:
- A tuple of (next_id, events, event_to_received_ts), where `next_id`
+ A tuple of (next_id, event_to_received_ts), where `next_id`
is the next value to pass as `from_id` (it will either be the
stream_ordering of the last returned event, or, if fewer than `limit`
events were found, the `current_id`). The `event_to_received_ts` is
- a dictionary mapping event ID to the event `received_ts`.
+ a dictionary mapping event ID to the event `received_ts`, sorted by ascending
+ stream_ordering.
"""
- def get_all_new_events_stream_txn(
+ def get_all_new_event_ids_stream_txn(
txn: LoggingTransaction,
) -> Tuple[int, Dict[str, Optional[int]]]:
sql = (
@@ -1070,15 +1091,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, event_to_received_ts
upper_bound, event_to_received_ts = await self.db_pool.runInteraction(
- "get_all_new_events_stream", get_all_new_events_stream_txn
- )
-
- events = await self.get_events_as_list(
- event_to_received_ts.keys(),
- get_prev_content=get_prev_content,
+ "get_all_new_event_ids_stream", get_all_new_event_ids_stream_txn
)
- return upper_bound, events, event_to_received_ts
+ return upper_bound, event_to_received_ts
async def get_federation_out_pos(self, typ: str) -> int:
if self._need_to_reset_federation_stream_positions:
@@ -1202,8 +1218,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
`to_token`), or `limit` is zero.
"""
- assert int(limit) >= 0
-
# Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities.
@@ -1282,8 +1296,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# Multiple labels could cause the same event to appear multiple times.
needs_distinct = True
- # If there is a filter on relation_senders and relation_types join to the
- # relations table.
+ # If there is a relation_senders and relation_types filter join to the
+ # relations table to get events related to the current event.
if event_filter and (
event_filter.related_by_senders or event_filter.related_by_rel_types
):
@@ -1298,6 +1312,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id)
"""
+ # If there is a not_rel_types filter join to the relations table to get
+ # the event's relation information.
+ if event_filter and (event_filter.rel_types or event_filter.not_rel_types):
+ join_clause += """
+ LEFT JOIN event_relations AS event_relation USING (event_id)
+ """
+
if needs_distinct:
select_keywords += " DISTINCT"
diff --git a/synapse/storage/schema/main/delta/73/09threads_table.sql b/synapse/storage/schema/main/delta/73/09threads_table.sql
new file mode 100644
index 0000000000..aa7c5e9a2e
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/09threads_table.sql
@@ -0,0 +1,30 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE threads (
+ room_id TEXT NOT NULL,
+ -- The event ID of the root event in the thread.
+ thread_id TEXT NOT NULL,
+ -- The latest event ID and corresponding topo / stream ordering.
+ latest_event_id TEXT NOT NULL,
+ topological_ordering BIGINT NOT NULL,
+ stream_ordering BIGINT NOT NULL,
+ CONSTRAINT threads_uniqueness UNIQUE (room_id, thread_id)
+);
+
+CREATE INDEX threads_ordering_idx ON threads(room_id, topological_ordering, stream_ordering);
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (7309, 'threads_backfill', '{}');
diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py
index 806b671305..2dcd43d0a2 100644
--- a/synapse/streams/__init__.py
+++ b/synapse/streams/__init__.py
@@ -27,7 +27,7 @@ class EventSource(Generic[K, R]):
self,
user: UserID,
from_key: K,
- limit: Optional[int],
+ limit: int,
room_ids: Collection[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index f6f7bf3d8b..6df2de919c 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -35,14 +35,14 @@ class PaginationConfig:
from_token: Optional[StreamToken]
to_token: Optional[StreamToken]
direction: str
- limit: Optional[int]
+ limit: int
@classmethod
async def from_request(
cls,
store: "DataStore",
request: SynapseRequest,
- default_limit: Optional[int] = None,
+ default_limit: int,
default_dir: str = "f",
) -> "PaginationConfig":
direction = parse_string(
@@ -69,12 +69,10 @@ class PaginationConfig:
raise SynapseError(400, "'to' parameter is invalid")
limit = parse_integer(request, "limit", default=default_limit)
+ if limit < 0:
+ raise SynapseError(400, "Limit must be 0 or above")
- if limit:
- if limit < 0:
- raise SynapseError(400, "Limit must be 0 or above")
-
- limit = min(int(limit), MAX_LIMIT)
+ limit = min(limit, MAX_LIMIT)
try:
return PaginationConfig(from_tok, to_tok, direction, limit)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index c4048d2477..40a9c5b53f 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -84,7 +84,15 @@ async def filter_events_for_client(
"""
# Filter out events that have been soft failed so that we don't relay them
# to clients.
+ events_before_filtering = events
events = [e for e in events if not e.internal_metadata.is_soft_failed()]
+ if len(events_before_filtering) != len(events):
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug(
+ "filter_events_for_client: Filtered out soft-failed events: Before=%s, After=%s",
+ [event.event_id for event in events_before_filtering],
+ [event.event_id for event in events],
+ )
types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id))
@@ -301,6 +309,10 @@ def _check_client_allowed_to_see_event(
_check_filter_send_to_client(event, clock, retention_policy, sender_ignored)
== _CheckFilter.DENIED
):
+ logger.debug(
+ "_check_client_allowed_to_see_event(event=%s): Filtered out event because `_check_filter_send_to_client` returned `_CheckFilter.DENIED`",
+ event.event_id,
+ )
return None
if event.event_id in always_include_ids:
@@ -312,9 +324,17 @@ def _check_client_allowed_to_see_event(
# for out-of-band membership events (eg, incoming invites, or rejections of
# said invite) for the user themselves.
if event.type == EventTypes.Member and event.state_key == user_id:
- logger.debug("Returning out-of-band-membership event %s", event)
+ logger.debug(
+ "_check_client_allowed_to_see_event(event=%s): Returning out-of-band-membership event %s",
+ event.event_id,
+ event,
+ )
return event
+ logger.debug(
+ "_check_client_allowed_to_see_event(event=%s): Filtered out event because it's an outlier",
+ event.event_id,
+ )
return None
if state is None:
@@ -337,11 +357,21 @@ def _check_client_allowed_to_see_event(
membership_result = _check_membership(user_id, event, visibility, state, is_peeking)
if not membership_result.allowed:
+ logger.debug(
+ "_check_client_allowed_to_see_event(event=%s): Filtered out event because the user can't see the event because of their membership, membership_result.allowed=%s membership_result.joined=%s",
+ event.event_id,
+ membership_result.allowed,
+ membership_result.joined,
+ )
return None
# If the sender has been erased and the user was not joined at the time, we
# must only return the redacted form.
if sender_erased and not membership_result.joined:
+ logger.debug(
+ "_check_client_allowed_to_see_event(event=%s): Returning pruned event because `sender_erased` and the user was not joined at the time",
+ event.event_id,
+ )
event = prune_event(event)
return event
|